Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
688448db
Commit
688448db
authored
Mar 14, 2025
by
silencealiang
Browse files
更新代码
parent
a02a5490
Pipeline
#2503
passed with stage
Changes
172
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3236 additions
and
2198 deletions
+3236
-2198
megatron/core/inference/model_inference_wrappers/multimodal/vlm_inference_wrapper.py
...el_inference_wrappers/multimodal/vlm_inference_wrapper.py
+208
-0
megatron/core/inference/model_inference_wrappers/t5/t5_inference_wrapper.py
...rence/model_inference_wrappers/t5/t5_inference_wrapper.py
+225
-215
megatron/core/inference/modelopt_support/__init__.py
megatron/core/inference/modelopt_support/__init__.py
+10
-8
megatron/core/inference/modelopt_support/gpt/model_specs.py
megatron/core/inference/modelopt_support/gpt/model_specs.py
+68
-63
megatron/core/inference/modelopt_support/mamba/__init__.py
megatron/core/inference/modelopt_support/mamba/__init__.py
+1
-0
megatron/core/inference/modelopt_support/mamba/model_specs.py
...tron/core/inference/modelopt_support/mamba/model_specs.py
+89
-0
megatron/core/inference/sampling_params.py
megatron/core/inference/sampling_params.py
+36
-35
megatron/core/inference/scheduler.py
megatron/core/inference/scheduler.py
+175
-127
megatron/core/inference/text_generation_controllers/encoder_decoder_text_generation_controller.py
...controllers/encoder_decoder_text_generation_controller.py
+38
-35
megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py
...neration_controllers/simple_text_generation_controller.py
+5
-5
megatron/core/inference/text_generation_controllers/text_generation_controller.py
...text_generation_controllers/text_generation_controller.py
+674
-400
megatron/core/inference/text_generation_controllers/vlm_text_generation_controller.py
..._generation_controllers/vlm_text_generation_controller.py
+40
-0
megatron/core/inference_params.py
megatron/core/inference_params.py
+100
-31
megatron/core/jit.py
megatron/core/jit.py
+1
-15
megatron/core/model_parallel_config.py
megatron/core/model_parallel_config.py
+392
-387
megatron/core/models/T5/t5_model.py
megatron/core/models/T5/t5_model.py
+72
-5
megatron/core/models/common/embeddings/relative_pos_embedding.py
...n/core/models/common/embeddings/relative_pos_embedding.py
+173
-0
megatron/core/models/common/embeddings/rotary_pos_embedding.py
...ron/core/models/common/embeddings/rotary_pos_embedding.py
+215
-213
megatron/core/models/gpt/gpt_layer_specs.py
megatron/core/models/gpt/gpt_layer_specs.py
+383
-350
megatron/core/models/gpt/gpt_model.py
megatron/core/models/gpt/gpt_model.py
+331
-309
No files found.
Too many changes to show.
To preserve performance only
172 of 172+
files are displayed.
Plain diff
Email patch
megatron/core/inference/model_inference_wrappers/multimodal/vlm_inference_wrapper.py
0 → 100644
View file @
688448db
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
typing
import
Any
,
Dict
import
torch
from
megatron.core
import
parallel_state
from
megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper
import
(
GPTInferenceWrapper
,
)
from
megatron.core.inference_params
import
InferenceParams
# pylint: disable=line-too-long
class
VLMInferenceWrapper
(
GPTInferenceWrapper
):
"""Inference wrapper for VLMs"""
def
prep_model_for_inference
(
self
,
prompts_tokens
:
torch
.
Tensor
):
"""A utility function for preparing model for inference
The function gets called once before the auto regressive inference loop.
It puts the model in eval mode.
Args:
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_seq_len]
"""
super
().
prep_model_for_inference
(
prompts_tokens
)
# For TP only model both is_pp_first_stage and _is_pp_last_stage returns True
self
.
model_is_pipeline_parallel
=
not
(
parallel_state
.
is_pipeline_first_stage
()
and
parallel_state
.
is_pipeline_last_stage
()
)
self
.
_recv_only_vision_embeds
=
False
pp_rank
=
parallel_state
.
get_pipeline_model_parallel_rank
()
# Checks if the previous stage only has a vision encoder, and that the current stage
# has part of the LM decoder. In this case, the current stage should only receive
# vision embeddings.
if
pp_rank
>
0
:
self
.
_recv_only_vision_embeds
=
(
parallel_state
.
is_inside_encoder
(
pp_rank
-
1
)
and
(
not
parallel_state
.
is_inside_decoder
(
pp_rank
-
1
))
and
parallel_state
.
is_inside_decoder
()
)
# Checks if the current stage only has a vision encoder
self
.
_encoder_only
=
(
parallel_state
.
is_inside_encoder
()
and
not
parallel_state
.
is_inside_decoder
()
)
# For TP only model both is_pp_first_stage and _is_pp_last_stage returns True
self
.
model_is_pipeline_parallel
=
not
(
parallel_state
.
is_pipeline_first_stage
()
and
parallel_state
.
is_pipeline_last_stage
()
)
def
prep_inference_input
(
self
,
prompts_tokens
:
torch
.
Tensor
,
num_img_embeddings_per_tile
:
int
,
images
:
torch
.
Tensor
,
num_tiles
:
torch
.
Tensor
,
decoder_seq_length
:
int
,
):
"""Prepares the inference input data.
Args:
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_seq_len]
num_img_embeddings_per_tile (int): The number of image embeddings per tile
images (torch.Tensor): The image embeddings
num_tiles (torch.Tensor): The number of tiles for each input image
decoder_seq_length (int): The decoder sequence length
"""
inference_input
=
super
().
prep_inference_input
(
prompts_tokens
)
total_num_tiles
=
torch
.
sum
(
num_tiles
).
item
()
num_img_embeddings
=
num_img_embeddings_per_tile
*
total_num_tiles
batch_size
,
max_sequence_length
=
prompts_tokens
.
shape
self
.
inference_params
=
InferenceParams
(
batch_size
,
max_sequence_length
+
num_img_embeddings
)
inference_input
[
"images"
]
=
images
inference_input
[
"num_tiles"
]
=
num_tiles
inference_input
[
"num_img_embeddings"
]
=
num_img_embeddings
inference_input
[
"decoder_seq_length"
]
=
decoder_seq_length
return
inference_input
def
get_batch_for_context_window
(
self
,
inference_input
:
Dict
[
str
,
Any
],
context_start_position
:
int
,
context_end_position
:
int
,
)
->
Dict
[
str
,
Any
]:
"""Returns the inference data given context window
This function gets called iteratively in a loop . Given the start and end context positions , it extracts the appropriate data.
Args:
inference_input (Dict[str, Any]): The inference input for the batch.
context_start_position (int): Start of the context window. During the first inference step it is mostly 0
context_end_position (int): End of the context window. During the last inference step it will mostly be the max generated sequence length.
Returns:
Dict[str, Any]: A dict of inputs that will be used by your model in the forward step
"""
tokens
=
inference_input
[
"tokens"
]
position_ids
=
inference_input
[
"position_ids"
]
images
=
inference_input
[
"images"
]
num_tiles
=
inference_input
[
"num_tiles"
]
num_img_embeddings
=
inference_input
[
"num_img_embeddings"
]
decoder_seq_length
=
inference_input
[
"decoder_seq_length"
]
tokens2use
=
tokens
[:,
context_start_position
:
context_end_position
]
positions2use
=
position_ids
[:,
context_start_position
:
context_end_position
]
return
{
"tokens"
:
tokens2use
,
"position_ids"
:
positions2use
,
"images"
:
images
,
"num_tiles"
:
num_tiles
,
"num_img_embeddings"
:
num_img_embeddings
,
"decoder_seq_length"
:
decoder_seq_length
,
}
def
_forward
(
self
,
inference_input
:
Dict
[
str
,
Any
]):
"""Runs a forward pass of the model.
Args:
inference_input(Dict[str, Any]): The input data.
Returns:
The model output logits.
"""
images
=
inference_input
[
"images"
]
tokens
=
inference_input
[
"tokens"
]
position_ids
=
inference_input
[
"position_ids"
]
num_image_tiles
=
inference_input
[
"num_tiles"
]
output
=
self
.
model
(
images
,
tokens
,
position_ids
=
position_ids
,
attention_mask
=
None
,
inference_params
=
self
.
inference_params
,
num_image_tiles
=
num_image_tiles
,
runtime_gather_output
=
True
,
)
if
isinstance
(
output
,
tuple
):
logits
,
_
=
output
else
:
logits
=
output
return
logits
def
run_one_forward_step
(
self
,
inference_input
:
Dict
[
str
,
Any
])
->
torch
.
Tensor
:
tokens
=
inference_input
[
"tokens"
]
num_image_tokens
=
(
tokens
==
self
.
model
.
module
.
image_token_index
).
sum
().
item
()
num_img_embeddings
=
inference_input
[
"num_img_embeddings"
]
decoder_seq_length
=
inference_input
[
"decoder_seq_length"
]
num_tokens
=
tokens
.
size
(
1
)
recv_buffer_seq_len
=
None
if
num_image_tokens
>
0
:
# When there are image tokens and this stage only receives vision embeddings,
# adjust the recv buffer seq length to match the image embeddings sequence length.
# If there are image tokens and this stage receives full embeddings, make sure we
# compensate for expansion of image tokens.
# Note that this will set a recv_buffer_seq_len for the encoder stage,
# this length is irrelevant since that recv buffer is never allocated.
if
self
.
_recv_only_vision_embeds
:
recv_buffer_seq_len
=
num_img_embeddings
else
:
recv_buffer_seq_len
=
min
(
num_img_embeddings
+
num_tokens
-
num_image_tokens
,
decoder_seq_length
)
elif
self
.
_recv_only_vision_embeds
:
# If this stage only receives vision embeddings and there are no image tokens
# we won't run the encoder and therefore shouldn't try to recv.
recv_buffer_seq_len
=
0
# If the pipeline stage only has a vision encoder, then it only needs to
# run when there are image tokens
if
not
(
self
.
_encoder_only
and
num_image_tokens
==
0
):
output
=
super
().
run_one_forward_step
(
inference_input
,
recv_buffer_seq_len
=
recv_buffer_seq_len
)
else
:
output
=
None
logits
=
output
# On the first inference iteration, we compute image tokens.
# On every PP stage(although inference params should only matter for decoder),
# update the sequence length offset by the number of image tokens.
if
num_tokens
>
1
and
num_image_tokens
>
0
:
if
"image_tokens_count"
not
in
self
.
inference_params
.
key_value_memory_dict
:
self
.
inference_params
.
key_value_memory_dict
[
"image_tokens_count"
]
=
(
num_img_embeddings
)
if
num_img_embeddings
+
num_tokens
-
num_image_tokens
>
decoder_seq_length
:
self
.
inference_params
.
sequence_len_offset
+=
decoder_seq_length
-
num_tokens
else
:
self
.
inference_params
.
sequence_len_offset
+=
(
self
.
inference_params
.
key_value_memory_dict
[
"image_tokens_count"
]
-
num_image_tokens
)
return
logits
megatron/core/inference/model_inference_wrappers/t5/t5_inference_wrapper.py
View file @
688448db
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
collections
import
deque
from
typing
import
Any
,
List
,
Tuple
import
numpy
import
torch
from
megatron.core
import
tensor_parallel
from
megatron.core.datasets.t5_dataset
import
T5MaskedWordPieceDataset
from
megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper
import
(
AbstractModelInferenceWrapper
,
)
from
megatron.core.inference.model_inference_wrappers.inference_wrapper_config
import
(
InferenceWrapperConfig
,
)
from
megatron.core.models.T5
import
T5Model
# pylint: disable=line-too-long
class
T5InferenceWrapper
(
AbstractModelInferenceWrapper
):
"""Constructor for the model inference wrapper
The wrapper prepares the model for inference, provides the required input
data, and runs the forward pass
Args:
model (T5Model): The T5 model (MCore or legacy)
inference_wrapper_config (InferenceWrapperConfig): The command line arguments that were passed
use_local (bool): Whether the T5 model's transformer impl
is local (vs transformer_engine)
"""
def
__init__
(
self
,
model
:
T5Model
,
inference_wrapper_config
:
InferenceWrapperConfig
,
use_local
:
bool
=
False
,
):
super
().
__init__
(
model
,
inference_wrapper_config
)
self
.
use_local
=
use_local
def
prep_model_for_inference
(
self
,
prompts_tokens
:
torch
.
Tensor
,
encoder_prompts
:
List
[
str
]
=
None
,
tokenizer
:
Any
=
None
):
"""A utility function for preparing model for inference
This function is called before the forward pass. It puts the model in eval mode, builds
position ids, and creates attention masks so that required slices can be extracted during
the forward pass.
Args:
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length]
encoder_prompts (dict): List of string of encoder input prompts
tokenizer (_type_): Tokenizer used for tokenizing and detokenizing text
"""
super
().
prep_model_for_inference
(
prompts_tokens
=
prompts_tokens
)
# get max_sequence_length
if
hasattr
(
self
.
model
,
"module"
):
# if self.model is Float16Module
max_sequence_length
=
self
.
model
.
module
.
max_sequence_length
else
:
max_sequence_length
=
self
.
model
.
max_sequence_length
encoder_prompts_tokens_list
=
[
self
.
tokenize_encoder_prompt
(
encoder_prompt
,
tokenizer
)
for
encoder_prompt
in
encoder_prompts
]
self
.
batch_encoder_prompts_tokens
=
self
.
pad_encoder_prompts_tokens
(
encoder_prompts_tokens_list
,
max_sequence_length
,
tokenizer
)
# create batch mask for encoder_prompt (self.batch_input_tokens) and
# decoder_input (self.prompts_tokens), similar to megatron/core/datasets/t5_dataset.py
decoder_prompts_tokens
=
self
.
prompts_tokens
.
cpu
().
numpy
()
encoder_prompts_tokens
=
self
.
batch_encoder_prompts_tokens
.
cpu
().
numpy
()
self
.
batch_mask_encoder
=
[]
self
.
batch_mask_decoder
=
[]
for
i
in
range
(
len
(
self
.
prompts_tokens
)):
mask_encoder
=
encoder_prompts_tokens
[
i
]
==
tokenizer
.
pad
mask_decoder
=
decoder_prompts_tokens
[
i
]
==
tokenizer
.
pad
self
.
batch_mask_encoder
.
append
(
mask_encoder
)
self
.
batch_mask_decoder
.
append
(
mask_decoder
)
self
.
batch_mask_encoder
=
torch
.
tensor
(
numpy
.
array
(
self
.
batch_mask_encoder
)).
cuda
()
self
.
batch_mask_decoder
=
torch
.
tensor
(
numpy
.
array
(
self
.
batch_mask_decoder
)).
cuda
()
def
tokenize_encoder_prompt
(
self
,
encoder_prompt
:
str
,
tokenizer
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Utility to tokenize the encoder_prompt
Args:
encoder_prompt (str): The encoder_prompt
tokenizer (_type_): Tokenizer used for tokenizing and detokenizing string
Returns:
torch.Tensor: Returns the tokenized prompt
"""
# if there is the word "<mask>" in prompt, replacing it with special_additional_token,
# similar to processing step in megatron/core/datasets/t5_dataset.py
divided_encoder_prompt_list
=
encoder_prompt
.
split
(
"<mask>"
)
masks_count
=
len
(
divided_encoder_prompt_list
)
-
1
sentinels
=
deque
(
tokenizer
.
additional_special_tokens_ids
)
encoder_prompt_tokens
=
[]
for
divided_encoder_prompt
in
divided_encoder_prompt_list
:
divided_encoder_prompt_tokens
=
tokenizer
.
tokenize
(
divided_encoder_prompt
)
encoder_prompt_tokens
.
extend
(
divided_encoder_prompt_tokens
)
if
masks_count
>
0
:
sentinel
=
sentinels
.
popleft
()
encoder_prompt_tokens
.
extend
([
sentinel
])
masks_count
-=
1
return
encoder_prompt_tokens
def
pad_encoder_prompts_tokens
(
self
,
encoder_prompts_tokens_list
:
List
[
List
[
int
]],
max_sequence_length
:
int
,
tokenizer
)
->
torch
.
Tensor
:
"""Method to pad input prompts
Given a list of prompts, pad them all to uniform length
Args:
encoder_prompts_tokens_list (List[List[int]]): A list containing the
encoder_input_tokens
max_sequence_length (int): Maximum of the length of the encoder inputs tokens
tokenizer (_type_): Tokenizer used for tokenizing and detokenizing text
Returns:
torch.Tensor: A torch tensor of shape [bs, max_sequence_length]
"""
for
encoder_prompt_tokens
in
encoder_prompts_tokens_list
:
padding_size
=
max_sequence_length
-
len
(
encoder_prompt_tokens
)
encoder_prompt_tokens
.
extend
([
tokenizer
.
pad
]
*
padding_size
)
return
torch
.
tensor
(
encoder_prompts_tokens_list
).
cuda
()
def
get_batch_for_context_window
(
self
,
context_start_position
:
int
,
context_end_position
:
int
)
->
List
:
"""Returns the inference data given context window
This function gets called iteratively in a loop . Given the start and end context
positions , it extracts the appropriate data.
Args:
context_start_position (int): Start of the context window. During
the first inference step it is mostly 0
context_end_position (int): End of the context window. During the
last inference step it will mostly be the max generated sequence length.
Returns:
List: A list of inputs that will be used by your model in the forward step
"""
# T5 inference not yet support kv_cache
encoder_tokens2use
=
self
.
batch_encoder_prompts_tokens
decoder_tokens2use
=
self
.
prompts_tokens
[:,
:
context_end_position
]
encoder_mask2use
=
self
.
batch_mask_encoder
decoder_mask2use
=
self
.
batch_mask_decoder
[:,
:
context_end_position
]
# Configure attention mask based on different conditions
# (e.g., transformer-impl, TE versions, TE backends)
[
encoder_mask2use
,
decoder_mask2use
,
encoder_decoder_mask2use
]
=
(
T5MaskedWordPieceDataset
.
config_attention_mask
(
encoder_tokens2use
,
decoder_tokens2use
,
encoder_mask2use
,
decoder_mask2use
,
self
.
use_local
,
)
)
data_at_step_idx
=
[
encoder_tokens2use
,
decoder_tokens2use
,
encoder_mask2use
,
decoder_mask2use
,
encoder_decoder_mask2use
,
]
return
data_at_step_idx
def
forward_pass_without_pipeline_parallel
(
self
,
inference_input
:
List
)
->
torch
.
Tensor
:
"""Utility to carry out simple forward pass for TP or no model parallel models
Runs a very simple forward pass for model. Used in the case of models without
any parallelism or only tensor parallelism.
Args:
inference_input (List): A list containg the inputs for the gpt
model [tokens, position ids, attention mask]
Returns:
torch.Tensor: The output logits of shape [batch_size, seq_len, padded_vocab_size]
"""
[
encoder_tokens
,
decoder_tokens
,
encoder_mask
,
decoder_mask
,
encoder_decoder_mask
]
=
(
inference_input
)
tokens
=
decoder_tokens
# T5 inference not yet support kv_cache
logits
=
self
.
model
(
encoder_tokens
,
decoder_tokens
,
encoder_mask
,
decoder_mask
,
encoder_decoder_mask
,
inference_params
=
None
,
)
logits
=
tensor_parallel
.
gather_from_tensor_model_parallel_region
(
logits
)
return
logits
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
collections
import
deque
from
typing
import
Any
,
Dict
,
List
,
Optional
import
numpy
import
torch
from
megatron.core
import
tensor_parallel
from
megatron.core.datasets.t5_dataset
import
T5MaskedWordPieceDataset
from
megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper
import
(
AbstractModelInferenceWrapper
,
)
from
megatron.core.inference.model_inference_wrappers.inference_wrapper_config
import
(
InferenceWrapperConfig
,
)
from
megatron.core.models.T5
import
T5Model
from
megatron.core.utils
import
get_attr_wrapped_model
# pylint: disable=line-too-long
class
T5InferenceWrapper
(
AbstractModelInferenceWrapper
):
"""Constructor for the model inference wrapper
The wrapper prepares the model for inference, provides the required input
data, and runs the forward pass
Args:
model (T5Model): The T5 model (MCore or legacy)
inference_wrapper_config (InferenceWrapperConfig): The command line arguments that were passed
use_local (bool): Whether the T5 model's transformer impl
is local (vs transformer_engine)
"""
def
__init__
(
self
,
model
:
T5Model
,
inference_wrapper_config
:
InferenceWrapperConfig
,
use_local
:
bool
=
False
,
):
super
().
__init__
(
model
,
inference_wrapper_config
)
self
.
use_local
=
use_local
def
prep_inference_input
(
self
,
prompts_tokens
:
torch
.
Tensor
,
encoder_prompts
:
Optional
[
List
[
str
]]
=
None
,
tokenizer
:
Any
=
None
,
)
->
Dict
[
str
,
Any
]:
"""Prepares the inference input data.
Args:
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_seq_len]
encoder_prompts (dict): List of string of encoder input prompts
tokenizer (_type_): Tokenizer used for tokenizing and detokenizing text
Returns:
A dict with all the inference input needed for the batch.
"""
# get max_sequence_length
max_sequence_length
=
get_attr_wrapped_model
(
self
.
model
,
"max_sequence_length"
)
encoder_prompts_tokens_list
=
[
self
.
tokenize_encoder_prompt
(
encoder_prompt
,
tokenizer
)
for
encoder_prompt
in
encoder_prompts
]
batch_encoder_prompts_tokens
=
self
.
pad_encoder_prompts_tokens
(
encoder_prompts_tokens_list
,
max_sequence_length
,
tokenizer
)
# create batch mask for encoder_prompt (self.batch_input_tokens) and
# decoder_input (prompts_tokens), similar to megatron/core/datasets/t5_dataset.py
decoder_prompts_tokens
=
prompts_tokens
encoder_prompts_tokens
=
batch_encoder_prompts_tokens
decoder_prompts_tokens_numpy
=
decoder_prompts_tokens
.
cpu
().
numpy
()
encoder_prompts_tokens_numpy
=
encoder_prompts_tokens
.
cpu
().
numpy
()
batch_mask_encoder
=
[]
batch_mask_decoder
=
[]
for
i
in
range
(
len
(
prompts_tokens
)):
mask_encoder
=
encoder_prompts_tokens_numpy
[
i
]
==
tokenizer
.
pad
mask_decoder
=
decoder_prompts_tokens_numpy
[
i
]
==
tokenizer
.
pad
batch_mask_encoder
.
append
(
mask_encoder
)
batch_mask_decoder
.
append
(
mask_decoder
)
batch_mask_encoder
=
torch
.
tensor
(
numpy
.
array
(
batch_mask_encoder
)).
cuda
()
batch_mask_decoder
=
torch
.
tensor
(
numpy
.
array
(
batch_mask_decoder
)).
cuda
()
return
{
"encoder_tokens"
:
encoder_prompts_tokens
,
"decoder_tokens"
:
decoder_prompts_tokens
,
"encoder_mask"
:
batch_mask_encoder
,
"decoder_mask"
:
batch_mask_decoder
,
}
def
tokenize_encoder_prompt
(
self
,
encoder_prompt
:
str
,
tokenizer
)
->
torch
.
Tensor
:
"""Utility to tokenize the encoder_prompt
Args:
encoder_prompt (str): The encoder_prompt
tokenizer (_type_): Tokenizer used for tokenizing and detokenizing string
Returns:
torch.Tensor: Returns the tokenized prompt
"""
# if there is the word "<mask>" in prompt, replacing it with special_additional_token,
# similar to processing step in megatron/core/datasets/t5_dataset.py
divided_encoder_prompt_list
=
encoder_prompt
.
split
(
"<mask>"
)
masks_count
=
len
(
divided_encoder_prompt_list
)
-
1
sentinels
=
deque
(
tokenizer
.
additional_special_tokens_ids
)
encoder_prompt_tokens
=
[]
for
divided_encoder_prompt
in
divided_encoder_prompt_list
:
divided_encoder_prompt_tokens
=
tokenizer
.
tokenize
(
divided_encoder_prompt
)
encoder_prompt_tokens
.
extend
(
divided_encoder_prompt_tokens
)
if
masks_count
>
0
:
sentinel
=
sentinels
.
popleft
()
encoder_prompt_tokens
.
extend
([
sentinel
])
masks_count
-=
1
return
encoder_prompt_tokens
def
pad_encoder_prompts_tokens
(
self
,
encoder_prompts_tokens_list
:
List
[
List
[
int
]],
max_sequence_length
:
int
,
tokenizer
)
->
torch
.
Tensor
:
"""Method to pad input prompts
Given a list of prompts, pad them all to uniform length
Args:
encoder_prompts_tokens_list (List[List[int]]): A list containing the
encoder_input_tokens
max_sequence_length (int): Maximum of the length of the encoder inputs tokens
tokenizer (_type_): Tokenizer used for tokenizing and detokenizing text
Returns:
torch.Tensor: A torch tensor of shape [bs, max_sequence_length]
"""
for
encoder_prompt_tokens
in
encoder_prompts_tokens_list
:
padding_size
=
max_sequence_length
-
len
(
encoder_prompt_tokens
)
encoder_prompt_tokens
.
extend
([
tokenizer
.
pad
]
*
padding_size
)
return
torch
.
tensor
(
encoder_prompts_tokens_list
).
cuda
()
def
get_batch_for_context_window
(
self
,
inference_input
:
Dict
[
str
,
Any
],
context_start_position
:
int
,
context_end_position
:
int
,
)
->
Dict
[
str
,
Any
]:
"""Returns the inference data given context window
This function gets called iteratively in a loop . Given the start and end context
positions , it extracts the appropriate data.
Args:
inference_input (Dict[str, Any]): The inference input for the batch.
context_start_position (int): Start of the context window. During
the first inference step it is mostly 0
context_end_position (int): End of the context window. During the
last inference step it will mostly be the max generated sequence length.
Returns:
Dict: A dict of inputs that will be used by your model in the forward step
"""
# T5 inference not yet support kv_cache
encoder_tokens2use
=
inference_input
[
"encoder_tokens"
]
decoder_tokens2use
=
inference_input
[
"decoder_tokens"
][:,
:
context_end_position
]
encoder_mask2use
=
inference_input
[
"encoder_mask"
]
decoder_mask2use
=
inference_input
[
"decoder_mask"
][:,
:
context_end_position
]
# Configure attention mask based on different conditions
# (e.g., transformer-impl, TE versions, TE backends)
[
encoder_mask2use
,
decoder_mask2use
,
encoder_decoder_mask2use
]
=
(
T5MaskedWordPieceDataset
.
config_attention_mask
(
encoder_tokens2use
,
decoder_tokens2use
,
encoder_mask2use
,
decoder_mask2use
,
self
.
use_local
,
)
)
return
{
"encoder_tokens"
:
encoder_tokens2use
,
"decoder_tokens"
:
decoder_tokens2use
,
"encoder_mask"
:
encoder_mask2use
,
"decoder_mask"
:
decoder_mask2use
,
"encoder_decoder_mask"
:
encoder_decoder_mask2use
,
}
def
forward_pass_without_pipeline_parallel
(
self
,
inference_input
:
Dict
[
str
,
Any
]
)
->
torch
.
Tensor
:
"""Utility to carry out simple forward pass for TP or no model parallel models
Runs a very simple forward pass for model. Used in the case of models without
any parallelism or only tensor parallelism.
Args:
inference_input (Dict[str, Any]): A dict containg the inputs for the gpt
model [tokens, position ids, attention mask]
Returns:
torch.Tensor: The output logits of shape [batch_size, seq_len, padded_vocab_size]
"""
encoder_tokens
=
inference_input
[
"encoder_tokens"
]
decoder_tokens
=
inference_input
[
"decoder_tokens"
]
encoder_mask
=
inference_input
[
"encoder_mask"
]
decoder_mask
=
inference_input
[
"decoder_mask"
]
encoder_decoder_mask
=
inference_input
[
"encoder_decoder_mask"
]
tokens
=
decoder_tokens
# T5 inference not yet support kv_cache
logits
=
self
.
model
(
encoder_tokens
,
decoder_tokens
,
encoder_mask
,
decoder_mask
,
encoder_decoder_mask
,
inference_params
=
None
,
)
logits
=
tensor_parallel
.
gather_from_tensor_model_parallel_region
(
logits
)
return
logits
megatron/core/inference/modelopt_support/__init__.py
View file @
688448db
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Integrations with NVIDIA TensorRT Model Optimizer (referred as ModelOpt).
ModelOpt is a library comprising state-of-the-art model optimization techniques including quantization and sparsity to
compress model for efficient inference on NVIDIA GPUs. ModelOpt is integrated with Megatron-core to provide a seamless
experience for users to optimize their Megatron-core models for inference. More details on ModelOpt including
installation and usage can be found at https://github.com/NVIDIA/TensorRT-Model-Optimizer.
"""
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Integrations with NVIDIA TensorRT Model Optimizer (referred as ModelOpt).
ModelOpt is a library comprising state-of-the-art model optimization techniques
including quantization and sparsity to compress model for efficient inference on
NVIDIA GPUs. ModelOpt is integrated with Megatron-core to provide a seamless
experience for users to optimize their Megatron-core models for inference.
More details on ModelOpt including installation and usage can be found at
https://github.com/NVIDIA/TensorRT-Model-Optimizer.
"""
megatron/core/inference/modelopt_support/gpt/model_specs.py
View file @
688448db
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
megatron.core.extensions.transformer_engine
import
TEDotProductAttention
,
TENorm
from
megatron.core.fusions.fused_bias_dropout
import
get_bias_dropout_add
from
megatron.core.models.gpt.gpt_layer_specs
import
_get_mlp_module_spec
from
megatron.core.tensor_parallel.layers
import
ColumnParallelLinear
,
RowParallelLinear
from
megatron.core.transformer.attention
import
SelfAttention
,
SelfAttentionSubmodules
from
megatron.core.transformer.enums
import
AttnMaskType
from
megatron.core.transformer.identity_op
import
IdentityOp
from
megatron.core.transformer.spec_utils
import
ModuleSpec
from
megatron.core.transformer.transformer_layer
import
TransformerLayer
,
TransformerLayerSubmodules
# Use this spec for ModelOpt PTQ and TensorRT-LLM export
def
get_gpt_layer_modelopt_spec
(
num_experts
:
int
=
None
,
moe_grouped_gemm
:
bool
=
False
,
remap_te_layernorm
:
bool
=
False
,
qk_layernorm
:
bool
=
False
,
)
->
ModuleSpec
:
"""Mix the native spec with TENorm.
This is essentially the native local spec except for the layernorm implementation
is using TENorm from Transformer-Engine. The issue is that FusedLayerNorm from apex
has stopped supporting RMSNorm needed by llama.
"""
mlp
=
_get_mlp_module_spec
(
use_te
=
False
,
num_experts
=
num_experts
,
moe_grouped_gemm
=
moe_grouped_gemm
,
fp8
=
False
)
sharded_state_dict_keys_map
=
{}
if
remap_te_layernorm
:
if
num_experts
:
sharded_state_dict_keys_map
=
{
'input_layernorm.'
:
'self_attention.linear_qkv.layer_norm_'
}
else
:
sharded_state_dict_keys_map
=
{
'input_layernorm.'
:
'self_attention.linear_qkv.layer_norm_'
,
'pre_mlp_layernorm.'
:
'mlp.linear_fc1.layer_norm_'
,
}
return
ModuleSpec
(
module
=
TransformerLayer
,
submodules
=
TransformerLayerSubmodules
(
input_layernorm
=
TENorm
,
self_attention
=
ModuleSpec
(
module
=
SelfAttention
,
params
=
{
"attn_mask_type"
:
AttnMaskType
.
causal
},
submodules
=
SelfAttentionSubmodules
(
linear_qkv
=
ColumnParallelLinear
,
core_attention
=
TEDotProductAttention
,
linear_proj
=
RowParallelLinear
,
q_layernorm
=
TENorm
if
qk_layernorm
else
IdentityOp
,
k_layernorm
=
TENorm
if
qk_layernorm
else
IdentityOp
,
),
),
self_attn_bda
=
get_bias_dropout_add
,
pre_mlp_layernorm
=
TENorm
,
mlp
=
mlp
,
mlp_bda
=
get_bias_dropout_add
,
# Map TE-layernorm-fusion keys back
sharded_state_dict_keys_map
=
sharded_state_dict_keys_map
,
),
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
typing
import
Optional
from
megatron.core.extensions.transformer_engine
import
TEDotProductAttention
,
TENorm
from
megatron.core.fusions.fused_bias_dropout
import
get_bias_dropout_add
from
megatron.core.models.gpt.gpt_layer_specs
import
get_mlp_module_spec
from
megatron.core.tensor_parallel.layers
import
ColumnParallelLinear
,
RowParallelLinear
from
megatron.core.transformer.attention
import
SelfAttention
,
SelfAttentionSubmodules
from
megatron.core.transformer.dot_product_attention
import
DotProductAttention
from
megatron.core.transformer.enums
import
AttnMaskType
from
megatron.core.transformer.identity_op
import
IdentityOp
from
megatron.core.transformer.spec_utils
import
ModuleSpec
from
megatron.core.transformer.transformer_layer
import
TransformerLayer
,
TransformerLayerSubmodules
# Use this spec for ModelOpt PTQ and TensorRT-LLM export
def
get_gpt_layer_modelopt_spec
(
num_experts
:
Optional
[
int
]
=
None
,
local_core_attention
:
bool
=
False
,
moe_grouped_gemm
:
bool
=
False
,
remap_te_layernorm
:
bool
=
False
,
qk_layernorm
:
bool
=
False
,
)
->
ModuleSpec
:
"""Mix the native spec with TENorm.
This is essentially the native local spec except for the layernorm implementation
is using TENorm from Transformer-Engine. The issue is that FusedLayerNorm from apex
has stopped supporting RMSNorm needed by llama.
"""
core_attention
=
DotProductAttention
if
local_core_attention
else
TEDotProductAttention
mlp
=
get_mlp_module_spec
(
use_te
=
False
,
num_experts
=
num_experts
,
moe_grouped_gemm
=
moe_grouped_gemm
,
fp8
=
False
)
sharded_state_dict_keys_map
=
{}
if
remap_te_layernorm
:
if
num_experts
:
sharded_state_dict_keys_map
=
{
'input_layernorm.'
:
'self_attention.linear_qkv.layer_norm_'
}
else
:
sharded_state_dict_keys_map
=
{
'input_layernorm.'
:
'self_attention.linear_qkv.layer_norm_'
,
'pre_mlp_layernorm.'
:
'mlp.linear_fc1.layer_norm_'
,
}
return
ModuleSpec
(
module
=
TransformerLayer
,
submodules
=
TransformerLayerSubmodules
(
input_layernorm
=
TENorm
,
self_attention
=
ModuleSpec
(
module
=
SelfAttention
,
params
=
{
"attn_mask_type"
:
AttnMaskType
.
causal
},
submodules
=
SelfAttentionSubmodules
(
linear_qkv
=
ColumnParallelLinear
,
core_attention
=
core_attention
,
linear_proj
=
RowParallelLinear
,
q_layernorm
=
TENorm
if
qk_layernorm
else
IdentityOp
,
k_layernorm
=
TENorm
if
qk_layernorm
else
IdentityOp
,
),
),
self_attn_bda
=
get_bias_dropout_add
,
pre_mlp_layernorm
=
TENorm
,
mlp
=
mlp
,
mlp_bda
=
get_bias_dropout_add
,
# Map TE-layernorm-fusion keys back
sharded_state_dict_keys_map
=
sharded_state_dict_keys_map
,
),
)
megatron/core/inference/modelopt_support/mamba/__init__.py
0 → 100644
View file @
688448db
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
megatron/core/inference/modelopt_support/mamba/model_specs.py
0 → 100644
View file @
688448db
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
from
megatron.core.extensions.transformer_engine
import
TEDotProductAttention
,
TENorm
from
megatron.core.fusions.fused_bias_dropout
import
get_bias_dropout_add
from
megatron.core.ssm.mamba_block
import
MambaStack
,
MambaStackSubmodules
from
megatron.core.ssm.mamba_layer
import
MambaLayer
,
MambaLayerSubmodules
from
megatron.core.ssm.mamba_mixer
import
MambaMixer
,
MambaMixerSubmodules
from
megatron.core.tensor_parallel.layers
import
ColumnParallelLinear
,
RowParallelLinear
from
megatron.core.transformer.attention
import
SelfAttention
,
SelfAttentionSubmodules
from
megatron.core.transformer.dot_product_attention
import
DotProductAttention
from
megatron.core.transformer.enums
import
AttnMaskType
from
megatron.core.transformer.mlp
import
MLP
,
MLPSubmodules
from
megatron.core.transformer.spec_utils
import
ModuleSpec
from
megatron.core.transformer.transformer_layer
import
TransformerLayer
,
TransformerLayerSubmodules
# Use this spec for ModelOpt PTQ and TensorRT-LLM export
def
get_mamba_stack_modelopt_spec
(
local_core_attention
:
bool
=
False
,
remap_te_layernorm
:
bool
=
False
)
->
ModuleSpec
:
"""Mix the native spec with TENorm.
This is essentially the native local spec except for the layernorm implementation
is using TENorm from Transformer-Engine.
"""
mamba_state_dict_keys_map
=
{}
transformer_state_dict_keys_map
=
{}
if
remap_te_layernorm
:
mamba_state_dict_keys_map
=
{
'norm.'
:
'mixer.in_proj.layer_norm_'
}
transformer_state_dict_keys_map
=
{
'input_layernorm.'
:
'self_attention.linear_qkv.layer_norm_'
,
'pre_mlp_layernorm.'
:
'mlp.linear_fc1.layer_norm_'
,
}
mamba_layer
=
ModuleSpec
(
module
=
MambaLayer
,
submodules
=
MambaLayerSubmodules
(
norm
=
TENorm
,
mixer
=
ModuleSpec
(
module
=
MambaMixer
,
submodules
=
MambaMixerSubmodules
(
in_proj
=
ColumnParallelLinear
,
out_proj
=
RowParallelLinear
),
),
mamba_bda
=
get_bias_dropout_add
,
sharded_state_dict_keys_map
=
mamba_state_dict_keys_map
,
),
)
core_attention
=
DotProductAttention
if
local_core_attention
else
TEDotProductAttention
attention_layer
=
ModuleSpec
(
module
=
TransformerLayer
,
submodules
=
TransformerLayerSubmodules
(
input_layernorm
=
TENorm
,
self_attention
=
ModuleSpec
(
module
=
SelfAttention
,
params
=
{
"attn_mask_type"
:
AttnMaskType
.
causal
},
submodules
=
SelfAttentionSubmodules
(
linear_qkv
=
ColumnParallelLinear
,
core_attention
=
core_attention
,
linear_proj
=
RowParallelLinear
,
),
),
self_attn_bda
=
get_bias_dropout_add
,
sharded_state_dict_keys_map
=
transformer_state_dict_keys_map
,
),
)
mlp_layer
=
ModuleSpec
(
module
=
TransformerLayer
,
submodules
=
TransformerLayerSubmodules
(
pre_mlp_layernorm
=
TENorm
,
mlp
=
ModuleSpec
(
module
=
MLP
,
submodules
=
MLPSubmodules
(
linear_fc1
=
ColumnParallelLinear
,
linear_fc2
=
RowParallelLinear
),
),
mlp_bda
=
get_bias_dropout_add
,
sharded_state_dict_keys_map
=
transformer_state_dict_keys_map
,
),
)
return
ModuleSpec
(
module
=
MambaStack
,
submodules
=
MambaStackSubmodules
(
mamba_layer
=
mamba_layer
,
attention_layer
=
attention_layer
,
mlp_layer
=
mlp_layer
),
)
megatron/core/inference/sampling_params.py
View file @
688448db
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
dataclasses
import
dataclass
@
dataclass
class
SamplingParams
:
"""Inference parameters sent along with the prompts.
This class contains request-level attributes that control the sampling techniques used when
generating text. This is distinct from megatron.core.InferenceParams, which is sets model-level
inference attributes such as the maximum sequence length, and contains the KV cache.
For an explanation of these parameters refer to this blog
https://ivibudh.medium.com/a-guide-to-controlling-llm-model-output-exploring-top-k-top-p-and-
temperature-parameters-ed6a31313910
"""
temperature
:
float
=
1.0
top_k
:
int
=
0
top_p
:
float
=
0.0
return_log_probs
:
bool
=
False
num_tokens_to_generate
:
int
=
30
def
add_attributes
(
self
,
attribute_value_pair
:
dict
):
"""Utility to add more attributes to sampling params
Use this method to pass in a custom dictionary to add more sampling parameter attributes.
c = SamplingParams
c.add_attributes({'min_length':4, 'eod_id':153})
Args:
attribute_value_pair (dict): A dictionary containing attributes as the key names and
their values as the values.
"""
for
key
,
value
in
attribute_value_pair
.
items
():
setattr
(
self
,
key
,
value
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
dataclasses
import
dataclass
@
dataclass
class
SamplingParams
:
"""Inference parameters sent along with the prompts.
This class contains request-level attributes that control the sampling techniques used when
generating text. This is distinct from megatron.core.InferenceParams, which is sets model-level
inference attributes such as the maximum sequence length, and contains the KV cache.
For an explanation of these parameters refer to this blog
https://ivibudh.medium.com/a-guide-to-controlling-llm-model-output-exploring-top-k-top-p-and-
temperature-parameters-ed6a31313910
"""
temperature
:
float
=
1.0
top_k
:
int
=
0
top_p
:
float
=
0.0
return_log_probs
:
bool
=
False
return_segments
:
bool
=
False
# Whether to return individually detokenized tokens
num_tokens_to_generate
:
int
=
30
def
add_attributes
(
self
,
attribute_value_pair
:
dict
):
"""Utility to add more attributes to sampling params
Use this method to pass in a custom dictionary to add more sampling parameter attributes.
c = SamplingParams
c.add_attributes({'min_length':4, 'eod_id':153})
Args:
attribute_value_pair (dict): A dictionary containing attributes as the key names and
their values as the values.
"""
for
key
,
value
in
attribute_value_pair
.
items
():
setattr
(
self
,
key
,
value
)
megatron/core/inference/scheduler.py
View file @
688448db
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import
time
import
typing
from
collections
import
OrderedDict
from
typing
import
Dict
import
torch
from
megatron.core.inference.inference_request
import
InferenceRequest
,
Status
from
megatron.core.inference.sampling_params
import
SamplingParams
from
megatron.core.inference.utils
import
Counter
class
Scheduler
:
"""Scheduler for handling requests to inference engine
This class is responsible for handing of all the incomign requests
Args:
max_batch_size (int): The max batch size that we can pass to the
inference engine at a time.
"""
def
__init__
(
self
,
max_batch_size
:
int
):
self
.
max_batch_size
=
max_batch_size
self
.
active_request_pool
:
Dict
[
int
,
InferenceRequest
]
=
OrderedDict
()
self
.
waiting_request_pool
:
Dict
[
int
,
InferenceRequest
]
=
OrderedDict
()
self
.
completed_request_pool
:
Dict
[
int
,
InferenceRequest
]
=
OrderedDict
()
self
.
request_counter
=
Counter
()
def
add_request
(
self
,
prompt
:
str
,
prompt_tokens
:
torch
.
Tensor
,
encoder_prompt
:
str
=
None
,
inference_parameters
:
SamplingParams
=
None
,
arrival_time
:
float
=
None
,
):
"""Add an incoming request
This method will add the request to either the active pool or the waiting pool
depending on the batch size.
Args:
prompt (str): Input prompt string
prompt_tokens (torch.Tensor): A torch tensor having the input prompts tokenized
encoder_prompt (str): Encoder input string
inference_parameters (SamplingParams): The inference parameters
arrival_time (float, optional): The incoming request time. Defaults to None.
"""
request_id
=
str
(
next
(
self
.
request_counter
))
if
arrival_time
is
None
:
arrival_time
=
time
.
time
()
status
=
(
Status
.
ACTIVE_BUT_NOT_GENERATING_TOKENS
if
len
(
self
.
active_request_pool
)
<
self
.
max_batch_size
else
Status
.
WAITING_IN_QUEUE
)
inference_request
=
InferenceRequest
(
request_id
=
request_id
,
prompt
=
prompt
,
inference_parameters
=
inference_parameters
,
arrival_time
=
arrival_time
,
prompt_tokens
=
prompt_tokens
,
status
=
status
,
encoder_prompt
=
encoder_prompt
,
)
if
status
==
status
.
ACTIVE_BUT_NOT_GENERATING_TOKENS
:
self
.
active_request_pool
[
request_id
]
=
inference_request
else
:
self
.
waiting_request_pool
[
request_id
]
=
inference_request
def
have_requests_pending
(
self
)
->
bool
:
"""Method to check if there are requests pending
This method returns False only when there are no active requests or waiting requests.
"""
num_requests_pending
=
len
(
self
.
active_request_pool
)
+
len
(
self
.
waiting_request_pool
)
return
num_requests_pending
>
0
def
add_earliest_waiting_request_to_active_pool
(
self
):
"""Utility to add the waiting request to active pool
This method will add the earliest request (FIFO) that is in the waiting request
pool to the active request pool.
"""
assert
(
len
(
self
.
active_request_pool
)
<
self
.
max_batch_size
),
"Active request pool is already full. Cant add any more requests"
if
len
(
self
.
waiting_request_pool
)
>
0
:
(
earliest_waiting_request_request_id
,
earliest_waiting_request
)
=
(
self
.
waiting_request_pool
.
popitem
(
last
=
False
)
)
earliest_waiting_request
.
status
=
Status
.
ACTIVE_BUT_NOT_GENERATING_TOKENS
self
.
active_request_pool
[
earliest_waiting_request_request_id
]
=
earliest_waiting_request
def
update_requests_pools
(
self
,
result_dict
:
typing
.
OrderedDict
[
int
,
InferenceRequest
]
=
None
):
"""Update request pool status
This method will full up the active request pool, if it has less than max batch size
elements from the waiting request pool.
If provided with a request dict, it will put the completed requests into the completed
request pool and add waiting request into active pool.
Args:
result (typing.OrderedDict[int, InferenceRequest], optional): The result returned
by the engine. A dictionary with keys as the request ids, and values as the
requests. Defaults to None
"""
for
result_request_id
in
list
(
result_dict
.
keys
()):
active_request
=
self
.
active_request_pool
[
result_request_id
]
# If a request has completed put it into the completed request pool.
if
active_request
.
status
==
Status
.
COMPLETED
:
completed_request
=
self
.
active_request_pool
.
pop
(
result_request_id
)
self
.
completed_request_pool
[
result_request_id
]
=
completed_request
# If the active request pool is not full, add waiting requests in FIFO order
while
(
len
(
self
.
active_request_pool
)
<
self
.
max_batch_size
and
len
(
self
.
waiting_request_pool
)
>
0
):
self
.
add_earliest_waiting_request_to_active_pool
()
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import
functools
import
time
import
typing
from
collections
import
OrderedDict
from
typing
import
Dict
,
Optional
,
Type
,
Union
import
torch
from
megatron.core.inference.async_stream
import
AsyncStream
from
megatron.core.inference.inference_request
import
InferenceRequest
,
Status
from
megatron.core.inference.sampling_params
import
SamplingParams
from
megatron.core.inference.utils
import
Counter
class
Scheduler
:
"""Scheduler for handling requests to inference engine
This class is responsible for handing of all the incomign requests
Args:
max_batch_size (int): The max batch size that we can pass to the
inference engine at a time.
request_type (InferenceRequest): The class to use for instantiating new requests.
"""
def
__init__
(
self
,
max_batch_size
):
self
.
max_batch_size
=
max_batch_size
self
.
requests
:
Dict
[
str
,
InferenceRequest
]
=
OrderedDict
()
self
.
streams
:
Dict
[
str
,
AsyncStream
]
=
OrderedDict
()
self
.
active_request_pool
:
Dict
[
str
,
InferenceRequest
]
=
OrderedDict
()
self
.
waiting_request_pool
:
Dict
[
str
,
InferenceRequest
]
=
OrderedDict
()
self
.
completed_request_pool
:
Dict
[
str
,
InferenceRequest
]
=
OrderedDict
()
self
.
request_counter
=
Counter
()
def
get_new_request_id
(
self
)
->
str
:
"""Gets a new request id"""
request_id
=
str
(
next
(
self
.
request_counter
))
return
request_id
def
add_request
(
self
,
prompt
:
Optional
[
str
]
=
None
,
prompt_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
encoder_prompt
:
Optional
[
str
]
=
None
,
inference_parameters
:
Optional
[
SamplingParams
]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
streaming
:
bool
=
False
,
inference_request
:
Optional
[
InferenceRequest
]
=
None
,
)
->
str
:
"""Add an incoming request
This method will add the request to either the active pool or the waiting pool
depending on the batch size.
Args:
prompt (str): Input prompt string
prompt_tokens (torch.Tensor): A torch tensor having the input prompts tokenized
encoder_prompt (str): Encoder input string
inference_parameters (SamplingParams): The inference parameters
arrival_time (float, optional): The incoming request time. Defaults to None.
streaming (bool, optional): Whether to asynchronously stream tokens for this request.
inference_request (InferenceRequest, optional): A fully constructed request.
Defaults to None.
Returns:
The request_id for the new request.
"""
status
=
(
Status
.
ACTIVE_BUT_NOT_GENERATING_TOKENS
if
len
(
self
.
active_request_pool
)
<
self
.
max_batch_size
else
Status
.
WAITING_IN_QUEUE
)
if
inference_request
is
None
:
assert
prompt
is
not
None
assert
prompt_tokens
is
not
None
request_id
=
self
.
get_new_request_id
()
if
arrival_time
is
None
:
arrival_time
=
time
.
time
()
inference_request
=
InferenceRequest
(
request_id
=
request_id
,
prompt
=
prompt
,
inference_parameters
=
inference_parameters
,
arrival_time
=
arrival_time
,
prompt_tokens
=
prompt_tokens
,
status
=
status
,
encoder_prompt
=
encoder_prompt
,
)
else
:
request_id
=
inference_request
.
request_id
inference_request
.
status
=
status
if
inference_request
.
arrival_time
is
None
:
inference_request
.
arrival_time
=
time
.
time
()
self
.
requests
[
request_id
]
=
inference_request
if
streaming
:
abort_request
=
functools
.
partial
(
self
.
abort_request
,
request_id
=
request_id
)
self
.
streams
[
request_id
]
=
AsyncStream
(
request_id
,
abort_request
)
if
status
==
status
.
ACTIVE_BUT_NOT_GENERATING_TOKENS
:
self
.
active_request_pool
[
request_id
]
=
inference_request
else
:
self
.
waiting_request_pool
[
request_id
]
=
inference_request
return
request_id
def
have_requests_pending
(
self
)
->
bool
:
"""Method to check if there are requests pending
This method returns False only when there are no active requests or waiting requests.
"""
num_requests_pending
=
len
(
self
.
active_request_pool
)
+
len
(
self
.
waiting_request_pool
)
return
num_requests_pending
>
0
def
add_earliest_waiting_request_to_active_pool
(
self
):
"""Utility to add the waiting request to active pool
This method will add the earliest request (FIFO) that is in the waiting request
pool to the active request pool.
"""
assert
(
len
(
self
.
active_request_pool
)
<
self
.
max_batch_size
),
"Active request pool is already full. Cant add any more requests"
if
len
(
self
.
waiting_request_pool
)
>
0
:
(
earliest_waiting_request_request_id
,
earliest_waiting_request
)
=
(
self
.
waiting_request_pool
.
popitem
(
last
=
False
)
)
earliest_waiting_request
.
status
=
Status
.
ACTIVE_BUT_NOT_GENERATING_TOKENS
self
.
active_request_pool
[
earliest_waiting_request_request_id
]
=
earliest_waiting_request
def
update_requests_pools
(
self
,
result_dict
:
Optional
[
typing
.
OrderedDict
[
str
,
InferenceRequest
]]
=
None
):
"""Update request pool status
This method will full up the active request pool, if it has less than max batch size
elements from the waiting request pool.
If provided with a request dict, it will put the completed requests into the completed
request pool and add waiting request into active pool.
Args:
result (typing.OrderedDict[str, InferenceRequest], optional): The result returned
by the engine. A dictionary with keys as the request ids, and values as the
requests. Defaults to None
"""
for
result_request_id
in
list
(
result_dict
.
keys
()):
active_request
=
self
.
active_request_pool
[
result_request_id
]
# If a request has completed put it into the completed request pool.
if
active_request
.
status
==
Status
.
COMPLETED
:
completed_request
=
self
.
active_request_pool
.
pop
(
result_request_id
)
self
.
completed_request_pool
[
result_request_id
]
=
completed_request
# If the active request pool is not full, add waiting requests in FIFO order
while
(
len
(
self
.
active_request_pool
)
<
self
.
max_batch_size
and
len
(
self
.
waiting_request_pool
)
>
0
):
self
.
add_earliest_waiting_request_to_active_pool
()
def
abort_request
(
self
,
request_id
:
str
,
*
,
exception
:
Optional
[
Union
[
BaseException
,
Type
[
BaseException
]]]
=
None
):
"""Cancels the given request"""
stream
=
self
.
streams
.
get
(
request_id
,
None
)
if
stream
is
not
None
:
stream
.
finish
(
exception
=
exception
)
megatron/core/inference/text_generation_controllers/encoder_decoder_text_generation_controller.py
View file @
688448db
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
typing
import
OrderedDict
import
torch
from
megatron.core.inference.inference_request
import
InferenceRequest
from
megatron.core.inference.text_generation_controllers.text_generation_controller
import
(
TextGenerationController
,
)
class
EncoderDecoderTextGenerationController
(
TextGenerationController
):
"""The text generation controller for encoder-decoder architecture
This class inherits from TextGenerationController, adding features
relating to encoder input encoder_prompt
"""
def
prep_model_for_inference
(
self
,
prompts_tokens
:
torch
.
Tensor
,
active_requests
:
OrderedDict
[
int
,
InferenceRequest
]
):
"""Preparing batch for inference, using respective wrapper's prep_model_for_inference method
Args:
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length]
active_requests (OrderedDict[int, InferenceRequest]): The input active requests
"""
encoder_prompts
=
list
(
map
(
lambda
request
:
request
.
encoder_prompt
,
active_requests
.
values
())
)
self
.
inference_wrapped_model
.
prep_model_for_inference
(
prompts_tokens
=
prompts_tokens
,
encoder_prompts
=
encoder_prompts
,
tokenizer
=
self
.
tokenizer
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
typing
import
Any
,
Dict
,
OrderedDict
import
torch
from
megatron.core.inference.inference_request
import
InferenceRequest
from
megatron.core.inference.text_generation_controllers.text_generation_controller
import
(
TextGenerationController
,
)
class
EncoderDecoderTextGenerationController
(
TextGenerationController
):
"""The text generation controller for encoder-decoder architecture
This class inherits from TextGenerationController, adding features
relating to encoder input encoder_prompt
"""
def
prep_inference_input
(
self
,
prompts_tokens
:
torch
.
Tensor
,
active_requests
:
OrderedDict
[
str
,
InferenceRequest
]
)
->
Dict
[
str
,
Any
]:
"""Preparing input data for inference, using respective wrapper's prep_inference_input method # pylint: disable=line-too-long
Args:
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length]
active_requests (OrderedDict[str, InferenceRequest]): The input active requests
Returns:
A dict of the inference input for the current batch.
"""
encoder_prompts
=
list
(
map
(
lambda
request
:
request
.
encoder_prompt
,
active_requests
.
values
())
)
return
self
.
inference_wrapped_model
.
prep_inference_input
(
prompts_tokens
,
encoder_prompts
,
tokenizer
=
self
.
tokenizer
)
megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py
View file @
688448db
# Copyright (c) 2024, NVIDIA CORPORATION.
All rights reserved.
from
megatron.core.inference.text_generation_controllers.text_generation_controller
import
(
# noqa: F401 # pylint: disable=unused-import
TextGenerationController
as
SimpleTextGenerationController
,
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
megatron.core.inference.text_generation_controllers.text_generation_controller
import
(
# noqa: F401 # pylint: disable=unused-import
TextGenerationController
as
SimpleTextGenerationController
,
)
megatron/core/inference/text_generation_controllers/text_generation_controller.py
View file @
688448db
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
typing
import
List
,
OrderedDict
,
Tuple
import
torch
import
torch.nn.functional
as
F
from
megatron.core
import
parallel_state
from
megatron.core.inference.communication_utils
import
broadcast_from_last_pipeline_stage
from
megatron.core.inference.inference_request
import
InferenceRequest
,
Status
from
megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper
import
(
AbstractModelInferenceWrapper
,
)
from
megatron.core.inference.sampling_params
import
SamplingParams
class
TextGenerationController
:
"""The text generation controller (the main sampling loop)
This class tokenizes the input, runs inference, samples from logits, and detokenizes the output.
Args:
inference_wrapped_model (AbstractModelInferenceWrapper): A model that
is wrapped using the specs given in the abstract_model_inference_wrapper.py
tokenizer (_type_): Tokenizer used for tokenizing and detokenizing the prompts
"""
def
__init__
(
self
,
inference_wrapped_model
:
AbstractModelInferenceWrapper
,
tokenizer
):
self
.
inference_wrapped_model
=
inference_wrapped_model
self
.
tokenizer
=
tokenizer
# For models without pipeline parallelism, is_first_stage and is_last_stage returns True
self
.
model_is_pipeline_parallel
=
not
(
parallel_state
.
is_pipeline_first_stage
()
and
parallel_state
.
is_pipeline_last_stage
()
)
def
tokenize_prompt
(
self
,
prompt
:
str
,
add_BOS
:
bool
=
False
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Utility to tokenize the input prompts
Args:
prompt (str): The input prompt
Returns:
torch.Tensor: Returns the tokenized prompt
"""
prompt_tokens
=
self
.
tokenizer
.
tokenize
(
prompt
)
if
add_BOS
:
prompt_tokens
=
[
self
.
tokenizer
.
bos
]
+
prompt_tokens
return
prompt_tokens
def
detokenize_generations
(
self
,
prompt_tokens_with_generated_tokens
:
torch
.
Tensor
)
->
str
:
"""Detokenize the output generations
Args:
prompt_tokens_with_generated_tokens (torch.Tensor): The input prompt
tokens plus the generated tokens
Returns:
str: The detokenized output
"""
tokens
=
prompt_tokens_with_generated_tokens
.
cpu
().
numpy
().
tolist
()
return
self
.
tokenizer
.
detokenize
(
tokens
)
def
sample_from_logits
(
self
,
last_token_logits
:
torch
.
Tensor
,
sampling_params
:
SamplingParams
=
None
,
vocab_size
:
int
=
None
,
**
kwargs
)
->
torch
.
Tensor
:
"""Samples the logits to generate outputs
Given the logits of the last token, this function samples it
according to the parameters defined in sampling_params
and returns the samples
Args:
last_token_logits (torch.Tensor): The last token logits. A tensor of
size [batch_size, vocab_size]
sampling_params (SamplingParams): The parameters to use for inference.
vocab_size (int): Obtained from the tokenizer. Defaults to None
Returns:
torch.Tensor: 1D tensor of the sampled logits with [batch_size] elements
"""
if
kwargs
.
get
(
'common_inference_params'
):
sampling_params
=
kwargs
[
'common_inference_params'
]
top_p
=
sampling_params
.
top_p
top_k
=
sampling_params
.
top_k
temperature
=
sampling_params
.
temperature
assert
not
(
top_k
>
0
and
top_p
>
0
),
'Cannot have top-p and top-k both greater than zero'
assert
top_p
<=
1.0
,
'top-p should be in (0,1]'
def
modify_logits_for_top_k_filtering
(
logits
,
top_k
):
"""Set the logits for none top-k values to -inf."""
filter_
=
logits
<
torch
.
topk
(
logits
,
top_k
)[
0
][...,
-
1
,
None
]
logits
.
masked_fill_
(
filter_
,
float
(
'-Inf'
))
def
modify_logits_for_top_p_filtering
(
logits
,
top_p
):
"""Set the logits for none top-p values to -inf."""
# First sort and calculate cumulative sum of probabilities.
sorted_logits
,
sorted_indices
=
torch
.
sort
(
logits
,
descending
=
True
)
cumulative_probs
=
sorted_logits
.
softmax
(
dim
=-
1
).
cumsum
(
dim
=-
1
)
# Filteration based on the cumulative sum.
filter_
=
cumulative_probs
>
top_p
# This shift by 1 is weird and I cannot justify it. This existed
# in the original implementation:
# https://github.com/ari-holtzman/degen/blob/master/gen.py
# and I guess it is needed so keeping it for now.
filter_
[:,
1
:]
=
filter_
[:,
:
-
1
].
clone
()
# Make sure we at least have one token to select from.
filter_
[...,
0
]
=
0
# Fill in the filtered part
filter_
=
filter_
.
scatter
(
1
,
sorted_indices
,
filter_
)
logits
.
masked_fill_
(
filter_
,
float
(
'-Inf'
))
# Greedy sampling
if
top_k
==
1
:
sampled_logits
=
torch
.
argmax
(
last_token_logits
,
dim
=-
1
)
else
:
last_token_logits
=
last_token_logits
.
clone
()
if
temperature
!=
1.0
:
last_token_logits
.
div_
(
temperature
)
if
top_k
>
1
:
assert
top_k
<=
last_token_logits
.
size
(
1
),
'top-k is larger than logit size.'
if
vocab_size
:
assert
top_k
<
vocab_size
,
'top-k is larger than vocab size.'
modify_logits_for_top_k_filtering
(
last_token_logits
,
top_k
)
elif
top_p
>
0.0
:
modify_logits_for_top_p_filtering
(
last_token_logits
,
top_p
)
# After filtering, we need to recalculate the distribution.
probabilities
=
last_token_logits
.
softmax
(
dim
=-
1
)
sampled_logits
=
torch
.
multinomial
(
probabilities
,
num_samples
=
1
).
view
(
-
1
)
# If vocab size is provided, make sure the samples are in in the range [0, vocab-size).
if
vocab_size
:
sampled_logits
=
torch
.
clamp
(
sampled_logits
,
min
=
0
,
max
=
(
vocab_size
-
1
))
return
sampled_logits
def
update_generation_status
(
self
,
updated_prompts_tokens
:
torch
.
Tensor
,
generation_started
:
torch
.
Tensor
,
current_context_end_position
:
int
,
is_generation_done_tensor
:
torch
.
Tensor
,
generated_sequence_lengths
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Checks which prompts have reached an end condition
We check which prompts have reached an end condition and set the corresponding
flags of the is_generation_done_tensor to True. The generated sequence lengths
increase as we keep generating, until that prompts hits an end condition. The
generation_started tensor determines which prompts have started generating.
Args:
updated_prompts_tokens (torch.Tensor): The prompts tokens updated with the latest
generated tokens. A tensor of shape [batch_size, max_seq_len]
(i.e max_seq_len = max_prompt_len + tokens_to_generate)
generation_started (torch.Tensor): A boolean tensor of shape [batch_size]. True
indicates the prompt at that index has started generating tokens.
current_context_end_position (int): An integer indicating which position to
extract from the prompts tokens to get the latest generated tokens.
is_generation_done_tensor (torch.Tensor): A boolean tensor of shape [batch_size].
True indicates the prompt at that index has reached end condition.
generated_sequence_lengths (torch.Tensor): A int tensor of shape [batch_size].
Each value represents the generated sequence lengths for that prompt.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Returns the boolean
is_generation_done_tensor and the generated_sequence_lengths after updating it
"""
latest_samples
=
updated_prompts_tokens
[:,
current_context_end_position
]
# Make sure we are checking eod criterion only for prompts that have started generating
# (i.e) We only look at the generated tokenns and not the input tokens.
reached_eod
=
(
latest_samples
==
self
.
tokenizer
.
eod
)
&
generation_started
is_generation_done_tensor
=
is_generation_done_tensor
|
reached_eod
# We increment generated sequence lengths when that prompt has not hit the
# EOD and generation has started
generated_sequence_lengths
+=
~
is_generation_done_tensor
&
generation_started
return
is_generation_done_tensor
,
generated_sequence_lengths
def
pad_input_prompt_tokens
(
self
,
batch_prompt_tokens_list
:
List
[
List
[
int
]],
max_prompt_length_in_batch
:
int
,
num_tokens_to_generate
:
int
,
)
->
torch
.
Tensor
:
"""Method to pad input prompts
Given a list of prompts, pad them all to uniform length
Args:
batch_prompt_tokens_list (List[List[int]]): A list containing the prompt tokens
max_prompt_length_in_batch (int): Maximum of the length of the input prompt tokens
num_tokens_togenerate (int): The number of tokens to generate for each prompt
Returns:
torch.Tensor: A torch tensor of shape [bs, max_seq_len] (i.e)
max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate,
with extra indices for each tensor padded with mask id.
"""
max_seq_len
=
max_prompt_length_in_batch
+
num_tokens_to_generate
for
prompt_tokens
in
batch_prompt_tokens_list
:
padding_size
=
max_seq_len
-
len
(
prompt_tokens
)
prompt_tokens
.
extend
([
self
.
tokenizer
.
eod
]
*
padding_size
)
return
torch
.
tensor
(
batch_prompt_tokens_list
).
cuda
()
def
generate_output_tokens_dynamic_batch
(
self
,
active_requests
:
OrderedDict
[
int
,
InferenceRequest
]
)
->
OrderedDict
[
int
,
InferenceRequest
]:
"""Utility to generate the output tokens and probabilities for the prompts
This utility generates the output tokens for a dynamic batch. It will run one forward step
at a time, and pass control back to the engine, which will update the request pool and call
this method again.
Args:
active_requests (OrderedDict[int, InferenceRequest]): The input active requests.
Returns:
OrderedDict[int, InferenceRequest]: The result for each of the incoming requests
after running one forward step.
"""
raise
Exception
(
"Not implemented yet"
)
def
generate_all_output_tokens_static_batch
(
self
,
active_requests
:
OrderedDict
[
int
,
InferenceRequest
]
)
->
OrderedDict
[
int
,
InferenceRequest
]:
"""Utility to generate the all the output tokens and probabilities for the prompts .
This utility generates the output tokens for a static batch. It runs the forward steps till
all prompts complete generation, updates the status of these requests to completed, adds
the generated result and returns these requests
Args:
active_requests (OrderedDict[int, InferenceRequest]): The input active requests.
Returns:
OrderedDict[int, InferenceRequest]: The result for each of the incoming requests
"""
batch_prompt_tokens_list
=
list
(
map
(
lambda
request
:
request
.
prompt_tokens
,
active_requests
.
values
())
)
prompt_lengths_in_batch
=
torch
.
tensor
(
[
len
(
prompt_tokens
)
for
prompt_tokens
in
batch_prompt_tokens_list
]
).
cuda
()
max_prompt_length_in_batch
=
max
(
prompt_lengths_in_batch
)
min_prompt_length_in_batch
=
min
(
prompt_lengths_in_batch
)
# For batch inference the inference params are the same for all request
sampling_params
:
SamplingParams
=
list
(
active_requests
.
values
())[
0
].
inference_parameters
# max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate
batch_prompt_tokens
=
self
.
pad_input_prompt_tokens
(
batch_prompt_tokens_list
,
max_prompt_length_in_batch
=
max_prompt_length_in_batch
,
num_tokens_to_generate
=
sampling_params
.
num_tokens_to_generate
,
)
batch_size
,
max_sequence_length
=
batch_prompt_tokens
.
shape
# Pre allocate log probs tensor
output_log_probs
=
None
if
sampling_params
.
return_log_probs
:
output_log_probs
=
torch
.
empty
(
(
batch_size
,
max_sequence_length
-
1
),
dtype
=
torch
.
float32
).
cuda
()
# An array to check which of the prompts have reached end of generation condition
is_generation_done_tensor
=
torch
.
zeros
(
batch_size
,
dtype
=
torch
.
bool
).
cuda
()
# An array to act as a counter to keep track of generated sequence lengths
generated_sequence_lengths
=
torch
.
zeros
(
batch_size
).
cuda
()
with
torch
.
no_grad
():
self
.
prep_model_for_inference
(
prompts_tokens
=
batch_prompt_tokens
,
active_requests
=
active_requests
)
context_start_position
=
0
# Pick the context window that we need to pass through the network.
for
context_end_position
in
range
(
min_prompt_length_in_batch
,
max_sequence_length
):
inference_input
=
self
.
inference_wrapped_model
.
get_batch_for_context_window
(
context_start_position
,
context_end_position
)
# Returns the final logits of shape [batch_size, context_length, vocab_size]
# Note: This is returned in all TP ranks or last PP stage in PP models
logits
=
self
.
inference_wrapped_model
.
run_one_forward_step
(
inference_input
)
if
self
.
model_is_pipeline_parallel
:
context_length
=
context_end_position
-
context_start_position
logits
=
broadcast_from_last_pipeline_stage
(
[
batch_size
,
context_length
,
self
.
tokenizer
.
vocab_size
],
dtype
=
self
.
inference_wrapped_model
.
inference_wrapper_config
.
params_dtype
,
tensor
=
logits
,
)
# Indicates which of the input prompts have started generating tokens.
# A 1D boolean tensor with [batch_size] elements (i.e) The shortest
# prompts will start generating first and so on
generation_started
=
prompt_lengths_in_batch
<=
context_end_position
last_token_logits
=
logits
[:,
-
1
,
:]
sampled_logits
=
self
.
sample_from_logits
(
last_token_logits
,
sampling_params
,
self
.
tokenizer
.
vocab_size
)
# Substitute the sampled logits only for only the prompts that
# have started generating tokens
batch_prompt_tokens
[
generation_started
,
context_end_position
]
=
sampled_logits
[
generation_started
]
if
sampling_params
.
return_log_probs
:
log_probs
=
F
.
log_softmax
(
logits
,
dim
=
2
)
indices
=
torch
.
unsqueeze
(
batch_prompt_tokens
[
:,
(
context_start_position
+
1
)
:
(
context_end_position
+
1
)
],
2
,
)
# Get the log probabilities for only the prompt tokens
output_log_probs
[:,
context_start_position
:
context_end_position
]
=
torch
.
gather
(
log_probs
,
2
,
indices
).
squeeze
(
2
)
context_start_position
=
context_end_position
# Check end of generation status for each tensor
# and update generated sequence lengths
(
is_generation_done_tensor
,
generated_sequence_lengths
)
=
(
self
.
update_generation_status
(
updated_prompts_tokens
=
batch_prompt_tokens
,
generation_started
=
generation_started
,
current_context_end_position
=
context_end_position
,
is_generation_done_tensor
=
is_generation_done_tensor
,
generated_sequence_lengths
=
generated_sequence_lengths
,
)
)
# Boolean flag indicating if all prompts are finished
all_prompts_done
=
torch
.
all
(
is_generation_done_tensor
)
if
all_prompts_done
:
break
# Include all the generated tokens
batch_prompt_tokens_with_generations
=
batch_prompt_tokens
[:,
:
(
context_end_position
+
1
)]
if
sampling_params
.
return_log_probs
:
output_log_probs
=
output_log_probs
[:,
:
context_end_position
]
generated_sequence_lengths
[
generated_sequence_lengths
>
sampling_params
.
num_tokens_to_generate
]
=
sampling_params
.
num_tokens_to_generate
for
idx
,
request
in
enumerate
(
active_requests
.
values
()):
input_prompt_length
=
int
(
prompt_lengths_in_batch
[
idx
])
# Shorter prompts might have generated more than required tokens. So we trim them down
required_sequence_length
=
int
(
min
(
generated_sequence_lengths
[
idx
],
sampling_params
.
num_tokens_to_generate
)
)
# Extract only the generated tokens
required_result_tokens
=
batch_prompt_tokens_with_generations
[
idx
,
input_prompt_length
:
(
input_prompt_length
+
required_sequence_length
)
]
request
.
generated_length
=
required_sequence_length
request
.
generated_tokens
=
required_result_tokens
request
.
generated_log_probs
=
(
None
if
output_log_probs
is
None
else
output_log_probs
[
idx
,
input_prompt_length
:
required_sequence_length
]
)
request
.
status
=
Status
.
COMPLETED
request
.
generated_text
=
self
.
detokenize_generations
(
required_result_tokens
)
return
active_requests
def
prep_model_for_inference
(
self
,
prompts_tokens
:
torch
.
Tensor
,
active_requests
:
OrderedDict
[
int
,
InferenceRequest
]
):
"""Preparing batch for inference, using respective wrapper's prep_model_for_inference method
Args:
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length]
active_requests (OrderedDict[int, InferenceRequest]): The input active requests
"""
self
.
inference_wrapped_model
.
prep_model_for_inference
(
prompts_tokens
=
prompts_tokens
)
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import
concurrent
import
copy
import
functools
from
typing
import
Any
,
Dict
,
List
,
Optional
,
OrderedDict
,
Tuple
,
Union
import
torch
import
torch.nn.functional
as
F
from
megatron.core
import
parallel_state
from
megatron.core.inference.async_stream
import
AsyncStream
from
megatron.core.inference.communication_utils
import
broadcast_from_last_pipeline_stage
from
megatron.core.inference.inference_request
import
InferenceRequest
,
Status
from
megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper
import
(
AbstractModelInferenceWrapper
,
)
from
megatron.core.inference.sampling_params
import
SamplingParams
from
megatron.core.transformer.cuda_graphs
import
create_cudagraphs
from
megatron.core.utils
import
get_model_config
class
TextGenerationController
:
"""The text generation controller (the main sampling loop)
This class tokenizes the input, runs inference, samples from logits, and detokenizes the output.
Args:
inference_wrapped_model (AbstractModelInferenceWrapper): A model that
is wrapped using the specs given in the abstract_model_inference_wrapper.py
tokenizer (_type_): Tokenizer used for tokenizing and detokenizing the prompts
"""
def
__init__
(
self
,
inference_wrapped_model
:
AbstractModelInferenceWrapper
,
tokenizer
):
self
.
inference_wrapped_model
=
inference_wrapped_model
self
.
tokenizer
=
tokenizer
# For models without pipeline parallelism, is_first_stage and is_last_stage returns True
self
.
model_is_pipeline_parallel
=
not
(
parallel_state
.
is_pipeline_first_stage
()
and
parallel_state
.
is_pipeline_last_stage
()
)
def
tokenize_prompt
(
self
,
prompt
:
str
,
add_BOS
:
bool
=
False
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Utility to tokenize the input prompts
Args:
prompt (str): The input prompt
Returns:
torch.Tensor: Returns the tokenized prompt
"""
prompt_tokens
=
self
.
tokenizer
.
tokenize
(
prompt
)
if
add_BOS
:
prompt_tokens
=
[
self
.
tokenizer
.
bos
]
+
prompt_tokens
return
prompt_tokens
def
detokenize_generations
(
self
,
tokens_gpu_tensor
:
torch
.
Tensor
,
lengths_gpu_tensor
:
torch
.
Tensor
,
detokenize_segments
:
bool
,
)
->
tuple
[
str
,
Optional
[
List
[
List
[
str
]]]]:
"""Detokenize the generated tokens.
Args:
tokens_gpu_tensor (torch.Tensor): Tensor containing the tokens
lengths_gpu_tensor (torch.Tensor): Tensor containing the lengths of each sequence
detokenize_segments (bool): If True, returns individually detokenized tokens. If False,
returns None as second element. Helpful for understanding per-token boundaries in
generated text.
Returns:
tuple[str, List[str] | None]: A tuple containing:
- str: The complete detokenized text
- List[str] | None: List of segmented tokens if detokenize_segments is True, else None
"""
# TODO(helenn): Unify with `detokenize_generations` from legacy textgen path
if
not
detokenize_segments
:
tokens
=
tokens_gpu_tensor
.
cpu
().
numpy
().
tolist
()
return
self
.
tokenizer
.
detokenize
(
tokens
),
None
prompts_plus_generations
:
List
[
str
]
=
[]
prompts_plus_generations_segments
:
List
[
List
[
str
]]
=
[]
tokens_gpu_tensor
=
torch
.
unsqueeze
(
tokens_gpu_tensor
,
0
)
tokens
=
tokens_gpu_tensor
.
cpu
().
numpy
().
tolist
()
lengths
=
lengths_gpu_tensor
.
cpu
().
numpy
().
tolist
()
for
sequence_tokens
,
length
in
zip
(
tokens
,
lengths
):
sequence_tokens
=
sequence_tokens
[:
length
]
detok_str
=
self
.
tokenizer
.
detokenize
(
sequence_tokens
)
prompts_plus_generations
.
append
(
detok_str
)
offsets
=
self
.
tokenizer
.
offsets
(
sequence_tokens
,
detok_str
)
words
=
[
detok_str
[
start
:
end
]
for
start
,
end
in
zip
(
offsets
,
offsets
[
1
:]
+
[
len
(
detok_str
)])
]
prompts_plus_generations_segments
.
append
(
words
)
text
=
self
.
tokenizer
.
detokenize
(
tokens
[
0
])
return
text
,
prompts_plus_generations_segments
def
sample_from_logits
(
self
,
last_token_logits
:
torch
.
Tensor
,
sampling_params
:
Optional
[
SamplingParams
]
=
None
,
vocab_size
:
Optional
[
int
]
=
None
,
**
kwargs
,
)
->
torch
.
Tensor
:
"""Samples the logits to generate outputs
Given the logits of the last token, this function samples it
according to the parameters defined in sampling_params
and returns the samples
Args:
last_token_logits (torch.Tensor): The last token logits. A tensor of
size [batch_size, vocab_size]
sampling_params (SamplingParams): The parameters to use for inference.
vocab_size (int): Obtained from the tokenizer. Defaults to None
Returns:
torch.Tensor: 1D tensor of the sampled logits with [batch_size] elements
"""
if
kwargs
.
get
(
'common_inference_params'
):
sampling_params
=
kwargs
[
'common_inference_params'
]
top_p
=
sampling_params
.
top_p
top_k
=
sampling_params
.
top_k
temperature
=
sampling_params
.
temperature
assert
not
(
top_k
>
0
and
top_p
>
0
),
'Cannot have top-p and top-k both greater than zero'
assert
top_p
<=
1.0
,
'top-p should be in (0,1]'
def
modify_logits_for_top_k_filtering
(
logits
,
top_k
):
"""Set the logits for none top-k values to -inf."""
filter_
=
logits
<
torch
.
topk
(
logits
,
top_k
)[
0
][...,
-
1
,
None
]
logits
.
masked_fill_
(
filter_
,
float
(
'-Inf'
))
def
modify_logits_for_top_p_filtering
(
logits
,
top_p
):
"""Set the logits for none top-p values to -inf."""
# First sort and calculate cumulative sum of probabilities.
sorted_logits
,
sorted_indices
=
torch
.
sort
(
logits
,
descending
=
True
)
cumulative_probs
=
sorted_logits
.
softmax
(
dim
=-
1
).
cumsum
(
dim
=-
1
)
# Filteration based on the cumulative sum.
filter_
=
cumulative_probs
>
top_p
# This shift by 1 is weird and I cannot justify it. This existed
# in the original implementation:
# https://github.com/ari-holtzman/degen/blob/master/gen.py
# and I guess it is needed so keeping it for now.
filter_
[:,
1
:]
=
filter_
[:,
:
-
1
].
clone
()
# Make sure we at least have one token to select from.
filter_
[...,
0
]
=
0
# Fill in the filtered part
filter_
=
filter_
.
scatter
(
1
,
sorted_indices
,
filter_
)
logits
.
masked_fill_
(
filter_
,
float
(
'-Inf'
))
# Greedy sampling
if
top_k
==
1
:
sampled_logits
=
torch
.
argmax
(
last_token_logits
,
dim
=-
1
)
else
:
last_token_logits
=
last_token_logits
.
clone
()
if
temperature
!=
1.0
:
last_token_logits
.
div_
(
temperature
)
if
top_k
>
1
:
assert
top_k
<=
last_token_logits
.
size
(
1
),
'top-k is larger than logit size.'
if
vocab_size
:
assert
top_k
<
vocab_size
,
'top-k is larger than vocab size.'
modify_logits_for_top_k_filtering
(
last_token_logits
,
top_k
)
elif
top_p
>
0.0
:
modify_logits_for_top_p_filtering
(
last_token_logits
,
top_p
)
# After filtering, we need to recalculate the distribution.
probabilities
=
last_token_logits
.
softmax
(
dim
=-
1
)
sampled_logits
=
torch
.
multinomial
(
probabilities
,
num_samples
=
1
).
view
(
-
1
)
# If vocab size is provided, make sure the samples are in in the range [0, vocab-size).
if
vocab_size
:
sampled_logits
=
torch
.
clamp
(
sampled_logits
,
min
=
0
,
max
=
(
vocab_size
-
1
))
return
sampled_logits
def
update_generation_status
(
self
,
updated_prompts_tokens
:
torch
.
Tensor
,
generation_started
:
torch
.
Tensor
,
current_context_end_position
:
int
,
is_generation_done_tensor
:
torch
.
Tensor
,
generated_sequence_lengths
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Checks which prompts have reached an end condition
We check which prompts have reached an end condition and set the corresponding
flags of the is_generation_done_tensor to True. The generated sequence lengths
increase as we keep generating, until that prompts hits an end condition. The
generation_started tensor determines which prompts have started generating.
Args:
updated_prompts_tokens (torch.Tensor): The prompts tokens updated with the latest
generated tokens. A tensor of shape [batch_size, max_seq_len]
(i.e max_seq_len = max_prompt_len + tokens_to_generate)
generation_started (torch.Tensor): A boolean tensor of shape [batch_size]. True
indicates the prompt at that index has started generating tokens.
current_context_end_position (int): An integer indicating which position to
extract from the prompts tokens to get the latest generated tokens.
is_generation_done_tensor (torch.Tensor): A boolean tensor of shape [batch_size].
True indicates the prompt at that index has reached end condition.
generated_sequence_lengths (torch.Tensor): A int tensor of shape [batch_size].
Each value represents the generated sequence lengths for that prompt.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Returns the boolean
is_generation_done_tensor and the generated_sequence_lengths after updating it
"""
latest_samples
=
updated_prompts_tokens
[:,
current_context_end_position
]
# Make sure we are checking eod criterion only for prompts that have started generating
# (i.e) We only look at the generated tokenns and not the input tokens.
reached_eod
=
(
latest_samples
==
self
.
tokenizer
.
eod
)
&
generation_started
is_generation_done_tensor
=
is_generation_done_tensor
|
reached_eod
# We increment generated sequence lengths when that prompt has not hit the
# EOD and generation has started
generated_sequence_lengths
+=
~
is_generation_done_tensor
&
generation_started
return
is_generation_done_tensor
,
generated_sequence_lengths
.
int
()
def
pad_input_prompt_tokens
(
self
,
batch_prompt_tokens_list
:
List
[
List
[
int
]],
max_prompt_length_in_batch
:
int
,
num_tokens_to_generate
:
int
,
)
->
torch
.
Tensor
:
"""Method to pad input prompts
Given a list of prompts, pad them all to uniform length
Args:
batch_prompt_tokens_list (List[List[int]]): A list containing the prompt tokens
max_prompt_length_in_batch (int): Maximum of the length of the input prompt tokens
num_tokens_togenerate (int): The number of tokens to generate for each prompt
Returns:
torch.Tensor: A torch tensor of shape [bs, max_seq_len] (i.e)
max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate,
"""
max_seq_len
=
max_prompt_length_in_batch
+
num_tokens_to_generate
for
prompt_tokens
in
batch_prompt_tokens_list
:
padding_size
=
max_seq_len
-
len
(
prompt_tokens
)
prompt_tokens
.
extend
([
self
.
tokenizer
.
eod
]
*
padding_size
)
return
torch
.
tensor
(
batch_prompt_tokens_list
,
device
=
torch
.
cuda
.
current_device
())
def
generate_output_tokens_dynamic_batch
(
self
,
active_requests
:
OrderedDict
[
str
,
InferenceRequest
]
)
->
OrderedDict
[
str
,
InferenceRequest
]:
"""Utility to generate the output tokens and probabilities for the prompts
This utility generates the output tokens for a dynamic batch. It will run one forward step
at a time, and pass control back to the engine, which will update the request pool and call
this method again.
Args:
active_requests (OrderedDict[str, InferenceRequest]): The input active requests.
Returns:
OrderedDict[str, InferenceRequest]: The result for each of the incoming requests
after running one forward step.
"""
raise
Exception
(
"Not implemented yet"
)
def
generate_all_output_tokens_static_batch
(
self
,
active_requests
:
OrderedDict
[
str
,
InferenceRequest
],
active_streams
:
Optional
[
OrderedDict
[
str
,
AsyncStream
]]
=
None
,
)
->
OrderedDict
[
str
,
InferenceRequest
]:
"""Utility to generate the all the output tokens and probabilities for the prompts .
This utility generates the output tokens for a static batch. It runs the forward steps till
all prompts complete generation, updates the status of these requests to completed, adds
the generated result and returns these requests
Args:
active_requests (OrderedDict[str, InferenceRequest]): The input active requests.
Returns:
OrderedDict[str, InferenceRequest]: The result for each of the incoming requests
"""
assert
all
(
request
.
prompt_tokens
is
not
None
for
request
in
active_requests
.
values
())
# Perform a deep copy so that the request prompt tokens do not get modified.
batch_prompt_tokens_list
:
List
[
List
[
int
]]
=
list
(
map
(
lambda
request
:
copy
.
deepcopy
(
request
.
prompt_tokens
),
# type: ignore[arg-type]
active_requests
.
values
(),
)
)
prompt_lengths_in_batch
=
torch
.
tensor
(
[
len
(
prompt_tokens
)
for
prompt_tokens
in
batch_prompt_tokens_list
],
device
=
torch
.
cuda
.
current_device
(),
)
max_prompt_length_in_batch
=
max
(
prompt_lengths_in_batch
)
min_prompt_length_in_batch
=
min
(
prompt_lengths_in_batch
)
# For batch inference the inference params are the same for all request
sampling_params
:
SamplingParams
=
list
(
active_requests
.
values
())[
0
].
inference_parameters
# max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate
batch_prompt_tokens
=
self
.
pad_input_prompt_tokens
(
batch_prompt_tokens_list
,
max_prompt_length_in_batch
=
max_prompt_length_in_batch
,
num_tokens_to_generate
=
sampling_params
.
num_tokens_to_generate
,
)
batch_size
,
max_sequence_length
=
batch_prompt_tokens
.
shape
# Verify that output sequence length is within configured limit
# TODO(ksanthanam): Raise TokenOverflowError once !2518 is merged
inference_max_sequence_length
=
(
self
.
inference_wrapped_model
.
inference_wrapper_config
.
inference_max_seq_length
)
assert
max_sequence_length
<=
inference_max_sequence_length
,
(
f
"Maximum allowed sequence length was set to
{
inference_max_sequence_length
}
tokens "
f
"but requested generation of
{
max_sequence_length
}
tokens"
)
# Pre allocate log probs tensor
output_log_probs
=
None
if
sampling_params
.
return_log_probs
:
output_log_probs
=
torch
.
empty
(
(
batch_size
,
max_sequence_length
-
1
),
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
(),
)
# An array to check which of the prompts have reached end of generation condition
is_generation_done_tensor
=
torch
.
zeros
(
batch_size
,
dtype
=
torch
.
bool
,
device
=
torch
.
cuda
.
current_device
()
)
# An array to act as a counter to keep track of generated sequence lengths
generated_sequence_lengths
=
torch
.
zeros
(
batch_size
,
device
=
torch
.
cuda
.
current_device
()
).
cuda
()
# Use padded vocab size because tokenizer vocab size might not include padding
# to nearest power of 2
vocab_size
=
self
.
inference_wrapped_model
.
inference_wrapper_config
.
padded_vocab_size
# Check whether CUDA graphs are enabled
enable_cuda_graph
=
get_model_config
(
self
.
inference_wrapped_model
.
model
).
enable_cuda_graph
streaming_enabled
=
active_streams
is
not
None
and
len
(
active_streams
)
>
0
if
streaming_enabled
:
# Start a separate thread for streaming tokens to avoid blocking the
# main computation
streaming_idx
:
List
[
int
]
=
[
i
for
(
i
,
request_id
)
in
enumerate
(
active_requests
.
keys
())
if
request_id
in
active_streams
]
streaming_request_ids
:
List
[
str
]
=
list
(
active_streams
.
keys
())
streams
:
List
[
AsyncStream
]
=
list
(
active_streams
.
values
())
streaming_requests
:
List
[
InferenceRequest
]
=
[
active_requests
[
request_id
]
for
request_id
in
streaming_request_ids
]
streaming_executor
=
concurrent
.
futures
.
ThreadPoolExecutor
(
max_workers
=
1
)
stream_tokens
=
functools
.
partial
(
self
.
stream_tokens
,
sampling_params
)
with
torch
.
no_grad
():
self
.
inference_wrapped_model
.
prep_model_for_inference
(
prompts_tokens
=
batch_prompt_tokens
)
inference_input
:
Dict
[
str
,
Any
]
=
self
.
prep_inference_input
(
prompts_tokens
=
batch_prompt_tokens
,
active_requests
=
active_requests
)
assert
(
not
self
.
inference_wrapped_model
.
inference_params
.
decode_mode
),
f
"Generation must start in prefill mode"
context_start_position
=
0
# Pick the context window that we need to pass through the network.
for
context_end_position
in
range
(
min_prompt_length_in_batch
,
max_sequence_length
):
inference_input_for_context_window
:
Dict
[
str
,
Any
]
=
(
self
.
inference_wrapped_model
.
get_batch_for_context_window
(
inference_input
,
context_start_position
,
context_end_position
)
)
# Disable attention mask when using CUDA graphs for decode
if
(
enable_cuda_graph
and
self
.
inference_wrapped_model
.
inference_params
.
decode_mode
and
"attention_mask"
in
inference_input_for_context_window
):
inference_input_for_context_window
[
"attention_mask"
]
=
None
# Returns the final logits of shape [batch_size, context_length, vocab_size]
# Note: This is returned in all TP ranks or last PP stage in PP models
logits
=
self
.
inference_wrapped_model
.
run_one_forward_step
(
inference_input_for_context_window
)
if
enable_cuda_graph
:
create_cudagraphs
()
if
self
.
model_is_pipeline_parallel
:
context_length
=
context_end_position
-
context_start_position
logits
=
broadcast_from_last_pipeline_stage
(
[
batch_size
,
context_length
,
vocab_size
],
dtype
=
self
.
inference_wrapped_model
.
inference_wrapper_config
.
params_dtype
,
tensor
=
logits
,
)
# Indicates which of the input prompts have started generating tokens.
# A 1D boolean tensor with [batch_size] elements (i.e) The shortest
# prompts will start generating first and so on
generation_started
=
prompt_lengths_in_batch
<=
context_end_position
last_token_logits
=
logits
[:,
-
1
,
:]
sampled_logits
=
self
.
sample_from_logits
(
last_token_logits
,
sampling_params
,
vocab_size
)
# Substitute the sampled logits only for the prompts that
# have started generating tokens
batch_prompt_tokens
[
generation_started
,
context_end_position
]
=
sampled_logits
[
generation_started
]
if
sampling_params
.
return_log_probs
:
log_probs
=
F
.
log_softmax
(
logits
,
dim
=
2
)
indices
=
torch
.
unsqueeze
(
batch_prompt_tokens
[
:,
(
context_start_position
+
1
)
:
(
context_end_position
+
1
)
],
2
,
)
# Get the log probabilities for only the prompt tokens
assert
output_log_probs
is
not
None
output_log_probs
[:,
context_start_position
:
context_end_position
]
=
torch
.
gather
(
log_probs
,
2
,
indices
).
squeeze
(
2
)
context_start_position
=
context_end_position
# Check end of generation status for each tensor
# and update generated sequence lengths
(
is_generation_done_tensor
,
generated_sequence_lengths
)
=
(
self
.
update_generation_status
(
updated_prompts_tokens
=
batch_prompt_tokens
,
generation_started
=
generation_started
,
current_context_end_position
=
context_end_position
,
is_generation_done_tensor
=
is_generation_done_tensor
,
generated_sequence_lengths
=
generated_sequence_lengths
,
)
)
# Stream intermediate outputs
if
streaming_enabled
:
streaming_executor
.
submit
(
stream_tokens
,
streaming_request_ids
,
streaming_requests
,
streams
,
generation_started
[
streaming_idx
].
cpu
(),
is_generation_done_tensor
[
streaming_idx
].
cpu
(),
batch_prompt_tokens
[
streaming_idx
].
cpu
(),
prompt_lengths_in_batch
[
streaming_idx
].
cpu
(),
generated_sequence_lengths
[
streaming_idx
].
cpu
(),
(
output_log_probs
[
streaming_idx
].
cpu
()
if
output_log_probs
is
not
None
else
[
None
]
*
len
(
streaming_idx
)
),
)
# Boolean flag indicating if all prompts are finished
all_prompts_done
=
torch
.
all
(
is_generation_done_tensor
)
if
all_prompts_done
:
break
# Change to decode mode if all prefill is complete
if
torch
.
all
(
generation_started
):
self
.
inference_wrapped_model
.
inference_params
.
enable_decode_mode
()
# Close all streams
if
streaming_enabled
:
streaming_executor
.
shutdown
()
for
stream
in
streams
:
stream
.
finish
()
# Include all the generated tokens
batch_prompt_tokens_with_generations
=
batch_prompt_tokens
[:,
:
(
context_end_position
+
1
)]
if
sampling_params
.
return_log_probs
:
assert
output_log_probs
is
not
None
output_log_probs
=
output_log_probs
[:,
:
context_end_position
]
generated_sequence_lengths
[
generated_sequence_lengths
>
sampling_params
.
num_tokens_to_generate
]
=
sampling_params
.
num_tokens_to_generate
for
idx
,
request
in
enumerate
(
active_requests
.
values
()):
input_prompt_length
=
int
(
prompt_lengths_in_batch
[
idx
])
# Shorter prompts might have generated more than required tokens. So we trim them down
required_sequence_length
=
int
(
min
(
generated_sequence_lengths
[
idx
],
sampling_params
.
num_tokens_to_generate
)
)
# Extract only the generated tokens
required_result_tokens
=
batch_prompt_tokens_with_generations
[
idx
,
input_prompt_length
:
(
input_prompt_length
+
required_sequence_length
)
]
generated_sequence_lengths
=
generated_sequence_lengths
.
to
(
dtype
=
torch
.
int32
)
request
.
generated_sequence_lengths
=
generated_sequence_lengths
.
to
(
dtype
=
torch
.
int32
)
request
.
generated_length
=
required_sequence_length
request
.
generated_tokens
=
required_result_tokens
request
.
prompt_log_probs
=
(
None
if
output_log_probs
is
None
else
output_log_probs
[
idx
,
:
input_prompt_length
].
cpu
().
numpy
().
tolist
()
)
request
.
generated_log_probs
=
(
None
if
output_log_probs
is
None
else
output_log_probs
[
idx
,
input_prompt_length
-
1
:
(
input_prompt_length
+
required_sequence_length
-
1
),
]
.
cpu
()
.
numpy
()
.
tolist
()
)
request
.
status
=
Status
.
COMPLETED
text
,
segments
=
self
.
detokenize_generations
(
batch_prompt_tokens_with_generations
[
idx
],
input_prompt_length
+
generated_sequence_lengths
,
sampling_params
.
return_segments
,
)
request
.
text
=
text
# Inference server returns prompts & generations together
if
sampling_params
.
return_segments
:
request
.
segments
=
segments
[
0
]
request
.
generated_text
=
text
[
len
(
request
.
prompt
)
:]
return
active_requests
def
prep_inference_input
(
self
,
prompts_tokens
:
torch
.
Tensor
,
active_requests
:
OrderedDict
[
str
,
InferenceRequest
]
)
->
Dict
[
str
,
Any
]:
"""Preparing input data for inference, using respective wrapper's prep_inference_input method # pylint: disable=line-too-long
Args:
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length]
active_requests (OrderedDict[str, InferenceRequest]): The input active requests
Returns:
A dict of the inference input for the current batch.
"""
return
self
.
inference_wrapped_model
.
prep_inference_input
(
prompts_tokens
)
def
stream_tokens
(
self
,
sampling_params
:
SamplingParams
,
request_ids
:
List
[
str
],
requests
:
List
[
InferenceRequest
],
streams
:
List
[
AsyncStream
],
generation_started
:
List
[
bool
],
is_generation_done
:
List
[
bool
],
tokens
:
torch
.
Tensor
,
prompt_lengths
:
List
[
int
],
generated_lengths
:
List
[
int
],
output_log_probs
:
Union
[
torch
.
Tensor
,
None
],
):
"""Asynchronously streams tokens for the given requests.
Args:
sampling_params (SamplingParams): The sampling parameters.
request_ids (List[str]): The request IDs.
request (List[InferenceRequest]): The requests.
stream (List[AsyncStream]): The streams over which to send tokens.
generation_started (List[bool]): Whether the decode step has started.
is_generation_done (List[bool]): Whether generation has completed.
tokens (torch.Tensor): The tokens for this request.
prompt_lengths (List[int]): The number of prompt tokens for each request.
generated_lengths (List[int]): The number of output tokens for each request.
output_log_probs (torch.Tensor, optional): The log probs for each request.
"""
def
stream_token
(
request_id
:
str
,
request
:
InferenceRequest
,
stream
:
AsyncStream
,
generation_started
:
bool
,
is_generation_done
:
bool
,
tokens
:
torch
.
Tensor
,
prompt_length
:
int
,
generated_length
:
int
,
output_log_probs
:
Union
[
torch
.
Tensor
,
None
],
):
"""Asynchronously streams a token for the given request."""
if
not
generation_started
or
stream
.
finished
:
return
num_tokens_to_generate
=
sampling_params
.
num_tokens_to_generate
return_segments
=
sampling_params
.
return_segments
detokenize_streaming_text
=
not
getattr
(
sampling_params
,
"no_detokenize_streaming_text"
,
False
)
generated_tokens
=
tokens
[
prompt_length
:
prompt_length
+
generated_length
]
if
detokenize_streaming_text
:
generated_text
,
generated_segments
=
self
.
detokenize_generations
(
generated_tokens
,
prompt_length
+
generated_length
,
return_segments
)
else
:
generated_text
=
""
generated_segments
=
[]
if
output_log_probs
is
not
None
:
generated_log_probs
=
(
output_log_probs
[
prompt_length
-
1
:
prompt_length
+
generated_length
-
1
]
.
cpu
()
.
numpy
()
.
tolist
()
)
else
:
generated_log_probs
=
None
stream
.
put
(
InferenceRequest
(
request_id
=
request_id
,
prompt
=
request
.
prompt
,
inference_parameters
=
request
.
inference_parameters
,
prompt_tokens
=
request
.
prompt_tokens
,
arrival_time
=
request
.
arrival_time
,
status
=
request
.
status
,
encoder_prompt
=
request
.
encoder_prompt
,
generated_text
=
generated_text
,
generated_segments
=
generated_segments
,
generated_tokens
=
generated_tokens
,
generated_log_probs
=
generated_log_probs
,
generated_length
=
generated_length
,
)
)
if
is_generation_done
or
generated_length
==
num_tokens_to_generate
:
stream
.
finish
()
ret
=
map
(
stream_token
,
request_ids
,
requests
,
streams
,
generation_started
,
is_generation_done
,
tokens
,
prompt_lengths
,
generated_lengths
,
output_log_probs
,
)
list
(
ret
)
megatron/core/inference/text_generation_controllers/vlm_text_generation_controller.py
0 → 100644
View file @
688448db
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
typing
import
OrderedDict
import
torch
from
megatron.core.inference.inference_request
import
InferenceRequest
,
VLMInferenceRequest
from
megatron.core.inference.text_generation_controllers.text_generation_controller
import
(
TextGenerationController
,
)
class
VLMTextGenerationController
(
TextGenerationController
):
"""The text generation controller for VLMs"""
def
prep_inference_input
(
self
,
prompts_tokens
:
torch
.
Tensor
,
active_requests
:
OrderedDict
[
str
,
InferenceRequest
]
):
"""Preparing input data for inference, using respective wrapper's prep_inference_input method # pylint: disable=line-too-long
Currently only supports batch size 1 inference.
Args:
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length]
active_requests (OrderedDict[str, InferenceRequest]): The input active requests
"""
assert
len
(
active_requests
)
==
1
,
f
"VLM inference currently only supports batch size 1"
request
=
list
(
active_requests
.
values
())[
0
]
assert
isinstance
(
request
,
VLMInferenceRequest
),
f
"Found inference request of type
{
type
(
request
)
}
, expected VLMInferenceRequest"
return
self
.
inference_wrapped_model
.
prep_inference_input
(
prompts_tokens
,
request
.
num_img_embeddings_per_tile
,
request
.
imgs
,
request
.
num_tiles
,
request
.
decoder_seq_length
,
)
megatron/core/inference_params.py
View file @
688448db
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
class
InferenceParams
:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
def
__init__
(
self
,
max_batch_size
,
max_sequence_length
):
self
.
max_sequence_length
=
max_sequence_length
self
.
max_batch_size
=
max_batch_size
self
.
sequence_len_offset
=
0
self
.
batch_size_offset
=
0
self
.
key_value_memory_dict
=
{}
def
swap_key_value_dict
(
self
,
batch_idx
):
"swap between batches"
if
len
(
self
.
key_value_memory_dict
)
==
0
:
raise
ValueError
(
"should not swap when dict in empty"
)
for
layer_number
in
self
.
key_value_memory_dict
.
keys
():
inference_key_memory
,
inference_value_memory
=
self
.
key_value_memory_dict
[
layer_number
]
assert
(
len
(
batch_idx
)
==
inference_key_memory
.
shape
[
1
]
)
# make sure batch size is the same
new_inference_key_memory
=
inference_key_memory
[:,
batch_idx
]
new_inference_value_memory
=
inference_value_memory
[:,
batch_idx
]
self
.
key_value_memory_dict
[
layer_number
]
=
(
new_inference_key_memory
,
new_inference_value_memory
,
)
def
__str__
(
self
):
return
f
"InferenceParams(max_seq_len =
{
self
.
max_sequence_length
}
, max_batch_size =
{
self
.
max_batch_size
}
, sequence_len_offset =
{
self
.
sequence_len_offset
}
, batch_size_offset =
{
self
.
batch_size_offset
}
, key_value_memory_dict =
{
self
.
key_value_memory_dict
.
keys
()
}
)"
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
class
InferenceParams
:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
def
__init__
(
self
,
max_batch_size
,
max_sequence_length
):
self
.
max_sequence_length
=
max_sequence_length
self
.
max_batch_size
=
max_batch_size
self
.
current_batch_size
=
max_batch_size
# Required for bookkeeping variable-sized batches
self
.
sequence_len_offset
=
0
self
.
batch_size_offset
=
0
self
.
decode_mode
=
False
self
.
key_value_memory_dict
=
{}
self
.
decode_mode
=
False
def
swap_key_value_dict
(
self
,
batch_idx
):
"swap between batches"
if
len
(
self
.
key_value_memory_dict
)
==
0
:
raise
ValueError
(
"should not swap when dict in empty"
)
for
layer_number
in
self
.
key_value_memory_dict
.
keys
():
inference_key_memory
,
inference_value_memory
=
self
.
key_value_memory_dict
[
layer_number
]
assert
(
len
(
batch_idx
)
==
inference_key_memory
.
shape
[
1
]
)
# make sure batch size is the same
new_inference_key_memory
=
inference_key_memory
[:,
batch_idx
]
new_inference_value_memory
=
inference_value_memory
[:,
batch_idx
]
self
.
key_value_memory_dict
[
layer_number
]
=
(
new_inference_key_memory
,
new_inference_value_memory
,
)
def
enable_prefill_mode
(
self
):
"""
Indicates the generation loop is in the prefill phase (still processing
input prompt tokens). This should be enabled if the generation loop is
encoding prompt tokens for *any* request in a batch.
"""
self
.
decode_mode
=
False
def
enable_decode_mode
(
self
):
"""
Indicates the generation loop is in the decode phase (generating new output
tokens). This should only be enabled if the generation loop has fully encoded
the prompts for *all* requests in a batch.
"""
self
.
decode_mode
=
True
def
reset
(
self
):
"""Resets the inference state for a new batch."""
self
.
current_batch_size
=
self
.
max_batch_size
self
.
sequence_len_offset
=
0
self
.
batch_size_offset
=
0
self
.
enable_prefill_mode
()
def
__str__
(
self
):
return
(
f
"InferenceParams(max_seq_len =
{
self
.
max_sequence_length
}
, "
f
"max_batch_size =
{
self
.
max_batch_size
}
, "
f
"current_batch_size =
{
self
.
current_batch_size
}
, "
f
"sequence_len_offset =
{
self
.
sequence_len_offset
}
, "
f
"batch_size_offset =
{
self
.
batch_size_offset
}
, "
f
"key_value_memory_dict =
{
self
.
key_value_memory_dict
.
keys
()
}
)"
f
"decode_mode =
{
self
.
decode_mode
}
"
)
def
__eq__
(
self
,
other
):
if
not
isinstance
(
other
,
InferenceParams
):
return
False
# Check all attributes match
basic_attrs
=
[
'max_sequence_length'
,
'max_batch_size'
,
'current_batch_size'
,
'sequence_len_offset'
,
'batch_size_offset'
,
]
if
not
all
(
hasattr
(
other
,
attr
)
for
attr
in
basic_attrs
):
return
False
# Check dictionary keys match; i.e. the same number of layers are cached
if
self
.
key_value_memory_dict
.
keys
()
!=
other
.
key_value_memory_dict
.
keys
():
return
False
# Check each tensor tuple in the dictionary
for
key
in
self
.
key_value_memory_dict
:
self_tensors
=
self
.
key_value_memory_dict
[
key
]
other_tensors
=
other
.
key_value_memory_dict
[
key
]
# Compare each key, value tensor in the tuple
for
self_tensor
,
other_tensor
in
zip
(
self_tensors
,
other_tensors
):
if
(
self_tensor
.
data_ptr
()
!=
other_tensor
.
data_ptr
()
or
self_tensor
.
shape
!=
other_tensor
.
shape
):
return
False
return
True
megatron/core/jit.py
View file @
688448db
...
...
@@ -7,18 +7,4 @@ from megatron.core.utils import is_torch_min_version
jit_fuser
=
torch
.
jit
.
script
# nvFuser is deprecated in PyTorch JIT starting from 2.2
if
is_torch_min_version
(
"2.2.0a0"
):
jit_fuser
=
torch
.
compile
(
mode
=
'max-autotune-no-cudagraphs'
)
# Decorator to disable Torch Dynamo
# See: https://github.com/NVIDIA/TransformerEngine/issues/308
no_torch_dynamo
=
lambda
recursive
=
True
:
lambda
func
:
func
if
torch
.
__version__
>=
"2"
:
import
torch._dynamo
if
torch
.
__version__
>=
"2.1"
:
no_torch_dynamo
=
lambda
recursive
=
True
:
lambda
f
:
torch
.
_dynamo
.
disable
(
f
,
recursive
=
recursive
)
else
:
# no "recursive" option in pyTorch 2.0 - it acts as if recursive was True
no_torch_dynamo
=
lambda
recursive
=
True
:
torch
.
_dynamo
.
disable
jit_fuser
=
torch
.
compile
megatron/core/model_parallel_config.py
View file @
688448db
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from
dataclasses
import
dataclass
from
typing
import
Callable
,
ContextManager
,
Optional
import
torch
@
dataclass
class
ModelParallelConfig
:
"""Base configuration for Megatron Core
The initialization function has an argument for each parameter.
"""
###################
# Model parallelism
###################
tensor_model_parallel_size
:
int
=
1
"""Intra-layer model parallelism. Splits tensors across GPU ranks."""
pipeline_model_parallel_size
:
int
=
1
"""Inter-layer model parallelism. Splits transformer layers across GPU ranks."""
virtual_pipeline_model_parallel_size
:
Optional
[
int
]
=
None
"""Interleaved pipeline parallelism is used to improve performance by reducing the pipeline
bubble. Considers a transformer block as a list of smaller transformer (virtual) blocks.
The number of virtual blocks per pipeline model parallel rank is the virtual model parallel
size. See Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM:
arxiv.org/pdf/2104.04473.pdf for more details.
"""
sequence_parallel
:
bool
=
False
"""Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms
and dropout sequentially. See Reducing Activation Recomputation in Large Transformer Models
(https://arxiv.org/abs/2205.05198) for more details.
"""
context_parallel_size
:
int
=
1
"""Splits network input along sequence dimension across GPU ranks."""
hierarchical_context_parallel_sizes
:
Optional
[
list
[
int
]]
=
None
"""Degrees of the hierarchical context parallelism. Users should provide a list to specify
the sizes for different levels. Taking the a2a+p2p cp comm type as example, it contains
groups of two levels, so the first value of the list indicates the group size of the a2a
communication type, and the second value indicates the group size of the p2p communication
type.
"""
expert_model_parallel_size
:
int
=
1
"""Distributes Moe Experts across sub data parallel dimension."""
expert_tensor_parallel_size
:
Optional
[
int
]
=
None
"""Intra-layer tensor model parallelsm for expert layer. Splits tensors across GPU ranks."""
moe_extended_tp
:
bool
=
False
"""NOTE: Deprecated from MCore v0.10. This flag is ignored.
Its functionality is replaced by expert_tensor_parallel_size.
"""
###################
# Initialization
###################
perform_initialization
:
bool
=
True
"""If true, weights are initialized. This option can be useful when you know you are going to
load values from a checkpoint.
"""
use_cpu_initialization
:
bool
=
False
"""When set to False, we initialize the weights directly on the GPU. CPU initialization is the
same regardless of tensor model parallelism, but GPU initialization is not. Transferring
weights from CPU to GPU can take a significant amount of time for large models.
"""
###################
# Training
###################
fp16
:
bool
=
False
"""If true, train with fp16 mixed precision training."""
bf16
:
bool
=
False
"""If true, train with bf16 mixed precision training."""
params_dtype
:
torch
.
dtype
=
torch
.
float32
"""dtype used when intializing the weights."""
timers
:
Optional
[
Callable
]
=
None
"""Timers object to call for various timing functions. See megatron.core.timers.Timers"""
finalize_model_grads_func
:
Optional
[
Callable
]
=
None
"""Function that finalizes gradients on all workers. Could include ensuring that grads are
all-reduced across data parallelism, pipeline parallelism, and sequence parallelism
dimensions.
"""
grad_scale_func
:
Optional
[
Callable
]
=
None
"""If using loss scaling, this function should take the loss and return the scaled loss. If
None, no function is called on the loss.
"""
no_sync_func
:
Optional
[
Callable
]
=
None
"""Function that creates a context that suppresses asynchronous data-parallel communication. If
the model is an instance of core.distributed.DistributedDataParallel, the default is to use
core.distributed.DistributedDataParallel.no_sync.
"""
grad_sync_func
:
Optional
[
Callable
]
=
None
"""Function that launches asynchronous gradient reductions (e.g. distributed optimizer gradient
reduce-scatters). The function should take one argument: an iterable of parameters whose
gradients are to be synchronized.
"""
param_sync_func
:
Optional
[
Callable
]
=
None
"""Function that launches asynchronous parameter synchronizations (e.g. distributed optimizer
parameter all-gathers). The function should take one argument: an iterable of parameters to
be synchronized.
"""
deterministic_mode
:
bool
=
False
"""If true, code that has deterministic execution will be chosen. This usually
means slower execution, but is good for debugging and testing. Defaults to False."""
enable_autocast
:
bool
=
False
"""If true runs the forward step function inside torch.autocast context."""
autocast_dtype
:
Optional
[
torch
.
dtype
]
=
None
"""dtype to pass to torch.amp.autocast when enabled. If None, is set to pipeline_dtype."""
num_microbatches_with_partial_activation_checkpoints
:
Optional
[
int
]
=
None
"""If int, set the number of microbatches where not all of the layers will be checkpointed and
recomputed. The rest of the microbatches within the window of maximum outstanding
microbatches will recompute all layers (either full recompute or selective recompute). If
None, the checkpoint and recompute will be left up to the forward_step function.
"""
###################
# Optimizations
###################
gradient_accumulation_fusion
:
bool
=
False
"""If true, fuses weight gradient accumulation to GEMMs. Requires the custom CUDA extension
fused_weight_gradient_mlp_cuda module. To use gradient_accumulation_fusion you must install
APEX with --cpp_ext and --cuda_ext. For example: "pip install --global-option=
\"
--cpp_ext
\"
--global-option=
\"
--cuda_ext
\"
". Note that the extension requires CUDA>=11. Otherwise, you
must turn off gradient accumulation fusion.
"""
async_tensor_model_parallel_allreduce
:
bool
=
False
"""NOTE: Deprecated. This flag is ignored."""
use_te_rng_tracker
:
bool
=
False
"""If true, uses RNG state tracker in TransformerEngine if exists.
"""
tp_comm_overlap
:
bool
=
False
"""If true, allows overlapping of Linear layer execution with tensor parallel communication
collectives like AllGather/ReduceScatter. Overlapping is done for the linear layers wherever
possible during the forward and the backward pass.
"""
tp_comm_bulk_wgrad
:
bool
=
True
"""If true, allows All-Gather overlap with Bprop activation gradient GEMM. Don't care if
tp_comm_overlap is False.
"""
tp_comm_bulk_dgrad
:
bool
=
True
"""If true, allows Reduce-Scatter overlap with Bprop weight gradient GEMM. Don't care if
tp_comm_overlap is False.
"""
tp_comm_overlap_ag
:
bool
=
True
"""If true, allows All-Gather overlap with GEMM by pipelining the GEMM and All-Gather.
Don't care if tp_comm_overlap is False.
"""
tp_comm_overlap_rs
:
bool
=
True
"""If true, allows Reduce-Scatter overlap with GEMM by pipelining the GEMM and Reduce-Scatter.
Don't care if tp_comm_overlap is False.
"""
tp_comm_overlap_rs_dgrad
:
bool
=
False
"""If true, allows Reduce-Scatter overlap with DGRAD GEMM by pipelining the
GEMM and Reduce-Scatter splits. Don't care if tp_comm_overlap is False.
"""
tp_comm_split_ag
:
bool
=
True
"""Deprecated from TransformerEngine v1.6.0.
If true, allows All-Gather overlap with Fprop GEMM by pipelining the GEMM and All-Gather
splits. Don't care if tp_comm_overlap is False.
"""
tp_comm_atomic_ag
:
bool
=
False
"""Deprecated from TransformerEngine v1.6.0.
If true, allows All-Gather overlap with Fprop GEMM by pipelining the GEMM and All-Gather
both done atomically. Don't care if tp_comm_overlap is False.
"""
tp_comm_split_rs
:
bool
=
True
"""Deprecated from TransformerEngine v1.6.0.
If true, allows Reduce-Scatter overlap with Fprop GEMM by pipelining the GEMM and
Reduce-Scatter splits. Don't care if tp_comm_overlap is False.
"""
tp_comm_atomic_rs
:
bool
=
False
"""Deprecated from TransformerEngine v1.6.0.
If true, allows Reduce-Scatter overlap with Fprop GEMM by pipelining the GEMM and
Reduce-Scatter both done atomically. Don't care if tp_comm_overlap is False.
"""
cross_entropy_loss_fusion
:
bool
=
False
"""If this is enabled, the fused cross entropy implementation would be used.
Defaults to False.
"""
tp_comm_overlap_disable_qkv
:
bool
=
False
"""
If true, the AllGather -> Gemm overlap for QKV gets disabled
"""
tp_comm_overlap_disable_fc1
:
bool
=
False
"""
If true, the AllGather -> Gemm overlap for FC1 layer of MLP gets disabled
"""
tp_comm_bootstrap_backend
:
str
=
'nccl'
"""
Set the bootstrapping backend out of 'nccl', 'mpi', and 'gloo'
"""
###################
# Pipeline Parallel
###################
pipeline_dtype
:
torch
.
dtype
=
None
"""dtype used in p2p communication, usually params_dtype"""
variable_seq_lengths
:
bool
=
False
"""Support for variable sequence lengths across microbatches. Setting this communicates the size
of tensors during pipeline parallelism communication, because of this extra overhead it
should only be set if the sequence length varies by microbatch within a global batch.
"""
overlap_p2p_comm
:
bool
=
False
"""When True some of the peer to peer communication for pipeline parallelism will overlap with
computation. Must be False if batch_p2p_comm is true.
"""
batch_p2p_comm
:
bool
=
True
"""Use batch_isend_irecv instead of individual isend/irecv calls. Must be False if
overlap_p2p_comm is True.
"""
batch_p2p_sync
:
bool
=
True
"""When using batch_isend_irecv, do a cuda.device.synchronize afterward to work around a bug in
older version of PyTorch.
"""
use_ring_exchange_p2p
:
bool
=
False
"""Use custom ring_exchange kernel instead of torch.distributed.batch_isend_irecv(). Requires
custom built torch with torch.distributed.ring_exchange.
"""
deallocate_pipeline_outputs
:
bool
=
False
"""If True, output data is deallocated after the tensor is sent to the next pipeline stage.
Helps with saving memory, does nothing when pipeline parallel is not used.
"""
defer_embedding_wgrad_compute
:
bool
=
False
"""If true, defers the embedding WGRAD GEMMs while pipeline flush is
taking place enabling us to hide pipeline flush latency. Defaults to False.
"""
wgrad_deferral_limit
:
int
=
0
"""This value tunes the number of micro-batches for which the embedding weight gradient compute
needs to be deferred to pipeline flush, this argument is invalid if
`defer_embedding_wgrad_compute` is False.
Defaults to 0, which means all micro-batches are deferred.
"""
pipeline_model_parallel_split_rank
:
Optional
[
int
]
=
None
"""If int, rank where encoder and decoder should be split in cases where the model has both an
encoder and decoder (e.g., T5). Ignored if None.
"""
overlap_p2p_comm_warmup_flush
:
bool
=
False
"""If true, overlap communication and computation in warm up and flush phase.
Only valid when overlap_p2p_comm is True and batch_p2p_comm is False.
Defaults to False.
"""
microbatch_group_size_per_vp_stage
:
Optional
[
int
]
=
None
"""This value specifies the number of micro-batches that are executed
at a time for a given virtual stage (both forward and backward).
Default (in __post_init__() method below) to pipeline_parallel_size
which specifies a depth-first schedule.
Example: for PP=2 VP=2, when microbatch_group_size_per_vp_stage=2,
num_microbatches = 4, we have
rank 0 | 0 1 0 1 2 3 2 3
rank 1 | 0 1 0 1 2 3 2 3
When microbatch_group_size_per_vp_stage=3, num_microbatches = 5,
we have
rank 0 | 0 1 2 0 1 2 3 4 3 4
rank 1 | 0 1 2 0 1 2 3 4 3 4
"""
###################
# CPU Offloading
###################
cpu_offloading
:
bool
=
False
"""When set to True, all the activations are offloaded to the CPU asynchronously."""
cpu_offloading_num_layers
:
int
=
0
"""Tells the number of transformer layers for which activations has to be offloaded."""
_cpu_offloading_context
:
Optional
[
ContextManager
]
=
(
None
# Used for internal use only, not to be set by a user.
# TODO: Need to move to the 'right' place when possible.
)
"""For internal use only, do not set."""
cpu_offloading_activations
:
bool
=
True
"""If True, offloads the activations to CPU."""
cpu_offloading_weights
:
bool
=
True
"""If True, offloads the weights to CPU."""
###################
# Timing
###################
barrier_with_L1_time
:
bool
=
True
"""If true, use barrier with level 1 time measurements. It is up to the user to make sure
calling barrier with their timers will not result in hangs. This can happen if for example
the user adds a level 1 timer that is not called by all ranks.
"""
def
__post_init__
(
self
):
"""Python dataclass method that is used to modify attributes after initialization.
See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more
details.
"""
if
self
.
sequence_parallel
:
if
self
.
tensor_model_parallel_size
<=
1
:
raise
ValueError
(
"Can not use sequence paralllelism without tensor parallelism"
)
if
self
.
expert_tensor_parallel_size
is
None
:
self
.
expert_tensor_parallel_size
=
self
.
tensor_model_parallel_size
if
self
.
pipeline_model_parallel_size
>
1
:
if
self
.
pipeline_dtype
is
None
:
raise
ValueError
(
"When using pipeline parallelism, pipeline_dtype must be specified"
)
if
self
.
autocast_dtype
is
None
:
self
.
autocast_dtype
=
self
.
params_dtype
if
self
.
defer_embedding_wgrad_compute
and
self
.
pipeline_model_parallel_size
==
1
:
raise
ValueError
(
"Cannot defer embedding wgrad compute when pipeline model parallel is not used"
)
if
self
.
defer_embedding_wgrad_compute
and
not
self
.
gradient_accumulation_fusion
:
raise
ValueError
(
"Cannot defer embedding wgrad compute when gradient accumulation fusion is not used"
)
if
self
.
defer_embedding_wgrad_compute
and
self
.
wgrad_deferral_limit
<
0
:
raise
ValueError
(
"Wgrad deferral limit should be greater than or equal to 0 when it is enabled!"
)
if
self
.
expert_model_parallel_size
>
1
and
self
.
tensor_model_parallel_size
>
1
:
if
self
.
sequence_parallel
is
False
:
raise
ValueError
(
"When using expert parallelism and tensor parallelism, "
"sequence parallelism must be used"
)
if
self
.
microbatch_group_size_per_vp_stage
is
None
:
self
.
microbatch_group_size_per_vp_stage
=
self
.
pipeline_model_parallel_size
if
self
.
overlap_p2p_comm_warmup_flush
:
if
not
self
.
overlap_p2p_comm
or
self
.
batch_p2p_comm
:
raise
ValueError
(
"Pipeline parallel communication overlapping in warmup and flush is only "
"compatible with overlap_p2p_comm but not batch_p2p_comm."
)
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from
dataclasses
import
dataclass
from
typing
import
Callable
,
ContextManager
,
Optional
import
torch
@
dataclass
class
ModelParallelConfig
:
"""Base configuration for Megatron Core
The initialization function has an argument for each parameter.
"""
###################
# Model parallelism
###################
tensor_model_parallel_size
:
int
=
1
"""Intra-layer model parallelism. Splits tensors across GPU ranks."""
pipeline_model_parallel_comm_backend
:
Optional
[
str
]
=
None
"""Configuring backend option of pipeline parallel communication (e.g., nccl, ucc)
If None, the default backend will be used.
"""
pipeline_model_parallel_size
:
int
=
1
"""Inter-layer model parallelism. Splits transformer layers across GPU ranks."""
virtual_pipeline_model_parallel_size
:
Optional
[
int
]
=
None
"""Interleaved pipeline parallelism is used to improve performance by reducing the pipeline
bubble. Considers a transformer block as a list of smaller transformer (virtual) blocks.
The number of virtual blocks per pipeline model parallel rank is the virtual model parallel
size. See Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM:
arxiv.org/pdf/2104.04473.pdf for more details.
"""
sequence_parallel
:
bool
=
False
"""Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms
and dropout sequentially. See Reducing Activation Recomputation in Large Transformer Models
(https://arxiv.org/abs/2205.05198) for more details.
"""
context_parallel_size
:
int
=
1
"""Splits network input along sequence dimension across GPU ranks."""
hierarchical_context_parallel_sizes
:
Optional
[
list
[
int
]]
=
None
"""Degrees of the hierarchical context parallelism. Users should provide a list to specify
the sizes for different levels. Taking the a2a+p2p cp comm type as example, it contains
groups of two levels, so the first value of the list indicates the group size of the a2a
communication type, and the second value indicates the group size of the p2p communication
type.
"""
expert_model_parallel_size
:
int
=
1
"""Distributes Moe Experts across sub data parallel dimension."""
expert_tensor_parallel_size
:
Optional
[
int
]
=
None
"""Intra-layer tensor model parallelsm for expert layer. Splits tensors across GPU ranks."""
moe_extended_tp
:
bool
=
False
"""NOTE: Deprecated from MCore v0.10. This flag is ignored.
Its functionality is replaced by expert_tensor_parallel_size.
"""
###################
# Initialization
###################
perform_initialization
:
bool
=
True
"""If true, weights are initialized. This option can be useful when you know you are going to
load values from a checkpoint.
"""
use_cpu_initialization
:
bool
=
False
"""When set to False, we initialize the weights directly on the GPU. CPU initialization is the
same regardless of tensor model parallelism, but GPU initialization is not. Transferring
weights from CPU to GPU can take a significant amount of time for large models.
"""
###################
# Training
###################
fp16
:
bool
=
False
"""If true, train with fp16 mixed precision training."""
bf16
:
bool
=
False
"""If true, train with bf16 mixed precision training."""
params_dtype
:
torch
.
dtype
=
torch
.
float32
"""dtype used when intializing the weights."""
timers
:
Optional
[
Callable
]
=
None
"""Timers object to call for various timing functions. See megatron.core.timers.Timers"""
finalize_model_grads_func
:
Optional
[
Callable
]
=
None
"""Function that finalizes gradients on all workers. Could include ensuring that grads are
all-reduced across data parallelism, pipeline parallelism, and sequence parallelism
dimensions.
"""
grad_scale_func
:
Optional
[
Callable
]
=
None
"""If using loss scaling, this function should take the loss and return the scaled loss. If
None, no function is called on the loss.
"""
no_sync_func
:
Optional
[
Callable
]
=
None
"""Function that creates a context that suppresses asynchronous data-parallel communication. If
the model is an instance of core.distributed.DistributedDataParallel, the default is to use
core.distributed.DistributedDataParallel.no_sync.
"""
grad_sync_func
:
Optional
[
Callable
]
=
None
"""Function that launches asynchronous gradient reductions (e.g. distributed optimizer gradient
reduce-scatters). The function should take one argument: an iterable of parameters whose
gradients are to be synchronized.
"""
param_sync_func
:
Optional
[
Callable
]
=
None
"""Function that launches asynchronous parameter synchronizations (e.g. distributed optimizer
parameter all-gathers). The function should take one argument: an iterable of parameters to
be synchronized.
"""
deterministic_mode
:
bool
=
False
"""If true, code that has deterministic execution will be chosen. This usually
means slower execution, but is good for debugging and testing. Defaults to False."""
enable_autocast
:
bool
=
False
"""If true runs the forward step function inside torch.autocast context."""
autocast_dtype
:
Optional
[
torch
.
dtype
]
=
None
"""dtype to pass to torch.amp.autocast when enabled. If None, is set to pipeline_dtype."""
num_microbatches_with_partial_activation_checkpoints
:
Optional
[
int
]
=
None
"""If int, set the number of microbatches where not all of the layers will be checkpointed and
recomputed. The rest of the microbatches within the window of maximum outstanding
microbatches will recompute all layers (either full recompute or selective recompute). If
None, the checkpoint and recompute will be left up to the forward_step function.
"""
###################
# Optimizations
###################
gradient_accumulation_fusion
:
bool
=
False
"""If true, fuses weight gradient accumulation to GEMMs. Requires the custom CUDA extension
fused_weight_gradient_mlp_cuda module. To use gradient_accumulation_fusion you must install
APEX with --cpp_ext and --cuda_ext. For example: "pip install --global-option=
\"
--cpp_ext
\"
--global-option=
\"
--cuda_ext
\"
". Note that the extension requires CUDA>=11. Otherwise, you
must turn off gradient accumulation fusion.
"""
async_tensor_model_parallel_allreduce
:
bool
=
False
"""NOTE: Deprecated. This flag is ignored."""
use_te_rng_tracker
:
bool
=
False
"""If true, uses RNG state tracker in TransformerEngine if exists.
"""
tp_comm_overlap
:
bool
=
False
"""If true, allows overlapping of Linear layer execution with tensor parallel communication
collectives like AllGather/ReduceScatter. Overlapping is done for the linear layers wherever
possible during the forward and the backward pass.
"""
tp_comm_bulk_wgrad
:
bool
=
True
"""If true, allows All-Gather overlap with Bprop activation gradient GEMM. Don't care if
tp_comm_overlap is False.
"""
tp_comm_bulk_dgrad
:
bool
=
True
"""If true, allows Reduce-Scatter overlap with Bprop weight gradient GEMM. Don't care if
tp_comm_overlap is False.
"""
tp_comm_overlap_ag
:
bool
=
True
"""If true, allows All-Gather overlap with GEMM by pipelining the GEMM and All-Gather.
Don't care if tp_comm_overlap is False.
"""
tp_comm_overlap_rs
:
bool
=
True
"""If true, allows Reduce-Scatter overlap with GEMM by pipelining the GEMM and Reduce-Scatter.
Don't care if tp_comm_overlap is False.
"""
tp_comm_overlap_rs_dgrad
:
bool
=
False
"""If true, allows Reduce-Scatter overlap with DGRAD GEMM by pipelining the
GEMM and Reduce-Scatter splits. Don't care if tp_comm_overlap is False.
"""
tp_comm_split_ag
:
bool
=
True
"""Deprecated from TransformerEngine v1.6.0.
If true, allows All-Gather overlap with Fprop GEMM by pipelining the GEMM and All-Gather
splits. Don't care if tp_comm_overlap is False.
"""
tp_comm_atomic_ag
:
bool
=
False
"""Deprecated from TransformerEngine v1.6.0.
If true, allows All-Gather overlap with Fprop GEMM by pipelining the GEMM and All-Gather
both done atomically. Don't care if tp_comm_overlap is False.
"""
tp_comm_split_rs
:
bool
=
True
"""Deprecated from TransformerEngine v1.6.0.
If true, allows Reduce-Scatter overlap with Fprop GEMM by pipelining the GEMM and
Reduce-Scatter splits. Don't care if tp_comm_overlap is False.
"""
tp_comm_atomic_rs
:
bool
=
False
"""Deprecated from TransformerEngine v1.6.0.
If true, allows Reduce-Scatter overlap with Fprop GEMM by pipelining the GEMM and
Reduce-Scatter both done atomically. Don't care if tp_comm_overlap is False.
"""
cross_entropy_loss_fusion
:
bool
=
False
"""If this is enabled, the fused cross entropy implementation would be used.
Defaults to False.
"""
tp_comm_overlap_disable_qkv
:
bool
=
False
"""
If true, the AllGather -> Gemm overlap for QKV gets disabled
"""
tp_comm_overlap_disable_fc1
:
bool
=
False
"""
If true, the AllGather -> Gemm overlap for FC1 layer of MLP gets disabled
"""
tp_comm_bootstrap_backend
:
str
=
'nccl'
"""
Set the bootstrapping backend out of 'nccl', 'mpi', and 'gloo'
"""
###################
# Pipeline Parallel
###################
pipeline_dtype
:
torch
.
dtype
=
None
"""dtype used in p2p communication, usually params_dtype"""
variable_seq_lengths
:
bool
=
False
"""Support for variable sequence lengths across microbatches. Setting this communicates the size
of tensors during pipeline parallelism communication, because of this extra overhead it
should only be set if the sequence length varies by microbatch within a global batch.
"""
overlap_p2p_comm
:
bool
=
False
"""When True some of the peer to peer communication for pipeline parallelism will overlap with
computation. Must be False if batch_p2p_comm is true.
"""
batch_p2p_comm
:
bool
=
True
"""Use batch_isend_irecv instead of individual isend/irecv calls. Must be False if
overlap_p2p_comm is True.
"""
batch_p2p_sync
:
bool
=
True
"""When using batch_isend_irecv, do a cuda.device.synchronize afterward to work around a bug in
older version of PyTorch.
"""
use_ring_exchange_p2p
:
bool
=
False
"""Use custom ring_exchange kernel instead of torch.distributed.batch_isend_irecv(). Requires
custom built torch with torch.distributed.ring_exchange.
"""
deallocate_pipeline_outputs
:
bool
=
False
"""If True, output data is deallocated after the tensor is sent to the next pipeline stage.
Helps with saving memory, does nothing when pipeline parallel is not used.
"""
defer_embedding_wgrad_compute
:
bool
=
False
"""If true, defers the embedding WGRAD GEMMs while pipeline flush is
taking place enabling us to hide pipeline flush latency. Defaults to False.
"""
wgrad_deferral_limit
:
int
=
0
"""This value tunes the number of micro-batches for which the embedding weight gradient compute
needs to be deferred to pipeline flush, this argument is invalid if
`defer_embedding_wgrad_compute` is False.
Defaults to 0, which means all micro-batches are deferred.
"""
pipeline_model_parallel_split_rank
:
Optional
[
int
]
=
None
"""If int, rank where encoder and decoder should be split in cases where the model has both an
encoder and decoder (e.g., T5). Ignored if None.
"""
overlap_p2p_comm_warmup_flush
:
bool
=
False
"""If true, overlap communication and computation in warm up and flush phase.
Only valid when overlap_p2p_comm is True and batch_p2p_comm is False.
Defaults to False.
"""
microbatch_group_size_per_vp_stage
:
Optional
[
int
]
=
None
"""This value specifies the number of micro-batches that are executed
at a time for a given virtual stage (both forward and backward).
Default (in __post_init__() method below) to pipeline_parallel_size
which specifies a depth-first schedule.
Example: for PP=2 VP=2, when microbatch_group_size_per_vp_stage=2,
num_microbatches = 4, we have
rank 0 | 0 1 0 1 2 3 2 3
rank 1 | 0 1 0 1 2 3 2 3
When microbatch_group_size_per_vp_stage=3, num_microbatches = 5,
we have
rank 0 | 0 1 2 0 1 2 3 4 3 4
rank 1 | 0 1 2 0 1 2 3 4 3 4
"""
###################
# CPU Offloading
###################
cpu_offloading
:
bool
=
False
"""When set to True, all the activations are offloaded to the CPU asynchronously."""
cpu_offloading_num_layers
:
int
=
0
"""Tells the number of transformer layers for which activations has to be offloaded."""
_cpu_offloading_context
:
Optional
[
ContextManager
]
=
(
None
# Used for internal use only, not to be set by a user.
# TODO: Need to move to the 'right' place when possible.
)
"""For internal use only, do not set."""
cpu_offloading_activations
:
bool
=
True
"""If True, offloads the activations to CPU."""
cpu_offloading_weights
:
bool
=
True
"""If True, offloads the weights to CPU."""
###################
# Timing
###################
barrier_with_L1_time
:
bool
=
True
"""If true, use barrier with level 1 time measurements. It is up to the user to make sure
calling barrier with their timers will not result in hangs. This can happen if for example
the user adds a level 1 timer that is not called by all ranks.
"""
def
__post_init__
(
self
):
"""Python dataclass method that is used to modify attributes after initialization.
See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more
details.
"""
if
self
.
sequence_parallel
:
if
self
.
tensor_model_parallel_size
<=
1
:
raise
ValueError
(
"Can not use sequence paralllelism without tensor parallelism"
)
if
self
.
expert_tensor_parallel_size
is
None
:
self
.
expert_tensor_parallel_size
=
self
.
tensor_model_parallel_size
if
self
.
pipeline_model_parallel_size
>
1
:
if
self
.
pipeline_dtype
is
None
:
raise
ValueError
(
"When using pipeline parallelism, pipeline_dtype must be specified"
)
if
self
.
autocast_dtype
is
None
:
self
.
autocast_dtype
=
self
.
params_dtype
if
self
.
defer_embedding_wgrad_compute
and
self
.
pipeline_model_parallel_size
==
1
:
raise
ValueError
(
"Cannot defer embedding wgrad compute when pipeline model parallel is not used"
)
if
self
.
defer_embedding_wgrad_compute
and
not
self
.
gradient_accumulation_fusion
:
raise
ValueError
(
"Cannot defer embedding wgrad compute when gradient accumulation fusion is not used"
)
if
self
.
defer_embedding_wgrad_compute
and
self
.
wgrad_deferral_limit
<
0
:
raise
ValueError
(
"Wgrad deferral limit should be greater than or equal to 0 when it is enabled!"
)
if
self
.
expert_model_parallel_size
>
1
and
self
.
tensor_model_parallel_size
>
1
:
if
self
.
sequence_parallel
is
False
:
raise
ValueError
(
"When using expert parallelism and tensor parallelism, "
"sequence parallelism must be used"
)
if
self
.
microbatch_group_size_per_vp_stage
is
None
:
self
.
microbatch_group_size_per_vp_stage
=
self
.
pipeline_model_parallel_size
if
self
.
overlap_p2p_comm_warmup_flush
:
if
not
self
.
overlap_p2p_comm
or
self
.
batch_p2p_comm
:
raise
ValueError
(
"Pipeline parallel communication overlapping in warmup and flush is only "
"compatible with overlap_p2p_comm but not batch_p2p_comm."
)
megatron/core/models/T5/t5_model.py
View file @
688448db
...
...
@@ -10,9 +10,11 @@ from megatron.core.config_logger import has_config_logger_enabled, log_config_to
from
megatron.core.dist_checkpointing.mapping
import
ShardedStateDict
from
megatron.core.enums
import
ModelType
from
megatron.core.models.common.embeddings.language_model_embedding
import
LanguageModelEmbedding
from
megatron.core.models.common.embeddings.relative_pos_embedding
import
RelativePositionEmbedding
from
megatron.core.models.common.embeddings.rotary_pos_embedding
import
RotaryEmbedding
from
megatron.core.models.common.language_module.language_module
import
LanguageModule
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.tensor_parallel.mappings
import
scatter_to_tensor_model_parallel_region
from
megatron.core.transformer.module
import
MegatronModule
from
megatron.core.transformer.spec_utils
import
ModuleSpec
from
megatron.core.transformer.transformer_block
import
TransformerBlock
...
...
@@ -135,9 +137,13 @@ class T5Model(LanguageModule):
fp16_lm_cross_entropy
:
bool
=
False
,
parallel_output
:
bool
=
True
,
share_embeddings_and_output_weights
:
bool
=
False
,
position_embedding_type
:
Literal
[
'learned_absolute'
,
'rope'
]
=
'learned_absolute'
,
position_embedding_type
:
Literal
[
'learned_absolute'
,
'rope'
,
'relative'
]
=
'learned_absolute'
,
rotary_percent
:
float
=
1.0
,
seq_len_interpolation_factor
:
Optional
[
float
]
=
None
,
relative_attention_num_buckets
:
int
=
32
,
relative_attention_max_distance
:
int
=
128
,
add_encoder
:
bool
=
True
,
add_decoder
:
bool
=
True
,
):
...
...
@@ -193,6 +199,23 @@ class T5Model(LanguageModule):
use_cpu_initialization
=
self
.
config
.
use_cpu_initialization
,
)
# Relative Position Embeddings
if
self
.
position_embedding_type
==
'relative'
:
self
.
encoder_relative_pos_emb
=
RelativePositionEmbedding
(
bidirectional
=
True
,
init_method
=
self
.
config
.
init_method
,
num_attention_heads
=
self
.
config
.
num_attention_heads
,
relative_attention_num_buckets
=
relative_attention_num_buckets
,
relative_attention_max_distance
=
relative_attention_max_distance
,
)
self
.
decoder_relative_pos_emb
=
RelativePositionEmbedding
(
bidirectional
=
False
,
init_method
=
self
.
config
.
init_method
,
num_attention_heads
=
self
.
config
.
num_attention_heads
,
relative_attention_num_buckets
=
relative_attention_num_buckets
,
relative_attention_max_distance
=
relative_attention_max_distance
,
)
# Transformer encoder
encoder_spec
,
decoder_spec
=
(
self
.
transformer_encoder_layer_spec
,
...
...
@@ -284,6 +307,27 @@ class T5Model(LanguageModule):
)
rotary_pos_emb
=
self
.
rotary_pos_emb
(
rotary_seq_len
)
# Relative positional embeddings
encoder_attention_bias_parallel
=
None
if
self
.
position_embedding_type
==
'relative'
:
query_seq_length
=
RelativePositionEmbedding
.
get_relative_seq_len
(
inference_params
,
self
.
encoder
,
encoder_input
,
self
.
config
)
key_seq_length
=
query_seq_length
attention_bias
=
self
.
encoder_relative_pos_emb
(
query_seq_length
,
key_seq_length
)
# Scatter attention_bias to TP ranks
# First, reshape [1, num_head, seqlen_q, seqlen_kv] to
# [1, seqlen_q, seqlen_kv, num_head] to be scatter along
# the last (num_heads dimension)
attention_bias
=
torch
.
permute
(
attention_bias
,
(
0
,
2
,
3
,
1
))
# Then, scatter to TP region
attention_bias_parallel
=
scatter_to_tensor_model_parallel_region
(
attention_bias
)
# Lastly, revert the dimension back to [1, num_head, seqlen_q, seqlen_kv]
encoder_attention_bias_parallel
=
torch
.
permute
(
attention_bias_parallel
,
(
0
,
3
,
1
,
2
)
)
# Run encoder.
if
self
.
add_encoder
:
encoder_hidden_states
=
self
.
encoder
(
...
...
@@ -291,6 +335,7 @@ class T5Model(LanguageModule):
attention_mask
=
encoder_attn_mask
,
inference_params
=
inference_params
,
rotary_pos_emb
=
rotary_pos_emb
,
attention_bias
=
encoder_attention_bias_parallel
,
)
else
:
encoder_hidden_states
=
self
.
encoder_hidden_state
...
...
@@ -315,10 +360,29 @@ class T5Model(LanguageModule):
rotary_pos_emb
=
None
if
self
.
position_embedding_type
==
'rope'
:
rotary_seq_len
=
self
.
rotary_pos_emb
.
get_rotary_seq_len
(
inference_params
,
self
.
e
n
coder
,
e
n
coder_input
,
self
.
config
,
packed_seq_params
inference_params
,
self
.
d
ecoder
,
d
ecoder_input
,
self
.
config
,
packed_seq_params
)
rotary_pos_emb
=
self
.
rotary_pos_emb
(
rotary_seq_len
)
# Relative positional embeddings
decoder_attention_bias_parallel
=
None
if
self
.
position_embedding_type
==
'relative'
:
query_seq_length
=
RelativePositionEmbedding
.
get_relative_seq_len
(
inference_params
,
self
.
decoder
,
decoder_input
,
self
.
config
)
key_seq_length
=
query_seq_length
attention_bias
=
self
.
decoder_relative_pos_emb
(
query_seq_length
,
key_seq_length
)
# Scatter attention_bias to TP ranks
# First, reshape [1, num_head, seqlen_q, seqlen_kv] to
# [1, seqlen_q, seqlen_kv, num_head] to be scatter along
# the last (num_heads dimension)
attention_bias
=
torch
.
permute
(
attention_bias
,
(
0
,
2
,
3
,
1
))
# Then, scatter to TP region
attention_bias_parallel
=
scatter_to_tensor_model_parallel_region
(
attention_bias
)
# Lastly, revert the dimension back to [1, num_head, seqlen_q, seqlen_kv]
decoder_attention_bias_parallel
=
torch
.
permute
(
attention_bias_parallel
,
(
0
,
3
,
1
,
2
))
# Run decoder.
decoder_hidden_states
=
self
.
decoder
(
hidden_states
=
decoder_input
,
...
...
@@ -327,12 +391,15 @@ class T5Model(LanguageModule):
context_mask
=
encoder_decoder_attn_mask
,
inference_params
=
inference_params
,
rotary_pos_emb
=
rotary_pos_emb
,
attention_bias
=
decoder_attention_bias_parallel
,
)
if
self
.
post_process
:
lm_logits
=
self
.
lm_head
(
decoder_hidden_states
,
self
.
shared_embedding_or_output_weight
()
)
output_weight
=
None
if
self
.
share_embeddings_and_output_weights
:
output_weight
=
self
.
shared_embedding_or_output_weight
()
lm_logits
=
self
.
lm_head
(
decoder_hidden_states
,
word_embeddings_weight
=
output_weight
)
if
lm_labels
is
None
:
# [s b h] => [b s h]
return
lm_logits
.
transpose
(
0
,
1
).
contiguous
()
...
...
megatron/core/models/common/embeddings/relative_pos_embedding.py
0 → 100644
View file @
688448db
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
import
logging
import
math
from
typing
import
Callable
import
torch
from
torch
import
Tensor
,
nn
from
megatron.core.inference_params
import
InferenceParams
from
megatron.core.transformer.transformer_block
import
TransformerBlock
from
megatron.core.transformer.transformer_config
import
TransformerConfig
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
'RelativePositionEmbedding'
]
class
RelativePositionEmbedding
(
nn
.
Module
):
"""Relative Position Embedding for language model.
Args:
"""
def
__init__
(
self
,
bidirectional
:
bool
,
init_method
:
Callable
,
num_attention_heads
:
int
,
relative_attention_num_buckets
:
int
=
32
,
relative_attention_max_distance
:
int
=
128
,
)
->
None
:
super
().
__init__
()
self
.
bidirectional
=
bidirectional
self
.
relative_attention_num_buckets
=
relative_attention_num_buckets
self
.
relative_attention_max_distance
=
relative_attention_max_distance
self
.
relative_attention_bias
=
torch
.
nn
.
Embedding
(
self
.
relative_attention_num_buckets
,
num_attention_heads
)
init_method
(
self
.
relative_attention_bias
.
weight
)
def
_relative_position_bucket
(
self
,
relative_position
,
bidirectional
=
True
,
num_buckets
=
32
,
max_distance
=
128
):
"""
Adapted from HuggingFace T5 Model:
https://github.com/huggingface/transformers/blob/329f5dbf97a5cb2473914c88c05aa3dcb242e19a/
src/transformers/models/t5/modeling_t5.py#L397
Translate relative position to a bucket number for relative attention.
The relative position is defined as memory_position - query_position, i.e. the
distance in tokens from the attending position to the attended-to position.
If bidirectional=False, then positive relative positions are invalid. We use
smaller buckets for small absolute relative_position and larger buckets for
larger absolute relative_positions. All relative positions >=max_distance map
to the same bucket. All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences than the
model has been trained on.
Args:
relative_position: an int32 Tensor
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer
max_distance: an integer
Returns:
a Tensor with the same shape as relative_position,
containing int32 values in the range [0, num_buckets)
"""
relative_buckets
=
0
if
bidirectional
:
num_buckets
//=
2
relative_buckets
+=
(
relative_position
>
0
).
to
(
torch
.
long
)
*
num_buckets
relative_position
=
torch
.
abs
(
relative_position
)
else
:
relative_position
=
-
torch
.
min
(
relative_position
,
torch
.
zeros_like
(
relative_position
))
# now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions
max_exact
=
num_buckets
//
2
is_small
=
relative_position
<
max_exact
# The other half of the buckets are for logarithmically bigger
# bins in positions up to max_distance
relative_position_if_large
=
max_exact
+
(
torch
.
log
(
relative_position
.
float
()
/
max_exact
)
/
math
.
log
(
max_distance
/
max_exact
)
*
(
num_buckets
-
max_exact
)
).
to
(
torch
.
long
)
relative_position_if_large
=
torch
.
min
(
relative_position_if_large
,
torch
.
full_like
(
relative_position_if_large
,
num_buckets
-
1
)
)
relative_buckets
+=
torch
.
where
(
is_small
,
relative_position
,
relative_position_if_large
)
return
relative_buckets
def
_compute_bias
(
self
,
query_length
,
key_length
):
"""
Adapted from HuggingFace T5 Model
https://github.com/huggingface/transformers/blob/329f5dbf97a5cb2473914c88c05aa3dcb242e19a/
src/transformers/models/t5/modeling_t5.py#L444C9-L444C21
Compute binned relative position bias
Args:
query_length (int): The length of the query sequence
(e.g., the input sequence in attention).
key_length (int): The length of the key sequence
(e.g., the sequence to compare against in attention).
Returns:
torch.Tensor: A tensor representing the relative position bias, with shape
(1, num_heads, query_length, key_length).
"""
device
=
self
.
relative_attention_bias
.
weight
.
device
context_position
=
torch
.
arange
(
query_length
,
dtype
=
torch
.
long
,
device
=
device
)[:,
None
]
memory_position
=
torch
.
arange
(
key_length
,
dtype
=
torch
.
long
,
device
=
device
)[
None
,
:]
relative_position
=
memory_position
-
context_position
# shape(query_length,key_length)
relative_position_bucket
=
self
.
_relative_position_bucket
(
relative_position
,
# shape (query_length, key_length)
bidirectional
=
self
.
bidirectional
,
num_buckets
=
self
.
relative_attention_num_buckets
,
max_distance
=
self
.
relative_attention_max_distance
,
)
values
=
self
.
relative_attention_bias
(
relative_position_bucket
)
# shape(query_length,key_length,num_heads)
values
=
values
.
permute
([
2
,
0
,
1
]).
unsqueeze
(
0
)
# shape(1, num_heads,query_length,key_length)
return
values
@
staticmethod
def
get_relative_seq_len
(
inference_params
:
InferenceParams
,
transformer
:
TransformerBlock
,
transformer_input
:
Tensor
,
transformer_config
:
TransformerConfig
,
)
->
float
:
"""Function to get the rotary sequence length.
Args:
inference_params : Used during Inference time
transformer (TransformerBlock): The transformer block (decoder/encoder) used
by the model
transformer_input (Tensor): Input tensor to the transformer
transformer_config (TransformerConfig): Transformer config used by the model
Returns:
float: The rotary sequence length
"""
if
inference_params
is
not
None
:
relative_seq_len
=
inference_params
.
max_sequence_length
else
:
if
transformer
.
input_tensor
is
not
None
:
relative_seq_len
=
transformer
.
input_tensor
.
size
(
0
)
else
:
relative_seq_len
=
transformer_input
.
size
(
0
)
if
transformer_config
.
sequence_parallel
:
relative_seq_len
*=
transformer_config
.
tensor_model_parallel_size
return
relative_seq_len
def
forward
(
self
,
query_seq_length
,
key_seq_length
):
"""
Args:
Returns:
"""
return
self
.
_compute_bias
(
query_seq_length
,
key_seq_length
)
megatron/core/models/common/embeddings/rotary_pos_embedding.py
View file @
688448db
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from
__future__
import
annotations
from
typing
import
TYPE_CHECKING
if
TYPE_CHECKING
:
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core.transformer.transformer_block
import
TransformerBlock
from
megatron.core.inference_params
import
InferenceParams
from
megatron.core.packed_seq_params
import
PackedSeqParams
import
logging
import
math
from
functools
import
lru_cache
import
torch
from
torch
import
Tensor
,
nn
from
megatron.core
import
parallel_state
from
megatron.core.models.common.embeddings.rope_utils
import
(
# for backward compatibility; pylint: disable=unused-import
_apply_rotary_pos_emb_bshd
,
_apply_rotary_pos_emb_thd
,
_rotate_half
,
apply_rotary_pos_emb
,
get_pos_emb_on_this_cp_rank
,
)
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
'RotaryEmbedding'
]
class
RotaryEmbedding
(
nn
.
Module
):
"""Rotary Embedding for language model.
Args:
kv_channels (int): Projection weights dimension in multi-head attention. Obtained
from transformer config
rotary_percent (float): Percent of rotary dimension to use for rotary position
embeddings.
rotary_interleaved (bool, optional): If True, interleaved rotary position embeddings.
Defaults to False.
seq_len_interpolation_factor (float, optional): scale of linearly interpolating RoPE
for longer sequences. The value must be a float larger than 1.0. Defaults to None
rotary_base (int, optional): Base period for rotary position embeddings. Defaults to
10000.
rope_scaling (bool, optional): Apply rope scaling as used in llama 3.1
use_cpu_initialization (bool, optional): If False, initialize the inv_freq directly
on the GPU. Defaults to False
"""
def
__init__
(
self
,
kv_channels
:
int
,
rotary_percent
:
float
,
rotary_interleaved
:
bool
=
False
,
seq_len_interpolation_factor
:
float
=
None
,
rotary_base
:
int
=
10000
,
rope_scaling
:
bool
=
False
,
use_cpu_initialization
:
bool
=
False
,
)
->
None
:
super
().
__init__
()
dim
=
kv_channels
if
rotary_percent
<
1.0
:
dim
=
int
(
dim
*
rotary_percent
)
self
.
rotary_interleaved
=
rotary_interleaved
self
.
seq_len_interpolation_factor
=
seq_len_interpolation_factor
device
=
'cpu'
if
use_cpu_initialization
else
torch
.
cuda
.
current_device
()
self
.
inv_freq
=
1.0
/
(
rotary_base
**
(
torch
.
arange
(
0
,
dim
,
2
,
dtype
=
torch
.
float32
,
device
=
device
)
/
dim
)
)
if
rope_scaling
:
self
.
inv_freq
=
self
.
_apply_scaling
(
self
.
inv_freq
)
def
_apply_scaling
(
self
,
freqs
,
factor
=
8
,
low_freq_factor
=
1
,
high_freq_factor
=
4
,
original_max_position_embeddings
=
8192
,
):
# This implementation is adapted from:
# https://github.com/huggingface/transformers/blob/2a5a6ad18aa22e98429bb5ecb880660328030ea0/src/transformers/modeling_rope_utils.py#L303-L343
factor
=
factor
# `8` in the original implementation
low_freq_factor
=
low_freq_factor
# `1` in the original implementation
high_freq_factor
=
high_freq_factor
# `4` in the original implementation
old_context_len
=
original_max_position_embeddings
# `8192` in the original implementation
low_freq_wavelen
=
old_context_len
/
low_freq_factor
high_freq_wavelen
=
old_context_len
/
high_freq_factor
wavelen
=
2
*
math
.
pi
/
freqs
# wavelen < high_freq_wavelen: do nothing
# wavelen > low_freq_wavelen: divide by factor
inv_freq_llama
=
torch
.
where
(
wavelen
>
low_freq_wavelen
,
freqs
/
factor
,
freqs
)
# otherwise: interpolate between the two, using a smooth factor
smooth_factor
=
(
old_context_len
/
wavelen
-
low_freq_factor
)
/
(
high_freq_factor
-
low_freq_factor
)
smoothed_inv_freq
=
(
1
-
smooth_factor
)
*
inv_freq_llama
/
factor
+
smooth_factor
*
inv_freq_llama
is_medium_freq
=
~
(
wavelen
<
high_freq_wavelen
)
*
~
(
wavelen
>
low_freq_wavelen
)
inv_freq_llama
=
torch
.
where
(
is_medium_freq
,
smoothed_inv_freq
,
inv_freq_llama
)
return
inv_freq_llama
def
get_freqs_non_repeated
(
self
,
max_seq_len
:
int
,
offset
:
int
=
0
)
->
Tensor
:
"""Generates matrix of frequencies based on positions in the sequence,
used to create positional encodings"""
seq
=
(
torch
.
arange
(
max_seq_len
,
device
=
self
.
inv_freq
.
device
,
dtype
=
self
.
inv_freq
.
dtype
)
+
offset
)
if
self
.
seq_len_interpolation_factor
is
not
None
:
seq
*=
1
/
self
.
seq_len_interpolation_factor
freqs
=
torch
.
outer
(
seq
,
self
.
inv_freq
)
# [seq len, dim]
return
freqs
def
get_cos_sin
(
self
,
max_seq_len
:
int
,
offset
:
int
=
0
)
->
(
Tensor
,
Tensor
):
"""Cosine and sine values for RoPE are precomputed for all positions up to the maximum
sequence length"""
freqs
=
self
.
get_freqs_non_repeated
(
max_seq_len
,
offset
)
cos
=
torch
.
cos
(
freqs
)
sin
=
torch
.
sin
(
freqs
)
return
cos
,
sin
@
lru_cache
(
maxsize
=
32
)
def
forward
(
self
,
max_seq_len
:
int
,
offset
:
int
=
0
,
packed_seq
:
bool
=
False
)
->
Tensor
:
"""Forward pass of RoPE embedding.
Args:
max_seq_len (int): Maximum size of sequence
offset (int, optional): RoPE offset. Defaults to 0.
packed_seq (bool, optional): Whether to use packed sequence. Defaults to False.
Returns:
Tensor: Embeddings after applying RoPE.
"""
if
self
.
inv_freq
.
device
.
type
==
'cpu'
:
# move `inv_freq` to GPU once at the first micro-batch forward pass
self
.
inv_freq
=
self
.
inv_freq
.
to
(
device
=
torch
.
cuda
.
current_device
())
freqs
=
self
.
get_freqs_non_repeated
(
max_seq_len
,
offset
)
# first part even vector components, second part odd vector components,
# 2 * dim in dimension size
if
not
self
.
rotary_interleaved
:
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
)
else
:
emb
=
torch
.
stack
((
freqs
.
view
(
-
1
,
1
),
freqs
.
view
(
-
1
,
1
)),
dim
=-
1
).
view
(
freqs
.
shape
[
0
],
-
1
)
# emb [seq_length, .., dim]
emb
=
emb
[:,
None
,
None
,
:]
if
parallel_state
.
get_context_parallel_world_size
()
>
1
and
not
packed_seq
:
# slice rotary_pos_emb along sequence dimension and select the parition of the current
# CP rank
emb
=
get_pos_emb_on_this_cp_rank
(
emb
,
0
)
return
emb
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
state_dict
.
pop
(
f
'
{
prefix
}
inv_freq'
,
None
)
return
super
().
_load_from_state_dict
(
state_dict
,
prefix
,
*
args
,
**
kwargs
)
def
get_rotary_seq_len
(
self
,
inference_params
:
InferenceParams
,
transformer
:
TransformerBlock
,
transformer_input
:
Tensor
,
transformer_config
:
TransformerConfig
,
packed_seq_params
:
PackedSeqParams
,
)
->
float
:
"""Function to get the rotary sequence length.
Args:
inference_params : Used during Inference time
transformer (TransformerBlock): The transformer block (decoder/encoder) used
by the model
transformer_input (Tensor): Input tensor to the transformer
transformer_config (TransformerConfig): Transformer config used by the model
packed_seq_params (PackedSeqParams): Packed sequence params
Returns:
float: The rotary sequence length
"""
if
packed_seq_params
is
not
None
:
# max_seqlen are the max sequence length in the packed sequence before being divived
# by the tp and cp size.
return
max
(
packed_seq_params
.
max_seqlen_q
,
packed_seq_params
.
max_seqlen_kv
)
elif
inference_params
is
not
None
:
rotary_seq_len
=
inference_params
.
max_sequence_length
else
:
if
transformer
.
input_tensor
is
not
None
:
rotary_seq_len
=
transformer
.
input_tensor
.
size
(
0
)
else
:
rotary_seq_len
=
transformer_input
.
size
(
0
)
if
transformer_config
.
sequence_parallel
:
rotary_seq_len
*=
transformer_config
.
tensor_model_parallel_size
rotary_seq_len
*=
transformer_config
.
context_parallel_size
return
rotary_seq_len
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from
__future__
import
annotations
from
typing
import
TYPE_CHECKING
if
TYPE_CHECKING
:
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core.transformer.transformer_block
import
TransformerBlock
from
megatron.core.inference_params
import
InferenceParams
from
megatron.core.packed_seq_params
import
PackedSeqParams
import
logging
import
math
from
functools
import
lru_cache
import
torch
from
torch
import
Tensor
,
nn
from
megatron.core
import
parallel_state
from
megatron.core.models.common.embeddings.rope_utils
import
(
# for backward compatibility; pylint: disable=unused-import
_apply_rotary_pos_emb_bshd
,
_apply_rotary_pos_emb_thd
,
_rotate_half
,
apply_rotary_pos_emb
,
get_pos_emb_on_this_cp_rank
,
)
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
'RotaryEmbedding'
]
class
RotaryEmbedding
(
nn
.
Module
):
"""Rotary Embedding for language model.
Args:
kv_channels (int): Projection weights dimension in multi-head attention. Obtained
from transformer config
rotary_percent (float): Percent of rotary dimension to use for rotary position
embeddings.
rotary_interleaved (bool, optional): If True, interleaved rotary position embeddings.
Defaults to False.
seq_len_interpolation_factor (float, optional): scale of linearly interpolating RoPE
for longer sequences. The value must be a float larger than 1.0. Defaults to None
rotary_base (int, optional): Base period for rotary position embeddings. Defaults to
10000.
rope_scaling (bool, optional): Apply rope scaling as used in llama 3.x.
rope_scaling_factor (float, optional): rope scaling factor in llama 3.x. Defaults to 8.
use_cpu_initialization (bool, optional): If False, initialize the inv_freq directly
on the GPU. Defaults to False
"""
def
__init__
(
self
,
kv_channels
:
int
,
rotary_percent
:
float
,
rotary_interleaved
:
bool
=
False
,
seq_len_interpolation_factor
:
float
=
None
,
rotary_base
:
int
=
10000
,
rope_scaling
:
bool
=
False
,
rope_scaling_factor
:
float
=
8.0
,
use_cpu_initialization
:
bool
=
False
,
)
->
None
:
super
().
__init__
()
dim
=
kv_channels
if
rotary_percent
<
1.0
:
dim
=
int
(
dim
*
rotary_percent
)
self
.
rotary_interleaved
=
rotary_interleaved
self
.
seq_len_interpolation_factor
=
seq_len_interpolation_factor
device
=
'cpu'
if
use_cpu_initialization
else
torch
.
cuda
.
current_device
()
self
.
inv_freq
=
1.0
/
(
rotary_base
**
(
torch
.
arange
(
0
,
dim
,
2
,
dtype
=
torch
.
float32
,
device
=
device
)
/
dim
)
)
if
rope_scaling
:
self
.
inv_freq
=
self
.
_apply_scaling
(
self
.
inv_freq
,
factor
=
rope_scaling_factor
)
def
_apply_scaling
(
self
,
freqs
,
factor
=
8
,
low_freq_factor
=
1
,
high_freq_factor
=
4
,
original_max_position_embeddings
=
8192
,
):
# This implementation is adapted from:
# https://github.com/huggingface/transformers/blob/2a5a6ad18aa22e98429bb5ecb880660328030ea0/src/transformers/modeling_rope_utils.py#L303-L343
factor
=
factor
# `8` in the original implementation
low_freq_factor
=
low_freq_factor
# `1` in the original implementation
high_freq_factor
=
high_freq_factor
# `4` in the original implementation
old_context_len
=
original_max_position_embeddings
# `8192` in the original implementation
low_freq_wavelen
=
old_context_len
/
low_freq_factor
high_freq_wavelen
=
old_context_len
/
high_freq_factor
wavelen
=
2
*
math
.
pi
/
freqs
# wavelen < high_freq_wavelen: do nothing
# wavelen > low_freq_wavelen: divide by factor
inv_freq_llama
=
torch
.
where
(
wavelen
>
low_freq_wavelen
,
freqs
/
factor
,
freqs
)
# otherwise: interpolate between the two, using a smooth factor
smooth_factor
=
(
old_context_len
/
wavelen
-
low_freq_factor
)
/
(
high_freq_factor
-
low_freq_factor
)
smoothed_inv_freq
=
(
1
-
smooth_factor
)
*
inv_freq_llama
/
factor
+
smooth_factor
*
inv_freq_llama
is_medium_freq
=
~
(
wavelen
<
high_freq_wavelen
)
*
~
(
wavelen
>
low_freq_wavelen
)
inv_freq_llama
=
torch
.
where
(
is_medium_freq
,
smoothed_inv_freq
,
inv_freq_llama
)
return
inv_freq_llama
def
get_freqs_non_repeated
(
self
,
max_seq_len
:
int
,
offset
:
int
=
0
)
->
Tensor
:
"""Generates matrix of frequencies based on positions in the sequence,
used to create positional encodings"""
seq
=
(
torch
.
arange
(
max_seq_len
,
device
=
self
.
inv_freq
.
device
,
dtype
=
self
.
inv_freq
.
dtype
)
+
offset
)
if
self
.
seq_len_interpolation_factor
is
not
None
:
seq
*=
1
/
self
.
seq_len_interpolation_factor
freqs
=
torch
.
outer
(
seq
,
self
.
inv_freq
)
# [seq len, dim]
return
freqs
def
get_cos_sin
(
self
,
max_seq_len
:
int
,
offset
:
int
=
0
)
->
(
Tensor
,
Tensor
):
"""Cosine and sine values for RoPE are precomputed for all positions up to the maximum
sequence length"""
freqs
=
self
.
get_freqs_non_repeated
(
max_seq_len
,
offset
)
cos
=
torch
.
cos
(
freqs
)
sin
=
torch
.
sin
(
freqs
)
return
cos
,
sin
@
lru_cache
(
maxsize
=
32
)
def
forward
(
self
,
max_seq_len
:
int
,
offset
:
int
=
0
,
packed_seq
:
bool
=
False
)
->
Tensor
:
"""Forward pass of RoPE embedding.
Args:
max_seq_len (int): Maximum size of sequence
offset (int, optional): RoPE offset. Defaults to 0.
packed_seq (bool, optional): Whether to use packed sequence. Defaults to False.
Returns:
Tensor: Embeddings after applying RoPE.
"""
if
self
.
inv_freq
.
device
.
type
==
'cpu'
:
# move `inv_freq` to GPU once at the first micro-batch forward pass
self
.
inv_freq
=
self
.
inv_freq
.
to
(
device
=
torch
.
cuda
.
current_device
())
freqs
=
self
.
get_freqs_non_repeated
(
max_seq_len
,
offset
)
# first part even vector components, second part odd vector components,
# 2 * dim in dimension size
if
not
self
.
rotary_interleaved
:
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
)
else
:
emb
=
torch
.
stack
((
freqs
.
view
(
-
1
,
1
),
freqs
.
view
(
-
1
,
1
)),
dim
=-
1
).
view
(
freqs
.
shape
[
0
],
-
1
)
# emb [seq_length, .., dim]
emb
=
emb
[:,
None
,
None
,
:]
if
parallel_state
.
get_context_parallel_world_size
()
>
1
and
not
packed_seq
:
# slice rotary_pos_emb along sequence dimension and select the parition of the current
# CP rank
emb
=
get_pos_emb_on_this_cp_rank
(
emb
,
0
)
return
emb
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
state_dict
.
pop
(
f
'
{
prefix
}
inv_freq'
,
None
)
return
super
().
_load_from_state_dict
(
state_dict
,
prefix
,
*
args
,
**
kwargs
)
def
get_rotary_seq_len
(
self
,
inference_params
:
InferenceParams
,
transformer
:
TransformerBlock
,
transformer_input
:
Tensor
,
transformer_config
:
TransformerConfig
,
packed_seq_params
:
PackedSeqParams
,
)
->
float
:
"""Function to get the rotary sequence length.
Args:
inference_params : Used during Inference time
transformer (TransformerBlock): The transformer block (decoder/encoder) used
by the model
transformer_input (Tensor): Input tensor to the transformer
transformer_config (TransformerConfig): Transformer config used by the model
packed_seq_params (PackedSeqParams): Packed sequence params
Returns:
float: The rotary sequence length
"""
if
packed_seq_params
is
not
None
:
# max_seqlen are the max sequence length in the packed sequence before being divived
# by the tp and cp size.
return
max
(
packed_seq_params
.
max_seqlen_q
,
packed_seq_params
.
max_seqlen_kv
)
elif
inference_params
is
not
None
:
rotary_seq_len
=
inference_params
.
max_sequence_length
else
:
if
transformer
is
not
None
and
transformer
.
input_tensor
is
not
None
:
rotary_seq_len
=
transformer
.
input_tensor
.
size
(
0
)
else
:
rotary_seq_len
=
transformer_input
.
size
(
0
)
if
transformer_config
.
sequence_parallel
:
rotary_seq_len
*=
transformer_config
.
tensor_model_parallel_size
rotary_seq_len
*=
transformer_config
.
context_parallel_size
return
rotary_seq_len
megatron/core/models/gpt/gpt_layer_specs.py
View file @
688448db
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import
warnings
from
typing
import
Optional
from
megatron.core.fusions.fused_bias_dropout
import
get_bias_dropout_add
from
megatron.core.models.gpt.moe_module_specs
import
get_moe_module_spec
from
megatron.core.tensor_parallel.layers
import
ColumnParallelLinear
,
RowParallelLinear
from
megatron.core.transformer.attention
import
SelfAttention
,
SelfAttentionSubmodules
from
megatron.core.transformer.dot_product_attention
import
DotProductAttention
from
megatron.core.transformer.enums
import
AttnMaskType
from
megatron.core.transformer.identity_op
import
IdentityOp
from
megatron.core.transformer.mlp
import
MLP
,
MLPSubmodules
from
megatron.core.transformer.multi_latent_attention
import
(
MLASelfAttention
,
MLASelfAttentionSubmodules
,
)
from
megatron.core.transformer.spec_utils
import
ModuleSpec
from
megatron.core.transformer.transformer_block
import
(
TransformerBlockSubmodules
,
get_num_layers_to_build
,
)
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core.transformer.transformer_layer
import
TransformerLayer
,
TransformerLayerSubmodules
from
megatron.core.utils
import
is_te_min_version
try
:
from
megatron.core.extensions.transformer_engine
import
(
TEColumnParallelLinear
,
TEDotProductAttention
,
TELayerNormColumnParallelLinear
,
TENorm
,
TERowParallelLinear
,
)
HAVE_TE
=
True
except
ImportError
:
HAVE_TE
=
False
try
:
import
apex
# pylint: disable=unused-import
from
megatron.core.fusions.fused_layer_norm
import
FusedLayerNorm
HAVE_APEX
=
True
LNImpl
=
FusedLayerNorm
except
ImportError
:
from
megatron.core.transformer.torch_norm
import
WrappedTorchNorm
warnings
.
warn
(
'Apex is not installed. Falling back to Torch Norm'
)
LNImpl
=
WrappedTorchNorm
def
get_gpt_layer_with_transformer_engine_spec
(
num_experts
:
Optional
[
int
]
=
None
,
moe_grouped_gemm
:
Optional
[
bool
]
=
False
,
qk_layernorm
:
Optional
[
bool
]
=
False
,
multi_latent_attention
:
Optional
[
bool
]
=
False
,
fp8
:
Optional
[
str
]
=
None
,
# pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm
:
Optional
[
bool
]
=
False
,
)
->
ModuleSpec
:
"""Use this spec to use lower-level Transformer Engine modules (required for fp8 training).
Args:
num_experts (int, optional): Number of experts. Defaults to None.
moe_grouped_gemm (bool, optional): To use Grouped GEMM. Defaults to False.
qk_layernorm (bool, optional): To use layernorm for queries/keys. Defaults to False.
fp8 (str, optional): Deprecated. For temporary Nemo compatibility.
moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP.
Defaults to False.
Returns:
ModuleSpec: Module specification with TE modules
"""
if
fp8
is
not
None
:
warnings
.
warn
(
'The fp8 argument in "get_gpt_layer_with_transformer_engine_spec" has been deprecated'
' and will be removed soon. Please update your code accordingly.'
)
mlp
=
_get_mlp_module_spec
(
use_te
=
True
,
num_experts
=
num_experts
,
moe_grouped_gemm
=
moe_grouped_gemm
,
moe_use_legacy_grouped_gemm
=
moe_use_legacy_grouped_gemm
,
)
if
multi_latent_attention
:
return
ModuleSpec
(
module
=
TransformerLayer
,
submodules
=
TransformerLayerSubmodules
(
input_layernorm
=
TENorm
,
self_attention
=
ModuleSpec
(
module
=
MLASelfAttention
,
params
=
{
"attn_mask_type"
:
AttnMaskType
.
causal
},
submodules
=
MLASelfAttentionSubmodules
(
linear_q_proj
=
TEColumnParallelLinear
,
linear_q_down_proj
=
TEColumnParallelLinear
,
linear_q_up_proj
=
TEColumnParallelLinear
,
linear_kv_down_proj
=
TEColumnParallelLinear
,
linear_kv_up_proj
=
TEColumnParallelLinear
,
core_attention
=
TEDotProductAttention
,
linear_proj
=
TERowParallelLinear
,
q_layernorm
=
TENorm
if
qk_layernorm
else
IdentityOp
,
kv_layernorm
=
TENorm
if
qk_layernorm
else
IdentityOp
,
),
),
self_attn_bda
=
get_bias_dropout_add
,
pre_mlp_layernorm
=
TENorm
if
num_experts
else
IdentityOp
,
mlp
=
mlp
,
mlp_bda
=
get_bias_dropout_add
,
),
)
else
:
# TENorm significantly harms convergence when used
# for QKLayerNorm if TE Version < 1.9;
# we instead use the Apex implementation.
qk_norm
=
TENorm
if
is_te_min_version
(
"1.9.0"
)
else
FusedLayerNorm
return
ModuleSpec
(
module
=
TransformerLayer
,
submodules
=
TransformerLayerSubmodules
(
self_attention
=
ModuleSpec
(
module
=
SelfAttention
,
params
=
{
"attn_mask_type"
:
AttnMaskType
.
causal
},
submodules
=
SelfAttentionSubmodules
(
linear_qkv
=
TELayerNormColumnParallelLinear
,
core_attention
=
TEDotProductAttention
,
linear_proj
=
TERowParallelLinear
,
q_layernorm
=
qk_norm
if
qk_layernorm
else
IdentityOp
,
k_layernorm
=
qk_norm
if
qk_layernorm
else
IdentityOp
,
),
),
self_attn_bda
=
get_bias_dropout_add
,
pre_mlp_layernorm
=
TENorm
if
num_experts
else
IdentityOp
,
mlp
=
mlp
,
mlp_bda
=
get_bias_dropout_add
,
),
)
def
get_gpt_layer_local_spec
(
num_experts
:
Optional
[
int
]
=
None
,
moe_grouped_gemm
:
Optional
[
bool
]
=
False
,
qk_layernorm
:
Optional
[
bool
]
=
False
,
multi_latent_attention
:
Optional
[
bool
]
=
False
,
fp8
:
Optional
[
str
]
=
None
,
# pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm
:
Optional
[
bool
]
=
False
,
)
->
ModuleSpec
:
"""Use this spec for an implementation using only modules in Megatron-Core.
Args:
num_experts (int, optional): Number of experts. Defaults to None.
moe_grouped_gemm (bool, optional): To use Grouped GEMM. Defaults to False.
qk_layernorm (bool, optional): To use layernorm for queries/keys. Defaults to False.
fp8 (str, optional): Deprecated. For temporary Nemo compatibility.
moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP.
Defaults to False.
Returns:
ModuleSpec: Module specification with Megatron-Core modules
"""
if
fp8
is
not
None
:
warnings
.
warn
(
'The fp8 argument in "get_gpt_layer_local_spec" has been deprecated'
' and will be removed soon. Please update your code accordingly.'
)
mlp
=
_get_mlp_module_spec
(
use_te
=
False
,
num_experts
=
num_experts
,
moe_grouped_gemm
=
moe_grouped_gemm
,
moe_use_legacy_grouped_gemm
=
moe_use_legacy_grouped_gemm
,
)
if
multi_latent_attention
:
return
ModuleSpec
(
module
=
TransformerLayer
,
submodules
=
TransformerLayerSubmodules
(
input_layernorm
=
LNImpl
,
self_attention
=
ModuleSpec
(
module
=
MLASelfAttention
,
params
=
{
"attn_mask_type"
:
AttnMaskType
.
causal
},
submodules
=
MLASelfAttentionSubmodules
(
linear_q_proj
=
ColumnParallelLinear
,
linear_q_down_proj
=
ColumnParallelLinear
,
linear_q_up_proj
=
ColumnParallelLinear
,
linear_kv_down_proj
=
ColumnParallelLinear
,
linear_kv_up_proj
=
ColumnParallelLinear
,
core_attention
=
DotProductAttention
,
linear_proj
=
RowParallelLinear
,
q_layernorm
=
LNImpl
if
qk_layernorm
else
IdentityOp
,
kv_layernorm
=
LNImpl
if
qk_layernorm
else
IdentityOp
,
),
),
self_attn_bda
=
get_bias_dropout_add
,
pre_mlp_layernorm
=
LNImpl
,
mlp
=
mlp
,
mlp_bda
=
get_bias_dropout_add
,
),
)
else
:
return
ModuleSpec
(
module
=
TransformerLayer
,
submodules
=
TransformerLayerSubmodules
(
input_layernorm
=
LNImpl
,
self_attention
=
ModuleSpec
(
module
=
SelfAttention
,
params
=
{
"attn_mask_type"
:
AttnMaskType
.
causal
},
submodules
=
SelfAttentionSubmodules
(
linear_qkv
=
ColumnParallelLinear
,
core_attention
=
DotProductAttention
,
linear_proj
=
RowParallelLinear
,
q_layernorm
=
LNImpl
if
qk_layernorm
else
IdentityOp
,
k_layernorm
=
LNImpl
if
qk_layernorm
else
IdentityOp
,
),
),
self_attn_bda
=
get_bias_dropout_add
,
pre_mlp_layernorm
=
LNImpl
,
mlp
=
mlp
,
mlp_bda
=
get_bias_dropout_add
,
sharded_state_dict_keys_map
=
{
'input_layernorm.'
:
'self_attention.linear_qkv.layer_norm_'
,
'pre_mlp_layernorm.'
:
'mlp.linear_fc1.layer_norm_'
,
},
),
)
def
_get_mlp_module_spec
(
use_te
:
Optional
[
bool
]
=
True
,
num_experts
:
Optional
[
int
]
=
None
,
moe_grouped_gemm
:
Optional
[
bool
]
=
False
,
fp8
:
Optional
[
str
]
=
None
,
# pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm
:
Optional
[
bool
]
=
False
,
)
->
ModuleSpec
:
"""Helper function to get module spec for MLP/MoE"""
if
fp8
is
not
None
:
warnings
.
warn
(
'The fp8 argument in "_get_mlp_module_spec" has been deprecated'
' and will be removed soon. Please update your code accordingly.'
)
if
num_experts
is
None
:
# Dense MLP w/ or w/o TE modules.
return
ModuleSpec
(
module
=
MLP
,
submodules
=
MLPSubmodules
(
linear_fc1
=
TELayerNormColumnParallelLinear
if
use_te
else
ColumnParallelLinear
,
linear_fc2
=
TERowParallelLinear
if
use_te
else
RowParallelLinear
,
),
)
else
:
# Mixture of experts with modules in megatron core.
return
get_moe_module_spec
(
use_te
=
use_te
,
num_experts
=
num_experts
,
moe_grouped_gemm
=
moe_grouped_gemm
,
moe_use_legacy_grouped_gemm
=
moe_use_legacy_grouped_gemm
,
)
def
get_gpt_decoder_block_spec
(
config
:
TransformerConfig
,
use_transformer_engine
:
bool
)
->
TransformerBlockSubmodules
:
"""GPT block spec."""
if
use_transformer_engine
:
layer_norm_impl
=
TENorm
else
:
layer_norm_impl
=
LNImpl
# Layer specs.
dense_layer_spec
=
(
get_gpt_layer_with_transformer_engine_spec
(
num_experts
=
None
,
moe_grouped_gemm
=
False
,
qk_layernorm
=
config
.
qk_layernorm
,
multi_latent_attention
=
config
.
multi_latent_attention
,
moe_use_legacy_grouped_gemm
=
config
.
moe_use_legacy_grouped_gemm
,
)
if
use_transformer_engine
else
get_gpt_layer_local_spec
(
num_experts
=
None
,
moe_grouped_gemm
=
False
,
qk_layernorm
=
config
.
qk_layernorm
,
multi_latent_attention
=
config
.
multi_latent_attention
,
moe_use_legacy_grouped_gemm
=
config
.
moe_use_legacy_grouped_gemm
,
)
)
moe_layer_spec
=
(
get_gpt_layer_with_transformer_engine_spec
(
num_experts
=
config
.
num_moe_experts
,
moe_grouped_gemm
=
config
.
moe_grouped_gemm
,
qk_layernorm
=
config
.
qk_layernorm
,
multi_latent_attention
=
config
.
multi_latent_attention
,
moe_use_legacy_grouped_gemm
=
config
.
moe_use_legacy_grouped_gemm
,
)
if
use_transformer_engine
else
get_gpt_layer_local_spec
(
num_experts
=
config
.
num_moe_experts
,
moe_grouped_gemm
=
config
.
moe_grouped_gemm
,
qk_layernorm
=
config
.
qk_layernorm
,
multi_latent_attention
=
config
.
multi_latent_attention
,
moe_use_legacy_grouped_gemm
=
config
.
moe_use_legacy_grouped_gemm
,
)
)
# Parse config.moe_layer_freq to determine the pattern of expert/dense layers.
# 0 stands for dense layers, 1 stands for expert layers.
# For integer N: Creates a pattern with one expert layer every N layers.
# For string pattern: Evaluates the str directly (e.g. "[1,0,1]" for alternating expert/dense).
if
isinstance
(
config
.
moe_layer_freq
,
int
):
moe_layer_pattern
=
[
1
if
(
i
%
config
.
moe_layer_freq
==
0
)
else
0
for
i
in
range
(
config
.
num_layers
)
]
elif
isinstance
(
config
.
moe_layer_freq
,
list
):
moe_layer_pattern
=
config
.
moe_layer_freq
assert
len
(
moe_layer_pattern
)
==
config
.
num_layers
,
(
f
"Invalid length of moe_layer_pattern:
{
len
(
moe_layer_pattern
)
}
, "
f
"expected
{
config
.
num_layers
}
, "
f
"current moe layer pattern:
{
config
.
moe_layer_freq
}
"
)
else
:
raise
ValueError
(
f
"Invalid moe_layer_freq:
{
type
(
config
.
moe_layer_freq
)
}
,
{
config
.
moe_layer_freq
}
"
)
# Create the layer specs for the model.
layer_specs
=
[]
for
layer_number
in
range
(
config
.
num_layers
):
if
moe_layer_pattern
[
layer_number
]
==
1
:
layer_specs
.
append
(
moe_layer_spec
)
elif
moe_layer_pattern
[
layer_number
]
==
0
:
layer_specs
.
append
(
dense_layer_spec
)
else
:
raise
ValueError
(
f
"Invalid layer pattern:
{
moe_layer_pattern
}
"
)
# Slice the layer specs to only include the layers that are built in this pipeline stage.
# Note: MCore layer_number starts at 1
offset
=
TransformerLayer
.
_get_layer_offset
(
config
)
num_layers_to_build
=
get_num_layers_to_build
(
config
)
layer_specs
=
layer_specs
[
offset
:
offset
+
num_layers_to_build
]
# Block spec.
block_spec
=
TransformerBlockSubmodules
(
layer_specs
=
layer_specs
,
layer_norm
=
layer_norm_impl
)
return
block_spec
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import
warnings
from
typing
import
Optional
from
megatron.core.fusions.fused_bias_dropout
import
get_bias_dropout_add
from
megatron.core.models.gpt.moe_module_specs
import
get_moe_module_spec
from
megatron.core.tensor_parallel.layers
import
ColumnParallelLinear
,
RowParallelLinear
from
megatron.core.transformer.attention
import
SelfAttention
,
SelfAttentionSubmodules
from
megatron.core.transformer.dot_product_attention
import
DotProductAttention
from
megatron.core.transformer.enums
import
AttnMaskType
from
megatron.core.transformer.identity_op
import
IdentityOp
from
megatron.core.transformer.mlp
import
MLP
,
MLPSubmodules
from
megatron.core.transformer.multi_latent_attention
import
(
MLASelfAttention
,
MLASelfAttentionSubmodules
,
)
from
megatron.core.transformer.spec_utils
import
ModuleSpec
from
megatron.core.transformer.transformer_block
import
(
TransformerBlockSubmodules
,
get_num_layers_to_build
,
)
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core.transformer.transformer_layer
import
(
TransformerLayer
,
TransformerLayerSubmodules
,
get_transformer_layer_offset
,
)
from
megatron.core.utils
import
is_te_min_version
try
:
from
megatron.core.extensions.transformer_engine
import
(
TEColumnParallelLinear
,
TEDotProductAttention
,
TELayerNormColumnParallelLinear
,
TENorm
,
TERowParallelLinear
,
)
HAVE_TE
=
True
except
ImportError
:
HAVE_TE
=
False
try
:
import
apex
# pylint: disable=unused-import
from
megatron.core.fusions.fused_layer_norm
import
FusedLayerNorm
HAVE_APEX
=
True
LNImpl
=
FusedLayerNorm
except
ImportError
:
from
megatron.core.transformer.torch_norm
import
WrappedTorchNorm
warnings
.
warn
(
'Apex is not installed. Falling back to Torch Norm'
)
LNImpl
=
WrappedTorchNorm
def
get_gpt_layer_with_transformer_engine_spec
(
num_experts
:
Optional
[
int
]
=
None
,
moe_grouped_gemm
:
Optional
[
bool
]
=
False
,
qk_layernorm
:
Optional
[
bool
]
=
False
,
multi_latent_attention
:
Optional
[
bool
]
=
False
,
fp8
:
Optional
[
str
]
=
None
,
# pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm
:
Optional
[
bool
]
=
False
,
)
->
ModuleSpec
:
"""Use this spec to use lower-level Transformer Engine modules (required for fp8 training).
Args:
num_experts (int, optional): Number of experts. Defaults to None.
moe_grouped_gemm (bool, optional): To use Grouped GEMM. Defaults to False.
qk_layernorm (bool, optional): To use layernorm for queries/keys. Defaults to False.
fp8 (str, optional): Deprecated. For temporary Nemo compatibility.
moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP.
Defaults to False.
Returns:
ModuleSpec: Module specification with TE modules
"""
if
fp8
is
not
None
:
warnings
.
warn
(
'The fp8 argument in "get_gpt_layer_with_transformer_engine_spec" has been deprecated'
' and will be removed soon. Please update your code accordingly.'
)
mlp
=
get_mlp_module_spec
(
use_te
=
True
,
num_experts
=
num_experts
,
moe_grouped_gemm
=
moe_grouped_gemm
,
moe_use_legacy_grouped_gemm
=
moe_use_legacy_grouped_gemm
,
)
if
multi_latent_attention
:
return
ModuleSpec
(
module
=
TransformerLayer
,
submodules
=
TransformerLayerSubmodules
(
input_layernorm
=
TENorm
,
self_attention
=
ModuleSpec
(
module
=
MLASelfAttention
,
params
=
{
"attn_mask_type"
:
AttnMaskType
.
causal
},
submodules
=
MLASelfAttentionSubmodules
(
linear_q_proj
=
TEColumnParallelLinear
,
linear_q_down_proj
=
TEColumnParallelLinear
,
linear_q_up_proj
=
(
TELayerNormColumnParallelLinear
if
qk_layernorm
else
TEColumnParallelLinear
),
linear_kv_down_proj
=
TEColumnParallelLinear
,
linear_kv_up_proj
=
(
TELayerNormColumnParallelLinear
if
qk_layernorm
else
TEColumnParallelLinear
),
core_attention
=
TEDotProductAttention
,
linear_proj
=
TERowParallelLinear
,
q_layernorm
=
IdentityOp
,
kv_layernorm
=
IdentityOp
,
),
),
self_attn_bda
=
get_bias_dropout_add
,
pre_mlp_layernorm
=
TENorm
if
num_experts
else
IdentityOp
,
mlp
=
mlp
,
mlp_bda
=
get_bias_dropout_add
,
),
)
else
:
# TENorm significantly harms convergence when used
# for QKLayerNorm if TE Version < 1.9;
# we instead use the Apex implementation.
qk_norm
=
TENorm
if
is_te_min_version
(
"1.9.0"
)
else
FusedLayerNorm
return
ModuleSpec
(
module
=
TransformerLayer
,
submodules
=
TransformerLayerSubmodules
(
self_attention
=
ModuleSpec
(
module
=
SelfAttention
,
params
=
{
"attn_mask_type"
:
AttnMaskType
.
causal
},
submodules
=
SelfAttentionSubmodules
(
linear_qkv
=
TELayerNormColumnParallelLinear
,
core_attention
=
TEDotProductAttention
,
linear_proj
=
TERowParallelLinear
,
q_layernorm
=
qk_norm
if
qk_layernorm
else
IdentityOp
,
k_layernorm
=
qk_norm
if
qk_layernorm
else
IdentityOp
,
),
),
self_attn_bda
=
get_bias_dropout_add
,
pre_mlp_layernorm
=
TENorm
if
num_experts
else
IdentityOp
,
mlp
=
mlp
,
mlp_bda
=
get_bias_dropout_add
,
),
)
def
get_gpt_layer_local_spec
(
num_experts
:
Optional
[
int
]
=
None
,
moe_grouped_gemm
:
Optional
[
bool
]
=
False
,
qk_layernorm
:
Optional
[
bool
]
=
False
,
multi_latent_attention
:
Optional
[
bool
]
=
False
,
fp8
:
Optional
[
str
]
=
None
,
# pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm
:
Optional
[
bool
]
=
False
,
)
->
ModuleSpec
:
"""Use this spec for an implementation using only modules in Megatron-Core.
Args:
num_experts (int, optional): Number of experts. Defaults to None.
moe_grouped_gemm (bool, optional): To use Grouped GEMM. Defaults to False.
qk_layernorm (bool, optional): To use layernorm for queries/keys. Defaults to False.
fp8 (str, optional): Deprecated. For temporary Nemo compatibility.
moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP.
Defaults to False.
Returns:
ModuleSpec: Module specification with Megatron-Core modules
"""
if
fp8
is
not
None
:
warnings
.
warn
(
'The fp8 argument in "get_gpt_layer_local_spec" has been deprecated'
' and will be removed soon. Please update your code accordingly.'
)
mlp
=
get_mlp_module_spec
(
use_te
=
False
,
num_experts
=
num_experts
,
moe_grouped_gemm
=
moe_grouped_gemm
,
moe_use_legacy_grouped_gemm
=
moe_use_legacy_grouped_gemm
,
)
if
multi_latent_attention
:
return
ModuleSpec
(
module
=
TransformerLayer
,
submodules
=
TransformerLayerSubmodules
(
input_layernorm
=
LNImpl
,
self_attention
=
ModuleSpec
(
module
=
MLASelfAttention
,
params
=
{
"attn_mask_type"
:
AttnMaskType
.
causal
},
submodules
=
MLASelfAttentionSubmodules
(
linear_q_proj
=
ColumnParallelLinear
,
linear_q_down_proj
=
ColumnParallelLinear
,
linear_q_up_proj
=
ColumnParallelLinear
,
linear_kv_down_proj
=
ColumnParallelLinear
,
linear_kv_up_proj
=
ColumnParallelLinear
,
core_attention
=
DotProductAttention
,
linear_proj
=
RowParallelLinear
,
q_layernorm
=
LNImpl
if
qk_layernorm
else
IdentityOp
,
kv_layernorm
=
LNImpl
if
qk_layernorm
else
IdentityOp
,
),
),
self_attn_bda
=
get_bias_dropout_add
,
pre_mlp_layernorm
=
LNImpl
,
mlp
=
mlp
,
mlp_bda
=
get_bias_dropout_add
,
),
)
else
:
return
ModuleSpec
(
module
=
TransformerLayer
,
submodules
=
TransformerLayerSubmodules
(
input_layernorm
=
LNImpl
,
self_attention
=
ModuleSpec
(
module
=
SelfAttention
,
params
=
{
"attn_mask_type"
:
AttnMaskType
.
causal
},
submodules
=
SelfAttentionSubmodules
(
linear_qkv
=
ColumnParallelLinear
,
core_attention
=
DotProductAttention
,
linear_proj
=
RowParallelLinear
,
q_layernorm
=
LNImpl
if
qk_layernorm
else
IdentityOp
,
k_layernorm
=
LNImpl
if
qk_layernorm
else
IdentityOp
,
),
),
self_attn_bda
=
get_bias_dropout_add
,
pre_mlp_layernorm
=
LNImpl
,
mlp
=
mlp
,
mlp_bda
=
get_bias_dropout_add
,
sharded_state_dict_keys_map
=
{
'input_layernorm.'
:
'self_attention.linear_qkv.layer_norm_'
,
'pre_mlp_layernorm.'
:
'mlp.linear_fc1.layer_norm_'
,
},
),
)
def
_get_mlp_module_spec
(
use_te
:
Optional
[
bool
]
=
True
,
num_experts
:
Optional
[
int
]
=
None
,
moe_grouped_gemm
:
Optional
[
bool
]
=
False
,
fp8
:
Optional
[
str
]
=
None
,
# pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm
:
Optional
[
bool
]
=
False
,
):
warnings
.
warn
(
"""This private function is on a deprecation track. Please switch to `get_mlp_module_spec`
since it will be removed in a future release."""
)
return
get_mlp_module_spec
(
use_te
=
use_te
,
num_experts
=
num_experts
,
moe_grouped_gemm
=
moe_grouped_gemm
,
fp8
=
fp8
,
moe_use_legacy_grouped_gemm
=
moe_use_legacy_grouped_gemm
,
)
def
get_mlp_module_spec
(
use_te
:
Optional
[
bool
]
=
True
,
num_experts
:
Optional
[
int
]
=
None
,
moe_grouped_gemm
:
Optional
[
bool
]
=
False
,
fp8
:
Optional
[
str
]
=
None
,
# pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm
:
Optional
[
bool
]
=
False
,
)
->
ModuleSpec
:
"""Helper function to get module spec for MLP/MoE"""
if
fp8
is
not
None
:
warnings
.
warn
(
'The fp8 argument in "_get_mlp_module_spec" has been deprecated'
' and will be removed soon. Please update your code accordingly.'
)
if
num_experts
is
None
:
# Dense MLP w/ or w/o TE modules.
return
ModuleSpec
(
module
=
MLP
,
submodules
=
MLPSubmodules
(
linear_fc1
=
TELayerNormColumnParallelLinear
if
use_te
else
ColumnParallelLinear
,
linear_fc2
=
TERowParallelLinear
if
use_te
else
RowParallelLinear
,
),
)
else
:
# Mixture of experts with modules in megatron core.
return
get_moe_module_spec
(
use_te
=
use_te
,
num_experts
=
num_experts
,
moe_grouped_gemm
=
moe_grouped_gemm
,
moe_use_legacy_grouped_gemm
=
moe_use_legacy_grouped_gemm
,
)
def
get_gpt_decoder_block_spec
(
config
:
TransformerConfig
,
use_transformer_engine
:
bool
)
->
TransformerBlockSubmodules
:
"""GPT block spec."""
if
use_transformer_engine
:
layer_norm_impl
=
TENorm
else
:
layer_norm_impl
=
LNImpl
# Layer specs.
dense_layer_spec
=
(
get_gpt_layer_with_transformer_engine_spec
(
num_experts
=
None
,
moe_grouped_gemm
=
False
,
qk_layernorm
=
config
.
qk_layernorm
,
multi_latent_attention
=
config
.
multi_latent_attention
,
moe_use_legacy_grouped_gemm
=
config
.
moe_use_legacy_grouped_gemm
,
)
if
use_transformer_engine
else
get_gpt_layer_local_spec
(
num_experts
=
None
,
moe_grouped_gemm
=
False
,
qk_layernorm
=
config
.
qk_layernorm
,
multi_latent_attention
=
config
.
multi_latent_attention
,
moe_use_legacy_grouped_gemm
=
config
.
moe_use_legacy_grouped_gemm
,
)
)
moe_layer_spec
=
(
get_gpt_layer_with_transformer_engine_spec
(
num_experts
=
config
.
num_moe_experts
,
moe_grouped_gemm
=
config
.
moe_grouped_gemm
,
qk_layernorm
=
config
.
qk_layernorm
,
multi_latent_attention
=
config
.
multi_latent_attention
,
moe_use_legacy_grouped_gemm
=
config
.
moe_use_legacy_grouped_gemm
,
)
if
use_transformer_engine
else
get_gpt_layer_local_spec
(
num_experts
=
config
.
num_moe_experts
,
moe_grouped_gemm
=
config
.
moe_grouped_gemm
,
qk_layernorm
=
config
.
qk_layernorm
,
multi_latent_attention
=
config
.
multi_latent_attention
,
moe_use_legacy_grouped_gemm
=
config
.
moe_use_legacy_grouped_gemm
,
)
)
# Parse config.moe_layer_freq to determine the pattern of expert/dense layers.
# 0 stands for dense layers, 1 stands for expert layers.
# For integer N: Creates a pattern with one expert layer every N layers.
# For string pattern: Evaluates the str directly (e.g. "[1,0,1]" for alternating expert/dense).
if
isinstance
(
config
.
moe_layer_freq
,
int
):
moe_layer_pattern
=
[
1
if
(
i
%
config
.
moe_layer_freq
==
0
)
else
0
for
i
in
range
(
config
.
num_layers
)
]
elif
isinstance
(
config
.
moe_layer_freq
,
list
):
moe_layer_pattern
=
config
.
moe_layer_freq
assert
len
(
moe_layer_pattern
)
==
config
.
num_layers
,
(
f
"Invalid length of moe_layer_pattern:
{
len
(
moe_layer_pattern
)
}
, "
f
"expected
{
config
.
num_layers
}
, "
f
"current moe layer pattern:
{
config
.
moe_layer_freq
}
"
)
else
:
raise
ValueError
(
f
"Invalid moe_layer_freq:
{
type
(
config
.
moe_layer_freq
)
}
,
{
config
.
moe_layer_freq
}
"
)
# Create the layer specs for the model.
layer_specs
=
[]
for
layer_number
in
range
(
config
.
num_layers
):
if
moe_layer_pattern
[
layer_number
]
==
1
:
layer_specs
.
append
(
moe_layer_spec
)
elif
moe_layer_pattern
[
layer_number
]
==
0
:
layer_specs
.
append
(
dense_layer_spec
)
else
:
raise
ValueError
(
f
"Invalid layer pattern:
{
moe_layer_pattern
}
"
)
# Slice the layer specs to only include the layers that are built in this pipeline stage.
# Note: MCore layer_number starts at 1
offset
=
get_transformer_layer_offset
(
config
)
num_layers_to_build
=
get_num_layers_to_build
(
config
)
layer_specs
=
layer_specs
[
offset
:
offset
+
num_layers_to_build
]
# Block spec.
block_spec
=
TransformerBlockSubmodules
(
layer_specs
=
layer_specs
,
layer_norm
=
layer_norm_impl
)
return
block_spec
megatron/core/models/gpt/gpt_model.py
View file @
688448db
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from
collections
import
OrderedDict
from
typing
import
Dict
,
Literal
,
Optional
from
torch
import
Tensor
from
megatron.core
import
InferenceParams
,
tensor_parallel
from
megatron.core.config_logger
import
has_config_logger_enabled
,
log_config_to_disk
from
megatron.core.dist_checkpointing.mapping
import
ShardedStateDict
from
megatron.core.models.common.embeddings.language_model_embedding
import
LanguageModelEmbedding
from
megatron.core.models.common.embeddings.rotary_pos_embedding
import
RotaryEmbedding
from
megatron.core.models.common.language_module.language_module
import
LanguageModule
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.transformer.enums
import
ModelType
from
megatron.core.transformer.spec_utils
import
ModuleSpec
from
megatron.core.transformer.transformer_block
import
TransformerBlock
from
megatron.core.transformer.transformer_config
import
TransformerConfig
class
GPTModel
(
LanguageModule
):
"""GPT Transformer language model.
Args:
config (TransformerConfig):
Transformer config
transformer_layer_spec (ModuleSpec):
Specifies module to use for transformer layers
vocab_size (int):
Vocabulary size
max_sequence_length (int):
maximum size of sequence. This is used for positional embedding
pre_process (bool, optional):
Include embedding layer (used with pipeline parallelism). Defaults to True.
post_process (bool, optional):
Include an output layer (used with pipeline parallelism). Defaults to True.
fp16_lm_cross_entropy (bool, optional):
Defaults to False.
parallel_output (bool, optional):
Do not gather the outputs, keep them split across tensor
parallel ranks. Defaults to True.
share_embeddings_and_output_weights (bool, optional):
When True, input embeddings and output logit weights are shared. Defaults to False.
position_embedding_type (Literal[learned_absolute,rope], optional):
Position embedding type.. Defaults to 'learned_absolute'.
rotary_percent (float, optional):
Percent of rotary dimension to use for rotary position embeddings.
Ignored unless position_embedding_type is 'rope'. Defaults to 1.0.
rotary_base (int, optional):
Base period for rotary position embeddings. Ignored unless
position_embedding_type is 'rope'.
Defaults to 10000.
scatter_embedding_sequence_parallel (bool, optional):
Whether embeddings should be scattered across sequence parallel
region or not. Defaults to True.
seq_len_interpolation_factor (Optional[float], optional):
scale of linearly interpolating RoPE for longer sequences.
The value must be a float larger than 1.0. Defaults to None.
"""
def
__init__
(
self
,
config
:
TransformerConfig
,
transformer_layer_spec
:
ModuleSpec
,
vocab_size
:
int
,
max_sequence_length
:
int
,
pre_process
:
bool
=
True
,
post_process
:
bool
=
True
,
fp16_lm_cross_entropy
:
bool
=
False
,
parallel_output
:
bool
=
True
,
share_embeddings_and_output_weights
:
bool
=
False
,
position_embedding_type
:
Literal
[
'learned_absolute'
,
'rope'
,
'none'
]
=
'learned_absolute'
,
rotary_percent
:
float
=
1.0
,
rotary_base
:
int
=
10000
,
rope_scaling
:
bool
=
False
,
scatter_embedding_sequence_parallel
:
bool
=
True
,
seq_len_interpolation_factor
:
Optional
[
float
]
=
None
,
)
->
None
:
super
().
__init__
(
config
=
config
)
if
has_config_logger_enabled
(
config
):
log_config_to_disk
(
config
,
locals
(),
prefix
=
type
(
self
).
__name__
)
self
.
transformer_layer_spec
:
ModuleSpec
=
transformer_layer_spec
self
.
vocab_size
=
vocab_size
self
.
max_sequence_length
=
max_sequence_length
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
fp16_lm_cross_entropy
=
fp16_lm_cross_entropy
self
.
parallel_output
=
parallel_output
self
.
share_embeddings_and_output_weights
=
share_embeddings_and_output_weights
self
.
position_embedding_type
=
position_embedding_type
# megatron core pipelining currently depends on model type
# TODO: remove this dependency ?
self
.
model_type
=
ModelType
.
encoder_or_decoder
# These 4 attributes are needed for TensorRT-LLM export.
self
.
max_position_embeddings
=
max_sequence_length
self
.
rotary_percent
=
rotary_percent
self
.
rotary_base
=
rotary_base
self
.
rotary_scaling
=
rope_scaling
if
self
.
pre_process
:
self
.
embedding
=
LanguageModelEmbedding
(
config
=
self
.
config
,
vocab_size
=
self
.
vocab_size
,
max_sequence_length
=
self
.
max_sequence_length
,
position_embedding_type
=
position_embedding_type
,
scatter_to_sequence_parallel
=
scatter_embedding_sequence_parallel
,
)
if
self
.
position_embedding_type
==
'rope'
and
not
self
.
config
.
multi_latent_attention
:
self
.
rotary_pos_emb
=
RotaryEmbedding
(
kv_channels
=
self
.
config
.
kv_channels
,
rotary_percent
=
rotary_percent
,
rotary_interleaved
=
self
.
config
.
rotary_interleaved
,
seq_len_interpolation_factor
=
seq_len_interpolation_factor
,
rotary_base
=
rotary_base
,
rope_scaling
=
rope_scaling
,
use_cpu_initialization
=
self
.
config
.
use_cpu_initialization
,
)
# Transformer.
self
.
decoder
=
TransformerBlock
(
config
=
self
.
config
,
spec
=
transformer_layer_spec
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
,
)
# Output
if
post_process
:
if
self
.
config
.
defer_embedding_wgrad_compute
:
# The embedding activation buffer preserves a reference to the input activations
# of the final embedding projection layer GEMM. It will hold the activations for
# all the micro-batches of a global batch for the last pipeline stage. Once we are
# done with all the back props for all the microbatches for the last pipeline stage,
# it will be in the pipeline flush stage. During this pipeline flush we use the
# input activations stored in embedding activation buffer and gradient outputs
# stored in gradient buffer to calculate the weight gradients for the embedding
# final linear layer.
self
.
embedding_activation_buffer
=
[]
self
.
grad_output_buffer
=
[]
else
:
self
.
embedding_activation_buffer
=
None
self
.
grad_output_buffer
=
None
self
.
output_layer
=
tensor_parallel
.
ColumnParallelLinear
(
config
.
hidden_size
,
self
.
vocab_size
,
config
=
config
,
init_method
=
config
.
init_method
,
bias
=
False
,
skip_bias_add
=
False
,
gather_output
=
not
self
.
parallel_output
,
skip_weight_param_allocation
=
self
.
pre_process
and
self
.
share_embeddings_and_output_weights
,
embedding_activation_buffer
=
self
.
embedding_activation_buffer
,
grad_output_buffer
=
self
.
grad_output_buffer
,
)
if
self
.
pre_process
or
self
.
post_process
:
self
.
setup_embeddings_and_output_layer
()
if
has_config_logger_enabled
(
self
.
config
):
log_config_to_disk
(
self
.
config
,
self
.
state_dict
(),
prefix
=
f
'
{
type
(
self
).
__name__
}
_init_ckpt'
)
def
set_input_tensor
(
self
,
input_tensor
:
Tensor
)
->
None
:
"""Sets input tensor to the model.
See megatron.model.transformer.set_input_tensor()
Args:
input_tensor (Tensor): Sets the input tensor for the model.
"""
# This is usually handled in schedules.py but some inference code still
# gives us non-lists or None
if
not
isinstance
(
input_tensor
,
list
):
input_tensor
=
[
input_tensor
]
assert
len
(
input_tensor
)
==
1
,
'input_tensor should only be length 1 for gpt/bert'
self
.
decoder
.
set_input_tensor
(
input_tensor
[
0
])
def
forward
(
self
,
input_ids
:
Tensor
,
position_ids
:
Tensor
,
attention_mask
:
Tensor
,
decoder_input
:
Tensor
=
None
,
labels
:
Tensor
=
None
,
inference_params
:
InferenceParams
=
None
,
packed_seq_params
:
PackedSeqParams
=
None
,
extra_block_kwargs
:
dict
=
None
,
runtime_gather_output
:
Optional
[
bool
]
=
None
,
)
->
Tensor
:
"""Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoeder and finally into the post
processing layer (optional).
It either returns the Loss values if labels are given or the final hidden units
Args:
runtime_gather_output (bool): Gather output at runtime. Default None means
`parallel_output` arg in the constructor will be used.
"""
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
# Decoder embedding.
if
decoder_input
is
not
None
:
pass
elif
self
.
pre_process
:
decoder_input
=
self
.
embedding
(
input_ids
=
input_ids
,
position_ids
=
position_ids
)
else
:
# intermediate stage of pipeline
# decoder will get hidden_states from encoder.input_tensor
decoder_input
=
None
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb
=
None
rotary_pos_cos
=
None
rotary_pos_sin
=
None
if
self
.
position_embedding_type
==
'rope'
and
not
self
.
config
.
multi_latent_attention
:
if
not
self
.
training
and
self
.
config
.
flash_decode
:
# Flash decoding uses precomputed cos and sin for RoPE
rotary_pos_cos
,
rotary_pos_sin
=
self
.
rotary_pos_emb
.
get_cos_sin
(
inference_params
.
max_sequence_length
)
else
:
rotary_seq_len
=
self
.
rotary_pos_emb
.
get_rotary_seq_len
(
inference_params
,
self
.
decoder
,
decoder_input
,
self
.
config
,
packed_seq_params
)
rotary_pos_emb
=
self
.
rotary_pos_emb
(
rotary_seq_len
,
packed_seq
=
packed_seq_params
is
not
None
and
packed_seq_params
.
qkv_format
==
'thd'
,
)
# Run decoder.
hidden_states
=
self
.
decoder
(
hidden_states
=
decoder_input
,
attention_mask
=
attention_mask
,
inference_params
=
inference_params
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_cos
=
rotary_pos_cos
,
rotary_pos_sin
=
rotary_pos_sin
,
packed_seq_params
=
packed_seq_params
,
**
(
extra_block_kwargs
or
{}),
)
if
not
self
.
post_process
:
return
hidden_states
# logits and loss
output_weight
=
None
if
self
.
share_embeddings_and_output_weights
:
output_weight
=
self
.
shared_embedding_or_output_weight
()
logits
,
_
=
self
.
output_layer
(
hidden_states
,
weight
=
output_weight
,
runtime_gather_output
=
runtime_gather_output
)
if
has_config_logger_enabled
(
self
.
config
):
payload
=
OrderedDict
(
{
'input_ids'
:
input_ids
,
'position_ids'
:
position_ids
,
'attention_mask'
:
attention_mask
,
'decoder_input'
:
decoder_input
,
'logits'
:
logits
,
}
)
log_config_to_disk
(
self
.
config
,
payload
,
prefix
=
'input_and_logits'
)
if
labels
is
None
:
# [s b h] => [b s h]
return
logits
.
transpose
(
0
,
1
).
contiguous
()
loss
=
self
.
compute_language_model_loss
(
labels
,
logits
)
return
loss
def
sharded_state_dict
(
self
,
prefix
:
str
=
''
,
sharded_offsets
:
tuple
=
(),
metadata
:
Optional
[
Dict
]
=
None
)
->
ShardedStateDict
:
"""Sharded state dict implementation for GPTModel backward-compatibility
(removing extra state).
Args:
prefix (str): Module name prefix.
sharded_offsets (tuple): PP related offsets, expected to be empty at this module level.
metadata (Optional[Dict]): metadata controlling sharded state dict creation.
Returns:
ShardedStateDict: sharded state dict for the GPTModel
"""
sharded_state_dict
=
super
().
sharded_state_dict
(
prefix
,
sharded_offsets
,
metadata
)
output_layer_extra_state_key
=
f
'
{
prefix
}
output_layer._extra_state'
# Old GPT checkpoints only stored the output layer weight key. So we remove the
# _extra_state key but check that it doesn't contain any data anyway
output_extra_state
=
sharded_state_dict
.
pop
(
output_layer_extra_state_key
,
None
)
assert
not
(
output_extra_state
and
output_extra_state
.
data
),
f
'Expected output layer extra state to be empty, got:
{
output_extra_state
}
'
return
sharded_state_dict
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from
collections
import
OrderedDict
from
typing
import
Dict
,
Literal
,
Optional
import
torch
from
torch
import
Tensor
from
megatron.core
import
InferenceParams
,
tensor_parallel
from
megatron.core.config_logger
import
has_config_logger_enabled
,
log_config_to_disk
from
megatron.core.dist_checkpointing.mapping
import
ShardedStateDict
from
megatron.core.models.common.embeddings.language_model_embedding
import
LanguageModelEmbedding
from
megatron.core.models.common.embeddings.rotary_pos_embedding
import
RotaryEmbedding
from
megatron.core.models.common.language_module.language_module
import
LanguageModule
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.transformer.enums
import
ModelType
from
megatron.core.transformer.spec_utils
import
ModuleSpec
from
megatron.core.transformer.transformer_block
import
TransformerBlock
from
megatron.core.transformer.transformer_config
import
TransformerConfig
class
GPTModel
(
LanguageModule
):
"""GPT Transformer language model.
Args:
config (TransformerConfig):
Transformer config
transformer_layer_spec (ModuleSpec):
Specifies module to use for transformer layers
vocab_size (int):
Vocabulary size
max_sequence_length (int):
maximum size of sequence. This is used for positional embedding
pre_process (bool, optional):
Include embedding layer (used with pipeline parallelism). Defaults to True.
post_process (bool, optional):
Include an output layer (used with pipeline parallelism). Defaults to True.
fp16_lm_cross_entropy (bool, optional):
Defaults to False.
parallel_output (bool, optional):
Do not gather the outputs, keep them split across tensor
parallel ranks. Defaults to True.
share_embeddings_and_output_weights (bool, optional):
When True, input embeddings and output logit weights are shared. Defaults to False.
position_embedding_type (Literal[learned_absolute,rope], optional):
Position embedding type.. Defaults to 'learned_absolute'.
rotary_percent (float, optional):
Percent of rotary dimension to use for rotary position embeddings.
Ignored unless position_embedding_type is 'rope'. Defaults to 1.0.
rotary_base (int, optional):
Base period for rotary position embeddings. Ignored unless
position_embedding_type is 'rope'.
Defaults to 10000.
rope_scaling (bool, optional): Toggle RoPE scaling.
rope_scaling_factor (float): RoPE scaling factor. Default 8.
scatter_embedding_sequence_parallel (bool, optional):
Whether embeddings should be scattered across sequence parallel
region or not. Defaults to True.
seq_len_interpolation_factor (Optional[float], optional):
scale of linearly interpolating RoPE for longer sequences.
The value must be a float larger than 1.0. Defaults to None.
"""
def
__init__
(
self
,
config
:
TransformerConfig
,
transformer_layer_spec
:
ModuleSpec
,
vocab_size
:
int
,
max_sequence_length
:
int
,
pre_process
:
bool
=
True
,
post_process
:
bool
=
True
,
fp16_lm_cross_entropy
:
bool
=
False
,
parallel_output
:
bool
=
True
,
share_embeddings_and_output_weights
:
bool
=
False
,
position_embedding_type
:
Literal
[
'learned_absolute'
,
'rope'
,
'none'
]
=
'learned_absolute'
,
rotary_percent
:
float
=
1.0
,
rotary_base
:
int
=
10000
,
rope_scaling
:
bool
=
False
,
rope_scaling_factor
:
float
=
8.0
,
scatter_embedding_sequence_parallel
:
bool
=
True
,
seq_len_interpolation_factor
:
Optional
[
float
]
=
None
,
)
->
None
:
super
().
__init__
(
config
=
config
)
if
has_config_logger_enabled
(
config
):
log_config_to_disk
(
config
,
locals
(),
prefix
=
type
(
self
).
__name__
)
self
.
transformer_layer_spec
:
ModuleSpec
=
transformer_layer_spec
self
.
vocab_size
=
vocab_size
self
.
max_sequence_length
=
max_sequence_length
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
fp16_lm_cross_entropy
=
fp16_lm_cross_entropy
self
.
parallel_output
=
parallel_output
self
.
share_embeddings_and_output_weights
=
share_embeddings_and_output_weights
self
.
position_embedding_type
=
position_embedding_type
# megatron core pipelining currently depends on model type
# TODO: remove this dependency ?
self
.
model_type
=
ModelType
.
encoder_or_decoder
# These 4 attributes are needed for TensorRT-LLM export.
self
.
max_position_embeddings
=
max_sequence_length
self
.
rotary_percent
=
rotary_percent
self
.
rotary_base
=
rotary_base
self
.
rotary_scaling
=
rope_scaling
if
self
.
pre_process
:
self
.
embedding
=
LanguageModelEmbedding
(
config
=
self
.
config
,
vocab_size
=
self
.
vocab_size
,
max_sequence_length
=
self
.
max_sequence_length
,
position_embedding_type
=
position_embedding_type
,
scatter_to_sequence_parallel
=
scatter_embedding_sequence_parallel
,
)
if
self
.
position_embedding_type
==
'rope'
and
not
self
.
config
.
multi_latent_attention
:
self
.
rotary_pos_emb
=
RotaryEmbedding
(
kv_channels
=
self
.
config
.
kv_channels
,
rotary_percent
=
rotary_percent
,
rotary_interleaved
=
self
.
config
.
rotary_interleaved
,
seq_len_interpolation_factor
=
seq_len_interpolation_factor
,
rotary_base
=
rotary_base
,
rope_scaling
=
rope_scaling
,
rope_scaling_factor
=
rope_scaling_factor
,
use_cpu_initialization
=
self
.
config
.
use_cpu_initialization
,
)
# Cache for RoPE tensors which do not change between iterations.
self
.
rotary_pos_emb_cache
=
{}
# Transformer.
self
.
decoder
=
TransformerBlock
(
config
=
self
.
config
,
spec
=
transformer_layer_spec
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
,
)
# Output
if
post_process
:
if
self
.
config
.
defer_embedding_wgrad_compute
:
# The embedding activation buffer preserves a reference to the input activations
# of the final embedding projection layer GEMM. It will hold the activations for
# all the micro-batches of a global batch for the last pipeline stage. Once we are
# done with all the back props for all the microbatches for the last pipeline stage,
# it will be in the pipeline flush stage. During this pipeline flush we use the
# input activations stored in embedding activation buffer and gradient outputs
# stored in gradient buffer to calculate the weight gradients for the embedding
# final linear layer.
self
.
embedding_activation_buffer
=
[]
self
.
grad_output_buffer
=
[]
else
:
self
.
embedding_activation_buffer
=
None
self
.
grad_output_buffer
=
None
self
.
output_layer
=
tensor_parallel
.
ColumnParallelLinear
(
config
.
hidden_size
,
self
.
vocab_size
,
config
=
config
,
init_method
=
config
.
init_method
,
bias
=
False
,
skip_bias_add
=
False
,
gather_output
=
not
self
.
parallel_output
,
skip_weight_param_allocation
=
self
.
pre_process
and
self
.
share_embeddings_and_output_weights
,
embedding_activation_buffer
=
self
.
embedding_activation_buffer
,
grad_output_buffer
=
self
.
grad_output_buffer
,
)
if
self
.
pre_process
or
self
.
post_process
:
self
.
setup_embeddings_and_output_layer
()
if
has_config_logger_enabled
(
self
.
config
):
log_config_to_disk
(
self
.
config
,
self
.
state_dict
(),
prefix
=
f
'
{
type
(
self
).
__name__
}
_init_ckpt'
)
def
set_input_tensor
(
self
,
input_tensor
:
Tensor
)
->
None
:
"""Sets input tensor to the model.
See megatron.model.transformer.set_input_tensor()
Args:
input_tensor (Tensor): Sets the input tensor for the model.
"""
# This is usually handled in schedules.py but some inference code still
# gives us non-lists or None
if
not
isinstance
(
input_tensor
,
list
):
input_tensor
=
[
input_tensor
]
assert
len
(
input_tensor
)
==
1
,
'input_tensor should only be length 1 for gpt/bert'
self
.
decoder
.
set_input_tensor
(
input_tensor
[
0
])
def
forward
(
self
,
input_ids
:
Tensor
,
position_ids
:
Tensor
,
attention_mask
:
Tensor
,
decoder_input
:
Tensor
=
None
,
labels
:
Tensor
=
None
,
inference_params
:
InferenceParams
=
None
,
packed_seq_params
:
PackedSeqParams
=
None
,
extra_block_kwargs
:
dict
=
None
,
runtime_gather_output
:
Optional
[
bool
]
=
None
,
)
->
Tensor
:
"""Forward function of the GPT Model This function passes the input tensors
through the embedding layer, and then the decoeder and finally into the post
processing layer (optional).
It either returns the Loss values if labels are given or the final hidden units
Args:
runtime_gather_output (bool): Gather output at runtime. Default None means
`parallel_output` arg in the constructor will be used.
"""
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
# Decoder embedding.
if
decoder_input
is
not
None
:
pass
elif
self
.
pre_process
:
decoder_input
=
self
.
embedding
(
input_ids
=
input_ids
,
position_ids
=
position_ids
)
else
:
# intermediate stage of pipeline
# decoder will get hidden_states from encoder.input_tensor
decoder_input
=
None
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb
=
None
rotary_pos_cos
=
None
rotary_pos_sin
=
None
if
self
.
position_embedding_type
==
'rope'
and
not
self
.
config
.
multi_latent_attention
:
if
not
self
.
training
and
self
.
config
.
flash_decode
and
inference_params
:
# Flash decoding uses precomputed cos and sin for RoPE
rotary_pos_cos
,
rotary_pos_sin
=
self
.
rotary_pos_emb_cache
.
setdefault
(
inference_params
.
max_sequence_length
,
self
.
rotary_pos_emb
.
get_cos_sin
(
inference_params
.
max_sequence_length
),
)
else
:
rotary_seq_len
=
self
.
rotary_pos_emb
.
get_rotary_seq_len
(
inference_params
,
self
.
decoder
,
decoder_input
,
self
.
config
,
packed_seq_params
)
rotary_pos_emb
=
self
.
rotary_pos_emb
(
rotary_seq_len
,
packed_seq
=
packed_seq_params
is
not
None
and
packed_seq_params
.
qkv_format
==
'thd'
,
)
if
(
(
self
.
config
.
enable_cuda_graph
or
self
.
config
.
flash_decode
)
and
rotary_pos_cos
is
not
None
and
inference_params
):
sequence_len_offset
=
torch
.
tensor
(
[
inference_params
.
sequence_len_offset
]
*
inference_params
.
current_batch_size
,
dtype
=
torch
.
int32
,
device
=
rotary_pos_cos
.
device
,
# Co-locate this with the rotary tensors
)
else
:
sequence_len_offset
=
None
# Run decoder.
hidden_states
=
self
.
decoder
(
hidden_states
=
decoder_input
,
attention_mask
=
attention_mask
,
inference_params
=
inference_params
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_cos
=
rotary_pos_cos
,
rotary_pos_sin
=
rotary_pos_sin
,
packed_seq_params
=
packed_seq_params
,
sequence_len_offset
=
sequence_len_offset
,
**
(
extra_block_kwargs
or
{}),
)
if
not
self
.
post_process
:
return
hidden_states
# logits and loss
output_weight
=
None
if
self
.
share_embeddings_and_output_weights
:
output_weight
=
self
.
shared_embedding_or_output_weight
()
logits
,
_
=
self
.
output_layer
(
hidden_states
,
weight
=
output_weight
,
runtime_gather_output
=
runtime_gather_output
)
if
has_config_logger_enabled
(
self
.
config
):
payload
=
OrderedDict
(
{
'input_ids'
:
input_ids
,
'position_ids'
:
position_ids
,
'attention_mask'
:
attention_mask
,
'decoder_input'
:
decoder_input
,
'logits'
:
logits
,
}
)
log_config_to_disk
(
self
.
config
,
payload
,
prefix
=
'input_and_logits'
)
if
labels
is
None
:
# [s b h] => [b s h]
return
logits
.
transpose
(
0
,
1
).
contiguous
()
loss
=
self
.
compute_language_model_loss
(
labels
,
logits
)
return
loss
def
sharded_state_dict
(
self
,
prefix
:
str
=
''
,
sharded_offsets
:
tuple
=
(),
metadata
:
Optional
[
Dict
]
=
None
)
->
ShardedStateDict
:
"""Sharded state dict implementation for GPTModel backward-compatibility
(removing extra state).
Args:
prefix (str): Module name prefix.
sharded_offsets (tuple): PP related offsets, expected to be empty at this module level.
metadata (Optional[Dict]): metadata controlling sharded state dict creation.
Returns:
ShardedStateDict: sharded state dict for the GPTModel
"""
sharded_state_dict
=
super
().
sharded_state_dict
(
prefix
,
sharded_offsets
,
metadata
)
output_layer_extra_state_key
=
f
'
{
prefix
}
output_layer._extra_state'
# Old GPT checkpoints only stored the output layer weight key. So we remove the
# _extra_state key but check that it doesn't contain any data anyway
output_extra_state
=
sharded_state_dict
.
pop
(
output_layer_extra_state_key
,
None
)
assert
not
(
output_extra_state
and
output_extra_state
.
data
),
f
'Expected output layer extra state to be empty, got:
{
output_extra_state
}
'
return
sharded_state_dict
Prev
1
…
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment