Commit 7b78db5d authored by dongcl's avatar dongcl
Browse files

支持deepseek v3训练

parent c099d843
Pipeline #2478 passed with stage
...@@ -461,7 +461,7 @@ class GPTModel(LanguageModule): ...@@ -461,7 +461,7 @@ class GPTModel(LanguageModule):
if ( if (
self.num_nextn_predict_layers self.num_nextn_predict_layers
and getattr(self.decoder, final_layernorm, None) is not None and getattr(self.decoder, "final_layernorm", None) is not None
): ):
# move block main model final norms here # move block main model final norms here
hidden_states = self.decoder.final_layernorm(hidden_states) hidden_states = self.decoder.final_layernorm(hidden_states)
......
...@@ -103,6 +103,10 @@ class Attention(MegatronModule, ABC): ...@@ -103,6 +103,10 @@ class Attention(MegatronModule, ABC):
self.num_attention_heads_per_partition = divide(self.config.num_attention_heads, world_size) self.num_attention_heads_per_partition = divide(self.config.num_attention_heads, world_size)
self.num_query_groups_per_partition = divide(self.config.num_query_groups, world_size) self.num_query_groups_per_partition = divide(self.config.num_query_groups, world_size)
# To support both CUDA Graphs and key value with different hidden size
self.key_hidden_size = self.hidden_size_per_attention_head
self.val_hidden_size = self.hidden_size_per_attention_head
self.core_attention = build_module( self.core_attention = build_module(
submodules.core_attention, submodules.core_attention,
config=self.config, config=self.config,
...@@ -209,10 +213,10 @@ class Attention(MegatronModule, ABC): ...@@ -209,10 +213,10 @@ class Attention(MegatronModule, ABC):
inf_max_seq_length = inference_params.max_sequence_length inf_max_seq_length = inference_params.max_sequence_length
inf_max_batch_size = inference_params.max_batch_size inf_max_batch_size = inference_params.max_batch_size
inference_key_memory = self._allocate_memory( inference_key_memory = self._allocate_memory(
inf_max_seq_length, inf_max_batch_size, key.shape[-1], key.dtype inf_max_seq_length, inf_max_batch_size, self.key_hidden_size, key.dtype
) )
inference_value_memory = self._allocate_memory( inference_value_memory = self._allocate_memory(
inf_max_seq_length, inf_max_batch_size, value.shape[-1], value.dtype inf_max_seq_length, inf_max_batch_size, self.val_hidden_size, value.dtype
) )
inference_params.key_value_memory_dict[self.layer_number] = ( inference_params.key_value_memory_dict[self.layer_number] = (
inference_key_memory, inference_key_memory,
...@@ -234,7 +238,10 @@ class Attention(MegatronModule, ABC): ...@@ -234,7 +238,10 @@ class Attention(MegatronModule, ABC):
assert batch_end <= inference_key_memory.size(1) assert batch_end <= inference_key_memory.size(1)
sequence_start = inference_params.sequence_len_offset sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + key.size(0) sequence_end = sequence_start + key.size(0)
assert sequence_end <= inference_key_memory.size(0) assert sequence_end <= inference_key_memory.size(0), (
"Current sequence length is longer than expected maximum sequence length! "
"Increase inference_max_seq_length."
)
if self.config.flash_decode: if self.config.flash_decode:
assert ( assert (
...@@ -245,7 +252,7 @@ class Attention(MegatronModule, ABC): ...@@ -245,7 +252,7 @@ class Attention(MegatronModule, ABC):
rotary_pos_sin_q = rotary_pos_sin[sequence_end - 1 : sequence_end] rotary_pos_sin_q = rotary_pos_sin[sequence_end - 1 : sequence_end]
rotary_pos_cos_k = rotary_pos_cos[sequence_end - 1 : sequence_end] rotary_pos_cos_k = rotary_pos_cos[sequence_end - 1 : sequence_end]
rotary_pos_sin_k = rotary_pos_sin[sequence_end - 1 : sequence_end] rotary_pos_sin_k = rotary_pos_sin[sequence_end - 1 : sequence_end]
else: else: # Prefill
rotary_pos_cos_q = rotary_pos_cos[:sequence_end] rotary_pos_cos_q = rotary_pos_cos[:sequence_end]
rotary_pos_sin_q = rotary_pos_sin[:sequence_end] rotary_pos_sin_q = rotary_pos_sin[:sequence_end]
rotary_pos_cos_k = rotary_pos_cos[:sequence_end] rotary_pos_cos_k = rotary_pos_cos[:sequence_end]
...@@ -394,7 +401,13 @@ class Attention(MegatronModule, ABC): ...@@ -394,7 +401,13 @@ class Attention(MegatronModule, ABC):
return output, bias return output, bias
query, key, value, rotary_pos_emb, attn_mask_type = self._adjust_key_value_for_inference( query, key, value, rotary_pos_emb, attn_mask_type = self._adjust_key_value_for_inference(
inference_params, query, key, value, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin inference_params,
query,
key,
value,
rotary_pos_emb,
rotary_pos_cos,
rotary_pos_sin,
) )
if packed_seq_params is not None: if packed_seq_params is not None:
......
...@@ -322,7 +322,7 @@ class TopKRouter(Router): ...@@ -322,7 +322,7 @@ class TopKRouter(Router):
scores, routing_map = self.aux_loss_load_balancing(logits) scores, routing_map = self.aux_loss_load_balancing(logits)
elif self.routing_type == "seq_aux_loss": elif self.routing_type == "seq_aux_loss":
scores, routing_map = self.seq_aux_loss_load_balancing(logits, bsz, seq_length) scores, routing_map = self.seq_aux_loss_load_balancing(logits, bsz, seq_length)
elif self.routing_type == "none": elif self.routing_type in ["none", "noaux_tc"]:
# A naive top-k routing without load balancing # A naive top-k routing without load balancing
scores, routing_map, _ = topk_softmax_with_capacity( scores, routing_map, _ = topk_softmax_with_capacity(
logits, logits,
......
...@@ -173,7 +173,7 @@ class MultiTokenPredictor(MegatronModule): ...@@ -173,7 +173,7 @@ class MultiTokenPredictor(MegatronModule):
# Rotary positional embeddings (embedding is None for PP intermediate devices) # Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb = None rotary_pos_emb = None
if self.position_embedding_type == 'rope': if self.position_embedding_type == 'rope' and not self.config.multi_latent_attention:
if inference_params is not None: if inference_params is not None:
rotary_seq_len = inference_params.max_sequence_length rotary_seq_len = inference_params.max_sequence_length
else: else:
...@@ -184,6 +184,7 @@ class MultiTokenPredictor(MegatronModule): ...@@ -184,6 +184,7 @@ class MultiTokenPredictor(MegatronModule):
rotary_seq_len *= self.config.context_parallel_size rotary_seq_len *= self.config.context_parallel_size
rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len) rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len)
if self.recompute_layer_norm: if self.recompute_layer_norm:
self.enorm_ckpt = CheckpointWithoutOutput() self.enorm_ckpt = CheckpointWithoutOutput()
enorm_output = self.enorm_ckpt.checkpoint(self.enorm, False, decoder_input) enorm_output = self.enorm_ckpt.checkpoint(self.enorm, False, decoder_input)
......
...@@ -68,6 +68,10 @@ class MultiLatentAttention(Attention): ...@@ -68,6 +68,10 @@ class MultiLatentAttention(Attention):
self.q_head_dim = self.config.qk_head_dim + self.config.qk_pos_emb_head_dim self.q_head_dim = self.config.qk_head_dim + self.config.qk_pos_emb_head_dim
# Overwrite the base class kv shape to support MLA inference
self.key_hidden_size = self.q_head_dim
self.val_hidden_size = self.config.v_head_dim
mscale = _yarn_get_mscale(self.config.rotary_scaling_factor, self.config.mscale) mscale = _yarn_get_mscale(self.config.rotary_scaling_factor, self.config.mscale)
self.softmax_scale = mscale * mscale / math.sqrt(self.q_head_dim) self.softmax_scale = mscale * mscale / math.sqrt(self.q_head_dim)
......
...@@ -1858,7 +1858,8 @@ def _add_tokenizer_args(parser): ...@@ -1858,7 +1858,8 @@ def _add_tokenizer_args(parser):
'QwenTokenizer', 'QwenTokenizer',
'TikTokenizer', 'TikTokenizer',
'MultimodalTokenizer', 'MultimodalTokenizer',
'NullTokenizer'], 'NullTokenizer',
'DeepSeekV2Tokenizer'],
help='What type of tokenizer to use.') help='What type of tokenizer to use.')
group.add_argument('--tokenizer-model', type=str, default=None, group.add_argument('--tokenizer-model', type=str, default=None,
help='Sentencepiece tokenizer model.') help='Sentencepiece tokenizer model.')
......
...@@ -16,6 +16,7 @@ from .bert_tokenization import FullTokenizer as FullBertTokenizer ...@@ -16,6 +16,7 @@ from .bert_tokenization import FullTokenizer as FullBertTokenizer
from .gpt2_tokenization import GPT2Tokenizer from .gpt2_tokenization import GPT2Tokenizer
from megatron.training.tokenizer.multimodal_tokenizer import MultimodalTokenizer from megatron.training.tokenizer.multimodal_tokenizer import MultimodalTokenizer
from transformers import Qwen2Tokenizer from transformers import Qwen2Tokenizer
from transformers import AutoTokenizer
def build_tokenizer(args, **kwargs): def build_tokenizer(args, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment