Unverified Commit c7d6975d authored by flybird11111's avatar flybird11111 Committed by GitHub
Browse files

[shardformer] fix GPT2DoubleHeadsModel (#4703)

parent 068372a7
......@@ -94,9 +94,9 @@ class GPT2PipelineForwards:
if hidden_states is None:
raise ValueError("hidden_states shouldn't be None for stages other than the first stage.")
input_shape = hidden_states.size()[:-1]
batch_size = input_shape[0]
device = hidden_states.device
hidden_states = hidden_states.view((-1,) + hidden_states.shape[-2:])
batch_size = hidden_states.shape[0]
# GPT2Attention mask.
if attention_mask is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment