Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
dcu_megatron
Commits
4e867b3c
Commit
4e867b3c
authored
Aug 06, 2025
by
jerrrrry
Browse files
Initial commit
parents
Changes
327
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
698 additions
and
0 deletions
+698
-0
Megatron-LM/megatron/core/fusions/fused_bias_swiglu.py
Megatron-LM/megatron/core/fusions/fused_bias_swiglu.py
+89
-0
Megatron-LM/megatron/core/fusions/fused_cross_entropy.py
Megatron-LM/megatron/core/fusions/fused_cross_entropy.py
+148
-0
Megatron-LM/megatron/core/fusions/fused_layer_norm.py
Megatron-LM/megatron/core/fusions/fused_layer_norm.py
+169
-0
Megatron-LM/megatron/core/fusions/fused_softmax.py
Megatron-LM/megatron/core/fusions/fused_softmax.py
+220
-0
Megatron-LM/megatron/core/inference/__init__.py
Megatron-LM/megatron/core/inference/__init__.py
+1
-0
Megatron-LM/megatron/core/inference/async_stream.py
Megatron-LM/megatron/core/inference/async_stream.py
+67
-0
Megatron-LM/megatron/core/inference/common_inference_params.py
...ron-LM/megatron/core/inference/common_inference_params.py
+4
-0
No files found.
Too many changes to show.
To preserve performance only
327 of 327+
files are displayed.
Plain diff
Email patch
Megatron-LM/megatron/core/fusions/fused_bias_swiglu.py
0 → 100644
View file @
4e867b3c
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import
torch
import
torch.nn.functional
as
F
from
megatron.core.jit
import
jit_fuser
###### BIAS SWIGLU FUSION/ NO AUTOGRAD ################
@
jit_fuser
def
swiglu
(
y
):
y_1
,
y_2
=
torch
.
chunk
(
y
,
2
,
-
1
)
return
F
.
silu
(
y_1
)
*
y_2
@
jit_fuser
def
bias_swiglu
(
y
,
bias
):
y
=
y
+
bias
return
swiglu
(
y
)
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@
jit_fuser
def
swiglu_back
(
g
,
y
):
y_1
,
y_2
=
torch
.
chunk
(
y
,
2
,
-
1
)
return
torch
.
cat
(
(
g
*
torch
.
sigmoid
(
y_1
)
*
(
1
+
y_1
*
(
1
-
torch
.
sigmoid
(
y_1
)))
*
y_2
,
g
*
F
.
silu
(
y_1
)),
-
1
)
@
jit_fuser
def
bias_swiglu_back
(
g
,
y
,
bias
):
y
=
y
+
bias
return
swiglu_back
(
g
,
y
)
class
BiasSwiGLUFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
# bias is an optional argument
def
forward
(
ctx
,
input
,
bias
,
fp8_input_store
):
input_for_backward
=
input
.
to
(
torch
.
float8_e4m3fn
)
if
fp8_input_store
else
input
ctx
.
save_for_backward
(
input_for_backward
,
bias
)
ctx
.
ori_input_dtype
=
input
.
dtype
ctx
.
fp8_input_store
=
fp8_input_store
return
bias_swiglu
(
input
,
bias
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
input
,
bias
=
ctx
.
saved_tensors
input
=
input
.
to
(
ctx
.
ori_input_dtype
)
if
ctx
.
fp8_input_store
else
input
tmp
=
bias_swiglu_back
(
grad_output
,
input
,
bias
)
return
tmp
,
tmp
,
None
class
SwiGLUFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
# bias is an optional argument
def
forward
(
ctx
,
input
,
fp8_input_store
):
input_for_backward
=
input
.
to
(
torch
.
float8_e4m3fn
)
if
fp8_input_store
else
input
ctx
.
save_for_backward
(
input_for_backward
)
ctx
.
ori_input_dtype
=
input
.
dtype
ctx
.
fp8_input_store
=
fp8_input_store
return
swiglu
(
input
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
input
=
ctx
.
saved_tensors
[
0
]
input
=
input
.
to
(
ctx
.
ori_input_dtype
)
if
ctx
.
fp8_input_store
else
input
tmp
=
swiglu_back
(
grad_output
,
input
)
return
tmp
,
None
def
bias_swiglu_impl
(
input
,
bias
,
fp8_input_store
=
False
):
ori_shape
=
input
.
shape
assert
len
(
ori_shape
)
in
[
2
,
3
]
input
=
input
.
view
(
-
1
,
ori_shape
[
-
1
])
if
bias
is
not
None
:
output
=
BiasSwiGLUFunction
.
apply
(
input
,
bias
,
fp8_input_store
)
else
:
output
=
SwiGLUFunction
.
apply
(
input
,
fp8_input_store
)
return
output
if
len
(
ori_shape
)
==
2
else
output
.
view
(
ori_shape
[
0
],
ori_shape
[
1
],
-
1
)
# bias_swiglu_impl = BiasSwiGLUFunction.apply
# swiglu_impl = SwiGLUFunction.apply
Megatron-LM/megatron/core/fusions/fused_cross_entropy.py
0 → 100644
View file @
4e867b3c
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
typing
import
Tuple
import
torch
from
megatron.core.jit
import
jit_fuser
from
megatron.core.tensor_parallel.cross_entropy
import
VocabParallelCrossEntropy
from
megatron.core.tensor_parallel.utils
import
VocabUtility
@
jit_fuser
def
calculate_logits_max
(
vocab_parallel_logits
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Calculates the maximum logits of the predicted tokens.
"""
vocab_parallel_logits
,
logits_max
=
VocabParallelCrossEntropy
.
calculate_logits_max
(
vocab_parallel_logits
)
return
vocab_parallel_logits
,
logits_max
@
jit_fuser
def
calculate_predicted_logits
(
vocab_parallel_logits
:
torch
.
Tensor
,
target
:
torch
.
Tensor
,
logits_max
:
torch
.
Tensor
,
vocab_start_index
:
int
,
vocab_end_index
:
int
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Calculates the predicted logits for the tokens.
"""
(
target_mask
,
masked_target_1d
,
predicted_logits
,
sum_exp_logits
,
exp_logits
)
=
(
VocabParallelCrossEntropy
.
calculate_predicted_logits
(
vocab_parallel_logits
,
target
,
logits_max
,
vocab_start_index
,
vocab_end_index
)
)
predicted_logits_sum_exp_logits
=
torch
.
cat
((
predicted_logits
,
sum_exp_logits
))
return
target_mask
,
masked_target_1d
,
predicted_logits_sum_exp_logits
,
exp_logits
@
jit_fuser
def
calculate_cross_entropy_loss
(
exp_logits
:
torch
.
Tensor
,
predicted_logits_sum_exp_logits
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Calculates the final cross entropy loss for the tokens.
"""
split_val
=
predicted_logits_sum_exp_logits
.
size
()[
0
]
//
2
predicted_logits
,
sum_exp_logits
=
torch
.
split
(
predicted_logits_sum_exp_logits
,
split_val
)
exp_logits
,
loss
=
VocabParallelCrossEntropy
.
calculate_cross_entropy_loss
(
exp_logits
,
predicted_logits
,
sum_exp_logits
)
return
exp_logits
,
loss
@
jit_fuser
def
calculate_gradients
(
softmax
:
torch
.
Tensor
,
grad_output
:
torch
.
Tensor
,
target_mask
:
torch
.
Tensor
,
masked_target_1d
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""
Calculate the logits gradients scaled based on the CE loss
"""
(
grad_2d
,
arange_1d
,
softmax_update
,
grad_input
)
=
(
VocabParallelCrossEntropy
.
prepare_gradient_calculation_operands
(
softmax
,
target_mask
)
)
grad_input
=
VocabParallelCrossEntropy
.
calculate_gradients
(
grad_2d
,
arange_1d
,
masked_target_1d
,
softmax_update
,
grad_input
,
grad_output
)
grad_input
=
grad_input
.
to
(
torch
.
bfloat16
)
return
grad_input
class
_VocabParallelCrossEntropy
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
vocab_parallel_logits
,
target
,
tp_group
):
"""
Forward implementation for the cross entropy loss.
"""
vocab_parallel_logits
,
logits_max
=
calculate_logits_max
(
vocab_parallel_logits
)
torch
.
distributed
.
all_reduce
(
logits_max
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
tp_group
)
# Get the partition's vocab indices
get_vocab_range
=
VocabUtility
.
vocab_range_from_per_partition_vocab_size
partition_vocab_size
=
vocab_parallel_logits
.
size
()[
-
1
]
vocab_start_index
,
vocab_end_index
=
get_vocab_range
(
partition_vocab_size
,
tp_group
.
rank
(),
tp_group
.
size
()
)
(
target_mask
,
masked_target_1d
,
predicted_logits_sum_exp_logits
,
exp_logits
)
=
(
calculate_predicted_logits
(
vocab_parallel_logits
,
target
,
logits_max
,
vocab_start_index
,
vocab_end_index
)
)
# All reduce is needed to get the chunks from other GPUs.
# In the fused case, tensors are batches to invoke a single
# AllReduce call
torch
.
distributed
.
all_reduce
(
predicted_logits_sum_exp_logits
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
tp_group
)
exp_logits
,
loss
=
calculate_cross_entropy_loss
(
exp_logits
,
predicted_logits_sum_exp_logits
)
# Store softmax, target-mask and masked-target for backward pass.
ctx
.
save_for_backward
(
exp_logits
,
target_mask
,
masked_target_1d
)
return
loss
@
staticmethod
def
backward
(
ctx
,
grad_output
):
"""
Backward implementation for the cross entropy loss.
"""
# Retreive tensors from the forward path.
softmax
,
target_mask
,
masked_target_1d
=
ctx
.
saved_tensors
grad_input
=
calculate_gradients
(
softmax
,
grad_output
,
target_mask
,
masked_target_1d
)
return
grad_input
,
None
,
None
def
fused_vocab_parallel_cross_entropy
(
vocab_parallel_logits
,
target
,
tp_group
):
"""
Performs cross entropy loss when logits are split across tensor parallel ranks
Args:
vocab_parallel_logits: logits split across tensor parallel ranks
dimension is [sequence_length, batch_size, hidden_size]
target: correct vocab ids of dimseion [sequence_length, micro_batch_size]
tp_group: the tensor parallel group over which to all reduce
"""
return
_VocabParallelCrossEntropy
.
apply
(
vocab_parallel_logits
,
target
,
tp_group
)
Megatron-LM/megatron/core/fusions/fused_layer_norm.py
0 → 100644
View file @
4e867b3c
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import
importlib
import
inspect
import
numbers
import
torch
from
torch
import
Tensor
from
torch.nn
import
init
from
torch.nn.parameter
import
Parameter
from
megatron.core.transformer
import
TransformerConfig
from
megatron.core.utils
import
make_viewless_tensor
try
:
from
apex.contrib.layer_norm.layer_norm
import
FastLayerNormFN
HAVE_PERSIST_LAYER_NORM
=
True
except
ImportError
:
HAVE_PERSIST_LAYER_NORM
=
False
try
:
from
apex.normalization.fused_layer_norm
import
FusedLayerNormAffineFunction
HAVE_FUSED_LAYER_NORM
=
True
except
ImportError
:
HAVE_FUSED_LAYER_NORM
=
False
class
FusedLayerNorm
(
torch
.
nn
.
Module
):
"""Layer Norm, fused into a single CUDA kernel.
Args:
hidden_size (int): Transformer hidden dimension.
eps (float): Epsilon added to denominator, for numerical stability.
persist_layer_norm (bool): Use persistent fused layer norm kernel.
This kernel supports only a set of hidden sizes. Please
check persist_ln_hidden_sizes if your hidden size is supported.
zero_centered_gamma (bool): Adjust LayerNorm weights such that they are
centered around zero. This improves numerical stability.
config (TransformerConfig): Transformer config. Include to match custom
layer norm interfaces.
normalization (str): Normalization type, used for Transformer Engine.
Must equal 'LayerNorm' here.
"""
def
__init__
(
self
,
config
:
TransformerConfig
,
hidden_size
:
int
,
eps
:
float
=
1e-5
,
persist_layer_norm
:
bool
=
True
,
zero_centered_gamma
:
bool
=
False
,
normalization
:
str
=
"LayerNorm"
,
# included to match TE interface
):
super
().
__init__
()
self
.
config
=
config
self
.
zero_centered_gamma
=
self
.
config
.
layernorm_zero_centered_gamma
assert
(
self
.
config
.
normalization
==
"LayerNorm"
),
f
'(
{
self
.
config
.
normalization
}
) is not supported in FusedLayerNorm'
# List of hiddens sizes supported in the persistent layer norm kernel
# If the hidden size is not supported, fall back to the non-persistent
# kernel.
persist_ln_hidden_sizes
=
[
1024
,
1536
,
2048
,
2304
,
3072
,
3840
,
4096
,
5120
,
6144
,
8192
,
10240
,
12288
,
12800
,
15360
,
16384
,
18432
,
20480
,
24576
,
25600
,
30720
,
32768
,
40960
,
49152
,
65536
,
]
persist_layer_norm
=
self
.
config
.
persist_layer_norm
if
hidden_size
not
in
persist_ln_hidden_sizes
or
not
HAVE_PERSIST_LAYER_NORM
:
persist_layer_norm
=
False
if
not
persist_layer_norm
and
not
HAVE_FUSED_LAYER_NORM
:
# TODO: Add pytorch only layer norm
raise
ValueError
(
f
'Apex must be installed to use FusedLayerNorm.'
)
if
isinstance
(
hidden_size
,
numbers
.
Integral
):
hidden_size
=
(
hidden_size
,)
self
.
hidden_size
=
torch
.
Size
(
hidden_size
)
self
.
eps
=
eps
# Parameters need to be initialized with torch.empty rather than torch.Tensor for correct device placement with nemo2.
self
.
weight
=
Parameter
(
torch
.
empty
(
*
hidden_size
))
self
.
bias
=
Parameter
(
torch
.
empty
(
*
hidden_size
))
self
.
reset_parameters
()
self
.
persist_layer_norm
=
persist_layer_norm
self
.
sequence_parallel
=
self
.
config
.
sequence_parallel
# set sequence parallelism flag on weight and bias parameters
setattr
(
self
.
weight
,
'sequence_parallel'
,
self
.
sequence_parallel
)
setattr
(
self
.
bias
,
'sequence_parallel'
,
self
.
sequence_parallel
)
def
reset_parameters
(
self
):
if
self
.
zero_centered_gamma
:
init
.
zeros_
(
self
.
weight
)
init
.
zeros_
(
self
.
bias
)
else
:
init
.
ones_
(
self
.
weight
)
init
.
zeros_
(
self
.
bias
)
def
forward
(
self
,
input
:
Tensor
)
->
Tensor
:
weight
=
self
.
weight
+
1
if
self
.
zero_centered_gamma
else
self
.
weight
if
self
.
persist_layer_norm
:
if
'memory_efficient'
in
inspect
.
getfullargspec
(
FastLayerNormFN
.
forward
).
args
:
output
=
FastLayerNormFN
.
apply
(
input
,
weight
,
self
.
bias
,
self
.
eps
,
self
.
config
.
memory_efficient_layer_norm
)
else
:
output
=
FastLayerNormFN
.
apply
(
input
,
weight
,
self
.
bias
,
self
.
eps
)
# Apex's fast layer norm function outputs a 'view' tensor (i.e., has
# a populated '_base' field). This will result in schedule.py's
# deallocate_output_tensor() throwing an error, so a viewless tensor is
# created to prevent this.
output
=
make_viewless_tensor
(
inp
=
output
,
requires_grad
=
input
.
requires_grad
,
keep_graph
=
True
)
else
:
if
(
'memory_efficient'
in
inspect
.
getfullargspec
(
FusedLayerNormAffineFunction
.
forward
).
args
):
return
FusedLayerNormAffineFunction
.
apply
(
input
,
weight
,
self
.
bias
,
self
.
hidden_size
,
self
.
eps
,
self
.
config
.
memory_efficient_layer_norm
,
)
else
:
return
FusedLayerNormAffineFunction
.
apply
(
input
,
weight
,
self
.
bias
,
self
.
hidden_size
,
self
.
eps
)
return
output
Megatron-LM/megatron/core/fusions/fused_softmax.py
0 → 100644
View file @
4e867b3c
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
from
megatron.core.transformer.enums
import
AttnMaskType
from
megatron.core.transformer.utils
import
get_default_causal_mask
class
ScaledUpperTriangMaskedSoftmax
(
torch
.
autograd
.
Function
):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply upper triangular mask (typically used in gpt models).
3. Perform softmax.
"""
@
staticmethod
def
forward
(
ctx
,
inputs
,
scale
):
import
scaled_upper_triang_masked_softmax_cuda
scale_t
=
torch
.
tensor
([
scale
])
softmax_results
=
scaled_upper_triang_masked_softmax_cuda
.
forward
(
inputs
,
scale_t
[
0
])
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
return
softmax_results
@
staticmethod
def
backward
(
ctx
,
output_grads
):
import
scaled_upper_triang_masked_softmax_cuda
softmax_results
,
scale_t
=
ctx
.
saved_tensors
input_grads
=
scaled_upper_triang_masked_softmax_cuda
.
backward
(
output_grads
,
softmax_results
,
scale_t
[
0
]
)
return
input_grads
,
None
class
ScaledMaskedSoftmax
(
torch
.
autograd
.
Function
):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply the mask.
3. Perform softmax.
"""
@
staticmethod
def
forward
(
ctx
,
inputs
,
mask
,
scale
):
import
scaled_masked_softmax_cuda
scale_t
=
torch
.
tensor
([
scale
])
softmax_results
=
scaled_masked_softmax_cuda
.
forward
(
inputs
,
mask
,
scale_t
[
0
])
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
return
softmax_results
@
staticmethod
def
backward
(
ctx
,
output_grads
):
import
scaled_masked_softmax_cuda
softmax_results
,
scale_t
=
ctx
.
saved_tensors
input_grads
=
scaled_masked_softmax_cuda
.
backward
(
output_grads
,
softmax_results
,
scale_t
[
0
])
return
input_grads
,
None
,
None
class
ScaledSoftmax
(
torch
.
autograd
.
Function
):
"""
Fused operation which performs following two operations in sequence
1. Scale the tensor.
2. Perform softmax.
"""
@
staticmethod
def
forward
(
ctx
,
inputs
,
scale
):
import
scaled_softmax_cuda
scale_t
=
torch
.
tensor
([
scale
])
softmax_results
=
scaled_softmax_cuda
.
forward
(
inputs
,
scale_t
[
0
])
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
return
softmax_results
@
staticmethod
def
backward
(
ctx
,
output_grads
):
import
scaled_softmax_cuda
softmax_results
,
scale_t
=
ctx
.
saved_tensors
input_grads
=
scaled_softmax_cuda
.
backward
(
output_grads
,
softmax_results
,
scale_t
[
0
])
return
input_grads
,
None
,
None
class
FusedScaleMaskSoftmax
(
nn
.
Module
):
"""
fused operation: scaling + mask + softmax
Args:
input_in_fp16: flag to indicate if input in fp16 data format.
input_in_bf16: flag to indicate if input in bf16 data format.
attn_mask_type: attention mask type (pad or causal)
scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion
mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling.
"""
def
__init__
(
self
,
input_in_fp16
,
input_in_bf16
,
attn_mask_type
,
scaled_masked_softmax_fusion
,
mask_func
,
softmax_in_fp32
,
scale
,
):
super
(
FusedScaleMaskSoftmax
,
self
).
__init__
()
self
.
input_in_fp16
=
input_in_fp16
self
.
input_in_bf16
=
input_in_bf16
assert
not
(
self
.
input_in_fp16
and
self
.
input_in_bf16
),
"both fp16 and bf16 flags cannot be active at the same time."
self
.
input_in_float16
=
self
.
input_in_fp16
or
self
.
input_in_bf16
self
.
attn_mask_type
=
attn_mask_type
self
.
scaled_masked_softmax_fusion
=
scaled_masked_softmax_fusion
self
.
mask_func
=
mask_func
self
.
softmax_in_fp32
=
softmax_in_fp32
self
.
scale
=
scale
assert
self
.
scale
is
None
or
softmax_in_fp32
,
"softmax should be in fp32 when scaled"
def
forward
(
self
,
input
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]):
"""Forward pass of softmax with masked input.
In case attn_mask_type is causal the mask is generated and None can be passed.
A user-defined mask is only needed when attn_mask_type is not causal.
"""
# [b, np, sq, sk]
assert
input
.
dim
()
==
4
if
self
.
is_kernel_available
(
mask
,
*
input
.
size
()):
return
self
.
forward_fused_softmax
(
input
,
mask
)
else
:
return
self
.
forward_torch_softmax
(
input
,
mask
)
def
is_kernel_available
(
self
,
mask
,
b
,
np
,
sq
,
sk
):
attn_batches
=
b
*
np
if
(
self
.
scaled_masked_softmax_fusion
# user want to fuse
and
self
.
input_in_float16
# input must be fp16
and
16
<
sk
<=
4096
# sk must be 16 ~ 2048
and
sq
%
4
==
0
# sq must be divisor of 4
and
sk
%
4
==
0
# sk must be divisor of 4
and
attn_batches
%
4
==
0
# np * b must be divisor of 4
):
if
0
<=
sk
<=
4096
:
batch_per_block
=
self
.
get_batch_per_block
(
sq
,
sk
,
b
,
np
)
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
if
attn_batches
%
batch_per_block
==
0
:
return
True
else
:
if
sq
%
batch_per_block
==
0
:
return
True
return
False
def
forward_fused_softmax
(
self
,
input
,
mask
):
b
,
np
,
sq
,
sk
=
input
.
size
()
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
assert
sq
==
sk
,
"causal mask is only for self attention"
# input is 3D tensor (attn_batches, sq, sk)
input
=
input
.
view
(
-
1
,
sq
,
sk
)
probs
=
ScaledUpperTriangMaskedSoftmax
.
apply
(
input
,
scale
)
return
probs
.
view
(
b
,
np
,
sq
,
sk
)
else
:
# input is 4D tensor (b, np, sq, sk)
if
mask
is
not
None
:
return
ScaledMaskedSoftmax
.
apply
(
input
,
mask
,
scale
)
else
:
return
ScaledSoftmax
.
apply
(
input
,
scale
)
def
forward_torch_softmax
(
self
,
input
,
mask
):
if
self
.
input_in_float16
and
self
.
softmax_in_fp32
:
input
=
input
.
float
()
if
self
.
scale
is
not
None
:
input
=
input
*
self
.
scale
# Generate causal mask if not given
sq
,
sk
=
input
.
size
(
2
),
input
.
size
(
3
)
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
and
mask
is
None
and
sq
>
1
:
# If sq == 1 then either KV cache is used or one-element context is passed
# so keeping mask=None in this case; subsequent code should handle it
assert
sq
==
sk
,
"causal mask is only for self attention"
mask
=
get_default_causal_mask
(
sq
)
mask_output
=
self
.
mask_func
(
input
,
mask
)
if
mask
is
not
None
else
input
probs
=
torch
.
nn
.
Softmax
(
dim
=-
1
)(
mask_output
)
if
self
.
input_in_float16
and
self
.
softmax_in_fp32
:
if
self
.
input_in_fp16
:
probs
=
probs
.
half
()
else
:
probs
=
probs
.
bfloat16
()
return
probs
@
staticmethod
def
get_batch_per_block
(
sq
,
sk
,
b
,
np
):
import
scaled_masked_softmax_cuda
return
scaled_masked_softmax_cuda
.
get_batch_per_block
(
sq
,
sk
,
b
,
np
)
Megatron-LM/megatron/core/inference/__init__.py
0 → 100644
View file @
4e867b3c
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
Megatron-LM/megatron/core/inference/async_stream.py
0 → 100644
View file @
4e867b3c
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Copyright 2025 The vLLM authors.
#
# This code was adopted from https://github.com/vllm-project/vllm/
# This source code is licensed under the Apache license found in the
# LICENSE file in the root directory of this source tree.
import
asyncio
from
typing
import
Any
,
AsyncGenerator
,
Callable
,
Optional
,
Type
,
Union
from
megatron.core.inference.inference_request
import
InferenceRequest
STOP_ITERATION
=
Exception
()
class
AsyncStream
:
"""
Class for encapsulating an asynchronous stream of InferenceRequest outputs.
Adopted from https://github.com/vllm-project/vllm/blob/eb881ed006ca458b052905e33f0d16dbb428063a/vllm/v1/engine/async_stream.py # pylint: disable=line-too-long
"""
def
__init__
(
self
,
request_id
:
str
,
cancel
:
Callable
[[
str
],
None
])
->
None
:
self
.
_request_id
=
request_id
self
.
_cancel
=
cancel
self
.
_queue
:
asyncio
.
Queue
=
asyncio
.
Queue
()
self
.
_finished
=
False
self
.
_loop
=
asyncio
.
get_running_loop
()
def
put
(
self
,
item
:
Union
[
InferenceRequest
,
Exception
])
->
None
:
"""Adds a new value to the stream"""
if
not
self
.
_finished
:
self
.
_loop
.
call_soon_threadsafe
(
self
.
_queue
.
put_nowait
,
item
)
def
finish
(
self
,
exception
:
Optional
[
Union
[
BaseException
,
Type
[
BaseException
]]]
=
None
)
->
None
:
"""Completes the stream by adding a sentinel value"""
if
not
self
.
_finished
:
self
.
_finished
=
True
self
.
_loop
.
call_soon_threadsafe
(
self
.
_queue
.
put_nowait
,
exception
if
self
.
_is_raisable
(
exception
)
else
STOP_ITERATION
,
)
@
property
def
finished
(
self
)
->
bool
:
"""Whether the stream has finished"""
return
self
.
_finished
async
def
generator
(
self
)
->
AsyncGenerator
[
InferenceRequest
,
None
]:
"""Creates an AsyncGenerator over the stream queue"""
try
:
while
True
:
result
=
await
self
.
_queue
.
get
()
if
self
.
_is_raisable
(
result
):
if
result
==
STOP_ITERATION
:
return
raise
result
yield
result
except
GeneratorExit
:
self
.
_cancel
()
raise
asyncio
.
CancelledError
from
None
@
staticmethod
def
_is_raisable
(
value
:
Any
):
return
isinstance
(
value
,
BaseException
)
or
(
isinstance
(
value
,
type
)
and
issubclass
(
value
,
BaseException
)
)
Megatron-LM/megatron/core/inference/common_inference_params.py
0 → 100644
View file @
4e867b3c
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from
megatron.core.inference.sampling_params
import
(
# noqa: F401 # pylint: disable=unused-import
SamplingParams
as
CommonInferenceParams
,
)
Prev
1
…
13
14
15
16
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment