Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
44740c6c
Commit
44740c6c
authored
Jul 22, 2025
by
yuguo
Browse files
Merge commit '
7a9a0825
' of...
Merge commit '
7a9a0825
' of
https://github.com/NVIDIA/TransformerEngine
parents
8113d9e0
7a9a0825
Changes
162
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
620 additions
and
242 deletions
+620
-242
transformer_engine/pytorch/triton/permutation.py
transformer_engine/pytorch/triton/permutation.py
+512
-241
transformer_engine/pytorch/utils.py
transformer_engine/pytorch/utils.py
+108
-1
No files found.
transformer_engine/pytorch/triton/permutation.py
View file @
44740c6c
...
...
@@ -10,6 +10,72 @@ import torch
import
triton
import
triton.language
as
tl
from
triton.language
import
core
from
triton.language.standard
import
_log2
# The following three argsort related kernels are adapted from
# the issue https://github.com/triton-lang/triton/issues/3698
@
triton
.
jit
def
_compare_and_swap
(
x
,
indices
,
flip
,
i
:
tl
.
constexpr
,
n_dims
:
tl
.
constexpr
):
n_outer
:
tl
.
constexpr
=
x
.
numel
>>
n_dims
shape
:
tl
.
constexpr
=
[
n_outer
*
(
2
**
i
),
2
,
2
**
(
n_dims
-
i
-
1
)]
y
=
tl
.
reshape
(
x
,
shape
)
z
=
tl
.
reshape
(
indices
,
shape
)
mask
=
tl
.
arange
(
0
,
2
)[
None
,
:,
None
]
l_value
=
tl
.
reshape
(
tl
.
broadcast_to
(
tl
.
sum
(
y
*
(
1
-
mask
),
1
)[:,
None
,
:],
shape
),
x
.
shape
).
to
(
x
.
dtype
)
r_value
=
tl
.
reshape
(
tl
.
broadcast_to
(
tl
.
sum
(
y
*
mask
,
1
)[:,
None
,
:],
shape
),
x
.
shape
).
to
(
x
.
dtype
)
l_indice
=
tl
.
reshape
(
tl
.
broadcast_to
(
tl
.
sum
(
z
*
(
1
-
mask
),
1
)[:,
None
,
:],
shape
),
x
.
shape
)
r_indice
=
tl
.
reshape
(
tl
.
broadcast_to
(
tl
.
sum
(
z
*
mask
,
1
)[:,
None
,
:],
shape
),
x
.
shape
)
idtype
=
core
.
get_int_dtype
(
bitwidth
=
x
.
dtype
.
primitive_bitwidth
,
signed
=
True
)
il_value
=
l_value
.
to
(
idtype
,
bitcast
=
True
)
ir_value
=
r_value
.
to
(
idtype
,
bitcast
=
True
)
ix
=
x
.
to
(
idtype
,
bitcast
=
True
)
flag1
=
tl
.
where
(((
l_value
>
r_value
)
^
flip
)
!=
0
,
il_value
^
ir_value
,
tl
.
zeros_like
(
ix
))
ret
=
ix
^
flag1
flag2
=
tl
.
where
(((
l_value
>
r_value
)
^
flip
)
!=
0
,
l_indice
^
r_indice
,
tl
.
zeros_like
(
ix
))
ind
=
indices
^
flag2
return
ret
.
to
(
x
.
dtype
,
bitcast
=
True
),
ind
@
triton
.
jit
def
_bitonic_merge
(
x
,
indices
,
stage
:
tl
.
constexpr
,
order
:
tl
.
constexpr
,
n_dims
:
tl
.
constexpr
):
n_outer
:
tl
.
constexpr
=
x
.
numel
>>
n_dims
tl
.
static_assert
(
stage
<=
n_dims
)
"""
order_type 0 == ascending
order_type 1 == descending
order_type 2 == alternating
"""
if
order
==
2
:
shape
:
tl
.
constexpr
=
[
n_outer
*
(
2
**
(
n_dims
-
1
-
stage
)),
2
,
2
**
stage
]
flip
=
tl
.
reshape
(
tl
.
broadcast_to
(
tl
.
arange
(
0
,
2
)[
None
,
:,
None
],
shape
),
x
.
shape
)
else
:
flip
=
tl
.
full
(
x
.
shape
,
value
=
order
,
dtype
=
tl
.
int32
)
for
i
in
tl
.
static_range
(
stage
):
x
,
indices
=
_compare_and_swap
(
x
,
indices
,
flip
,
i
+
(
n_dims
-
stage
),
n_dims
)
return
x
,
indices
@
triton
.
jit
def
_argsort
(
x
,
indices
,
n_dims
:
tl
.
constexpr
):
for
i
in
tl
.
static_range
(
1
,
n_dims
+
1
):
x
,
indices
=
_bitonic_merge
(
x
,
indices
,
i
,
2
if
i
<
n_dims
else
1
,
n_dims
)
return
x
,
indices
@
triton
.
jit
def
_row_id_map_pass_1_kernel
(
...
...
@@ -22,6 +88,8 @@ def _row_id_map_pass_1_kernel(
# strides
stride_routing_map_token
,
stride_routing_map_expert
,
stride_row_id_map_token
,
stride_row_id_map_expert
,
# metas
BLOCK_SIZE
:
tl
.
constexpr
,
):
...
...
@@ -32,10 +100,10 @@ def _row_id_map_pass_1_kernel(
routing_map_ptr
+
pid_m
*
stride_routing_map_expert
+
offset
*
stride_routing_map_token
,
mask
=
(
offset
<
num_tokens
),
other
=
0
,
).
to
(
tl
.
int
64
)
).
to
(
tl
.
int
32
)
row_id_within_token_block
=
tl
.
cumsum
(
expert_token_mask
)
*
expert_token_mask
tl
.
store
(
row_id_map_ptr
+
pid_m
*
num_tokens
+
offset
,
row_id_map_ptr
+
pid_m
*
stride_row_id_map_expert
+
offset
*
stride_row_id_map_token
,
row_id_within_token_block
,
mask
=
offset
<
num_tokens
,
)
...
...
@@ -50,6 +118,9 @@ def _row_id_map_pass_2_kernel(
workspace_ptr
,
# sizes
num_tokens
,
# strides
stride_row_id_map_token
,
stride_row_id_map_expert
,
# metas
WORKSPACE_LOAD_WIDTH
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
...
...
@@ -59,7 +130,9 @@ def _row_id_map_pass_2_kernel(
chunk_idx
=
pid_m
*
tl
.
cdiv
(
num_tokens
,
BLOCK_SIZE
)
+
pid_n
offset
=
pid_n
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
row_id_within_token_block
=
tl
.
load
(
row_id_map_ptr
+
pid_m
*
num_tokens
+
offset
,
mask
=
(
offset
<
num_tokens
),
other
=
0
row_id_map_ptr
+
pid_m
*
stride_row_id_map_expert
+
offset
*
stride_row_id_map_token
,
mask
=
(
offset
<
num_tokens
),
other
=
0
,
)
workspace_off
=
tl
.
arange
(
0
,
WORKSPACE_LOAD_WIDTH
)
...
...
@@ -70,23 +143,102 @@ def _row_id_map_pass_2_kernel(
row_id_within_token_block
+
tl
.
sum
(
n_tokens_per_chunk
)
-
1
,
)
tl
.
store
(
row_id_map_ptr
+
pid_m
*
num_tokens
+
offset
,
row_id_map_ptr
+
pid_m
*
stride_row_id_map_expert
+
offset
*
stride_row_id_map_token
,
row_id
,
mask
=
(
offset
<
num_tokens
),
)
@
triton
.
jit
def
_row_id_map_pass_3_kernel
(
# pointers
row_id_map_ptr
,
# sizes
num_experts
:
tl
.
constexpr
,
# strides
stride_row_id_map_token
,
stride_row_id_map_expert
,
# metas
LOAD_SIZE
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
n_dims
:
tl
.
constexpr
=
_log2
(
LOAD_SIZE
)
off
=
tl
.
arange
(
0
,
LOAD_SIZE
)
row_id_map
=
tl
.
load
(
row_id_map_ptr
+
pid
*
stride_row_id_map_token
+
stride_row_id_map_expert
*
off
,
mask
=
off
<
num_experts
,
other
=-
1
,
)
n_routed
=
tl
.
sum
(
tl
.
where
(
row_id_map
!=
-
1
,
1
,
0
))
indices
=
off
sorted_map
,
indices
=
_argsort
(
row_id_map
,
indices
,
n_dims
=
n_dims
)
tl
.
store
(
row_id_map_ptr
+
pid
*
stride_row_id_map_token
+
off
*
stride_row_id_map_expert
,
sorted_map
,
mask
=
off
<
n_routed
,
)
tl
.
store
(
row_id_map_ptr
+
pid
*
stride_row_id_map_token
+
(
num_experts
+
off
)
*
stride_row_id_map_expert
,
indices
,
mask
=
off
<
n_routed
,
)
tl
.
store
(
row_id_map_ptr
+
pid
*
stride_row_id_map_token
+
num_experts
*
2
*
stride_row_id_map_expert
,
n_routed
,
)
def
make_row_id_map
(
routing_map
:
torch
.
Tensor
,
num_tokens
:
int
,
num_experts
:
int
,
):
# pylint: disable=missing-function-docstring
row_id_map
=
torch
.
empty
((
num_experts
,
num_tokens
),
dtype
=
torch
.
int64
,
device
=
"cuda"
)
block_size
=
256
"""
Prepare the row_id_map for the permutation.
Parameters
----------
routing_map: torch.Tensor
Input tensor of shape `[num_tokens, num_experts]`. It is a mask tensor that indicates
which experts are routed to which tokens. The values in it: 1 means the token is routed to
this expert and 0 means not.
num_tokens: int
Number of tokens in the input tensor.
num_experts: int
Number of experts in the input tensor.
Returns
-------
row_id_map: torch.Tensor
The row_id_map for the permutation of shape `[num_tokens, num_experts * 2 + 1]`.
For each token, the last item is the number of experts that are routed (n_routed).
The first n_routed items are the destination row indices in the permuted tokens.
The [num_experts, num_experts + n_routed) items are the indices of the experts corresponding
to the first n_routed row indices above.
"""
row_id_map
=
torch
.
empty
((
num_tokens
,
num_experts
*
2
+
1
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
block_size
=
1024
grid
=
(
num_experts
,
triton
.
cdiv
(
num_tokens
,
block_size
))
workspace_tensor
=
torch
.
empty
(
grid
,
dtype
=
torch
.
int64
,
device
=
"cuda"
)
# block cumsum
workspace_tensor
=
torch
.
empty
(
grid
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
# supposing num_tokens == 5, num_experts == 3, block_size == 3
# and we have a routing_map like this:
# [[1, 1, 0],
# [1, 0, 1],
# [0, 0, 1],
# [1, 1, 0],
# [0, 0, 0]]
# pass 1: block cumsum
# for each expert, compute the cumsum of every block_size tokens
# the row_id_map will be like this after pass 1 (r means useless values):
# [[1, 1, 0, r, r, r, r],
# [2, 0, 1, r, r, r, r],
# [0, 0, 2, r, r, r, r],
# [1, 1, 0, r, r, r, r],
# [0, 0, 0, r, r, r, r]]
_row_id_map_pass_1_kernel
[
grid
](
routing_map
,
row_id_map
,
...
...
@@ -94,16 +246,44 @@ def make_row_id_map(
num_tokens
,
routing_map
.
stride
(
0
),
routing_map
.
stride
(
1
),
row_id_map
.
stride
(
0
),
row_id_map
.
stride
(
1
),
block_size
,
)
# cumsum all and process the mask
# pass 2: cumsum all and process the mask
# process the block cumsum into the global cumsum and then into the dst row indices
# the row_id_map will be like this after pass 2 (r means useless value):
# [[ 0, 3, -1, r, r, r, r],
# [ 1, -1, 5, r, r, r, r],
# [-1, -1, 6, r, r, r, r],
# [ 2, 4, -1, r, r, r, r],
# [-1, -1, -1, r, r, r, r]]
_row_id_map_pass_2_kernel
[
grid
](
row_id_map
,
workspace_tensor
,
num_tokens
,
row_id_map
.
stride
(
0
),
row_id_map
.
stride
(
1
),
triton
.
next_power_of_2
(
num_experts
*
triton
.
cdiv
(
num_tokens
,
block_size
)),
block_size
,
)
# pass 3: make the row_id_map from the sparse structure to the dense structure
# the row_id_map will be like this after pass 3 (r means useless value):
# [[3, 0, r, 1, 0, r, 2],
# [5, 1, r, 2, 0, r, 2],
# [6, r, r, 2, r, r, 1],
# [4, 2, r, 1, 0, r, 2],
# [r, r, r, r, r, r, 0]]
grid
=
(
num_tokens
,)
_row_id_map_pass_3_kernel
[
grid
](
row_id_map
,
num_experts
,
row_id_map
.
stride
(
0
),
row_id_map
.
stride
(
1
),
triton
.
next_power_of_2
(
num_experts
),
)
return
row_id_map
...
...
@@ -118,11 +298,12 @@ def _permute_kernel(
permuted_probs_ptr
,
permuted_scale_ptr
,
# sizes
num_tokens
,
num_experts
,
hidden_size
,
num_experts
:
tl
.
constexpr
,
hidden_size
:
tl
.
constexpr
,
scale_hidden_dim
,
# strides
stride_row_id_map_token
,
stride_row_id_map_expert
,
stride_input_token
,
stride_input_hidden
,
stride_output_token
,
...
...
@@ -139,35 +320,50 @@ def _permute_kernel(
PERMUTE_SCALE
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
cur_pos
=
0
while
cur_pos
<
hidden_size
:
cur_off
=
cur_pos
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
cur_off
<
hidden_size
input_off
=
pid
*
stride_input_token
+
cur_off
*
stride_input_hidden
inp
=
tl
.
load
(
input_ptr
+
input_off
,
mask
=
mask
)
pid_t
=
tl
.
program_id
(
0
)
pid_h
=
tl
.
program_id
(
1
)
cur_off
=
pid_h
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
cur_off
<
hidden_size
input_off
=
pid_t
*
stride_input_token
+
cur_off
*
stride_input_hidden
inp
=
tl
.
load
(
input_ptr
+
input_off
,
mask
=
mask
)
if
PERMUTE_SCALE
:
mask_scale
=
cur_off
<
scale_hidden_dim
scale_off
=
pid_t
*
stride_scale_token
+
cur_off
*
stride_scale_hidden
scale
=
tl
.
load
(
scale_ptr
+
scale_off
,
mask
=
mask_scale
)
n_routed
=
tl
.
load
(
row_id_map_ptr
+
pid_t
*
stride_row_id_map_token
+
num_experts
*
2
*
stride_row_id_map_expert
)
for
idx
in
tl
.
range
(
n_routed
):
dst_row
=
tl
.
load
(
row_id_map_ptr
+
pid_t
*
stride_row_id_map_token
+
idx
*
stride_row_id_map_expert
)
output_off
=
dst_row
*
stride_output_token
+
cur_off
*
stride_output_hidden
if
PERMUTE_SCALE
:
mask_scale
=
cur_off
<
scale_hidden_dim
scale_off
=
pid
*
stride_scale_token
+
cur_off
*
stride_scale_hidden
scale
=
tl
.
load
(
scale_ptr
+
scale_off
,
mask
=
mask_scale
)
for
expert_idx
in
range
(
num_experts
):
dst_row
=
tl
.
load
(
row_id_map_ptr
+
expert_idx
*
num_tokens
+
pid
)
if
dst_row
!=
-
1
:
output_off
=
dst_row
*
stride_output_token
+
cur_off
*
stride_output_hidden
permuted_scale_off
=
(
dst_row
*
stride_permuted_scale_token
+
cur_off
*
stride_permuted_scale_hidden
)
tl
.
store
(
permuted_scale_ptr
+
permuted_scale_off
,
scale
,
mask
=
mask_scale
)
if
PERMUTE_PROBS
:
expert_idx
=
tl
.
load
(
row_id_map_ptr
+
pid_t
*
stride_row_id_map_token
+
(
num_experts
+
idx
)
*
stride_row_id_map_expert
)
prob_off
=
pid_t
*
stride_probs_token
+
expert_idx
*
stride_probs_expert
prob
=
tl
.
load
(
probs_ptr
+
prob_off
)
if
pid_h
==
0
:
permuted_prob_off
=
dst_row
*
stride_permuted_probs_token
tl
.
store
(
permuted_probs_ptr
+
permuted_prob_off
,
prob
)
if
prob
==
0.0
:
# for routing_map padding
# dst_row != -1 and prob == 0.0 means that this slot is padded
tl
.
store
(
output_ptr
+
output_off
,
0
,
mask
=
mask
)
else
:
tl
.
store
(
output_ptr
+
output_off
,
inp
,
mask
=
mask
)
if
PERMUTE_SCALE
:
permuted_scale_off
=
(
dst_row
*
stride_permuted_scale_token
+
cur_off
*
stride_permuted_scale_hidden
)
tl
.
store
(
permuted_scale_ptr
+
permuted_scale_off
,
scale
,
mask
=
mask_scale
)
if
PERMUTE_PROBS
:
if
cur_pos
==
0
:
prob_off
=
pid
*
stride_probs_token
+
expert_idx
*
stride_probs_expert
prob
=
tl
.
load
(
probs_ptr
+
prob_off
)
permuted_prob_off
=
dst_row
*
stride_permuted_probs_token
tl
.
store
(
permuted_probs_ptr
+
permuted_prob_off
,
prob
)
cur_pos
+=
BLOCK_SIZE
else
:
tl
.
store
(
output_ptr
+
output_off
,
inp
,
mask
=
mask
)
try
:
...
...
@@ -178,6 +374,8 @@ try:
triton
.
Config
({
"BLOCK_SIZE"
:
256
}),
triton
.
Config
({
"BLOCK_SIZE"
:
512
}),
triton
.
Config
({
"BLOCK_SIZE"
:
1024
}),
triton
.
Config
({
"BLOCK_SIZE"
:
2048
}),
triton
.
Config
({
"BLOCK_SIZE"
:
4096
}),
],
key
=
[
"hidden_size"
],
)(
_permute_kernel
)
...
...
@@ -196,7 +394,30 @@ def permute_with_mask_map(
hidden_size
:
int
,
scale_hidden_dim
:
int
,
):
# pylint: disable=missing-function-docstring
"""
Permute the input tensor based on the row_id_map.
Parameters
----------
inp: torch.Tensor
Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
row_id_map: torch.Tensor
The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
probs: torch.Tensor
The probabilities of the input tensor. If it is not None, it will be permuted.
scale: torch.Tensor
The scale of the input tensor. If it is not None, it will be permuted.
num_tokens: int
Number of tokens in the input tensor.
num_experts: int
Number of experts in the input tensor.
num_out_tokens: int
Number of tokens in the permuted tensor.
hidden_size: int
Hidden size of the input tensor.
scale_hidden_dim: int
Hidden size of the scale tensor.
"""
output
=
torch
.
empty
((
num_out_tokens
,
hidden_size
),
dtype
=
inp
.
dtype
,
device
=
"cuda"
)
if
probs
is
not
None
:
permuted_probs
=
torch
.
empty
((
num_out_tokens
,),
dtype
=
probs
.
dtype
,
device
=
"cuda"
)
...
...
@@ -209,8 +430,8 @@ def permute_with_mask_map(
)
else
:
permuted_scale
=
None
grid
=
(
num_tokens
,
)
# pylint: disable=unnecessary-lambda-assignment
grid
=
lambda
META
:
(
num_tokens
,
triton
.
cdiv
(
hidden_size
,
META
[
"BLOCK_SIZE"
])
)
_permute_kernel
[
grid
](
inp
,
output
,
...
...
@@ -219,10 +440,11 @@ def permute_with_mask_map(
scale
,
permuted_probs
,
permuted_scale
,
num_tokens
,
num_experts
,
hidden_size
,
scale_hidden_dim
,
row_id_map
.
stride
(
0
),
row_id_map
.
stride
(
1
),
inp
.
stride
(
0
),
inp
.
stride
(
1
),
output
.
stride
(
0
),
...
...
@@ -250,10 +472,11 @@ def _unpermute_kernel(
permuted_probs_ptr
,
unpermuted_probs_ptr
,
# sizes
num_tokens
,
num_experts
,
hidden_size
,
num_experts
:
tl
.
constexpr
,
hidden_size
:
tl
.
constexpr
,
# strides
stride_row_id_map_token
,
stride_row_id_map_expert
,
stride_input_token
,
stride_input_hidden
,
stride_output_token
,
...
...
@@ -264,6 +487,7 @@ def _unpermute_kernel(
stride_unpermuted_probs_token
,
stride_unpermuted_probs_expert
,
# metas
PROBS_LOAD_WIDTH
:
tl
.
constexpr
,
WITH_MERGING_PROBS
:
tl
.
constexpr
,
PERMUTE_PROBS
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
...
...
@@ -271,41 +495,63 @@ def _unpermute_kernel(
data_type
=
input_ptr
.
dtype
.
element_ty
compute_type
=
tl
.
float32
pid
=
tl
.
program_id
(
0
)
current_start
=
0
while
current_start
<
hidden_size
:
current_offset
=
current_start
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
current_offset
<
hidden_size
accumulator
=
tl
.
zeros
((
BLOCK_SIZE
,),
dtype
=
compute_type
)
for
expert_idx
in
range
(
num_experts
):
src_row
=
tl
.
load
(
row_id_map_ptr
+
expert_idx
*
num_tokens
+
pid
)
if
src_row
!=
-
1
:
input_off
=
src_row
*
stride_input_token
+
current_offset
*
stride_input_hidden
inp
=
tl
.
load
(
input_ptr
+
input_off
,
mask
=
mask
)
inp
=
inp
.
to
(
compute_type
)
if
WITH_MERGING_PROBS
:
merging_prob_off
=
(
pid
*
stride_merging_probs_token
+
expert_idx
*
stride_merging_probs_expert
)
merging_prob
=
tl
.
load
(
merging_probs_ptr
+
merging_prob_off
).
to
(
compute_type
)
inp
*=
merging_prob
accumulator
+=
inp
if
PERMUTE_PROBS
:
if
current_start
==
0
:
unpermuted_prob_off
=
(
pid
*
stride_unpermuted_probs_token
+
expert_idx
*
stride_unpermuted_probs_expert
)
if
src_row
!=
-
1
:
permuted_prob_off
=
src_row
*
stride_permuted_probs_token
prob
=
tl
.
load
(
permuted_probs_ptr
+
permuted_prob_off
)
tl
.
store
(
unpermuted_probs_ptr
+
unpermuted_prob_off
,
prob
)
else
:
tl
.
store
(
unpermuted_probs_ptr
+
unpermuted_prob_off
,
0.0
)
accumulator
=
accumulator
.
to
(
data_type
)
output_off
=
pid
*
stride_output_token
+
current_offset
*
stride_output_hidden
tl
.
store
(
output_ptr
+
output_off
,
accumulator
,
mask
=
mask
)
current_start
+=
BLOCK_SIZE
pid_t
=
tl
.
program_id
(
0
)
pid_h
=
tl
.
program_id
(
1
)
current_offset
=
pid_h
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
current_offset
<
hidden_size
if
PERMUTE_PROBS
:
# write 0.0 to probs_grad that are not routed
if
pid_h
==
0
:
map_load_off
=
tl
.
arange
(
0
,
PROBS_LOAD_WIDTH
)
unpermuted_prob_off
=
(
pid_t
*
stride_unpermuted_probs_token
+
stride_unpermuted_probs_expert
*
map_load_off
)
tl
.
store
(
unpermuted_probs_ptr
+
unpermuted_prob_off
,
0.0
,
mask
=
map_load_off
<
num_experts
)
accumulator
=
tl
.
zeros
((
BLOCK_SIZE
,),
dtype
=
compute_type
)
n_routed
=
tl
.
load
(
row_id_map_ptr
+
pid_t
*
stride_row_id_map_token
+
num_experts
*
2
*
stride_row_id_map_expert
)
for
idx
in
tl
.
range
(
n_routed
):
src_row
=
tl
.
load
(
row_id_map_ptr
+
pid_t
*
stride_row_id_map_token
+
idx
*
stride_row_id_map_expert
)
input_off
=
src_row
*
stride_input_token
+
current_offset
*
stride_input_hidden
inp
=
tl
.
load
(
input_ptr
+
input_off
,
mask
=
mask
)
inp
=
inp
.
to
(
compute_type
)
if
WITH_MERGING_PROBS
:
expert_idx
=
tl
.
load
(
row_id_map_ptr
+
pid_t
*
stride_row_id_map_token
+
(
num_experts
+
idx
)
*
stride_row_id_map_expert
)
merging_prob_off
=
(
pid_t
*
stride_merging_probs_token
+
expert_idx
*
stride_merging_probs_expert
)
merging_prob
=
tl
.
load
(
merging_probs_ptr
+
merging_prob_off
).
to
(
compute_type
)
inp
*=
merging_prob
accumulator
+=
inp
if
PERMUTE_PROBS
:
if
pid_h
==
0
:
expert_idx
=
tl
.
load
(
row_id_map_ptr
+
pid_t
*
stride_row_id_map_token
+
(
num_experts
+
idx
)
*
stride_row_id_map_expert
)
unpermuted_prob_off
=
(
pid_t
*
stride_unpermuted_probs_token
+
expert_idx
*
stride_unpermuted_probs_expert
)
permuted_prob_off
=
src_row
*
stride_permuted_probs_token
prob
=
tl
.
load
(
permuted_probs_ptr
+
permuted_prob_off
)
tl
.
store
(
unpermuted_probs_ptr
+
unpermuted_prob_off
,
prob
)
accumulator
=
accumulator
.
to
(
data_type
)
output_off
=
pid_t
*
stride_output_token
+
current_offset
*
stride_output_hidden
tl
.
store
(
output_ptr
+
output_off
,
accumulator
,
mask
=
mask
)
try
:
...
...
@@ -316,6 +562,8 @@ try:
triton
.
Config
({
"BLOCK_SIZE"
:
256
}),
triton
.
Config
({
"BLOCK_SIZE"
:
512
}),
triton
.
Config
({
"BLOCK_SIZE"
:
1024
}),
triton
.
Config
({
"BLOCK_SIZE"
:
2048
}),
triton
.
Config
({
"BLOCK_SIZE"
:
4096
}),
],
key
=
[
"hidden_size"
],
)(
_unpermute_kernel
)
...
...
@@ -332,7 +580,27 @@ def unpermute_with_mask_map(
num_experts
:
int
,
hidden_size
:
int
,
):
# pylint: disable=missing-function-docstring
"""
Unpermute the input tensor based on the row_id_map.
Parameters
----------
inp: torch.Tensor
Input tensor of shape `[num_out_tokens, hidden_size]`.
row_id_map: torch.Tensor
The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
merging_probs: torch.Tensor
The merging probabilities of the input tensor. If it is not None, it will be used as weights
to reduce the unpermuted tokens.
permuted_probs: torch.Tensor
The permuted probabilities of the input tensor. If it is not None, it will be unpermuted.
num_tokens: int
Number of tokens in the permuted tensor.
num_experts: int
Number of experts in the permuted tensor.
hidden_size: int
Hidden size of the permuted tensor.
"""
output
=
torch
.
empty
((
num_tokens
,
hidden_size
),
dtype
=
inp
.
dtype
,
device
=
"cuda"
)
if
permuted_probs
is
not
None
:
unpermuted_probs
=
torch
.
empty
(
...
...
@@ -340,7 +608,8 @@ def unpermute_with_mask_map(
)
else
:
unpermuted_probs
=
None
grid
=
(
num_tokens
,)
# pylint: disable=unnecessary-lambda-assignment
grid
=
lambda
META
:
(
num_tokens
,
triton
.
cdiv
(
hidden_size
,
META
[
"BLOCK_SIZE"
]))
_unpermute_kernel
[
grid
](
inp
,
output
,
...
...
@@ -348,9 +617,10 @@ def unpermute_with_mask_map(
merging_probs
,
permuted_probs
,
unpermuted_probs
,
num_tokens
,
num_experts
,
hidden_size
,
row_id_map
.
stride
(
0
),
row_id_map
.
stride
(
1
),
inp
.
stride
(
0
),
inp
.
stride
(
1
),
output
.
stride
(
0
),
...
...
@@ -360,6 +630,7 @@ def unpermute_with_mask_map(
permuted_probs
.
stride
(
0
)
if
permuted_probs
is
not
None
else
None
,
unpermuted_probs
.
stride
(
0
)
if
unpermuted_probs
is
not
None
else
None
,
unpermuted_probs
.
stride
(
1
)
if
unpermuted_probs
is
not
None
else
None
,
PROBS_LOAD_WIDTH
=
triton
.
next_power_of_2
(
num_experts
),
WITH_MERGING_PROBS
=
merging_probs
is
not
None
,
PERMUTE_PROBS
=
permuted_probs
is
not
None
,
)
...
...
@@ -376,10 +647,11 @@ def _unpermute_bwd_with_merging_probs_kernel(
merging_probs_grad_ptr
,
row_id_map_ptr
,
# sizes
num_tokens
,
num_experts
,
hidden_size
,
num_experts
:
tl
.
constexpr
,
hidden_size
:
tl
.
constexpr
,
# strides
stride_row_id_map_token
,
stride_row_id_map_expert
,
stride_fwd_output_grad_token
,
stride_fwd_output_grad_hidden
,
stride_fwd_input_grad_token
,
...
...
@@ -391,56 +663,63 @@ def _unpermute_bwd_with_merging_probs_kernel(
stride_merging_probs_grad_token
,
stride_merging_probs_grad_expert
,
# metas
PROBS_LOAD_WIDTH
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
data_type
=
fwd_output_grad_ptr
.
dtype
.
element_ty
compute_type
=
tl
.
float32
pid
=
tl
.
program_id
(
0
)
for
expert_idx
in
range
(
num_experts
):
dst_row
=
tl
.
load
(
row_id_map_ptr
+
expert_idx
*
num_tokens
+
pid
)
if
dst_row
!=
-
1
:
prob_grad_accum
=
tl
.
zeros
((
BLOCK_SIZE
,),
dtype
=
compute_type
)
current_start
=
0
while
current_start
<
hidden_size
:
current_offset
=
current_start
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
current_offset
<
hidden_size
input_off
=
(
pid
*
stride_fwd_output_grad_token
+
current_offset
*
stride_fwd_output_grad_hidden
)
inp
=
tl
.
load
(
fwd_output_grad_ptr
+
input_off
,
mask
=
mask
)
inp
=
inp
.
to
(
compute_type
)
merging_prob_off
=
(
pid
*
stride_merging_probs_token
+
expert_idx
*
stride_merging_probs_expert
)
merging_prob
=
tl
.
load
(
merging_probs_ptr
+
merging_prob_off
).
to
(
compute_type
)
output
=
inp
*
merging_prob
output
=
output
.
to
(
data_type
)
output_off
=
(
dst_row
*
stride_fwd_input_grad_token
+
current_offset
*
stride_fwd_input_grad_hidden
)
tl
.
store
(
fwd_input_grad_ptr
+
output_off
,
output
,
mask
=
mask
)
fwd_input_off
=
(
dst_row
*
stride_fwd_input_token
+
current_offset
*
stride_fwd_input_hidden
)
fwd_input
=
tl
.
load
(
fwd_input_ptr
+
fwd_input_off
,
mask
=
mask
)
prob_grad_accum
+=
fwd_input
.
to
(
compute_type
)
*
inp
current_start
+=
BLOCK_SIZE
probs_grad
=
tl
.
sum
(
prob_grad_accum
).
to
(
merging_probs_grad_ptr
.
dtype
.
element_ty
)
probs_grad_off
=
(
pid
*
stride_merging_probs_grad_token
+
expert_idx
*
stride_merging_probs_grad_expert
map_load_off
=
tl
.
arange
(
0
,
PROBS_LOAD_WIDTH
)
token_probs_grad_off
=
(
pid
*
stride_merging_probs_grad_token
+
stride_merging_probs_grad_expert
*
map_load_off
)
tl
.
store
(
merging_probs_grad_ptr
+
token_probs_grad_off
,
0.0
,
mask
=
map_load_off
<
num_experts
)
n_routed
=
tl
.
load
(
row_id_map_ptr
+
pid
*
stride_row_id_map_token
+
num_experts
*
2
*
stride_row_id_map_expert
)
for
idx
in
tl
.
range
(
n_routed
):
dst_row
=
tl
.
load
(
row_id_map_ptr
+
pid
*
stride_row_id_map_token
+
idx
*
stride_row_id_map_expert
)
expert_idx
=
tl
.
load
(
row_id_map_ptr
+
pid
*
stride_row_id_map_token
+
(
num_experts
+
idx
)
*
stride_row_id_map_expert
)
prob_grad_accum
=
tl
.
zeros
((
BLOCK_SIZE
,),
dtype
=
compute_type
)
current_start
=
0
while
current_start
<
hidden_size
:
current_offset
=
current_start
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
current_offset
<
hidden_size
input_off
=
(
pid
*
stride_fwd_output_grad_token
+
current_offset
*
stride_fwd_output_grad_hidden
)
tl
.
store
(
merging_probs_grad_ptr
+
probs_grad_off
,
probs_grad
)
else
:
probs_grad_off
=
(
pid
*
stride_merging_probs_grad_token
+
expert_idx
*
stride_merging_probs_grad_expert
inp
=
tl
.
load
(
fwd_output_grad_ptr
+
input_off
,
mask
=
mask
)
inp
=
inp
.
to
(
compute_type
)
merging_prob_off
=
(
pid
*
stride_merging_probs_token
+
expert_idx
*
stride_merging_probs_expert
)
merging_prob
=
tl
.
load
(
merging_probs_ptr
+
merging_prob_off
).
to
(
compute_type
)
output
=
inp
*
merging_prob
output
=
output
.
to
(
data_type
)
output_off
=
(
dst_row
*
stride_fwd_input_grad_token
+
current_offset
*
stride_fwd_input_grad_hidden
)
tl
.
store
(
merging_probs_grad_ptr
+
probs_grad_off
,
0.0
)
tl
.
store
(
fwd_input_grad_ptr
+
output_off
,
output
,
mask
=
mask
)
fwd_input_off
=
(
dst_row
*
stride_fwd_input_token
+
current_offset
*
stride_fwd_input_hidden
)
fwd_input
=
tl
.
load
(
fwd_input_ptr
+
fwd_input_off
,
mask
=
mask
)
prob_grad_accum
+=
fwd_input
.
to
(
compute_type
)
*
inp
current_start
+=
BLOCK_SIZE
probs_grad
=
tl
.
sum
(
prob_grad_accum
).
to
(
merging_probs_grad_ptr
.
dtype
.
element_ty
)
probs_grad_off
=
(
pid
*
stride_merging_probs_grad_token
+
expert_idx
*
stride_merging_probs_grad_expert
)
tl
.
store
(
merging_probs_grad_ptr
+
probs_grad_off
,
probs_grad
)
try
:
...
...
@@ -451,6 +730,8 @@ try:
triton
.
Config
({
"BLOCK_SIZE"
:
256
}),
triton
.
Config
({
"BLOCK_SIZE"
:
512
}),
triton
.
Config
({
"BLOCK_SIZE"
:
1024
}),
triton
.
Config
({
"BLOCK_SIZE"
:
2048
}),
triton
.
Config
({
"BLOCK_SIZE"
:
4096
}),
],
key
=
[
"hidden_size"
],
)(
_unpermute_bwd_with_merging_probs_kernel
)
...
...
@@ -468,7 +749,28 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
num_out_tokens
:
int
,
hidden_size
:
int
,
):
# pylint: disable=missing-function-docstring
"""
Unpermute backward pass kernel with merging probs.
Parameters
----------
fwd_output_grad: torch.Tensor
The gradient of the output tensor of shape `[num_tokens, hidden_size]`.
row_id_map: torch.Tensor
The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
fwd_input: torch.Tensor
The input tensor of the forward pass of shape `[num_out_tokens, hidden_size]`.
merging_probs: torch.Tensor
The merging probabilities of the input tensor of shape `[num_tokens, num_experts]`.
num_tokens: int
Number of tokens in the permuted tensor.
num_experts: int
Number of experts in the permuted tensor.
num_out_tokens: int
Number of tokens in the output tensor.
hidden_size: int
Hidden size of the output tensor.
"""
act_grad
=
torch
.
empty
(
(
num_out_tokens
,
hidden_size
),
dtype
=
fwd_output_grad
.
dtype
,
device
=
"cuda"
)
...
...
@@ -483,9 +785,10 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
merging_probs
,
merging_probs_grad
,
row_id_map
,
num_tokens
,
num_experts
,
hidden_size
,
row_id_map
.
stride
(
0
),
row_id_map
.
stride
(
1
),
fwd_output_grad
.
stride
(
0
),
fwd_output_grad
.
stride
(
1
),
act_grad
.
stride
(
0
),
...
...
@@ -496,34 +799,21 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
merging_probs
.
stride
(
1
),
merging_probs_grad
.
stride
(
0
),
merging_probs_grad
.
stride
(
1
),
PROBS_LOAD_WIDTH
=
triton
.
next_power_of_2
(
num_experts
),
)
return
act_grad
,
merging_probs_grad
@
triton
.
jit
def
_
sort
_chunk
s_by_idxs
_kernel
(
def
_
make
_chunk
_sort_map
_kernel
(
# pointers
input_ptr
,
split_sizes_ptr
,
sorted_indices_ptr
,
output_ptr
,
dst_rows_ptr
,
probs_ptr
,
permuted_probs_ptr
,
# sizes
num_splits
,
hidden_size
,
# strides
stride_input_token
,
stride_input_hidden
,
stride_output_token
,
stride_output_hidden
,
stride_probs_token
,
stride_permuted_probs_token
,
num_splits
:
tl
.
constexpr
,
# metas
PERMUTE_PROBS
:
tl
.
constexpr
,
IDX_LOAD_WIDTH
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
...
...
@@ -533,104 +823,58 @@ def _sort_chunks_by_idxs_kernel(
)
# get chunk idx of the current token in the input tensor
input_chunk_idx
=
-
1
in_chunk_offset
=
tl
.
zeros
([],
dtype
=
tl
.
int64
)
acc_chunk_sizes
=
tl
.
zeros
([],
dtype
=
tl
.
int64
)
cursor
=
0
while
cursor
<
num_splits
:
cur_chunk_size
=
tl
.
load
(
split_sizes_ptr
+
cursor
).
to
(
tl
.
int64
)
acc_chunk_sizes
+=
cur_chunk_size
if
input_chunk_idx
==
-
1
and
acc_chunk_sizes
>
pid
:
input_chunk_idx
=
cursor
in_chunk_offset
=
pid
-
(
acc_chunk_sizes
-
cur_chunk_size
)
cursor
+=
1
input_split_sizes
=
tl
.
load
(
split_sizes_ptr
+
load_split_offset
,
mask
=
load_split_offset
<
num_splits
,
other
=
0
).
to
(
tl
.
int32
)
input_split_sizes_cumsum
=
tl
.
cumsum
(
input_split_sizes
)
input_split_sizes_mask
=
tl
.
where
(
input_split_sizes_cumsum
<=
pid
,
1
,
0
)
input_chunk_idx
=
tl
.
sum
(
input_split_sizes_mask
)
input_split_sizes_presum
=
tl
.
sum
(
input_split_sizes
*
input_split_sizes_mask
)
in_chunk_offset
=
pid
-
input_split_sizes_presum
# get chunk idx of the current token in the output tensor
output_chunk_idx
=
0
cursor
=
0
while
cursor
<
num_splits
:
cur_input_idx
=
tl
.
load
(
sorted_indices_ptr
+
cursor
)
if
cur_input_idx
==
input_chunk_idx
:
output_chunk_idx
=
cursor
cursor
+=
1
output_chunk_mask
=
tl
.
where
(
sorted_indices
==
input_chunk_idx
,
1
,
0
)
output_chunk_idx
=
tl
.
argmax
(
output_chunk_mask
,
axis
=-
1
)
# make row_id_map
output_split_sizes
=
tl
.
load
(
split_sizes_ptr
+
sorted_indices
,
mask
=
load_split_offset
<
num_splits
).
to
(
tl
.
int
64
)
).
to
(
tl
.
int
32
)
output_pre_split_sizes
=
tl
.
where
(
load_split_offset
<
output_chunk_idx
,
output_split_sizes
,
0
)
dst_row
=
tl
.
sum
(
output_pre_split_sizes
)
+
in_chunk_offset
tl
.
store
(
dst_rows_ptr
+
pid
,
dst_row
)
current_start
=
0
while
current_start
<
hidden_size
:
current_offset
=
current_start
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
current_offset
<
hidden_size
input_offsets
=
pid
*
stride_input_token
+
current_offset
*
stride_input_hidden
output_offsets
=
dst_row
*
stride_output_token
+
current_offset
*
stride_output_hidden
inp
=
tl
.
load
(
input_ptr
+
input_offsets
,
mask
=
mask
)
tl
.
store
(
output_ptr
+
output_offsets
,
inp
,
mask
=
mask
)
current_start
+=
BLOCK_SIZE
if
PERMUTE_PROBS
:
prob_off
=
pid
*
stride_probs_token
prob
=
tl
.
load
(
probs_ptr
+
prob_off
)
permuted_prob_off
=
dst_row
*
stride_permuted_probs_token
tl
.
store
(
permuted_probs_ptr
+
permuted_prob_off
,
prob
)
try
:
_sort_chunks_by_idxs_kernel
=
triton
.
autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_SIZE"
:
64
}),
triton
.
Config
({
"BLOCK_SIZE"
:
128
}),
triton
.
Config
({
"BLOCK_SIZE"
:
256
}),
triton
.
Config
({
"BLOCK_SIZE"
:
512
}),
triton
.
Config
({
"BLOCK_SIZE"
:
1024
}),
],
key
=
[
"hidden_size"
],
)(
_sort_chunks_by_idxs_kernel
)
except
RuntimeError
:
pass
def
sort_chunks_by_idx
(
inp
:
torch
.
Tensor
,
def
make_chunk_sort_map
(
split_sizes
:
torch
.
Tensor
,
sorted_indices
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
num_tokens
:
int
,
hidden_size
:
int
,
num_splits
:
int
,
):
# pylint: disable=missing-function-docstring
row_id_map
=
torch
.
empty
((
num_tokens
,),
dtype
=
torch
.
int64
,
device
=
"cuda"
)
output
=
torch
.
empty
((
num_tokens
,
hidden_size
),
dtype
=
inp
.
dtype
,
device
=
"cuda"
)
if
probs
is
not
None
:
permuted_probs
=
torch
.
empty
((
num_tokens
,),
dtype
=
probs
.
dtype
,
device
=
"cuda"
)
else
:
permuted_probs
=
None
"""
Make a row_id_map for chunk sort.
Parameters
----------
split_sizes: torch.Tensor
The sizes of the chunks of shape `[num_splits,]`.
sorted_indices: torch.Tensor
The indices of the sorted chunks of shape `[num_splits,]`.
num_tokens: int
Number of tokens in the input tensor.
num_splits: int
Number of splits of split_sizes and sorted_indices.
"""
row_id_map
=
torch
.
empty
((
num_tokens
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
grid
=
(
num_tokens
,)
_sort_chunks_by_idxs_kernel
[
grid
](
inp
,
_make_chunk_sort_map_kernel
[
grid
](
split_sizes
,
sorted_indices
,
output
,
row_id_map
,
probs
,
permuted_probs
,
num_splits
,
hidden_size
,
inp
.
stride
(
0
),
inp
.
stride
(
1
),
output
.
stride
(
0
),
output
.
stride
(
1
),
probs
.
stride
(
0
)
if
probs
is
not
None
else
None
,
permuted_probs
.
stride
(
0
)
if
permuted_probs
is
not
None
else
None
,
PERMUTE_PROBS
=
probs
is
not
None
,
IDX_LOAD_WIDTH
=
triton
.
next_power_of_2
(
num_splits
),
)
return
output
,
row_id_map
,
permuted_probs
return
row_id_map
@
triton
.
jit
...
...
@@ -642,7 +886,7 @@ def _sort_chunks_by_map_kernel(
probs_ptr
,
permuted_probs_ptr
,
# sizes
hidden_size
,
hidden_size
:
tl
.
constexpr
,
# strides
stride_input_token
,
stride_input_hidden
,
...
...
@@ -653,23 +897,28 @@ def _sort_chunks_by_map_kernel(
# metas
PERMUTE_PROBS
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
FORWARD
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
dst_row
=
tl
.
load
(
row_id_map_ptr
+
pid
)
current_start
=
0
while
current_start
<
hidden_size
:
current_offset
=
current_start
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
current_offset
<
hidden_size
input_offsets
=
dst_row
*
stride_input_token
+
current_offset
*
stride_input_hidden
output_offsets
=
pid
*
stride_output_token
+
current_offset
*
stride_output_hidden
inp
=
tl
.
load
(
input_ptr
+
input_offsets
,
mask
=
mask
)
tl
.
store
(
output_ptr
+
output_offsets
,
inp
,
mask
=
mask
)
current_start
+=
BLOCK_SIZE
pid_t
=
tl
.
program_id
(
0
)
pid_h
=
tl
.
program_id
(
1
)
if
FORWARD
:
src_row
=
pid_t
dst_row
=
tl
.
load
(
row_id_map_ptr
+
pid_t
)
else
:
src_row
=
tl
.
load
(
row_id_map_ptr
+
pid_t
)
dst_row
=
pid_t
current_offset
=
pid_h
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
current_offset
<
hidden_size
input_offsets
=
src_row
*
stride_input_token
+
current_offset
*
stride_input_hidden
output_offsets
=
dst_row
*
stride_output_token
+
current_offset
*
stride_output_hidden
inp
=
tl
.
load
(
input_ptr
+
input_offsets
,
mask
=
mask
)
tl
.
store
(
output_ptr
+
output_offsets
,
inp
,
mask
=
mask
)
if
PERMUTE_PROBS
:
prob_off
=
dst_row
*
stride_probs_token
prob
=
tl
.
load
(
probs_ptr
+
prob_off
)
permuted_prob_off
=
pid
*
stride_permuted_probs_token
tl
.
store
(
permuted_probs_ptr
+
permuted_prob_off
,
prob
)
if
pid_h
==
0
:
prob_off
=
src_row
*
stride_probs_token
prob
=
tl
.
load
(
probs_ptr
+
prob_off
)
permuted_prob_off
=
dst_row
*
stride_permuted_probs_token
tl
.
store
(
permuted_probs_ptr
+
permuted_prob_off
,
prob
)
try
:
...
...
@@ -680,6 +929,8 @@ try:
triton
.
Config
({
"BLOCK_SIZE"
:
256
}),
triton
.
Config
({
"BLOCK_SIZE"
:
512
}),
triton
.
Config
({
"BLOCK_SIZE"
:
1024
}),
triton
.
Config
({
"BLOCK_SIZE"
:
2048
}),
triton
.
Config
({
"BLOCK_SIZE"
:
4096
}),
],
key
=
[
"hidden_size"
],
)(
_sort_chunks_by_map_kernel
)
...
...
@@ -693,14 +944,33 @@ def sort_chunks_by_map(
probs
:
torch
.
Tensor
,
num_tokens
:
int
,
hidden_size
:
int
,
is_forward
:
bool
,
):
# pylint: disable=missing-function-docstring
"""
Sort chunks with row_id_map.
Parameters
----------
inp: torch.Tensor
Input tensor of shape `[num_tokens, hidden_size]`.
row_id_map: torch.Tensor
The token to expert mapping tensor of shape `[num_tokens,]`.
probs: torch.Tensor
The probabilities of the input tensor. If it is not None, it will be permuted.
num_tokens: int
Number of tokens in the input tensor.
hidden_size: int
Hidden size of the input tensor.
is_forward: bool
Whether the sort is for forward or backward.
"""
output
=
torch
.
empty
((
num_tokens
,
hidden_size
),
dtype
=
inp
.
dtype
,
device
=
"cuda"
)
if
probs
is
not
None
:
permuted_probs
=
torch
.
empty
((
num_tokens
,),
dtype
=
probs
.
dtype
,
device
=
"cuda"
)
else
:
permuted_probs
=
None
grid
=
(
num_tokens
,)
# pylint: disable=unnecessary-lambda-assignment
grid
=
lambda
META
:
(
num_tokens
,
triton
.
cdiv
(
hidden_size
,
META
[
"BLOCK_SIZE"
]))
_sort_chunks_by_map_kernel
[
grid
](
inp
,
output
,
...
...
@@ -715,5 +985,6 @@ def sort_chunks_by_map(
probs
.
stride
(
0
)
if
probs
is
not
None
else
None
,
permuted_probs
.
stride
(
0
)
if
permuted_probs
is
not
None
else
None
,
PERMUTE_PROBS
=
probs
is
not
None
,
FORWARD
=
is_forward
,
)
return
output
,
permuted_probs
transformer_engine/pytorch/utils.py
View file @
44740c6c
...
...
@@ -7,7 +7,7 @@ from __future__ import annotations
import
functools
import
math
import
os
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
import
numpy
as
np
import
torch
...
...
@@ -654,3 +654,110 @@ else:
gpu_autocast_ctx
=
torch
.
cuda
.
amp
.
autocast
from
..debug.pytorch.debug_quantization
import
DebugQuantizedTensor
_torch_dtype_to_np_typestr_dict
=
{
torch
.
float16
:
"<f2"
,
torch
.
float32
:
"<f4"
,
torch
.
int64
:
"<i8"
,
torch
.
int32
:
"<i4"
,
torch
.
int8
:
"|i1"
,
torch
.
float8_e4m3fn
:
"|i1"
,
torch
.
qint8
:
"|u1"
,
torch
.
bool
:
"|b1"
,
torch
.
bfloat16
:
"<f2"
,
}
class
_WeakRefTensor
:
"""
A wrapper wraps raw data pointer to a tensor-like object. Could be compatibale with openai triton kernel and be converted to `torch.Tensor` with zero-copy overhead.
"""
def
__init__
(
self
,
data_ptr
:
int
,
dtype
:
torch
.
dtype
,
shape
:
Sequence
[
int
],
):
self
.
_data_ptr
=
data_ptr
self
.
dtype
=
dtype
self
.
shape
=
shape
def
data_ptr
(
self
):
"""Data pointer of the tensor."""
return
self
.
_data_ptr
@
property
def
dtype
(
self
):
"""Dtype of the tensor."""
return
self
.
_dtype
@
property
def
shape
(
self
):
"""Shape of the tensor."""
return
getattr
(
self
,
"_shape"
,
None
)
@
dtype
.
setter
def
dtype
(
self
,
dtype
:
torch
.
dtype
):
self
.
_dtype
=
dtype
@
shape
.
setter
def
shape
(
self
,
shape
:
Sequence
[
int
]):
self
.
_shape
=
tuple
(
int
(
i
)
for
i
in
shape
)
def
numel
(
self
):
"""Number of elements in the tensor."""
return
np
.
prod
(
self
.
shape
)
@
property
def
__cuda_array_interface__
(
self
):
return
{
"shape"
:
self
.
shape
,
"typestr"
:
self
.
torch_dtype_to_np_typestr
(),
"data"
:
(
self
.
data_ptr
()
if
self
.
numel
()
>
0
else
0
,
False
),
"version"
:
3
,
}
def
torch_dtype_to_np_typestr
(
self
):
"""Convert PyTorch dtype to numpy typestr."""
ret
=
_torch_dtype_to_np_typestr_dict
.
get
(
self
.
dtype
)
assert
ret
is
not
None
,
f
"Unsupported dtype:
{
self
.
dtype
}
"
return
ret
def
make_weak_ref
(
x
):
"""
This function is to make a weak reference to the input so that the memory can be released.
"""
def
convert_to_torch_tensor
(
tensor
:
Union
[
_WeakRefTensor
,
torch
.
Tensor
])
->
torch
.
Tensor
:
"""
This function is to convert the `_WeakRefTensor` to torch.Tensor.
"""
if
isinstance
(
tensor
,
torch
.
Tensor
):
return
tensor
old_ptr
=
tensor
.
data_ptr
()
new_tensor
=
torch
.
as_tensor
(
tensor
).
view
(
tensor
.
dtype
)
new_ptr
=
new_tensor
.
data_ptr
()
if
old_ptr
!=
new_ptr
:
raise
RuntimeError
(
"Data pointer mismatch after converting to torch.Tensor"
)
return
new_tensor
if
isinstance
(
x
,
torch
.
Tensor
):
return
(
convert_to_torch_tensor
(
_WeakRefTensor
(
x
.
data_ptr
(),
x
.
dtype
,
x
.
shape
))
if
x
.
is_cuda
else
x
)
if
isinstance
(
x
,
tuple
):
return
tuple
(
make_weak_ref
(
i
)
for
i
in
x
)
if
isinstance
(
x
,
list
):
return
[
make_weak_ref
(
i
)
for
i
in
x
]
if
isinstance
(
x
,
dict
):
return
{
k
:
make_weak_ref
(
v
)
for
k
,
v
in
x
.
items
()}
if
isinstance
(
x
,
(
int
,
float
,
bool
)):
return
x
if
x
is
None
:
return
None
raise
TypeError
(
f
"Invalid type
{
type
(
x
)
}
to make weak ref"
)
Prev
1
…
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment