"examples/pytorch/test_pytorch_examples.py" did not exist on "1381b6d01dca9c84c3cacd3eae5155cda8e03c18"
Unverified Commit a9cdb059 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix device issue in `OpenLlamaModelTest::test_model_parallelism` (#24195)



fix
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 9f81f4f6
...@@ -736,12 +736,16 @@ class OpenLlamaForCausalLM(OpenLlamaPreTrainedModel): ...@@ -736,12 +736,16 @@ class OpenLlamaForCausalLM(OpenLlamaPreTrainedModel):
hidden_states = outputs[0] hidden_states = outputs[0]
if self.config.shared_input_output_embedding: if self.config.shared_input_output_embedding:
logits = torch.einsum("blh,vh->blv", hidden_states, self.model.embed_tokens.weight) logits = torch.einsum(
"blh,vh->blv", hidden_states.to(self.model.embed_tokens.weight.device), self.model.embed_tokens.weight
)
else: else:
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)
loss = None loss = None
if labels is not None: if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
# Shift so that tokens < n predict n # Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous() shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous() shift_labels = labels[..., 1:].contiguous()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment