Unverified Commit b2e95f62 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

Fix two issues related to `--moe-dense-tp-size=1` (#5657)


Co-authored-by: default avatarliusy58 <liusy58@linux.alibaba.com>
Co-authored-by: default avatar颉沆 <xiehang.lsy@alibaba-inc.com>
parent 1ab14c4c
...@@ -24,8 +24,10 @@ if TYPE_CHECKING: ...@@ -24,8 +24,10 @@ if TYPE_CHECKING:
_ATTN_TP_GROUP = None _ATTN_TP_GROUP = None
_ATTN_TP_RANK = None _ATTN_TP_RANK = None
_ATTN_TP_SIZE = None _ATTN_TP_SIZE = None
_DP_RANK = None _ATTN_DP_RANK = None
_DP_SIZE = None _ATTN_DP_SIZE = None
_LOCAL_ATTN_DP_SIZE = None
_LOCAL_ATTN_DP_RANK = None
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size): def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
...@@ -33,9 +35,27 @@ def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_si ...@@ -33,9 +35,27 @@ def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_si
return tp_rank, tp_size, 0 return tp_rank, tp_size, 0
attn_tp_size = tp_size // dp_size attn_tp_size = tp_size // dp_size
dp_rank = tp_rank // attn_tp_size attn_dp_rank = tp_rank // attn_tp_size
attn_tp_rank = tp_rank % attn_tp_size attn_tp_rank = tp_rank % attn_tp_size
return attn_tp_rank, attn_tp_size, dp_rank
return attn_tp_rank, attn_tp_size, attn_dp_rank
def compute_dp_attention_local_info(
enable_dp_attention, tp_rank, tp_size, dp_size, moe_dense_tp_size
):
if not enable_dp_attention:
return tp_rank, tp_size, 0
local_tp_size = moe_dense_tp_size if moe_dense_tp_size else tp_size
local_tp_rank = tp_rank % local_tp_size
local_dp_size = max(1, dp_size // (tp_size // local_tp_size))
local_attn_tp_size = local_tp_size // local_dp_size
local_attn_dp_rank = local_tp_rank // local_attn_tp_size
local_attn_tp_rank = local_tp_rank % local_attn_tp_size
return local_attn_tp_rank, local_attn_tp_size, local_attn_dp_rank
def initialize_dp_attention( def initialize_dp_attention(
...@@ -43,22 +63,32 @@ def initialize_dp_attention( ...@@ -43,22 +63,32 @@ def initialize_dp_attention(
tp_rank: int, tp_rank: int,
tp_size: int, tp_size: int,
dp_size: int, dp_size: int,
moe_dense_tp_size: int,
pp_size: int, pp_size: int,
): ):
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK, _ATTN_DP_SIZE
global _LOCAL_ATTN_DP_SIZE, _LOCAL_ATTN_DP_RANK
from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP
_ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK = compute_dp_attention_world_info( _ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK = compute_dp_attention_world_info(
enable_dp_attention, tp_rank, tp_size, dp_size enable_dp_attention, tp_rank, tp_size, dp_size
) )
_, _, _LOCAL_ATTN_DP_RANK = compute_dp_attention_local_info(
enable_dp_attention, tp_rank, tp_size, dp_size, moe_dense_tp_size
)
if enable_dp_attention: if enable_dp_attention:
local_rank = tp_rank % (tp_size // dp_size) local_rank = tp_rank % (tp_size // dp_size)
_DP_SIZE = dp_size _ATTN_DP_SIZE = dp_size
if moe_dense_tp_size is None:
_LOCAL_ATTN_DP_SIZE = _ATTN_DP_SIZE
else:
_LOCAL_ATTN_DP_SIZE = max(1, dp_size // (tp_size // moe_dense_tp_size))
else: else:
local_rank = tp_rank local_rank = tp_rank
_DP_SIZE = 1 _ATTN_DP_SIZE = 1
_LOCAL_ATTN_DP_SIZE = 1
tp_group = get_tp_group() tp_group = get_tp_group()
_ATTN_TP_GROUP = GroupCoordinator( _ATTN_TP_GROUP = GroupCoordinator(
...@@ -93,13 +123,33 @@ def get_attention_tp_size(): ...@@ -93,13 +123,33 @@ def get_attention_tp_size():
def get_attention_dp_rank(): def get_attention_dp_rank():
assert _DP_RANK is not None, "dp attention not initialized!" assert _ATTN_DP_RANK is not None, "dp attention not initialized!"
return _DP_RANK return _ATTN_DP_RANK
def get_attention_dp_size(): def get_attention_dp_size():
assert _DP_SIZE is not None, "dp attention not initialized!" assert _ATTN_DP_SIZE is not None, "dp attention not initialized!"
return _DP_SIZE return _ATTN_DP_SIZE
def get_local_attention_dp_rank():
assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!"
return _LOCAL_ATTN_DP_RANK
def get_local_attention_dp_size():
assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!"
return _LOCAL_ATTN_DP_SIZE
def get_local_attention_dp_rank():
assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!"
return _LOCAL_ATTN_DP_RANK
def get_local_attention_dp_size():
assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!"
return _LOCAL_ATTN_DP_SIZE
@contextmanager @contextmanager
...@@ -112,19 +162,19 @@ def disable_dp_size(): ...@@ -112,19 +162,19 @@ def disable_dp_size():
Args: Args:
tp_group (GroupCoordinator): the tp group coordinator tp_group (GroupCoordinator): the tp group coordinator
""" """
global _DP_SIZE global _ATTN_DP_SIZE
assert _DP_SIZE is not None, "dp attention not initialized!" assert _ATTN_DP_SIZE is not None, "dp attention not initialized!"
old_dp_size = _DP_SIZE old_dp_size = _ATTN_DP_SIZE
_DP_SIZE = 1 _ATTN_DP_SIZE = 1
try: try:
yield yield
finally: finally:
_DP_SIZE = old_dp_size _ATTN_DP_SIZE = old_dp_size
def get_dp_local_info(forward_batch: ForwardBatch): def get_dp_local_info(forward_batch: ForwardBatch):
dp_rank = get_attention_dp_rank() dp_rank = get_local_attention_dp_rank()
if forward_batch.dp_local_start_pos is None: if forward_batch.dp_local_start_pos is None:
cumtokens = torch.cumsum(forward_batch.global_num_tokens_gpu, dim=0) cumtokens = torch.cumsum(forward_batch.global_num_tokens_gpu, dim=0)
......
...@@ -30,9 +30,10 @@ from sglang.srt.layers.dp_attention import ( ...@@ -30,9 +30,10 @@ from sglang.srt.layers.dp_attention import (
attn_tp_all_gather, attn_tp_all_gather,
dp_gather_replicate, dp_gather_replicate,
dp_scatter, dp_scatter,
get_attention_dp_rank,
get_attention_dp_size, get_attention_dp_size,
get_attention_tp_size, get_attention_tp_size,
get_local_attention_dp_rank,
get_local_attention_dp_size,
) )
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
...@@ -46,6 +47,18 @@ from sglang.srt.utils import dump_to_file ...@@ -46,6 +47,18 @@ from sglang.srt.utils import dump_to_file
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
ForwardBatch,
ForwardMode,
)
from sglang.srt.utils import dump_to_file
logger = logging.getLogger(__name__)
@dataclasses.dataclass @dataclasses.dataclass
class LogitsProcessorOutput: class LogitsProcessorOutput:
## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor ## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
...@@ -170,7 +183,7 @@ class LogitsMetadata: ...@@ -170,7 +183,7 @@ class LogitsMetadata:
return return
cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0) cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
dp_rank = get_attention_dp_rank() dp_rank = get_local_attention_dp_rank()
if dp_rank == 0: if dp_rank == 0:
dp_local_start_pos = torch.zeros_like( dp_local_start_pos = torch.zeros_like(
self.global_num_tokens_for_logprob_gpu[0] self.global_num_tokens_for_logprob_gpu[0]
...@@ -324,7 +337,8 @@ class LogitsProcessor(nn.Module): ...@@ -324,7 +337,8 @@ class LogitsProcessor(nn.Module):
if self.debug_tensor_dump_output_folder: if self.debug_tensor_dump_output_folder:
assert ( assert (
not self.do_tensor_parallel_all_gather or get_attention_dp_size() == 1 not self.do_tensor_parallel_all_gather
or get_local_attention_dp_size() == 1
), "dp attention + sharded lm_head doesn't support full logits" ), "dp attention + sharded lm_head doesn't support full logits"
full_logits = self._get_logits(hidden_states, lm_head, logits_metadata) full_logits = self._get_logits(hidden_states, lm_head, logits_metadata)
dump_to_file(self.debug_tensor_dump_output_folder, "logits", full_logits) dump_to_file(self.debug_tensor_dump_output_folder, "logits", full_logits)
......
...@@ -207,7 +207,8 @@ class Scheduler( ...@@ -207,7 +207,8 @@ class Scheduler(
self.page_size = server_args.page_size self.page_size = server_args.page_size
# Distributed rank info # Distributed rank info
self.attn_tp_rank, self.attn_tp_size, self.dp_rank = ( self.dp_size = server_args.dp_size
self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
compute_dp_attention_world_info( compute_dp_attention_world_info(
server_args.enable_dp_attention, server_args.enable_dp_attention,
self.tp_rank, self.tp_rank,
...@@ -768,7 +769,7 @@ class Scheduler( ...@@ -768,7 +769,7 @@ class Scheduler(
) )
# send out reqs to the next stage # send out reqs to the next stage
dp_offset = self.dp_rank * self.attn_tp_size dp_offset = self.attn_dp_rank * self.attn_tp_size
if self.attn_tp_rank == 0: if self.attn_tp_rank == 0:
point_to_point_pyobj( point_to_point_pyobj(
recv_reqs, recv_reqs,
...@@ -815,7 +816,7 @@ class Scheduler( ...@@ -815,7 +816,7 @@ class Scheduler(
recv_reqs = None recv_reqs = None
else: else:
if self.attn_tp_rank == 0: if self.attn_tp_rank == 0:
dp_offset = self.dp_rank * self.attn_tp_size dp_offset = self.attn_dp_rank * self.attn_tp_size
recv_reqs = point_to_point_pyobj( recv_reqs = point_to_point_pyobj(
[], [],
self.pp_rank * self.tp_size + dp_offset, self.pp_rank * self.tp_size + dp_offset,
...@@ -1610,6 +1611,7 @@ class Scheduler( ...@@ -1610,6 +1611,7 @@ class Scheduler(
local_batch, local_batch,
dp_size=self.server_args.dp_size, dp_size=self.server_args.dp_size,
attn_tp_size=self.attn_tp_size, attn_tp_size=self.attn_tp_size,
moe_dense_tp_size=self.server_args.moe_dense_tp_size,
tp_cpu_group=self.tp_cpu_group, tp_cpu_group=self.tp_cpu_group,
get_idle_batch=self.get_idle_batch, get_idle_batch=self.get_idle_batch,
disable_cuda_graph=self.server_args.disable_cuda_graph, disable_cuda_graph=self.server_args.disable_cuda_graph,
...@@ -1622,6 +1624,7 @@ class Scheduler( ...@@ -1622,6 +1624,7 @@ class Scheduler(
local_batch: ScheduleBatch, local_batch: ScheduleBatch,
dp_size, dp_size,
attn_tp_size: int, attn_tp_size: int,
moe_dense_tp_size: Optional[int],
tp_cpu_group, tp_cpu_group,
get_idle_batch, get_idle_batch,
disable_cuda_graph: bool, disable_cuda_graph: bool,
...@@ -1631,15 +1634,15 @@ class Scheduler( ...@@ -1631,15 +1634,15 @@ class Scheduler(
# Check if other DP workers have running batches # Check if other DP workers have running batches
if local_batch is None: if local_batch is None:
num_tokens = 0 num_tokens = 0
global_num_tokens_for_logprob = 0 num_tokens_for_logprob = 0
elif local_batch.forward_mode.is_decode(): elif local_batch.forward_mode.is_decode():
num_tokens = local_batch.batch_size() num_tokens = local_batch.batch_size()
if not spec_algorithm.is_none() and spec_algorithm.is_eagle(): if not spec_algorithm.is_none() and spec_algorithm.is_eagle():
num_tokens = num_tokens * speculative_num_draft_tokens num_tokens = num_tokens * speculative_num_draft_tokens
global_num_tokens_for_logprob = num_tokens num_tokens_for_logprob = num_tokens
else: else:
num_tokens = local_batch.extend_num_tokens num_tokens = local_batch.extend_num_tokens
global_num_tokens_for_logprob = sum( num_tokens_for_logprob = sum(
[ [
# We should have at least 1 token for sample in every case. # We should have at least 1 token for sample in every case.
max(extend_len - logprob_start_len, 1) max(extend_len - logprob_start_len, 1)
...@@ -1666,7 +1669,7 @@ class Scheduler( ...@@ -1666,7 +1669,7 @@ class Scheduler(
[ [
num_tokens, num_tokens,
can_cuda_graph, can_cuda_graph,
global_num_tokens_for_logprob, num_tokens_for_logprob,
is_extend_in_batch, is_extend_in_batch,
], ],
dtype=torch.int64, dtype=torch.int64,
...@@ -1689,8 +1692,15 @@ class Scheduler( ...@@ -1689,8 +1692,15 @@ class Scheduler(
local_batch = get_idle_batch() local_batch = get_idle_batch()
if local_batch is not None: if local_batch is not None:
# TODO: handle the case when moe_dense_tp_size != 1
if moe_dense_tp_size == 1 and global_server_args_dict["enable_dp_lm_head"]:
local_batch.global_num_tokens = [num_tokens]
local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob]
else:
local_batch.global_num_tokens = global_num_tokens local_batch.global_num_tokens = global_num_tokens
local_batch.global_num_tokens_for_logprob = global_num_tokens_for_logprob local_batch.global_num_tokens_for_logprob = (
global_num_tokens_for_logprob
)
# Check forward mode for cuda graph # Check forward mode for cuda graph
if not disable_cuda_graph: if not disable_cuda_graph:
...@@ -2177,8 +2187,8 @@ class Scheduler( ...@@ -2177,8 +2187,8 @@ class Scheduler(
def get_print_prefix(self): def get_print_prefix(self):
prefix = "" prefix = ""
if self.dp_rank is not None: if self.attn_dp_rank is not None:
prefix += f" DP{self.dp_rank}" prefix += f" DP{self.attn_dp_rank}"
if self.server_args.tp_size > 1: if self.server_args.tp_size > 1:
prefix += f" TP{self.tp_rank}" prefix += f" TP{self.tp_rank}"
if self.pp_size > 1: if self.pp_size > 1:
......
...@@ -401,6 +401,7 @@ class ModelRunner: ...@@ -401,6 +401,7 @@ class ModelRunner:
tp_rank=self.tp_rank, tp_rank=self.tp_rank,
tp_size=self.tp_size, tp_size=self.tp_size,
dp_size=self.server_args.dp_size, dp_size=self.server_args.dp_size,
moe_dense_tp_size=self.server_args.moe_dense_tp_size,
pp_size=self.server_args.pp_size, pp_size=self.server_args.pp_size,
) )
......
...@@ -40,9 +40,9 @@ from sglang.srt.layers.dp_attention import ( ...@@ -40,9 +40,9 @@ from sglang.srt.layers.dp_attention import (
attn_tp_reduce_scatter, attn_tp_reduce_scatter,
dp_gather_partial, dp_gather_partial,
dp_scatter, dp_scatter,
get_attention_dp_size,
get_attention_tp_rank, get_attention_tp_rank,
get_attention_tp_size, get_attention_tp_size,
get_local_attention_dp_size,
) )
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
...@@ -438,7 +438,6 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -438,7 +438,6 @@ class DeepseekV2AttentionMLA(nn.Module):
self.v_head_dim = v_head_dim self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank self.kv_lora_rank = kv_lora_rank
self.dp_size = get_attention_dp_size()
attn_tp_rank = get_attention_tp_rank() attn_tp_rank = get_attention_tp_rank()
attn_tp_size = get_attention_tp_size() attn_tp_size = get_attention_tp_size()
...@@ -1133,7 +1132,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1133,7 +1132,7 @@ class DeepseekV2DecoderLayer(nn.Module):
max_position_embeddings = getattr(config, "max_position_embeddings", 8192) max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"] self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
self.layer_id = layer_id self.layer_id = layer_id
self.dp_size = get_attention_dp_size() self.local_dp_size = get_local_attention_dp_size()
self.attn_tp_size = get_attention_tp_size() self.attn_tp_size = get_attention_tp_size()
self.attn_tp_rank = get_attention_tp_rank() self.attn_tp_rank = get_attention_tp_rank()
self.self_attn = DeepseekV2AttentionMLA( self.self_attn = DeepseekV2AttentionMLA(
...@@ -1184,7 +1183,8 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1184,7 +1183,8 @@ class DeepseekV2DecoderLayer(nn.Module):
) )
self.input_is_scattered = ( self.input_is_scattered = (
previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED layer_id > 0
and previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
) )
self.is_last_layer = self.layer_id == config.num_hidden_layers - 1 self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
...@@ -1264,7 +1264,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1264,7 +1264,7 @@ class DeepseekV2DecoderLayer(nn.Module):
# Gather # Gather
if get_tensor_model_parallel_world_size() > 1: if get_tensor_model_parallel_world_size() > 1:
# all gather and all reduce # all gather and all reduce
if self.dp_size != 1: if self.local_dp_size != 1:
if self.attn_tp_rank == 0: if self.attn_tp_rank == 0:
hidden_states += residual hidden_states += residual
hidden_states, local_hidden_states = ( hidden_states, local_hidden_states = (
...@@ -1289,7 +1289,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1289,7 +1289,7 @@ class DeepseekV2DecoderLayer(nn.Module):
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter # TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
# Scatter # Scatter
if self.dp_size != 1: if self.local_dp_size != 1:
# important: forward batch.gathered_buffer is used both after scatter and after gather. # important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this! # be careful about this!
hidden_states, global_hidden_states = ( hidden_states, global_hidden_states = (
...@@ -1413,7 +1413,7 @@ class DeepseekV2Model(nn.Module): ...@@ -1413,7 +1413,7 @@ class DeepseekV2Model(nn.Module):
) )
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.dp_size = get_attention_dp_size() self.dp_size = get_local_attention_dp_size()
def get_input_embeddings(self) -> torch.Tensor: def get_input_embeddings(self) -> torch.Tensor:
return self.embed_tokens return self.embed_tokens
...@@ -1478,7 +1478,7 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1478,7 +1478,7 @@ class DeepseekV2ForCausalLM(nn.Module):
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.dp_size = get_attention_dp_size() self.dp_size = get_local_attention_dp_size()
def determine_n_share_experts_fusion( def determine_n_share_experts_fusion(
self, architecture: str = "DeepseekV3ForCausalLM" self, architecture: str = "DeepseekV3ForCausalLM"
......
...@@ -30,9 +30,9 @@ from sglang.srt.distributed import ( ...@@ -30,9 +30,9 @@ from sglang.srt.distributed import (
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import (
dp_gather_partial, dp_gather_partial,
dp_scatter, dp_scatter,
get_attention_dp_size,
get_attention_tp_rank, get_attention_tp_rank,
get_attention_tp_size, get_attention_tp_size,
get_local_attention_dp_size,
) )
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
...@@ -198,7 +198,6 @@ class Llama4Attention(nn.Module): ...@@ -198,7 +198,6 @@ class Llama4Attention(nn.Module):
self.use_rope = int((layer_id + 1) % 4 != 0) self.use_rope = int((layer_id + 1) % 4 != 0)
self.use_qk_norm = config.use_qk_norm and self.use_rope self.use_qk_norm = config.use_qk_norm and self.use_rope
self.dp_size = get_attention_dp_size()
attn_tp_rank = get_attention_tp_rank() attn_tp_rank = get_attention_tp_rank()
attn_tp_size = get_attention_tp_size() attn_tp_size = get_attention_tp_size()
...@@ -342,7 +341,7 @@ class Llama4DecoderLayer(nn.Module): ...@@ -342,7 +341,7 @@ class Llama4DecoderLayer(nn.Module):
rope_theta = config.rope_theta rope_theta = config.rope_theta
rope_scaling = config.rope_scaling rope_scaling = config.rope_scaling
max_position_embeddings = config.max_position_embeddings max_position_embeddings = config.max_position_embeddings
self.dp_size = get_attention_dp_size() self.local_dp_size = get_local_attention_dp_size()
self.attn_tp_size = get_attention_tp_size() self.attn_tp_size = get_attention_tp_size()
self.attn_tp_rank = get_attention_tp_rank() self.attn_tp_rank = get_attention_tp_rank()
...@@ -405,7 +404,7 @@ class Llama4DecoderLayer(nn.Module): ...@@ -405,7 +404,7 @@ class Llama4DecoderLayer(nn.Module):
# Gather # Gather
if get_tensor_model_parallel_world_size() > 1: if get_tensor_model_parallel_world_size() > 1:
# all gather and all reduce # all gather and all reduce
if self.dp_size != 1: if self.local_dp_size != 1:
if self.attn_tp_rank == 0: if self.attn_tp_rank == 0:
hidden_states += residual hidden_states += residual
hidden_states, local_hidden_states = ( hidden_states, local_hidden_states = (
...@@ -430,7 +429,7 @@ class Llama4DecoderLayer(nn.Module): ...@@ -430,7 +429,7 @@ class Llama4DecoderLayer(nn.Module):
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter # TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
# Scatter # Scatter
if self.dp_size != 1: if self.local_dp_size != 1:
# important: forward batch.gathered_buffer is used both after scatter and after gather. # important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this! # be careful about this!
hidden_states, global_hidden_states = ( hidden_states, global_hidden_states = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment