Unverified Commit db4efbf4 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

fix(server): T5 weights names. (#582)

Fixes #541
parent f063ebde
...@@ -1006,12 +1006,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ...@@ -1006,12 +1006,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
super().__init__(config) super().__init__(config)
self.model_dim = config.d_model self.model_dim = config.d_model
try: self.shared = TensorParallelEmbedding(prefix="shared", weights=weights)
self.shared = TensorParallelEmbedding(prefix="shared", weights=weights)
except RuntimeError:
self.shared = TensorParallelEmbedding(
prefix="encoder.embed_tokens", weights=weights
)
encoder_config = copy.deepcopy(config) encoder_config = copy.deepcopy(config)
encoder_config.is_decoder = False encoder_config.is_decoder = False
......
...@@ -55,7 +55,16 @@ class T5Sharded(Seq2SeqLM): ...@@ -55,7 +55,16 @@ class T5Sharded(Seq2SeqLM):
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
weights = Weights( weights = Weights(
filenames, device=device, dtype=dtype, process_group=self.process_group filenames,
device=device,
dtype=dtype,
process_group=self.process_group,
aliases={
"shared.weight": [
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
]
},
) )
model = T5ForConditionalGeneration(config, weights) model = T5ForConditionalGeneration(config, weights)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment