Commit dc6f80a5 authored by zhuwenwen's avatar zhuwenwen
Browse files

update tp_size

parent a441a5d9
...@@ -718,12 +718,12 @@ def main(args: argparse.Namespace): ...@@ -718,12 +718,12 @@ def main(args: argparse.Namespace):
intermediate_size = config.intermediate_size intermediate_size = config.intermediate_size
enable_ep = bool(args.enable_expert_parallel) enable_ep = bool(args.enable_expert_parallel)
if enable_ep: if enable_ep:
ensure_divisibility(E, args.tp_size, "Number of experts") ensure_divisibility(E, tp_size, "Number of experts")
E = E // args.tp_size E = E // tp_size
shard_intermediate_size = 2 * intermediate_size shard_intermediate_size = 2 * intermediate_size
else: else:
ensure_divisibility(intermediate_size, args.tp_size, "intermediate_size") ensure_divisibility(intermediate_size, tp_size, "intermediate_size")
shard_intermediate_size = 2 * intermediate_size // args.tp_size shard_intermediate_size = 2 * intermediate_size // tp_size
hidden_size = config.hidden_size hidden_size = config.hidden_size
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
use_fp8_w8a8 = args.dtype == "fp8_w8a8" use_fp8_w8a8 = args.dtype == "fp8_w8a8"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment