Unverified Commit 45fdf1f7 authored by Yi Pan's avatar Yi Pan Committed by GitHub
Browse files

Fix shared memory OOM on sm86 GPUs. (#4797)

parent d89c0e4b
......@@ -341,8 +341,8 @@ def extend_attention_fwd(
else:
BLOCK_M, BLOCK_N = (32, 64)
elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
# 8.9 has a much smaller shared memory size (100K) than 8.0 (160K)
if CUDA_CAPABILITY[1] == 9:
# sm86/sm89 has a much smaller shared memory size (100K) than sm80 (160K)
if CUDA_CAPABILITY[1] == 9 or CUDA_CAPABILITY[1] == 6:
if Lq <= 128:
BLOCK_M, BLOCK_N = (64, 128)
elif Lq <= 256:
......
......@@ -703,8 +703,8 @@ torch::Tensor int8_scaled_mm(
sm75_dispatch_shape<cutlass::half_t, cutlass::arch::Sm75, cutlass::gemm::GemmShape<8, 8, 16>>(
out, mat_a, mat_b, scales_a, scales_b, bias);
} else if (sm_version >= 80 && sm_version < 90) {
// sm89 has a much smaller shared memory size (100K) than sm80 (160K)
if (sm_version == 89) {
// sm86/sm89 has a much smaller shared memory size (100K) than sm80 (160K)
if (sm_version == 86 || sm_version == 89) {
if (out_dtype == torch::kBFloat16) {
sm89_dispatch_shape<cutlass::bfloat16_t, cutlass::arch::Sm80, cutlass::gemm::GemmShape<16, 8, 32>>(
out, mat_a, mat_b, scales_a, scales_b, bias);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment