Commit ed2c06c3 authored by zhuwenwen's avatar zhuwenwen
Browse files

fix run error

parent 1940460d
......@@ -210,4 +210,13 @@ def set_profilling(profiling):
def get_profilling() -> bool:
global _profiling
return _profiling
\ No newline at end of file
return _profiling
@contextmanager
def set_warming_up(warming_up):
global _warming_up
_warming_up = warming_up
def get_warming_up() -> bool:
global _warming_up
return _warming_up
\ No newline at end of file
......@@ -1578,6 +1578,11 @@ class RowParallelLinear(LinearBase):
# Divide the weight matrix along the first dimension.
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.enable_dp_attn_moe = enable_dp_attn_moe
if enable_dp_attn_moe:
self.tp_rank = get_moe_tp_rank()
self.tp_size = get_moe_tp_size()
if expect_tp_size is not None:
self.tp_rank = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment