"docs/source/ko/tasks/image_classification.mdx" did not exist on "2e90c3df8f965ec616faa08e3fb2a1857a1e64b6"
Unverified Commit 5be1fb6d authored by Marc Sun's avatar Marc Sun Committed by GitHub
Browse files

Fix no split modules underlying modules (#27090)



* fix no split

* style

* remove comm

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* rename modules

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 66b088fa
......@@ -1520,21 +1520,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
Returns:
`List[str]`: List of modules that should not be split
"""
if self._no_split_modules is None:
raise ValueError(
f"{self.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model "
"class needs to implement the `_no_split_modules` attribute."
)
_no_split_modules = set(self._no_split_modules)
for module in self.modules():
if isinstance(module, PreTrainedModel):
if module._no_split_modules is None:
raise ValueError(
f"{module.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model "
"class needs to implement the `_no_split_modules` attribute."
)
else:
_no_split_modules = _no_split_modules | set(module._no_split_modules)
_no_split_modules = set()
modules_to_check = [self]
while len(modules_to_check) > 0:
module = modules_to_check.pop(-1)
# if the module does not appear in _no_split_modules, we also check the children
if module.__class__.__name__ not in _no_split_modules:
if isinstance(module, PreTrainedModel):
if module._no_split_modules is None:
raise ValueError(
f"{module.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model "
"class needs to implement the `_no_split_modules` attribute."
)
else:
_no_split_modules = _no_split_modules | set(module._no_split_modules)
modules_to_check += list(module.children())
return list(_no_split_modules)
def resize_token_embeddings(
......
......@@ -2641,6 +2641,7 @@ class SeamlessM4THifiGan(nn.Module):
class SeamlessM4TCodeHifiGan(PreTrainedModel):
config_class = SeamlessM4TConfig
main_input_name = "input_embeds"
_no_split_modules = []
def __init__(self, config):
super().__init__(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment