Commit f87b35b2 authored by jerrrrry's avatar jerrrrry
Browse files

Initial commit

parents
Pipeline #2648 failed with stages
in 0 seconds
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from megatron.core import parallel_state as mpu
from megatron.core import tensor_parallel
from megatron.core import ModelParallelConfig
from torch import nn
from transformers.activations import ACT2FN
from verl.models.llama.megatron.layers.parallel_linear import MergedColumnParallelLinear
from verl.utils.megatron import tensor_parallel as tp_utils
class ParallelLlamaMLP(nn.Module):
def __init__(self, config, megatron_config: ModelParallelConfig = None) -> None:
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
# The weight is only [hidden_size, intermediate_size // model_parallel_world_size]
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear()
if megatron_config is not None:
assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
assert row_kwargs.get('config', False), 'must have ModelParallelConfig'
tp_utils.update_kwargs_with_config(row_kwargs, megatron_config)
tp_utils.update_kwargs_with_config(column_kwargs, megatron_config)
tp_size = mpu.get_tensor_model_parallel_world_size()
self.gate_up_proj = MergedColumnParallelLinear(
input_size=self.hidden_size,
gate_ouput_size=self.intermediate_size,
up_output_size=self.intermediate_size,
bias=False,
gather_output=False,
skip_bias_add=False,
**column_kwargs,
)
self.gate_size = self.intermediate_size // tp_size
self.down_proj = tensor_parallel.RowParallelLinear(input_size=self.intermediate_size,
output_size=self.hidden_size,
bias=False,
input_is_parallel=True,
skip_bias_add=False,
**row_kwargs)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
gate_up = self.gate_up_proj(x)[0]
gate, up = gate_up.split(self.gate_size, dim=-1)
return self.down_proj(self.act_fn(gate) * up)[0]
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numbers
import torch
from megatron.core import ModelParallelConfig
from torch import nn
from transformers import LlamaConfig
from apex.normalization.fused_layer_norm import fused_rms_norm_affine
from verl.utils.megatron import sequence_parallel as sp_utils
class ParallelLlamaRMSNorm(nn.Module):
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
if isinstance(config.hidden_size, numbers.Integral):
normalized_shape = (config.hidden_size,)
self.normalized_shape = torch.Size(normalized_shape)
self.weight = nn.Parameter(torch.ones(self.normalized_shape))
self.variance_epsilon = config.rms_norm_eps
if megatron_config.sequence_parallel:
sp_utils.mark_parameter_as_sequence_parallel(self.weight)
def forward(self, hidden_states):
return fused_rms_norm_affine(input=hidden_states,
weight=self.weight,
normalized_shape=self.normalized_shape,
eps=self.variance_epsilon,
memory_efficient=True)
\ No newline at end of file
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch LLaMA model with Megatron-style acceleration."""
from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from megatron.core import tensor_parallel
from megatron.core import ModelParallelConfig
from megatron.core import mpu
from torch import nn
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import CausalLMOutputWithPast
from verl.utils.megatron import sequence_parallel as sp_utils
from verl.utils.megatron import tensor_parallel as tp_utils
from verl.utils.megatron_utils import TransformerConfig, convert_config
from .layers import ParallelLlamaDecoderLayer, ParallelLlamaRMSNorm, ParallelLlamaDecoderLayerRmPad
"""
TODO:
1. Add weight initialization. Here we need to be careful on TP weight init.
2. Add sequence parallel
3. Load checkpoint from meta LLama pretrained checkpoint
"""
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
class ParallelLlamaModel(nn.Module):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
Args:
config: LlamaConfig
"""
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
super().__init__()
self.config: TransformerConfig = convert_config(config, megatron_config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()
if megatron_config is not None:
assert embedding_kwargs.get('config', False), 'must have ModelParallelConfig'
tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config)
self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size,
embedding_dim=config.hidden_size,
**embedding_kwargs)
self.layers = nn.ModuleList(
[ParallelLlamaDecoderLayer(config, megatron_config) for _ in range(config.num_hidden_layers)])
self.norm = ParallelLlamaRMSNorm(config, megatron_config)
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype,
tgt_len=input_shape[-1]).to(inputs_embeds.device)
combined_attention_mask = (expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask +
combined_attention_mask)
return combined_attention_mask
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
"""
Args:
input_ids: input ids. shape (batch_size, seq_length)
attention_mask: attention_mask. shape (batch_size, seq_length)
position_ids: position ids. shape (batch_size, seq_length)
Returns:
"""
batch_size, seq_length = input_ids.shape
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds)
hidden_states = inputs_embeds
for idx, decoder_layer in enumerate(self.layers):
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
)
hidden_states = layer_outputs
hidden_states = self.norm(hidden_states)
return hidden_states
class ParallelLlamaForCausalLM(nn.Module):
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
super().__init__()
self.config: TransformerConfig = convert_config(config, megatron_config)
self.model = ParallelLlamaModel(config, megatron_config=megatron_config)
self.vocab_size = config.vocab_size
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
if megatron_config is not None:
assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
self.lm_head = tensor_parallel.ColumnParallelLinear(input_size=config.hidden_size,
output_size=config.vocab_size,
bias=False,
gather_output=False,
skip_bias_add=False,
**column_kwargs)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
```"""
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
)
hidden_states = outputs
logits = self.lm_head(hidden_states)[0]
logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits)
logits = logits.float()
return CausalLMOutputWithPast(
loss=None,
logits=logits,
past_key_values=None,
hidden_states=None,
attentions=None,
)
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
class ParallelLlamaModelRmPad(nn.Module):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
Args:
config: LlamaConfig
"""
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
super().__init__()
self.config: TransformerConfig = convert_config(config, megatron_config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()
self.megatron_config = megatron_config
if megatron_config is not None:
assert embedding_kwargs.get('config', False), 'must have ModelParallelConfig'
tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config)
self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size,
embedding_dim=config.hidden_size,
**embedding_kwargs)
self.layers = nn.ModuleList(
[ParallelLlamaDecoderLayerRmPad(config, megatron_config) for _ in range(config.num_hidden_layers)])
self.norm = ParallelLlamaRMSNorm(config, megatron_config)
def forward(self,
input_ids: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
sequence_length: int = None,
indices: torch.Tensor = None,
cu_seqlens: int = None,
max_seqlen_in_batch: int = None) -> Union[Tuple, BaseModelOutputWithPast]:
"""
Args:
input_ids: input ids. shape (1, totol_nnz)
position_ids: position ids. shape (batch_size, seq_length)
Returns:
"""
inputs_embeds = self.embed_tokens(input_ids) # (1, total_nnz) -> (1, total_nnz, hidden_size)
# (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size)
inputs_embeds = inputs_embeds.transpose(0, 1)
if self.megatron_config.sequence_parallel:
inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds)
hidden_states = inputs_embeds
for idx, decoder_layer in enumerate(self.layers):
layer_outputs = decoder_layer(hidden_states,
position_ids=position_ids,
sequence_length=sequence_length,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen_in_batch=max_seqlen_in_batch)
hidden_states = layer_outputs
hidden_states = self.norm(hidden_states)
return hidden_states
class ParallelLlamaForCausalLMRmPad(nn.Module):
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig):
super().__init__()
self.config: TransformerConfig = convert_config(config, megatron_config)
self.megatron_config = megatron_config
self.model = ParallelLlamaModelRmPad(config, megatron_config=megatron_config)
self.vocab_size = config.vocab_size
self._init_head(config)
def _init_head(self, config):
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
if self.megatron_config is not None:
assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
self.lm_head = tensor_parallel.ColumnParallelLinear(input_size=config.hidden_size,
output_size=config.vocab_size,
bias=False,
gather_output=False,
skip_bias_add=False,
**column_kwargs)
def _forward_head(self, hidden_states):
# all_gather from sequence parallel region is performed inside lm_head
logits = self.lm_head(hidden_states)[0]
logits = logits.float() # (total_nnz_padded, 1, vocab_size // tp)
logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) # (total_nnz_padded, 1, vocab_size)
return logits
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
```"""
batch_size, sequence_length = input_ids.shape
# remove padding here
input_ids, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(input_ids.unsqueeze(dim=-1),
attention_mask) # (total_nnz, 1)
# pad input_ids to multiple of tp for all tp ranks
# TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap
if self.megatron_config.sequence_parallel:
input_ids = sp_utils.pad_to_sequence_parallel(input_ids)
input_ids = input_ids.transpose(0, 1) # (1, total_nnz+pad)
outputs = self.model(input_ids=input_ids,
position_ids=position_ids,
sequence_length=sequence_length,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen_in_batch=max_seqlen_in_batch)
hidden_states = outputs
logits = self._forward_head(hidden_states)
# remove padding from sequence parallel
if self.megatron_config.sequence_parallel:
totol_nnz = cu_seqlens[-1]
logits = logits[:totol_nnz] # (total_nnz_padded)
logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension
# add removed padding back
logits = pad_input(logits, indices, batch_size,
seqlen=sequence_length) # (batch_size, sequence_length, vocab_size)
return CausalLMOutputWithPast(
loss=None,
logits=logits,
past_key_values=None,
hidden_states=None,
attentions=None,
)
class ParallelLlamaForValueRmPad(ParallelLlamaForCausalLMRmPad):
def _init_head(self, config):
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
if self.megatron_config is not None:
assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False)
# lm_head is effectively the same as sequence parallel
sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight)
def _forward_head(self, hidden_states):
logits = self.lm_head(hidden_states) # (total_nnz_padded // tp, 1, 1)
logits = logits.float()
if self.megatron_config.sequence_parallel:
logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False)
return logits
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output = super().forward(input_ids, attention_mask, position_ids)
output.logits = torch.squeeze(output.logits, dim=-1)
return output
"""
Support pipeline parallelism
"""
class ParallelLlamaModelRmPadPP(nn.Module):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
This model definition supports pipeline parallelism. To support pp and vpp,
- This model only contains layer in this pp stage and vpp chunk
- When calling get_model in Megatron, this rank will instantiate all the vpp chunks in this pp.
Args:
config: LlamaConfig
"""
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, pre_process, post_process):
super().__init__()
self.config: TransformerConfig = convert_config(config, megatron_config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.pre_process = pre_process
self.post_process = post_process
self.megatron_config = megatron_config
embedding_kwargs = tp_utils.get_default_kwargs_for_parallel_embedding()
if megatron_config is not None:
assert embedding_kwargs.get('config', False), 'must have ModelParallelConfig'
tp_utils.update_kwargs_with_config(embedding_kwargs, self.megatron_config)
if pre_process:
self.embed_tokens = tensor_parallel.VocabParallelEmbedding(num_embeddings=config.vocab_size,
embedding_dim=config.hidden_size,
**embedding_kwargs)
else:
self.embed_tokens = None
pp_rank = mpu.get_pipeline_model_parallel_rank()
pp_size = megatron_config.pipeline_model_parallel_size
self.num_layer_per_pp = config.num_hidden_layers // pp_size
vpp_size = megatron_config.virtual_pipeline_model_parallel_size
vpp_rank = mpu.get_virtual_pipeline_model_parallel_rank()
if vpp_size is not None:
self.layers = nn.ModuleList()
self.num_layer_vpp_chunk = self.num_layer_per_pp // vpp_size
self.num_layer_this_model = self.num_layer_vpp_chunk
offset = vpp_rank * (
config.num_hidden_layers // vpp_size) + \
(pp_rank * self.num_layer_vpp_chunk)
else:
self.num_layer_this_model = self.num_layer_per_pp
offset = pp_rank * self.num_layer_per_pp
self.layers = nn.ModuleList()
for i in range(self.num_layer_this_model):
layer = ParallelLlamaDecoderLayerRmPad(config, megatron_config, layer_idx=offset + i)
self.layers.add_module(f'{i}', layer)
if post_process:
self.norm = ParallelLlamaRMSNorm(config, megatron_config)
else:
self.norm = None
def set_input_tensor(self, input_tensor):
"""Set input tensor to be used instead of forward()'s input.
When doing pipeline parallelism the input from the previous
stage comes from communication, not from the input, so the
model's forward_step_func won't have it. This function is thus
used by internal code to bypass the input provided by the
forward_step_func"""
self.input_tensor = input_tensor
def forward(self,
input_ids: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
sequence_length: int = None,
indices: torch.Tensor = None,
cu_seqlens: int = None,
max_seqlen_in_batch: int = None) -> Union[Tuple, BaseModelOutputWithPast]:
"""
Args:
input_ids: input ids. shape (1, totol_nnz)
position_ids: position ids. shape (batch_size, seq_length)
Returns:
"""
if self.pre_process:
inputs_embeds = self.embed_tokens(input_ids) # (1, total_nnz) -> (1, total_nnz, hidden_size)
# vocab parallel embedding will not do sequence parallel reduce-scatter in open source megatron
# so need to deal with it by handle here:
# (1, total_nnz, hidden_size) -> (total_nnz, 1, hidden_size) -> (total_nnz // sp, 1, hidden_size)
inputs_embeds = inputs_embeds.transpose(0, 1)
if self.megatron_config.sequence_parallel:
inputs_embeds = tensor_parallel.scatter_to_sequence_parallel_region(inputs_embeds)
hidden_states = inputs_embeds
else:
# self.hidden_states should be passed by Megatron
hidden_states = self.input_tensor
for idx, decoder_layer in enumerate(self.layers):
layer_outputs = decoder_layer(hidden_states,
position_ids=position_ids,
sequence_length=sequence_length,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen_in_batch=max_seqlen_in_batch)
hidden_states = layer_outputs
if self.post_process:
hidden_states = self.norm(hidden_states)
return hidden_states
class ParallelLlamaForCausalLMRmPadPP(nn.Module):
def __init__(self,
config: LlamaConfig,
megatron_config: ModelParallelConfig,
pre_process,
post_process,
share_embeddings_and_output_weights=False):
super().__init__()
self.config: TransformerConfig = convert_config(config, megatron_config)
self.megatron_config = megatron_config
self.model = ParallelLlamaModelRmPadPP(config,
megatron_config=megatron_config,
pre_process=pre_process,
post_process=post_process)
assert share_embeddings_and_output_weights == False, f'Llama Model not supports sharing embedding and output weights'
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
self.vocab_size = config.vocab_size
self.pre_process = pre_process
self.post_process = post_process
if post_process:
self._init_head(config)
def set_input_tensor(self, input_tensor):
"""Set input tensor to be used instead of forward()'s input.
When doing pipeline parallelism the input from the previous
stage comes from communication, not from the input, so the
model's forward_step_func won't have it. This function is thus
used by internal code to bypass the input provided by the
forward_step_func"""
assert len(input_tensor) == 1
self.model.set_input_tensor(input_tensor[0])
def _init_head(self, config):
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
if self.megatron_config is not None:
assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
self.lm_head = tensor_parallel.ColumnParallelLinear(input_size=config.hidden_size,
output_size=config.vocab_size,
bias=False,
gather_output=False,
skip_bias_add=False,
**column_kwargs)
def _forward_head(self, hidden_states):
# all_gather from sequence parallel region is performed inside lm_head
# logits shape before forward_head hidden_states.shape: [4, 32, 4096]
logits = self.lm_head(hidden_states)[0]
# logits shape after forward_head logits.shape: [8, 32, 8]
logits = logits.float() # (total_nnz_padded, 1, vocab_size // tp)
return logits
def forward(
self,
# original input
*,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
```"""
# Note that input_ids, attention_mask and position_ids should be passed to every pp layer.
# In the first pp, input_ids will be used, in other pp layers hidden_states will be used inside self.model
batch_size, sequence_length = input_ids.shape
# remove padding here
input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(input_ids.unsqueeze(dim=-1),
attention_mask) # (total_nnz, 1)
# pad input_ids to multiple of tp for all tp ranks
# TODO: for better performance, the sp padding should be removed at each layer. Not sure the performance gap
if self.megatron_config.sequence_parallel:
input_ids_rmpad = sp_utils.pad_to_sequence_parallel(input_ids_rmpad)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz+pad)
outputs = self.model(input_ids=input_ids_rmpad,
position_ids=position_ids,
sequence_length=sequence_length,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen_in_batch=max_seqlen_in_batch)
if self.post_process:
hidden_states = outputs
# print(f'hidden_states.shape = {hidden_states.shape}') # torch.Size([4, 32, 4096])
logits = self._forward_head(hidden_states)
logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension # torch.Size([8, 32, 16])
# remove padding from sequence parallel
if self.megatron_config.sequence_parallel:
totol_nnz = cu_seqlens[-1]
logits = logits[:totol_nnz] # (total_nnz_padded)
# add removed padding back. If input is already rmpad, we let the caller pad_input
logits = pad_input(logits, indices, batch_size,
seqlen=sequence_length) # (batch_size, sequence_length, vocab_size)
return CausalLMOutputWithPast(
loss=None,
logits=logits,
past_key_values=None,
hidden_states=None,
attentions=None,
)
else:
return outputs
class ParallelLlamaForValueRmPadPP(ParallelLlamaForCausalLMRmPadPP):
def _init_head(self, config):
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
if self.megatron_config is not None:
assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
tp_utils.update_kwargs_with_config(column_kwargs, self.megatron_config)
self.lm_head = nn.Linear(in_features=config.hidden_size, out_features=1, bias=False)
# lm_head is effectively the same as sequence parallel
sp_utils.mark_parameter_as_sequence_parallel(self.lm_head.weight)
def _forward_head(self, hidden_states):
logits = self.lm_head(hidden_states) # (total_nnz_padded // tp, 1, 1)
logits = logits.float()
if self.megatron_config.sequence_parallel:
logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False)
return logits
def forward(
self,
*,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output = super().forward(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids)
if self.post_process:
output.logits = torch.squeeze(output.logits, dim=-1)
return output
else:
return output
# Copyright 2025 Bytedance Ltd. and/or its affiliates
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .gpt_model import gptmodel_forward
\ No newline at end of file
# Copyright 2025 Bytedance Ltd. and/or its affiliates
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from verl.utils.megatron import sequence_parallel as sp_utils
from verl.utils.megatron import tensor_parallel as tp_utils
import torch
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core import parallel_state as mpu
from verl.utils.megatron_utils import unwrap_model
def gptmodel_forward(model,
input_ids,
attention_mask,
position_ids,
sequence_parallel,
value_model=False,
pack_seqs=True):
pre_process = unwrap_model(model).pre_process
post_process = unwrap_model(model).post_process
if pack_seqs:
batch_size, seq_len = attention_mask.shape[:2]
input_ids_rmpad, packed_seq_params = preprocess_packed_seqs(input_ids, attention_mask, pre_process=pre_process)
input_ids_rmpad = input_ids_rmpad.contiguous()
output_orig = model(input_ids=input_ids_rmpad,
attention_mask=None,
position_ids=position_ids,
packed_seq_params=packed_seq_params)
output = postprocess_packed_seqs(output_orig,
packed_seq_params,
attention_mask,
batch_size,
seq_len,
post_process=post_process)
else:
batch_size, sequence_length = attention_mask.shape
new_input_ids, new_attention_mask, new_position_ids = remove_left_padding(input_ids,
attention_mask,
position_ids,
sequence_parallel,
pre_process=pre_process)
output = model(input_ids=new_input_ids, attention_mask=new_attention_mask, position_ids=new_position_ids)
output = recover_left_padding(output,
new_attention_mask,
attention_mask,
sequence_length,
post_process=post_process)
if value_model and post_process:
output = output[..., 0]
return output
def preprocess_packed_seqs(input_ids: torch.Tensor,
attention_mask: torch.Tensor,
pre_process: bool = True) -> tuple[torch.Tensor, PackedSeqParams]:
"""
Preprocess packed sequences
CP splits sequence into CP*2 chunks, and each GPU gets 2 chunks (GPU0 gets first and last chunks, GPU1 gets second and second last chunks, and so on), this is for load balancing with causal masking.
See https://github.com/NVIDIA/TransformerEngine/issues/1368
"""
batch_size = input_ids.shape[0]
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
tp_size = mpu.get_tensor_model_parallel_world_size()
cp_size = mpu.get_context_parallel_world_size()
cp_rank = mpu.get_context_parallel_rank()
if cp_size > 1:
align_size = tp_size * cp_size * 2
else:
align_size = tp_size
pad_size = (align_size - seqlens_in_batch % align_size) % align_size
seqlens_in_batch_padded = seqlens_in_batch + pad_size
cu_seqlens = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device)
cu_seqlens[1:] = torch.cumsum(seqlens_in_batch, dim=0)
cu_seqlens_padded = torch.zeros(batch_size + 1, dtype=torch.int32, device=input_ids.device)
cu_seqlens_padded[1:] = torch.cumsum(seqlens_in_batch_padded, dim=0)
max_seqlen_in_batch = seqlens_in_batch_padded.max().item()
shape = list(input_ids.shape[1:])
shape[0] = seqlens_in_batch_padded.sum().item() // cp_size
if pre_process:
input_ids_rmpad = torch.zeros(shape, dtype=input_ids.dtype, device=input_ids.device)
for i in range(batch_size):
if cp_size <= 1:
seqlen = seqlens_in_batch[i]
input_ids_rmpad[cu_seqlens_padded[i]:cu_seqlens_padded[i] + seqlen] = input_ids[i, attention_mask[i]]
continue
seqlen = seqlens_in_batch_padded[i] // cp_size
half_seqlen = seqlen // 2
start_idx = cu_seqlens_padded[i] // cp_size
# split to 2 chunks
d = input_ids[i, attention_mask[i]]
input_ids_rmpad[start_idx:start_idx + half_seqlen] = d[half_seqlen * cp_rank:half_seqlen * (cp_rank + 1)]
remain_start = seqlens_in_batch_padded[i] - half_seqlen * (cp_rank + 1)
remain_end = seqlens_in_batch_padded[i] - half_seqlen * cp_rank
remain_end = min(remain_end, d.shape[0])
remain_len = remain_end - remain_start
if remain_len > 0:
input_ids_rmpad[start_idx + half_seqlen:start_idx + half_seqlen +
remain_len] = d[remain_start:remain_end]
packed_seq_params = PackedSeqParams(qkv_format='thd',
cu_seqlens_q=cu_seqlens_padded,
max_seqlen_q=max_seqlen_in_batch,
cu_seqlens_kv=cu_seqlens_padded,
max_seqlen_kv=max_seqlen_in_batch,
cu_seqlens_q_padded=cu_seqlens_padded,
cu_seqlens_kv_padded=cu_seqlens_padded)
if pre_process:
return input_ids_rmpad.unsqueeze(0), packed_seq_params
else:
return input_ids, packed_seq_params
def postprocess_packed_seqs(output: torch.Tensor,
packed_seq_params: PackedSeqParams,
attention_mask: torch.Tensor,
batch_size: int,
seq_len: int,
post_process: bool = True) -> torch.Tensor:
"""
Postprocess packed sequences
"""
if not post_process:
return output
shape = [batch_size, seq_len] + list(output.shape[2:]) # 1,packed, dim -> batch_size, seq_len, dim
output_new = torch.zeros(shape, dtype=output.dtype, device=output.device)
cp_size = mpu.get_context_parallel_world_size()
# all gather output across context parallel group
if cp_size > 1:
# output shape: [1, packed_len, hidden_dim]
# need to gather across cp group and concatenate in sequence dimension
output_list = [torch.empty_like(output) for _ in range(cp_size)]
torch.distributed.all_gather(output_list, output.detach(), group=mpu.get_context_parallel_group())
output_list[mpu.get_context_parallel_rank()] = output
else:
output_list = [output]
for i in range(batch_size):
if cp_size <= 1:
s = attention_mask[i].sum().item()
output_new[i,
attention_mask[i]] = output[0][packed_seq_params.
cu_seqlens_q_padded[i]:packed_seq_params.cu_seqlens_q_padded[i] +
s]
continue
s_len_padded_chunk = (packed_seq_params.cu_seqlens_q_padded[i + 1] -
packed_seq_params.cu_seqlens_q_padded[i]) // cp_size
half_seqlen = s_len_padded_chunk // 2
s_len = attention_mask[i].sum().item()
s_len_padded = s_len_padded_chunk * cp_size
tmp = torch.empty(s_len_padded, *output.shape[2:], device=output.device)
for j in range(cp_size):
o = output_list[j][0]
# split to 2 chunks
packed_start_idx = packed_seq_params.cu_seqlens_q_padded[i] // cp_size
o0, o1 = o[packed_start_idx:packed_start_idx +
half_seqlen], o[packed_start_idx + half_seqlen:packed_start_idx + s_len_padded_chunk]
tmp[j * half_seqlen:(j + 1) * half_seqlen] = o0
tmp[s_len_padded - (j + 1) * half_seqlen:s_len_padded - j * half_seqlen] = o1
output_new[i, attention_mask[i]] = tmp[:s_len]
return output_new
def remove_left_padding(input_ids: torch.Tensor,
attention_mask: torch.Tensor,
position_ids: torch.Tensor,
sequence_parallel: bool = False,
pre_process: bool = True):
"""
Remove left padding from input_ids, attention_mask and position_ids
return new_input_ids, new_attention_mask, new_position_ids
"""
assert attention_mask.ndim == 2
assert position_ids.ndim == 2
cp_size = mpu.get_context_parallel_world_size()
assert cp_size == 1, 'Context parallel size without seq_pack is not supported'
batch_size = input_ids.shape[0]
shape = list(input_ids.shape) # batch_size, seq_len,...
seq_lens = attention_mask.sum(dim=1)
seq_len = seq_lens.max().item()
if sequence_parallel:
from megatron.core import parallel_state as mpu
sp_world_size = mpu.get_tensor_model_parallel_world_size()
pad_size = (sp_world_size - seq_len % sp_world_size) % sp_world_size
seq_len = seq_len + pad_size
shape[1] = seq_len
if pre_process:
new_input_ids = torch.zeros(dtype=input_ids.dtype, device=input_ids.device, size=shape)
new_attention_mask = torch.zeros(dtype=attention_mask.dtype,
device=attention_mask.device,
size=(batch_size, seq_len))
new_position_ids = torch.zeros(dtype=position_ids.dtype, device=position_ids.device, size=(batch_size, seq_len))
for i in range(batch_size):
if pre_process:
new_input_ids[i, :seq_lens[i]] = input_ids[i, attention_mask[i]]
new_attention_mask[i, :seq_lens[i]] = attention_mask[i, attention_mask[i]]
new_position_ids[i, :seq_lens[i]] = position_ids[i, attention_mask[i]]
if pre_process:
return new_input_ids, new_attention_mask, new_position_ids
else:
return input_ids, new_attention_mask, new_position_ids
def recover_left_padding(result,
attention_mask: torch.Tensor,
original_attention_mask: torch.Tensor,
origin_seqlen: int,
post_process: bool = True):
"""
Recover left padding from result
return result
"""
if not post_process:
return result
shape = list(result.shape)
batch_size = shape[0]
shape[1] = origin_seqlen
new_result = torch.zeros(dtype=result.dtype, device=result.device, size=shape)
for i in range(batch_size):
new_result[i, original_attention_mask[i]] = result[i, attention_mask[i]]
return new_result
# Copyright 2025 Bytedance Ltd. and/or its affiliates
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import time
import torch.distributed as dist
from .saver import _megatron_calc_global_rank
def _megatron_calc_layer_map(config):
"""Calculate the mapping of global layer_idx to local layer_idx
Returns:
layer_map (Dict: int -> tuple(int, int, int)):
mapping from the global layer index to
a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)
"""
import megatron
from megatron.core import mpu
pp_size = mpu.get_pipeline_model_parallel_world_size()
virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
layer_map = dict()
num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers
for pp_rank_idx in range(pp_size):
for virtual_pp_rank_idx in range(virtual_pp_size):
layer_offset = (virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) +
pp_rank_idx * num_layers_per_model)
for layer_idx in range(num_layers_per_model):
layer_map[layer_offset + layer_idx] = (
pp_rank_idx,
virtual_pp_rank_idx,
layer_idx,
)
return layer_map
def load_state_dict_to_megatron_gptmodel(state_dict, wrapped_models, config, params_dtype, is_value_model=False):
"""Load merged state_dict to sharded Megatron module in training.
"""
import megatron
from megatron.core import mpu
from verl.utils.megatron_utils import print_rank_0, unwrap_model
from megatron.core.transformer.module import Float16Module
from megatron.core import DistributedDataParallel as LocalDDP
from torch.nn.parallel import DistributedDataParallel as torchDDP
start_time = time.time()
def _get_gpt_model(model):
return model
def broadcast_params(module):
for param in module.parameters():
torch.distributed.broadcast(param.data,
src=mpu.get_data_parallel_src_rank(),
group=mpu.get_data_parallel_group())
dp_rank = mpu.get_data_parallel_rank()
pp_rank = mpu.get_pipeline_model_parallel_rank()
cp_rank = mpu.get_context_parallel_rank()
src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=0, cp_rank=cp_rank)
pp_size = mpu.get_pipeline_model_parallel_world_size()
virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
mp_group = mpu.get_model_parallel_group()
if torch.distributed.get_rank() == src_rank:
assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0"
assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0"
assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0"
if not isinstance(wrapped_models, (list, tuple)):
wrapped_models = list(wrapped_models)
assert len(wrapped_models) == virtual_pp_size
num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers
models = [None] * len(wrapped_models)
for i, wrapped_model in enumerate(wrapped_models):
models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))
gpt_model_module = _get_gpt_model(models[i])
assert len(gpt_model_module.decoder.layers) == num_layers_per_model
def _broadcast_tensor(tensor, name) -> torch.Tensor:
"""broadcast tensor from rank0 across mp_group"""
nonlocal state_dict
nonlocal mp_group
if torch.distributed.get_rank() == src_rank:
if name in state_dict:
weight = state_dict[name]
tensor_shape = weight.shape
else:
tensor_shape = None
else:
weight = None
tensor_shape = None
obj_list = [tensor_shape]
dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
tensor_shape = obj_list[0]
if tensor_shape is None:
# all or none ranks in the mp_group should reach here
print_rank_0(f"tensor:[{name}] not in state_dict, skip load")
return
if tensor is None:
tensor = torch.empty(
tensor_shape,
dtype=params_dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
if torch.distributed.get_rank() == src_rank:
tensor.data.copy_(weight)
dist.broadcast(tensor, src=src_rank, group=mp_group)
def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:
"""broadcast tensor in tp shards across mp_group"""
nonlocal state_dict
nonlocal mp_group
tp_rank = mpu.get_tensor_model_parallel_rank()
tp_size = mpu.get_tensor_model_parallel_world_size()
if torch.distributed.get_rank() == src_rank:
if name in state_dict:
full_weight = state_dict[name]
if mutate_func is not None:
full_weight = mutate_func(full_weight)
tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)
chunk_shape = tensor_chunk[0].shape
else:
chunk_shape = None
else:
chunk_shape = None
obj_list = [chunk_shape]
dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
chunk_shape = obj_list[0]
if chunk_shape is None:
# all or none ranks in the mp_group should reach here
print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading")
return
if tensor is None:
sync_tensor = torch.empty(
chunk_shape,
dtype=params_dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
else:
assert (tensor.shape == chunk_shape
), f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}"
sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
for i in range(tp_size):
if torch.distributed.get_rank() == src_rank:
sync_tensor.data.copy_(tensor_chunk[i])
dist.broadcast(sync_tensor, src=src_rank, group=mp_group)
if (i == tp_rank) and (tensor is not None):
tensor.data.copy_(sync_tensor)
def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:
"""broadcast tensor in tp shards across mp_group"""
nonlocal state_dict
nonlocal mp_group
tp_rank = mpu.get_tensor_model_parallel_rank()
tp_size = mpu.get_tensor_model_parallel_world_size()
if torch.distributed.get_rank() == src_rank:
if name in state_dict:
full_weight = state_dict[name]
if mutate_func is not None:
full_weight = mutate_func(full_weight)
tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)
chunk_shape = tensor_chunk[0].shape
else:
chunk_shape = None
else:
chunk_shape = None
obj_list = [chunk_shape]
dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
chunk_shape = obj_list[0]
if chunk_shape is None:
# all or none ranks in the mp_group should reach here
print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading")
return
if tensor is None:
sync_tensor = torch.empty(
chunk_shape,
dtype=params_dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
else:
assert (tensor.shape == chunk_shape
), f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}"
sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
for i in range(tp_size):
if torch.distributed.get_rank() == src_rank:
sync_tensor.data.copy_(tensor_chunk[i])
dist.broadcast(sync_tensor, src=src_rank, group=mp_group)
if (i == tp_rank) and (tensor is not None):
tensor.data.copy_(sync_tensor)
def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor:
"""broadcast tensor in tp shards across mp_group"""
nonlocal state_dict
nonlocal mp_group
tp_rank = mpu.get_tensor_model_parallel_rank()
tp_size = mpu.get_tensor_model_parallel_world_size()
if torch.distributed.get_rank() == src_rank:
gate_weight = state_dict[gate_name]
up_weight = state_dict[up_name]
new_gate_up_weight = torch.empty(config.intermediate_size * 2,
config.hidden_size,
dtype=params_dtype,
device=torch.cuda.current_device())
for i in range(tp_size):
intermediate_size_tp = config.intermediate_size // tp_size
gate_weight_tp = gate_weight[i * intermediate_size_tp:(i + 1) * intermediate_size_tp]
up_weight_tp = up_weight[i * intermediate_size_tp:(i + 1) * intermediate_size_tp]
new_gate_up_weight[intermediate_size_tp * 2 * i:intermediate_size_tp * 2 * (i + 1)].copy_(
torch.cat([gate_weight_tp, up_weight_tp], dim=0))
tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0)
chunk_shape = tensor_chunk[0].shape
else:
chunk_shape = None
obj_list = [chunk_shape]
dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
chunk_shape = obj_list[0]
if chunk_shape is None:
# all or none ranks in the mp_group should reach here
print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not in state_dict, skip loading")
return
if tensor is None:
sync_tensor = torch.empty(
chunk_shape,
dtype=params_dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
else:
assert (
tensor.shape == chunk_shape
), f"rank #{torch.distributed.get_rank() == src_rank:} tensor {gate_name, up_name} shape {tensor.shape} != {chunk_shape}"
sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
for i in range(tp_size):
if torch.distributed.get_rank() == src_rank:
sync_tensor.data.copy_(tensor_chunk[i])
dist.broadcast(sync_tensor, src=src_rank, group=mp_group)
if (i == tp_rank) and (tensor is not None):
tensor.data.copy_(sync_tensor)
def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> torch.Tensor:
"""broadcast tensor in tp shards across mp_group"""
nonlocal state_dict
nonlocal mp_group
tp_rank = mpu.get_tensor_model_parallel_rank()
tp_size = mpu.get_tensor_model_parallel_world_size()
if torch.distributed.get_rank() == src_rank:
assert (q_name in state_dict and k_name in state_dict and v_name in state_dict)
full_weight_q = state_dict[q_name]
full_weight_k = state_dict[k_name]
full_weight_v = state_dict[v_name]
hidden_size_per_head = config.hidden_size // config.num_attention_heads
if config.num_key_value_heads >= tp_size:
q_size_tp = config.hidden_size // tp_size
kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size
total_size = q_size_tp + 2 * kv_size_tp
sizes = [total_size * tp_size]
if not bias:
sizes.append(config.hidden_size)
new_weight_qkv = torch.empty(*sizes, dtype=params_dtype, device=torch.cuda.current_device())
for i in range(tp_size):
q_part = full_weight_q[i * q_size_tp:(i + 1) * q_size_tp]
k_part = full_weight_k[i * kv_size_tp:(i + 1) * kv_size_tp]
v_part = full_weight_v[i * kv_size_tp:(i + 1) * kv_size_tp]
num_query_groups_per_partition = models[0].config.num_query_groups // tp_size
new_weight_qkv_this_tp = new_weight_qkv[i * total_size:(i + 1) * total_size]
q_part_per_head = torch.chunk(q_part, num_query_groups_per_partition, dim=0)
k_part_per_head = torch.chunk(k_part, num_query_groups_per_partition, dim=0)
v_part_per_head = torch.chunk(v_part, num_query_groups_per_partition, dim=0)
total_size_per_head = total_size // num_query_groups_per_partition
for j in range(num_query_groups_per_partition):
new_weight_qkv_this_tp[j * total_size_per_head:(j + 1) * total_size_per_head].copy_(
torch.cat([q_part_per_head[j], k_part_per_head[j], v_part_per_head[j]], dim=0))
else:
q_size_tp = config.hidden_size // tp_size
kv_size_tp = hidden_size_per_head
total_size = q_size_tp + 2 * kv_size_tp
sizes = [total_size * tp_size]
if not bias:
sizes.append(config.hidden_size)
new_weight_qkv = torch.empty(*sizes, dtype=params_dtype, device=torch.cuda.current_device())
for i in range(tp_size):
q_part = full_weight_q[i * q_size_tp:(i + 1) * q_size_tp]
start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head
end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head
k_part = full_weight_k[start_idx:end_idx]
v_part = full_weight_v[start_idx:end_idx]
new_weight_qkv_this_tp = new_weight_qkv[i * total_size:(i + 1) * total_size]
q_part_per_head = torch.chunk(q_part, config.num_attention_heads, dim=0)
k_part_per_head = torch.chunk(k_part, config.num_attention_heads, dim=0)
v_part_per_head = torch.chunk(v_part, config.num_attention_heads, dim=0)
total_size_per_head = total_size // config.num_attention_heads
for j in range(config.num_attention_heads):
new_weight_qkv_this_tp[j * total_size_per_head:(j + 1) * total_size_per_head].copy_(
torch.cat([q_part_per_head[j], k_part_per_head[j], v_part_per_head[j]], dim=0))
tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0)
chunk_shape = tensor_chunk[0].shape
else:
chunk_shape = None
obj_list = [chunk_shape]
dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
chunk_shape = obj_list[0]
if chunk_shape is None:
# all or none ranks in the mp_group should reach here
print_rank_0(f"tp_shard tensor:[{q_name, k_name, v_name}] not in state_dict, skip loading")
return
if tensor is None:
sync_tensor = torch.empty(
chunk_shape,
dtype=params_dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
else:
assert (tensor.shape == chunk_shape
), f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}"
sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
for i in range(tp_size):
if torch.distributed.get_rank() == src_rank:
sync_tensor.data.copy_(tensor_chunk[i])
dist.broadcast(sync_tensor, src=src_rank, group=mp_group)
if (i == tp_rank) and (tensor is not None):
tensor.data.copy_(sync_tensor)
if dp_rank == 0:
# Embeddings
# -------------------
print_rank_0("loading embeddings...")
gpt_model_module = _get_gpt_model(models[0])
embed_tokens_weight = None
if pp_rank == 0:
embed_tokens_weight = gpt_model_module.embedding.word_embeddings.weight
_broadcast_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight")
# Transformer layers
# -------------------
layer_map = _megatron_calc_layer_map(config)
for layer in range(config.num_hidden_layers):
print_rank_0(f"loading layer #{layer}...")
layer_name = f"model.layers.{layer}"
dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer]
gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank])
sync_layer = gpt_model_module.decoder.layers[dst_layer_idx]
_broadcast_tensor(
sync_layer.self_attention.linear_qkv.layer_norm_weight if dst_pp_rank == pp_rank else None,
f"{layer_name}.input_layernorm.weight",
)
_broadcast_tp_shard_tensor_qkv(
sync_layer.self_attention.linear_qkv.weight if dst_pp_rank == pp_rank else None,
f"{layer_name}.self_attn.q_proj.weight",
f"{layer_name}.self_attn.k_proj.weight",
f"{layer_name}.self_attn.v_proj.weight",
)
if f"{layer_name}.self_attn.q_proj.bias" in state_dict:
_broadcast_tp_shard_tensor_qkv(
sync_layer.self_attention.linear_qkv.bias if dst_pp_rank == pp_rank else None,
f"{layer_name}.self_attn.q_proj.bias",
f"{layer_name}.self_attn.k_proj.bias",
f"{layer_name}.self_attn.v_proj.bias",
bias=True)
_broadcast_tp_shard_tensor(
sync_layer.self_attention.linear_proj.weight if dst_pp_rank == pp_rank else None,
f"{layer_name}.self_attn.o_proj.weight",
chunk_dim=1,
)
_broadcast_tensor(
sync_layer.mlp.linear_fc1.layer_norm_weight if dst_pp_rank == pp_rank else None,
f"{layer_name}.post_attention_layernorm.weight",
)
_broadcast_tp_shard_tensor_gate_up(sync_layer.mlp.linear_fc1.weight if dst_pp_rank == pp_rank else None,
f"{layer_name}.mlp.gate_proj.weight", f"{layer_name}.mlp.up_proj.weight")
_broadcast_tp_shard_tensor(
sync_layer.mlp.linear_fc2.weight if dst_pp_rank == pp_rank else None,
f"{layer_name}.mlp.down_proj.weight",
chunk_dim=1,
)
# Final Layernorm
# -------------------
print_rank_0("loading final layernorm...")
gpt_model_module = _get_gpt_model(models[-1])
_broadcast_tensor(
getattr(gpt_model_module.decoder.final_layernorm, "weight", None),
"model.norm.weight",
)
print_rank_0("loading lm_head...")
lm_head_weight = None
if pp_rank + 1 == pp_size:
lm_head_weight = gpt_model_module.output_layer.weight
if is_value_model:
# if torch.distributed.get_rank() == src_rank:
if 'lm_head.weight' in state_dict and state_dict['lm_head.weight'].shape[0] == 1:
_broadcast_tensor(lm_head_weight, "lm_head.weight")
elif 'reward_head.weight' in state_dict and state_dict['reward_head.weight'].shape[0] == 1:
_broadcast_tensor(lm_head_weight, "reward_head.weight")
print_rank_0('load lm_head from value_head weight')
else:
_broadcast_tensor(None, "lm_head.weight")
print_rank_0('fail to match lm_head in value_model')
# else:
# _broadcast_tensor(lm_head_weight, "lm_head.weight")
else:
_broadcast_tp_shard_tensor(lm_head_weight, "lm_head.weight")
dist.barrier()
# Broadcast weights inside data parallel groups
for wrapped_model in wrapped_models:
broadcast_params(wrapped_model)
pass
torch.cuda.empty_cache()
print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s")
# veRL Megatron-Core Models
The earlier versions of veRL use `Megatron-LM` 0.4 and workaround huggingface model classes. To better use the latest features and speedup of modern Megatron, we are migrating to `Megatron-Core`(mcore), and use the recommended `GPTModel` class for all language models. With mcore `GPTModel`, we can use the latest features like `context parallel`, `expert parallel`, `dist_checkpointing`, etc. and we can update mcore with little effort in the future for new features.
The migration has been successful with the help of the mcore team and the community. What we have done is:
1. update `Megatron` version to `0.11.0`
2. migrate `LlamaForCausalLM` and `Qwen2ForCausalLM` to mcore `GPTModel`
3. support sequence packing/thd format.
4. support `tensor parallel`, `pipeline parallel`, `sequence parallel`, `virtual pipeline parallel`, `context parallel`.
5. support the mcore `dist_checkpointing` feature and a basic offline weighs conversion scipt from huggingface to mcore `dist_checkpointing` format.
We are working on the following features:
- support `Qwen2MoeForCausalLM`
- support `DeepseekV3ForCausalLM`
- support `expert parallel`
Features we invite the community to contribute:
- better scipts for offline weights conversion from huggingface to mcore `dist_checkpointing` format.
- conversion of large models with multiple GPUs
- conversion of large models with single GPU
- refactor the `megatron_checkpoint_manager.py` by `dist_checkpointing` format.
- support llama4
- support qwen2.5-vl
To track the progress of verl mcore integration, please refer to the [mcore integration issue](https://github.com/volcengine/verl/issues/1033).
## How things work now
To engage the community in contributing, here are the key steps in our mcore integration process and features under development.
The huggingface `transformers` is the de facto standard of model zoo while mcore is good at computation efficiency. The main challenge is conversion between the two.
main steps:
1. modelling the huggingface model with mcore `GPTModel`
- a. convert the huggingface config to mcore `TransformerConfig`
- b. init the mcore `GPTModel` with the converted config
- c. load the huggingface model weights to the `GPTModel`
2. online weight conversion from mcore to huggingface (due the the rollout engine `vLLM` is using huggingface format)
- a. bridge the gap between mcore and huggingface weights format and name mapping
- b. online resharding the mcore weights to rollout engine
- this part is very complicated with multiple parallel strategies composition between mcore and rollout engine
3. support the mcore features in verl
- a. support `tensor parallel`, `pipeline parallel`, `sequence parallel`, `virtual pipeline parallel`, `context parallel`
- b. support recompute and other mcore speed up features
4. checkpointing
- a. support recovering the verl training.
- b. support exporting the mcore checkpoint to huggingface format, for downstream inference.
### Modelling the huggingface model with mcore `GPTModel`
The first step is to convert huggingface config to mcore `TransformerConfig` and init the mcore `GPTModel` with the converted config. See code in `verl/models/mcore/config_converter.py` and `verl/verl/models/mcore/models/model_initializer.py`. The corresponding model forward code is in `verl/verl/models/mcore/models/model_forward.py`.
There are two ways of loading the huggingface model weights to the `GPTModel`
1. Runtime loading
- every rank loads the entire huggingface model weights and then shard and convert to mcore weights.
- speed is slow and memory consumption is high.
- this way is deprecated and will not support new models.
2. Offline loading
- use offline script to convert the huggingface model weights to mcore weights and save with mcore `dist_checkpointing` format.
- online loading and sharding is automatically done by mcore `dist_checkpointing` format. The speed is fast and memory consumption is low.
- the offline script is in `verl/scripts/converter_hf_to_mcore.py`.
### online weight conversion from mcore to huggingface
See function `convert_megatron_model_to_transformers_model` in `verl/utils/megatron_utils.py` for the details.
It should be refatored for extensibility and better performance.
### support the mcore features in verl
Most of the features of `GPTModel` is out-of-the-box supported in verl through changing the `TransformerConfig`, except those about parallel strategies, such as `expert parallel`.
Features about parallel strategies should be supported with changes about the online weights conversion(especially the resharding part) and verl work dispatching.
### checkpointing
The existing checkpointing code is in `verl/utils/checkpoint/megatron_checkpoint_manager.py`. And the script to convert checkpoint to huggingface format is in `verl/scripts/model_merger.py`.
The existing checkpoint format is simplely save every rank's weights and optimizer states. It should be refactored by `dist_checkpointing` format.
## How to support new models
1. make sure the model is supported by vLLM
2. modelling the huggingface model with mcore `GPTModel` (The [Pai-Megatron-Path](https://github.com/alibaba/Pai-Megatron-Patch/tree/main) is a good reference)
- a. convert the huggingface config to mcore `TransformerConfig`
- b. init the mcore `GPTModel` with the converted config
- c. load the huggingface model weights to the `GPTModel`
- d. for VLM the interface might be different, it is ok to add a new model class with GPTModel as its module.
3. offline weights conversion from huggingface to mcore `dist_checkpointing` format
4. support online weights conversion from mcore to huggingface
- it is recommended to initilize a vLLM model with the converted mcore weights, and then test if the generating sequence is correct.
## How to scale up to larger models like deepseek-v3 or other 100B+ models
The greatest challenge for scaling up to larger models is the memory consumption.
The necessary features under development for scaling up are
1. Training engine part
- expert parallel
2. Rollout engine part
- pipeline parallel
- expert parallel
- more efficient and general weight resharding and loading
3. Offline weights conversion
- support weights larger then single GPU memory
# Copyright 2025 Bytedance Ltd. and/or its affiliates
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from verl.utils.megatron_utils import print_rank_0, unwrap_model
from megatron.core import mpu
from megatron.core.transformer.module import Float16Module
from megatron.core.distributed import DistributedDataParallel as LocalDDP
from torch.nn.parallel import DistributedDataParallel as torchDDP
import torch
import time
import torch
import torch.distributed as dist
def _megatron_calc_global_rank(tp_rank: int = 0,
dp_rank: int = 0,
pp_rank: int = 0,
cp_rank: int = 0,
ep_rank: int = 0):
"""Calculate global rank with support for CP/EP parallelism"""
# Get parallel sizes for each dimension
tp_size = mpu.get_tensor_model_parallel_world_size()
dp_size = mpu.get_data_parallel_world_size()
pp_size = mpu.get_pipeline_model_parallel_world_size()
cp_size = mpu.get_context_parallel_world_size()
ep_size = mpu.get_expert_model_parallel_world_size()
# Verify total GPU count matches (must be consistent with parallel_state.py)
total_size = tp_size * dp_size * pp_size * cp_size
assert total_size == torch.distributed.get_world_size(), \
f"{tp_size}x{dp_size}x{pp_size}x{cp_size} != {torch.distributed.get_world_size()}"
# Core calculation logic (corresponds to RankGenerator order parameter)
# Assumes default order is "tp-cp-ep-dp-pp"
return ((pp_rank * dp_size + dp_rank) * cp_size + cp_rank) * tp_size + tp_rank
def _megatron_calc_layer_map(config):
"""Calculate the mapping of global layer_idx to local layer_idx
Returns:
layer_map (Dict: int -> tuple(int, int, int)):
mapping from the global layer index to
a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)
"""
from megatron.core import mpu
pp_size = mpu.get_pipeline_model_parallel_world_size()
virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
layer_map = dict()
num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers
for pp_rank_idx in range(pp_size):
for virtual_pp_rank_idx in range(virtual_pp_size):
layer_offset = (virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) +
pp_rank_idx * num_layers_per_model)
for layer_idx in range(num_layers_per_model):
layer_map[layer_offset + layer_idx] = (
pp_rank_idx,
virtual_pp_rank_idx,
layer_idx,
)
return layer_map
def merge_megatron_ckpt_gptmodel(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False):
"""Merge sharded parameters of a Megatron module into a merged checkpoint.
Args:
wrapped_models (list of megatron.core.distributed.DistributedDataParallel):
The local DDP wrapped megatron modules.
config (str or None):
HF config for model
dtype: model params type
is_value_model: if model is value model
tie_word_embeddings: tie_word_embeddings
Returns:
state_dict (dict):
The merged state_dict in rank 0, and an empty dictionary in other ranks.
"""
start_time = time.time()
def _get_gpt_model(model):
return model
dp_rank = mpu.get_data_parallel_rank()
pp_size = mpu.get_pipeline_model_parallel_world_size()
pp_rank = mpu.get_pipeline_model_parallel_rank()
cp_rank = mpu.get_context_parallel_rank()
virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
mp_group = mpu.get_model_parallel_group()
if dist.get_rank() == 0:
assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0"
assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0"
assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0"
if not isinstance(wrapped_models, (list, tuple)):
wrapped_models = list(wrapped_models)
assert len(wrapped_models) == virtual_pp_size
num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers
models = [None] * len(wrapped_models)
for i, wrapped_model in enumerate(wrapped_models):
models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))
assert len(models[i].decoder.layers
) == num_layers_per_model, 'len model layers {} not equal to num_layers_per_model {}'.format(
len(models[i].decoder.layers), num_layers_per_model)
state_dict = dict()
def _get_cpu_tensor(tensor: torch.Tensor):
if tensor is None:
return None
if tensor.device == torch.device("cpu"):
return tensor.detach().clone()
return tensor.detach().cpu()
def _broadcast_tensor(tensor, name, src_pp_rank) -> torch.Tensor:
"""broadcast tensor across mp_group"""
nonlocal state_dict
nonlocal mp_group
src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank)
if torch.distributed.get_rank() == src_rank:
if tensor is None:
weight = None
tensor_shape = None
else:
weight = tensor
tensor_shape = weight.shape
else:
weight = None
tensor_shape = None
obj_list = [tensor_shape]
dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
tensor_shape = obj_list[0]
if tensor_shape is None:
# all or none ranks in the mp_group should reach here
print_rank_0(f"tensor:[{name}] not exist, skip collect")
return
if weight is None:
weight = torch.empty(
tensor_shape,
dtype=dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
dist.broadcast(weight, src=src_rank, group=mp_group)
if torch.distributed.get_rank() == 0:
state_dict[name] = _get_cpu_tensor(weight)
def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_func=None) -> torch.Tensor:
"""broadcast tensor in tp shards across mp_group"""
nonlocal state_dict
nonlocal mp_group
tp_rank = mpu.get_tensor_model_parallel_rank()
tp_size = mpu.get_tensor_model_parallel_world_size()
src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank)
if torch.distributed.get_rank() == src_rank:
chunk_shape = tensor.shape
else:
chunk_shape = None
obj_list = [chunk_shape]
dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
chunk_shape = obj_list[0]
if chunk_shape is None:
# all or none ranks in the mp_group should reach here
print_rank_0(f"tp_shard tensor:[{name}] not exist, skip collecting")
return
buffer_tensor = torch.empty(
chunk_shape,
dtype=dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
chunk_tensors = [None] * tp_size
for i in range(tp_size):
cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank)
sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor
dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)
if torch.distributed.get_rank() == 0:
chunk_tensors[i] = _get_cpu_tensor(sync_tensor)
if torch.distributed.get_rank() == 0:
full_tensor = torch.concat(chunk_tensors, dim=concat_dim)
if mutate_func is not None:
full_tensor = mutate_func(full_tensor)
state_dict[name] = full_tensor
def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) -> torch.Tensor:
"""broadcast tensor in tp shards across mp_group"""
nonlocal state_dict
nonlocal mp_group
tp_rank = mpu.get_tensor_model_parallel_rank()
tp_size = mpu.get_tensor_model_parallel_world_size()
src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank)
if torch.distributed.get_rank() == src_rank:
chunk_shape = tensor.shape
else:
chunk_shape = None
obj_list = [chunk_shape]
dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
chunk_shape = obj_list[0]
if chunk_shape is None:
# all or none ranks in the mp_group should reach here
print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not exist, skip collecting")
return
buffer_tensor = torch.empty(
chunk_shape,
dtype=dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
chunk_tensors = [None] * tp_size
for i in range(tp_size):
cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank)
sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor
dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)
if torch.distributed.get_rank() == 0:
chunk_tensors[i] = _get_cpu_tensor(sync_tensor)
if torch.distributed.get_rank() == 0:
full_tensor = torch.concat(chunk_tensors, dim=0)
intermediate_size_tp = config.intermediate_size // tp_size
gate_weight_list = []
up_weight_list = []
for i in range(tp_size):
gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i:intermediate_size_tp * 2 * (i + 1)]
gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp]
up_weight_tp = gate_up_weight_tp[intermediate_size_tp:]
gate_weight_list.append(gate_weight_tp)
up_weight_list.append(up_weight_tp)
state_dict[gate_name] = torch.cat(gate_weight_list, dim=0)
state_dict[up_name] = torch.cat(up_weight_list, dim=0)
def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank):
"""broadcast tensor in tp shards across mp_group"""
nonlocal state_dict
nonlocal mp_group
tp_rank = mpu.get_tensor_model_parallel_rank()
tp_size = mpu.get_tensor_model_parallel_world_size()
src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank)
if torch.distributed.get_rank() == src_rank:
chunk_shape = tensor.shape
else:
chunk_shape = None
obj_list = [chunk_shape]
dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
chunk_shape = obj_list[0]
if chunk_shape is None:
# all or none ranks in the mp_group should reach here
print_rank_0(f"tp_shard tensor:[{q_name}] not exist, skip collecting")
return
buffer_tensor = torch.empty(
chunk_shape,
dtype=dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
chunk_tensors = [None] * tp_size
for i in range(tp_size):
cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank, cp_rank=cp_rank)
sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor
dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)
if torch.distributed.get_rank() == 0:
chunk_tensors[i] = _get_cpu_tensor(sync_tensor)
if torch.distributed.get_rank() == 0:
full_tensor = torch.concat(chunk_tensors, dim=0)
q_weight_list = []
k_weight_list = []
v_weight_list = []
hidden_size_per_head = config.hidden_size // config.num_attention_heads
if config.num_key_value_heads >= tp_size:
q_size_tp = config.hidden_size // tp_size
kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size
total_size = q_size_tp + 2 * kv_size_tp
for i in range(tp_size):
num_query_groups_per_partition = wrapped_models[0].config.num_query_groups // tp_size
qkv_part = full_tensor[i * total_size:(i + 1) * total_size]
q_size_chunk = q_size_tp // num_query_groups_per_partition
kv_size_chunk = kv_size_tp // num_query_groups_per_partition
for qkv_part_chunk in qkv_part.chunk(num_query_groups_per_partition):
q_part = qkv_part_chunk[:q_size_chunk]
k_part = qkv_part_chunk[q_size_chunk:q_size_chunk + kv_size_chunk]
v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk:]
q_weight_list.append(q_part)
k_weight_list.append(k_part)
v_weight_list.append(v_part)
else:
q_size_tp = config.hidden_size // tp_size
kv_size_tp = hidden_size_per_head
total_size = q_size_tp + 2 * kv_size_tp
for i in range(tp_size):
num_query_groups_per_partition = wrapped_models[0].config.num_query_groups // tp_size
qkv_part = full_tensor[i * total_size:(i + 1) * total_size]
q_size_chunk = q_size_tp // num_query_groups_per_partition
kv_size_chunk = kv_size_tp // num_query_groups_per_partition
for qkv_part_chunk in qkv_part.chunk(num_query_groups_per_partition):
q_part = qkv_part_chunk[:q_size_chunk]
k_part = qkv_part_chunk[q_size_chunk:q_size_chunk + kv_size_chunk]
v_part = qkv_part_chunk[q_size_chunk + kv_size_chunk:]
q_weight_list.append(q_part)
if i * config.num_key_value_heads % tp_size == 0:
k_weight_list.append(k_part)
v_weight_list.append(v_part)
state_dict[q_name] = torch.cat(q_weight_list, dim=0)
state_dict[k_name] = torch.cat(k_weight_list, dim=0)
state_dict[v_name] = torch.cat(v_weight_list, dim=0)
# empty cache before collecting weights
torch.cuda.empty_cache()
# Embeddings
# -------------------
if dp_rank == 0 and cp_rank == 0: # models are identical across cp ranks
# Embeddings
# -------------------
print_rank_0("collecting embeddings...")
gpt_model_module = _get_gpt_model(models[0])
_broadcast_tp_shard_tensor(
gpt_model_module.embedding.word_embeddings.weight if pp_rank == 0 else None,
"model.embed_tokens.weight",
src_pp_rank=0,
)
# Transformer layers
# -------------------
layer_map = _megatron_calc_layer_map(config)
for layer in range(config.num_hidden_layers):
print_rank_0(f"collecting layer #{layer}...")
layer_name = f"model.layers.{layer}"
src_pp_rank, src_virtual_pp_rank, src_layer_idx = layer_map[layer]
gpt_model_module = _get_gpt_model(models[src_virtual_pp_rank])
sync_layer = gpt_model_module.decoder.layers[src_layer_idx]
_broadcast_tensor(
sync_layer.self_attention.linear_qkv.layer_norm_weight,
f"{layer_name}.input_layernorm.weight",
src_pp_rank=src_pp_rank,
)
_broadcast_tp_shard_tensor_qkv(
sync_layer.self_attention.linear_qkv.weight,
f"{layer_name}.self_attn.q_proj.weight",
f"{layer_name}.self_attn.k_proj.weight",
f"{layer_name}.self_attn.v_proj.weight",
src_pp_rank=src_pp_rank,
)
if getattr(sync_layer.self_attention.linear_qkv, 'bias',
None) is not None and sync_layer.self_attention.linear_qkv.bias.numel() > 0:
_broadcast_tp_shard_tensor_qkv(
sync_layer.self_attention.linear_qkv.bias,
f"{layer_name}.self_attn.q_proj.bias",
f"{layer_name}.self_attn.k_proj.bias",
f"{layer_name}.self_attn.v_proj.bias",
src_pp_rank=src_pp_rank,
)
_broadcast_tp_shard_tensor(
sync_layer.self_attention.linear_proj.weight,
f"{layer_name}.self_attn.o_proj.weight",
concat_dim=1,
src_pp_rank=src_pp_rank,
)
_broadcast_tensor(
sync_layer.mlp.linear_fc1.layer_norm_weight,
f"{layer_name}.post_attention_layernorm.weight",
src_pp_rank=src_pp_rank,
)
_broadcast_tp_shard_tensor_gate_up(sync_layer.mlp.linear_fc1.weight,
f"{layer_name}.mlp.gate_proj.weight",
f"{layer_name}.mlp.up_proj.weight",
src_pp_rank=src_pp_rank)
_broadcast_tp_shard_tensor(
sync_layer.mlp.linear_fc2.weight,
f"{layer_name}.mlp.down_proj.weight",
concat_dim=1,
src_pp_rank=src_pp_rank,
)
# Final Layernorm
# -------------------
print_rank_0("collecting final layernorm...")
gpt_model_module = _get_gpt_model(models[-1])
_broadcast_tensor(
getattr(gpt_model_module.decoder.final_layernorm, "weight", None),
"model.norm.weight",
src_pp_rank=pp_size - 1,
)
if tie_word_embeddings:
print_rank_0(f"tie word embedding skip load lm_head...")
else:
print_rank_0("collecting lm_head...")
if is_value_model:
lm_head_weight = None
if pp_rank == pp_size - 1:
lm_head_weight = getattr(gpt_model_module.output_layer, "weight", None)
_broadcast_tensor(lm_head_weight, "lm_head.weight", src_pp_rank=pp_size - 1)
else:
_broadcast_tp_shard_tensor(
getattr(gpt_model_module.output_layer, "weight", None) if pp_rank == pp_size - 1 else None,
"lm_head.weight",
src_pp_rank=pp_size - 1,
)
dist.barrier()
torch.cuda.empty_cache()
if torch.distributed.get_rank() == 0:
for k, v in state_dict.items():
if dtype != v.dtype:
state_dict[k] = v.to(dtype)
print_rank_0(f"merge megatron ckpt done, time elapsed {time.time() - start_time}s")
return state_dict
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .modeling_qwen2_megatron import (
# original model with megatron
ParallelQwen2Model,
ParallelQwen2ForCausalLM,
# rmpad with megatron
ParallelQwen2ForCausalLMRmPad,
ParallelQwen2ForValueRmPad,
# rmpad with megatron and pipeline parallelism
ParallelQwen2ForCausalLMRmPadPP,
ParallelQwen2ForValueRmPadPP)
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import time
from typing import Dict, Any, Callable, Optional
import torch.distributed as dist
def _megatron_calc_layer_map(config):
"""Calculate the mapping of global layer_idx to local layer_idx
Returns:
layer_map (Dict: int -> tuple(int, int, int)):
mapping from the global layer index to
a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)
"""
from megatron.core import mpu
pp_size = mpu.get_pipeline_model_parallel_world_size()
virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
layer_map = dict()
num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers
for pp_rank_idx in range(pp_size):
for virtual_pp_rank_idx in range(virtual_pp_size):
layer_offset = (virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) +
pp_rank_idx * num_layers_per_model)
for layer_idx in range(num_layers_per_model):
layer_map[layer_offset + layer_idx] = (
pp_rank_idx,
virtual_pp_rank_idx,
layer_idx,
)
return layer_map
def load_state_dict_to_megatron_qwen2(state_dict,
wrapped_models,
config,
params_dtype,
is_value_model=False,
tie_word_embeddings=False):
"""Load merged state_dict to sharded Megatron module in training.
"""
from verl.utils.megatron_utils import print_rank_0, unwrap_model
from megatron.core import mpu
from megatron.core.transformer.module import Float16Module
from megatron.core import DistributedDataParallel as LocalDDP
from torch.nn.parallel import DistributedDataParallel as torchDDP
start_time = time.time()
def _get_gpt_model(model):
return model
def fetch_params(module):
for param in module.parameters():
torch.distributed.fetch(param.data,
src=mpu.get_data_parallel_src_rank(),
group=mpu.get_data_parallel_group())
dp_rank = mpu.get_data_parallel_rank()
pp_rank = mpu.get_pipeline_model_parallel_rank()
pp_size = mpu.get_pipeline_model_parallel_world_size()
virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
mp_group = mpu.get_model_parallel_group()
if torch.distributed.get_rank() == 0:
assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0"
assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0"
assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0"
if not isinstance(wrapped_models, (list, tuple)):
wrapped_models = list(wrapped_models)
assert len(wrapped_models) == virtual_pp_size
num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, f'num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size: {virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}'
models = [None] * len(wrapped_models)
for i, wrapped_model in enumerate(wrapped_models):
models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))
gpt_model_module = _get_gpt_model(models[i])
assert len(gpt_model_module.model.layers) == num_layers_per_model
def _fetch_tensor(tensor, name) -> torch.Tensor:
"""fetch tensor"""
nonlocal state_dict
if tensor is not None:
tensor = tensor.data.copy_(state_dict[name], non_blocking=True)
def _fetch_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:
"""fetch tensor in tp shards"""
nonlocal state_dict
tp_rank = mpu.get_tensor_model_parallel_rank()
tp_size = mpu.get_tensor_model_parallel_world_size()
if name in state_dict:
full_weight = state_dict[name]
if mutate_func is not None:
full_weight = mutate_func(full_weight)
tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)
if tensor is not None:
tensor = tensor.data.copy_(tensor_chunk[tp_rank], non_blocking=True)
else:
print(f"tp_shard tensor:[{name}] not in state_dict, skip loading")
def _fetch_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:
"""fetch tensor in tp shards"""
nonlocal state_dict
tp_rank = mpu.get_tensor_model_parallel_rank()
tp_size = mpu.get_tensor_model_parallel_world_size()
if name in state_dict:
full_weight = state_dict[name]
if mutate_func is not None:
full_weight = mutate_func(full_weight)
tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)
if tensor is not None:
tensor = tensor.data.copy_(tensor_chunk[tp_rank], non_blocking=True)
else:
print(f"tp_shard tensor:[{name}] not in state_dict, skip loading")
def _fetch_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor:
"""fetch gate_up tensor in tp shards"""
nonlocal state_dict
nonlocal mp_group
tp_rank = mpu.get_tensor_model_parallel_rank()
tp_size = mpu.get_tensor_model_parallel_world_size()
if gate_name in state_dict and up_name in state_dict:
gate_weight = state_dict[gate_name]
up_weight = state_dict[up_name]
new_gate_up_weight = torch.empty(config.intermediate_size * 2,
config.hidden_size,
dtype=params_dtype,
device=torch.cuda.current_device())
for i in range(tp_size):
intermediate_size_tp = config.intermediate_size // tp_size
gate_weight_tp = gate_weight[i * intermediate_size_tp:(i + 1) * intermediate_size_tp]
up_weight_tp = up_weight[i * intermediate_size_tp:(i + 1) * intermediate_size_tp]
new_gate_up_weight[intermediate_size_tp * 2 * i:intermediate_size_tp * 2 * (i + 1)].copy_(
torch.cat([gate_weight_tp, up_weight_tp], dim=0))
tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0)
if tensor is not None:
tensor = tensor.data.copy_(tensor_chunk[tp_rank], non_blocking=True)
else:
print(f"tp_shard tensor:[{gate_name}, {up_name}] not in state_dict, skip loading")
def _fetch_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> torch.Tensor:
"""fetch tensor in tp shards across mp_group"""
nonlocal state_dict
nonlocal mp_group
tp_rank = mpu.get_tensor_model_parallel_rank()
tp_size = mpu.get_tensor_model_parallel_world_size()
assert (q_name in state_dict and k_name in state_dict and v_name in state_dict)
full_weight_q = state_dict[q_name]
full_weight_k = state_dict[k_name]
full_weight_v = state_dict[v_name]
hidden_size_per_head = config.hidden_size // config.num_attention_heads
if config.num_key_value_heads >= tp_size:
q_size_tp = config.hidden_size // tp_size
kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size
total_size = q_size_tp + 2 * kv_size_tp
if not bias:
new_weight_qkv = torch.empty(total_size * tp_size,
config.hidden_size,
dtype=params_dtype,
device=torch.cuda.current_device())
else:
new_weight_qkv = torch.empty(total_size * tp_size,
dtype=params_dtype,
device=torch.cuda.current_device())
for i in range(tp_size):
q_part = full_weight_q[i * q_size_tp:(i + 1) * q_size_tp]
k_part = full_weight_k[i * kv_size_tp:(i + 1) * kv_size_tp]
v_part = full_weight_v[i * kv_size_tp:(i + 1) * kv_size_tp]
new_weight_qkv[i * total_size:(i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0))
else:
q_size_tp = config.hidden_size // tp_size
kv_size_tp = hidden_size_per_head
total_size = q_size_tp + 2 * kv_size_tp
if not bias:
new_weight_qkv = torch.empty(total_size * tp_size,
config.hidden_size,
dtype=params_dtype,
device=torch.cuda.current_device())
else:
new_weight_qkv = torch.empty(total_size * tp_size,
dtype=params_dtype,
device=torch.cuda.current_device())
for i in range(tp_size):
q_part = full_weight_q[i * q_size_tp:(i + 1) * q_size_tp]
start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head
end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head
k_part = full_weight_k[start_idx:end_idx]
v_part = full_weight_v[start_idx:end_idx]
new_weight_qkv[i * total_size:(i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part], dim=0))
tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0)
if tensor is not None:
tensor = tensor.data.copy_(tensor_chunk[tp_rank], non_blocking=True)
# Embeddings
# -------------------
print_rank_0("loading embeddings...")
gpt_model_module = _get_gpt_model(models[0])
if pp_rank == 0:
embed_tokens_weight = gpt_model_module.model.embed_tokens.weight
_fetch_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight")
# Transformer layers
# -------------------
layer_map = _megatron_calc_layer_map(config)
pp_rank = mpu.get_pipeline_model_parallel_rank()
pp_size = mpu.get_pipeline_model_parallel_world_size()
num_layer_per_pp = config.num_hidden_layers // pp_size
vpp_size = mpu.get_virtual_pipeline_model_parallel_world_size()
layer_list = []
if vpp_size is not None:
for vpp_rank in range(vpp_size):
num_layer_vpp_chunk = num_layer_per_pp // vpp_size
num_layer_this_model = num_layer_vpp_chunk
offset = vpp_rank * (
config.num_hidden_layers // mpu.get_virtual_pipeline_model_parallel_world_size()) + \
(mpu.get_pipeline_model_parallel_rank() * num_layer_vpp_chunk)
layer_list.extend(list(range(offset, offset + num_layer_this_model)))
else:
num_layer_this_model = num_layer_per_pp
offset = pp_rank * num_layer_per_pp
layer_list.extend(list(range(offset, offset + num_layer_this_model)))
for layer in layer_list:
print(f"{torch.distributed.get_rank()} loading layer #{layer}...")
layer_name = f"model.layers.{layer}"
dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer]
print(
f'{torch.distributed.get_rank()} offset: {offset}, num_layer_this_model: {num_layer_this_model}, layer_name: {layer_name}, layer_map[layer]: {layer_map[layer]}'
)
gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank])
sync_layer = gpt_model_module.model.layers[dst_layer_idx]
_fetch_tensor(
sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None,
f"{layer_name}.input_layernorm.weight",
)
_fetch_tp_shard_tensor_qkv(
sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None,
f"{layer_name}.self_attn.q_proj.weight",
f"{layer_name}.self_attn.k_proj.weight",
f"{layer_name}.self_attn.v_proj.weight",
)
_fetch_tp_shard_tensor_qkv(sync_layer.self_attn.qkv_proj.bias if dst_pp_rank == pp_rank else None,
f"{layer_name}.self_attn.q_proj.bias",
f"{layer_name}.self_attn.k_proj.bias",
f"{layer_name}.self_attn.v_proj.bias",
bias=True)
_fetch_tp_shard_tensor(
sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None,
f"{layer_name}.self_attn.o_proj.weight",
chunk_dim=1,
)
_fetch_tensor(
sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None,
f"{layer_name}.post_attention_layernorm.weight",
)
_fetch_tp_shard_tensor_gate_up(sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None,
f"{layer_name}.mlp.gate_proj.weight", f"{layer_name}.mlp.up_proj.weight")
_fetch_tp_shard_tensor(
sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None,
f"{layer_name}.mlp.down_proj.weight",
chunk_dim=1,
)
# Final Layernorm
# -------------------
print_rank_0("loading final layernorm...")
gpt_model_module = _get_gpt_model(models[-1])
_fetch_tensor(
getattr(gpt_model_module.model.norm, "weight", None),
"model.norm.weight",
)
if tie_word_embeddings:
print_rank_0("tie_word_embeddings skip load lm_head")
else:
print_rank_0("loading lm_head...")
if pp_rank + 1 == pp_size:
lm_head_weight = gpt_model_module.lm_head.weight
if is_value_model:
if 'lm_head.weight' in state_dict and state_dict['lm_head.weight'].shape[0] == 1:
_fetch_tensor(lm_head_weight, "lm_head.weight")
print_rank_0('load lm_head from value_head weight')
elif 'reward_head.weight' in state_dict and state_dict['reward_head.weight'].shape[0] == 1:
_fetch_tensor(lm_head_weight, "reward_head.weight")
print_rank_0('load lm_head from value_head weight')
else:
_fetch_tensor(None, "lm_head.weight")
print_rank_0('fail to match lm_head in value_model')
else:
_fetch_tp_shard_tensor(lm_head_weight, "lm_head.weight")
dist.barrier()
torch.cuda.empty_cache()
print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s")
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import time
from typing import Dict, Any, Callable, Optional
import torch.distributed as dist
def _megatron_calc_layer_map(config):
"""Calculate the mapping of global layer_idx to local layer_idx
Returns:
layer_map (Dict: int -> tuple(int, int, int)):
mapping from the global layer index to
a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)
"""
from megatron.core import mpu
pp_size = mpu.get_pipeline_model_parallel_world_size()
virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
layer_map = dict()
num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers
for pp_rank_idx in range(pp_size):
for virtual_pp_rank_idx in range(virtual_pp_size):
layer_offset = (virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) +
pp_rank_idx * num_layers_per_model)
for layer_idx in range(num_layers_per_model):
layer_map[layer_offset + layer_idx] = (
pp_rank_idx,
virtual_pp_rank_idx,
layer_idx,
)
return layer_map
def load_state_dict_to_megatron_qwen2(state_dict,
wrapped_models,
config,
params_dtype,
is_value_model=False,
tie_word_embeddings=False):
"""Load merged state_dict to sharded Megatron module in training.
"""
from verl.utils.megatron_utils import print_rank_0, unwrap_model
from megatron.core import mpu
from megatron.core.transformer.module import Float16Module
from megatron.core import DistributedDataParallel as LocalDDP
from torch.nn.parallel import DistributedDataParallel as torchDDP
start_time = time.time()
def _get_gpt_model(model):
return model
def broadcast_params(module):
for param in module.parameters():
torch.distributed.broadcast(param.data,
src=mpu.get_data_parallel_src_rank(),
group=mpu.get_data_parallel_group())
dp_rank = mpu.get_data_parallel_rank()
pp_rank = mpu.get_pipeline_model_parallel_rank()
pp_size = mpu.get_pipeline_model_parallel_world_size()
virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
mp_group = mpu.get_model_parallel_group()
if torch.distributed.get_rank() == 0:
assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0"
assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0"
assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0"
if not isinstance(wrapped_models, (list, tuple)):
wrapped_models = list(wrapped_models)
assert len(wrapped_models) == virtual_pp_size
num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers, f'num_layers_per_model: {num_layers_per_model} * pp_size: {pp_size} * virtual_pp_size: {virtual_pp_size} != config.num_hidden_layers: {config.num_hidden_layers}'
models = [None] * len(wrapped_models)
for i, wrapped_model in enumerate(wrapped_models):
models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))
gpt_model_module = _get_gpt_model(models[i])
assert len(gpt_model_module.model.layers) == num_layers_per_model
def _broadcast_tensor(tensor, name) -> torch.Tensor:
"""broadcast tensor from rank0 across mp_group"""
nonlocal state_dict
nonlocal mp_group
if torch.distributed.get_rank() == 0:
if name in state_dict:
weight = state_dict[name]
tensor_shape = weight.shape
else:
tensor_shape = None
else:
weight = None
tensor_shape = None
obj_list = [tensor_shape]
dist.broadcast_object_list(obj_list, src=0, group=mp_group)
tensor_shape = obj_list[0]
if tensor_shape is None:
# all or none ranks in the mp_group should reach here
print_rank_0(f"tensor:[{name}] not in state_dict, skip load")
return
if tensor is None:
tensor = torch.empty(
tensor_shape,
dtype=params_dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
if torch.distributed.get_rank() == 0:
tensor.data.copy_(weight)
dist.broadcast(tensor, src=0, group=mp_group)
def _broadcast_tp_shard_tensor_vocab(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:
"""broadcast tensor in tp shards across mp_group"""
nonlocal state_dict
nonlocal mp_group
tp_rank = mpu.get_tensor_model_parallel_rank()
tp_size = mpu.get_tensor_model_parallel_world_size()
if torch.distributed.get_rank() == 0:
if name in state_dict:
full_weight = state_dict[name]
if mutate_func is not None:
full_weight = mutate_func(full_weight)
tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)
chunk_shape = tensor_chunk[0].shape
else:
chunk_shape = None
else:
chunk_shape = None
obj_list = [chunk_shape]
dist.broadcast_object_list(obj_list, src=0, group=mp_group)
chunk_shape = obj_list[0]
if chunk_shape is None:
# all or none ranks in the mp_group should reach here
print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading")
return
if tensor is None:
sync_tensor = torch.empty(
chunk_shape,
dtype=params_dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
else:
assert (tensor.shape == chunk_shape
), f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}"
sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
for i in range(tp_size):
if torch.distributed.get_rank() == 0:
sync_tensor.data.copy_(tensor_chunk[i])
dist.broadcast(sync_tensor, src=0, group=mp_group)
if (i == tp_rank) and (tensor is not None):
tensor.data.copy_(sync_tensor)
def _broadcast_tp_shard_tensor(tensor, name, chunk_dim=0, mutate_func=None) -> torch.Tensor:
"""broadcast tensor in tp shards across mp_group"""
nonlocal state_dict
nonlocal mp_group
tp_rank = mpu.get_tensor_model_parallel_rank()
tp_size = mpu.get_tensor_model_parallel_world_size()
if torch.distributed.get_rank() == 0:
if name in state_dict:
full_weight = state_dict[name]
if mutate_func is not None:
full_weight = mutate_func(full_weight)
tensor_chunk = torch.chunk(full_weight, tp_size, dim=chunk_dim)
chunk_shape = tensor_chunk[0].shape
else:
chunk_shape = None
else:
chunk_shape = None
obj_list = [chunk_shape]
dist.broadcast_object_list(obj_list, src=0, group=mp_group)
chunk_shape = obj_list[0]
if chunk_shape is None:
# all or none ranks in the mp_group should reach here
print_rank_0(f"tp_shard tensor:[{name}] not in state_dict, skip loading")
return
if tensor is None:
sync_tensor = torch.empty(
chunk_shape,
dtype=params_dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
else:
assert (tensor.shape == chunk_shape
), f"rank #{torch.distributed.get_rank()} tensor {name} shape {tensor.shape} != {chunk_shape}"
sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
for i in range(tp_size):
if torch.distributed.get_rank() == 0:
sync_tensor.data.copy_(tensor_chunk[i])
dist.broadcast(sync_tensor, src=0, group=mp_group)
if (i == tp_rank) and (tensor is not None):
tensor.data.copy_(sync_tensor)
def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name) -> torch.Tensor:
"""broadcast tensor in tp shards across mp_group"""
nonlocal state_dict
nonlocal mp_group
tp_rank = mpu.get_tensor_model_parallel_rank()
tp_size = mpu.get_tensor_model_parallel_world_size()
if torch.distributed.get_rank() == 0:
gate_weight = state_dict[gate_name]
up_weight = state_dict[up_name]
new_gate_up_weight = torch.empty(config.intermediate_size * 2,
config.hidden_size,
dtype=params_dtype,
device=torch.cuda.current_device())
for i in range(tp_size):
intermediate_size_tp = config.intermediate_size // tp_size
gate_weight_tp = gate_weight[i * intermediate_size_tp:(i + 1) * intermediate_size_tp]
up_weight_tp = up_weight[i * intermediate_size_tp:(i + 1) * intermediate_size_tp]
new_gate_up_weight[intermediate_size_tp * 2 * i:intermediate_size_tp * 2 * (i + 1)].copy_(
torch.cat([gate_weight_tp, up_weight_tp], dim=0))
tensor_chunk = torch.chunk(new_gate_up_weight, tp_size, dim=0)
chunk_shape = tensor_chunk[0].shape
else:
chunk_shape = None
obj_list = [chunk_shape]
dist.broadcast_object_list(obj_list, src=0, group=mp_group)
chunk_shape = obj_list[0]
if chunk_shape is None:
# all or none ranks in the mp_group should reach here
print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not in state_dict, skip loading")
return
if tensor is None:
sync_tensor = torch.empty(
chunk_shape,
dtype=params_dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
else:
assert (
tensor.shape == chunk_shape
), f"rank #{torch.distributed.get_rank() == 0:} tensor {gate_name, up_name} shape {tensor.shape} != {chunk_shape}"
sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
for i in range(tp_size):
if torch.distributed.get_rank() == 0:
sync_tensor.data.copy_(tensor_chunk[i])
dist.broadcast(sync_tensor, src=0, group=mp_group)
if (i == tp_rank) and (tensor is not None):
tensor.data.copy_(sync_tensor)
def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, bias=False) -> torch.Tensor:
"""broadcast tensor in tp shards across mp_group"""
nonlocal state_dict
nonlocal mp_group
tp_rank = mpu.get_tensor_model_parallel_rank()
tp_size = mpu.get_tensor_model_parallel_world_size()
if torch.distributed.get_rank() == 0:
assert (q_name in state_dict and k_name in state_dict and v_name in state_dict)
full_weight_q = state_dict[q_name]
full_weight_k = state_dict[k_name]
full_weight_v = state_dict[v_name]
hidden_size_per_head = config.hidden_size // config.num_attention_heads
if config.num_key_value_heads >= tp_size:
q_size_tp = config.hidden_size // tp_size
kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size
total_size = q_size_tp + 2 * kv_size_tp
if not bias:
new_weight_qkv = torch.empty(total_size * tp_size,
config.hidden_size,
dtype=params_dtype,
device=torch.cuda.current_device())
else:
new_weight_qkv = torch.empty(total_size * tp_size,
dtype=params_dtype,
device=torch.cuda.current_device())
for i in range(tp_size):
q_part = full_weight_q[i * q_size_tp:(i + 1) * q_size_tp]
k_part = full_weight_k[i * kv_size_tp:(i + 1) * kv_size_tp]
v_part = full_weight_v[i * kv_size_tp:(i + 1) * kv_size_tp]
new_weight_qkv[i * total_size:(i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part],
dim=0))
else:
q_size_tp = config.hidden_size // tp_size
kv_size_tp = hidden_size_per_head
total_size = q_size_tp + 2 * kv_size_tp
if not bias:
new_weight_qkv = torch.empty(total_size * tp_size,
config.hidden_size,
dtype=params_dtype,
device=torch.cuda.current_device())
else:
new_weight_qkv = torch.empty(total_size * tp_size,
dtype=params_dtype,
device=torch.cuda.current_device())
for i in range(tp_size):
q_part = full_weight_q[i * q_size_tp:(i + 1) * q_size_tp]
start_idx = i * config.num_key_value_heads // tp_size * hidden_size_per_head
end_idx = (i * config.num_key_value_heads // tp_size + 1) * hidden_size_per_head
k_part = full_weight_k[start_idx:end_idx]
v_part = full_weight_v[start_idx:end_idx]
new_weight_qkv[i * total_size:(i + 1) * total_size].copy_(torch.cat([q_part, k_part, v_part],
dim=0))
tensor_chunk = torch.chunk(new_weight_qkv, tp_size, dim=0)
chunk_shape = tensor_chunk[0].shape
else:
chunk_shape = None
obj_list = [chunk_shape]
dist.broadcast_object_list(obj_list, src=0, group=mp_group)
chunk_shape = obj_list[0]
if chunk_shape is None:
# all or none ranks in the mp_group should reach here
print_rank_0(f"tp_shard tensor:[{q_name, k_name, v_name}] not in state_dict, skip loading")
return
if tensor is None:
sync_tensor = torch.empty(
chunk_shape,
dtype=params_dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
else:
assert (tensor.shape == chunk_shape
), f"rank #{torch.distributed.get_rank()} tensor {q_name} shape {tensor.shape} != {chunk_shape}"
sync_tensor = torch.empty_like(tensor, device=torch.cuda.current_device(), requires_grad=False)
for i in range(tp_size):
if torch.distributed.get_rank() == 0:
sync_tensor.data.copy_(tensor_chunk[i])
dist.broadcast(sync_tensor, src=0, group=mp_group)
if (i == tp_rank) and (tensor is not None):
tensor.data.copy_(sync_tensor)
if dp_rank == 0:
# Embeddings
# -------------------
print_rank_0("loading embeddings...")
gpt_model_module = _get_gpt_model(models[0])
embed_tokens_weight = None
if pp_rank == 0:
embed_tokens_weight = gpt_model_module.model.embed_tokens.weight
_broadcast_tp_shard_tensor_vocab(embed_tokens_weight, "model.embed_tokens.weight")
# Transformer layers
# -------------------
layer_map = _megatron_calc_layer_map(config)
for layer in range(config.num_hidden_layers):
print_rank_0(f"loading layer #{layer}...")
layer_name = f"model.layers.{layer}"
dst_pp_rank, dst_virtual_pp_rank, dst_layer_idx = layer_map[layer]
gpt_model_module = _get_gpt_model(models[dst_virtual_pp_rank])
sync_layer = gpt_model_module.model.layers[dst_layer_idx]
_broadcast_tensor(
sync_layer.input_layernorm.weight if dst_pp_rank == pp_rank else None,
f"{layer_name}.input_layernorm.weight",
)
_broadcast_tp_shard_tensor_qkv(
sync_layer.self_attn.qkv_proj.weight if dst_pp_rank == pp_rank else None,
f"{layer_name}.self_attn.q_proj.weight",
f"{layer_name}.self_attn.k_proj.weight",
f"{layer_name}.self_attn.v_proj.weight",
)
_broadcast_tp_shard_tensor_qkv(sync_layer.self_attn.qkv_proj.bias if dst_pp_rank == pp_rank else None,
f"{layer_name}.self_attn.q_proj.bias",
f"{layer_name}.self_attn.k_proj.bias",
f"{layer_name}.self_attn.v_proj.bias",
bias=True)
_broadcast_tp_shard_tensor(
sync_layer.self_attn.o_proj.weight if dst_pp_rank == pp_rank else None,
f"{layer_name}.self_attn.o_proj.weight",
chunk_dim=1,
)
_broadcast_tensor(
sync_layer.post_attention_layernorm.weight if dst_pp_rank == pp_rank else None,
f"{layer_name}.post_attention_layernorm.weight",
)
_broadcast_tp_shard_tensor_gate_up(sync_layer.mlp.gate_up_proj.weight if dst_pp_rank == pp_rank else None,
f"{layer_name}.mlp.gate_proj.weight", f"{layer_name}.mlp.up_proj.weight")
_broadcast_tp_shard_tensor(
sync_layer.mlp.down_proj.weight if dst_pp_rank == pp_rank else None,
f"{layer_name}.mlp.down_proj.weight",
chunk_dim=1,
)
# Final Layernorm
# -------------------
print_rank_0("loading final layernorm...")
gpt_model_module = _get_gpt_model(models[-1])
_broadcast_tensor(
getattr(gpt_model_module.model.norm, "weight", None),
"model.norm.weight",
)
if tie_word_embeddings:
print_rank_0("tie_word_embeddings skip load lm_head")
else:
print_rank_0("loading lm_head...")
lm_head_weight = None
if pp_rank + 1 == pp_size:
lm_head_weight = gpt_model_module.lm_head.weight
if is_value_model:
if 'lm_head.weight' in state_dict and state_dict['lm_head.weight'].shape[0] == 1:
_broadcast_tensor(lm_head_weight, "lm_head.weight")
print_rank_0('load lm_head from value_head weight')
elif 'reward_head.weight' in state_dict and state_dict['reward_head.weight'].shape[0] == 1:
_broadcast_tensor(lm_head_weight, "reward_head.weight")
print_rank_0('load lm_head from value_head weight')
else:
_broadcast_tensor(None, "lm_head.weight")
print_rank_0('fail to match lm_head in value_model')
else:
_broadcast_tp_shard_tensor(lm_head_weight, "lm_head.weight")
dist.barrier()
# Broadcast weights inside data parallel groups
for wrapped_model in wrapped_models:
broadcast_params(wrapped_model)
torch.cuda.empty_cache()
print_rank_0(f"loading megatron ckpt done, time elapsed {time.time() - start_time}s")
\ No newline at end of file
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from verl.utils.megatron_utils import print_rank_0, unwrap_model
from megatron.core import mpu
from megatron.core.transformer.module import Float16Module
from megatron.core.distributed import DistributedDataParallel as LocalDDP
from torch.nn.parallel import DistributedDataParallel as torchDDP
import torch
import time
import torch
import torch.distributed as dist
def _megatron_calc_global_rank(tp_rank: int = 0, dp_rank: int = 0, pp_rank: int = 0):
"""given TP,DP,PP rank to get the global rank."""
tp_size = mpu.get_tensor_model_parallel_world_size()
dp_size = mpu.get_data_parallel_world_size()
pp_size = mpu.get_pipeline_model_parallel_world_size()
assert (tp_size * dp_size * pp_size == torch.distributed.get_world_size()
), f"{tp_size} x {dp_size} x {pp_size} != {torch.distributed.get_world_size()}"
# We only support TP-DP-PP grouping, for correctness when resharding
return (pp_rank * dp_size + dp_rank) * tp_size + tp_rank
def _megatron_calc_layer_map(config):
"""Calculate the mapping of global layer_idx to local layer_idx
Returns:
layer_map (Dict: int -> tuple(int, int, int)):
mapping from the global layer index to
a tuple of (pp_rank, virtual_pp_rank, layer_idx inside model)
"""
from megatron.core import mpu
pp_size = mpu.get_pipeline_model_parallel_world_size()
virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
layer_map = dict()
num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers
for pp_rank_idx in range(pp_size):
for virtual_pp_rank_idx in range(virtual_pp_size):
layer_offset = (virtual_pp_rank_idx * (config.num_hidden_layers // virtual_pp_size) +
pp_rank_idx * num_layers_per_model)
for layer_idx in range(num_layers_per_model):
layer_map[layer_offset + layer_idx] = (
pp_rank_idx,
virtual_pp_rank_idx,
layer_idx,
)
return layer_map
def merge_megatron_ckpt_qwen2(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False):
"""Merge sharded parameters of a Megatron module into a merged checkpoint.
Args:
wrapped_models (list of megatron.core.distributed.DistributedDataParallel):
The local DDP wrapped megatron modules.
config (str or None):
HF config for model
dtype: model params type
is_value_model: if model is value model
tie_word_embeddings: tie_word_embeddings
Returns:
state_dict (dict):
The merged state_dict in rank 0, and an empty dictionary in other ranks.
"""
start_time = time.time()
def _get_gpt_model(model):
return model
dp_rank = mpu.get_data_parallel_rank()
pp_size = mpu.get_pipeline_model_parallel_world_size()
pp_rank = mpu.get_pipeline_model_parallel_rank()
virtual_pp_size = mpu.get_virtual_pipeline_model_parallel_world_size() or 1
mp_group = mpu.get_model_parallel_group()
if dist.get_rank() == 0:
assert mp_group.rank() == 0, f"mp_rank:[{mp_group.rank}] != 0 on rank #0"
assert pp_rank == 0, f"pp_rank:[{pp_rank}] != 0 on rank #0"
assert dp_rank == 0, f"dp_rank:[{dp_rank}] != 0 on rank #0"
if not isinstance(wrapped_models, (list, tuple)):
wrapped_models = list(wrapped_models)
assert len(wrapped_models) == virtual_pp_size
num_layers_per_model = config.num_hidden_layers // pp_size // virtual_pp_size
assert num_layers_per_model * pp_size * virtual_pp_size == config.num_hidden_layers
models = [None] * len(wrapped_models)
for i, wrapped_model in enumerate(wrapped_models):
models[i] = unwrap_model(wrapped_model, (torchDDP, LocalDDP, Float16Module))
assert len(models[i].model.layers
) == num_layers_per_model, 'len model layers {} not equal to num_layers_per_model {}'.format(
len(models[i].model.layers), num_layers_per_model)
state_dict = dict()
def _get_cpu_tensor(tensor: torch.Tensor):
if tensor is None:
return None
if tensor.device == torch.device("cpu"):
return tensor.detach().clone()
return tensor.detach().cpu()
def _broadcast_tensor(tensor, name, src_pp_rank) -> torch.Tensor:
"""broadcast tensor across mp_group"""
nonlocal state_dict
nonlocal mp_group
src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)
if torch.distributed.get_rank() == src_rank:
if tensor is None:
weight = None
tensor_shape = None
else:
weight = tensor
tensor_shape = weight.shape
else:
weight = None
tensor_shape = None
obj_list = [tensor_shape]
dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
tensor_shape = obj_list[0]
if tensor_shape is None:
# all or none ranks in the mp_group should reach here
print_rank_0(f"tensor:[{name}] not exist, skip collect")
return
if weight is None:
weight = torch.empty(
tensor_shape,
dtype=dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
dist.broadcast(weight, src=src_rank, group=mp_group)
if torch.distributed.get_rank() == 0:
state_dict[name] = _get_cpu_tensor(weight)
def _broadcast_tp_shard_tensor(tensor, name, src_pp_rank, concat_dim=0, mutate_func=None) -> torch.Tensor:
"""broadcast tensor in tp shards across mp_group"""
nonlocal state_dict
nonlocal mp_group
tp_rank = mpu.get_tensor_model_parallel_rank()
tp_size = mpu.get_tensor_model_parallel_world_size()
src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)
if torch.distributed.get_rank() == src_rank:
chunk_shape = tensor.shape
else:
chunk_shape = None
obj_list = [chunk_shape]
dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
chunk_shape = obj_list[0]
if chunk_shape is None:
# all or none ranks in the mp_group should reach here
print_rank_0(f"tp_shard tensor:[{name}] not exist, skip collecting")
return
buffer_tensor = torch.empty(
chunk_shape,
dtype=dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
chunk_tensors = [None] * tp_size
for i in range(tp_size):
cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)
sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor
dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)
if torch.distributed.get_rank() == 0:
chunk_tensors[i] = _get_cpu_tensor(sync_tensor)
if torch.distributed.get_rank() == 0:
full_tensor = torch.concat(chunk_tensors, dim=concat_dim)
if mutate_func is not None:
full_tensor = mutate_func(full_tensor)
state_dict[name] = full_tensor
def _broadcast_tp_shard_tensor_gate_up(tensor, gate_name, up_name, src_pp_rank) -> torch.Tensor:
"""broadcast tensor in tp shards across mp_group"""
nonlocal state_dict
nonlocal mp_group
tp_rank = mpu.get_tensor_model_parallel_rank()
tp_size = mpu.get_tensor_model_parallel_world_size()
src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)
if torch.distributed.get_rank() == src_rank:
chunk_shape = tensor.shape
else:
chunk_shape = None
obj_list = [chunk_shape]
dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
chunk_shape = obj_list[0]
if chunk_shape is None:
# all or none ranks in the mp_group should reach here
print_rank_0(f"tp_shard tensor:[{gate_name, up_name}] not exist, skip collecting")
return
buffer_tensor = torch.empty(
chunk_shape,
dtype=dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
chunk_tensors = [None] * tp_size
for i in range(tp_size):
cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)
sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor
dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)
if torch.distributed.get_rank() == 0:
chunk_tensors[i] = _get_cpu_tensor(sync_tensor)
if torch.distributed.get_rank() == 0:
full_tensor = torch.concat(chunk_tensors, dim=0)
intermediate_size_tp = config.intermediate_size // tp_size
gate_weight_list = []
up_weight_list = []
for i in range(tp_size):
gate_up_weight_tp = full_tensor[intermediate_size_tp * 2 * i:intermediate_size_tp * 2 * (i + 1)]
gate_weight_tp = gate_up_weight_tp[:intermediate_size_tp]
up_weight_tp = gate_up_weight_tp[intermediate_size_tp:]
gate_weight_list.append(gate_weight_tp)
up_weight_list.append(up_weight_tp)
state_dict[gate_name] = torch.cat(gate_weight_list, dim=0)
state_dict[up_name] = torch.cat(up_weight_list, dim=0)
def _broadcast_tp_shard_tensor_qkv(tensor, q_name, k_name, v_name, src_pp_rank):
"""broadcast tensor in tp shards across mp_group"""
nonlocal state_dict
nonlocal mp_group
tp_rank = mpu.get_tensor_model_parallel_rank()
tp_size = mpu.get_tensor_model_parallel_world_size()
src_rank = _megatron_calc_global_rank(tp_rank=0, dp_rank=0, pp_rank=src_pp_rank)
if torch.distributed.get_rank() == src_rank:
chunk_shape = tensor.shape
else:
chunk_shape = None
obj_list = [chunk_shape]
dist.broadcast_object_list(obj_list, src=src_rank, group=mp_group)
chunk_shape = obj_list[0]
if chunk_shape is None:
# all or none ranks in the mp_group should reach here
print_rank_0(f"tp_shard tensor:[{q_name}] not exist, skip collecting")
return
buffer_tensor = torch.empty(
chunk_shape,
dtype=dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
chunk_tensors = [None] * tp_size
for i in range(tp_size):
cur_src_rank = _megatron_calc_global_rank(tp_rank=i, dp_rank=0, pp_rank=src_pp_rank)
sync_tensor = tensor if torch.distributed.get_rank() == cur_src_rank else buffer_tensor
dist.broadcast(sync_tensor, src=cur_src_rank, group=mp_group)
if torch.distributed.get_rank() == 0:
chunk_tensors[i] = _get_cpu_tensor(sync_tensor)
if torch.distributed.get_rank() == 0:
full_tensor = torch.concat(chunk_tensors, dim=0)
q_weight_list = []
k_weight_list = []
v_weight_list = []
hidden_size_per_head = config.hidden_size // config.num_attention_heads
if config.num_key_value_heads >= tp_size:
q_size_tp = config.hidden_size // tp_size
kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size
total_size = q_size_tp + 2 * kv_size_tp
for i in range(tp_size):
qkv_part = full_tensor[i * total_size:(i + 1) * total_size]
q_part = qkv_part[:q_size_tp]
k_part = qkv_part[q_size_tp:q_size_tp + kv_size_tp]
v_part = qkv_part[q_size_tp + kv_size_tp:total_size]
q_weight_list.append(q_part)
k_weight_list.append(k_part)
v_weight_list.append(v_part)
else:
q_size_tp = config.hidden_size // tp_size
kv_size_tp = hidden_size_per_head
total_size = q_size_tp + 2 * kv_size_tp
for i in range(tp_size):
qkv_part = full_tensor[i * total_size:(i + 1) * total_size]
q_part = qkv_part[:q_size_tp]
k_part = qkv_part[q_size_tp:q_size_tp + kv_size_tp]
v_part = qkv_part[q_size_tp + kv_size_tp:total_size]
q_weight_list.append(q_part)
if i * config.num_key_value_heads % tp_size == 0:
k_weight_list.append(k_part)
v_weight_list.append(v_part)
state_dict[q_name] = torch.cat(q_weight_list, dim=0)
state_dict[k_name] = torch.cat(k_weight_list, dim=0)
state_dict[v_name] = torch.cat(v_weight_list, dim=0)
# empty cache before collecting weights
torch.cuda.empty_cache()
# Embeddings
# -------------------
if dp_rank == 0:
# Embeddings
# -------------------
print_rank_0("collecting embeddings...")
gpt_model_module = _get_gpt_model(models[0])
_broadcast_tp_shard_tensor(
gpt_model_module.model.embed_tokens.weight if pp_rank == 0 else None,
"model.embed_tokens.weight",
src_pp_rank=0,
)
# Transformer layers
# -------------------
layer_map = _megatron_calc_layer_map(config)
for layer in range(config.num_hidden_layers):
print_rank_0(f"collecting layer #{layer}...")
layer_name = f"model.layers.{layer}"
src_pp_rank, src_virtual_pp_rank, src_layer_idx = layer_map[layer]
gpt_model_module = _get_gpt_model(models[src_virtual_pp_rank])
sync_layer = gpt_model_module.model.layers[src_layer_idx]
_broadcast_tensor(
sync_layer.input_layernorm.weight,
f"{layer_name}.input_layernorm.weight",
src_pp_rank=src_pp_rank,
)
_broadcast_tp_shard_tensor_qkv(
sync_layer.self_attn.qkv_proj.weight,
f"{layer_name}.self_attn.q_proj.weight",
f"{layer_name}.self_attn.k_proj.weight",
f"{layer_name}.self_attn.v_proj.weight",
src_pp_rank=src_pp_rank,
)
_broadcast_tp_shard_tensor_qkv(
sync_layer.self_attn.qkv_proj.bias,
f"{layer_name}.self_attn.q_proj.bias",
f"{layer_name}.self_attn.k_proj.bias",
f"{layer_name}.self_attn.v_proj.bias",
src_pp_rank=src_pp_rank,
)
_broadcast_tp_shard_tensor(
sync_layer.self_attn.o_proj.weight,
f"{layer_name}.self_attn.o_proj.weight",
concat_dim=1,
src_pp_rank=src_pp_rank,
)
_broadcast_tensor(
sync_layer.post_attention_layernorm.weight,
f"{layer_name}.post_attention_layernorm.weight",
src_pp_rank=src_pp_rank,
)
_broadcast_tp_shard_tensor_gate_up(sync_layer.mlp.gate_up_proj.weight,
f"{layer_name}.mlp.gate_proj.weight",
f"{layer_name}.mlp.up_proj.weight",
src_pp_rank=src_pp_rank)
_broadcast_tp_shard_tensor(
sync_layer.mlp.down_proj.weight,
f"{layer_name}.mlp.down_proj.weight",
concat_dim=1,
src_pp_rank=src_pp_rank,
)
# Final Layernorm
# -------------------
print_rank_0("collecting final layernorm...")
gpt_model_module = _get_gpt_model(models[-1])
_broadcast_tensor(
getattr(gpt_model_module.model.norm, "weight", None),
"model.norm.weight",
src_pp_rank=pp_size - 1,
)
if tie_word_embeddings:
print_rank_0(f"tie word embedding skip load lm_head...")
else:
print_rank_0("collecting lm_head...")
if is_value_model:
_broadcast_tensor(gpt_model_module.lm_head.weight if pp_rank == pp_size - 1 else None,
"lm_head.weight",
src_pp_rank=pp_size - 1)
_broadcast_tensor(gpt_model_module.reward_head.weight if pp_rank == pp_size - 1 and
getattr(gpt_model_module, "reward_weight", None) is not None else None,
"reward_head.weight",
src_pp_rank=pp_size - 1)
else:
_broadcast_tp_shard_tensor(
getattr(gpt_model_module.lm_head, "weight", None) if pp_rank == pp_size - 1 else None,
"lm_head.weight",
src_pp_rank=pp_size - 1,
)
dist.barrier()
torch.cuda.empty_cache()
if torch.distributed.get_rank() == 0:
for k, v in state_dict.items():
if dtype != v.dtype:
state_dict[k] = v.to(dtype)
print_rank_0(f"merge megatron ckpt done, time elapsed {time.time() - start_time}s")
return state_dict
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .parallel_attention import ParallelQwen2Attention
from .parallel_decoder import ParallelQwen2DecoderLayer, ParallelQwen2DecoderLayerRmPad
from .parallel_mlp import ParallelQwen2MLP
from .parallel_rmsnorm import ParallelQwen2RMSNorm
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, Tuple
import torch
from megatron.core import parallel_state as mpu
from megatron.core import tensor_parallel
from megatron.core import ModelParallelConfig
from torch import nn
from transformers import Qwen2Config
from verl.models.qwen2.megatron.layers.parallel_linear import QKVParallelLinear
from verl.utils.megatron import tensor_parallel as tp_utils
class Qwen2RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base**(torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(seq_len=max_position_embeddings,
device=self.inv_freq.device,
dtype=torch.get_default_dtype())
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype),
)
class Qwen2LinearScalingRotaryEmbedding(Qwen2RotaryEmbedding):
"""Qwen2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = t / self.scaling_factor
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
class Qwen2DynamicNTKScalingRotaryEmbedding(Qwen2RotaryEmbedding):
"""Qwen2RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
if seq_len > self.max_position_embeddings:
base = self.base * ((self.scaling_factor * seq_len / self.max_position_embeddings) -
(self.scaling_factor - 1))**(self.dim / (self.dim - 2))
inv_freq = 1.0 / (base**(torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class ParallelQwen2Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig):
super().__init__()
self.config = config
self.megatron_config = megatron_config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
# assign values after tp
tp_size = mpu.get_tensor_model_parallel_world_size()
assert self.num_heads % tp_size == 0, f'num_head must be divisible by tp_size. Got num_head={self.num_heads}, tp_size={tp_size}'
assert self.num_key_value_heads % tp_size == 0, \
f'num_key_value_heads must be divisible by tp_size. Got num_key_value_heads={self.num_key_value_heads}, tp_size={tp_size}'
self.num_heads_per_tp = self.num_heads // tp_size
self.num_key_value_heads_per_tp = self.num_key_value_heads // tp_size
self.hidden_size_per_tp = self.hidden_size // tp_size
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads}).")
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear()
if megatron_config is not None:
assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
assert row_kwargs.get('config', False), 'must have ModelParallelConfig'
tp_utils.update_kwargs_with_config(column_kwargs, megatron_config)
tp_utils.update_kwargs_with_config(row_kwargs, megatron_config)
# [self.q_size, self.k_size, self.v_size]
self.qkv_proj = QKVParallelLinear(
input_size=self.hidden_size,
num_heads=self.num_heads,
num_key_value_heads=self.num_key_value_heads,
head_dim=self.head_dim,
# bias=config.attention_bias,
bias=True,
gather_output=False,
skip_bias_add=False,
**column_kwargs)
self.q_size = self.num_heads_per_tp * self.head_dim
self.k_size = self.num_key_value_heads_per_tp * self.head_dim
self.v_size = self.num_key_value_heads_per_tp * self.head_dim
self.o_proj = tensor_parallel.RowParallelLinear(
input_size=self.num_heads * self.head_dim,
output_size=self.hidden_size,
# bias=config.attention_bias,
bias=False,
input_is_parallel=True,
skip_bias_add=False,
**row_kwargs)
self._init_rope()
def _init_rope(self):
self.rotary_emb = Qwen2RotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
qkv = self.qkv_proj(hidden_states)[0]
query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1)
query_states = query_states.view(bsz, q_len, self.num_heads_per_tp, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads_per_tp, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads_per_tp, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads_per_tp, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}")
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}")
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads_per_tp, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads_per_tp, q_len, self.head_dim)}, but is"
f" {attn_output.size()}")
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size_per_tp)
attn_output = self.o_proj(attn_output)[0]
return attn_output
"""
Remove padding Attention
- Using Flash-attn 2
- Compatible with sequence parallel
"""
from transformers.utils import is_flash_attn_2_available
import torch.nn.functional as F
from einops import rearrange
if is_flash_attn_2_available():
from flash_attn import flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
def apply_rotary_pos_emb_rmpad(q, k, cos, sin, position_ids, indices, sequence_length):
batch_size = position_ids.shape[0]
q = pad_input(q, indices, batch_size, sequence_length) # (batch_size, seqlen, num_head, head_dim)
k = pad_input(k, indices, batch_size, sequence_length)
cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
q_embed = index_first_axis(rearrange(q_embed, "b s ... -> (b s) ..."), indices)
k_embed = index_first_axis(rearrange(k_embed, "b s ... -> (b s) ..."), indices)
return q_embed, k_embed
from flash_attn.layers.rotary import apply_rotary_emb
# use flash-attn rotary embeddings with rmpad
# cos/sin shoudl be: (seq_length, rotary_dim / 2)
def apply_rotary_pos_emb_rmpad_flash(q, k, cos, sin, cu_seqlens, max_seqlen):
q_embed = apply_rotary_emb(q,
cos,
sin,
interleaved=False,
inplace=False,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen)
k_embed = apply_rotary_emb(k,
cos,
sin,
interleaved=False,
inplace=False,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen)
return q_embed, k_embed
class ParallelQwen2AttentionRmPad(ParallelQwen2Attention):
def forward(self,
hidden_states: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
sequence_length: int = None,
indices: torch.Tensor = None,
cu_seqlens: torch.Tensor = None,
max_seqlen_in_batch: int = None):
total_nnz, _, _ = hidden_states.size() # This is the total_nnz padded after sequence parallel
if self.megatron_config.sequence_parallel:
total_nnz = total_nnz * mpu.get_tensor_model_parallel_world_size()
qkv = self.qkv_proj(hidden_states)[0]
query_states, key_states, value_states = qkv.split([self.q_size, self.k_size, self.v_size],
dim=-1) # (total_nnz, 1, hidden_size)
if self.megatron_config.sequence_parallel:
sequence_parallel_pad = total_nnz - cu_seqlens[-1]
total_nnz = cu_seqlens[-1] # total_nnz before sp padding
query_states = query_states[:total_nnz]
key_states = key_states[:total_nnz]
value_states = value_states[:total_nnz]
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dime x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(total_nnz, self.num_heads_per_tp, self.head_dim)
key_states = key_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim)
value_states = value_states.view(total_nnz, self.num_key_value_heads_per_tp, self.head_dim)
cos, sin = self.rotary_emb(value_states, seq_len=sequence_length)
cos, sin = cos[:, :cos.shape[1] // 2], sin[:, :sin.shape[1] // 2] # flash attn only needs half
query_states, key_states = apply_rotary_pos_emb_rmpad_flash(query_states,
key_states,
cos,
sin,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen_in_batch)
# query_states, key_states = apply_rotary_pos_emb_rmpad(query_states, key_states, cos, sin, position_ids, indices,
# It is recommended to use dropout with FA according to the docs
# when training.
dropout_rate = 0.0 # if not self.training else self.attn_dropout
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (Qwen2RMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
query_states = query_states.to(torch.float16)
key_states = key_states.to(torch.float16)
value_states = value_states.to(torch.float16)
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen_in_batch,
max_seqlen_k=max_seqlen_in_batch,
dropout_p=dropout_rate,
softmax_scale=None,
causal=True,
)
attn_output_unpad = attn_output_unpad.to(input_dtype)
attn_output_unpad = attn_output_unpad.reshape(total_nnz, 1, self.hidden_size_per_tp).contiguous()
# sequence parallel reduce_scatter is performed inside RowColumnParallel if enabled
# Here we need to repad
if self.megatron_config.sequence_parallel:
attn_output_unpad = F.pad(attn_output_unpad, pad=(0, 0, 0, 0, 0, sequence_parallel_pad))
attn_output_unpad = self.o_proj(attn_output_unpad)[0]
return attn_output_unpad
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
import torch
from torch import nn
from transformers import Qwen2Config
from megatron.core import ModelParallelConfig
from .parallel_attention import ParallelQwen2Attention, ParallelQwen2AttentionRmPad
from .parallel_mlp import ParallelQwen2MLP
from .parallel_rmsnorm import ParallelQwen2RMSNorm
from verl.utils.megatron_utils import TransformerConfig, convert_config
class ParallelQwen2DecoderLayer(nn.Module):
def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig, layer_idx: int):
super().__init__()
self.config: TransformerConfig = convert_config(config, megatron_config)
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
self.self_attn = ParallelQwen2Attention(config=config, megatron_config=megatron_config)
self.mlp = ParallelQwen2MLP(config, megatron_config=megatron_config)
self.input_layernorm = ParallelQwen2RMSNorm(config, megatron_config)
self.post_attention_layernorm = ParallelQwen2RMSNorm(config, megatron_config)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Note: sequence parallel is hidden inside ColumnParallelLinear
# reduce scatter is hidden inside RowParallelLinear
# Self Attention
hidden_states = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
)
# TODO: add sequence parallel operator reduce_scatter here
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
# TODO: add sequence parallel operator all_gather here
hidden_states = self.mlp(hidden_states)
# TODO: add sequence parallel operator reduce_scatter here
hidden_states = residual + hidden_states
outputs = hidden_states
return outputs
class ParallelQwen2DecoderLayerRmPad(nn.Module):
def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig, layer_idx: int):
super().__init__()
self.config: TransformerConfig = convert_config(config, megatron_config)
self.hidden_size = config.hidden_size
self.layer_idx = layer_idx
self.self_attn = ParallelQwen2AttentionRmPad(config=config, megatron_config=megatron_config)
self.mlp = ParallelQwen2MLP(config, megatron_config=megatron_config)
self.input_layernorm = ParallelQwen2RMSNorm(config, megatron_config)
self.post_attention_layernorm = ParallelQwen2RMSNorm(config, megatron_config)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
sequence_length: int = None,
indices: torch.Tensor = None,
cu_seqlens: int = None,
max_seqlen_in_batch: int = None
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states # (total_nnz // sp, 1, hidden_size)
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
# (total_nnz // sp, 1, hidden_size) -> all-gather (total_nnz, 1, hidden_size)
# -> col + row -> reduce-scatter -> (total_nnz // sp, 1, hidden_size)
hidden_states = self.self_attn(hidden_states=hidden_states,
position_ids=position_ids,
sequence_length=sequence_length,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen_in_batch=max_seqlen_in_batch)
hidden_states = residual + hidden_states
# Fully Connected
# shape changes same as attn
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = hidden_states
return outputs
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023 The vLLM team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/linear.py
from typing import Optional, Tuple
from megatron.core import tensor_parallel
class QKVParallelLinear(tensor_parallel.ColumnParallelLinear):
def __init__(self,
input_size,
num_heads,
num_key_value_heads,
head_dim,
*,
bias=True,
gather_output=True,
skip_bias_add=False,
**kwargs):
# Keep input parameters, and already restrict the head numbers
self.input_size = input_size
self.q_output_size = num_heads * head_dim
self.kv_output_size = num_key_value_heads * head_dim
self.head_dim = head_dim
self.gather_output = gather_output
self.skip_bias_add = skip_bias_add
input_size = self.input_size
output_size = (num_heads + 2 * num_key_value_heads) * self.head_dim
super().__init__(input_size=input_size,
output_size=output_size,
bias=bias,
gather_output=gather_output,
skip_bias_add=skip_bias_add,
**kwargs)
class MergedColumnParallelLinear(tensor_parallel.ColumnParallelLinear):
def __init__(self,
input_size,
gate_ouput_size,
up_output_size,
*,
bias=True,
gather_output=True,
skip_bias_add=False,
**kwargs):
# Keep input parameters, and already restrict the head numbers
self.input_size = input_size
self.output_size = gate_ouput_size + up_output_size
self.gather_output = gather_output
self.skip_bias_add = skip_bias_add
super().__init__(input_size=self.input_size,
output_size=self.output_size,
bias=bias,
gather_output=gather_output,
skip_bias_add=skip_bias_add,
**kwargs)
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from megatron.core import parallel_state as mpu
from megatron.core import tensor_parallel
from megatron.core import ModelParallelConfig
from torch import nn
from transformers.activations import ACT2FN
from verl.models.qwen2.megatron.layers.parallel_linear import MergedColumnParallelLinear
from verl.utils.megatron import tensor_parallel as tp_utils
class ParallelQwen2MLP(nn.Module):
def __init__(self, config, megatron_config: ModelParallelConfig = None) -> None:
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
# The weight is only [hidden_size, intermediate_size // model_parallel_world_size]
column_kwargs = tp_utils.get_default_kwargs_for_column_parallel_linear()
row_kwargs = tp_utils.get_default_kwargs_for_row_parallel_linear()
if megatron_config is not None:
assert column_kwargs.get('config', False), 'must have ModelParallelConfig'
assert row_kwargs.get('config', False), 'must have ModelParallelConfig'
tp_utils.update_kwargs_with_config(row_kwargs, megatron_config)
tp_utils.update_kwargs_with_config(column_kwargs, megatron_config)
tp_size = mpu.get_tensor_model_parallel_world_size()
self.gate_up_proj = MergedColumnParallelLinear(
input_size=self.hidden_size,
gate_ouput_size=self.intermediate_size,
up_output_size=self.intermediate_size,
bias=False,
gather_output=False,
skip_bias_add=False,
**column_kwargs,
)
self.gate_size = self.intermediate_size // tp_size
self.down_proj = tensor_parallel.RowParallelLinear(input_size=self.intermediate_size,
output_size=self.hidden_size,
bias=False,
input_is_parallel=True,
skip_bias_add=False,
**row_kwargs)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
gate_up = self.gate_up_proj(x)[0]
gate, up = gate_up.split(self.gate_size, dim=-1)
return self.down_proj(self.act_fn(gate) * up)[0]
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numbers
import torch
from megatron.core import ModelParallelConfig
from torch import nn
from transformers import Qwen2Config
from apex.normalization.fused_layer_norm import fused_rms_norm_affine
from verl.utils.megatron import sequence_parallel as sp_utils
class ParallelQwen2RMSNorm(nn.Module):
def __init__(self, config: Qwen2Config, megatron_config: ModelParallelConfig):
"""
Qwen2RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
if isinstance(config.hidden_size, numbers.Integral):
normalized_shape = (config.hidden_size,)
self.normalized_shape = torch.Size(normalized_shape)
self.weight = nn.Parameter(torch.ones(self.normalized_shape))
self.variance_epsilon = config.rms_norm_eps
if megatron_config.sequence_parallel:
sp_utils.mark_parameter_as_sequence_parallel(self.weight)
def forward(self, hidden_states):
return fused_rms_norm_affine(input=hidden_states,
weight=self.weight,
normalized_shape=self.normalized_shape,
eps=self.variance_epsilon,
memory_efficient=True)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment