"...composable_kernel_rocm.git" did not exist on "f0fd02634c2f8f8c70f5a0ab2a8c84db5e36eeca"
Unverified Commit 6bf94bc0 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

correctly handle mt5 (#9879)

parent 7eadfe16
...@@ -563,7 +563,7 @@ def freeze_embeds(model): ...@@ -563,7 +563,7 @@ def freeze_embeds(model):
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5.""" """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
model_type = model.config.model_type model_type = model.config.model_type
if model_type == "t5": if model_type in ["t5", "mt5"]:
freeze_params(model.shared) freeze_params(model.shared)
for d in [model.encoder, model.decoder]: for d in [model.encoder, model.decoder]:
freeze_params(d.embed_tokens) freeze_params(d.embed_tokens)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment