Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
688448db
Commit
688448db
authored
Mar 14, 2025
by
silencealiang
Browse files
更新代码
parent
a02a5490
Pipeline
#2503
passed with stage
Changes
172
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3236 additions
and
2198 deletions
+3236
-2198
megatron/core/inference/model_inference_wrappers/multimodal/vlm_inference_wrapper.py
...el_inference_wrappers/multimodal/vlm_inference_wrapper.py
+208
-0
megatron/core/inference/model_inference_wrappers/t5/t5_inference_wrapper.py
...rence/model_inference_wrappers/t5/t5_inference_wrapper.py
+225
-215
megatron/core/inference/modelopt_support/__init__.py
megatron/core/inference/modelopt_support/__init__.py
+10
-8
megatron/core/inference/modelopt_support/gpt/model_specs.py
megatron/core/inference/modelopt_support/gpt/model_specs.py
+68
-63
megatron/core/inference/modelopt_support/mamba/__init__.py
megatron/core/inference/modelopt_support/mamba/__init__.py
+1
-0
megatron/core/inference/modelopt_support/mamba/model_specs.py
...tron/core/inference/modelopt_support/mamba/model_specs.py
+89
-0
megatron/core/inference/sampling_params.py
megatron/core/inference/sampling_params.py
+36
-35
megatron/core/inference/scheduler.py
megatron/core/inference/scheduler.py
+175
-127
megatron/core/inference/text_generation_controllers/encoder_decoder_text_generation_controller.py
...controllers/encoder_decoder_text_generation_controller.py
+38
-35
megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py
...neration_controllers/simple_text_generation_controller.py
+5
-5
megatron/core/inference/text_generation_controllers/text_generation_controller.py
...text_generation_controllers/text_generation_controller.py
+674
-400
megatron/core/inference/text_generation_controllers/vlm_text_generation_controller.py
..._generation_controllers/vlm_text_generation_controller.py
+40
-0
megatron/core/inference_params.py
megatron/core/inference_params.py
+100
-31
megatron/core/jit.py
megatron/core/jit.py
+1
-15
megatron/core/model_parallel_config.py
megatron/core/model_parallel_config.py
+392
-387
megatron/core/models/T5/t5_model.py
megatron/core/models/T5/t5_model.py
+72
-5
megatron/core/models/common/embeddings/relative_pos_embedding.py
...n/core/models/common/embeddings/relative_pos_embedding.py
+173
-0
megatron/core/models/common/embeddings/rotary_pos_embedding.py
...ron/core/models/common/embeddings/rotary_pos_embedding.py
+215
-213
megatron/core/models/gpt/gpt_layer_specs.py
megatron/core/models/gpt/gpt_layer_specs.py
+383
-350
megatron/core/models/gpt/gpt_model.py
megatron/core/models/gpt/gpt_model.py
+331
-309
No files found.
Too many changes to show.
To preserve performance only
172 of 172+
files are displayed.
Plain diff
Email patch
megatron/core/inference/model_inference_wrappers/multimodal/vlm_inference_wrapper.py
0 → 100644
View file @
688448db
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
typing
import
Any
,
Dict
import
torch
from
megatron.core
import
parallel_state
from
megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper
import
(
GPTInferenceWrapper
,
)
from
megatron.core.inference_params
import
InferenceParams
# pylint: disable=line-too-long
class
VLMInferenceWrapper
(
GPTInferenceWrapper
):
"""Inference wrapper for VLMs"""
def
prep_model_for_inference
(
self
,
prompts_tokens
:
torch
.
Tensor
):
"""A utility function for preparing model for inference
The function gets called once before the auto regressive inference loop.
It puts the model in eval mode.
Args:
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_seq_len]
"""
super
().
prep_model_for_inference
(
prompts_tokens
)
# For TP only model both is_pp_first_stage and _is_pp_last_stage returns True
self
.
model_is_pipeline_parallel
=
not
(
parallel_state
.
is_pipeline_first_stage
()
and
parallel_state
.
is_pipeline_last_stage
()
)
self
.
_recv_only_vision_embeds
=
False
pp_rank
=
parallel_state
.
get_pipeline_model_parallel_rank
()
# Checks if the previous stage only has a vision encoder, and that the current stage
# has part of the LM decoder. In this case, the current stage should only receive
# vision embeddings.
if
pp_rank
>
0
:
self
.
_recv_only_vision_embeds
=
(
parallel_state
.
is_inside_encoder
(
pp_rank
-
1
)
and
(
not
parallel_state
.
is_inside_decoder
(
pp_rank
-
1
))
and
parallel_state
.
is_inside_decoder
()
)
# Checks if the current stage only has a vision encoder
self
.
_encoder_only
=
(
parallel_state
.
is_inside_encoder
()
and
not
parallel_state
.
is_inside_decoder
()
)
# For TP only model both is_pp_first_stage and _is_pp_last_stage returns True
self
.
model_is_pipeline_parallel
=
not
(
parallel_state
.
is_pipeline_first_stage
()
and
parallel_state
.
is_pipeline_last_stage
()
)
def
prep_inference_input
(
self
,
prompts_tokens
:
torch
.
Tensor
,
num_img_embeddings_per_tile
:
int
,
images
:
torch
.
Tensor
,
num_tiles
:
torch
.
Tensor
,
decoder_seq_length
:
int
,
):
"""Prepares the inference input data.
Args:
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_seq_len]
num_img_embeddings_per_tile (int): The number of image embeddings per tile
images (torch.Tensor): The image embeddings
num_tiles (torch.Tensor): The number of tiles for each input image
decoder_seq_length (int): The decoder sequence length
"""
inference_input
=
super
().
prep_inference_input
(
prompts_tokens
)
total_num_tiles
=
torch
.
sum
(
num_tiles
).
item
()
num_img_embeddings
=
num_img_embeddings_per_tile
*
total_num_tiles
batch_size
,
max_sequence_length
=
prompts_tokens
.
shape
self
.
inference_params
=
InferenceParams
(
batch_size
,
max_sequence_length
+
num_img_embeddings
)
inference_input
[
"images"
]
=
images
inference_input
[
"num_tiles"
]
=
num_tiles
inference_input
[
"num_img_embeddings"
]
=
num_img_embeddings
inference_input
[
"decoder_seq_length"
]
=
decoder_seq_length
return
inference_input
def
get_batch_for_context_window
(
self
,
inference_input
:
Dict
[
str
,
Any
],
context_start_position
:
int
,
context_end_position
:
int
,
)
->
Dict
[
str
,
Any
]:
"""Returns the inference data given context window
This function gets called iteratively in a loop . Given the start and end context positions , it extracts the appropriate data.
Args:
inference_input (Dict[str, Any]): The inference input for the batch.
context_start_position (int): Start of the context window. During the first inference step it is mostly 0
context_end_position (int): End of the context window. During the last inference step it will mostly be the max generated sequence length.
Returns:
Dict[str, Any]: A dict of inputs that will be used by your model in the forward step
"""
tokens
=
inference_input
[
"tokens"
]
position_ids
=
inference_input
[
"position_ids"
]
images
=
inference_input
[
"images"
]
num_tiles
=
inference_input
[
"num_tiles"
]
num_img_embeddings
=
inference_input
[
"num_img_embeddings"
]
decoder_seq_length
=
inference_input
[
"decoder_seq_length"
]
tokens2use
=
tokens
[:,
context_start_position
:
context_end_position
]
positions2use
=
position_ids
[:,
context_start_position
:
context_end_position
]
return
{
"tokens"
:
tokens2use
,
"position_ids"
:
positions2use
,
"images"
:
images
,
"num_tiles"
:
num_tiles
,
"num_img_embeddings"
:
num_img_embeddings
,
"decoder_seq_length"
:
decoder_seq_length
,
}
def
_forward
(
self
,
inference_input
:
Dict
[
str
,
Any
]):
"""Runs a forward pass of the model.
Args:
inference_input(Dict[str, Any]): The input data.
Returns:
The model output logits.
"""
images
=
inference_input
[
"images"
]
tokens
=
inference_input
[
"tokens"
]
position_ids
=
inference_input
[
"position_ids"
]
num_image_tiles
=
inference_input
[
"num_tiles"
]
output
=
self
.
model
(
images
,
tokens
,
position_ids
=
position_ids
,
attention_mask
=
None
,
inference_params
=
self
.
inference_params
,
num_image_tiles
=
num_image_tiles
,
runtime_gather_output
=
True
,
)
if
isinstance
(
output
,
tuple
):
logits
,
_
=
output
else
:
logits
=
output
return
logits
def
run_one_forward_step
(
self
,
inference_input
:
Dict
[
str
,
Any
])
->
torch
.
Tensor
:
tokens
=
inference_input
[
"tokens"
]
num_image_tokens
=
(
tokens
==
self
.
model
.
module
.
image_token_index
).
sum
().
item
()
num_img_embeddings
=
inference_input
[
"num_img_embeddings"
]
decoder_seq_length
=
inference_input
[
"decoder_seq_length"
]
num_tokens
=
tokens
.
size
(
1
)
recv_buffer_seq_len
=
None
if
num_image_tokens
>
0
:
# When there are image tokens and this stage only receives vision embeddings,
# adjust the recv buffer seq length to match the image embeddings sequence length.
# If there are image tokens and this stage receives full embeddings, make sure we
# compensate for expansion of image tokens.
# Note that this will set a recv_buffer_seq_len for the encoder stage,
# this length is irrelevant since that recv buffer is never allocated.
if
self
.
_recv_only_vision_embeds
:
recv_buffer_seq_len
=
num_img_embeddings
else
:
recv_buffer_seq_len
=
min
(
num_img_embeddings
+
num_tokens
-
num_image_tokens
,
decoder_seq_length
)
elif
self
.
_recv_only_vision_embeds
:
# If this stage only receives vision embeddings and there are no image tokens
# we won't run the encoder and therefore shouldn't try to recv.
recv_buffer_seq_len
=
0
# If the pipeline stage only has a vision encoder, then it only needs to
# run when there are image tokens
if
not
(
self
.
_encoder_only
and
num_image_tokens
==
0
):
output
=
super
().
run_one_forward_step
(
inference_input
,
recv_buffer_seq_len
=
recv_buffer_seq_len
)
else
:
output
=
None
logits
=
output
# On the first inference iteration, we compute image tokens.
# On every PP stage(although inference params should only matter for decoder),
# update the sequence length offset by the number of image tokens.
if
num_tokens
>
1
and
num_image_tokens
>
0
:
if
"image_tokens_count"
not
in
self
.
inference_params
.
key_value_memory_dict
:
self
.
inference_params
.
key_value_memory_dict
[
"image_tokens_count"
]
=
(
num_img_embeddings
)
if
num_img_embeddings
+
num_tokens
-
num_image_tokens
>
decoder_seq_length
:
self
.
inference_params
.
sequence_len_offset
+=
decoder_seq_length
-
num_tokens
else
:
self
.
inference_params
.
sequence_len_offset
+=
(
self
.
inference_params
.
key_value_memory_dict
[
"image_tokens_count"
]
-
num_image_tokens
)
return
logits
megatron/core/inference/model_inference_wrappers/t5/t5_inference_wrapper.py
View file @
688448db
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
collections
import
deque
from
collections
import
deque
from
typing
import
Any
,
List
,
Tuple
from
typing
import
Any
,
Dict
,
List
,
Optional
import
numpy
import
numpy
import
torch
import
torch
from
megatron.core
import
tensor_parallel
from
megatron.core
import
tensor_parallel
from
megatron.core.datasets.t5_dataset
import
T5MaskedWordPieceDataset
from
megatron.core.datasets.t5_dataset
import
T5MaskedWordPieceDataset
from
megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper
import
(
from
megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper
import
(
AbstractModelInferenceWrapper
,
AbstractModelInferenceWrapper
,
)
)
from
megatron.core.inference.model_inference_wrappers.inference_wrapper_config
import
(
from
megatron.core.inference.model_inference_wrappers.inference_wrapper_config
import
(
InferenceWrapperConfig
,
InferenceWrapperConfig
,
)
)
from
megatron.core.models.T5
import
T5Model
from
megatron.core.models.T5
import
T5Model
from
megatron.core.utils
import
get_attr_wrapped_model
# pylint: disable=line-too-long
class
T5InferenceWrapper
(
AbstractModelInferenceWrapper
):
# pylint: disable=line-too-long
"""Constructor for the model inference wrapper
class
T5InferenceWrapper
(
AbstractModelInferenceWrapper
):
"""Constructor for the model inference wrapper
The wrapper prepares the model for inference, provides the required input
data, and runs the forward pass
The wrapper prepares the model for inference, provides the required input
data, and runs the forward pass
Args:
model (T5Model): The T5 model (MCore or legacy)
Args:
inference_wrapper_config (InferenceWrapperConfig): The command line arguments that were passed
model (T5Model): The T5 model (MCore or legacy)
use_local (bool): Whether the T5 model's transformer impl
inference_wrapper_config (InferenceWrapperConfig): The command line arguments that were passed
is local (vs transformer_engine)
use_local (bool): Whether the T5 model's transformer impl
"""
is local (vs transformer_engine)
"""
def
__init__
(
self
,
def
__init__
(
model
:
T5Model
,
self
,
inference_wrapper_config
:
InferenceWrapperConfig
,
model
:
T5Model
,
use_local
:
bool
=
False
,
inference_wrapper_config
:
InferenceWrapperConfig
,
):
use_local
:
bool
=
False
,
super
().
__init__
(
model
,
inference_wrapper_config
)
):
self
.
use_local
=
use_local
super
().
__init__
(
model
,
inference_wrapper_config
)
self
.
use_local
=
use_local
def
prep_model_for_inference
(
self
,
prompts_tokens
:
torch
.
Tensor
,
encoder_prompts
:
List
[
str
]
=
None
,
tokenizer
:
Any
=
None
def
prep_inference_input
(
):
self
,
"""A utility function for preparing model for inference
prompts_tokens
:
torch
.
Tensor
,
encoder_prompts
:
Optional
[
List
[
str
]]
=
None
,
This function is called before the forward pass. It puts the model in eval mode, builds
tokenizer
:
Any
=
None
,
position ids, and creates attention masks so that required slices can be extracted during
)
->
Dict
[
str
,
Any
]:
the forward pass.
"""Prepares the inference input data.
Args:
Args:
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length]
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_seq_len]
encoder_prompts (dict): List of string of encoder input prompts
encoder_prompts (dict): List of string of encoder input prompts
tokenizer (_type_): Tokenizer used for tokenizing and detokenizing text
tokenizer (_type_): Tokenizer used for tokenizing and detokenizing text
"""
Returns:
super
().
prep_model_for_inference
(
prompts_tokens
=
prompts_tokens
)
A dict with all the inference input needed for the batch.
"""
# get max_sequence_length
# get max_sequence_length
if
hasattr
(
self
.
model
,
"module"
):
# if self.model is Float16Module
max_sequence_length
=
get_attr_wrapped_model
(
self
.
model
,
"max_sequence_length"
)
max_sequence_length
=
self
.
model
.
module
.
max_sequence_length
else
:
encoder_prompts_tokens_list
=
[
max_sequence_length
=
self
.
model
.
max_sequence_length
self
.
tokenize_encoder_prompt
(
encoder_prompt
,
tokenizer
)
for
encoder_prompt
in
encoder_prompts
encoder_prompts_tokens_list
=
[
]
self
.
tokenize_encoder_prompt
(
encoder_prompt
,
tokenizer
)
batch_encoder_prompts_tokens
=
self
.
pad_encoder_prompts_tokens
(
for
encoder_prompt
in
encoder_prompts
encoder_prompts_tokens_list
,
max_sequence_length
,
tokenizer
]
)
self
.
batch_encoder_prompts_tokens
=
self
.
pad_encoder_prompts_tokens
(
encoder_prompts_tokens_list
,
max_sequence_length
,
tokenizer
# create batch mask for encoder_prompt (self.batch_input_tokens) and
)
# decoder_input (prompts_tokens), similar to megatron/core/datasets/t5_dataset.py
decoder_prompts_tokens
=
prompts_tokens
# create batch mask for encoder_prompt (self.batch_input_tokens) and
encoder_prompts_tokens
=
batch_encoder_prompts_tokens
# decoder_input (self.prompts_tokens), similar to megatron/core/datasets/t5_dataset.py
decoder_prompts_tokens_numpy
=
decoder_prompts_tokens
.
cpu
().
numpy
()
decoder_prompts_tokens
=
self
.
prompts_tokens
.
cpu
().
numpy
()
encoder_prompts_tokens_numpy
=
encoder_prompts_tokens
.
cpu
().
numpy
()
encoder_prompts_tokens
=
self
.
batch_encoder_prompts_tokens
.
cpu
().
numpy
()
batch_mask_encoder
=
[]
self
.
batch_mask_encoder
=
[]
batch_mask_decoder
=
[]
self
.
batch_mask_decoder
=
[]
for
i
in
range
(
len
(
prompts_tokens
)):
for
i
in
range
(
len
(
self
.
prompts_tokens
)):
mask_encoder
=
encoder_prompts_tokens_numpy
[
i
]
==
tokenizer
.
pad
mask_encoder
=
encoder_prompts_tokens
[
i
]
==
tokenizer
.
pad
mask_decoder
=
decoder_prompts_tokens_numpy
[
i
]
==
tokenizer
.
pad
mask_decoder
=
decoder_prompts_tokens
[
i
]
==
tokenizer
.
pad
batch_mask_encoder
.
append
(
mask_encoder
)
self
.
batch_mask_encoder
.
append
(
mask_encoder
)
batch_mask_decoder
.
append
(
mask_decoder
)
self
.
batch_mask_decoder
.
append
(
mask_decoder
)
batch_mask_encoder
=
torch
.
tensor
(
numpy
.
array
(
batch_mask_encoder
)).
cuda
()
self
.
batch_mask_encoder
=
torch
.
tensor
(
numpy
.
array
(
self
.
batch_mask_encoder
)).
cuda
()
batch_mask_decoder
=
torch
.
tensor
(
numpy
.
array
(
batch_mask_decoder
)).
cuda
()
self
.
batch_mask_decoder
=
torch
.
tensor
(
numpy
.
array
(
self
.
batch_mask_decoder
)).
cuda
()
return
{
def
tokenize_encoder_prompt
(
"encoder_tokens"
:
encoder_prompts_tokens
,
self
,
encoder_prompt
:
str
,
tokenizer
"decoder_tokens"
:
decoder_prompts_tokens
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"encoder_mask"
:
batch_mask_encoder
,
"""Utility to tokenize the encoder_prompt
"decoder_mask"
:
batch_mask_decoder
,
}
Args:
encoder_prompt (str): The encoder_prompt
def
tokenize_encoder_prompt
(
self
,
encoder_prompt
:
str
,
tokenizer
)
->
torch
.
Tensor
:
tokenizer (_type_): Tokenizer used for tokenizing and detokenizing string
"""Utility to tokenize the encoder_prompt
Returns:
Args:
torch.Tensor: Returns the tokenized prompt
encoder_prompt (str): The encoder_prompt
"""
tokenizer (_type_): Tokenizer used for tokenizing and detokenizing string
# if there is the word "<mask>" in prompt, replacing it with special_additional_token,
Returns:
# similar to processing step in megatron/core/datasets/t5_dataset.py
torch.Tensor: Returns the tokenized prompt
divided_encoder_prompt_list
=
encoder_prompt
.
split
(
"<mask>"
)
"""
masks_count
=
len
(
divided_encoder_prompt_list
)
-
1
sentinels
=
deque
(
tokenizer
.
additional_special_tokens_ids
)
# if there is the word "<mask>" in prompt, replacing it with special_additional_token,
# similar to processing step in megatron/core/datasets/t5_dataset.py
encoder_prompt_tokens
=
[]
divided_encoder_prompt_list
=
encoder_prompt
.
split
(
"<mask>"
)
for
divided_encoder_prompt
in
divided_encoder_prompt_list
:
masks_count
=
len
(
divided_encoder_prompt_list
)
-
1
divided_encoder_prompt_tokens
=
tokenizer
.
tokenize
(
divided_encoder_prompt
)
sentinels
=
deque
(
tokenizer
.
additional_special_tokens_ids
)
encoder_prompt_tokens
.
extend
(
divided_encoder_prompt_tokens
)
if
masks_count
>
0
:
encoder_prompt_tokens
=
[]
sentinel
=
sentinels
.
popleft
()
for
divided_encoder_prompt
in
divided_encoder_prompt_list
:
encoder_prompt_tokens
.
extend
([
sentinel
])
divided_encoder_prompt_tokens
=
tokenizer
.
tokenize
(
divided_encoder_prompt
)
masks_count
-=
1
encoder_prompt_tokens
.
extend
(
divided_encoder_prompt_tokens
)
if
masks_count
>
0
:
return
encoder_prompt_tokens
sentinel
=
sentinels
.
popleft
()
encoder_prompt_tokens
.
extend
([
sentinel
])
def
pad_encoder_prompts_tokens
(
masks_count
-=
1
self
,
encoder_prompts_tokens_list
:
List
[
List
[
int
]],
max_sequence_length
:
int
,
tokenizer
)
->
torch
.
Tensor
:
return
encoder_prompt_tokens
"""Method to pad input prompts
def
pad_encoder_prompts_tokens
(
Given a list of prompts, pad them all to uniform length
self
,
encoder_prompts_tokens_list
:
List
[
List
[
int
]],
max_sequence_length
:
int
,
tokenizer
)
->
torch
.
Tensor
:
Args:
"""Method to pad input prompts
encoder_prompts_tokens_list (List[List[int]]): A list containing the
encoder_input_tokens
Given a list of prompts, pad them all to uniform length
max_sequence_length (int): Maximum of the length of the encoder inputs tokens
tokenizer (_type_): Tokenizer used for tokenizing and detokenizing text
Args:
encoder_prompts_tokens_list (List[List[int]]): A list containing the
Returns:
encoder_input_tokens
torch.Tensor: A torch tensor of shape [bs, max_sequence_length]
max_sequence_length (int): Maximum of the length of the encoder inputs tokens
"""
tokenizer (_type_): Tokenizer used for tokenizing and detokenizing text
for
encoder_prompt_tokens
in
encoder_prompts_tokens_list
:
Returns:
padding_size
=
max_sequence_length
-
len
(
encoder_prompt_tokens
)
torch.Tensor: A torch tensor of shape [bs, max_sequence_length]
encoder_prompt_tokens
.
extend
([
tokenizer
.
pad
]
*
padding_size
)
"""
return
torch
.
tensor
(
encoder_prompts_tokens_list
).
cuda
()
for
encoder_prompt_tokens
in
encoder_prompts_tokens_list
:
padding_size
=
max_sequence_length
-
len
(
encoder_prompt_tokens
)
def
get_batch_for_context_window
(
encoder_prompt_tokens
.
extend
([
tokenizer
.
pad
]
*
padding_size
)
self
,
context_start_position
:
int
,
context_end_position
:
int
)
->
List
:
return
torch
.
tensor
(
encoder_prompts_tokens_list
).
cuda
()
"""Returns the inference data given context window
def
get_batch_for_context_window
(
This function gets called iteratively in a loop . Given the start and end context
self
,
positions , it extracts the appropriate data.
inference_input
:
Dict
[
str
,
Any
],
context_start_position
:
int
,
Args:
context_end_position
:
int
,
context_start_position (int): Start of the context window. During
)
->
Dict
[
str
,
Any
]:
the first inference step it is mostly 0
"""Returns the inference data given context window
context_end_position (int): End of the context window. During the
last inference step it will mostly be the max generated sequence length.
This function gets called iteratively in a loop . Given the start and end context
positions , it extracts the appropriate data.
Returns:
List: A list of inputs that will be used by your model in the forward step
Args:
"""
inference_input (Dict[str, Any]): The inference input for the batch.
context_start_position (int): Start of the context window. During
# T5 inference not yet support kv_cache
the first inference step it is mostly 0
encoder_tokens2use
=
self
.
batch_encoder_prompts_tokens
context_end_position (int): End of the context window. During the
decoder_tokens2use
=
self
.
prompts_tokens
[:,
:
context_end_position
]
last inference step it will mostly be the max generated sequence length.
encoder_mask2use
=
self
.
batch_mask_encoder
decoder_mask2use
=
self
.
batch_mask_decoder
[:,
:
context_end_position
]
Returns:
Dict: A dict of inputs that will be used by your model in the forward step
# Configure attention mask based on different conditions
"""
# (e.g., transformer-impl, TE versions, TE backends)
[
encoder_mask2use
,
decoder_mask2use
,
encoder_decoder_mask2use
]
=
(
# T5 inference not yet support kv_cache
T5MaskedWordPieceDataset
.
config_attention_mask
(
encoder_tokens2use
=
inference_input
[
"encoder_tokens"
]
encoder_tokens2use
,
decoder_tokens2use
=
inference_input
[
"decoder_tokens"
][:,
:
context_end_position
]
decoder_tokens2use
,
encoder_mask2use
=
inference_input
[
"encoder_mask"
]
encoder_mask2use
,
decoder_mask2use
=
inference_input
[
"decoder_mask"
][:,
:
context_end_position
]
decoder_mask2use
,
self
.
use_local
,
# Configure attention mask based on different conditions
)
# (e.g., transformer-impl, TE versions, TE backends)
)
[
encoder_mask2use
,
decoder_mask2use
,
encoder_decoder_mask2use
]
=
(
T5MaskedWordPieceDataset
.
config_attention_mask
(
data_at_step_idx
=
[
encoder_tokens2use
,
encoder_tokens2use
,
decoder_tokens2use
,
decoder_tokens2use
,
encoder_mask2use
,
encoder_mask2use
,
decoder_mask2use
,
decoder_mask2use
,
self
.
use_local
,
encoder_decoder_mask2use
,
)
]
)
return
data_at_step_idx
return
{
"encoder_tokens"
:
encoder_tokens2use
,
def
forward_pass_without_pipeline_parallel
(
self
,
inference_input
:
List
)
->
torch
.
Tensor
:
"decoder_tokens"
:
decoder_tokens2use
,
"""Utility to carry out simple forward pass for TP or no model parallel models
"encoder_mask"
:
encoder_mask2use
,
"decoder_mask"
:
decoder_mask2use
,
Runs a very simple forward pass for model. Used in the case of models without
"encoder_decoder_mask"
:
encoder_decoder_mask2use
,
any parallelism or only tensor parallelism.
}
Args:
def
forward_pass_without_pipeline_parallel
(
inference_input (List): A list containg the inputs for the gpt
self
,
inference_input
:
Dict
[
str
,
Any
]
model [tokens, position ids, attention mask]
)
->
torch
.
Tensor
:
"""Utility to carry out simple forward pass for TP or no model parallel models
Returns:
torch.Tensor: The output logits of shape [batch_size, seq_len, padded_vocab_size]
Runs a very simple forward pass for model. Used in the case of models without
"""
any parallelism or only tensor parallelism.
[
encoder_tokens
,
decoder_tokens
,
encoder_mask
,
decoder_mask
,
encoder_decoder_mask
]
=
(
inference_input
Args:
)
inference_input (Dict[str, Any]): A dict containg the inputs for the gpt
tokens
=
decoder_tokens
model [tokens, position ids, attention mask]
# T5 inference not yet support kv_cache
Returns:
logits
=
self
.
model
(
torch.Tensor: The output logits of shape [batch_size, seq_len, padded_vocab_size]
encoder_tokens
,
"""
decoder_tokens
,
encoder_tokens
=
inference_input
[
"encoder_tokens"
]
encoder_mask
,
decoder_tokens
=
inference_input
[
"decoder_tokens"
]
decoder_mask
,
encoder_mask
=
inference_input
[
"encoder_mask"
]
encoder_decoder_mask
,
decoder_mask
=
inference_input
[
"decoder_mask"
]
inference_params
=
None
,
encoder_decoder_mask
=
inference_input
[
"encoder_decoder_mask"
]
)
tokens
=
decoder_tokens
logits
=
tensor_parallel
.
gather_from_tensor_model_parallel_region
(
logits
)
# T5 inference not yet support kv_cache
return
logits
logits
=
self
.
model
(
encoder_tokens
,
decoder_tokens
,
encoder_mask
,
decoder_mask
,
encoder_decoder_mask
,
inference_params
=
None
,
)
logits
=
tensor_parallel
.
gather_from_tensor_model_parallel_region
(
logits
)
return
logits
megatron/core/inference/modelopt_support/__init__.py
View file @
688448db
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Integrations with NVIDIA TensorRT Model Optimizer (referred as ModelOpt).
"""Integrations with NVIDIA TensorRT Model Optimizer (referred as ModelOpt).
ModelOpt is a library comprising state-of-the-art model optimization techniques including quantization and sparsity to
ModelOpt is a library comprising state-of-the-art model optimization techniques
compress model for efficient inference on NVIDIA GPUs. ModelOpt is integrated with Megatron-core to provide a seamless
including quantization and sparsity to compress model for efficient inference on
experience for users to optimize their Megatron-core models for inference. More details on ModelOpt including
NVIDIA GPUs. ModelOpt is integrated with Megatron-core to provide a seamless
installation and usage can be found at https://github.com/NVIDIA/TensorRT-Model-Optimizer.
experience for users to optimize their Megatron-core models for inference.
"""
More details on ModelOpt including installation and usage can be found at
https://github.com/NVIDIA/TensorRT-Model-Optimizer.
"""
megatron/core/inference/modelopt_support/gpt/model_specs.py
View file @
688448db
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
megatron.core.extensions.transformer_engine
import
TEDotProductAttention
,
TENorm
from
typing
import
Optional
from
megatron.core.fusions.fused_bias_dropout
import
get_bias_dropout_add
from
megatron.core.models.gpt.gpt_layer_specs
import
_get_mlp_module_spec
from
megatron.core.extensions.transformer_engine
import
TEDotProductAttention
,
TENorm
from
megatron.core.tensor_parallel.layers
import
ColumnParallelLinear
,
RowParallelLinear
from
megatron.core.fusions.fused_bias_dropout
import
get_bias_dropout_add
from
megatron.core.transformer.attention
import
SelfAttention
,
SelfAttentionSubmodules
from
megatron.core.models.gpt.gpt_layer_specs
import
get_mlp_module_spec
from
megatron.core.transformer.enums
import
AttnMaskType
from
megatron.core.tensor_parallel.layers
import
ColumnParallelLinear
,
RowParallelLinear
from
megatron.core.transformer.identity_op
import
IdentityOp
from
megatron.core.transformer.attention
import
SelfAttention
,
SelfAttentionSubmodules
from
megatron.core.transformer.spec_utils
import
ModuleSpec
from
megatron.core.transformer.dot_product_attention
import
DotProductAttention
from
megatron.core.transformer.transformer_layer
import
TransformerLayer
,
TransformerLayerSubmodules
from
megatron.core.transformer.enums
import
AttnMaskType
from
megatron.core.transformer.identity_op
import
IdentityOp
from
megatron.core.transformer.spec_utils
import
ModuleSpec
# Use this spec for ModelOpt PTQ and TensorRT-LLM export
from
megatron.core.transformer.transformer_layer
import
TransformerLayer
,
TransformerLayerSubmodules
def
get_gpt_layer_modelopt_spec
(
num_experts
:
int
=
None
,
moe_grouped_gemm
:
bool
=
False
,
# Use this spec for ModelOpt PTQ and TensorRT-LLM export
remap_te_layernorm
:
bool
=
False
,
def
get_gpt_layer_modelopt_spec
(
qk_layernorm
:
bool
=
False
,
num_experts
:
Optional
[
int
]
=
None
,
)
->
ModuleSpec
:
local_core_attention
:
bool
=
False
,
"""Mix the native spec with TENorm.
moe_grouped_gemm
:
bool
=
False
,
remap_te_layernorm
:
bool
=
False
,
This is essentially the native local spec except for the layernorm implementation
qk_layernorm
:
bool
=
False
,
is using TENorm from Transformer-Engine. The issue is that FusedLayerNorm from apex
)
->
ModuleSpec
:
has stopped supporting RMSNorm needed by llama.
"""Mix the native spec with TENorm.
"""
mlp
=
_get_mlp_module_spec
(
This is essentially the native local spec except for the layernorm implementation
use_te
=
False
,
num_experts
=
num_experts
,
moe_grouped_gemm
=
moe_grouped_gemm
,
fp8
=
False
is using TENorm from Transformer-Engine. The issue is that FusedLayerNorm from apex
)
has stopped supporting RMSNorm needed by llama.
sharded_state_dict_keys_map
=
{}
"""
if
remap_te_layernorm
:
core_attention
=
DotProductAttention
if
local_core_attention
else
TEDotProductAttention
if
num_experts
:
mlp
=
get_mlp_module_spec
(
sharded_state_dict_keys_map
=
{
use_te
=
False
,
num_experts
=
num_experts
,
moe_grouped_gemm
=
moe_grouped_gemm
,
fp8
=
False
'input_layernorm.'
:
'self_attention.linear_qkv.layer_norm_'
)
}
sharded_state_dict_keys_map
=
{}
else
:
if
remap_te_layernorm
:
sharded_state_dict_keys_map
=
{
if
num_experts
:
'input_layernorm.'
:
'self_attention.linear_qkv.layer_norm_'
,
sharded_state_dict_keys_map
=
{
'pre_mlp_layernorm.'
:
'mlp.linear_fc1.layer_norm_'
,
'input_layernorm.'
:
'self_attention.linear_qkv.layer_norm_'
}
}
return
ModuleSpec
(
else
:
module
=
TransformerLayer
,
sharded_state_dict_keys_map
=
{
submodules
=
TransformerLayerSubmodules
(
'input_layernorm.'
:
'self_attention.linear_qkv.layer_norm_'
,
input_layernorm
=
TENorm
,
'pre_mlp_layernorm.'
:
'mlp.linear_fc1.layer_norm_'
,
self_attention
=
ModuleSpec
(
}
module
=
SelfAttention
,
return
ModuleSpec
(
params
=
{
"attn_mask_type"
:
AttnMaskType
.
causal
},
module
=
TransformerLayer
,
submodules
=
SelfAttentionSubmodules
(
submodules
=
TransformerLayerSubmodules
(
linear_qkv
=
ColumnParallelLinear
,
input_layernorm
=
TENorm
,
core_attention
=
TEDotProductAttention
,
self_attention
=
ModuleSpec
(
linear_proj
=
RowParallelLinear
,
module
=
SelfAttention
,
q_layernorm
=
TENorm
if
qk_layernorm
else
IdentityOp
,
params
=
{
"attn_mask_type"
:
AttnMaskType
.
causal
},
k_layernorm
=
TENorm
if
qk_layernorm
else
IdentityOp
,
submodules
=
SelfAttentionSubmodules
(
),
linear_qkv
=
ColumnParallelLinear
,
),
core_attention
=
core_attention
,
self_attn_bda
=
get_bias_dropout_add
,
linear_proj
=
RowParallelLinear
,
pre_mlp_layernorm
=
TENorm
,
q_layernorm
=
TENorm
if
qk_layernorm
else
IdentityOp
,
mlp
=
mlp
,
k_layernorm
=
TENorm
if
qk_layernorm
else
IdentityOp
,
mlp_bda
=
get_bias_dropout_add
,
),
# Map TE-layernorm-fusion keys back
),
sharded_state_dict_keys_map
=
sharded_state_dict_keys_map
,
self_attn_bda
=
get_bias_dropout_add
,
),
pre_mlp_layernorm
=
TENorm
,
)
mlp
=
mlp
,
mlp_bda
=
get_bias_dropout_add
,
# Map TE-layernorm-fusion keys back
sharded_state_dict_keys_map
=
sharded_state_dict_keys_map
,
),
)
megatron/core/inference/modelopt_support/mamba/__init__.py
0 → 100644
View file @
688448db
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
megatron/core/inference/modelopt_support/mamba/model_specs.py
0 → 100644
View file @
688448db
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
from
megatron.core.extensions.transformer_engine
import
TEDotProductAttention
,
TENorm
from
megatron.core.fusions.fused_bias_dropout
import
get_bias_dropout_add
from
megatron.core.ssm.mamba_block
import
MambaStack
,
MambaStackSubmodules
from
megatron.core.ssm.mamba_layer
import
MambaLayer
,
MambaLayerSubmodules
from
megatron.core.ssm.mamba_mixer
import
MambaMixer
,
MambaMixerSubmodules
from
megatron.core.tensor_parallel.layers
import
ColumnParallelLinear
,
RowParallelLinear
from
megatron.core.transformer.attention
import
SelfAttention
,
SelfAttentionSubmodules
from
megatron.core.transformer.dot_product_attention
import
DotProductAttention
from
megatron.core.transformer.enums
import
AttnMaskType
from
megatron.core.transformer.mlp
import
MLP
,
MLPSubmodules
from
megatron.core.transformer.spec_utils
import
ModuleSpec
from
megatron.core.transformer.transformer_layer
import
TransformerLayer
,
TransformerLayerSubmodules
# Use this spec for ModelOpt PTQ and TensorRT-LLM export
def
get_mamba_stack_modelopt_spec
(
local_core_attention
:
bool
=
False
,
remap_te_layernorm
:
bool
=
False
)
->
ModuleSpec
:
"""Mix the native spec with TENorm.
This is essentially the native local spec except for the layernorm implementation
is using TENorm from Transformer-Engine.
"""
mamba_state_dict_keys_map
=
{}
transformer_state_dict_keys_map
=
{}
if
remap_te_layernorm
:
mamba_state_dict_keys_map
=
{
'norm.'
:
'mixer.in_proj.layer_norm_'
}
transformer_state_dict_keys_map
=
{
'input_layernorm.'
:
'self_attention.linear_qkv.layer_norm_'
,
'pre_mlp_layernorm.'
:
'mlp.linear_fc1.layer_norm_'
,
}
mamba_layer
=
ModuleSpec
(
module
=
MambaLayer
,
submodules
=
MambaLayerSubmodules
(
norm
=
TENorm
,
mixer
=
ModuleSpec
(
module
=
MambaMixer
,
submodules
=
MambaMixerSubmodules
(
in_proj
=
ColumnParallelLinear
,
out_proj
=
RowParallelLinear
),
),
mamba_bda
=
get_bias_dropout_add
,
sharded_state_dict_keys_map
=
mamba_state_dict_keys_map
,
),
)
core_attention
=
DotProductAttention
if
local_core_attention
else
TEDotProductAttention
attention_layer
=
ModuleSpec
(
module
=
TransformerLayer
,
submodules
=
TransformerLayerSubmodules
(
input_layernorm
=
TENorm
,
self_attention
=
ModuleSpec
(
module
=
SelfAttention
,
params
=
{
"attn_mask_type"
:
AttnMaskType
.
causal
},
submodules
=
SelfAttentionSubmodules
(
linear_qkv
=
ColumnParallelLinear
,
core_attention
=
core_attention
,
linear_proj
=
RowParallelLinear
,
),
),
self_attn_bda
=
get_bias_dropout_add
,
sharded_state_dict_keys_map
=
transformer_state_dict_keys_map
,
),
)
mlp_layer
=
ModuleSpec
(
module
=
TransformerLayer
,
submodules
=
TransformerLayerSubmodules
(
pre_mlp_layernorm
=
TENorm
,
mlp
=
ModuleSpec
(
module
=
MLP
,
submodules
=
MLPSubmodules
(
linear_fc1
=
ColumnParallelLinear
,
linear_fc2
=
RowParallelLinear
),
),
mlp_bda
=
get_bias_dropout_add
,
sharded_state_dict_keys_map
=
transformer_state_dict_keys_map
,
),
)
return
ModuleSpec
(
module
=
MambaStack
,
submodules
=
MambaStackSubmodules
(
mamba_layer
=
mamba_layer
,
attention_layer
=
attention_layer
,
mlp_layer
=
mlp_layer
),
)
megatron/core/inference/sampling_params.py
View file @
688448db
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
@
dataclass
@
dataclass
class
SamplingParams
:
class
SamplingParams
:
"""Inference parameters sent along with the prompts.
"""Inference parameters sent along with the prompts.
This class contains request-level attributes that control the sampling techniques used when
This class contains request-level attributes that control the sampling techniques used when
generating text. This is distinct from megatron.core.InferenceParams, which is sets model-level
generating text. This is distinct from megatron.core.InferenceParams, which is sets model-level
inference attributes such as the maximum sequence length, and contains the KV cache.
inference attributes such as the maximum sequence length, and contains the KV cache.
For an explanation of these parameters refer to this blog
For an explanation of these parameters refer to this blog
https://ivibudh.medium.com/a-guide-to-controlling-llm-model-output-exploring-top-k-top-p-and-
https://ivibudh.medium.com/a-guide-to-controlling-llm-model-output-exploring-top-k-top-p-and-
temperature-parameters-ed6a31313910
temperature-parameters-ed6a31313910
"""
"""
temperature
:
float
=
1.0
temperature
:
float
=
1.0
top_k
:
int
=
0
top_k
:
int
=
0
top_p
:
float
=
0.0
top_p
:
float
=
0.0
return_log_probs
:
bool
=
False
return_log_probs
:
bool
=
False
num_tokens_to_generate
:
int
=
30
return_segments
:
bool
=
False
# Whether to return individually detokenized tokens
num_tokens_to_generate
:
int
=
30
def
add_attributes
(
self
,
attribute_value_pair
:
dict
):
"""Utility to add more attributes to sampling params
def
add_attributes
(
self
,
attribute_value_pair
:
dict
):
"""Utility to add more attributes to sampling params
Use this method to pass in a custom dictionary to add more sampling parameter attributes.
c = SamplingParams
Use this method to pass in a custom dictionary to add more sampling parameter attributes.
c.add_attributes({'min_length':4, 'eod_id':153})
c = SamplingParams
c.add_attributes({'min_length':4, 'eod_id':153})
Args:
attribute_value_pair (dict): A dictionary containing attributes as the key names and
Args:
their values as the values.
attribute_value_pair (dict): A dictionary containing attributes as the key names and
"""
their values as the values.
for
key
,
value
in
attribute_value_pair
.
items
():
"""
setattr
(
self
,
key
,
value
)
for
key
,
value
in
attribute_value_pair
.
items
():
setattr
(
self
,
key
,
value
)
megatron/core/inference/scheduler.py
View file @
688448db
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import
time
import
functools
import
typing
import
time
from
collections
import
OrderedDict
import
typing
from
typing
import
Dict
from
collections
import
OrderedDict
from
typing
import
Dict
,
Optional
,
Type
,
Union
import
torch
import
torch
from
megatron.core.inference.inference_request
import
InferenceRequest
,
Status
from
megatron.core.inference.sampling_params
import
SamplingParams
from
megatron.core.inference.async_stream
import
AsyncStream
from
megatron.core.inference.utils
import
Counter
from
megatron.core.inference.inference_request
import
InferenceRequest
,
Status
from
megatron.core.inference.sampling_params
import
SamplingParams
from
megatron.core.inference.utils
import
Counter
class
Scheduler
:
"""Scheduler for handling requests to inference engine
class
Scheduler
:
This class is responsible for handing of all the incomign requests
"""Scheduler for handling requests to inference engine
Args:
This class is responsible for handing of all the incomign requests
max_batch_size (int): The max batch size that we can pass to the
inference engine at a time.
Args:
"""
max_batch_size (int): The max batch size that we can pass to the
inference engine at a time.
def
__init__
(
self
,
max_batch_size
:
int
):
request_type (InferenceRequest): The class to use for instantiating new requests.
self
.
max_batch_size
=
max_batch_size
"""
self
.
active_request_pool
:
Dict
[
int
,
InferenceRequest
]
=
OrderedDict
()
self
.
waiting_request_pool
:
Dict
[
int
,
InferenceRequest
]
=
OrderedDict
()
def
__init__
(
self
,
max_batch_size
):
self
.
completed_request_pool
:
Dict
[
int
,
InferenceRequest
]
=
OrderedDict
()
self
.
max_batch_size
=
max_batch_size
self
.
request_counter
=
Counter
()
self
.
requests
:
Dict
[
str
,
InferenceRequest
]
=
OrderedDict
()
self
.
streams
:
Dict
[
str
,
AsyncStream
]
=
OrderedDict
()
def
add_request
(
self
.
active_request_pool
:
Dict
[
str
,
InferenceRequest
]
=
OrderedDict
()
self
,
self
.
waiting_request_pool
:
Dict
[
str
,
InferenceRequest
]
=
OrderedDict
()
prompt
:
str
,
self
.
completed_request_pool
:
Dict
[
str
,
InferenceRequest
]
=
OrderedDict
()
prompt_tokens
:
torch
.
Tensor
,
self
.
request_counter
=
Counter
()
encoder_prompt
:
str
=
None
,
inference_parameters
:
SamplingParams
=
None
,
def
get_new_request_id
(
self
)
->
str
:
arrival_time
:
float
=
None
,
"""Gets a new request id"""
):
request_id
=
str
(
next
(
self
.
request_counter
))
"""Add an incoming request
return
request_id
This method will add the request to either the active pool or the waiting pool
def
add_request
(
depending on the batch size.
self
,
prompt
:
Optional
[
str
]
=
None
,
Args:
prompt_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
prompt (str): Input prompt string
encoder_prompt
:
Optional
[
str
]
=
None
,
prompt_tokens (torch.Tensor): A torch tensor having the input prompts tokenized
inference_parameters
:
Optional
[
SamplingParams
]
=
None
,
encoder_prompt (str): Encoder input string
arrival_time
:
Optional
[
float
]
=
None
,
inference_parameters (SamplingParams): The inference parameters
streaming
:
bool
=
False
,
arrival_time (float, optional): The incoming request time. Defaults to None.
inference_request
:
Optional
[
InferenceRequest
]
=
None
,
"""
)
->
str
:
request_id
=
str
(
next
(
self
.
request_counter
))
"""Add an incoming request
if
arrival_time
is
None
:
This method will add the request to either the active pool or the waiting pool
arrival_time
=
time
.
time
()
depending on the batch size.
status
=
(
Args:
Status
.
ACTIVE_BUT_NOT_GENERATING_TOKENS
prompt (str): Input prompt string
if
len
(
self
.
active_request_pool
)
<
self
.
max_batch_size
prompt_tokens (torch.Tensor): A torch tensor having the input prompts tokenized
else
Status
.
WAITING_IN_QUEUE
encoder_prompt (str): Encoder input string
)
inference_parameters (SamplingParams): The inference parameters
arrival_time (float, optional): The incoming request time. Defaults to None.
inference_request
=
InferenceRequest
(
streaming (bool, optional): Whether to asynchronously stream tokens for this request.
request_id
=
request_id
,
inference_request (InferenceRequest, optional): A fully constructed request.
prompt
=
prompt
,
Defaults to None.
inference_parameters
=
inference_parameters
,
arrival_time
=
arrival_time
,
Returns:
prompt_tokens
=
prompt_tokens
,
The request_id for the new request.
status
=
status
,
"""
encoder_prompt
=
encoder_prompt
,
status
=
(
)
Status
.
ACTIVE_BUT_NOT_GENERATING_TOKENS
if
len
(
self
.
active_request_pool
)
<
self
.
max_batch_size
if
status
==
status
.
ACTIVE_BUT_NOT_GENERATING_TOKENS
:
else
Status
.
WAITING_IN_QUEUE
self
.
active_request_pool
[
request_id
]
=
inference_request
)
else
:
self
.
waiting_request_pool
[
request_id
]
=
inference_request
if
inference_request
is
None
:
assert
prompt
is
not
None
def
have_requests_pending
(
self
)
->
bool
:
assert
prompt_tokens
is
not
None
"""Method to check if there are requests pending
request_id
=
self
.
get_new_request_id
()
This method returns False only when there are no active requests or waiting requests.
"""
if
arrival_time
is
None
:
num_requests_pending
=
len
(
self
.
active_request_pool
)
+
len
(
self
.
waiting_request_pool
)
arrival_time
=
time
.
time
()
return
num_requests_pending
>
0
inference_request
=
InferenceRequest
(
def
add_earliest_waiting_request_to_active_pool
(
self
):
request_id
=
request_id
,
"""Utility to add the waiting request to active pool
prompt
=
prompt
,
inference_parameters
=
inference_parameters
,
This method will add the earliest request (FIFO) that is in the waiting request
arrival_time
=
arrival_time
,
pool to the active request pool.
prompt_tokens
=
prompt_tokens
,
"""
status
=
status
,
assert
(
encoder_prompt
=
encoder_prompt
,
len
(
self
.
active_request_pool
)
<
self
.
max_batch_size
)
),
"Active request pool is already full. Cant add any more requests"
else
:
if
len
(
self
.
waiting_request_pool
)
>
0
:
request_id
=
inference_request
.
request_id
(
earliest_waiting_request_request_id
,
earliest_waiting_request
)
=
(
inference_request
.
status
=
status
self
.
waiting_request_pool
.
popitem
(
last
=
False
)
if
inference_request
.
arrival_time
is
None
:
)
inference_request
.
arrival_time
=
time
.
time
()
earliest_waiting_request
.
status
=
Status
.
ACTIVE_BUT_NOT_GENERATING_TOKENS
self
.
active_request_pool
[
earliest_waiting_request_request_id
]
=
earliest_waiting_request
self
.
requests
[
request_id
]
=
inference_request
def
update_requests_pools
(
self
,
result_dict
:
typing
.
OrderedDict
[
int
,
InferenceRequest
]
=
None
):
if
streaming
:
"""Update request pool status
abort_request
=
functools
.
partial
(
self
.
abort_request
,
request_id
=
request_id
)
self
.
streams
[
request_id
]
=
AsyncStream
(
request_id
,
abort_request
)
This method will full up the active request pool, if it has less than max batch size
elements from the waiting request pool.
if
status
==
status
.
ACTIVE_BUT_NOT_GENERATING_TOKENS
:
If provided with a request dict, it will put the completed requests into the completed
self
.
active_request_pool
[
request_id
]
=
inference_request
request pool and add waiting request into active pool.
else
:
self
.
waiting_request_pool
[
request_id
]
=
inference_request
Args:
result (typing.OrderedDict[int, InferenceRequest], optional): The result returned
return
request_id
by the engine. A dictionary with keys as the request ids, and values as the
requests. Defaults to None
def
have_requests_pending
(
self
)
->
bool
:
"""
"""Method to check if there are requests pending
for
result_request_id
in
list
(
result_dict
.
keys
()):
active_request
=
self
.
active_request_pool
[
result_request_id
]
This method returns False only when there are no active requests or waiting requests.
"""
# If a request has completed put it into the completed request pool.
num_requests_pending
=
len
(
self
.
active_request_pool
)
+
len
(
self
.
waiting_request_pool
)
if
active_request
.
status
==
Status
.
COMPLETED
:
return
num_requests_pending
>
0
completed_request
=
self
.
active_request_pool
.
pop
(
result_request_id
)
self
.
completed_request_pool
[
result_request_id
]
=
completed_request
def
add_earliest_waiting_request_to_active_pool
(
self
):
"""Utility to add the waiting request to active pool
# If the active request pool is not full, add waiting requests in FIFO order
while
(
This method will add the earliest request (FIFO) that is in the waiting request
len
(
self
.
active_request_pool
)
<
self
.
max_batch_size
pool to the active request pool.
and
len
(
self
.
waiting_request_pool
)
>
0
"""
):
assert
(
self
.
add_earliest_waiting_request_to_active_pool
()
len
(
self
.
active_request_pool
)
<
self
.
max_batch_size
),
"Active request pool is already full. Cant add any more requests"
if
len
(
self
.
waiting_request_pool
)
>
0
:
(
earliest_waiting_request_request_id
,
earliest_waiting_request
)
=
(
self
.
waiting_request_pool
.
popitem
(
last
=
False
)
)
earliest_waiting_request
.
status
=
Status
.
ACTIVE_BUT_NOT_GENERATING_TOKENS
self
.
active_request_pool
[
earliest_waiting_request_request_id
]
=
earliest_waiting_request
def
update_requests_pools
(
self
,
result_dict
:
Optional
[
typing
.
OrderedDict
[
str
,
InferenceRequest
]]
=
None
):
"""Update request pool status
This method will full up the active request pool, if it has less than max batch size
elements from the waiting request pool.
If provided with a request dict, it will put the completed requests into the completed
request pool and add waiting request into active pool.
Args:
result (typing.OrderedDict[str, InferenceRequest], optional): The result returned
by the engine. A dictionary with keys as the request ids, and values as the
requests. Defaults to None
"""
for
result_request_id
in
list
(
result_dict
.
keys
()):
active_request
=
self
.
active_request_pool
[
result_request_id
]
# If a request has completed put it into the completed request pool.
if
active_request
.
status
==
Status
.
COMPLETED
:
completed_request
=
self
.
active_request_pool
.
pop
(
result_request_id
)
self
.
completed_request_pool
[
result_request_id
]
=
completed_request
# If the active request pool is not full, add waiting requests in FIFO order
while
(
len
(
self
.
active_request_pool
)
<
self
.
max_batch_size
and
len
(
self
.
waiting_request_pool
)
>
0
):
self
.
add_earliest_waiting_request_to_active_pool
()
def
abort_request
(
self
,
request_id
:
str
,
*
,
exception
:
Optional
[
Union
[
BaseException
,
Type
[
BaseException
]]]
=
None
):
"""Cancels the given request"""
stream
=
self
.
streams
.
get
(
request_id
,
None
)
if
stream
is
not
None
:
stream
.
finish
(
exception
=
exception
)
megatron/core/inference/text_generation_controllers/encoder_decoder_text_generation_controller.py
View file @
688448db
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
typing
import
OrderedDict
from
typing
import
Any
,
Dict
,
OrderedDict
import
torch
import
torch
from
megatron.core.inference.inference_request
import
InferenceRequest
from
megatron.core.inference.inference_request
import
InferenceRequest
from
megatron.core.inference.text_generation_controllers.text_generation_controller
import
(
from
megatron.core.inference.text_generation_controllers.text_generation_controller
import
(
TextGenerationController
,
TextGenerationController
,
)
)
class
EncoderDecoderTextGenerationController
(
TextGenerationController
):
class
EncoderDecoderTextGenerationController
(
TextGenerationController
):
"""The text generation controller for encoder-decoder architecture
"""The text generation controller for encoder-decoder architecture
This class inherits from TextGenerationController, adding features
This class inherits from TextGenerationController, adding features
relating to encoder input encoder_prompt
relating to encoder input encoder_prompt
"""
"""
def
prep_model_for_inference
(
def
prep_inference_input
(
self
,
prompts_tokens
:
torch
.
Tensor
,
active_requests
:
OrderedDict
[
int
,
InferenceRequest
]
self
,
prompts_tokens
:
torch
.
Tensor
,
active_requests
:
OrderedDict
[
str
,
InferenceRequest
]
):
)
->
Dict
[
str
,
Any
]:
"""Preparing batch for inference, using respective wrapper's prep_model_for_inference method
"""Preparing input data for inference, using respective wrapper's prep_inference_input method # pylint: disable=line-too-long
Args:
Args:
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length]
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length]
active_requests (OrderedDict[int, InferenceRequest]): The input active requests
active_requests (OrderedDict[str, InferenceRequest]): The input active requests
"""
encoder_prompts
=
list
(
Returns:
map
(
lambda
request
:
request
.
encoder_prompt
,
active_requests
.
values
())
A dict of the inference input for the current batch.
)
"""
encoder_prompts
=
list
(
self
.
inference_wrapped_model
.
prep_model_for_inference
(
map
(
lambda
request
:
request
.
encoder_prompt
,
active_requests
.
values
())
prompts_tokens
=
prompts_tokens
,
encoder_prompts
=
encoder_prompts
,
tokenizer
=
self
.
tokenizer
)
)
return
self
.
inference_wrapped_model
.
prep_inference_input
(
prompts_tokens
,
encoder_prompts
,
tokenizer
=
self
.
tokenizer
)
megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py
View file @
688448db
# Copyright (c) 2024, NVIDIA CORPORATION.
All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
megatron.core.inference.text_generation_controllers.text_generation_controller
import
(
# noqa: F401 # pylint: disable=unused-import
from
megatron.core.inference.text_generation_controllers.text_generation_controller
import
(
# noqa: F401 # pylint: disable=unused-import
TextGenerationController
as
SimpleTextGenerationController
,
TextGenerationController
as
SimpleTextGenerationController
,
)
)
megatron/core/inference/text_generation_controllers/text_generation_controller.py
View file @
688448db
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
typing
import
List
,
OrderedDict
,
Tuple
import
concurrent
import
copy
import
torch
import
functools
import
torch.nn.functional
as
F
from
typing
import
Any
,
Dict
,
List
,
Optional
,
OrderedDict
,
Tuple
,
Union
from
megatron.core
import
parallel_state
import
torch
from
megatron.core.inference.communication_utils
import
broadcast_from_last_pipeline_stage
import
torch.nn.functional
as
F
from
megatron.core.inference.inference_request
import
InferenceRequest
,
Status
from
megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper
import
(
from
megatron.core
import
parallel_state
AbstractModelInferenceWrapper
,
from
megatron.core.inference.async_stream
import
AsyncStream
)
from
megatron.core.inference.communication_utils
import
broadcast_from_last_pipeline_stage
from
megatron.core.inference.sampling_params
import
SamplingParams
from
megatron.core.inference.inference_request
import
InferenceRequest
,
Status
from
megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper
import
(
AbstractModelInferenceWrapper
,
class
TextGenerationController
:
)
"""The text generation controller (the main sampling loop)
from
megatron.core.inference.sampling_params
import
SamplingParams
from
megatron.core.transformer.cuda_graphs
import
create_cudagraphs
This class tokenizes the input, runs inference, samples from logits, and detokenizes the output.
from
megatron.core.utils
import
get_model_config
Args:
inference_wrapped_model (AbstractModelInferenceWrapper): A model that
class
TextGenerationController
:
is wrapped using the specs given in the abstract_model_inference_wrapper.py
"""The text generation controller (the main sampling loop)
tokenizer (_type_): Tokenizer used for tokenizing and detokenizing the prompts
"""
This class tokenizes the input, runs inference, samples from logits, and detokenizes the output.
def
__init__
(
self
,
inference_wrapped_model
:
AbstractModelInferenceWrapper
,
tokenizer
):
Args:
self
.
inference_wrapped_model
=
inference_wrapped_model
inference_wrapped_model (AbstractModelInferenceWrapper): A model that
self
.
tokenizer
=
tokenizer
is wrapped using the specs given in the abstract_model_inference_wrapper.py
tokenizer (_type_): Tokenizer used for tokenizing and detokenizing the prompts
# For models without pipeline parallelism, is_first_stage and is_last_stage returns True
"""
self
.
model_is_pipeline_parallel
=
not
(
parallel_state
.
is_pipeline_first_stage
()
and
parallel_state
.
is_pipeline_last_stage
()
def
__init__
(
self
,
inference_wrapped_model
:
AbstractModelInferenceWrapper
,
tokenizer
):
)
self
.
inference_wrapped_model
=
inference_wrapped_model
self
.
tokenizer
=
tokenizer
def
tokenize_prompt
(
self
,
prompt
:
str
,
add_BOS
:
bool
=
False
# For models without pipeline parallelism, is_first_stage and is_last_stage returns True
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
self
.
model_is_pipeline_parallel
=
not
(
"""Utility to tokenize the input prompts
parallel_state
.
is_pipeline_first_stage
()
and
parallel_state
.
is_pipeline_last_stage
()
)
Args:
prompt (str): The input prompt
def
tokenize_prompt
(
self
,
prompt
:
str
,
add_BOS
:
bool
=
False
Returns:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
torch.Tensor: Returns the tokenized prompt
"""Utility to tokenize the input prompts
"""
prompt_tokens
=
self
.
tokenizer
.
tokenize
(
prompt
)
Args:
prompt (str): The input prompt
if
add_BOS
:
prompt_tokens
=
[
self
.
tokenizer
.
bos
]
+
prompt_tokens
Returns:
torch.Tensor: Returns the tokenized prompt
return
prompt_tokens
"""
prompt_tokens
=
self
.
tokenizer
.
tokenize
(
prompt
)
def
detokenize_generations
(
self
,
prompt_tokens_with_generated_tokens
:
torch
.
Tensor
)
->
str
:
"""Detokenize the output generations
if
add_BOS
:
prompt_tokens
=
[
self
.
tokenizer
.
bos
]
+
prompt_tokens
Args:
prompt_tokens_with_generated_tokens (torch.Tensor): The input prompt
return
prompt_tokens
tokens plus the generated tokens
def
detokenize_generations
(
Returns:
self
,
str: The detokenized output
tokens_gpu_tensor
:
torch
.
Tensor
,
"""
lengths_gpu_tensor
:
torch
.
Tensor
,
tokens
=
prompt_tokens_with_generated_tokens
.
cpu
().
numpy
().
tolist
()
detokenize_segments
:
bool
,
return
self
.
tokenizer
.
detokenize
(
tokens
)
)
->
tuple
[
str
,
Optional
[
List
[
List
[
str
]]]]:
"""Detokenize the generated tokens.
def
sample_from_logits
(
self
,
Args:
last_token_logits
:
torch
.
Tensor
,
tokens_gpu_tensor (torch.Tensor): Tensor containing the tokens
sampling_params
:
SamplingParams
=
None
,
lengths_gpu_tensor (torch.Tensor): Tensor containing the lengths of each sequence
vocab_size
:
int
=
None
,
detokenize_segments (bool): If True, returns individually detokenized tokens. If False,
**
kwargs
returns None as second element. Helpful for understanding per-token boundaries in
)
->
torch
.
Tensor
:
generated text.
"""Samples the logits to generate outputs
Returns:
Given the logits of the last token, this function samples it
tuple[str, List[str] | None]: A tuple containing:
according to the parameters defined in sampling_params
- str: The complete detokenized text
and returns the samples
- List[str] | None: List of segmented tokens if detokenize_segments is True, else None
"""
Args:
# TODO(helenn): Unify with `detokenize_generations` from legacy textgen path
last_token_logits (torch.Tensor): The last token logits. A tensor of
size [batch_size, vocab_size]
if
not
detokenize_segments
:
sampling_params (SamplingParams): The parameters to use for inference.
tokens
=
tokens_gpu_tensor
.
cpu
().
numpy
().
tolist
()
vocab_size (int): Obtained from the tokenizer. Defaults to None
return
self
.
tokenizer
.
detokenize
(
tokens
),
None
Returns:
prompts_plus_generations
:
List
[
str
]
=
[]
torch.Tensor: 1D tensor of the sampled logits with [batch_size] elements
prompts_plus_generations_segments
:
List
[
List
[
str
]]
=
[]
"""
tokens_gpu_tensor
=
torch
.
unsqueeze
(
tokens_gpu_tensor
,
0
)
if
kwargs
.
get
(
'common_inference_params'
):
tokens
=
tokens_gpu_tensor
.
cpu
().
numpy
().
tolist
()
sampling_params
=
kwargs
[
'common_inference_params'
]
lengths
=
lengths_gpu_tensor
.
cpu
().
numpy
().
tolist
()
top_p
=
sampling_params
.
top_p
for
sequence_tokens
,
length
in
zip
(
tokens
,
lengths
):
top_k
=
sampling_params
.
top_k
sequence_tokens
=
sequence_tokens
[:
length
]
temperature
=
sampling_params
.
temperature
detok_str
=
self
.
tokenizer
.
detokenize
(
sequence_tokens
)
prompts_plus_generations
.
append
(
detok_str
)
assert
not
(
top_k
>
0
and
top_p
>
0
),
'Cannot have top-p and top-k both greater than zero'
offsets
=
self
.
tokenizer
.
offsets
(
sequence_tokens
,
detok_str
)
assert
top_p
<=
1.0
,
'top-p should be in (0,1]'
words
=
[
detok_str
[
start
:
end
]
for
start
,
end
in
zip
(
offsets
,
offsets
[
1
:]
+
[
len
(
detok_str
)])
def
modify_logits_for_top_k_filtering
(
logits
,
top_k
):
]
"""Set the logits for none top-k values to -inf."""
filter_
=
logits
<
torch
.
topk
(
logits
,
top_k
)[
0
][...,
-
1
,
None
]
prompts_plus_generations_segments
.
append
(
words
)
logits
.
masked_fill_
(
filter_
,
float
(
'-Inf'
))
text
=
self
.
tokenizer
.
detokenize
(
tokens
[
0
])
def
modify_logits_for_top_p_filtering
(
logits
,
top_p
):
"""Set the logits for none top-p values to -inf."""
return
text
,
prompts_plus_generations_segments
# First sort and calculate cumulative sum of probabilities.
sorted_logits
,
sorted_indices
=
torch
.
sort
(
logits
,
descending
=
True
)
def
sample_from_logits
(
cumulative_probs
=
sorted_logits
.
softmax
(
dim
=-
1
).
cumsum
(
dim
=-
1
)
self
,
last_token_logits
:
torch
.
Tensor
,
# Filteration based on the cumulative sum.
sampling_params
:
Optional
[
SamplingParams
]
=
None
,
filter_
=
cumulative_probs
>
top_p
vocab_size
:
Optional
[
int
]
=
None
,
# This shift by 1 is weird and I cannot justify it. This existed
**
kwargs
,
# in the original implementation:
)
->
torch
.
Tensor
:
# https://github.com/ari-holtzman/degen/blob/master/gen.py
"""Samples the logits to generate outputs
# and I guess it is needed so keeping it for now.
filter_
[:,
1
:]
=
filter_
[:,
:
-
1
].
clone
()
Given the logits of the last token, this function samples it
# Make sure we at least have one token to select from.
according to the parameters defined in sampling_params
filter_
[...,
0
]
=
0
and returns the samples
# Fill in the filtered part
Args:
filter_
=
filter_
.
scatter
(
1
,
sorted_indices
,
filter_
)
last_token_logits (torch.Tensor): The last token logits. A tensor of
logits
.
masked_fill_
(
filter_
,
float
(
'-Inf'
))
size [batch_size, vocab_size]
sampling_params (SamplingParams): The parameters to use for inference.
# Greedy sampling
vocab_size (int): Obtained from the tokenizer. Defaults to None
if
top_k
==
1
:
sampled_logits
=
torch
.
argmax
(
last_token_logits
,
dim
=-
1
)
Returns:
else
:
torch.Tensor: 1D tensor of the sampled logits with [batch_size] elements
last_token_logits
=
last_token_logits
.
clone
()
"""
if
temperature
!=
1.0
:
last_token_logits
.
div_
(
temperature
)
if
kwargs
.
get
(
'common_inference_params'
):
sampling_params
=
kwargs
[
'common_inference_params'
]
if
top_k
>
1
:
assert
top_k
<=
last_token_logits
.
size
(
1
),
'top-k is larger than logit size.'
top_p
=
sampling_params
.
top_p
if
vocab_size
:
top_k
=
sampling_params
.
top_k
assert
top_k
<
vocab_size
,
'top-k is larger than vocab size.'
temperature
=
sampling_params
.
temperature
modify_logits_for_top_k_filtering
(
last_token_logits
,
top_k
)
assert
not
(
top_k
>
0
and
top_p
>
0
),
'Cannot have top-p and top-k both greater than zero'
elif
top_p
>
0.0
:
assert
top_p
<=
1.0
,
'top-p should be in (0,1]'
modify_logits_for_top_p_filtering
(
last_token_logits
,
top_p
)
def
modify_logits_for_top_k_filtering
(
logits
,
top_k
):
# After filtering, we need to recalculate the distribution.
"""Set the logits for none top-k values to -inf."""
probabilities
=
last_token_logits
.
softmax
(
dim
=-
1
)
filter_
=
logits
<
torch
.
topk
(
logits
,
top_k
)[
0
][...,
-
1
,
None
]
sampled_logits
=
torch
.
multinomial
(
probabilities
,
num_samples
=
1
).
view
(
-
1
)
logits
.
masked_fill_
(
filter_
,
float
(
'-Inf'
))
# If vocab size is provided, make sure the samples are in in the range [0, vocab-size).
def
modify_logits_for_top_p_filtering
(
logits
,
top_p
):
if
vocab_size
:
"""Set the logits for none top-p values to -inf."""
sampled_logits
=
torch
.
clamp
(
sampled_logits
,
min
=
0
,
max
=
(
vocab_size
-
1
))
# First sort and calculate cumulative sum of probabilities.
return
sampled_logits
sorted_logits
,
sorted_indices
=
torch
.
sort
(
logits
,
descending
=
True
)
cumulative_probs
=
sorted_logits
.
softmax
(
dim
=-
1
).
cumsum
(
dim
=-
1
)
def
update_generation_status
(
self
,
# Filteration based on the cumulative sum.
updated_prompts_tokens
:
torch
.
Tensor
,
filter_
=
cumulative_probs
>
top_p
generation_started
:
torch
.
Tensor
,
# This shift by 1 is weird and I cannot justify it. This existed
current_context_end_position
:
int
,
# in the original implementation:
is_generation_done_tensor
:
torch
.
Tensor
,
# https://github.com/ari-holtzman/degen/blob/master/gen.py
generated_sequence_lengths
:
torch
.
Tensor
,
# and I guess it is needed so keeping it for now.
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
filter_
[:,
1
:]
=
filter_
[:,
:
-
1
].
clone
()
"""Checks which prompts have reached an end condition
# Make sure we at least have one token to select from.
filter_
[...,
0
]
=
0
We check which prompts have reached an end condition and set the corresponding
flags of the is_generation_done_tensor to True. The generated sequence lengths
# Fill in the filtered part
increase as we keep generating, until that prompts hits an end condition. The
filter_
=
filter_
.
scatter
(
1
,
sorted_indices
,
filter_
)
generation_started tensor determines which prompts have started generating.
logits
.
masked_fill_
(
filter_
,
float
(
'-Inf'
))
Args:
# Greedy sampling
updated_prompts_tokens (torch.Tensor): The prompts tokens updated with the latest
if
top_k
==
1
:
generated tokens. A tensor of shape [batch_size, max_seq_len]
sampled_logits
=
torch
.
argmax
(
last_token_logits
,
dim
=-
1
)
(i.e max_seq_len = max_prompt_len + tokens_to_generate)
else
:
generation_started (torch.Tensor): A boolean tensor of shape [batch_size]. True
last_token_logits
=
last_token_logits
.
clone
()
indicates the prompt at that index has started generating tokens.
if
temperature
!=
1.0
:
current_context_end_position (int): An integer indicating which position to
last_token_logits
.
div_
(
temperature
)
extract from the prompts tokens to get the latest generated tokens.
is_generation_done_tensor (torch.Tensor): A boolean tensor of shape [batch_size].
if
top_k
>
1
:
True indicates the prompt at that index has reached end condition.
assert
top_k
<=
last_token_logits
.
size
(
1
),
'top-k is larger than logit size.'
generated_sequence_lengths (torch.Tensor): A int tensor of shape [batch_size].
if
vocab_size
:
Each value represents the generated sequence lengths for that prompt.
assert
top_k
<
vocab_size
,
'top-k is larger than vocab size.'
modify_logits_for_top_k_filtering
(
last_token_logits
,
top_k
)
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Returns the boolean
elif
top_p
>
0.0
:
is_generation_done_tensor and the generated_sequence_lengths after updating it
modify_logits_for_top_p_filtering
(
last_token_logits
,
top_p
)
"""
latest_samples
=
updated_prompts_tokens
[:,
current_context_end_position
]
# After filtering, we need to recalculate the distribution.
# Make sure we are checking eod criterion only for prompts that have started generating
probabilities
=
last_token_logits
.
softmax
(
dim
=-
1
)
# (i.e) We only look at the generated tokenns and not the input tokens.
sampled_logits
=
torch
.
multinomial
(
probabilities
,
num_samples
=
1
).
view
(
-
1
)
reached_eod
=
(
latest_samples
==
self
.
tokenizer
.
eod
)
&
generation_started
is_generation_done_tensor
=
is_generation_done_tensor
|
reached_eod
# If vocab size is provided, make sure the samples are in in the range [0, vocab-size).
# We increment generated sequence lengths when that prompt has not hit the
if
vocab_size
:
# EOD and generation has started
sampled_logits
=
torch
.
clamp
(
sampled_logits
,
min
=
0
,
max
=
(
vocab_size
-
1
))
generated_sequence_lengths
+=
~
is_generation_done_tensor
&
generation_started
return
sampled_logits
return
is_generation_done_tensor
,
generated_sequence_lengths
def
update_generation_status
(
self
,
def
pad_input_prompt_tokens
(
updated_prompts_tokens
:
torch
.
Tensor
,
self
,
generation_started
:
torch
.
Tensor
,
batch_prompt_tokens_list
:
List
[
List
[
int
]],
current_context_end_position
:
int
,
max_prompt_length_in_batch
:
int
,
is_generation_done_tensor
:
torch
.
Tensor
,
num_tokens_to_generate
:
int
,
generated_sequence_lengths
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Method to pad input prompts
"""Checks which prompts have reached an end condition
Given a list of prompts, pad them all to uniform length
We check which prompts have reached an end condition and set the corresponding
flags of the is_generation_done_tensor to True. The generated sequence lengths
Args:
increase as we keep generating, until that prompts hits an end condition. The
batch_prompt_tokens_list (List[List[int]]): A list containing the prompt tokens
generation_started tensor determines which prompts have started generating.
max_prompt_length_in_batch (int): Maximum of the length of the input prompt tokens
num_tokens_togenerate (int): The number of tokens to generate for each prompt
Args:
updated_prompts_tokens (torch.Tensor): The prompts tokens updated with the latest
Returns:
generated tokens. A tensor of shape [batch_size, max_seq_len]
torch.Tensor: A torch tensor of shape [bs, max_seq_len] (i.e)
(i.e max_seq_len = max_prompt_len + tokens_to_generate)
max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate,
generation_started (torch.Tensor): A boolean tensor of shape [batch_size]. True
with extra indices for each tensor padded with mask id.
indicates the prompt at that index has started generating tokens.
"""
current_context_end_position (int): An integer indicating which position to
max_seq_len
=
max_prompt_length_in_batch
+
num_tokens_to_generate
extract from the prompts tokens to get the latest generated tokens.
is_generation_done_tensor (torch.Tensor): A boolean tensor of shape [batch_size].
for
prompt_tokens
in
batch_prompt_tokens_list
:
True indicates the prompt at that index has reached end condition.
padding_size
=
max_seq_len
-
len
(
prompt_tokens
)
generated_sequence_lengths (torch.Tensor): A int tensor of shape [batch_size].
prompt_tokens
.
extend
([
self
.
tokenizer
.
eod
]
*
padding_size
)
Each value represents the generated sequence lengths for that prompt.
return
torch
.
tensor
(
batch_prompt_tokens_list
).
cuda
()
Returns:
Tuple[torch.Tensor, torch.Tensor]: Returns the boolean
def
generate_output_tokens_dynamic_batch
(
is_generation_done_tensor and the generated_sequence_lengths after updating it
self
,
active_requests
:
OrderedDict
[
int
,
InferenceRequest
]
"""
)
->
OrderedDict
[
int
,
InferenceRequest
]:
latest_samples
=
updated_prompts_tokens
[:,
current_context_end_position
]
"""Utility to generate the output tokens and probabilities for the prompts
# Make sure we are checking eod criterion only for prompts that have started generating
# (i.e) We only look at the generated tokenns and not the input tokens.
This utility generates the output tokens for a dynamic batch. It will run one forward step
reached_eod
=
(
latest_samples
==
self
.
tokenizer
.
eod
)
&
generation_started
at a time, and pass control back to the engine, which will update the request pool and call
is_generation_done_tensor
=
is_generation_done_tensor
|
reached_eod
this method again.
# We increment generated sequence lengths when that prompt has not hit the
# EOD and generation has started
Args:
generated_sequence_lengths
+=
~
is_generation_done_tensor
&
generation_started
active_requests (OrderedDict[int, InferenceRequest]): The input active requests.
return
is_generation_done_tensor
,
generated_sequence_lengths
.
int
()
Returns:
OrderedDict[int, InferenceRequest]: The result for each of the incoming requests
def
pad_input_prompt_tokens
(
after running one forward step.
self
,
"""
batch_prompt_tokens_list
:
List
[
List
[
int
]],
raise
Exception
(
"Not implemented yet"
)
max_prompt_length_in_batch
:
int
,
num_tokens_to_generate
:
int
,
def
generate_all_output_tokens_static_batch
(
)
->
torch
.
Tensor
:
self
,
active_requests
:
OrderedDict
[
int
,
InferenceRequest
]
"""Method to pad input prompts
)
->
OrderedDict
[
int
,
InferenceRequest
]:
"""Utility to generate the all the output tokens and probabilities for the prompts .
Given a list of prompts, pad them all to uniform length
This utility generates the output tokens for a static batch. It runs the forward steps till
Args:
all prompts complete generation, updates the status of these requests to completed, adds
batch_prompt_tokens_list (List[List[int]]): A list containing the prompt tokens
the generated result and returns these requests
max_prompt_length_in_batch (int): Maximum of the length of the input prompt tokens
num_tokens_togenerate (int): The number of tokens to generate for each prompt
Args:
active_requests (OrderedDict[int, InferenceRequest]): The input active requests.
Returns:
torch.Tensor: A torch tensor of shape [bs, max_seq_len] (i.e)
Returns:
max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate,
OrderedDict[int, InferenceRequest]: The result for each of the incoming requests
"""
"""
max_seq_len
=
max_prompt_length_in_batch
+
num_tokens_to_generate
batch_prompt_tokens_list
=
list
(
map
(
lambda
request
:
request
.
prompt_tokens
,
active_requests
.
values
())
for
prompt_tokens
in
batch_prompt_tokens_list
:
)
padding_size
=
max_seq_len
-
len
(
prompt_tokens
)
prompt_lengths_in_batch
=
torch
.
tensor
(
prompt_tokens
.
extend
([
self
.
tokenizer
.
eod
]
*
padding_size
)
[
len
(
prompt_tokens
)
for
prompt_tokens
in
batch_prompt_tokens_list
]
).
cuda
()
return
torch
.
tensor
(
batch_prompt_tokens_list
,
device
=
torch
.
cuda
.
current_device
())
max_prompt_length_in_batch
=
max
(
prompt_lengths_in_batch
)
min_prompt_length_in_batch
=
min
(
prompt_lengths_in_batch
)
def
generate_output_tokens_dynamic_batch
(
self
,
active_requests
:
OrderedDict
[
str
,
InferenceRequest
]
# For batch inference the inference params are the same for all request
)
->
OrderedDict
[
str
,
InferenceRequest
]:
sampling_params
:
SamplingParams
=
list
(
active_requests
.
values
())[
0
].
inference_parameters
"""Utility to generate the output tokens and probabilities for the prompts
# max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate
This utility generates the output tokens for a dynamic batch. It will run one forward step
batch_prompt_tokens
=
self
.
pad_input_prompt_tokens
(
at a time, and pass control back to the engine, which will update the request pool and call
batch_prompt_tokens_list
,
this method again.
max_prompt_length_in_batch
=
max_prompt_length_in_batch
,
num_tokens_to_generate
=
sampling_params
.
num_tokens_to_generate
,
Args:
)
active_requests (OrderedDict[str, InferenceRequest]): The input active requests.
batch_size
,
max_sequence_length
=
batch_prompt_tokens
.
shape
Returns:
# Pre allocate log probs tensor
OrderedDict[str, InferenceRequest]: The result for each of the incoming requests
output_log_probs
=
None
after running one forward step.
if
sampling_params
.
return_log_probs
:
"""
output_log_probs
=
torch
.
empty
(
raise
Exception
(
"Not implemented yet"
)
(
batch_size
,
max_sequence_length
-
1
),
dtype
=
torch
.
float32
).
cuda
()
def
generate_all_output_tokens_static_batch
(
self
,
# An array to check which of the prompts have reached end of generation condition
active_requests
:
OrderedDict
[
str
,
InferenceRequest
],
is_generation_done_tensor
=
torch
.
zeros
(
batch_size
,
dtype
=
torch
.
bool
).
cuda
()
active_streams
:
Optional
[
OrderedDict
[
str
,
AsyncStream
]]
=
None
,
)
->
OrderedDict
[
str
,
InferenceRequest
]:
# An array to act as a counter to keep track of generated sequence lengths
"""Utility to generate the all the output tokens and probabilities for the prompts .
generated_sequence_lengths
=
torch
.
zeros
(
batch_size
).
cuda
()
This utility generates the output tokens for a static batch. It runs the forward steps till
with
torch
.
no_grad
():
all prompts complete generation, updates the status of these requests to completed, adds
the generated result and returns these requests
self
.
prep_model_for_inference
(
prompts_tokens
=
batch_prompt_tokens
,
active_requests
=
active_requests
Args:
)
active_requests (OrderedDict[str, InferenceRequest]): The input active requests.
context_start_position
=
0
Returns:
# Pick the context window that we need to pass through the network.
OrderedDict[str, InferenceRequest]: The result for each of the incoming requests
for
context_end_position
in
range
(
min_prompt_length_in_batch
,
max_sequence_length
):
"""
assert
all
(
request
.
prompt_tokens
is
not
None
for
request
in
active_requests
.
values
())
inference_input
=
self
.
inference_wrapped_model
.
get_batch_for_context_window
(
context_start_position
,
context_end_position
# Perform a deep copy so that the request prompt tokens do not get modified.
)
batch_prompt_tokens_list
:
List
[
List
[
int
]]
=
list
(
map
(
# Returns the final logits of shape [batch_size, context_length, vocab_size]
lambda
request
:
copy
.
deepcopy
(
request
.
prompt_tokens
),
# type: ignore[arg-type]
# Note: This is returned in all TP ranks or last PP stage in PP models
active_requests
.
values
(),
logits
=
self
.
inference_wrapped_model
.
run_one_forward_step
(
inference_input
)
)
if
self
.
model_is_pipeline_parallel
:
)
context_length
=
context_end_position
-
context_start_position
prompt_lengths_in_batch
=
torch
.
tensor
(
logits
=
broadcast_from_last_pipeline_stage
(
[
len
(
prompt_tokens
)
for
prompt_tokens
in
batch_prompt_tokens_list
],
[
batch_size
,
context_length
,
self
.
tokenizer
.
vocab_size
],
device
=
torch
.
cuda
.
current_device
(),
dtype
=
self
.
inference_wrapped_model
.
inference_wrapper_config
.
params_dtype
,
)
tensor
=
logits
,
max_prompt_length_in_batch
=
max
(
prompt_lengths_in_batch
)
)
min_prompt_length_in_batch
=
min
(
prompt_lengths_in_batch
)
# Indicates which of the input prompts have started generating tokens.
# For batch inference the inference params are the same for all request
# A 1D boolean tensor with [batch_size] elements (i.e) The shortest
sampling_params
:
SamplingParams
=
list
(
active_requests
.
values
())[
0
].
inference_parameters
# prompts will start generating first and so on
generation_started
=
prompt_lengths_in_batch
<=
context_end_position
# max_seq_len = max_prompt_length_in_batch + num_tokens_to_generate
last_token_logits
=
logits
[:,
-
1
,
:]
batch_prompt_tokens
=
self
.
pad_input_prompt_tokens
(
sampled_logits
=
self
.
sample_from_logits
(
batch_prompt_tokens_list
,
last_token_logits
,
sampling_params
,
self
.
tokenizer
.
vocab_size
max_prompt_length_in_batch
=
max_prompt_length_in_batch
,
)
num_tokens_to_generate
=
sampling_params
.
num_tokens_to_generate
,
)
# Substitute the sampled logits only for only the prompts that
batch_size
,
max_sequence_length
=
batch_prompt_tokens
.
shape
# have started generating tokens
batch_prompt_tokens
[
generation_started
,
context_end_position
]
=
sampled_logits
[
# Verify that output sequence length is within configured limit
generation_started
# TODO(ksanthanam): Raise TokenOverflowError once !2518 is merged
]
inference_max_sequence_length
=
(
self
.
inference_wrapped_model
.
inference_wrapper_config
.
inference_max_seq_length
if
sampling_params
.
return_log_probs
:
)
log_probs
=
F
.
log_softmax
(
logits
,
dim
=
2
)
assert
max_sequence_length
<=
inference_max_sequence_length
,
(
indices
=
torch
.
unsqueeze
(
f
"Maximum allowed sequence length was set to
{
inference_max_sequence_length
}
tokens "
batch_prompt_tokens
[
f
"but requested generation of
{
max_sequence_length
}
tokens"
:,
(
context_start_position
+
1
)
:
(
context_end_position
+
1
)
)
],
2
,
# Pre allocate log probs tensor
)
output_log_probs
=
None
# Get the log probabilities for only the prompt tokens
if
sampling_params
.
return_log_probs
:
output_log_probs
[:,
context_start_position
:
context_end_position
]
=
torch
.
gather
(
output_log_probs
=
torch
.
empty
(
log_probs
,
2
,
indices
(
batch_size
,
max_sequence_length
-
1
),
).
squeeze
(
2
)
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
(),
context_start_position
=
context_end_position
)
# Check end of generation status for each tensor
# An array to check which of the prompts have reached end of generation condition
# and update generated sequence lengths
is_generation_done_tensor
=
torch
.
zeros
(
(
is_generation_done_tensor
,
generated_sequence_lengths
)
=
(
batch_size
,
dtype
=
torch
.
bool
,
device
=
torch
.
cuda
.
current_device
()
self
.
update_generation_status
(
)
updated_prompts_tokens
=
batch_prompt_tokens
,
generation_started
=
generation_started
,
# An array to act as a counter to keep track of generated sequence lengths
current_context_end_position
=
context_end_position
,
generated_sequence_lengths
=
torch
.
zeros
(
is_generation_done_tensor
=
is_generation_done_tensor
,
batch_size
,
device
=
torch
.
cuda
.
current_device
()
generated_sequence_lengths
=
generated_sequence_lengths
,
).
cuda
()
)
)
# Use padded vocab size because tokenizer vocab size might not include padding
# Boolean flag indicating if all prompts are finished
# to nearest power of 2
all_prompts_done
=
torch
.
all
(
is_generation_done_tensor
)
vocab_size
=
self
.
inference_wrapped_model
.
inference_wrapper_config
.
padded_vocab_size
if
all_prompts_done
:
break
# Check whether CUDA graphs are enabled
enable_cuda_graph
=
get_model_config
(
self
.
inference_wrapped_model
.
model
).
enable_cuda_graph
# Include all the generated tokens
batch_prompt_tokens_with_generations
=
batch_prompt_tokens
[:,
:
(
context_end_position
+
1
)]
streaming_enabled
=
active_streams
is
not
None
and
len
(
active_streams
)
>
0
if
sampling_params
.
return_log_probs
:
if
streaming_enabled
:
output_log_probs
=
output_log_probs
[:,
:
context_end_position
]
# Start a separate thread for streaming tokens to avoid blocking the
# main computation
generated_sequence_lengths
[
streaming_idx
:
List
[
int
]
=
[
generated_sequence_lengths
>
sampling_params
.
num_tokens_to_generate
i
]
=
sampling_params
.
num_tokens_to_generate
for
(
i
,
request_id
)
in
enumerate
(
active_requests
.
keys
())
if
request_id
in
active_streams
for
idx
,
request
in
enumerate
(
active_requests
.
values
()):
]
input_prompt_length
=
int
(
prompt_lengths_in_batch
[
idx
])
streaming_request_ids
:
List
[
str
]
=
list
(
active_streams
.
keys
())
# Shorter prompts might have generated more than required tokens. So we trim them down
streams
:
List
[
AsyncStream
]
=
list
(
active_streams
.
values
())
required_sequence_length
=
int
(
streaming_requests
:
List
[
InferenceRequest
]
=
[
min
(
generated_sequence_lengths
[
idx
],
sampling_params
.
num_tokens_to_generate
)
active_requests
[
request_id
]
for
request_id
in
streaming_request_ids
)
]
# Extract only the generated tokens
streaming_executor
=
concurrent
.
futures
.
ThreadPoolExecutor
(
max_workers
=
1
)
required_result_tokens
=
batch_prompt_tokens_with_generations
[
stream_tokens
=
functools
.
partial
(
self
.
stream_tokens
,
sampling_params
)
idx
,
input_prompt_length
:
(
input_prompt_length
+
required_sequence_length
)
]
with
torch
.
no_grad
():
request
.
generated_length
=
required_sequence_length
self
.
inference_wrapped_model
.
prep_model_for_inference
(
request
.
generated_tokens
=
required_result_tokens
prompts_tokens
=
batch_prompt_tokens
request
.
generated_log_probs
=
(
)
None
if
output_log_probs
is
None
inference_input
:
Dict
[
str
,
Any
]
=
self
.
prep_inference_input
(
else
output_log_probs
[
idx
,
input_prompt_length
:
required_sequence_length
]
prompts_tokens
=
batch_prompt_tokens
,
active_requests
=
active_requests
)
)
request
.
status
=
Status
.
COMPLETED
request
.
generated_text
=
self
.
detokenize_generations
(
required_result_tokens
)
assert
(
not
self
.
inference_wrapped_model
.
inference_params
.
decode_mode
return
active_requests
),
f
"Generation must start in prefill mode"
def
prep_model_for_inference
(
context_start_position
=
0
self
,
prompts_tokens
:
torch
.
Tensor
,
active_requests
:
OrderedDict
[
int
,
InferenceRequest
]
# Pick the context window that we need to pass through the network.
):
for
context_end_position
in
range
(
min_prompt_length_in_batch
,
max_sequence_length
):
"""Preparing batch for inference, using respective wrapper's prep_model_for_inference method
inference_input_for_context_window
:
Dict
[
str
,
Any
]
=
(
Args:
self
.
inference_wrapped_model
.
get_batch_for_context_window
(
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length]
inference_input
,
context_start_position
,
context_end_position
active_requests (OrderedDict[int, InferenceRequest]): The input active requests
)
"""
)
self
.
inference_wrapped_model
.
prep_model_for_inference
(
prompts_tokens
=
prompts_tokens
)
# Disable attention mask when using CUDA graphs for decode
if
(
enable_cuda_graph
and
self
.
inference_wrapped_model
.
inference_params
.
decode_mode
and
"attention_mask"
in
inference_input_for_context_window
):
inference_input_for_context_window
[
"attention_mask"
]
=
None
# Returns the final logits of shape [batch_size, context_length, vocab_size]
# Note: This is returned in all TP ranks or last PP stage in PP models
logits
=
self
.
inference_wrapped_model
.
run_one_forward_step
(
inference_input_for_context_window
)
if
enable_cuda_graph
:
create_cudagraphs
()
if
self
.
model_is_pipeline_parallel
:
context_length
=
context_end_position
-
context_start_position
logits
=
broadcast_from_last_pipeline_stage
(
[
batch_size
,
context_length
,
vocab_size
],
dtype
=
self
.
inference_wrapped_model
.
inference_wrapper_config
.
params_dtype
,
tensor
=
logits
,
)
# Indicates which of the input prompts have started generating tokens.
# A 1D boolean tensor with [batch_size] elements (i.e) The shortest
# prompts will start generating first and so on
generation_started
=
prompt_lengths_in_batch
<=
context_end_position
last_token_logits
=
logits
[:,
-
1
,
:]
sampled_logits
=
self
.
sample_from_logits
(
last_token_logits
,
sampling_params
,
vocab_size
)
# Substitute the sampled logits only for the prompts that
# have started generating tokens
batch_prompt_tokens
[
generation_started
,
context_end_position
]
=
sampled_logits
[
generation_started
]
if
sampling_params
.
return_log_probs
:
log_probs
=
F
.
log_softmax
(
logits
,
dim
=
2
)
indices
=
torch
.
unsqueeze
(
batch_prompt_tokens
[
:,
(
context_start_position
+
1
)
:
(
context_end_position
+
1
)
],
2
,
)
# Get the log probabilities for only the prompt tokens
assert
output_log_probs
is
not
None
output_log_probs
[:,
context_start_position
:
context_end_position
]
=
torch
.
gather
(
log_probs
,
2
,
indices
).
squeeze
(
2
)
context_start_position
=
context_end_position
# Check end of generation status for each tensor
# and update generated sequence lengths
(
is_generation_done_tensor
,
generated_sequence_lengths
)
=
(
self
.
update_generation_status
(
updated_prompts_tokens
=
batch_prompt_tokens
,
generation_started
=
generation_started
,
current_context_end_position
=
context_end_position
,
is_generation_done_tensor
=
is_generation_done_tensor
,
generated_sequence_lengths
=
generated_sequence_lengths
,
)
)
# Stream intermediate outputs
if
streaming_enabled
:
streaming_executor
.
submit
(
stream_tokens
,
streaming_request_ids
,
streaming_requests
,
streams
,
generation_started
[
streaming_idx
].
cpu
(),
is_generation_done_tensor
[
streaming_idx
].
cpu
(),
batch_prompt_tokens
[
streaming_idx
].
cpu
(),
prompt_lengths_in_batch
[
streaming_idx
].
cpu
(),
generated_sequence_lengths
[
streaming_idx
].
cpu
(),
(
output_log_probs
[
streaming_idx
].
cpu
()
if
output_log_probs
is
not
None
else
[
None
]
*
len
(
streaming_idx
)
),
)
# Boolean flag indicating if all prompts are finished
all_prompts_done
=
torch
.
all
(
is_generation_done_tensor
)
if
all_prompts_done
:
break
# Change to decode mode if all prefill is complete
if
torch
.
all
(
generation_started
):
self
.
inference_wrapped_model
.
inference_params
.
enable_decode_mode
()
# Close all streams
if
streaming_enabled
:
streaming_executor
.
shutdown
()
for
stream
in
streams
:
stream
.
finish
()
# Include all the generated tokens
batch_prompt_tokens_with_generations
=
batch_prompt_tokens
[:,
:
(
context_end_position
+
1
)]
if
sampling_params
.
return_log_probs
:
assert
output_log_probs
is
not
None
output_log_probs
=
output_log_probs
[:,
:
context_end_position
]
generated_sequence_lengths
[
generated_sequence_lengths
>
sampling_params
.
num_tokens_to_generate
]
=
sampling_params
.
num_tokens_to_generate
for
idx
,
request
in
enumerate
(
active_requests
.
values
()):
input_prompt_length
=
int
(
prompt_lengths_in_batch
[
idx
])
# Shorter prompts might have generated more than required tokens. So we trim them down
required_sequence_length
=
int
(
min
(
generated_sequence_lengths
[
idx
],
sampling_params
.
num_tokens_to_generate
)
)
# Extract only the generated tokens
required_result_tokens
=
batch_prompt_tokens_with_generations
[
idx
,
input_prompt_length
:
(
input_prompt_length
+
required_sequence_length
)
]
generated_sequence_lengths
=
generated_sequence_lengths
.
to
(
dtype
=
torch
.
int32
)
request
.
generated_sequence_lengths
=
generated_sequence_lengths
.
to
(
dtype
=
torch
.
int32
)
request
.
generated_length
=
required_sequence_length
request
.
generated_tokens
=
required_result_tokens
request
.
prompt_log_probs
=
(
None
if
output_log_probs
is
None
else
output_log_probs
[
idx
,
:
input_prompt_length
].
cpu
().
numpy
().
tolist
()
)
request
.
generated_log_probs
=
(
None
if
output_log_probs
is
None
else
output_log_probs
[
idx
,
input_prompt_length
-
1
:
(
input_prompt_length
+
required_sequence_length
-
1
),
]
.
cpu
()
.
numpy
()
.
tolist
()
)
request
.
status
=
Status
.
COMPLETED
text
,
segments
=
self
.
detokenize_generations
(
batch_prompt_tokens_with_generations
[
idx
],
input_prompt_length
+
generated_sequence_lengths
,
sampling_params
.
return_segments
,
)
request
.
text
=
text
# Inference server returns prompts & generations together
if
sampling_params
.
return_segments
:
request
.
segments
=
segments
[
0
]
request
.
generated_text
=
text
[
len
(
request
.
prompt
)
:]
return
active_requests
def
prep_inference_input
(
self
,
prompts_tokens
:
torch
.
Tensor
,
active_requests
:
OrderedDict
[
str
,
InferenceRequest
]
)
->
Dict
[
str
,
Any
]:
"""Preparing input data for inference, using respective wrapper's prep_inference_input method # pylint: disable=line-too-long
Args:
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length]
active_requests (OrderedDict[str, InferenceRequest]): The input active requests
Returns:
A dict of the inference input for the current batch.
"""
return
self
.
inference_wrapped_model
.
prep_inference_input
(
prompts_tokens
)
def
stream_tokens
(
self
,
sampling_params
:
SamplingParams
,
request_ids
:
List
[
str
],
requests
:
List
[
InferenceRequest
],
streams
:
List
[
AsyncStream
],
generation_started
:
List
[
bool
],
is_generation_done
:
List
[
bool
],
tokens
:
torch
.
Tensor
,
prompt_lengths
:
List
[
int
],
generated_lengths
:
List
[
int
],
output_log_probs
:
Union
[
torch
.
Tensor
,
None
],
):
"""Asynchronously streams tokens for the given requests.
Args:
sampling_params (SamplingParams): The sampling parameters.
request_ids (List[str]): The request IDs.
request (List[InferenceRequest]): The requests.
stream (List[AsyncStream]): The streams over which to send tokens.
generation_started (List[bool]): Whether the decode step has started.
is_generation_done (List[bool]): Whether generation has completed.
tokens (torch.Tensor): The tokens for this request.
prompt_lengths (List[int]): The number of prompt tokens for each request.
generated_lengths (List[int]): The number of output tokens for each request.
output_log_probs (torch.Tensor, optional): The log probs for each request.
"""
def
stream_token
(
request_id
:
str
,
request
:
InferenceRequest
,
stream
:
AsyncStream
,
generation_started
:
bool
,
is_generation_done
:
bool
,
tokens
:
torch
.
Tensor
,
prompt_length
:
int
,
generated_length
:
int
,
output_log_probs
:
Union
[
torch
.
Tensor
,
None
],
):
"""Asynchronously streams a token for the given request."""
if
not
generation_started
or
stream
.
finished
:
return
num_tokens_to_generate
=
sampling_params
.
num_tokens_to_generate
return_segments
=
sampling_params
.
return_segments
detokenize_streaming_text
=
not
getattr
(
sampling_params
,
"no_detokenize_streaming_text"
,
False
)
generated_tokens
=
tokens
[
prompt_length
:
prompt_length
+
generated_length
]
if
detokenize_streaming_text
:
generated_text
,
generated_segments
=
self
.
detokenize_generations
(
generated_tokens
,
prompt_length
+
generated_length
,
return_segments
)
else
:
generated_text
=
""
generated_segments
=
[]
if
output_log_probs
is
not
None
:
generated_log_probs
=
(
output_log_probs
[
prompt_length
-
1
:
prompt_length
+
generated_length
-
1
]
.
cpu
()
.
numpy
()
.
tolist
()
)
else
:
generated_log_probs
=
None
stream
.
put
(
InferenceRequest
(
request_id
=
request_id
,
prompt
=
request
.
prompt
,
inference_parameters
=
request
.
inference_parameters
,
prompt_tokens
=
request
.
prompt_tokens
,
arrival_time
=
request
.
arrival_time
,
status
=
request
.
status
,
encoder_prompt
=
request
.
encoder_prompt
,
generated_text
=
generated_text
,
generated_segments
=
generated_segments
,
generated_tokens
=
generated_tokens
,
generated_log_probs
=
generated_log_probs
,
generated_length
=
generated_length
,
)
)
if
is_generation_done
or
generated_length
==
num_tokens_to_generate
:
stream
.
finish
()
ret
=
map
(
stream_token
,
request_ids
,
requests
,
streams
,
generation_started
,
is_generation_done
,
tokens
,
prompt_lengths
,
generated_lengths
,
output_log_probs
,
)
list
(
ret
)
megatron/core/inference/text_generation_controllers/vlm_text_generation_controller.py
0 → 100644
View file @
688448db
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
typing
import
OrderedDict
import
torch
from
megatron.core.inference.inference_request
import
InferenceRequest
,
VLMInferenceRequest
from
megatron.core.inference.text_generation_controllers.text_generation_controller
import
(
TextGenerationController
,
)
class
VLMTextGenerationController
(
TextGenerationController
):
"""The text generation controller for VLMs"""
def
prep_inference_input
(
self
,
prompts_tokens
:
torch
.
Tensor
,
active_requests
:
OrderedDict
[
str
,
InferenceRequest
]
):
"""Preparing input data for inference, using respective wrapper's prep_inference_input method # pylint: disable=line-too-long
Currently only supports batch size 1 inference.
Args:
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_sequence_length]
active_requests (OrderedDict[str, InferenceRequest]): The input active requests
"""
assert
len
(
active_requests
)
==
1
,
f
"VLM inference currently only supports batch size 1"
request
=
list
(
active_requests
.
values
())[
0
]
assert
isinstance
(
request
,
VLMInferenceRequest
),
f
"Found inference request of type
{
type
(
request
)
}
, expected VLMInferenceRequest"
return
self
.
inference_wrapped_model
.
prep_inference_input
(
prompts_tokens
,
request
.
num_img_embeddings_per_tile
,
request
.
imgs
,
request
.
num_tiles
,
request
.
decoder_seq_length
,
)
megatron/core/inference_params.py
View file @
688448db
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
class
InferenceParams
:
class
InferenceParams
:
"""Inference parameters that are passed to the main model in order
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
to efficienly calculate and store the context during inference."""
def
__init__
(
self
,
max_batch_size
,
max_sequence_length
):
def
__init__
(
self
,
max_batch_size
,
max_sequence_length
):
self
.
max_sequence_length
=
max_sequence_length
self
.
max_sequence_length
=
max_sequence_length
self
.
max_batch_size
=
max_batch_size
self
.
max_batch_size
=
max_batch_size
self
.
sequence_len_offset
=
0
self
.
current_batch_size
=
max_batch_size
# Required for bookkeeping variable-sized batches
self
.
batch_size_offset
=
0
self
.
sequence_len_offset
=
0
self
.
key_value_memory_dict
=
{}
self
.
batch_size_offset
=
0
self
.
decode_mode
=
False
def
swap_key_value_dict
(
self
,
batch_idx
):
self
.
key_value_memory_dict
=
{}
"swap between batches"
self
.
decode_mode
=
False
if
len
(
self
.
key_value_memory_dict
)
==
0
:
raise
ValueError
(
"should not swap when dict in empty"
)
def
swap_key_value_dict
(
self
,
batch_idx
):
"swap between batches"
for
layer_number
in
self
.
key_value_memory_dict
.
keys
():
if
len
(
self
.
key_value_memory_dict
)
==
0
:
inference_key_memory
,
inference_value_memory
=
self
.
key_value_memory_dict
[
layer_number
]
raise
ValueError
(
"should not swap when dict in empty"
)
assert
(
len
(
batch_idx
)
==
inference_key_memory
.
shape
[
1
]
for
layer_number
in
self
.
key_value_memory_dict
.
keys
():
)
# make sure batch size is the same
inference_key_memory
,
inference_value_memory
=
self
.
key_value_memory_dict
[
layer_number
]
new_inference_key_memory
=
inference_key_memory
[:,
batch_idx
]
assert
(
new_inference_value_memory
=
inference_value_memory
[:,
batch_idx
]
len
(
batch_idx
)
==
inference_key_memory
.
shape
[
1
]
self
.
key_value_memory_dict
[
layer_number
]
=
(
)
# make sure batch size is the same
new_inference_key_memory
,
new_inference_key_memory
=
inference_key_memory
[:,
batch_idx
]
new_inference_value_memory
,
new_inference_value_memory
=
inference_value_memory
[:,
batch_idx
]
)
self
.
key_value_memory_dict
[
layer_number
]
=
(
new_inference_key_memory
,
def
__str__
(
self
):
new_inference_value_memory
,
return
f
"InferenceParams(max_seq_len =
{
self
.
max_sequence_length
}
, max_batch_size =
{
self
.
max_batch_size
}
, sequence_len_offset =
{
self
.
sequence_len_offset
}
, batch_size_offset =
{
self
.
batch_size_offset
}
, key_value_memory_dict =
{
self
.
key_value_memory_dict
.
keys
()
}
)"
)
def
enable_prefill_mode
(
self
):
"""
Indicates the generation loop is in the prefill phase (still processing
input prompt tokens). This should be enabled if the generation loop is
encoding prompt tokens for *any* request in a batch.
"""
self
.
decode_mode
=
False
def
enable_decode_mode
(
self
):
"""
Indicates the generation loop is in the decode phase (generating new output
tokens). This should only be enabled if the generation loop has fully encoded
the prompts for *all* requests in a batch.
"""
self
.
decode_mode
=
True
def
reset
(
self
):
"""Resets the inference state for a new batch."""
self
.
current_batch_size
=
self
.
max_batch_size
self
.
sequence_len_offset
=
0
self
.
batch_size_offset
=
0
self
.
enable_prefill_mode
()
def
__str__
(
self
):
return
(
f
"InferenceParams(max_seq_len =
{
self
.
max_sequence_length
}
, "
f
"max_batch_size =
{
self
.
max_batch_size
}
, "
f
"current_batch_size =
{
self
.
current_batch_size
}
, "
f
"sequence_len_offset =
{
self
.
sequence_len_offset
}
, "
f
"batch_size_offset =
{
self
.
batch_size_offset
}
, "
f
"key_value_memory_dict =
{
self
.
key_value_memory_dict
.
keys
()
}
)"
f
"decode_mode =
{
self
.
decode_mode
}
"
)
def
__eq__
(
self
,
other
):
if
not
isinstance
(
other
,
InferenceParams
):
return
False
# Check all attributes match
basic_attrs
=
[
'max_sequence_length'
,
'max_batch_size'
,
'current_batch_size'
,
'sequence_len_offset'
,
'batch_size_offset'
,
]
if
not
all
(
hasattr
(
other
,
attr
)
for
attr
in
basic_attrs
):
return
False
# Check dictionary keys match; i.e. the same number of layers are cached
if
self
.
key_value_memory_dict
.
keys
()
!=
other
.
key_value_memory_dict
.
keys
():
return
False
# Check each tensor tuple in the dictionary
for
key
in
self
.
key_value_memory_dict
:
self_tensors
=
self
.
key_value_memory_dict
[
key
]
other_tensors
=
other
.
key_value_memory_dict
[
key
]
# Compare each key, value tensor in the tuple
for
self_tensor
,
other_tensor
in
zip
(
self_tensors
,
other_tensors
):
if
(
self_tensor
.
data_ptr
()
!=
other_tensor
.
data_ptr
()
or
self_tensor
.
shape
!=
other_tensor
.
shape
):
return
False
return
True
megatron/core/jit.py
View file @
688448db
...
@@ -7,18 +7,4 @@ from megatron.core.utils import is_torch_min_version
...
@@ -7,18 +7,4 @@ from megatron.core.utils import is_torch_min_version
jit_fuser
=
torch
.
jit
.
script
jit_fuser
=
torch
.
jit
.
script
# nvFuser is deprecated in PyTorch JIT starting from 2.2
# nvFuser is deprecated in PyTorch JIT starting from 2.2
if
is_torch_min_version
(
"2.2.0a0"
):
if
is_torch_min_version
(
"2.2.0a0"
):
jit_fuser
=
torch
.
compile
(
mode
=
'max-autotune-no-cudagraphs'
)
jit_fuser
=
torch
.
compile
# Decorator to disable Torch Dynamo
# See: https://github.com/NVIDIA/TransformerEngine/issues/308
no_torch_dynamo
=
lambda
recursive
=
True
:
lambda
func
:
func
if
torch
.
__version__
>=
"2"
:
import
torch._dynamo
if
torch
.
__version__
>=
"2.1"
:
no_torch_dynamo
=
lambda
recursive
=
True
:
lambda
f
:
torch
.
_dynamo
.
disable
(
f
,
recursive
=
recursive
)
else
:
# no "recursive" option in pyTorch 2.0 - it acts as if recursive was True
no_torch_dynamo
=
lambda
recursive
=
True
:
torch
.
_dynamo
.
disable
megatron/core/model_parallel_config.py
View file @
688448db
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Callable
,
ContextManager
,
Optional
from
typing
import
Callable
,
ContextManager
,
Optional
import
torch
import
torch
@
dataclass
@
dataclass
class
ModelParallelConfig
:
class
ModelParallelConfig
:
"""Base configuration for Megatron Core
"""Base configuration for Megatron Core
The initialization function has an argument for each parameter.
The initialization function has an argument for each parameter.
"""
"""
###################
###################
# Model parallelism
# Model parallelism
###################
###################
tensor_model_parallel_size
:
int
=
1
tensor_model_parallel_size
:
int
=
1
"""Intra-layer model parallelism. Splits tensors across GPU ranks."""
"""Intra-layer model parallelism. Splits tensors across GPU ranks."""
pipeline_model_parallel_size
:
int
=
1
pipeline_model_parallel_comm_backend
:
Optional
[
str
]
=
None
"""Inter-layer model parallelism. Splits transformer layers across GPU ranks."""
"""Configuring backend option of pipeline parallel communication (e.g., nccl, ucc)
If None, the default backend will be used.
virtual_pipeline_model_parallel_size
:
Optional
[
int
]
=
None
"""
"""Interleaved pipeline parallelism is used to improve performance by reducing the pipeline
bubble. Considers a transformer block as a list of smaller transformer (virtual) blocks.
pipeline_model_parallel_size
:
int
=
1
The number of virtual blocks per pipeline model parallel rank is the virtual model parallel
"""Inter-layer model parallelism. Splits transformer layers across GPU ranks."""
size. See Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM:
arxiv.org/pdf/2104.04473.pdf for more details.
virtual_pipeline_model_parallel_size
:
Optional
[
int
]
=
None
"""
"""Interleaved pipeline parallelism is used to improve performance by reducing the pipeline
bubble. Considers a transformer block as a list of smaller transformer (virtual) blocks.
sequence_parallel
:
bool
=
False
The number of virtual blocks per pipeline model parallel rank is the virtual model parallel
"""Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms
size. See Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM:
and dropout sequentially. See Reducing Activation Recomputation in Large Transformer Models
arxiv.org/pdf/2104.04473.pdf for more details.
(https://arxiv.org/abs/2205.05198) for more details.
"""
"""
sequence_parallel
:
bool
=
False
context_parallel_size
:
int
=
1
"""Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms
"""Splits network input along sequence dimension across GPU ranks."""
and dropout sequentially. See Reducing Activation Recomputation in Large Transformer Models
(https://arxiv.org/abs/2205.05198) for more details.
hierarchical_context_parallel_sizes
:
Optional
[
list
[
int
]]
=
None
"""
"""Degrees of the hierarchical context parallelism. Users should provide a list to specify
the sizes for different levels. Taking the a2a+p2p cp comm type as example, it contains
context_parallel_size
:
int
=
1
groups of two levels, so the first value of the list indicates the group size of the a2a
"""Splits network input along sequence dimension across GPU ranks."""
communication type, and the second value indicates the group size of the p2p communication
type.
hierarchical_context_parallel_sizes
:
Optional
[
list
[
int
]]
=
None
"""
"""Degrees of the hierarchical context parallelism. Users should provide a list to specify
the sizes for different levels. Taking the a2a+p2p cp comm type as example, it contains
expert_model_parallel_size
:
int
=
1
groups of two levels, so the first value of the list indicates the group size of the a2a
"""Distributes Moe Experts across sub data parallel dimension."""
communication type, and the second value indicates the group size of the p2p communication
type.
expert_tensor_parallel_size
:
Optional
[
int
]
=
None
"""
"""Intra-layer tensor model parallelsm for expert layer. Splits tensors across GPU ranks."""
expert_model_parallel_size
:
int
=
1
moe_extended_tp
:
bool
=
False
"""Distributes Moe Experts across sub data parallel dimension."""
"""NOTE: Deprecated from MCore v0.10. This flag is ignored.
Its functionality is replaced by expert_tensor_parallel_size.
expert_tensor_parallel_size
:
Optional
[
int
]
=
None
"""
"""Intra-layer tensor model parallelsm for expert layer. Splits tensors across GPU ranks."""
###################
moe_extended_tp
:
bool
=
False
# Initialization
"""NOTE: Deprecated from MCore v0.10. This flag is ignored.
###################
Its functionality is replaced by expert_tensor_parallel_size.
perform_initialization
:
bool
=
True
"""
"""If true, weights are initialized. This option can be useful when you know you are going to
load values from a checkpoint.
###################
"""
# Initialization
###################
use_cpu_initialization
:
bool
=
False
perform_initialization
:
bool
=
True
"""When set to False, we initialize the weights directly on the GPU. CPU initialization is the
"""If true, weights are initialized. This option can be useful when you know you are going to
same regardless of tensor model parallelism, but GPU initialization is not. Transferring
load values from a checkpoint.
weights from CPU to GPU can take a significant amount of time for large models.
"""
"""
use_cpu_initialization
:
bool
=
False
###################
"""When set to False, we initialize the weights directly on the GPU. CPU initialization is the
# Training
same regardless of tensor model parallelism, but GPU initialization is not. Transferring
###################
weights from CPU to GPU can take a significant amount of time for large models.
fp16
:
bool
=
False
"""
"""If true, train with fp16 mixed precision training."""
###################
bf16
:
bool
=
False
# Training
"""If true, train with bf16 mixed precision training."""
###################
fp16
:
bool
=
False
params_dtype
:
torch
.
dtype
=
torch
.
float32
"""If true, train with fp16 mixed precision training."""
"""dtype used when intializing the weights."""
bf16
:
bool
=
False
timers
:
Optional
[
Callable
]
=
None
"""If true, train with bf16 mixed precision training."""
"""Timers object to call for various timing functions. See megatron.core.timers.Timers"""
params_dtype
:
torch
.
dtype
=
torch
.
float32
finalize_model_grads_func
:
Optional
[
Callable
]
=
None
"""dtype used when intializing the weights."""
"""Function that finalizes gradients on all workers. Could include ensuring that grads are
all-reduced across data parallelism, pipeline parallelism, and sequence parallelism
timers
:
Optional
[
Callable
]
=
None
dimensions.
"""Timers object to call for various timing functions. See megatron.core.timers.Timers"""
"""
finalize_model_grads_func
:
Optional
[
Callable
]
=
None
grad_scale_func
:
Optional
[
Callable
]
=
None
"""Function that finalizes gradients on all workers. Could include ensuring that grads are
"""If using loss scaling, this function should take the loss and return the scaled loss. If
all-reduced across data parallelism, pipeline parallelism, and sequence parallelism
None, no function is called on the loss.
dimensions.
"""
"""
no_sync_func
:
Optional
[
Callable
]
=
None
grad_scale_func
:
Optional
[
Callable
]
=
None
"""Function that creates a context that suppresses asynchronous data-parallel communication. If
"""If using loss scaling, this function should take the loss and return the scaled loss. If
the model is an instance of core.distributed.DistributedDataParallel, the default is to use
None, no function is called on the loss.
core.distributed.DistributedDataParallel.no_sync.
"""
"""
no_sync_func
:
Optional
[
Callable
]
=
None
grad_sync_func
:
Optional
[
Callable
]
=
None
"""Function that creates a context that suppresses asynchronous data-parallel communication. If
"""Function that launches asynchronous gradient reductions (e.g. distributed optimizer gradient
the model is an instance of core.distributed.DistributedDataParallel, the default is to use
reduce-scatters). The function should take one argument: an iterable of parameters whose
core.distributed.DistributedDataParallel.no_sync.
gradients are to be synchronized.
"""
"""
grad_sync_func
:
Optional
[
Callable
]
=
None
param_sync_func
:
Optional
[
Callable
]
=
None
"""Function that launches asynchronous gradient reductions (e.g. distributed optimizer gradient
"""Function that launches asynchronous parameter synchronizations (e.g. distributed optimizer
reduce-scatters). The function should take one argument: an iterable of parameters whose
parameter all-gathers). The function should take one argument: an iterable of parameters to
gradients are to be synchronized.
be synchronized.
"""
"""
param_sync_func
:
Optional
[
Callable
]
=
None
deterministic_mode
:
bool
=
False
"""Function that launches asynchronous parameter synchronizations (e.g. distributed optimizer
"""If true, code that has deterministic execution will be chosen. This usually
parameter all-gathers). The function should take one argument: an iterable of parameters to
means slower execution, but is good for debugging and testing. Defaults to False."""
be synchronized.
"""
enable_autocast
:
bool
=
False
"""If true runs the forward step function inside torch.autocast context."""
deterministic_mode
:
bool
=
False
"""If true, code that has deterministic execution will be chosen. This usually
autocast_dtype
:
Optional
[
torch
.
dtype
]
=
None
means slower execution, but is good for debugging and testing. Defaults to False."""
"""dtype to pass to torch.amp.autocast when enabled. If None, is set to pipeline_dtype."""
enable_autocast
:
bool
=
False
num_microbatches_with_partial_activation_checkpoints
:
Optional
[
int
]
=
None
"""If true runs the forward step function inside torch.autocast context."""
"""If int, set the number of microbatches where not all of the layers will be checkpointed and
recomputed. The rest of the microbatches within the window of maximum outstanding
autocast_dtype
:
Optional
[
torch
.
dtype
]
=
None
microbatches will recompute all layers (either full recompute or selective recompute). If
"""dtype to pass to torch.amp.autocast when enabled. If None, is set to pipeline_dtype."""
None, the checkpoint and recompute will be left up to the forward_step function.
num_microbatches_with_partial_activation_checkpoints
:
Optional
[
int
]
=
None
"""
"""If int, set the number of microbatches where not all of the layers will be checkpointed and
recomputed. The rest of the microbatches within the window of maximum outstanding
###################
microbatches will recompute all layers (either full recompute or selective recompute). If
# Optimizations
None, the checkpoint and recompute will be left up to the forward_step function.
###################
gradient_accumulation_fusion
:
bool
=
False
"""
"""If true, fuses weight gradient accumulation to GEMMs. Requires the custom CUDA extension
fused_weight_gradient_mlp_cuda module. To use gradient_accumulation_fusion you must install
###################
APEX with --cpp_ext and --cuda_ext. For example: "pip install --global-option=
\"
--cpp_ext
\"
# Optimizations
--global-option=
\"
--cuda_ext
\"
". Note that the extension requires CUDA>=11. Otherwise, you
###################
must turn off gradient accumulation fusion.
gradient_accumulation_fusion
:
bool
=
False
"""
"""If true, fuses weight gradient accumulation to GEMMs. Requires the custom CUDA extension
fused_weight_gradient_mlp_cuda module. To use gradient_accumulation_fusion you must install
async_tensor_model_parallel_allreduce
:
bool
=
False
APEX with --cpp_ext and --cuda_ext. For example: "pip install --global-option=
\"
--cpp_ext
\"
"""NOTE: Deprecated. This flag is ignored."""
--global-option=
\"
--cuda_ext
\"
". Note that the extension requires CUDA>=11. Otherwise, you
must turn off gradient accumulation fusion.
use_te_rng_tracker
:
bool
=
False
"""
"""If true, uses RNG state tracker in TransformerEngine if exists.
"""
async_tensor_model_parallel_allreduce
:
bool
=
False
"""NOTE: Deprecated. This flag is ignored."""
tp_comm_overlap
:
bool
=
False
"""If true, allows overlapping of Linear layer execution with tensor parallel communication
use_te_rng_tracker
:
bool
=
False
collectives like AllGather/ReduceScatter. Overlapping is done for the linear layers wherever
"""If true, uses RNG state tracker in TransformerEngine if exists.
possible during the forward and the backward pass.
"""
"""
tp_comm_overlap
:
bool
=
False
tp_comm_bulk_wgrad
:
bool
=
True
"""If true, allows overlapping of Linear layer execution with tensor parallel communication
"""If true, allows All-Gather overlap with Bprop activation gradient GEMM. Don't care if
collectives like AllGather/ReduceScatter. Overlapping is done for the linear layers wherever
tp_comm_overlap is False.
possible during the forward and the backward pass.
"""
"""
tp_comm_bulk_dgrad
:
bool
=
True
tp_comm_bulk_wgrad
:
bool
=
True
"""If true, allows Reduce-Scatter overlap with Bprop weight gradient GEMM. Don't care if
"""If true, allows All-Gather overlap with Bprop activation gradient GEMM. Don't care if
tp_comm_overlap is False.
tp_comm_overlap is False.
"""
"""
tp_comm_overlap_ag
:
bool
=
True
tp_comm_bulk_dgrad
:
bool
=
True
"""If true, allows All-Gather overlap with GEMM by pipelining the GEMM and All-Gather.
"""If true, allows Reduce-Scatter overlap with Bprop weight gradient GEMM. Don't care if
Don't care if tp_comm_overlap is False.
tp_comm_overlap is False.
"""
"""
tp_comm_overlap_rs
:
bool
=
True
tp_comm_overlap_ag
:
bool
=
True
"""If true, allows Reduce-Scatter overlap with GEMM by pipelining the GEMM and Reduce-Scatter.
"""If true, allows All-Gather overlap with GEMM by pipelining the GEMM and All-Gather.
Don't care if tp_comm_overlap is False.
Don't care if tp_comm_overlap is False.
"""
"""
tp_comm_overlap_rs_dgrad
:
bool
=
False
tp_comm_overlap_rs
:
bool
=
True
"""If true, allows Reduce-Scatter overlap with DGRAD GEMM by pipelining the
"""If true, allows Reduce-Scatter overlap with GEMM by pipelining the GEMM and Reduce-Scatter.
GEMM and Reduce-Scatter splits. Don't care if tp_comm_overlap is False.
Don't care if tp_comm_overlap is False.
"""
"""
tp_comm_split_ag
:
bool
=
True
tp_comm_overlap_rs_dgrad
:
bool
=
False
"""Deprecated from TransformerEngine v1.6.0.
"""If true, allows Reduce-Scatter overlap with DGRAD GEMM by pipelining the
If true, allows All-Gather overlap with Fprop GEMM by pipelining the GEMM and All-Gather
GEMM and Reduce-Scatter splits. Don't care if tp_comm_overlap is False.
splits. Don't care if tp_comm_overlap is False.
"""
"""
tp_comm_split_ag
:
bool
=
True
tp_comm_atomic_ag
:
bool
=
False
"""Deprecated from TransformerEngine v1.6.0.
"""Deprecated from TransformerEngine v1.6.0.
If true, allows All-Gather overlap with Fprop GEMM by pipelining the GEMM and All-Gather
If true, allows All-Gather overlap with Fprop GEMM by pipelining the GEMM and All-Gather
splits. Don't care if tp_comm_overlap is False.
both done atomically. Don't care if tp_comm_overlap is False.
"""
"""
tp_comm_atomic_ag
:
bool
=
False
tp_comm_split_rs
:
bool
=
True
"""Deprecated from TransformerEngine v1.6.0.
"""Deprecated from TransformerEngine v1.6.0.
If true, allows All-Gather overlap with Fprop GEMM by pipelining the GEMM and All-Gather
If true, allows Reduce-Scatter overlap with Fprop GEMM by pipelining the GEMM and
both done atomically. Don't care if tp_comm_overlap is False.
Reduce-Scatter splits. Don't care if tp_comm_overlap is False.
"""
"""
tp_comm_split_rs
:
bool
=
True
tp_comm_atomic_rs
:
bool
=
False
"""Deprecated from TransformerEngine v1.6.0.
"""Deprecated from TransformerEngine v1.6.0.
If true, allows Reduce-Scatter overlap with Fprop GEMM by pipelining the GEMM and
If true, allows Reduce-Scatter overlap with Fprop GEMM by pipelining the GEMM and
Reduce-Scatter splits. Don't care if tp_comm_overlap is False.
Reduce-Scatter both done atomically. Don't care if tp_comm_overlap is False.
"""
"""
tp_comm_atomic_rs
:
bool
=
False
cross_entropy_loss_fusion
:
bool
=
False
"""Deprecated from TransformerEngine v1.6.0.
"""If this is enabled, the fused cross entropy implementation would be used.
If true, allows Reduce-Scatter overlap with Fprop GEMM by pipelining the GEMM and
Defaults to False.
Reduce-Scatter both done atomically. Don't care if tp_comm_overlap is False.
"""
"""
tp_comm_overlap_disable_qkv
:
bool
=
False
cross_entropy_loss_fusion
:
bool
=
False
"""
"""If this is enabled, the fused cross entropy implementation would be used.
If true, the AllGather -> Gemm overlap for QKV gets disabled
Defaults to False.
"""
"""
tp_comm_overlap_disable_fc1
:
bool
=
False
tp_comm_overlap_disable_qkv
:
bool
=
False
"""
"""
If true, the AllGather -> Gemm overlap for FC1 layer of MLP gets disabled
If true, the AllGather -> Gemm overlap for QKV gets disabled
"""
"""
tp_comm_bootstrap_backend
:
str
=
'nccl'
tp_comm_overlap_disable_fc1
:
bool
=
False
"""
"""
Set the bootstrapping backend out of 'nccl', 'mpi', and 'gloo'
If true, the AllGather -> Gemm overlap for FC1 layer of MLP gets disabled
"""
"""
###################
tp_comm_bootstrap_backend
:
str
=
'nccl'
# Pipeline Parallel
"""
###################
Set the bootstrapping backend out of 'nccl', 'mpi', and 'gloo'
pipeline_dtype
:
torch
.
dtype
=
None
"""
"""dtype used in p2p communication, usually params_dtype"""
###################
variable_seq_lengths
:
bool
=
False
# Pipeline Parallel
"""Support for variable sequence lengths across microbatches. Setting this communicates the size
###################
of tensors during pipeline parallelism communication, because of this extra overhead it
pipeline_dtype
:
torch
.
dtype
=
None
should only be set if the sequence length varies by microbatch within a global batch.
"""dtype used in p2p communication, usually params_dtype"""
"""
variable_seq_lengths
:
bool
=
False
overlap_p2p_comm
:
bool
=
False
"""Support for variable sequence lengths across microbatches. Setting this communicates the size
"""When True some of the peer to peer communication for pipeline parallelism will overlap with
of tensors during pipeline parallelism communication, because of this extra overhead it
computation. Must be False if batch_p2p_comm is true.
should only be set if the sequence length varies by microbatch within a global batch.
"""
"""
batch_p2p_comm
:
bool
=
True
overlap_p2p_comm
:
bool
=
False
"""Use batch_isend_irecv instead of individual isend/irecv calls. Must be False if
"""When True some of the peer to peer communication for pipeline parallelism will overlap with
overlap_p2p_comm is True.
computation. Must be False if batch_p2p_comm is true.
"""
"""
batch_p2p_sync
:
bool
=
True
batch_p2p_comm
:
bool
=
True
"""When using batch_isend_irecv, do a cuda.device.synchronize afterward to work around a bug in
"""Use batch_isend_irecv instead of individual isend/irecv calls. Must be False if
older version of PyTorch.
overlap_p2p_comm is True.
"""
"""
use_ring_exchange_p2p
:
bool
=
False
batch_p2p_sync
:
bool
=
True
"""Use custom ring_exchange kernel instead of torch.distributed.batch_isend_irecv(). Requires
"""When using batch_isend_irecv, do a cuda.device.synchronize afterward to work around a bug in
custom built torch with torch.distributed.ring_exchange.
older version of PyTorch.
"""
"""
deallocate_pipeline_outputs
:
bool
=
False
use_ring_exchange_p2p
:
bool
=
False
"""If True, output data is deallocated after the tensor is sent to the next pipeline stage.
"""Use custom ring_exchange kernel instead of torch.distributed.batch_isend_irecv(). Requires
Helps with saving memory, does nothing when pipeline parallel is not used.
custom built torch with torch.distributed.ring_exchange.
"""
"""
defer_embedding_wgrad_compute
:
bool
=
False
deallocate_pipeline_outputs
:
bool
=
False
"""If true, defers the embedding WGRAD GEMMs while pipeline flush is
"""If True, output data is deallocated after the tensor is sent to the next pipeline stage.
taking place enabling us to hide pipeline flush latency. Defaults to False.
Helps with saving memory, does nothing when pipeline parallel is not used.
"""
"""
wgrad_deferral_limit
:
int
=
0
defer_embedding_wgrad_compute
:
bool
=
False
"""This value tunes the number of micro-batches for which the embedding weight gradient compute
"""If true, defers the embedding WGRAD GEMMs while pipeline flush is
needs to be deferred to pipeline flush, this argument is invalid if
taking place enabling us to hide pipeline flush latency. Defaults to False.
`defer_embedding_wgrad_compute` is False.
"""
Defaults to 0, which means all micro-batches are deferred.
"""
wgrad_deferral_limit
:
int
=
0
"""This value tunes the number of micro-batches for which the embedding weight gradient compute
pipeline_model_parallel_split_rank
:
Optional
[
int
]
=
None
needs to be deferred to pipeline flush, this argument is invalid if
"""If int, rank where encoder and decoder should be split in cases where the model has both an
`defer_embedding_wgrad_compute` is False.
encoder and decoder (e.g., T5). Ignored if None.
Defaults to 0, which means all micro-batches are deferred.
"""
"""
overlap_p2p_comm_warmup_flush
:
bool
=
False
pipeline_model_parallel_split_rank
:
Optional
[
int
]
=
None
"""If true, overlap communication and computation in warm up and flush phase.
"""If int, rank where encoder and decoder should be split in cases where the model has both an
Only valid when overlap_p2p_comm is True and batch_p2p_comm is False.
encoder and decoder (e.g., T5). Ignored if None.
Defaults to False.
"""
"""
overlap_p2p_comm_warmup_flush
:
bool
=
False
microbatch_group_size_per_vp_stage
:
Optional
[
int
]
=
None
"""If true, overlap communication and computation in warm up and flush phase.
"""This value specifies the number of micro-batches that are executed
Only valid when overlap_p2p_comm is True and batch_p2p_comm is False.
at a time for a given virtual stage (both forward and backward).
Defaults to False.
Default (in __post_init__() method below) to pipeline_parallel_size
"""
which specifies a depth-first schedule.
Example: for PP=2 VP=2, when microbatch_group_size_per_vp_stage=2,
microbatch_group_size_per_vp_stage
:
Optional
[
int
]
=
None
num_microbatches = 4, we have
"""This value specifies the number of micro-batches that are executed
rank 0 | 0 1 0 1 2 3 2 3
at a time for a given virtual stage (both forward and backward).
rank 1 | 0 1 0 1 2 3 2 3
Default (in __post_init__() method below) to pipeline_parallel_size
When microbatch_group_size_per_vp_stage=3, num_microbatches = 5,
which specifies a depth-first schedule.
we have
Example: for PP=2 VP=2, when microbatch_group_size_per_vp_stage=2,
rank 0 | 0 1 2 0 1 2 3 4 3 4
num_microbatches = 4, we have
rank 1 | 0 1 2 0 1 2 3 4 3 4
rank 0 | 0 1 0 1 2 3 2 3
"""
rank 1 | 0 1 0 1 2 3 2 3
When microbatch_group_size_per_vp_stage=3, num_microbatches = 5,
###################
we have
# CPU Offloading
rank 0 | 0 1 2 0 1 2 3 4 3 4
###################
rank 1 | 0 1 2 0 1 2 3 4 3 4
cpu_offloading
:
bool
=
False
"""
"""When set to True, all the activations are offloaded to the CPU asynchronously."""
###################
cpu_offloading_num_layers
:
int
=
0
# CPU Offloading
"""Tells the number of transformer layers for which activations has to be offloaded."""
###################
cpu_offloading
:
bool
=
False
_cpu_offloading_context
:
Optional
[
ContextManager
]
=
(
"""When set to True, all the activations are offloaded to the CPU asynchronously."""
None
# Used for internal use only, not to be set by a user.
cpu_offloading_num_layers
:
int
=
0
# TODO: Need to move to the 'right' place when possible.
"""Tells the number of transformer layers for which activations has to be offloaded."""
)
"""For internal use only, do not set."""
_cpu_offloading_context
:
Optional
[
ContextManager
]
=
(
None
cpu_offloading_activations
:
bool
=
True
# Used for internal use only, not to be set by a user.
"""If True, offloads the activations to CPU."""
# TODO: Need to move to the 'right' place when possible.
)
cpu_offloading_weights
:
bool
=
True
"""For internal use only, do not set."""
"""If True, offloads the weights to CPU."""
cpu_offloading_activations
:
bool
=
True
###################
"""If True, offloads the activations to CPU."""
# Timing
###################
cpu_offloading_weights
:
bool
=
True
barrier_with_L1_time
:
bool
=
True
"""If True, offloads the weights to CPU."""
"""If true, use barrier with level 1 time measurements. It is up to the user to make sure
calling barrier with their timers will not result in hangs. This can happen if for example
###################
the user adds a level 1 timer that is not called by all ranks.
# Timing
"""
###################
barrier_with_L1_time
:
bool
=
True
def
__post_init__
(
self
):
"""If true, use barrier with level 1 time measurements. It is up to the user to make sure
"""Python dataclass method that is used to modify attributes after initialization.
calling barrier with their timers will not result in hangs. This can happen if for example
See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more
the user adds a level 1 timer that is not called by all ranks.
details.
"""
"""
if
self
.
sequence_parallel
:
def
__post_init__
(
self
):
if
self
.
tensor_model_parallel_size
<=
1
:
"""Python dataclass method that is used to modify attributes after initialization.
raise
ValueError
(
"Can not use sequence paralllelism without tensor parallelism"
)
See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more
details.
if
self
.
expert_tensor_parallel_size
is
None
:
"""
self
.
expert_tensor_parallel_size
=
self
.
tensor_model_parallel_size
if
self
.
sequence_parallel
:
if
self
.
tensor_model_parallel_size
<=
1
:
if
self
.
pipeline_model_parallel_size
>
1
:
raise
ValueError
(
"Can not use sequence paralllelism without tensor parallelism"
)
if
self
.
pipeline_dtype
is
None
:
raise
ValueError
(
if
self
.
expert_tensor_parallel_size
is
None
:
"When using pipeline parallelism, pipeline_dtype must be specified"
self
.
expert_tensor_parallel_size
=
self
.
tensor_model_parallel_size
)
if
self
.
pipeline_model_parallel_size
>
1
:
if
self
.
autocast_dtype
is
None
:
if
self
.
pipeline_dtype
is
None
:
self
.
autocast_dtype
=
self
.
params_dtype
raise
ValueError
(
"When using pipeline parallelism, pipeline_dtype must be specified"
if
self
.
defer_embedding_wgrad_compute
and
self
.
pipeline_model_parallel_size
==
1
:
)
raise
ValueError
(
"Cannot defer embedding wgrad compute when pipeline model parallel is not used"
if
self
.
autocast_dtype
is
None
:
)
self
.
autocast_dtype
=
self
.
params_dtype
if
self
.
defer_embedding_wgrad_compute
and
not
self
.
gradient_accumulation_fusion
:
if
self
.
defer_embedding_wgrad_compute
and
self
.
pipeline_model_parallel_size
==
1
:
raise
ValueError
(
raise
ValueError
(
"Cannot defer embedding wgrad compute when gradient accumulation fusion is not used"
"Cannot defer embedding wgrad compute when pipeline model parallel is not used"
)
)
if
self
.
defer_embedding_wgrad_compute
and
self
.
wgrad_deferral_limit
<
0
:
if
self
.
defer_embedding_wgrad_compute
and
not
self
.
gradient_accumulation_fusion
:
raise
ValueError
(
raise
ValueError
(
"Wgrad deferral limit should be greater than or equal to 0 when it is enabled!"
"Cannot defer embedding wgrad compute when gradient accumulation fusion is not used"
)
)
if
self
.
expert_model_parallel_size
>
1
and
self
.
tensor_model_parallel_size
>
1
:
if
self
.
defer_embedding_wgrad_compute
and
self
.
wgrad_deferral_limit
<
0
:
if
self
.
sequence_parallel
is
False
:
raise
ValueError
(
raise
ValueError
(
"Wgrad deferral limit should be greater than or equal to 0 when it is enabled!"
"When using expert parallelism and tensor parallelism, "
)
"sequence parallelism must be used"
)
if
self
.
expert_model_parallel_size
>
1
and
self
.
tensor_model_parallel_size
>
1
:
if
self
.
sequence_parallel
is
False
:
if
self
.
microbatch_group_size_per_vp_stage
is
None
:
raise
ValueError
(
self
.
microbatch_group_size_per_vp_stage
=
self
.
pipeline_model_parallel_size
"When using expert parallelism and tensor parallelism, "
"sequence parallelism must be used"
if
self
.
overlap_p2p_comm_warmup_flush
:
)
if
not
self
.
overlap_p2p_comm
or
self
.
batch_p2p_comm
:
raise
ValueError
(
if
self
.
microbatch_group_size_per_vp_stage
is
None
:
"Pipeline parallel communication overlapping in warmup and flush is only "
self
.
microbatch_group_size_per_vp_stage
=
self
.
pipeline_model_parallel_size
"compatible with overlap_p2p_comm but not batch_p2p_comm."
)
if
self
.
overlap_p2p_comm_warmup_flush
:
if
not
self
.
overlap_p2p_comm
or
self
.
batch_p2p_comm
:
raise
ValueError
(
"Pipeline parallel communication overlapping in warmup and flush is only "
"compatible with overlap_p2p_comm but not batch_p2p_comm."
)
megatron/core/models/T5/t5_model.py
View file @
688448db
...
@@ -10,9 +10,11 @@ from megatron.core.config_logger import has_config_logger_enabled, log_config_to
...
@@ -10,9 +10,11 @@ from megatron.core.config_logger import has_config_logger_enabled, log_config_to
from
megatron.core.dist_checkpointing.mapping
import
ShardedStateDict
from
megatron.core.dist_checkpointing.mapping
import
ShardedStateDict
from
megatron.core.enums
import
ModelType
from
megatron.core.enums
import
ModelType
from
megatron.core.models.common.embeddings.language_model_embedding
import
LanguageModelEmbedding
from
megatron.core.models.common.embeddings.language_model_embedding
import
LanguageModelEmbedding
from
megatron.core.models.common.embeddings.relative_pos_embedding
import
RelativePositionEmbedding
from
megatron.core.models.common.embeddings.rotary_pos_embedding
import
RotaryEmbedding
from
megatron.core.models.common.embeddings.rotary_pos_embedding
import
RotaryEmbedding
from
megatron.core.models.common.language_module.language_module
import
LanguageModule
from
megatron.core.models.common.language_module.language_module
import
LanguageModule
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.tensor_parallel.mappings
import
scatter_to_tensor_model_parallel_region
from
megatron.core.transformer.module
import
MegatronModule
from
megatron.core.transformer.module
import
MegatronModule
from
megatron.core.transformer.spec_utils
import
ModuleSpec
from
megatron.core.transformer.spec_utils
import
ModuleSpec
from
megatron.core.transformer.transformer_block
import
TransformerBlock
from
megatron.core.transformer.transformer_block
import
TransformerBlock
...
@@ -135,9 +137,13 @@ class T5Model(LanguageModule):
...
@@ -135,9 +137,13 @@ class T5Model(LanguageModule):
fp16_lm_cross_entropy
:
bool
=
False
,
fp16_lm_cross_entropy
:
bool
=
False
,
parallel_output
:
bool
=
True
,
parallel_output
:
bool
=
True
,
share_embeddings_and_output_weights
:
bool
=
False
,
share_embeddings_and_output_weights
:
bool
=
False
,
position_embedding_type
:
Literal
[
'learned_absolute'
,
'rope'
]
=
'learned_absolute'
,
position_embedding_type
:
Literal
[
'learned_absolute'
,
'rope'
,
'relative'
]
=
'learned_absolute'
,
rotary_percent
:
float
=
1.0
,
rotary_percent
:
float
=
1.0
,
seq_len_interpolation_factor
:
Optional
[
float
]
=
None
,
seq_len_interpolation_factor
:
Optional
[
float
]
=
None
,
relative_attention_num_buckets
:
int
=
32
,
relative_attention_max_distance
:
int
=
128
,
add_encoder
:
bool
=
True
,
add_encoder
:
bool
=
True
,
add_decoder
:
bool
=
True
,
add_decoder
:
bool
=
True
,
):
):
...
@@ -193,6 +199,23 @@ class T5Model(LanguageModule):
...
@@ -193,6 +199,23 @@ class T5Model(LanguageModule):
use_cpu_initialization
=
self
.
config
.
use_cpu_initialization
,
use_cpu_initialization
=
self
.
config
.
use_cpu_initialization
,
)
)
# Relative Position Embeddings
if
self
.
position_embedding_type
==
'relative'
:
self
.
encoder_relative_pos_emb
=
RelativePositionEmbedding
(
bidirectional
=
True
,
init_method
=
self
.
config
.
init_method
,
num_attention_heads
=
self
.
config
.
num_attention_heads
,
relative_attention_num_buckets
=
relative_attention_num_buckets
,
relative_attention_max_distance
=
relative_attention_max_distance
,
)
self
.
decoder_relative_pos_emb
=
RelativePositionEmbedding
(
bidirectional
=
False
,
init_method
=
self
.
config
.
init_method
,
num_attention_heads
=
self
.
config
.
num_attention_heads
,
relative_attention_num_buckets
=
relative_attention_num_buckets
,
relative_attention_max_distance
=
relative_attention_max_distance
,
)
# Transformer encoder
# Transformer encoder
encoder_spec
,
decoder_spec
=
(
encoder_spec
,
decoder_spec
=
(
self
.
transformer_encoder_layer_spec
,
self
.
transformer_encoder_layer_spec
,
...
@@ -284,6 +307,27 @@ class T5Model(LanguageModule):
...
@@ -284,6 +307,27 @@ class T5Model(LanguageModule):
)
)
rotary_pos_emb
=
self
.
rotary_pos_emb
(
rotary_seq_len
)
rotary_pos_emb
=
self
.
rotary_pos_emb
(
rotary_seq_len
)
# Relative positional embeddings
encoder_attention_bias_parallel
=
None
if
self
.
position_embedding_type
==
'relative'
:
query_seq_length
=
RelativePositionEmbedding
.
get_relative_seq_len
(
inference_params
,
self
.
encoder
,
encoder_input
,
self
.
config
)
key_seq_length
=
query_seq_length
attention_bias
=
self
.
encoder_relative_pos_emb
(
query_seq_length
,
key_seq_length
)
# Scatter attention_bias to TP ranks
# First, reshape [1, num_head, seqlen_q, seqlen_kv] to
# [1, seqlen_q, seqlen_kv, num_head] to be scatter along
# the last (num_heads dimension)
attention_bias
=
torch
.
permute
(
attention_bias
,
(
0
,
2
,
3
,
1
))
# Then, scatter to TP region
attention_bias_parallel
=
scatter_to_tensor_model_parallel_region
(
attention_bias
)
# Lastly, revert the dimension back to [1, num_head, seqlen_q, seqlen_kv]
encoder_attention_bias_parallel
=
torch
.
permute
(
attention_bias_parallel
,
(
0
,
3
,
1
,
2
)
)
# Run encoder.
# Run encoder.
if
self
.
add_encoder
:
if
self
.
add_encoder
:
encoder_hidden_states
=
self
.
encoder
(
encoder_hidden_states
=
self
.
encoder
(
...
@@ -291,6 +335,7 @@ class T5Model(LanguageModule):
...
@@ -291,6 +335,7 @@ class T5Model(LanguageModule):
attention_mask
=
encoder_attn_mask
,
attention_mask
=
encoder_attn_mask
,
inference_params
=
inference_params
,
inference_params
=
inference_params
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_emb
=
rotary_pos_emb
,
attention_bias
=
encoder_attention_bias_parallel
,
)
)
else
:
else
:
encoder_hidden_states
=
self
.
encoder_hidden_state
encoder_hidden_states
=
self
.
encoder_hidden_state
...
@@ -315,10 +360,29 @@ class T5Model(LanguageModule):
...
@@ -315,10 +360,29 @@ class T5Model(LanguageModule):
rotary_pos_emb
=
None
rotary_pos_emb
=
None
if
self
.
position_embedding_type
==
'rope'
:
if
self
.
position_embedding_type
==
'rope'
:
rotary_seq_len
=
self
.
rotary_pos_emb
.
get_rotary_seq_len
(
rotary_seq_len
=
self
.
rotary_pos_emb
.
get_rotary_seq_len
(
inference_params
,
self
.
e
n
coder
,
e
n
coder_input
,
self
.
config
,
packed_seq_params
inference_params
,
self
.
d
ecoder
,
d
ecoder_input
,
self
.
config
,
packed_seq_params
)
)
rotary_pos_emb
=
self
.
rotary_pos_emb
(
rotary_seq_len
)
rotary_pos_emb
=
self
.
rotary_pos_emb
(
rotary_seq_len
)
# Relative positional embeddings
decoder_attention_bias_parallel
=
None
if
self
.
position_embedding_type
==
'relative'
:
query_seq_length
=
RelativePositionEmbedding
.
get_relative_seq_len
(
inference_params
,
self
.
decoder
,
decoder_input
,
self
.
config
)
key_seq_length
=
query_seq_length
attention_bias
=
self
.
decoder_relative_pos_emb
(
query_seq_length
,
key_seq_length
)
# Scatter attention_bias to TP ranks
# First, reshape [1, num_head, seqlen_q, seqlen_kv] to
# [1, seqlen_q, seqlen_kv, num_head] to be scatter along
# the last (num_heads dimension)
attention_bias
=
torch
.
permute
(
attention_bias
,
(
0
,
2
,
3
,
1
))
# Then, scatter to TP region
attention_bias_parallel
=
scatter_to_tensor_model_parallel_region
(
attention_bias
)
# Lastly, revert the dimension back to [1, num_head, seqlen_q, seqlen_kv]
decoder_attention_bias_parallel
=
torch
.
permute
(
attention_bias_parallel
,
(
0
,
3
,
1
,
2
))
# Run decoder.
# Run decoder.
decoder_hidden_states
=
self
.
decoder
(
decoder_hidden_states
=
self
.
decoder
(
hidden_states
=
decoder_input
,
hidden_states
=
decoder_input
,
...
@@ -327,12 +391,15 @@ class T5Model(LanguageModule):
...
@@ -327,12 +391,15 @@ class T5Model(LanguageModule):
context_mask
=
encoder_decoder_attn_mask
,
context_mask
=
encoder_decoder_attn_mask
,
inference_params
=
inference_params
,
inference_params
=
inference_params
,
rotary_pos_emb
=
rotary_pos_emb
,
rotary_pos_emb
=
rotary_pos_emb
,
attention_bias
=
decoder_attention_bias_parallel
,
)
)
if
self
.
post_process
:
if
self
.
post_process
:
lm_logits
=
self
.
lm_head
(
output_weight
=
None
decoder_hidden_states
,
self
.
shared_embedding_or_output_weight
()
if
self
.
share_embeddings_and_output_weights
:
)
output_weight
=
self
.
shared_embedding_or_output_weight
()
lm_logits
=
self
.
lm_head
(
decoder_hidden_states
,
word_embeddings_weight
=
output_weight
)
if
lm_labels
is
None
:
if
lm_labels
is
None
:
# [s b h] => [b s h]
# [s b h] => [b s h]
return
lm_logits
.
transpose
(
0
,
1
).
contiguous
()
return
lm_logits
.
transpose
(
0
,
1
).
contiguous
()
...
...
megatron/core/models/common/embeddings/relative_pos_embedding.py
0 → 100644
View file @
688448db
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
import
logging
import
math
from
typing
import
Callable
import
torch
from
torch
import
Tensor
,
nn
from
megatron.core.inference_params
import
InferenceParams
from
megatron.core.transformer.transformer_block
import
TransformerBlock
from
megatron.core.transformer.transformer_config
import
TransformerConfig
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
'RelativePositionEmbedding'
]
class
RelativePositionEmbedding
(
nn
.
Module
):
"""Relative Position Embedding for language model.
Args:
"""
def
__init__
(
self
,
bidirectional
:
bool
,
init_method
:
Callable
,
num_attention_heads
:
int
,
relative_attention_num_buckets
:
int
=
32
,
relative_attention_max_distance
:
int
=
128
,
)
->
None
:
super
().
__init__
()
self
.
bidirectional
=
bidirectional
self
.
relative_attention_num_buckets
=
relative_attention_num_buckets
self
.
relative_attention_max_distance
=
relative_attention_max_distance
self
.
relative_attention_bias
=
torch
.
nn
.
Embedding
(
self
.
relative_attention_num_buckets
,
num_attention_heads
)
init_method
(
self
.
relative_attention_bias
.
weight
)
def
_relative_position_bucket
(
self
,
relative_position
,
bidirectional
=
True
,
num_buckets
=
32
,
max_distance
=
128
):
"""
Adapted from HuggingFace T5 Model:
https://github.com/huggingface/transformers/blob/329f5dbf97a5cb2473914c88c05aa3dcb242e19a/
src/transformers/models/t5/modeling_t5.py#L397
Translate relative position to a bucket number for relative attention.
The relative position is defined as memory_position - query_position, i.e. the
distance in tokens from the attending position to the attended-to position.
If bidirectional=False, then positive relative positions are invalid. We use
smaller buckets for small absolute relative_position and larger buckets for
larger absolute relative_positions. All relative positions >=max_distance map
to the same bucket. All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences than the
model has been trained on.
Args:
relative_position: an int32 Tensor
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer
max_distance: an integer
Returns:
a Tensor with the same shape as relative_position,
containing int32 values in the range [0, num_buckets)
"""
relative_buckets
=
0
if
bidirectional
:
num_buckets
//=
2
relative_buckets
+=
(
relative_position
>
0
).
to
(
torch
.
long
)
*
num_buckets
relative_position
=
torch
.
abs
(
relative_position
)
else
:
relative_position
=
-
torch
.
min
(
relative_position
,
torch
.
zeros_like
(
relative_position
))
# now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions
max_exact
=
num_buckets
//
2
is_small
=
relative_position
<
max_exact
# The other half of the buckets are for logarithmically bigger
# bins in positions up to max_distance
relative_position_if_large
=
max_exact
+
(
torch
.
log
(
relative_position
.
float
()
/
max_exact
)
/
math
.
log
(
max_distance
/
max_exact
)
*
(
num_buckets
-
max_exact
)
).
to
(
torch
.
long
)
relative_position_if_large
=
torch
.
min
(
relative_position_if_large
,
torch
.
full_like
(
relative_position_if_large
,
num_buckets
-
1
)
)
relative_buckets
+=
torch
.
where
(
is_small
,
relative_position
,
relative_position_if_large
)
return
relative_buckets
def
_compute_bias
(
self
,
query_length
,
key_length
):
"""
Adapted from HuggingFace T5 Model
https://github.com/huggingface/transformers/blob/329f5dbf97a5cb2473914c88c05aa3dcb242e19a/
src/transformers/models/t5/modeling_t5.py#L444C9-L444C21
Compute binned relative position bias
Args:
query_length (int): The length of the query sequence
(e.g., the input sequence in attention).
key_length (int): The length of the key sequence
(e.g., the sequence to compare against in attention).
Returns:
torch.Tensor: A tensor representing the relative position bias, with shape
(1, num_heads, query_length, key_length).
"""
device
=
self
.
relative_attention_bias
.
weight
.
device
context_position
=
torch
.
arange
(
query_length
,
dtype
=
torch
.
long
,
device
=
device
)[:,
None
]
memory_position
=
torch
.
arange
(
key_length
,
dtype
=
torch
.
long
,
device
=
device
)[
None
,
:]
relative_position
=
memory_position
-
context_position
# shape(query_length,key_length)
relative_position_bucket
=
self
.
_relative_position_bucket
(
relative_position
,
# shape (query_length, key_length)
bidirectional
=
self
.
bidirectional
,
num_buckets
=
self
.
relative_attention_num_buckets
,
max_distance
=
self
.
relative_attention_max_distance
,
)
values
=
self
.
relative_attention_bias
(
relative_position_bucket
)
# shape(query_length,key_length,num_heads)
values
=
values
.
permute
([
2
,
0
,
1
]).
unsqueeze
(
0
)
# shape(1, num_heads,query_length,key_length)
return
values
@
staticmethod
def
get_relative_seq_len
(
inference_params
:
InferenceParams
,
transformer
:
TransformerBlock
,
transformer_input
:
Tensor
,
transformer_config
:
TransformerConfig
,
)
->
float
:
"""Function to get the rotary sequence length.
Args:
inference_params : Used during Inference time
transformer (TransformerBlock): The transformer block (decoder/encoder) used
by the model
transformer_input (Tensor): Input tensor to the transformer
transformer_config (TransformerConfig): Transformer config used by the model
Returns:
float: The rotary sequence length
"""
if
inference_params
is
not
None
:
relative_seq_len
=
inference_params
.
max_sequence_length
else
:
if
transformer
.
input_tensor
is
not
None
:
relative_seq_len
=
transformer
.
input_tensor
.
size
(
0
)
else
:
relative_seq_len
=
transformer_input
.
size
(
0
)
if
transformer_config
.
sequence_parallel
:
relative_seq_len
*=
transformer_config
.
tensor_model_parallel_size
return
relative_seq_len
def
forward
(
self
,
query_seq_length
,
key_seq_length
):
"""
Args:
Returns:
"""
return
self
.
_compute_bias
(
query_seq_length
,
key_seq_length
)
megatron/core/models/common/embeddings/rotary_pos_embedding.py
View file @
688448db
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from
__future__
import
annotations
from
__future__
import
annotations
from
typing
import
TYPE_CHECKING
from
typing
import
TYPE_CHECKING
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core.transformer.transformer_block
import
TransformerBlock
from
megatron.core.transformer.transformer_block
import
TransformerBlock
from
megatron.core.inference_params
import
InferenceParams
from
megatron.core.inference_params
import
InferenceParams
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.packed_seq_params
import
PackedSeqParams
import
logging
import
logging
import
math
import
math
from
functools
import
lru_cache
from
functools
import
lru_cache
import
torch
import
torch
from
torch
import
Tensor
,
nn
from
torch
import
Tensor
,
nn
from
megatron.core
import
parallel_state
from
megatron.core
import
parallel_state
from
megatron.core.models.common.embeddings.rope_utils
import
(
# for backward compatibility; pylint: disable=unused-import
from
megatron.core.models.common.embeddings.rope_utils
import
(
# for backward compatibility; pylint: disable=unused-import
_apply_rotary_pos_emb_bshd
,
_apply_rotary_pos_emb_bshd
,
_apply_rotary_pos_emb_thd
,
_apply_rotary_pos_emb_thd
,
_rotate_half
,
_rotate_half
,
apply_rotary_pos_emb
,
apply_rotary_pos_emb
,
get_pos_emb_on_this_cp_rank
,
get_pos_emb_on_this_cp_rank
,
)
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
'RotaryEmbedding'
]
__all__
=
[
'RotaryEmbedding'
]
class
RotaryEmbedding
(
nn
.
Module
):
class
RotaryEmbedding
(
nn
.
Module
):
"""Rotary Embedding for language model.
"""Rotary Embedding for language model.
Args:
Args:
kv_channels (int): Projection weights dimension in multi-head attention. Obtained
kv_channels (int): Projection weights dimension in multi-head attention. Obtained
from transformer config
from transformer config
rotary_percent (float): Percent of rotary dimension to use for rotary position
rotary_percent (float): Percent of rotary dimension to use for rotary position
embeddings.
embeddings.
rotary_interleaved (bool, optional): If True, interleaved rotary position embeddings.
rotary_interleaved (bool, optional): If True, interleaved rotary position embeddings.
Defaults to False.
Defaults to False.
seq_len_interpolation_factor (float, optional): scale of linearly interpolating RoPE
seq_len_interpolation_factor (float, optional): scale of linearly interpolating RoPE
for longer sequences. The value must be a float larger than 1.0. Defaults to None
for longer sequences. The value must be a float larger than 1.0. Defaults to None
rotary_base (int, optional): Base period for rotary position embeddings. Defaults to
rotary_base (int, optional): Base period for rotary position embeddings. Defaults to
10000.
10000.
rope_scaling (bool, optional): Apply rope scaling as used in llama 3.1
rope_scaling (bool, optional): Apply rope scaling as used in llama 3.x.
use_cpu_initialization (bool, optional): If False, initialize the inv_freq directly
rope_scaling_factor (float, optional): rope scaling factor in llama 3.x. Defaults to 8.
on the GPU. Defaults to False
use_cpu_initialization (bool, optional): If False, initialize the inv_freq directly
"""
on the GPU. Defaults to False
"""
def
__init__
(
self
,
def
__init__
(
kv_channels
:
int
,
self
,
rotary_percent
:
float
,
kv_channels
:
int
,
rotary_interleaved
:
bool
=
False
,
rotary_percent
:
float
,
seq_len_interpolation_factor
:
float
=
None
,
rotary_interleaved
:
bool
=
False
,
rotary_base
:
int
=
10000
,
seq_len_interpolation_factor
:
float
=
None
,
rope_scaling
:
bool
=
False
,
rotary_base
:
int
=
10000
,
use_cpu_initialization
:
bool
=
False
,
rope_scaling
:
bool
=
False
,
)
->
None
:
rope_scaling_factor
:
float
=
8.0
,
super
().
__init__
()
use_cpu_initialization
:
bool
=
False
,
)
->
None
:
dim
=
kv_channels
super
().
__init__
()
if
rotary_percent
<
1.0
:
dim
=
int
(
dim
*
rotary_percent
)
dim
=
kv_channels
self
.
rotary_interleaved
=
rotary_interleaved
if
rotary_percent
<
1.0
:
dim
=
int
(
dim
*
rotary_percent
)
self
.
seq_len_interpolation_factor
=
seq_len_interpolation_factor
self
.
rotary_interleaved
=
rotary_interleaved
device
=
'cpu'
if
use_cpu_initialization
else
torch
.
cuda
.
current_device
()
self
.
inv_freq
=
1.0
/
(
self
.
seq_len_interpolation_factor
=
seq_len_interpolation_factor
rotary_base
**
(
torch
.
arange
(
0
,
dim
,
2
,
dtype
=
torch
.
float32
,
device
=
device
)
/
dim
)
device
=
'cpu'
if
use_cpu_initialization
else
torch
.
cuda
.
current_device
()
)
self
.
inv_freq
=
1.0
/
(
rotary_base
**
(
torch
.
arange
(
0
,
dim
,
2
,
dtype
=
torch
.
float32
,
device
=
device
)
/
dim
)
if
rope_scaling
:
)
self
.
inv_freq
=
self
.
_apply_scaling
(
self
.
inv_freq
)
if
rope_scaling
:
def
_apply_scaling
(
self
.
inv_freq
=
self
.
_apply_scaling
(
self
.
inv_freq
,
factor
=
rope_scaling_factor
)
self
,
freqs
,
def
_apply_scaling
(
factor
=
8
,
self
,
low_freq_factor
=
1
,
freqs
,
high_freq_factor
=
4
,
factor
=
8
,
original_max_position_embeddings
=
8192
,
low_freq_factor
=
1
,
):
high_freq_factor
=
4
,
# This implementation is adapted from:
original_max_position_embeddings
=
8192
,
# https://github.com/huggingface/transformers/blob/2a5a6ad18aa22e98429bb5ecb880660328030ea0/src/transformers/modeling_rope_utils.py#L303-L343
):
# This implementation is adapted from:
factor
=
factor
# `8` in the original implementation
# https://github.com/huggingface/transformers/blob/2a5a6ad18aa22e98429bb5ecb880660328030ea0/src/transformers/modeling_rope_utils.py#L303-L343
low_freq_factor
=
low_freq_factor
# `1` in the original implementation
high_freq_factor
=
high_freq_factor
# `4` in the original implementation
factor
=
factor
# `8` in the original implementation
old_context_len
=
original_max_position_embeddings
# `8192` in the original implementation
low_freq_factor
=
low_freq_factor
# `1` in the original implementation
high_freq_factor
=
high_freq_factor
# `4` in the original implementation
low_freq_wavelen
=
old_context_len
/
low_freq_factor
old_context_len
=
original_max_position_embeddings
# `8192` in the original implementation
high_freq_wavelen
=
old_context_len
/
high_freq_factor
low_freq_wavelen
=
old_context_len
/
low_freq_factor
wavelen
=
2
*
math
.
pi
/
freqs
high_freq_wavelen
=
old_context_len
/
high_freq_factor
# wavelen < high_freq_wavelen: do nothing
# wavelen > low_freq_wavelen: divide by factor
wavelen
=
2
*
math
.
pi
/
freqs
inv_freq_llama
=
torch
.
where
(
wavelen
>
low_freq_wavelen
,
freqs
/
factor
,
freqs
)
# wavelen < high_freq_wavelen: do nothing
# otherwise: interpolate between the two, using a smooth factor
# wavelen > low_freq_wavelen: divide by factor
smooth_factor
=
(
old_context_len
/
wavelen
-
low_freq_factor
)
/
(
inv_freq_llama
=
torch
.
where
(
wavelen
>
low_freq_wavelen
,
freqs
/
factor
,
freqs
)
high_freq_factor
-
low_freq_factor
# otherwise: interpolate between the two, using a smooth factor
)
smooth_factor
=
(
old_context_len
/
wavelen
-
low_freq_factor
)
/
(
smoothed_inv_freq
=
(
high_freq_factor
-
low_freq_factor
1
-
smooth_factor
)
)
*
inv_freq_llama
/
factor
+
smooth_factor
*
inv_freq_llama
smoothed_inv_freq
=
(
is_medium_freq
=
~
(
wavelen
<
high_freq_wavelen
)
*
~
(
wavelen
>
low_freq_wavelen
)
1
-
smooth_factor
inv_freq_llama
=
torch
.
where
(
is_medium_freq
,
smoothed_inv_freq
,
inv_freq_llama
)
)
*
inv_freq_llama
/
factor
+
smooth_factor
*
inv_freq_llama
is_medium_freq
=
~
(
wavelen
<
high_freq_wavelen
)
*
~
(
wavelen
>
low_freq_wavelen
)
return
inv_freq_llama
inv_freq_llama
=
torch
.
where
(
is_medium_freq
,
smoothed_inv_freq
,
inv_freq_llama
)
def
get_freqs_non_repeated
(
self
,
max_seq_len
:
int
,
offset
:
int
=
0
)
->
Tensor
:
return
inv_freq_llama
"""Generates matrix of frequencies based on positions in the sequence,
used to create positional encodings"""
def
get_freqs_non_repeated
(
self
,
max_seq_len
:
int
,
offset
:
int
=
0
)
->
Tensor
:
seq
=
(
"""Generates matrix of frequencies based on positions in the sequence,
torch
.
arange
(
max_seq_len
,
device
=
self
.
inv_freq
.
device
,
dtype
=
self
.
inv_freq
.
dtype
)
used to create positional encodings"""
+
offset
seq
=
(
)
torch
.
arange
(
max_seq_len
,
device
=
self
.
inv_freq
.
device
,
dtype
=
self
.
inv_freq
.
dtype
)
+
offset
if
self
.
seq_len_interpolation_factor
is
not
None
:
)
seq
*=
1
/
self
.
seq_len_interpolation_factor
if
self
.
seq_len_interpolation_factor
is
not
None
:
freqs
=
torch
.
outer
(
seq
,
self
.
inv_freq
)
# [seq len, dim]
seq
*=
1
/
self
.
seq_len_interpolation_factor
return
freqs
freqs
=
torch
.
outer
(
seq
,
self
.
inv_freq
)
# [seq len, dim]
def
get_cos_sin
(
self
,
max_seq_len
:
int
,
offset
:
int
=
0
)
->
(
Tensor
,
Tensor
):
return
freqs
"""Cosine and sine values for RoPE are precomputed for all positions up to the maximum
sequence length"""
def
get_cos_sin
(
self
,
max_seq_len
:
int
,
offset
:
int
=
0
)
->
(
Tensor
,
Tensor
):
freqs
=
self
.
get_freqs_non_repeated
(
max_seq_len
,
offset
)
"""Cosine and sine values for RoPE are precomputed for all positions up to the maximum
cos
=
torch
.
cos
(
freqs
)
sequence length"""
sin
=
torch
.
sin
(
freqs
)
freqs
=
self
.
get_freqs_non_repeated
(
max_seq_len
,
offset
)
return
cos
,
sin
cos
=
torch
.
cos
(
freqs
)
sin
=
torch
.
sin
(
freqs
)
@
lru_cache
(
maxsize
=
32
)
return
cos
,
sin
def
forward
(
self
,
max_seq_len
:
int
,
offset
:
int
=
0
,
packed_seq
:
bool
=
False
)
->
Tensor
:
"""Forward pass of RoPE embedding.
@
lru_cache
(
maxsize
=
32
)
def
forward
(
self
,
max_seq_len
:
int
,
offset
:
int
=
0
,
packed_seq
:
bool
=
False
)
->
Tensor
:
Args:
"""Forward pass of RoPE embedding.
max_seq_len (int): Maximum size of sequence
offset (int, optional): RoPE offset. Defaults to 0.
Args:
packed_seq (bool, optional): Whether to use packed sequence. Defaults to False.
max_seq_len (int): Maximum size of sequence
offset (int, optional): RoPE offset. Defaults to 0.
Returns:
packed_seq (bool, optional): Whether to use packed sequence. Defaults to False.
Tensor: Embeddings after applying RoPE.
"""
Returns:
if
self
.
inv_freq
.
device
.
type
==
'cpu'
:
Tensor: Embeddings after applying RoPE.
# move `inv_freq` to GPU once at the first micro-batch forward pass
"""
self
.
inv_freq
=
self
.
inv_freq
.
to
(
device
=
torch
.
cuda
.
current_device
())
if
self
.
inv_freq
.
device
.
type
==
'cpu'
:
# move `inv_freq` to GPU once at the first micro-batch forward pass
freqs
=
self
.
get_freqs_non_repeated
(
max_seq_len
,
offset
)
self
.
inv_freq
=
self
.
inv_freq
.
to
(
device
=
torch
.
cuda
.
current_device
())
# first part even vector components, second part odd vector components,
# 2 * dim in dimension size
freqs
=
self
.
get_freqs_non_repeated
(
max_seq_len
,
offset
)
if
not
self
.
rotary_interleaved
:
# first part even vector components, second part odd vector components,
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
)
# 2 * dim in dimension size
else
:
if
not
self
.
rotary_interleaved
:
emb
=
torch
.
stack
((
freqs
.
view
(
-
1
,
1
),
freqs
.
view
(
-
1
,
1
)),
dim
=-
1
).
view
(
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
)
freqs
.
shape
[
0
],
-
1
else
:
)
emb
=
torch
.
stack
((
freqs
.
view
(
-
1
,
1
),
freqs
.
view
(
-
1
,
1
)),
dim
=-
1
).
view
(
# emb [seq_length, .., dim]
freqs
.
shape
[
0
],
-
1
emb
=
emb
[:,
None
,
None
,
:]
)
if
parallel_state
.
get_context_parallel_world_size
()
>
1
and
not
packed_seq
:
# emb [seq_length, .., dim]
# slice rotary_pos_emb along sequence dimension and select the parition of the current
emb
=
emb
[:,
None
,
None
,
:]
# CP rank
if
parallel_state
.
get_context_parallel_world_size
()
>
1
and
not
packed_seq
:
emb
=
get_pos_emb_on_this_cp_rank
(
emb
,
0
)
# slice rotary_pos_emb along sequence dimension and select the parition of the current
return
emb
# CP rank
emb
=
get_pos_emb_on_this_cp_rank
(
emb
,
0
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
return
emb
state_dict
.
pop
(
f
'
{
prefix
}
inv_freq'
,
None
)
return
super
().
_load_from_state_dict
(
state_dict
,
prefix
,
*
args
,
**
kwargs
)
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
state_dict
.
pop
(
f
'
{
prefix
}
inv_freq'
,
None
)
def
get_rotary_seq_len
(
return
super
().
_load_from_state_dict
(
state_dict
,
prefix
,
*
args
,
**
kwargs
)
self
,
inference_params
:
InferenceParams
,
def
get_rotary_seq_len
(
transformer
:
TransformerBlock
,
self
,
transformer_input
:
Tensor
,
inference_params
:
InferenceParams
,
transformer_config
:
TransformerConfig
,
transformer
:
TransformerBlock
,
packed_seq_params
:
PackedSeqParams
,
transformer_input
:
Tensor
,
)
->
float
:
transformer_config
:
TransformerConfig
,
"""Function to get the rotary sequence length.
packed_seq_params
:
PackedSeqParams
,
)
->
float
:
Args:
"""Function to get the rotary sequence length.
inference_params : Used during Inference time
transformer (TransformerBlock): The transformer block (decoder/encoder) used
Args:
by the model
inference_params : Used during Inference time
transformer_input (Tensor): Input tensor to the transformer
transformer (TransformerBlock): The transformer block (decoder/encoder) used
transformer_config (TransformerConfig): Transformer config used by the model
by the model
packed_seq_params (PackedSeqParams): Packed sequence params
transformer_input (Tensor): Input tensor to the transformer
transformer_config (TransformerConfig): Transformer config used by the model
Returns:
packed_seq_params (PackedSeqParams): Packed sequence params
float: The rotary sequence length
"""
Returns:
if
packed_seq_params
is
not
None
:
float: The rotary sequence length
# max_seqlen are the max sequence length in the packed sequence before being divived
"""
# by the tp and cp size.
if
packed_seq_params
is
not
None
:
return
max
(
packed_seq_params
.
max_seqlen_q
,
packed_seq_params
.
max_seqlen_kv
)
# max_seqlen are the max sequence length in the packed sequence before being divived
elif
inference_params
is
not
None
:
# by the tp and cp size.
rotary_seq_len
=
inference_params
.
max_sequence_length
return
max
(
packed_seq_params
.
max_seqlen_q
,
packed_seq_params
.
max_seqlen_kv
)
else
:
elif
inference_params
is
not
None
:
if
transformer
.
input_tensor
is
not
None
:
rotary_seq_len
=
inference_params
.
max_sequence_length
rotary_seq_len
=
transformer
.
input_tensor
.
size
(
0
)
else
:
else
:
if
transformer
is
not
None
and
transformer
.
input_tensor
is
not
None
:
rotary_seq_len
=
transformer_input
.
size
(
0
)
rotary_seq_len
=
transformer
.
input_tensor
.
size
(
0
)
else
:
if
transformer_config
.
sequence_parallel
:
rotary_seq_len
=
transformer_input
.
size
(
0
)
rotary_seq_len
*=
transformer_config
.
tensor_model_parallel_size
if
transformer_config
.
sequence_parallel
:
rotary_seq_len
*=
transformer_config
.
context_parallel_size
rotary_seq_len
*=
transformer_config
.
tensor_model_parallel_size
return
rotary_seq_len
rotary_seq_len
*=
transformer_config
.
context_parallel_size
return
rotary_seq_len
megatron/core/models/gpt/gpt_layer_specs.py
View file @
688448db
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import
warnings
import
warnings
from
typing
import
Optional
from
typing
import
Optional
from
megatron.core.fusions.fused_bias_dropout
import
get_bias_dropout_add
from
megatron.core.fusions.fused_bias_dropout
import
get_bias_dropout_add
from
megatron.core.models.gpt.moe_module_specs
import
get_moe_module_spec
from
megatron.core.models.gpt.moe_module_specs
import
get_moe_module_spec
from
megatron.core.tensor_parallel.layers
import
ColumnParallelLinear
,
RowParallelLinear
from
megatron.core.tensor_parallel.layers
import
ColumnParallelLinear
,
RowParallelLinear
from
megatron.core.transformer.attention
import
SelfAttention
,
SelfAttentionSubmodules
from
megatron.core.transformer.attention
import
SelfAttention
,
SelfAttentionSubmodules
from
megatron.core.transformer.dot_product_attention
import
DotProductAttention
from
megatron.core.transformer.dot_product_attention
import
DotProductAttention
from
megatron.core.transformer.enums
import
AttnMaskType
from
megatron.core.transformer.enums
import
AttnMaskType
from
megatron.core.transformer.identity_op
import
IdentityOp
from
megatron.core.transformer.identity_op
import
IdentityOp
from
megatron.core.transformer.mlp
import
MLP
,
MLPSubmodules
from
megatron.core.transformer.mlp
import
MLP
,
MLPSubmodules
from
megatron.core.transformer.multi_latent_attention
import
(
from
megatron.core.transformer.multi_latent_attention
import
(
MLASelfAttention
,
MLASelfAttention
,
MLASelfAttentionSubmodules
,
MLASelfAttentionSubmodules
,
)
)
from
megatron.core.transformer.spec_utils
import
ModuleSpec
from
megatron.core.transformer.spec_utils
import
ModuleSpec
from
megatron.core.transformer.transformer_block
import
(
from
megatron.core.transformer.transformer_block
import
(
TransformerBlockSubmodules
,
TransformerBlockSubmodules
,
get_num_layers_to_build
,
get_num_layers_to_build
,
)
)
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core.transformer.transformer_layer
import
TransformerLayer
,
TransformerLayerSubmodules
from
megatron.core.transformer.transformer_layer
import
(
from
megatron.core.utils
import
is_te_min_version
TransformerLayer
,
TransformerLayerSubmodules
,
try
:
get_transformer_layer_offset
,
from
megatron.core.extensions.transformer_engine
import
(
)
TEColumnParallelLinear
,
from
megatron.core.utils
import
is_te_min_version
TEDotProductAttention
,
TELayerNormColumnParallelLinear
,
try
:
TENorm
,
from
megatron.core.extensions.transformer_engine
import
(
TERowParallelLinear
,
TEColumnParallelLinear
,
)
TEDotProductAttention
,
TELayerNormColumnParallelLinear
,
HAVE_TE
=
True
TENorm
,
except
ImportError
:
TERowParallelLinear
,
HAVE_TE
=
False
)
try
:
HAVE_TE
=
True
import
apex
# pylint: disable=unused-import
except
ImportError
:
HAVE_TE
=
False
from
megatron.core.fusions.fused_layer_norm
import
FusedLayerNorm
try
:
HAVE_APEX
=
True
import
apex
# pylint: disable=unused-import
LNImpl
=
FusedLayerNorm
except
ImportError
:
from
megatron.core.fusions.fused_layer_norm
import
FusedLayerNorm
from
megatron.core.transformer.torch_norm
import
WrappedTorchNorm
HAVE_APEX
=
True
warnings
.
warn
(
'Apex is not installed. Falling back to Torch Norm'
)
LNImpl
=
FusedLayerNorm
LNImpl
=
WrappedTorchNorm
except
ImportError
:
from
megatron.core.transformer.torch_norm
import
WrappedTorchNorm
def
get_gpt_layer_with_transformer_engine_spec
(
warnings
.
warn
(
'Apex is not installed. Falling back to Torch Norm'
)
num_experts
:
Optional
[
int
]
=
None
,
LNImpl
=
WrappedTorchNorm
moe_grouped_gemm
:
Optional
[
bool
]
=
False
,
qk_layernorm
:
Optional
[
bool
]
=
False
,
multi_latent_attention
:
Optional
[
bool
]
=
False
,
def
get_gpt_layer_with_transformer_engine_spec
(
fp8
:
Optional
[
str
]
=
None
,
# pylint: disable=unused-arguments
num_experts
:
Optional
[
int
]
=
None
,
moe_use_legacy_grouped_gemm
:
Optional
[
bool
]
=
False
,
moe_grouped_gemm
:
Optional
[
bool
]
=
False
,
)
->
ModuleSpec
:
qk_layernorm
:
Optional
[
bool
]
=
False
,
"""Use this spec to use lower-level Transformer Engine modules (required for fp8 training).
multi_latent_attention
:
Optional
[
bool
]
=
False
,
fp8
:
Optional
[
str
]
=
None
,
# pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm
:
Optional
[
bool
]
=
False
,
Args:
)
->
ModuleSpec
:
num_experts (int, optional): Number of experts. Defaults to None.
"""Use this spec to use lower-level Transformer Engine modules (required for fp8 training).
moe_grouped_gemm (bool, optional): To use Grouped GEMM. Defaults to False.
qk_layernorm (bool, optional): To use layernorm for queries/keys. Defaults to False.
fp8 (str, optional): Deprecated. For temporary Nemo compatibility.
Args:
moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP.
num_experts (int, optional): Number of experts. Defaults to None.
Defaults to False.
moe_grouped_gemm (bool, optional): To use Grouped GEMM. Defaults to False.
qk_layernorm (bool, optional): To use layernorm for queries/keys. Defaults to False.
Returns:
fp8 (str, optional): Deprecated. For temporary Nemo compatibility.
ModuleSpec: Module specification with TE modules
moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP.
"""
Defaults to False.
if
fp8
is
not
None
:
warnings
.
warn
(
Returns:
'The fp8 argument in "get_gpt_layer_with_transformer_engine_spec" has been deprecated'
ModuleSpec: Module specification with TE modules
' and will be removed soon. Please update your code accordingly.'
"""
)
if
fp8
is
not
None
:
warnings
.
warn
(
mlp
=
_get_mlp_module_spec
(
'The fp8 argument in "get_gpt_layer_with_transformer_engine_spec" has been deprecated'
use_te
=
True
,
' and will be removed soon. Please update your code accordingly.'
num_experts
=
num_experts
,
)
moe_grouped_gemm
=
moe_grouped_gemm
,
moe_use_legacy_grouped_gemm
=
moe_use_legacy_grouped_gemm
,
mlp
=
get_mlp_module_spec
(
)
use_te
=
True
,
num_experts
=
num_experts
,
if
multi_latent_attention
:
moe_grouped_gemm
=
moe_grouped_gemm
,
return
ModuleSpec
(
moe_use_legacy_grouped_gemm
=
moe_use_legacy_grouped_gemm
,
module
=
TransformerLayer
,
)
submodules
=
TransformerLayerSubmodules
(
input_layernorm
=
TENorm
,
if
multi_latent_attention
:
self_attention
=
ModuleSpec
(
return
ModuleSpec
(
module
=
MLASelfAttention
,
module
=
TransformerLayer
,
params
=
{
"attn_mask_type"
:
AttnMaskType
.
causal
},
submodules
=
TransformerLayerSubmodules
(
submodules
=
MLASelfAttentionSubmodules
(
input_layernorm
=
TENorm
,
linear_q_proj
=
TEColumnParallelLinear
,
self_attention
=
ModuleSpec
(
linear_q_down_proj
=
TEColumnParallelLinear
,
module
=
MLASelfAttention
,
linear_q_up_proj
=
TEColumnParallelLinear
,
params
=
{
"attn_mask_type"
:
AttnMaskType
.
causal
},
linear_kv_down_proj
=
TEColumnParallelLinear
,
submodules
=
MLASelfAttentionSubmodules
(
linear_kv_up_proj
=
TEColumnParallelLinear
,
linear_q_proj
=
TEColumnParallelLinear
,
core_attention
=
TEDotProductAttention
,
linear_q_down_proj
=
TEColumnParallelLinear
,
linear_proj
=
TERowParallelLinear
,
linear_q_up_proj
=
(
q_layernorm
=
TENorm
if
qk_layernorm
else
IdentityOp
,
TELayerNormColumnParallelLinear
kv_layernorm
=
TENorm
if
qk_layernorm
else
IdentityOp
,
if
qk_layernorm
),
else
TEColumnParallelLinear
),
),
self_attn_bda
=
get_bias_dropout_add
,
linear_kv_down_proj
=
TEColumnParallelLinear
,
pre_mlp_layernorm
=
TENorm
if
num_experts
else
IdentityOp
,
linear_kv_up_proj
=
(
mlp
=
mlp
,
TELayerNormColumnParallelLinear
mlp_bda
=
get_bias_dropout_add
,
if
qk_layernorm
),
else
TEColumnParallelLinear
)
),
else
:
core_attention
=
TEDotProductAttention
,
linear_proj
=
TERowParallelLinear
,
# TENorm significantly harms convergence when used
q_layernorm
=
IdentityOp
,
# for QKLayerNorm if TE Version < 1.9;
kv_layernorm
=
IdentityOp
,
# we instead use the Apex implementation.
),
qk_norm
=
TENorm
if
is_te_min_version
(
"1.9.0"
)
else
FusedLayerNorm
),
self_attn_bda
=
get_bias_dropout_add
,
return
ModuleSpec
(
pre_mlp_layernorm
=
TENorm
if
num_experts
else
IdentityOp
,
module
=
TransformerLayer
,
mlp
=
mlp
,
submodules
=
TransformerLayerSubmodules
(
mlp_bda
=
get_bias_dropout_add
,
self_attention
=
ModuleSpec
(
),
module
=
SelfAttention
,
)
params
=
{
"attn_mask_type"
:
AttnMaskType
.
causal
},
else
:
submodules
=
SelfAttentionSubmodules
(
linear_qkv
=
TELayerNormColumnParallelLinear
,
# TENorm significantly harms convergence when used
core_attention
=
TEDotProductAttention
,
# for QKLayerNorm if TE Version < 1.9;
linear_proj
=
TERowParallelLinear
,
# we instead use the Apex implementation.
q_layernorm
=
qk_norm
if
qk_layernorm
else
IdentityOp
,
qk_norm
=
TENorm
if
is_te_min_version
(
"1.9.0"
)
else
FusedLayerNorm
k_layernorm
=
qk_norm
if
qk_layernorm
else
IdentityOp
,
),
return
ModuleSpec
(
),
module
=
TransformerLayer
,
self_attn_bda
=
get_bias_dropout_add
,
submodules
=
TransformerLayerSubmodules
(
pre_mlp_layernorm
=
TENorm
if
num_experts
else
IdentityOp
,
self_attention
=
ModuleSpec
(
mlp
=
mlp
,
module
=
SelfAttention
,
mlp_bda
=
get_bias_dropout_add
,
params
=
{
"attn_mask_type"
:
AttnMaskType
.
causal
},
),
submodules
=
SelfAttentionSubmodules
(
)
linear_qkv
=
TELayerNormColumnParallelLinear
,
core_attention
=
TEDotProductAttention
,
linear_proj
=
TERowParallelLinear
,
def
get_gpt_layer_local_spec
(
q_layernorm
=
qk_norm
if
qk_layernorm
else
IdentityOp
,
num_experts
:
Optional
[
int
]
=
None
,
k_layernorm
=
qk_norm
if
qk_layernorm
else
IdentityOp
,
moe_grouped_gemm
:
Optional
[
bool
]
=
False
,
),
qk_layernorm
:
Optional
[
bool
]
=
False
,
),
multi_latent_attention
:
Optional
[
bool
]
=
False
,
self_attn_bda
=
get_bias_dropout_add
,
fp8
:
Optional
[
str
]
=
None
,
# pylint: disable=unused-arguments
pre_mlp_layernorm
=
TENorm
if
num_experts
else
IdentityOp
,
moe_use_legacy_grouped_gemm
:
Optional
[
bool
]
=
False
,
mlp
=
mlp
,
)
->
ModuleSpec
:
mlp_bda
=
get_bias_dropout_add
,
"""Use this spec for an implementation using only modules in Megatron-Core.
),
)
Args:
num_experts (int, optional): Number of experts. Defaults to None.
def
get_gpt_layer_local_spec
(
moe_grouped_gemm (bool, optional): To use Grouped GEMM. Defaults to False.
num_experts
:
Optional
[
int
]
=
None
,
qk_layernorm (bool, optional): To use layernorm for queries/keys. Defaults to False.
moe_grouped_gemm
:
Optional
[
bool
]
=
False
,
fp8 (str, optional): Deprecated. For temporary Nemo compatibility.
qk_layernorm
:
Optional
[
bool
]
=
False
,
moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP.
multi_latent_attention
:
Optional
[
bool
]
=
False
,
Defaults to False.
fp8
:
Optional
[
str
]
=
None
,
# pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm
:
Optional
[
bool
]
=
False
,
Returns:
)
->
ModuleSpec
:
ModuleSpec: Module specification with Megatron-Core modules
"""Use this spec for an implementation using only modules in Megatron-Core.
"""
if
fp8
is
not
None
:
warnings
.
warn
(
Args:
'The fp8 argument in "get_gpt_layer_local_spec" has been deprecated'
num_experts (int, optional): Number of experts. Defaults to None.
' and will be removed soon. Please update your code accordingly.'
moe_grouped_gemm (bool, optional): To use Grouped GEMM. Defaults to False.
)
qk_layernorm (bool, optional): To use layernorm for queries/keys. Defaults to False.
fp8 (str, optional): Deprecated. For temporary Nemo compatibility.
mlp
=
_get_mlp_module_spec
(
moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP.
use_te
=
False
,
Defaults to False.
num_experts
=
num_experts
,
moe_grouped_gemm
=
moe_grouped_gemm
,
Returns:
moe_use_legacy_grouped_gemm
=
moe_use_legacy_grouped_gemm
,
ModuleSpec: Module specification with Megatron-Core modules
)
"""
if
fp8
is
not
None
:
if
multi_latent_attention
:
warnings
.
warn
(
return
ModuleSpec
(
'The fp8 argument in "get_gpt_layer_local_spec" has been deprecated'
module
=
TransformerLayer
,
' and will be removed soon. Please update your code accordingly.'
submodules
=
TransformerLayerSubmodules
(
)
input_layernorm
=
LNImpl
,
self_attention
=
ModuleSpec
(
mlp
=
get_mlp_module_spec
(
module
=
MLASelfAttention
,
use_te
=
False
,
params
=
{
"attn_mask_type"
:
AttnMaskType
.
causal
},
num_experts
=
num_experts
,
submodules
=
MLASelfAttentionSubmodules
(
moe_grouped_gemm
=
moe_grouped_gemm
,
linear_q_proj
=
ColumnParallelLinear
,
moe_use_legacy_grouped_gemm
=
moe_use_legacy_grouped_gemm
,
linear_q_down_proj
=
ColumnParallelLinear
,
)
linear_q_up_proj
=
ColumnParallelLinear
,
linear_kv_down_proj
=
ColumnParallelLinear
,
if
multi_latent_attention
:
linear_kv_up_proj
=
ColumnParallelLinear
,
return
ModuleSpec
(
core_attention
=
DotProductAttention
,
module
=
TransformerLayer
,
linear_proj
=
RowParallelLinear
,
submodules
=
TransformerLayerSubmodules
(
q_layernorm
=
LNImpl
if
qk_layernorm
else
IdentityOp
,
input_layernorm
=
LNImpl
,
kv_layernorm
=
LNImpl
if
qk_layernorm
else
IdentityOp
,
self_attention
=
ModuleSpec
(
),
module
=
MLASelfAttention
,
),
params
=
{
"attn_mask_type"
:
AttnMaskType
.
causal
},
self_attn_bda
=
get_bias_dropout_add
,
submodules
=
MLASelfAttentionSubmodules
(
pre_mlp_layernorm
=
LNImpl
,
linear_q_proj
=
ColumnParallelLinear
,
mlp
=
mlp
,
linear_q_down_proj
=
ColumnParallelLinear
,
mlp_bda
=
get_bias_dropout_add
,
linear_q_up_proj
=
ColumnParallelLinear
,
),
linear_kv_down_proj
=
ColumnParallelLinear
,
)
linear_kv_up_proj
=
ColumnParallelLinear
,
else
:
core_attention
=
DotProductAttention
,
return
ModuleSpec
(
linear_proj
=
RowParallelLinear
,
module
=
TransformerLayer
,
q_layernorm
=
LNImpl
if
qk_layernorm
else
IdentityOp
,
submodules
=
TransformerLayerSubmodules
(
kv_layernorm
=
LNImpl
if
qk_layernorm
else
IdentityOp
,
input_layernorm
=
LNImpl
,
),
self_attention
=
ModuleSpec
(
),
module
=
SelfAttention
,
self_attn_bda
=
get_bias_dropout_add
,
params
=
{
"attn_mask_type"
:
AttnMaskType
.
causal
},
pre_mlp_layernorm
=
LNImpl
,
submodules
=
SelfAttentionSubmodules
(
mlp
=
mlp
,
linear_qkv
=
ColumnParallelLinear
,
mlp_bda
=
get_bias_dropout_add
,
core_attention
=
DotProductAttention
,
),
linear_proj
=
RowParallelLinear
,
)
q_layernorm
=
LNImpl
if
qk_layernorm
else
IdentityOp
,
else
:
k_layernorm
=
LNImpl
if
qk_layernorm
else
IdentityOp
,
return
ModuleSpec
(
),
module
=
TransformerLayer
,
),
submodules
=
TransformerLayerSubmodules
(
self_attn_bda
=
get_bias_dropout_add
,
input_layernorm
=
LNImpl
,
pre_mlp_layernorm
=
LNImpl
,
self_attention
=
ModuleSpec
(
mlp
=
mlp
,
module
=
SelfAttention
,
mlp_bda
=
get_bias_dropout_add
,
params
=
{
"attn_mask_type"
:
AttnMaskType
.
causal
},
sharded_state_dict_keys_map
=
{
submodules
=
SelfAttentionSubmodules
(
'input_layernorm.'
:
'self_attention.linear_qkv.layer_norm_'
,
linear_qkv
=
ColumnParallelLinear
,
'pre_mlp_layernorm.'
:
'mlp.linear_fc1.layer_norm_'
,
core_attention
=
DotProductAttention
,
},
linear_proj
=
RowParallelLinear
,
),
q_layernorm
=
LNImpl
if
qk_layernorm
else
IdentityOp
,
)
k_layernorm
=
LNImpl
if
qk_layernorm
else
IdentityOp
,
),
),
def
_get_mlp_module_spec
(
self_attn_bda
=
get_bias_dropout_add
,
use_te
:
Optional
[
bool
]
=
True
,
pre_mlp_layernorm
=
LNImpl
,
num_experts
:
Optional
[
int
]
=
None
,
mlp
=
mlp
,
moe_grouped_gemm
:
Optional
[
bool
]
=
False
,
mlp_bda
=
get_bias_dropout_add
,
fp8
:
Optional
[
str
]
=
None
,
# pylint: disable=unused-arguments
sharded_state_dict_keys_map
=
{
moe_use_legacy_grouped_gemm
:
Optional
[
bool
]
=
False
,
'input_layernorm.'
:
'self_attention.linear_qkv.layer_norm_'
,
)
->
ModuleSpec
:
'pre_mlp_layernorm.'
:
'mlp.linear_fc1.layer_norm_'
,
"""Helper function to get module spec for MLP/MoE"""
},
if
fp8
is
not
None
:
),
warnings
.
warn
(
)
'The fp8 argument in "_get_mlp_module_spec" has been deprecated'
' and will be removed soon. Please update your code accordingly.'
)
def
_get_mlp_module_spec
(
use_te
:
Optional
[
bool
]
=
True
,
if
num_experts
is
None
:
num_experts
:
Optional
[
int
]
=
None
,
# Dense MLP w/ or w/o TE modules.
moe_grouped_gemm
:
Optional
[
bool
]
=
False
,
return
ModuleSpec
(
fp8
:
Optional
[
str
]
=
None
,
# pylint: disable=unused-arguments
module
=
MLP
,
moe_use_legacy_grouped_gemm
:
Optional
[
bool
]
=
False
,
submodules
=
MLPSubmodules
(
):
linear_fc1
=
TELayerNormColumnParallelLinear
if
use_te
else
ColumnParallelLinear
,
warnings
.
warn
(
linear_fc2
=
TERowParallelLinear
if
use_te
else
RowParallelLinear
,
"""This private function is on a deprecation track. Please switch to `get_mlp_module_spec`
),
since it will be removed in a future release."""
)
)
else
:
# Mixture of experts with modules in megatron core.
return
get_mlp_module_spec
(
return
get_moe_module_spec
(
use_te
=
use_te
,
use_te
=
use_te
,
num_experts
=
num_experts
,
num_experts
=
num_experts
,
moe_grouped_gemm
=
moe_grouped_gemm
,
moe_grouped_gemm
=
moe_grouped_gemm
,
fp8
=
fp8
,
moe_use_legacy_grouped_gemm
=
moe_use_legacy_grouped_gemm
,
moe_use_legacy_grouped_gemm
=
moe_use_legacy_grouped_gemm
,
)
)
def
get_gpt_decoder_block_spec
(
def
get_mlp_module_spec
(
config
:
TransformerConfig
,
use_transformer_engine
:
bool
use_te
:
Optional
[
bool
]
=
True
,
)
->
TransformerBlockSubmodules
:
num_experts
:
Optional
[
int
]
=
None
,
"""GPT block spec."""
moe_grouped_gemm
:
Optional
[
bool
]
=
False
,
if
use_transformer_engine
:
fp8
:
Optional
[
str
]
=
None
,
# pylint: disable=unused-arguments
layer_norm_impl
=
TENorm
moe_use_legacy_grouped_gemm
:
Optional
[
bool
]
=
False
,
else
:
)
->
ModuleSpec
:
layer_norm_impl
=
LNImpl
"""Helper function to get module spec for MLP/MoE"""
if
fp8
is
not
None
:
# Layer specs.
warnings
.
warn
(
dense_layer_spec
=
(
'The fp8 argument in "_get_mlp_module_spec" has been deprecated'
get_gpt_layer_with_transformer_engine_spec
(
' and will be removed soon. Please update your code accordingly.'
num_experts
=
None
,
)
moe_grouped_gemm
=
False
,
qk_layernorm
=
config
.
qk_layernorm
,
if
num_experts
is
None
:
multi_latent_attention
=
config
.
multi_latent_attention
,
# Dense MLP w/ or w/o TE modules.
moe_use_legacy_grouped_gemm
=
config
.
moe_use_legacy_grouped_gemm
,
return
ModuleSpec
(
)
module
=
MLP
,
if
use_transformer_engine
submodules
=
MLPSubmodules
(
else
get_gpt_layer_local_spec
(
linear_fc1
=
TELayerNormColumnParallelLinear
if
use_te
else
ColumnParallelLinear
,
num_experts
=
None
,
linear_fc2
=
TERowParallelLinear
if
use_te
else
RowParallelLinear
,
moe_grouped_gemm
=
False
,
),
qk_layernorm
=
config
.
qk_layernorm
,
)
multi_latent_attention
=
config
.
multi_latent_attention
,
else
:
moe_use_legacy_grouped_gemm
=
config
.
moe_use_legacy_grouped_gemm
,
# Mixture of experts with modules in megatron core.
)
return
get_moe_module_spec
(
)
use_te
=
use_te
,
moe_layer_spec
=
(
num_experts
=
num_experts
,
get_gpt_layer_with_transformer_engine_spec
(
moe_grouped_gemm
=
moe_grouped_gemm
,
num_experts
=
config
.
num_moe_experts
,
moe_use_legacy_grouped_gemm
=
moe_use_legacy_grouped_gemm
,
moe_grouped_gemm
=
config
.
moe_grouped_gemm
,
)
qk_layernorm
=
config
.
qk_layernorm
,
multi_latent_attention
=
config
.
multi_latent_attention
,
moe_use_legacy_grouped_gemm
=
config
.
moe_use_legacy_grouped_gemm
,
def
get_gpt_decoder_block_spec
(
)
config
:
TransformerConfig
,
use_transformer_engine
:
bool
if
use_transformer_engine
)
->
TransformerBlockSubmodules
:
else
get_gpt_layer_local_spec
(
"""GPT block spec."""
num_experts
=
config
.
num_moe_experts
,
if
use_transformer_engine
:
moe_grouped_gemm
=
config
.
moe_grouped_gemm
,
layer_norm_impl
=
TENorm
qk_layernorm
=
config
.
qk_layernorm
,
else
:
multi_latent_attention
=
config
.
multi_latent_attention
,
layer_norm_impl
=
LNImpl
moe_use_legacy_grouped_gemm
=
config
.
moe_use_legacy_grouped_gemm
,
)
# Layer specs.
)
dense_layer_spec
=
(
get_gpt_layer_with_transformer_engine_spec
(
# Parse config.moe_layer_freq to determine the pattern of expert/dense layers.
num_experts
=
None
,
# 0 stands for dense layers, 1 stands for expert layers.
moe_grouped_gemm
=
False
,
# For integer N: Creates a pattern with one expert layer every N layers.
qk_layernorm
=
config
.
qk_layernorm
,
# For string pattern: Evaluates the str directly (e.g. "[1,0,1]" for alternating expert/dense).
multi_latent_attention
=
config
.
multi_latent_attention
,
if
isinstance
(
config
.
moe_layer_freq
,
int
):
moe_use_legacy_grouped_gemm
=
config
.
moe_use_legacy_grouped_gemm
,
moe_layer_pattern
=
[
)
1
if
(
i
%
config
.
moe_layer_freq
==
0
)
else
0
for
i
in
range
(
config
.
num_layers
)
if
use_transformer_engine
]
else
get_gpt_layer_local_spec
(
elif
isinstance
(
config
.
moe_layer_freq
,
list
):
num_experts
=
None
,
moe_layer_pattern
=
config
.
moe_layer_freq
moe_grouped_gemm
=
False
,
assert
len
(
moe_layer_pattern
)
==
config
.
num_layers
,
(
qk_layernorm
=
config
.
qk_layernorm
,
f
"Invalid length of moe_layer_pattern:
{
len
(
moe_layer_pattern
)
}
, "
multi_latent_attention
=
config
.
multi_latent_attention
,
f
"expected
{
config
.
num_layers
}
, "
moe_use_legacy_grouped_gemm
=
config
.
moe_use_legacy_grouped_gemm
,
f
"current moe layer pattern:
{
config
.
moe_layer_freq
}
"
)
)
)
else
:
moe_layer_spec
=
(
raise
ValueError
(
get_gpt_layer_with_transformer_engine_spec
(
f
"Invalid moe_layer_freq:
{
type
(
config
.
moe_layer_freq
)
}
,
{
config
.
moe_layer_freq
}
"
num_experts
=
config
.
num_moe_experts
,
)
moe_grouped_gemm
=
config
.
moe_grouped_gemm
,
qk_layernorm
=
config
.
qk_layernorm
,
# Create the layer specs for the model.
multi_latent_attention
=
config
.
multi_latent_attention
,
layer_specs
=
[]
moe_use_legacy_grouped_gemm
=
config
.
moe_use_legacy_grouped_gemm
,
for
layer_number
in
range
(
config
.
num_layers
):
)
if
moe_layer_pattern
[
layer_number
]
==
1
:
if
use_transformer_engine
layer_specs
.
append
(
moe_layer_spec
)
else
get_gpt_layer_local_spec
(
elif
moe_layer_pattern
[
layer_number
]
==
0
:
num_experts
=
config
.
num_moe_experts
,
layer_specs
.
append
(
dense_layer_spec
)
moe_grouped_gemm
=
config
.
moe_grouped_gemm
,
else
:
qk_layernorm
=
config
.
qk_layernorm
,
raise
ValueError
(
f
"Invalid layer pattern:
{
moe_layer_pattern
}
"
)
multi_latent_attention
=
config
.
multi_latent_attention
,
moe_use_legacy_grouped_gemm
=
config
.
moe_use_legacy_grouped_gemm
,
# Slice the layer specs to only include the layers that are built in this pipeline stage.
)
# Note: MCore layer_number starts at 1
)
offset
=
TransformerLayer
.
_get_layer_offset
(
config
)
num_layers_to_build
=
get_num_layers_to_build
(
config
)
# Parse config.moe_layer_freq to determine the pattern of expert/dense layers.
layer_specs
=
layer_specs
[
offset
:
offset
+
num_layers_to_build
]
# 0 stands for dense layers, 1 stands for expert layers.
# For integer N: Creates a pattern with one expert layer every N layers.
# Block spec.
# For string pattern: Evaluates the str directly (e.g. "[1,0,1]" for alternating expert/dense).
block_spec
=
TransformerBlockSubmodules
(
layer_specs
=
layer_specs
,
layer_norm
=
layer_norm_impl
)
if
isinstance
(
config
.
moe_layer_freq
,
int
):
moe_layer_pattern
=
[
return
block_spec
1
if
(
i
%
config
.
moe_layer_freq
==
0
)
else
0
for
i
in
range
(
config
.
num_layers
)
]
elif
isinstance
(
config
.
moe_layer_freq
,
list
):
moe_layer_pattern
=
config
.
moe_layer_freq
assert
len
(
moe_layer_pattern
)
==
config
.
num_layers
,
(
f
"Invalid length of moe_layer_pattern:
{
len
(
moe_layer_pattern
)
}
, "
f
"expected
{
config
.
num_layers
}
, "
f
"current moe layer pattern:
{
config
.
moe_layer_freq
}
"
)
else
:
raise
ValueError
(
f
"Invalid moe_layer_freq:
{
type
(
config
.
moe_layer_freq
)
}
,
{
config
.
moe_layer_freq
}
"
)
# Create the layer specs for the model.
layer_specs
=
[]
for
layer_number
in
range
(
config
.
num_layers
):
if
moe_layer_pattern
[
layer_number
]
==
1
:
layer_specs
.
append
(
moe_layer_spec
)
elif
moe_layer_pattern
[
layer_number
]
==
0
:
layer_specs
.
append
(
dense_layer_spec
)
else
:
raise
ValueError
(
f
"Invalid layer pattern:
{
moe_layer_pattern
}
"
)
# Slice the layer specs to only include the layers that are built in this pipeline stage.
# Note: MCore layer_number starts at 1
offset
=
get_transformer_layer_offset
(
config
)
num_layers_to_build
=
get_num_layers_to_build
(
config
)
layer_specs
=
layer_specs
[
offset
:
offset
+
num_layers_to_build
]
# Block spec.
block_spec
=
TransformerBlockSubmodules
(
layer_specs
=
layer_specs
,
layer_norm
=
layer_norm_impl
)
return
block_spec
megatron/core/models/gpt/gpt_model.py
View file @
688448db
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
typing
import
Dict
,
Literal
,
Optional
from
typing
import
Dict
,
Literal
,
Optional
from
torch
import
Tensor
import
torch
from
torch
import
Tensor
from
megatron.core
import
InferenceParams
,
tensor_parallel
from
megatron.core.config_logger
import
has_config_logger_enabled
,
log_config_to_disk
from
megatron.core
import
InferenceParams
,
tensor_parallel
from
megatron.core.dist_checkpointing.mapping
import
ShardedStateDict
from
megatron.core.config_logger
import
has_config_logger_enabled
,
log_config_to_disk
from
megatron.core.models.common.embeddings.language_model_embedding
import
LanguageModelEmbedding
from
megatron.core.dist_checkpointing.mapping
import
ShardedStateDict
from
megatron.core.models.common.embeddings.rotary_pos_embedding
import
RotaryEmbedding
from
megatron.core.models.common.embeddings.language_model_embedding
import
LanguageModelEmbedding
from
megatron.core.models.common.language_module.language_module
import
LanguageModule
from
megatron.core.models.common.embeddings.rotary_pos_embedding
import
RotaryEmbedding
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.models.common.language_module.language_module
import
LanguageModule
from
megatron.core.transformer.enums
import
ModelType
from
megatron.core.packed_seq_params
import
PackedSeqParams
from
megatron.core.transformer.spec_utils
import
ModuleSpec
from
megatron.core.transformer.enums
import
ModelType
from
megatron.core.transformer.transformer_block
import
TransformerBlock
from
megatron.core.transformer.spec_utils
import
ModuleSpec
from
megatron.core.transformer.transformer_config
import
TransformerConfig
from
megatron.core.transformer.transformer_block
import
TransformerBlock
from
megatron.core.transformer.transformer_config
import
TransformerConfig
class
GPTModel
(
LanguageModule
):
"""GPT Transformer language model.
class
GPTModel
(
LanguageModule
):
"""GPT Transformer language model.
Args:
config (TransformerConfig):
Args:
Transformer config
config (TransformerConfig):
transformer_layer_spec (ModuleSpec):
Transformer config
Specifies module to use for transformer layers
transformer_layer_spec (ModuleSpec):
vocab_size (int):
Specifies module to use for transformer layers
Vocabulary size
vocab_size (int):
max_sequence_length (int):
Vocabulary size
maximum size of sequence. This is used for positional embedding
max_sequence_length (int):
pre_process (bool, optional):
maximum size of sequence. This is used for positional embedding
Include embedding layer (used with pipeline parallelism). Defaults to True.
pre_process (bool, optional):
post_process (bool, optional):
Include embedding layer (used with pipeline parallelism). Defaults to True.
Include an output layer (used with pipeline parallelism). Defaults to True.
post_process (bool, optional):
fp16_lm_cross_entropy (bool, optional):
Include an output layer (used with pipeline parallelism). Defaults to True.
Defaults to False.
fp16_lm_cross_entropy (bool, optional):
parallel_output (bool, optional):
Defaults to False.
Do not gather the outputs, keep them split across tensor
parallel_output (bool, optional):
parallel ranks. Defaults to True.
Do not gather the outputs, keep them split across tensor
share_embeddings_and_output_weights (bool, optional):
parallel ranks. Defaults to True.
When True, input embeddings and output logit weights are shared. Defaults to False.
share_embeddings_and_output_weights (bool, optional):
position_embedding_type (Literal[learned_absolute,rope], optional):
When True, input embeddings and output logit weights are shared. Defaults to False.
Position embedding type.. Defaults to 'learned_absolute'.
position_embedding_type (Literal[learned_absolute,rope], optional):
rotary_percent (float, optional):
Position embedding type.. Defaults to 'learned_absolute'.
Percent of rotary dimension to use for rotary position embeddings.
rotary_percent (float, optional):
Ignored unless position_embedding_type is 'rope'. Defaults to 1.0.
Percent of rotary dimension to use for rotary position embeddings.
rotary_base (int, optional):
Ignored unless position_embedding_type is 'rope'. Defaults to 1.0.
Base period for rotary position embeddings. Ignored unless
rotary_base (int, optional):
position_embedding_type is 'rope'.
Base period for rotary position embeddings. Ignored unless
Defaults to 10000.
position_embedding_type is 'rope'.
scatter_embedding_sequence_parallel (bool, optional):
Defaults to 10000.
Whether embeddings should be scattered across sequence parallel
rope_scaling (bool, optional): Toggle RoPE scaling.
region or not. Defaults to True.
rope_scaling_factor (float): RoPE scaling factor. Default 8.
seq_len_interpolation_factor (Optional[float], optional):
scatter_embedding_sequence_parallel (bool, optional):
scale of linearly interpolating RoPE for longer sequences.
Whether embeddings should be scattered across sequence parallel
The value must be a float larger than 1.0. Defaults to None.
region or not. Defaults to True.
"""
seq_len_interpolation_factor (Optional[float], optional):
scale of linearly interpolating RoPE for longer sequences.
def
__init__
(
The value must be a float larger than 1.0. Defaults to None.
self
,
"""
config
:
TransformerConfig
,
transformer_layer_spec
:
ModuleSpec
,
def
__init__
(
vocab_size
:
int
,
self
,
max_sequence_length
:
int
,
config
:
TransformerConfig
,
pre_process
:
bool
=
True
,
transformer_layer_spec
:
ModuleSpec
,
post_process
:
bool
=
True
,
vocab_size
:
int
,
fp16_lm_cross_entropy
:
bool
=
False
,
max_sequence_length
:
int
,
parallel_output
:
bool
=
True
,
pre_process
:
bool
=
True
,
share_embeddings_and_output_weights
:
bool
=
False
,
post_process
:
bool
=
True
,
position_embedding_type
:
Literal
[
'learned_absolute'
,
'rope'
,
'none'
]
=
'learned_absolute'
,
fp16_lm_cross_entropy
:
bool
=
False
,
rotary_percent
:
float
=
1.0
,
parallel_output
:
bool
=
True
,
rotary_base
:
int
=
10000
,
share_embeddings_and_output_weights
:
bool
=
False
,
rope_scaling
:
bool
=
False
,
position_embedding_type
:
Literal
[
'learned_absolute'
,
'rope'
,
'none'
]
=
'learned_absolute'
,
scatter_embedding_sequence_parallel
:
bool
=
True
,
rotary_percent
:
float
=
1.0
,
seq_len_interpolation_factor
:
Optional
[
float
]
=
None
,
rotary_base
:
int
=
10000
,
)
->
None
:
rope_scaling
:
bool
=
False
,
super
().
__init__
(
config
=
config
)
rope_scaling_factor
:
float
=
8.0
,
scatter_embedding_sequence_parallel
:
bool
=
True
,
if
has_config_logger_enabled
(
config
):
seq_len_interpolation_factor
:
Optional
[
float
]
=
None
,
log_config_to_disk
(
config
,
locals
(),
prefix
=
type
(
self
).
__name__
)
)
->
None
:
super
().
__init__
(
config
=
config
)
self
.
transformer_layer_spec
:
ModuleSpec
=
transformer_layer_spec
self
.
vocab_size
=
vocab_size
if
has_config_logger_enabled
(
config
):
self
.
max_sequence_length
=
max_sequence_length
log_config_to_disk
(
config
,
locals
(),
prefix
=
type
(
self
).
__name__
)
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
transformer_layer_spec
:
ModuleSpec
=
transformer_layer_spec
self
.
fp16_lm_cross_entropy
=
fp16_lm_cross_entropy
self
.
vocab_size
=
vocab_size
self
.
parallel_output
=
parallel_output
self
.
max_sequence_length
=
max_sequence_length
self
.
share_embeddings_and_output_weights
=
share_embeddings_and_output_weights
self
.
pre_process
=
pre_process
self
.
position_embedding_type
=
position_embedding_type
self
.
post_process
=
post_process
self
.
fp16_lm_cross_entropy
=
fp16_lm_cross_entropy
# megatron core pipelining currently depends on model type
self
.
parallel_output
=
parallel_output
# TODO: remove this dependency ?
self
.
share_embeddings_and_output_weights
=
share_embeddings_and_output_weights
self
.
model_type
=
ModelType
.
encoder_or_decoder
self
.
position_embedding_type
=
position_embedding_type
# These 4 attributes are needed for TensorRT-LLM export.
# megatron core pipelining currently depends on model type
self
.
max_position_embeddings
=
max_sequence_length
# TODO: remove this dependency ?
self
.
rotary_percent
=
rotary_percent
self
.
model_type
=
ModelType
.
encoder_or_decoder
self
.
rotary_base
=
rotary_base
self
.
rotary_scaling
=
rope_scaling
# These 4 attributes are needed for TensorRT-LLM export.
self
.
max_position_embeddings
=
max_sequence_length
if
self
.
pre_process
:
self
.
rotary_percent
=
rotary_percent
self
.
embedding
=
LanguageModelEmbedding
(
self
.
rotary_base
=
rotary_base
config
=
self
.
config
,
self
.
rotary_scaling
=
rope_scaling
vocab_size
=
self
.
vocab_size
,
max_sequence_length
=
self
.
max_sequence_length
,
if
self
.
pre_process
:
position_embedding_type
=
position_embedding_type
,
self
.
embedding
=
LanguageModelEmbedding
(
scatter_to_sequence_parallel
=
scatter_embedding_sequence_parallel
,
config
=
self
.
config
,
)
vocab_size
=
self
.
vocab_size
,
max_sequence_length
=
self
.
max_sequence_length
,
if
self
.
position_embedding_type
==
'rope'
and
not
self
.
config
.
multi_latent_attention
:
position_embedding_type
=
position_embedding_type
,
self
.
rotary_pos_emb
=
RotaryEmbedding
(
scatter_to_sequence_parallel
=
scatter_embedding_sequence_parallel
,
kv_channels
=
self
.
config
.
kv_channels
,
)
rotary_percent
=
rotary_percent
,
rotary_interleaved
=
self
.
config
.
rotary_interleaved
,
if
self
.
position_embedding_type
==
'rope'
and
not
self
.
config
.
multi_latent_attention
:
seq_len_interpolation_factor
=
seq_len_interpolation_factor
,
self
.
rotary_pos_emb
=
RotaryEmbedding
(
rotary_base
=
rotary_base
,
kv_channels
=
self
.
config
.
kv_channels
,
rope_scaling
=
rope_scaling
,
rotary_percent
=
rotary_percent
,
use_cpu_initialization
=
self
.
config
.
use_cpu_initialization
,
rotary_interleaved
=
self
.
config
.
rotary_interleaved
,
)
seq_len_interpolation_factor
=
seq_len_interpolation_factor
,
rotary_base
=
rotary_base
,
# Transformer.
rope_scaling
=
rope_scaling
,
self
.
decoder
=
TransformerBlock
(
rope_scaling_factor
=
rope_scaling_factor
,
config
=
self
.
config
,
use_cpu_initialization
=
self
.
config
.
use_cpu_initialization
,
spec
=
transformer_layer_spec
,
)
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
,
# Cache for RoPE tensors which do not change between iterations.
)
self
.
rotary_pos_emb_cache
=
{}
# Output
# Transformer.
if
post_process
:
self
.
decoder
=
TransformerBlock
(
if
self
.
config
.
defer_embedding_wgrad_compute
:
config
=
self
.
config
,
# The embedding activation buffer preserves a reference to the input activations
spec
=
transformer_layer_spec
,
# of the final embedding projection layer GEMM. It will hold the activations for
pre_process
=
self
.
pre_process
,
# all the micro-batches of a global batch for the last pipeline stage. Once we are
post_process
=
self
.
post_process
,
# done with all the back props for all the microbatches for the last pipeline stage,
)
# it will be in the pipeline flush stage. During this pipeline flush we use the
# input activations stored in embedding activation buffer and gradient outputs
# Output
# stored in gradient buffer to calculate the weight gradients for the embedding
if
post_process
:
# final linear layer.
if
self
.
config
.
defer_embedding_wgrad_compute
:
self
.
embedding_activation_buffer
=
[]
# The embedding activation buffer preserves a reference to the input activations
self
.
grad_output_buffer
=
[]
# of the final embedding projection layer GEMM. It will hold the activations for
else
:
# all the micro-batches of a global batch for the last pipeline stage. Once we are
self
.
embedding_activation_buffer
=
None
# done with all the back props for all the microbatches for the last pipeline stage,
self
.
grad_output_buffer
=
None
# it will be in the pipeline flush stage. During this pipeline flush we use the
# input activations stored in embedding activation buffer and gradient outputs
self
.
output_layer
=
tensor_parallel
.
ColumnParallelLinear
(
# stored in gradient buffer to calculate the weight gradients for the embedding
config
.
hidden_size
,
# final linear layer.
self
.
vocab_size
,
self
.
embedding_activation_buffer
=
[]
config
=
config
,
self
.
grad_output_buffer
=
[]
init_method
=
config
.
init_method
,
else
:
bias
=
False
,
self
.
embedding_activation_buffer
=
None
skip_bias_add
=
False
,
self
.
grad_output_buffer
=
None
gather_output
=
not
self
.
parallel_output
,
skip_weight_param_allocation
=
self
.
pre_process
self
.
output_layer
=
tensor_parallel
.
ColumnParallelLinear
(
and
self
.
share_embeddings_and_output_weights
,
config
.
hidden_size
,
embedding_activation_buffer
=
self
.
embedding_activation_buffer
,
self
.
vocab_size
,
grad_output_buffer
=
self
.
grad_output_buffer
,
config
=
config
,
)
init_method
=
config
.
init_method
,
bias
=
False
,
if
self
.
pre_process
or
self
.
post_process
:
skip_bias_add
=
False
,
self
.
setup_embeddings_and_output_layer
()
gather_output
=
not
self
.
parallel_output
,
skip_weight_param_allocation
=
self
.
pre_process
if
has_config_logger_enabled
(
self
.
config
):
and
self
.
share_embeddings_and_output_weights
,
log_config_to_disk
(
embedding_activation_buffer
=
self
.
embedding_activation_buffer
,
self
.
config
,
self
.
state_dict
(),
prefix
=
f
'
{
type
(
self
).
__name__
}
_init_ckpt'
grad_output_buffer
=
self
.
grad_output_buffer
,
)
)
def
set_input_tensor
(
self
,
input_tensor
:
Tensor
)
->
None
:
if
self
.
pre_process
or
self
.
post_process
:
"""Sets input tensor to the model.
self
.
setup_embeddings_and_output_layer
()
See megatron.model.transformer.set_input_tensor()
if
has_config_logger_enabled
(
self
.
config
):
log_config_to_disk
(
Args:
self
.
config
,
self
.
state_dict
(),
prefix
=
f
'
{
type
(
self
).
__name__
}
_init_ckpt'
input_tensor (Tensor): Sets the input tensor for the model.
)
"""
# This is usually handled in schedules.py but some inference code still
def
set_input_tensor
(
self
,
input_tensor
:
Tensor
)
->
None
:
# gives us non-lists or None
"""Sets input tensor to the model.
if
not
isinstance
(
input_tensor
,
list
):
input_tensor
=
[
input_tensor
]
See megatron.model.transformer.set_input_tensor()
assert
len
(
input_tensor
)
==
1
,
'input_tensor should only be length 1 for gpt/bert'
Args:
self
.
decoder
.
set_input_tensor
(
input_tensor
[
0
])
input_tensor (Tensor): Sets the input tensor for the model.
"""
def
forward
(
# This is usually handled in schedules.py but some inference code still
self
,
# gives us non-lists or None
input_ids
:
Tensor
,
if
not
isinstance
(
input_tensor
,
list
):
position_ids
:
Tensor
,
input_tensor
=
[
input_tensor
]
attention_mask
:
Tensor
,
decoder_input
:
Tensor
=
None
,
assert
len
(
input_tensor
)
==
1
,
'input_tensor should only be length 1 for gpt/bert'
labels
:
Tensor
=
None
,
self
.
decoder
.
set_input_tensor
(
input_tensor
[
0
])
inference_params
:
InferenceParams
=
None
,
packed_seq_params
:
PackedSeqParams
=
None
,
def
forward
(
extra_block_kwargs
:
dict
=
None
,
self
,
runtime_gather_output
:
Optional
[
bool
]
=
None
,
input_ids
:
Tensor
,
)
->
Tensor
:
position_ids
:
Tensor
,
"""Forward function of the GPT Model This function passes the input tensors
attention_mask
:
Tensor
,
through the embedding layer, and then the decoeder and finally into the post
decoder_input
:
Tensor
=
None
,
processing layer (optional).
labels
:
Tensor
=
None
,
inference_params
:
InferenceParams
=
None
,
It either returns the Loss values if labels are given or the final hidden units
packed_seq_params
:
PackedSeqParams
=
None
,
extra_block_kwargs
:
dict
=
None
,
Args:
runtime_gather_output
:
Optional
[
bool
]
=
None
,
runtime_gather_output (bool): Gather output at runtime. Default None means
)
->
Tensor
:
`parallel_output` arg in the constructor will be used.
"""Forward function of the GPT Model This function passes the input tensors
"""
through the embedding layer, and then the decoeder and finally into the post
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
processing layer (optional).
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
It either returns the Loss values if labels are given or the final hidden units
# Decoder embedding.
if
decoder_input
is
not
None
:
Args:
pass
runtime_gather_output (bool): Gather output at runtime. Default None means
elif
self
.
pre_process
:
`parallel_output` arg in the constructor will be used.
decoder_input
=
self
.
embedding
(
input_ids
=
input_ids
,
position_ids
=
position_ids
)
"""
else
:
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
# intermediate stage of pipeline
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
# decoder will get hidden_states from encoder.input_tensor
decoder_input
=
None
# Decoder embedding.
if
decoder_input
is
not
None
:
# Rotary positional embeddings (embedding is None for PP intermediate devices)
pass
rotary_pos_emb
=
None
elif
self
.
pre_process
:
rotary_pos_cos
=
None
decoder_input
=
self
.
embedding
(
input_ids
=
input_ids
,
position_ids
=
position_ids
)
rotary_pos_sin
=
None
else
:
if
self
.
position_embedding_type
==
'rope'
and
not
self
.
config
.
multi_latent_attention
:
# intermediate stage of pipeline
if
not
self
.
training
and
self
.
config
.
flash_decode
:
# decoder will get hidden_states from encoder.input_tensor
# Flash decoding uses precomputed cos and sin for RoPE
decoder_input
=
None
rotary_pos_cos
,
rotary_pos_sin
=
self
.
rotary_pos_emb
.
get_cos_sin
(
inference_params
.
max_sequence_length
# Rotary positional embeddings (embedding is None for PP intermediate devices)
)
rotary_pos_emb
=
None
else
:
rotary_pos_cos
=
None
rotary_seq_len
=
self
.
rotary_pos_emb
.
get_rotary_seq_len
(
rotary_pos_sin
=
None
inference_params
,
self
.
decoder
,
decoder_input
,
self
.
config
,
packed_seq_params
if
self
.
position_embedding_type
==
'rope'
and
not
self
.
config
.
multi_latent_attention
:
)
if
not
self
.
training
and
self
.
config
.
flash_decode
and
inference_params
:
rotary_pos_emb
=
self
.
rotary_pos_emb
(
# Flash decoding uses precomputed cos and sin for RoPE
rotary_seq_len
,
rotary_pos_cos
,
rotary_pos_sin
=
self
.
rotary_pos_emb_cache
.
setdefault
(
packed_seq
=
packed_seq_params
is
not
None
inference_params
.
max_sequence_length
,
and
packed_seq_params
.
qkv_format
==
'thd'
,
self
.
rotary_pos_emb
.
get_cos_sin
(
inference_params
.
max_sequence_length
),
)
)
else
:
# Run decoder.
rotary_seq_len
=
self
.
rotary_pos_emb
.
get_rotary_seq_len
(
hidden_states
=
self
.
decoder
(
inference_params
,
self
.
decoder
,
decoder_input
,
self
.
config
,
packed_seq_params
hidden_states
=
decoder_input
,
)
attention_mask
=
attention_mask
,
rotary_pos_emb
=
self
.
rotary_pos_emb
(
inference_params
=
inference_params
,
rotary_seq_len
,
rotary_pos_emb
=
rotary_pos_emb
,
packed_seq
=
packed_seq_params
is
not
None
rotary_pos_cos
=
rotary_pos_cos
,
and
packed_seq_params
.
qkv_format
==
'thd'
,
rotary_pos_sin
=
rotary_pos_sin
,
)
packed_seq_params
=
packed_seq_params
,
if
(
**
(
extra_block_kwargs
or
{}),
(
self
.
config
.
enable_cuda_graph
or
self
.
config
.
flash_decode
)
)
and
rotary_pos_cos
is
not
None
and
inference_params
if
not
self
.
post_process
:
):
return
hidden_states
sequence_len_offset
=
torch
.
tensor
(
[
inference_params
.
sequence_len_offset
]
*
inference_params
.
current_batch_size
,
# logits and loss
dtype
=
torch
.
int32
,
output_weight
=
None
device
=
rotary_pos_cos
.
device
,
# Co-locate this with the rotary tensors
if
self
.
share_embeddings_and_output_weights
:
)
output_weight
=
self
.
shared_embedding_or_output_weight
()
else
:
logits
,
_
=
self
.
output_layer
(
sequence_len_offset
=
None
hidden_states
,
weight
=
output_weight
,
runtime_gather_output
=
runtime_gather_output
)
# Run decoder.
hidden_states
=
self
.
decoder
(
if
has_config_logger_enabled
(
self
.
config
):
hidden_states
=
decoder_input
,
payload
=
OrderedDict
(
attention_mask
=
attention_mask
,
{
inference_params
=
inference_params
,
'input_ids'
:
input_ids
,
rotary_pos_emb
=
rotary_pos_emb
,
'position_ids'
:
position_ids
,
rotary_pos_cos
=
rotary_pos_cos
,
'attention_mask'
:
attention_mask
,
rotary_pos_sin
=
rotary_pos_sin
,
'decoder_input'
:
decoder_input
,
packed_seq_params
=
packed_seq_params
,
'logits'
:
logits
,
sequence_len_offset
=
sequence_len_offset
,
}
**
(
extra_block_kwargs
or
{}),
)
)
log_config_to_disk
(
self
.
config
,
payload
,
prefix
=
'input_and_logits'
)
if
not
self
.
post_process
:
if
labels
is
None
:
return
hidden_states
# [s b h] => [b s h]
return
logits
.
transpose
(
0
,
1
).
contiguous
()
# logits and loss
output_weight
=
None
loss
=
self
.
compute_language_model_loss
(
labels
,
logits
)
if
self
.
share_embeddings_and_output_weights
:
output_weight
=
self
.
shared_embedding_or_output_weight
()
return
loss
logits
,
_
=
self
.
output_layer
(
hidden_states
,
weight
=
output_weight
,
runtime_gather_output
=
runtime_gather_output
def
sharded_state_dict
(
)
self
,
prefix
:
str
=
''
,
sharded_offsets
:
tuple
=
(),
metadata
:
Optional
[
Dict
]
=
None
)
->
ShardedStateDict
:
if
has_config_logger_enabled
(
self
.
config
):
"""Sharded state dict implementation for GPTModel backward-compatibility
payload
=
OrderedDict
(
(removing extra state).
{
'input_ids'
:
input_ids
,
Args:
'position_ids'
:
position_ids
,
prefix (str): Module name prefix.
'attention_mask'
:
attention_mask
,
sharded_offsets (tuple): PP related offsets, expected to be empty at this module level.
'decoder_input'
:
decoder_input
,
metadata (Optional[Dict]): metadata controlling sharded state dict creation.
'logits'
:
logits
,
}
Returns:
)
ShardedStateDict: sharded state dict for the GPTModel
log_config_to_disk
(
self
.
config
,
payload
,
prefix
=
'input_and_logits'
)
"""
sharded_state_dict
=
super
().
sharded_state_dict
(
prefix
,
sharded_offsets
,
metadata
)
if
labels
is
None
:
output_layer_extra_state_key
=
f
'
{
prefix
}
output_layer._extra_state'
# [s b h] => [b s h]
return
logits
.
transpose
(
0
,
1
).
contiguous
()
# Old GPT checkpoints only stored the output layer weight key. So we remove the
# _extra_state key but check that it doesn't contain any data anyway
loss
=
self
.
compute_language_model_loss
(
labels
,
logits
)
output_extra_state
=
sharded_state_dict
.
pop
(
output_layer_extra_state_key
,
None
)
assert
not
(
return
loss
output_extra_state
and
output_extra_state
.
data
),
f
'Expected output layer extra state to be empty, got:
{
output_extra_state
}
'
def
sharded_state_dict
(
self
,
prefix
:
str
=
''
,
sharded_offsets
:
tuple
=
(),
metadata
:
Optional
[
Dict
]
=
None
return
sharded_state_dict
)
->
ShardedStateDict
:
"""Sharded state dict implementation for GPTModel backward-compatibility
(removing extra state).
Args:
prefix (str): Module name prefix.
sharded_offsets (tuple): PP related offsets, expected to be empty at this module level.
metadata (Optional[Dict]): metadata controlling sharded state dict creation.
Returns:
ShardedStateDict: sharded state dict for the GPTModel
"""
sharded_state_dict
=
super
().
sharded_state_dict
(
prefix
,
sharded_offsets
,
metadata
)
output_layer_extra_state_key
=
f
'
{
prefix
}
output_layer._extra_state'
# Old GPT checkpoints only stored the output layer weight key. So we remove the
# _extra_state key but check that it doesn't contain any data anyway
output_extra_state
=
sharded_state_dict
.
pop
(
output_layer_extra_state_key
,
None
)
assert
not
(
output_extra_state
and
output_extra_state
.
data
),
f
'Expected output layer extra state to be empty, got:
{
output_extra_state
}
'
return
sharded_state_dict
Prev
1
…
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment