Unverified Commit 12c3e323 authored by hx's avatar hx Committed by GitHub
Browse files

Fix import error on CPU only devices (#1578)



fix cpu device import error
Signed-off-by: default avatarHongxiao Bai <hongxiaob@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent dc40f9fd
...@@ -109,16 +109,6 @@ def make_row_id_map( ...@@ -109,16 +109,6 @@ def make_row_id_map(
return row_id_map return row_id_map
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 64}),
triton.Config({"BLOCK_SIZE": 128}),
triton.Config({"BLOCK_SIZE": 256}),
triton.Config({"BLOCK_SIZE": 512}),
triton.Config({"BLOCK_SIZE": 1024}),
],
key=["hidden_size"],
)
@triton.jit @triton.jit
def _permute_kernel( def _permute_kernel(
# pointers # pointers
...@@ -164,6 +154,21 @@ def _permute_kernel( ...@@ -164,6 +154,21 @@ def _permute_kernel(
cur_pos += BLOCK_SIZE cur_pos += BLOCK_SIZE
try:
_permute_kernel = triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 64}),
triton.Config({"BLOCK_SIZE": 128}),
triton.Config({"BLOCK_SIZE": 256}),
triton.Config({"BLOCK_SIZE": 512}),
triton.Config({"BLOCK_SIZE": 1024}),
],
key=["hidden_size"],
)(_permute_kernel)
except RuntimeError:
pass
def permute_with_mask_map( def permute_with_mask_map(
inp: torch.Tensor, inp: torch.Tensor,
row_id_map: torch.Tensor, row_id_map: torch.Tensor,
...@@ -201,16 +206,6 @@ def permute_with_mask_map( ...@@ -201,16 +206,6 @@ def permute_with_mask_map(
return output, permuted_probs return output, permuted_probs
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 64}),
triton.Config({"BLOCK_SIZE": 128}),
triton.Config({"BLOCK_SIZE": 256}),
triton.Config({"BLOCK_SIZE": 512}),
triton.Config({"BLOCK_SIZE": 1024}),
],
key=["hidden_size"],
)
@triton.jit @triton.jit
def _unpermute_kernel( def _unpermute_kernel(
# pointers # pointers
...@@ -297,6 +292,21 @@ def _unpermute_kernel( ...@@ -297,6 +292,21 @@ def _unpermute_kernel(
current_start += BLOCK_SIZE current_start += BLOCK_SIZE
try:
_unpermute_kernel = triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 64}),
triton.Config({"BLOCK_SIZE": 128}),
triton.Config({"BLOCK_SIZE": 256}),
triton.Config({"BLOCK_SIZE": 512}),
triton.Config({"BLOCK_SIZE": 1024}),
],
key=["hidden_size"],
)(_unpermute_kernel)
except RuntimeError:
pass
def unpermute_with_mask_map( def unpermute_with_mask_map(
inp: torch.Tensor, inp: torch.Tensor,
row_id_map: torch.Tensor, row_id_map: torch.Tensor,
...@@ -348,16 +358,6 @@ def unpermute_with_mask_map( ...@@ -348,16 +358,6 @@ def unpermute_with_mask_map(
return output, unpermuted_probs return output, unpermuted_probs
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 64}),
triton.Config({"BLOCK_SIZE": 128}),
triton.Config({"BLOCK_SIZE": 256}),
triton.Config({"BLOCK_SIZE": 512}),
triton.Config({"BLOCK_SIZE": 1024}),
],
key=["hidden_size"],
)
@triton.jit @triton.jit
def _unpermute_bwd_with_merging_probs_kernel( def _unpermute_bwd_with_merging_probs_kernel(
# pointers # pointers
...@@ -450,6 +450,21 @@ def _unpermute_bwd_with_merging_probs_kernel( ...@@ -450,6 +450,21 @@ def _unpermute_bwd_with_merging_probs_kernel(
tl.store(merging_probs_grad_ptr + probs_grad_off, 0.0) tl.store(merging_probs_grad_ptr + probs_grad_off, 0.0)
try:
_unpermute_bwd_with_merging_probs_kernel = triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 64}),
triton.Config({"BLOCK_SIZE": 128}),
triton.Config({"BLOCK_SIZE": 256}),
triton.Config({"BLOCK_SIZE": 512}),
triton.Config({"BLOCK_SIZE": 1024}),
],
key=["hidden_size"],
)(_unpermute_bwd_with_merging_probs_kernel)
except RuntimeError:
pass
def unpermute_with_mask_map_bwd_with_merging_probs( def unpermute_with_mask_map_bwd_with_merging_probs(
fwd_output_grad: torch.Tensor, fwd_output_grad: torch.Tensor,
row_id_map: torch.Tensor, row_id_map: torch.Tensor,
...@@ -500,16 +515,6 @@ def unpermute_with_mask_map_bwd_with_merging_probs( ...@@ -500,16 +515,6 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
return act_grad, merging_probs_grad return act_grad, merging_probs_grad
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 64}),
triton.Config({"BLOCK_SIZE": 128}),
triton.Config({"BLOCK_SIZE": 256}),
triton.Config({"BLOCK_SIZE": 512}),
triton.Config({"BLOCK_SIZE": 1024}),
],
key=["hidden_size"],
)
@triton.jit @triton.jit
def _sort_chunks_by_idxs_kernel( def _sort_chunks_by_idxs_kernel(
# pointers # pointers
...@@ -589,6 +594,21 @@ def _sort_chunks_by_idxs_kernel( ...@@ -589,6 +594,21 @@ def _sort_chunks_by_idxs_kernel(
tl.store(permuted_probs_ptr + permuted_prob_off, prob) tl.store(permuted_probs_ptr + permuted_prob_off, prob)
try:
_sort_chunks_by_idxs_kernel = triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 64}),
triton.Config({"BLOCK_SIZE": 128}),
triton.Config({"BLOCK_SIZE": 256}),
triton.Config({"BLOCK_SIZE": 512}),
triton.Config({"BLOCK_SIZE": 1024}),
],
key=["hidden_size"],
)(_sort_chunks_by_idxs_kernel)
except RuntimeError:
pass
def sort_chunks_by_idx( def sort_chunks_by_idx(
inp: torch.Tensor, inp: torch.Tensor,
split_sizes: torch.Tensor, split_sizes: torch.Tensor,
...@@ -628,18 +648,8 @@ def sort_chunks_by_idx( ...@@ -628,18 +648,8 @@ def sort_chunks_by_idx(
return output, row_id_map, permuted_probs return output, row_id_map, permuted_probs
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 64}),
triton.Config({"BLOCK_SIZE": 128}),
triton.Config({"BLOCK_SIZE": 256}),
triton.Config({"BLOCK_SIZE": 512}),
triton.Config({"BLOCK_SIZE": 1024}),
],
key=["hidden_size"],
)
@triton.jit @triton.jit
def _sort_chunks_by_map( def _sort_chunks_by_map_kernel(
# pointers # pointers
input_ptr, input_ptr,
output_ptr, output_ptr,
...@@ -677,6 +687,21 @@ def _sort_chunks_by_map( ...@@ -677,6 +687,21 @@ def _sort_chunks_by_map(
tl.store(permuted_probs_ptr + permuted_prob_off, prob) tl.store(permuted_probs_ptr + permuted_prob_off, prob)
try:
_sort_chunks_by_map_kernel = triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 64}),
triton.Config({"BLOCK_SIZE": 128}),
triton.Config({"BLOCK_SIZE": 256}),
triton.Config({"BLOCK_SIZE": 512}),
triton.Config({"BLOCK_SIZE": 1024}),
],
key=["hidden_size"],
)(_sort_chunks_by_map_kernel)
except RuntimeError:
pass
def sort_chunks_by_map( def sort_chunks_by_map(
inp: torch.Tensor, inp: torch.Tensor,
row_id_map: torch.Tensor, row_id_map: torch.Tensor,
...@@ -691,7 +716,7 @@ def sort_chunks_by_map( ...@@ -691,7 +716,7 @@ def sort_chunks_by_map(
else: else:
permuted_probs = None permuted_probs = None
grid = (num_tokens,) grid = (num_tokens,)
_sort_chunks_by_map[grid]( _sort_chunks_by_map_kernel[grid](
inp, inp,
output, output,
row_id_map, row_id_map,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment