Unverified Commit 5d9a0ae7 authored by Zhongkai Zhao's avatar Zhongkai Zhao Committed by GitHub
Browse files

[hotfix] Fix ShardFormer test execution path when using sequence parallelism (#5230)

parent 46e09165
...@@ -154,7 +154,7 @@ def run_forward_backward_with_hybrid_plugin( ...@@ -154,7 +154,7 @@ def run_forward_backward_with_hybrid_plugin(
data = data_gen_fn() data = data_gen_fn()
if booster.plugin.enable_sequence_parallelism and booster.plugin.tp_size != 0: if booster.plugin.shard_config.enable_sequence_parallelism and booster.plugin.tp_size != 0:
seq_len = data["input_ids"].shape[-1] seq_len = data["input_ids"].shape[-1]
lcm = booster.plugin.tp_size * seq_len // math.gcd(booster.plugin.tp_size, seq_len) lcm = booster.plugin.tp_size * seq_len // math.gcd(booster.plugin.tp_size, seq_len)
times = lcm // seq_len times = lcm // seq_len
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment