Commit 9da3621b authored by yuguo's avatar yuguo
Browse files
parents 16de530e 86f2e9a9
...@@ -39,6 +39,7 @@ from transformer_engine.pytorch import ( ...@@ -39,6 +39,7 @@ from transformer_engine.pytorch import (
Fp8Padding, Fp8Padding,
Fp8Unpadding, Fp8Unpadding,
) )
from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm
...@@ -61,8 +62,10 @@ torch.cuda.manual_seed(seed) ...@@ -61,8 +62,10 @@ torch.cuda.manual_seed(seed)
_cpu_rng_state = torch.get_rng_state() _cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state() _cuda_rng_state = torch.cuda.get_rng_state()
torch._dynamo.config.recompile_limit = 16 if torch_version() >= (2, 7, 0):
torch._dynamo.config.recompile_limit = 16
else:
torch._dynamo.config.cache_size_limit = 16
class ModelConfig: class ModelConfig:
def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq_len): def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq_len):
......
...@@ -253,7 +253,11 @@ void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id, ...@@ -253,7 +253,11 @@ void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id,
num_out_tokens); num_out_tokens);
blocks = num_rows; blocks = num_rows;
#ifdef __HIP_PLATFORM_AMD__
threads = std::min(num_cols / kElementsPerAccess, 256);
#else
threads = std::min(num_cols / kElementsPerAccess, 1024); threads = std::min(num_cols / kElementsPerAccess, 1024);
#endif
moe_permute_kernel<T, TCompute, 128, false><<<blocks, threads, 0, stream>>>( moe_permute_kernel<T, TCompute, 128, false><<<blocks, threads, 0, stream>>>(
input, nullptr, output, nullptr, nullptr, row_id_map, num_rows, topK, num_cols); input, nullptr, output, nullptr, nullptr, row_id_map, num_rows, topK, num_cols);
} else { } else {
...@@ -305,7 +309,11 @@ void nvte_unpermute_launcher(const T *input, T *output, int *row_id_map, const f ...@@ -305,7 +309,11 @@ void nvte_unpermute_launcher(const T *input, T *output, int *row_id_map, const f
static constexpr int kElementsPerAccess = 16 / sizeof(T); static constexpr int kElementsPerAccess = 16 / sizeof(T);
int blocks = num_rows; int blocks = num_rows;
#ifdef __HIP_PLATFORM_AMD__
int threads = std::min(num_cols / kElementsPerAccess, 256);
#else
int threads = std::min(num_cols / kElementsPerAccess, 1024); int threads = std::min(num_cols / kElementsPerAccess, 1024);
#endif
size_t smem_bytes = topK * sizeof(TCompute); size_t smem_bytes = topK * sizeof(TCompute);
if (prob == nullptr) { if (prob == nullptr) {
......
...@@ -18,7 +18,8 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( ...@@ -18,7 +18,8 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend(
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right) { int64_t window_size_right) {
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
assert(false); // assert(false);
return NVTE_Fused_Attn_Backend::NVTE_No_Backend;
#else #else
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
......
...@@ -281,7 +281,7 @@ def cross_entropy_forward( ...@@ -281,7 +281,7 @@ def cross_entropy_forward(
rank=rank, rank=rank,
n_cols=V, n_cols=V,
BLOCK_SIZE=BLOCK_SIZE, BLOCK_SIZE=BLOCK_SIZE,
num_warps=32, num_warps=16 if IS_HIP_EXTENSION else 32,
) )
world_size = 1 if dist_process_group is None else dist.get_world_size(dist_process_group) world_size = 1 if dist_process_group is None else dist.get_world_size(dist_process_group)
...@@ -309,7 +309,7 @@ def cross_entropy_forward( ...@@ -309,7 +309,7 @@ def cross_entropy_forward(
n_non_ignore=n_rows, n_non_ignore=n_rows,
label_smoothing=label_smoothing, label_smoothing=label_smoothing,
BLOCK_SIZE=BLOCK_SIZE, BLOCK_SIZE=BLOCK_SIZE,
num_warps=32, num_warps=16 if IS_HIP_EXTENSION else 32,
) )
loss = torch.reshape(loss_1d, (B, SQ)) if not reduce_loss else (torch.sum(loss_1d) / n_rows) loss = torch.reshape(loss_1d, (B, SQ)) if not reduce_loss else (torch.sum(loss_1d) / n_rows)
...@@ -335,7 +335,7 @@ def cross_entropy_backward(_input: torch.Tensor, grad_output: torch.Tensor): ...@@ -335,7 +335,7 @@ def cross_entropy_backward(_input: torch.Tensor, grad_output: torch.Tensor):
grad_output, grad_output,
V, V,
BLOCK_SIZE=BLOCK_SIZE, BLOCK_SIZE=BLOCK_SIZE,
num_warps=32, num_warps=16 if IS_HIP_EXTENSION else 32,
) )
return _input return _input
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment