Commit 7b78db5d authored by dongcl's avatar dongcl
Browse files

支持deepseek v3训练

parent c099d843
Pipeline #2478 passed with stage
......@@ -461,7 +461,7 @@ class GPTModel(LanguageModule):
if (
self.num_nextn_predict_layers
and getattr(self.decoder, final_layernorm, None) is not None
and getattr(self.decoder, "final_layernorm", None) is not None
):
# move block main model final norms here
hidden_states = self.decoder.final_layernorm(hidden_states)
......
......@@ -103,6 +103,10 @@ class Attention(MegatronModule, ABC):
self.num_attention_heads_per_partition = divide(self.config.num_attention_heads, world_size)
self.num_query_groups_per_partition = divide(self.config.num_query_groups, world_size)
# To support both CUDA Graphs and key value with different hidden size
self.key_hidden_size = self.hidden_size_per_attention_head
self.val_hidden_size = self.hidden_size_per_attention_head
self.core_attention = build_module(
submodules.core_attention,
config=self.config,
......@@ -209,10 +213,10 @@ class Attention(MegatronModule, ABC):
inf_max_seq_length = inference_params.max_sequence_length
inf_max_batch_size = inference_params.max_batch_size
inference_key_memory = self._allocate_memory(
inf_max_seq_length, inf_max_batch_size, key.shape[-1], key.dtype
inf_max_seq_length, inf_max_batch_size, self.key_hidden_size, key.dtype
)
inference_value_memory = self._allocate_memory(
inf_max_seq_length, inf_max_batch_size, value.shape[-1], value.dtype
inf_max_seq_length, inf_max_batch_size, self.val_hidden_size, value.dtype
)
inference_params.key_value_memory_dict[self.layer_number] = (
inference_key_memory,
......@@ -234,7 +238,10 @@ class Attention(MegatronModule, ABC):
assert batch_end <= inference_key_memory.size(1)
sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + key.size(0)
assert sequence_end <= inference_key_memory.size(0)
assert sequence_end <= inference_key_memory.size(0), (
"Current sequence length is longer than expected maximum sequence length! "
"Increase inference_max_seq_length."
)
if self.config.flash_decode:
assert (
......@@ -245,7 +252,7 @@ class Attention(MegatronModule, ABC):
rotary_pos_sin_q = rotary_pos_sin[sequence_end - 1 : sequence_end]
rotary_pos_cos_k = rotary_pos_cos[sequence_end - 1 : sequence_end]
rotary_pos_sin_k = rotary_pos_sin[sequence_end - 1 : sequence_end]
else:
else: # Prefill
rotary_pos_cos_q = rotary_pos_cos[:sequence_end]
rotary_pos_sin_q = rotary_pos_sin[:sequence_end]
rotary_pos_cos_k = rotary_pos_cos[:sequence_end]
......@@ -394,7 +401,13 @@ class Attention(MegatronModule, ABC):
return output, bias
query, key, value, rotary_pos_emb, attn_mask_type = self._adjust_key_value_for_inference(
inference_params, query, key, value, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin
inference_params,
query,
key,
value,
rotary_pos_emb,
rotary_pos_cos,
rotary_pos_sin,
)
if packed_seq_params is not None:
......
......@@ -322,7 +322,7 @@ class TopKRouter(Router):
scores, routing_map = self.aux_loss_load_balancing(logits)
elif self.routing_type == "seq_aux_loss":
scores, routing_map = self.seq_aux_loss_load_balancing(logits, bsz, seq_length)
elif self.routing_type == "none":
elif self.routing_type in ["none", "noaux_tc"]:
# A naive top-k routing without load balancing
scores, routing_map, _ = topk_softmax_with_capacity(
logits,
......
......@@ -173,7 +173,7 @@ class MultiTokenPredictor(MegatronModule):
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb = None
if self.position_embedding_type == 'rope':
if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention:
if inference_params is not None:
rotary_seq_len = inference_params.max_sequence_length
else:
......@@ -184,6 +184,7 @@ class MultiTokenPredictor(MegatronModule):
rotary_seq_len *= self.config.context_parallel_size
rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len)
if self.recompute_layer_norm:
self.enorm_ckpt = CheckpointWithoutOutput()
enorm_output = self.enorm_ckpt.checkpoint(self.enorm, False, decoder_input)
......
......@@ -68,6 +68,10 @@ class MultiLatentAttention(Attention):
self.q_head_dim = self.config.qk_head_dim + self.config.qk_pos_emb_head_dim
# Overwrite the base class kv shape to support MLA inference
self.key_hidden_size = self.q_head_dim
self.val_hidden_size = self.config.v_head_dim
mscale = _yarn_get_mscale(self.config.rotary_scaling_factor, self.config.mscale)
self.softmax_scale = mscale * mscale / math.sqrt(self.q_head_dim)
......
......@@ -1858,7 +1858,8 @@ def _add_tokenizer_args(parser):
'QwenTokenizer',
'TikTokenizer',
'MultimodalTokenizer',
'NullTokenizer'],
'NullTokenizer',
'DeepSeekV2Tokenizer'],
help='What type of tokenizer to use.')
group.add_argument('--tokenizer-model', type=str, default=None,
help='Sentencepiece tokenizer model.')
......
......@@ -16,6 +16,7 @@ from .bert_tokenization import FullTokenizer as FullBertTokenizer
from .gpt2_tokenization import GPT2Tokenizer
from megatron.training.tokenizer.multimodal_tokenizer import MultimodalTokenizer
from transformers import Qwen2Tokenizer
from transformers import AutoTokenizer
def build_tokenizer(args, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment