"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "28ec07d8add1698590632aaa35d0f3b95c721bb2"
Unverified Commit 4bb07647 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

refactor test (#20300)

- simplifies the devce checking test
parent 700e0cd6
...@@ -215,23 +215,8 @@ class MixedInt8TestMultiGpu(BaseMixedInt8Test): ...@@ -215,23 +215,8 @@ class MixedInt8TestMultiGpu(BaseMixedInt8Test):
self.model_name, load_in_8bit=True, max_memory=memory_mapping, device_map="auto" self.model_name, load_in_8bit=True, max_memory=memory_mapping, device_map="auto"
) )
def get_list_devices(model): # Check correct device map
list_devices = [] self.assertEqual(set(model_parallel.hf_device_map.values()), {0, 1})
for _, module in model.named_children():
if len(list(module.children())) > 0:
list_devices.extend(get_list_devices(module))
else:
# Do a try except since we can encounter Dropout modules that does not
# have any device set
try:
list_devices.append(next(module.parameters()).device.index)
except BaseException:
continue
return list_devices
list_devices = get_list_devices(model_parallel)
# Check that we have dispatched the model into 2 separate devices
self.assertTrue((1 in list_devices) and (0 in list_devices))
# Check that inference pass works on the model # Check that inference pass works on the model
encoded_input = self.tokenizer(self.input_text, return_tensors="pt") encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment