Unverified Commit a43d9352 authored by Baizhou Huang's avatar Baizhou Huang Committed by GitHub
Browse files

replace assert with exception in src/transformers/utils/model_pararallel_utils.py (#14072)



* replace assert with exception in src/transformers/utils/model_parallel_utils.py

* fix some code style

* fix typo
Co-authored-by: default avatarskpig <1900012999@pku.edu.cn>
parent 53dc39d8
......@@ -30,19 +30,21 @@ def assert_device_map(device_map, num_blocks):
missing_blocks = [i for i in blocks if i not in device_map_blocks]
extra_blocks = [i for i in device_map_blocks if i not in blocks]
assert len(duplicate_blocks) == 0, (
"Duplicate attention blocks specified in device_map. Attention blocks must be specified to one device. These "
"attention blocks were specified more than once: " + str(duplicate_blocks)
)
assert len(missing_blocks) == 0, (
"There are attention blocks for this model that are not specified in the device_map. Add these attention "
"blocks to a device on the device_map: " + str(missing_blocks)
)
assert (
len(extra_blocks) == 0
), "The device_map contains more attention blocks than this model has. Remove these from the device_map:" + str(
extra_blocks
)
if len(duplicate_blocks) != 0:
raise ValueError(
"Duplicate attention blocks specified in device_map. Attention blocks must be specified to one device. These "
"attention blocks were specified more than once: " + str(duplicate_blocks)
)
if len(missing_blocks) != 0:
raise ValueError(
"There are attention blocks for this model that are not specified in the device_map. Add these attention "
"blocks to a device on the device_map: " + str(missing_blocks)
)
if len(extra_blocks) != 0:
raise ValueError(
"The device_map contains more attention blocks than this model has. Remove these from the device_map:"
+ str(extra_blocks)
)
def get_device_map(n_layers, devices):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment