Unverified Commit ffd9a3cb authored by littsk's avatar littsk Committed by GitHub
Browse files

[hotfix] fix bug in sequence parallel test (#4887)

parent fdec650b
...@@ -160,7 +160,7 @@ def run_forward_backward_with_hybrid_plugin( ...@@ -160,7 +160,7 @@ def run_forward_backward_with_hybrid_plugin(
input_shape = data["input_ids"].shape input_shape = data["input_ids"].shape
for k, v in data.items(): for k, v in data.items():
if v.shape == input_shape: if v.shape == input_shape:
data[k] = v.repeat(input_shape[:-1] + (input_shape[-1] * times,)) data[k] = v.repeat((1, ) * (v.dim() - 1) + (times,))
sharded_model.train() sharded_model.train()
if booster.plugin.stage_manager is not None: if booster.plugin.stage_manager is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment