Commit 44740c6c authored by yuguo's avatar yuguo
Browse files

Merge commit '7a9a0825' of...

Merge commit '7a9a0825' of https://github.com/NVIDIA/TransformerEngine
parents 8113d9e0 7a9a0825
...@@ -10,6 +10,72 @@ import torch ...@@ -10,6 +10,72 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
from triton.language import core
from triton.language.standard import _log2
# The following three argsort related kernels are adapted from
# the issue https://github.com/triton-lang/triton/issues/3698
@triton.jit
def _compare_and_swap(x, indices, flip, i: tl.constexpr, n_dims: tl.constexpr):
n_outer: tl.constexpr = x.numel >> n_dims
shape: tl.constexpr = [n_outer * (2**i), 2, 2 ** (n_dims - i - 1)]
y = tl.reshape(x, shape)
z = tl.reshape(indices, shape)
mask = tl.arange(0, 2)[None, :, None]
l_value = tl.reshape(tl.broadcast_to(tl.sum(y * (1 - mask), 1)[:, None, :], shape), x.shape).to(
x.dtype
)
r_value = tl.reshape(tl.broadcast_to(tl.sum(y * mask, 1)[:, None, :], shape), x.shape).to(
x.dtype
)
l_indice = tl.reshape(tl.broadcast_to(tl.sum(z * (1 - mask), 1)[:, None, :], shape), x.shape)
r_indice = tl.reshape(tl.broadcast_to(tl.sum(z * mask, 1)[:, None, :], shape), x.shape)
idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
il_value = l_value.to(idtype, bitcast=True)
ir_value = r_value.to(idtype, bitcast=True)
ix = x.to(idtype, bitcast=True)
flag1 = tl.where(((l_value > r_value) ^ flip) != 0, il_value ^ ir_value, tl.zeros_like(ix))
ret = ix ^ flag1
flag2 = tl.where(((l_value > r_value) ^ flip) != 0, l_indice ^ r_indice, tl.zeros_like(ix))
ind = indices ^ flag2
return ret.to(x.dtype, bitcast=True), ind
@triton.jit
def _bitonic_merge(x, indices, stage: tl.constexpr, order: tl.constexpr, n_dims: tl.constexpr):
n_outer: tl.constexpr = x.numel >> n_dims
tl.static_assert(stage <= n_dims)
"""
order_type 0 == ascending
order_type 1 == descending
order_type 2 == alternating
"""
if order == 2:
shape: tl.constexpr = [n_outer * (2 ** (n_dims - 1 - stage)), 2, 2**stage]
flip = tl.reshape(tl.broadcast_to(tl.arange(0, 2)[None, :, None], shape), x.shape)
else:
flip = tl.full(x.shape, value=order, dtype=tl.int32)
for i in tl.static_range(stage):
x, indices = _compare_and_swap(x, indices, flip, i + (n_dims - stage), n_dims)
return x, indices
@triton.jit
def _argsort(x, indices, n_dims: tl.constexpr):
for i in tl.static_range(1, n_dims + 1):
x, indices = _bitonic_merge(x, indices, i, 2 if i < n_dims else 1, n_dims)
return x, indices
@triton.jit @triton.jit
def _row_id_map_pass_1_kernel( def _row_id_map_pass_1_kernel(
...@@ -22,6 +88,8 @@ def _row_id_map_pass_1_kernel( ...@@ -22,6 +88,8 @@ def _row_id_map_pass_1_kernel(
# strides # strides
stride_routing_map_token, stride_routing_map_token,
stride_routing_map_expert, stride_routing_map_expert,
stride_row_id_map_token,
stride_row_id_map_expert,
# metas # metas
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
): ):
...@@ -32,10 +100,10 @@ def _row_id_map_pass_1_kernel( ...@@ -32,10 +100,10 @@ def _row_id_map_pass_1_kernel(
routing_map_ptr + pid_m * stride_routing_map_expert + offset * stride_routing_map_token, routing_map_ptr + pid_m * stride_routing_map_expert + offset * stride_routing_map_token,
mask=(offset < num_tokens), mask=(offset < num_tokens),
other=0, other=0,
).to(tl.int64) ).to(tl.int32)
row_id_within_token_block = tl.cumsum(expert_token_mask) * expert_token_mask row_id_within_token_block = tl.cumsum(expert_token_mask) * expert_token_mask
tl.store( tl.store(
row_id_map_ptr + pid_m * num_tokens + offset, row_id_map_ptr + pid_m * stride_row_id_map_expert + offset * stride_row_id_map_token,
row_id_within_token_block, row_id_within_token_block,
mask=offset < num_tokens, mask=offset < num_tokens,
) )
...@@ -50,6 +118,9 @@ def _row_id_map_pass_2_kernel( ...@@ -50,6 +118,9 @@ def _row_id_map_pass_2_kernel(
workspace_ptr, workspace_ptr,
# sizes # sizes
num_tokens, num_tokens,
# strides
stride_row_id_map_token,
stride_row_id_map_expert,
# metas # metas
WORKSPACE_LOAD_WIDTH: tl.constexpr, WORKSPACE_LOAD_WIDTH: tl.constexpr,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
...@@ -59,7 +130,9 @@ def _row_id_map_pass_2_kernel( ...@@ -59,7 +130,9 @@ def _row_id_map_pass_2_kernel(
chunk_idx = pid_m * tl.cdiv(num_tokens, BLOCK_SIZE) + pid_n chunk_idx = pid_m * tl.cdiv(num_tokens, BLOCK_SIZE) + pid_n
offset = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) offset = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
row_id_within_token_block = tl.load( row_id_within_token_block = tl.load(
row_id_map_ptr + pid_m * num_tokens + offset, mask=(offset < num_tokens), other=0 row_id_map_ptr + pid_m * stride_row_id_map_expert + offset * stride_row_id_map_token,
mask=(offset < num_tokens),
other=0,
) )
workspace_off = tl.arange(0, WORKSPACE_LOAD_WIDTH) workspace_off = tl.arange(0, WORKSPACE_LOAD_WIDTH)
...@@ -70,23 +143,102 @@ def _row_id_map_pass_2_kernel( ...@@ -70,23 +143,102 @@ def _row_id_map_pass_2_kernel(
row_id_within_token_block + tl.sum(n_tokens_per_chunk) - 1, row_id_within_token_block + tl.sum(n_tokens_per_chunk) - 1,
) )
tl.store( tl.store(
row_id_map_ptr + pid_m * num_tokens + offset, row_id_map_ptr + pid_m * stride_row_id_map_expert + offset * stride_row_id_map_token,
row_id, row_id,
mask=(offset < num_tokens), mask=(offset < num_tokens),
) )
@triton.jit
def _row_id_map_pass_3_kernel(
# pointers
row_id_map_ptr,
# sizes
num_experts: tl.constexpr,
# strides
stride_row_id_map_token,
stride_row_id_map_expert,
# metas
LOAD_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
n_dims: tl.constexpr = _log2(LOAD_SIZE)
off = tl.arange(0, LOAD_SIZE)
row_id_map = tl.load(
row_id_map_ptr + pid * stride_row_id_map_token + stride_row_id_map_expert * off,
mask=off < num_experts,
other=-1,
)
n_routed = tl.sum(tl.where(row_id_map != -1, 1, 0))
indices = off
sorted_map, indices = _argsort(row_id_map, indices, n_dims=n_dims)
tl.store(
row_id_map_ptr + pid * stride_row_id_map_token + off * stride_row_id_map_expert,
sorted_map,
mask=off < n_routed,
)
tl.store(
row_id_map_ptr
+ pid * stride_row_id_map_token
+ (num_experts + off) * stride_row_id_map_expert,
indices,
mask=off < n_routed,
)
tl.store(
row_id_map_ptr + pid * stride_row_id_map_token + num_experts * 2 * stride_row_id_map_expert,
n_routed,
)
def make_row_id_map( def make_row_id_map(
routing_map: torch.Tensor, routing_map: torch.Tensor,
num_tokens: int, num_tokens: int,
num_experts: int, num_experts: int,
): ):
# pylint: disable=missing-function-docstring """
row_id_map = torch.empty((num_experts, num_tokens), dtype=torch.int64, device="cuda") Prepare the row_id_map for the permutation.
block_size = 256
Parameters
----------
routing_map: torch.Tensor
Input tensor of shape `[num_tokens, num_experts]`. It is a mask tensor that indicates
which experts are routed to which tokens. The values in it: 1 means the token is routed to
this expert and 0 means not.
num_tokens: int
Number of tokens in the input tensor.
num_experts: int
Number of experts in the input tensor.
Returns
-------
row_id_map: torch.Tensor
The row_id_map for the permutation of shape `[num_tokens, num_experts * 2 + 1]`.
For each token, the last item is the number of experts that are routed (n_routed).
The first n_routed items are the destination row indices in the permuted tokens.
The [num_experts, num_experts + n_routed) items are the indices of the experts corresponding
to the first n_routed row indices above.
"""
row_id_map = torch.empty((num_tokens, num_experts * 2 + 1), dtype=torch.int32, device="cuda")
block_size = 1024
grid = (num_experts, triton.cdiv(num_tokens, block_size)) grid = (num_experts, triton.cdiv(num_tokens, block_size))
workspace_tensor = torch.empty(grid, dtype=torch.int64, device="cuda") workspace_tensor = torch.empty(grid, dtype=torch.int32, device="cuda")
# block cumsum
# supposing num_tokens == 5, num_experts == 3, block_size == 3
# and we have a routing_map like this:
# [[1, 1, 0],
# [1, 0, 1],
# [0, 0, 1],
# [1, 1, 0],
# [0, 0, 0]]
# pass 1: block cumsum
# for each expert, compute the cumsum of every block_size tokens
# the row_id_map will be like this after pass 1 (r means useless values):
# [[1, 1, 0, r, r, r, r],
# [2, 0, 1, r, r, r, r],
# [0, 0, 2, r, r, r, r],
# [1, 1, 0, r, r, r, r],
# [0, 0, 0, r, r, r, r]]
_row_id_map_pass_1_kernel[grid]( _row_id_map_pass_1_kernel[grid](
routing_map, routing_map,
row_id_map, row_id_map,
...@@ -94,16 +246,44 @@ def make_row_id_map( ...@@ -94,16 +246,44 @@ def make_row_id_map(
num_tokens, num_tokens,
routing_map.stride(0), routing_map.stride(0),
routing_map.stride(1), routing_map.stride(1),
row_id_map.stride(0),
row_id_map.stride(1),
block_size, block_size,
) )
# cumsum all and process the mask
# pass 2: cumsum all and process the mask
# process the block cumsum into the global cumsum and then into the dst row indices
# the row_id_map will be like this after pass 2 (r means useless value):
# [[ 0, 3, -1, r, r, r, r],
# [ 1, -1, 5, r, r, r, r],
# [-1, -1, 6, r, r, r, r],
# [ 2, 4, -1, r, r, r, r],
# [-1, -1, -1, r, r, r, r]]
_row_id_map_pass_2_kernel[grid]( _row_id_map_pass_2_kernel[grid](
row_id_map, row_id_map,
workspace_tensor, workspace_tensor,
num_tokens, num_tokens,
row_id_map.stride(0),
row_id_map.stride(1),
triton.next_power_of_2(num_experts * triton.cdiv(num_tokens, block_size)), triton.next_power_of_2(num_experts * triton.cdiv(num_tokens, block_size)),
block_size, block_size,
) )
# pass 3: make the row_id_map from the sparse structure to the dense structure
# the row_id_map will be like this after pass 3 (r means useless value):
# [[3, 0, r, 1, 0, r, 2],
# [5, 1, r, 2, 0, r, 2],
# [6, r, r, 2, r, r, 1],
# [4, 2, r, 1, 0, r, 2],
# [r, r, r, r, r, r, 0]]
grid = (num_tokens,)
_row_id_map_pass_3_kernel[grid](
row_id_map,
num_experts,
row_id_map.stride(0),
row_id_map.stride(1),
triton.next_power_of_2(num_experts),
)
return row_id_map return row_id_map
...@@ -118,11 +298,12 @@ def _permute_kernel( ...@@ -118,11 +298,12 @@ def _permute_kernel(
permuted_probs_ptr, permuted_probs_ptr,
permuted_scale_ptr, permuted_scale_ptr,
# sizes # sizes
num_tokens, num_experts: tl.constexpr,
num_experts, hidden_size: tl.constexpr,
hidden_size,
scale_hidden_dim, scale_hidden_dim,
# strides # strides
stride_row_id_map_token,
stride_row_id_map_expert,
stride_input_token, stride_input_token,
stride_input_hidden, stride_input_hidden,
stride_output_token, stride_output_token,
...@@ -139,35 +320,50 @@ def _permute_kernel( ...@@ -139,35 +320,50 @@ def _permute_kernel(
PERMUTE_SCALE: tl.constexpr, PERMUTE_SCALE: tl.constexpr,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
): ):
pid = tl.program_id(0) pid_t = tl.program_id(0)
cur_pos = 0 pid_h = tl.program_id(1)
while cur_pos < hidden_size: cur_off = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
cur_off = cur_pos + tl.arange(0, BLOCK_SIZE) mask = cur_off < hidden_size
mask = cur_off < hidden_size input_off = pid_t * stride_input_token + cur_off * stride_input_hidden
input_off = pid * stride_input_token + cur_off * stride_input_hidden inp = tl.load(input_ptr + input_off, mask=mask)
inp = tl.load(input_ptr + input_off, mask=mask) if PERMUTE_SCALE:
mask_scale = cur_off < scale_hidden_dim
scale_off = pid_t * stride_scale_token + cur_off * stride_scale_hidden
scale = tl.load(scale_ptr + scale_off, mask=mask_scale)
n_routed = tl.load(
row_id_map_ptr
+ pid_t * stride_row_id_map_token
+ num_experts * 2 * stride_row_id_map_expert
)
for idx in tl.range(n_routed):
dst_row = tl.load(
row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert
)
output_off = dst_row * stride_output_token + cur_off * stride_output_hidden
if PERMUTE_SCALE: if PERMUTE_SCALE:
mask_scale = cur_off < scale_hidden_dim permuted_scale_off = (
scale_off = pid * stride_scale_token + cur_off * stride_scale_hidden dst_row * stride_permuted_scale_token + cur_off * stride_permuted_scale_hidden
scale = tl.load(scale_ptr + scale_off, mask=mask_scale) )
for expert_idx in range(num_experts): tl.store(permuted_scale_ptr + permuted_scale_off, scale, mask=mask_scale)
dst_row = tl.load(row_id_map_ptr + expert_idx * num_tokens + pid) if PERMUTE_PROBS:
if dst_row != -1: expert_idx = tl.load(
output_off = dst_row * stride_output_token + cur_off * stride_output_hidden row_id_map_ptr
+ pid_t * stride_row_id_map_token
+ (num_experts + idx) * stride_row_id_map_expert
)
prob_off = pid_t * stride_probs_token + expert_idx * stride_probs_expert
prob = tl.load(probs_ptr + prob_off)
if pid_h == 0:
permuted_prob_off = dst_row * stride_permuted_probs_token
tl.store(permuted_probs_ptr + permuted_prob_off, prob)
if prob == 0.0:
# for routing_map padding
# dst_row != -1 and prob == 0.0 means that this slot is padded
tl.store(output_ptr + output_off, 0, mask=mask)
else:
tl.store(output_ptr + output_off, inp, mask=mask) tl.store(output_ptr + output_off, inp, mask=mask)
if PERMUTE_SCALE: else:
permuted_scale_off = ( tl.store(output_ptr + output_off, inp, mask=mask)
dst_row * stride_permuted_scale_token
+ cur_off * stride_permuted_scale_hidden
)
tl.store(permuted_scale_ptr + permuted_scale_off, scale, mask=mask_scale)
if PERMUTE_PROBS:
if cur_pos == 0:
prob_off = pid * stride_probs_token + expert_idx * stride_probs_expert
prob = tl.load(probs_ptr + prob_off)
permuted_prob_off = dst_row * stride_permuted_probs_token
tl.store(permuted_probs_ptr + permuted_prob_off, prob)
cur_pos += BLOCK_SIZE
try: try:
...@@ -178,6 +374,8 @@ try: ...@@ -178,6 +374,8 @@ try:
triton.Config({"BLOCK_SIZE": 256}), triton.Config({"BLOCK_SIZE": 256}),
triton.Config({"BLOCK_SIZE": 512}), triton.Config({"BLOCK_SIZE": 512}),
triton.Config({"BLOCK_SIZE": 1024}), triton.Config({"BLOCK_SIZE": 1024}),
triton.Config({"BLOCK_SIZE": 2048}),
triton.Config({"BLOCK_SIZE": 4096}),
], ],
key=["hidden_size"], key=["hidden_size"],
)(_permute_kernel) )(_permute_kernel)
...@@ -196,7 +394,30 @@ def permute_with_mask_map( ...@@ -196,7 +394,30 @@ def permute_with_mask_map(
hidden_size: int, hidden_size: int,
scale_hidden_dim: int, scale_hidden_dim: int,
): ):
# pylint: disable=missing-function-docstring """
Permute the input tensor based on the row_id_map.
Parameters
----------
inp: torch.Tensor
Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied.
row_id_map: torch.Tensor
The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
probs: torch.Tensor
The probabilities of the input tensor. If it is not None, it will be permuted.
scale: torch.Tensor
The scale of the input tensor. If it is not None, it will be permuted.
num_tokens: int
Number of tokens in the input tensor.
num_experts: int
Number of experts in the input tensor.
num_out_tokens: int
Number of tokens in the permuted tensor.
hidden_size: int
Hidden size of the input tensor.
scale_hidden_dim: int
Hidden size of the scale tensor.
"""
output = torch.empty((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda") output = torch.empty((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda")
if probs is not None: if probs is not None:
permuted_probs = torch.empty((num_out_tokens,), dtype=probs.dtype, device="cuda") permuted_probs = torch.empty((num_out_tokens,), dtype=probs.dtype, device="cuda")
...@@ -209,8 +430,8 @@ def permute_with_mask_map( ...@@ -209,8 +430,8 @@ def permute_with_mask_map(
) )
else: else:
permuted_scale = None permuted_scale = None
# pylint: disable=unnecessary-lambda-assignment
grid = (num_tokens,) grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"]))
_permute_kernel[grid]( _permute_kernel[grid](
inp, inp,
output, output,
...@@ -219,10 +440,11 @@ def permute_with_mask_map( ...@@ -219,10 +440,11 @@ def permute_with_mask_map(
scale, scale,
permuted_probs, permuted_probs,
permuted_scale, permuted_scale,
num_tokens,
num_experts, num_experts,
hidden_size, hidden_size,
scale_hidden_dim, scale_hidden_dim,
row_id_map.stride(0),
row_id_map.stride(1),
inp.stride(0), inp.stride(0),
inp.stride(1), inp.stride(1),
output.stride(0), output.stride(0),
...@@ -250,10 +472,11 @@ def _unpermute_kernel( ...@@ -250,10 +472,11 @@ def _unpermute_kernel(
permuted_probs_ptr, permuted_probs_ptr,
unpermuted_probs_ptr, unpermuted_probs_ptr,
# sizes # sizes
num_tokens, num_experts: tl.constexpr,
num_experts, hidden_size: tl.constexpr,
hidden_size,
# strides # strides
stride_row_id_map_token,
stride_row_id_map_expert,
stride_input_token, stride_input_token,
stride_input_hidden, stride_input_hidden,
stride_output_token, stride_output_token,
...@@ -264,6 +487,7 @@ def _unpermute_kernel( ...@@ -264,6 +487,7 @@ def _unpermute_kernel(
stride_unpermuted_probs_token, stride_unpermuted_probs_token,
stride_unpermuted_probs_expert, stride_unpermuted_probs_expert,
# metas # metas
PROBS_LOAD_WIDTH: tl.constexpr,
WITH_MERGING_PROBS: tl.constexpr, WITH_MERGING_PROBS: tl.constexpr,
PERMUTE_PROBS: tl.constexpr, PERMUTE_PROBS: tl.constexpr,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
...@@ -271,41 +495,63 @@ def _unpermute_kernel( ...@@ -271,41 +495,63 @@ def _unpermute_kernel(
data_type = input_ptr.dtype.element_ty data_type = input_ptr.dtype.element_ty
compute_type = tl.float32 compute_type = tl.float32
pid = tl.program_id(0) pid_t = tl.program_id(0)
current_start = 0 pid_h = tl.program_id(1)
while current_start < hidden_size: current_offset = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
current_offset = current_start + tl.arange(0, BLOCK_SIZE) mask = current_offset < hidden_size
mask = current_offset < hidden_size if PERMUTE_PROBS:
accumulator = tl.zeros((BLOCK_SIZE,), dtype=compute_type) # write 0.0 to probs_grad that are not routed
for expert_idx in range(num_experts): if pid_h == 0:
src_row = tl.load(row_id_map_ptr + expert_idx * num_tokens + pid) map_load_off = tl.arange(0, PROBS_LOAD_WIDTH)
if src_row != -1: unpermuted_prob_off = (
input_off = src_row * stride_input_token + current_offset * stride_input_hidden pid_t * stride_unpermuted_probs_token
inp = tl.load(input_ptr + input_off, mask=mask) + stride_unpermuted_probs_expert * map_load_off
inp = inp.to(compute_type) )
if WITH_MERGING_PROBS: tl.store(
merging_prob_off = ( unpermuted_probs_ptr + unpermuted_prob_off, 0.0, mask=map_load_off < num_experts
pid * stride_merging_probs_token + expert_idx * stride_merging_probs_expert )
) accumulator = tl.zeros((BLOCK_SIZE,), dtype=compute_type)
merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type) n_routed = tl.load(
inp *= merging_prob row_id_map_ptr
accumulator += inp + pid_t * stride_row_id_map_token
if PERMUTE_PROBS: + num_experts * 2 * stride_row_id_map_expert
if current_start == 0: )
unpermuted_prob_off = ( for idx in tl.range(n_routed):
pid * stride_unpermuted_probs_token src_row = tl.load(
+ expert_idx * stride_unpermuted_probs_expert row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert
) )
if src_row != -1: input_off = src_row * stride_input_token + current_offset * stride_input_hidden
permuted_prob_off = src_row * stride_permuted_probs_token inp = tl.load(input_ptr + input_off, mask=mask)
prob = tl.load(permuted_probs_ptr + permuted_prob_off) inp = inp.to(compute_type)
tl.store(unpermuted_probs_ptr + unpermuted_prob_off, prob) if WITH_MERGING_PROBS:
else: expert_idx = tl.load(
tl.store(unpermuted_probs_ptr + unpermuted_prob_off, 0.0) row_id_map_ptr
accumulator = accumulator.to(data_type) + pid_t * stride_row_id_map_token
output_off = pid * stride_output_token + current_offset * stride_output_hidden + (num_experts + idx) * stride_row_id_map_expert
tl.store(output_ptr + output_off, accumulator, mask=mask) )
current_start += BLOCK_SIZE merging_prob_off = (
pid_t * stride_merging_probs_token + expert_idx * stride_merging_probs_expert
)
merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type)
inp *= merging_prob
accumulator += inp
if PERMUTE_PROBS:
if pid_h == 0:
expert_idx = tl.load(
row_id_map_ptr
+ pid_t * stride_row_id_map_token
+ (num_experts + idx) * stride_row_id_map_expert
)
unpermuted_prob_off = (
pid_t * stride_unpermuted_probs_token
+ expert_idx * stride_unpermuted_probs_expert
)
permuted_prob_off = src_row * stride_permuted_probs_token
prob = tl.load(permuted_probs_ptr + permuted_prob_off)
tl.store(unpermuted_probs_ptr + unpermuted_prob_off, prob)
accumulator = accumulator.to(data_type)
output_off = pid_t * stride_output_token + current_offset * stride_output_hidden
tl.store(output_ptr + output_off, accumulator, mask=mask)
try: try:
...@@ -316,6 +562,8 @@ try: ...@@ -316,6 +562,8 @@ try:
triton.Config({"BLOCK_SIZE": 256}), triton.Config({"BLOCK_SIZE": 256}),
triton.Config({"BLOCK_SIZE": 512}), triton.Config({"BLOCK_SIZE": 512}),
triton.Config({"BLOCK_SIZE": 1024}), triton.Config({"BLOCK_SIZE": 1024}),
triton.Config({"BLOCK_SIZE": 2048}),
triton.Config({"BLOCK_SIZE": 4096}),
], ],
key=["hidden_size"], key=["hidden_size"],
)(_unpermute_kernel) )(_unpermute_kernel)
...@@ -332,7 +580,27 @@ def unpermute_with_mask_map( ...@@ -332,7 +580,27 @@ def unpermute_with_mask_map(
num_experts: int, num_experts: int,
hidden_size: int, hidden_size: int,
): ):
# pylint: disable=missing-function-docstring """
Unpermute the input tensor based on the row_id_map.
Parameters
----------
inp: torch.Tensor
Input tensor of shape `[num_out_tokens, hidden_size]`.
row_id_map: torch.Tensor
The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
merging_probs: torch.Tensor
The merging probabilities of the input tensor. If it is not None, it will be used as weights
to reduce the unpermuted tokens.
permuted_probs: torch.Tensor
The permuted probabilities of the input tensor. If it is not None, it will be unpermuted.
num_tokens: int
Number of tokens in the permuted tensor.
num_experts: int
Number of experts in the permuted tensor.
hidden_size: int
Hidden size of the permuted tensor.
"""
output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda") output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda")
if permuted_probs is not None: if permuted_probs is not None:
unpermuted_probs = torch.empty( unpermuted_probs = torch.empty(
...@@ -340,7 +608,8 @@ def unpermute_with_mask_map( ...@@ -340,7 +608,8 @@ def unpermute_with_mask_map(
) )
else: else:
unpermuted_probs = None unpermuted_probs = None
grid = (num_tokens,) # pylint: disable=unnecessary-lambda-assignment
grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"]))
_unpermute_kernel[grid]( _unpermute_kernel[grid](
inp, inp,
output, output,
...@@ -348,9 +617,10 @@ def unpermute_with_mask_map( ...@@ -348,9 +617,10 @@ def unpermute_with_mask_map(
merging_probs, merging_probs,
permuted_probs, permuted_probs,
unpermuted_probs, unpermuted_probs,
num_tokens,
num_experts, num_experts,
hidden_size, hidden_size,
row_id_map.stride(0),
row_id_map.stride(1),
inp.stride(0), inp.stride(0),
inp.stride(1), inp.stride(1),
output.stride(0), output.stride(0),
...@@ -360,6 +630,7 @@ def unpermute_with_mask_map( ...@@ -360,6 +630,7 @@ def unpermute_with_mask_map(
permuted_probs.stride(0) if permuted_probs is not None else None, permuted_probs.stride(0) if permuted_probs is not None else None,
unpermuted_probs.stride(0) if unpermuted_probs is not None else None, unpermuted_probs.stride(0) if unpermuted_probs is not None else None,
unpermuted_probs.stride(1) if unpermuted_probs is not None else None, unpermuted_probs.stride(1) if unpermuted_probs is not None else None,
PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts),
WITH_MERGING_PROBS=merging_probs is not None, WITH_MERGING_PROBS=merging_probs is not None,
PERMUTE_PROBS=permuted_probs is not None, PERMUTE_PROBS=permuted_probs is not None,
) )
...@@ -376,10 +647,11 @@ def _unpermute_bwd_with_merging_probs_kernel( ...@@ -376,10 +647,11 @@ def _unpermute_bwd_with_merging_probs_kernel(
merging_probs_grad_ptr, merging_probs_grad_ptr,
row_id_map_ptr, row_id_map_ptr,
# sizes # sizes
num_tokens, num_experts: tl.constexpr,
num_experts, hidden_size: tl.constexpr,
hidden_size,
# strides # strides
stride_row_id_map_token,
stride_row_id_map_expert,
stride_fwd_output_grad_token, stride_fwd_output_grad_token,
stride_fwd_output_grad_hidden, stride_fwd_output_grad_hidden,
stride_fwd_input_grad_token, stride_fwd_input_grad_token,
...@@ -391,56 +663,63 @@ def _unpermute_bwd_with_merging_probs_kernel( ...@@ -391,56 +663,63 @@ def _unpermute_bwd_with_merging_probs_kernel(
stride_merging_probs_grad_token, stride_merging_probs_grad_token,
stride_merging_probs_grad_expert, stride_merging_probs_grad_expert,
# metas # metas
PROBS_LOAD_WIDTH: tl.constexpr,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
): ):
data_type = fwd_output_grad_ptr.dtype.element_ty data_type = fwd_output_grad_ptr.dtype.element_ty
compute_type = tl.float32 compute_type = tl.float32
pid = tl.program_id(0) pid = tl.program_id(0)
for expert_idx in range(num_experts): map_load_off = tl.arange(0, PROBS_LOAD_WIDTH)
dst_row = tl.load(row_id_map_ptr + expert_idx * num_tokens + pid) token_probs_grad_off = (
if dst_row != -1: pid * stride_merging_probs_grad_token + stride_merging_probs_grad_expert * map_load_off
prob_grad_accum = tl.zeros((BLOCK_SIZE,), dtype=compute_type) )
current_start = 0 tl.store(merging_probs_grad_ptr + token_probs_grad_off, 0.0, mask=map_load_off < num_experts)
while current_start < hidden_size: n_routed = tl.load(
current_offset = current_start + tl.arange(0, BLOCK_SIZE) row_id_map_ptr + pid * stride_row_id_map_token + num_experts * 2 * stride_row_id_map_expert
mask = current_offset < hidden_size )
input_off = ( for idx in tl.range(n_routed):
pid * stride_fwd_output_grad_token dst_row = tl.load(
+ current_offset * stride_fwd_output_grad_hidden row_id_map_ptr + pid * stride_row_id_map_token + idx * stride_row_id_map_expert
) )
inp = tl.load(fwd_output_grad_ptr + input_off, mask=mask) expert_idx = tl.load(
inp = inp.to(compute_type) row_id_map_ptr
merging_prob_off = ( + pid * stride_row_id_map_token
pid * stride_merging_probs_token + expert_idx * stride_merging_probs_expert + (num_experts + idx) * stride_row_id_map_expert
) )
merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type) prob_grad_accum = tl.zeros((BLOCK_SIZE,), dtype=compute_type)
output = inp * merging_prob current_start = 0
output = output.to(data_type) while current_start < hidden_size:
output_off = ( current_offset = current_start + tl.arange(0, BLOCK_SIZE)
dst_row * stride_fwd_input_grad_token mask = current_offset < hidden_size
+ current_offset * stride_fwd_input_grad_hidden input_off = (
) pid * stride_fwd_output_grad_token + current_offset * stride_fwd_output_grad_hidden
tl.store(fwd_input_grad_ptr + output_off, output, mask=mask)
fwd_input_off = (
dst_row * stride_fwd_input_token + current_offset * stride_fwd_input_hidden
)
fwd_input = tl.load(fwd_input_ptr + fwd_input_off, mask=mask)
prob_grad_accum += fwd_input.to(compute_type) * inp
current_start += BLOCK_SIZE
probs_grad = tl.sum(prob_grad_accum).to(merging_probs_grad_ptr.dtype.element_ty)
probs_grad_off = (
pid * stride_merging_probs_grad_token
+ expert_idx * stride_merging_probs_grad_expert
) )
tl.store(merging_probs_grad_ptr + probs_grad_off, probs_grad) inp = tl.load(fwd_output_grad_ptr + input_off, mask=mask)
else: inp = inp.to(compute_type)
probs_grad_off = ( merging_prob_off = (
pid * stride_merging_probs_grad_token pid * stride_merging_probs_token + expert_idx * stride_merging_probs_expert
+ expert_idx * stride_merging_probs_grad_expert )
merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type)
output = inp * merging_prob
output = output.to(data_type)
output_off = (
dst_row * stride_fwd_input_grad_token
+ current_offset * stride_fwd_input_grad_hidden
) )
tl.store(merging_probs_grad_ptr + probs_grad_off, 0.0) tl.store(fwd_input_grad_ptr + output_off, output, mask=mask)
fwd_input_off = (
dst_row * stride_fwd_input_token + current_offset * stride_fwd_input_hidden
)
fwd_input = tl.load(fwd_input_ptr + fwd_input_off, mask=mask)
prob_grad_accum += fwd_input.to(compute_type) * inp
current_start += BLOCK_SIZE
probs_grad = tl.sum(prob_grad_accum).to(merging_probs_grad_ptr.dtype.element_ty)
probs_grad_off = (
pid * stride_merging_probs_grad_token + expert_idx * stride_merging_probs_grad_expert
)
tl.store(merging_probs_grad_ptr + probs_grad_off, probs_grad)
try: try:
...@@ -451,6 +730,8 @@ try: ...@@ -451,6 +730,8 @@ try:
triton.Config({"BLOCK_SIZE": 256}), triton.Config({"BLOCK_SIZE": 256}),
triton.Config({"BLOCK_SIZE": 512}), triton.Config({"BLOCK_SIZE": 512}),
triton.Config({"BLOCK_SIZE": 1024}), triton.Config({"BLOCK_SIZE": 1024}),
triton.Config({"BLOCK_SIZE": 2048}),
triton.Config({"BLOCK_SIZE": 4096}),
], ],
key=["hidden_size"], key=["hidden_size"],
)(_unpermute_bwd_with_merging_probs_kernel) )(_unpermute_bwd_with_merging_probs_kernel)
...@@ -468,7 +749,28 @@ def unpermute_with_mask_map_bwd_with_merging_probs( ...@@ -468,7 +749,28 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
num_out_tokens: int, num_out_tokens: int,
hidden_size: int, hidden_size: int,
): ):
# pylint: disable=missing-function-docstring """
Unpermute backward pass kernel with merging probs.
Parameters
----------
fwd_output_grad: torch.Tensor
The gradient of the output tensor of shape `[num_tokens, hidden_size]`.
row_id_map: torch.Tensor
The token to expert mapping tensor of shape `[num_tokens, num_experts * 2 + 1]`.
fwd_input: torch.Tensor
The input tensor of the forward pass of shape `[num_out_tokens, hidden_size]`.
merging_probs: torch.Tensor
The merging probabilities of the input tensor of shape `[num_tokens, num_experts]`.
num_tokens: int
Number of tokens in the permuted tensor.
num_experts: int
Number of experts in the permuted tensor.
num_out_tokens: int
Number of tokens in the output tensor.
hidden_size: int
Hidden size of the output tensor.
"""
act_grad = torch.empty( act_grad = torch.empty(
(num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda" (num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda"
) )
...@@ -483,9 +785,10 @@ def unpermute_with_mask_map_bwd_with_merging_probs( ...@@ -483,9 +785,10 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
merging_probs, merging_probs,
merging_probs_grad, merging_probs_grad,
row_id_map, row_id_map,
num_tokens,
num_experts, num_experts,
hidden_size, hidden_size,
row_id_map.stride(0),
row_id_map.stride(1),
fwd_output_grad.stride(0), fwd_output_grad.stride(0),
fwd_output_grad.stride(1), fwd_output_grad.stride(1),
act_grad.stride(0), act_grad.stride(0),
...@@ -496,34 +799,21 @@ def unpermute_with_mask_map_bwd_with_merging_probs( ...@@ -496,34 +799,21 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
merging_probs.stride(1), merging_probs.stride(1),
merging_probs_grad.stride(0), merging_probs_grad.stride(0),
merging_probs_grad.stride(1), merging_probs_grad.stride(1),
PROBS_LOAD_WIDTH=triton.next_power_of_2(num_experts),
) )
return act_grad, merging_probs_grad return act_grad, merging_probs_grad
@triton.jit @triton.jit
def _sort_chunks_by_idxs_kernel( def _make_chunk_sort_map_kernel(
# pointers # pointers
input_ptr,
split_sizes_ptr, split_sizes_ptr,
sorted_indices_ptr, sorted_indices_ptr,
output_ptr,
dst_rows_ptr, dst_rows_ptr,
probs_ptr,
permuted_probs_ptr,
# sizes # sizes
num_splits, num_splits: tl.constexpr,
hidden_size,
# strides
stride_input_token,
stride_input_hidden,
stride_output_token,
stride_output_hidden,
stride_probs_token,
stride_permuted_probs_token,
# metas # metas
PERMUTE_PROBS: tl.constexpr,
IDX_LOAD_WIDTH: tl.constexpr, IDX_LOAD_WIDTH: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
): ):
pid = tl.program_id(0) pid = tl.program_id(0)
...@@ -533,104 +823,58 @@ def _sort_chunks_by_idxs_kernel( ...@@ -533,104 +823,58 @@ def _sort_chunks_by_idxs_kernel(
) )
# get chunk idx of the current token in the input tensor # get chunk idx of the current token in the input tensor
input_chunk_idx = -1 input_split_sizes = tl.load(
in_chunk_offset = tl.zeros([], dtype=tl.int64) split_sizes_ptr + load_split_offset, mask=load_split_offset < num_splits, other=0
acc_chunk_sizes = tl.zeros([], dtype=tl.int64) ).to(tl.int32)
cursor = 0 input_split_sizes_cumsum = tl.cumsum(input_split_sizes)
while cursor < num_splits: input_split_sizes_mask = tl.where(input_split_sizes_cumsum <= pid, 1, 0)
cur_chunk_size = tl.load(split_sizes_ptr + cursor).to(tl.int64) input_chunk_idx = tl.sum(input_split_sizes_mask)
acc_chunk_sizes += cur_chunk_size input_split_sizes_presum = tl.sum(input_split_sizes * input_split_sizes_mask)
if input_chunk_idx == -1 and acc_chunk_sizes > pid: in_chunk_offset = pid - input_split_sizes_presum
input_chunk_idx = cursor
in_chunk_offset = pid - (acc_chunk_sizes - cur_chunk_size)
cursor += 1
# get chunk idx of the current token in the output tensor # get chunk idx of the current token in the output tensor
output_chunk_idx = 0 output_chunk_mask = tl.where(sorted_indices == input_chunk_idx, 1, 0)
cursor = 0 output_chunk_idx = tl.argmax(output_chunk_mask, axis=-1)
while cursor < num_splits:
cur_input_idx = tl.load(sorted_indices_ptr + cursor)
if cur_input_idx == input_chunk_idx:
output_chunk_idx = cursor
cursor += 1
# make row_id_map # make row_id_map
output_split_sizes = tl.load( output_split_sizes = tl.load(
split_sizes_ptr + sorted_indices, mask=load_split_offset < num_splits split_sizes_ptr + sorted_indices, mask=load_split_offset < num_splits
).to(tl.int64) ).to(tl.int32)
output_pre_split_sizes = tl.where(load_split_offset < output_chunk_idx, output_split_sizes, 0) output_pre_split_sizes = tl.where(load_split_offset < output_chunk_idx, output_split_sizes, 0)
dst_row = tl.sum(output_pre_split_sizes) + in_chunk_offset dst_row = tl.sum(output_pre_split_sizes) + in_chunk_offset
tl.store(dst_rows_ptr + pid, dst_row) tl.store(dst_rows_ptr + pid, dst_row)
current_start = 0
while current_start < hidden_size:
current_offset = current_start + tl.arange(0, BLOCK_SIZE)
mask = current_offset < hidden_size
input_offsets = pid * stride_input_token + current_offset * stride_input_hidden
output_offsets = dst_row * stride_output_token + current_offset * stride_output_hidden
inp = tl.load(input_ptr + input_offsets, mask=mask)
tl.store(output_ptr + output_offsets, inp, mask=mask)
current_start += BLOCK_SIZE
if PERMUTE_PROBS: def make_chunk_sort_map(
prob_off = pid * stride_probs_token
prob = tl.load(probs_ptr + prob_off)
permuted_prob_off = dst_row * stride_permuted_probs_token
tl.store(permuted_probs_ptr + permuted_prob_off, prob)
try:
_sort_chunks_by_idxs_kernel = triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 64}),
triton.Config({"BLOCK_SIZE": 128}),
triton.Config({"BLOCK_SIZE": 256}),
triton.Config({"BLOCK_SIZE": 512}),
triton.Config({"BLOCK_SIZE": 1024}),
],
key=["hidden_size"],
)(_sort_chunks_by_idxs_kernel)
except RuntimeError:
pass
def sort_chunks_by_idx(
inp: torch.Tensor,
split_sizes: torch.Tensor, split_sizes: torch.Tensor,
sorted_indices: torch.Tensor, sorted_indices: torch.Tensor,
probs: torch.Tensor,
num_tokens: int, num_tokens: int,
hidden_size: int,
num_splits: int, num_splits: int,
): ):
# pylint: disable=missing-function-docstring """
row_id_map = torch.empty((num_tokens,), dtype=torch.int64, device="cuda") Make a row_id_map for chunk sort.
output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda")
if probs is not None: Parameters
permuted_probs = torch.empty((num_tokens,), dtype=probs.dtype, device="cuda") ----------
else: split_sizes: torch.Tensor
permuted_probs = None The sizes of the chunks of shape `[num_splits,]`.
sorted_indices: torch.Tensor
The indices of the sorted chunks of shape `[num_splits,]`.
num_tokens: int
Number of tokens in the input tensor.
num_splits: int
Number of splits of split_sizes and sorted_indices.
"""
row_id_map = torch.empty((num_tokens,), dtype=torch.int32, device="cuda")
grid = (num_tokens,) grid = (num_tokens,)
_sort_chunks_by_idxs_kernel[grid]( _make_chunk_sort_map_kernel[grid](
inp,
split_sizes, split_sizes,
sorted_indices, sorted_indices,
output,
row_id_map, row_id_map,
probs,
permuted_probs,
num_splits, num_splits,
hidden_size,
inp.stride(0),
inp.stride(1),
output.stride(0),
output.stride(1),
probs.stride(0) if probs is not None else None,
permuted_probs.stride(0) if permuted_probs is not None else None,
PERMUTE_PROBS=probs is not None,
IDX_LOAD_WIDTH=triton.next_power_of_2(num_splits), IDX_LOAD_WIDTH=triton.next_power_of_2(num_splits),
) )
return output, row_id_map, permuted_probs return row_id_map
@triton.jit @triton.jit
...@@ -642,7 +886,7 @@ def _sort_chunks_by_map_kernel( ...@@ -642,7 +886,7 @@ def _sort_chunks_by_map_kernel(
probs_ptr, probs_ptr,
permuted_probs_ptr, permuted_probs_ptr,
# sizes # sizes
hidden_size, hidden_size: tl.constexpr,
# strides # strides
stride_input_token, stride_input_token,
stride_input_hidden, stride_input_hidden,
...@@ -653,23 +897,28 @@ def _sort_chunks_by_map_kernel( ...@@ -653,23 +897,28 @@ def _sort_chunks_by_map_kernel(
# metas # metas
PERMUTE_PROBS: tl.constexpr, PERMUTE_PROBS: tl.constexpr,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
FORWARD: tl.constexpr,
): ):
pid = tl.program_id(0) pid_t = tl.program_id(0)
dst_row = tl.load(row_id_map_ptr + pid) pid_h = tl.program_id(1)
current_start = 0 if FORWARD:
while current_start < hidden_size: src_row = pid_t
current_offset = current_start + tl.arange(0, BLOCK_SIZE) dst_row = tl.load(row_id_map_ptr + pid_t)
mask = current_offset < hidden_size else:
input_offsets = dst_row * stride_input_token + current_offset * stride_input_hidden src_row = tl.load(row_id_map_ptr + pid_t)
output_offsets = pid * stride_output_token + current_offset * stride_output_hidden dst_row = pid_t
inp = tl.load(input_ptr + input_offsets, mask=mask) current_offset = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
tl.store(output_ptr + output_offsets, inp, mask=mask) mask = current_offset < hidden_size
current_start += BLOCK_SIZE input_offsets = src_row * stride_input_token + current_offset * stride_input_hidden
output_offsets = dst_row * stride_output_token + current_offset * stride_output_hidden
inp = tl.load(input_ptr + input_offsets, mask=mask)
tl.store(output_ptr + output_offsets, inp, mask=mask)
if PERMUTE_PROBS: if PERMUTE_PROBS:
prob_off = dst_row * stride_probs_token if pid_h == 0:
prob = tl.load(probs_ptr + prob_off) prob_off = src_row * stride_probs_token
permuted_prob_off = pid * stride_permuted_probs_token prob = tl.load(probs_ptr + prob_off)
tl.store(permuted_probs_ptr + permuted_prob_off, prob) permuted_prob_off = dst_row * stride_permuted_probs_token
tl.store(permuted_probs_ptr + permuted_prob_off, prob)
try: try:
...@@ -680,6 +929,8 @@ try: ...@@ -680,6 +929,8 @@ try:
triton.Config({"BLOCK_SIZE": 256}), triton.Config({"BLOCK_SIZE": 256}),
triton.Config({"BLOCK_SIZE": 512}), triton.Config({"BLOCK_SIZE": 512}),
triton.Config({"BLOCK_SIZE": 1024}), triton.Config({"BLOCK_SIZE": 1024}),
triton.Config({"BLOCK_SIZE": 2048}),
triton.Config({"BLOCK_SIZE": 4096}),
], ],
key=["hidden_size"], key=["hidden_size"],
)(_sort_chunks_by_map_kernel) )(_sort_chunks_by_map_kernel)
...@@ -693,14 +944,33 @@ def sort_chunks_by_map( ...@@ -693,14 +944,33 @@ def sort_chunks_by_map(
probs: torch.Tensor, probs: torch.Tensor,
num_tokens: int, num_tokens: int,
hidden_size: int, hidden_size: int,
is_forward: bool,
): ):
# pylint: disable=missing-function-docstring """
Sort chunks with row_id_map.
Parameters
----------
inp: torch.Tensor
Input tensor of shape `[num_tokens, hidden_size]`.
row_id_map: torch.Tensor
The token to expert mapping tensor of shape `[num_tokens,]`.
probs: torch.Tensor
The probabilities of the input tensor. If it is not None, it will be permuted.
num_tokens: int
Number of tokens in the input tensor.
hidden_size: int
Hidden size of the input tensor.
is_forward: bool
Whether the sort is for forward or backward.
"""
output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda") output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda")
if probs is not None: if probs is not None:
permuted_probs = torch.empty((num_tokens,), dtype=probs.dtype, device="cuda") permuted_probs = torch.empty((num_tokens,), dtype=probs.dtype, device="cuda")
else: else:
permuted_probs = None permuted_probs = None
grid = (num_tokens,) # pylint: disable=unnecessary-lambda-assignment
grid = lambda META: (num_tokens, triton.cdiv(hidden_size, META["BLOCK_SIZE"]))
_sort_chunks_by_map_kernel[grid]( _sort_chunks_by_map_kernel[grid](
inp, inp,
output, output,
...@@ -715,5 +985,6 @@ def sort_chunks_by_map( ...@@ -715,5 +985,6 @@ def sort_chunks_by_map(
probs.stride(0) if probs is not None else None, probs.stride(0) if probs is not None else None,
permuted_probs.stride(0) if permuted_probs is not None else None, permuted_probs.stride(0) if permuted_probs is not None else None,
PERMUTE_PROBS=probs is not None, PERMUTE_PROBS=probs is not None,
FORWARD=is_forward,
) )
return output, permuted_probs return output, permuted_probs
...@@ -7,7 +7,7 @@ from __future__ import annotations ...@@ -7,7 +7,7 @@ from __future__ import annotations
import functools import functools
import math import math
import os import os
from typing import Any, Callable, List, Optional, Tuple, Union from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -654,3 +654,110 @@ else: ...@@ -654,3 +654,110 @@ else:
gpu_autocast_ctx = torch.cuda.amp.autocast gpu_autocast_ctx = torch.cuda.amp.autocast
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor from ..debug.pytorch.debug_quantization import DebugQuantizedTensor
_torch_dtype_to_np_typestr_dict = {
torch.float16: "<f2",
torch.float32: "<f4",
torch.int64: "<i8",
torch.int32: "<i4",
torch.int8: "|i1",
torch.float8_e4m3fn: "|i1",
torch.qint8: "|u1",
torch.bool: "|b1",
torch.bfloat16: "<f2",
}
class _WeakRefTensor:
"""
A wrapper wraps raw data pointer to a tensor-like object. Could be compatibale with openai triton kernel and be converted to `torch.Tensor` with zero-copy overhead.
"""
def __init__(
self,
data_ptr: int,
dtype: torch.dtype,
shape: Sequence[int],
):
self._data_ptr = data_ptr
self.dtype = dtype
self.shape = shape
def data_ptr(self):
"""Data pointer of the tensor."""
return self._data_ptr
@property
def dtype(self):
"""Dtype of the tensor."""
return self._dtype
@property
def shape(self):
"""Shape of the tensor."""
return getattr(self, "_shape", None)
@dtype.setter
def dtype(self, dtype: torch.dtype):
self._dtype = dtype
@shape.setter
def shape(self, shape: Sequence[int]):
self._shape = tuple(int(i) for i in shape)
def numel(self):
"""Number of elements in the tensor."""
return np.prod(self.shape)
@property
def __cuda_array_interface__(self):
return {
"shape": self.shape,
"typestr": self.torch_dtype_to_np_typestr(),
"data": (self.data_ptr() if self.numel() > 0 else 0, False),
"version": 3,
}
def torch_dtype_to_np_typestr(self):
"""Convert PyTorch dtype to numpy typestr."""
ret = _torch_dtype_to_np_typestr_dict.get(self.dtype)
assert ret is not None, f"Unsupported dtype: {self.dtype}"
return ret
def make_weak_ref(x):
"""
This function is to make a weak reference to the input so that the memory can be released.
"""
def convert_to_torch_tensor(tensor: Union[_WeakRefTensor, torch.Tensor]) -> torch.Tensor:
"""
This function is to convert the `_WeakRefTensor` to torch.Tensor.
"""
if isinstance(tensor, torch.Tensor):
return tensor
old_ptr = tensor.data_ptr()
new_tensor = torch.as_tensor(tensor).view(tensor.dtype)
new_ptr = new_tensor.data_ptr()
if old_ptr != new_ptr:
raise RuntimeError("Data pointer mismatch after converting to torch.Tensor")
return new_tensor
if isinstance(x, torch.Tensor):
return (
convert_to_torch_tensor(_WeakRefTensor(x.data_ptr(), x.dtype, x.shape))
if x.is_cuda
else x
)
if isinstance(x, tuple):
return tuple(make_weak_ref(i) for i in x)
if isinstance(x, list):
return [make_weak_ref(i) for i in x]
if isinstance(x, dict):
return {k: make_weak_ref(v) for k, v in x.items()}
if isinstance(x, (int, float, bool)):
return x
if x is None:
return None
raise TypeError(f"Invalid type {type(x)} to make weak ref")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment