Unverified Commit fe531d6f authored by Yuhao Yao's avatar Yuhao Yao Committed by GitHub
Browse files

[Bug] Fix Issue#10215 (#10572)

parent c4e314f9
...@@ -1025,8 +1025,6 @@ struct CollectiveMmaArrayMixedInput< ...@@ -1025,8 +1025,6 @@ struct CollectiveMmaArrayMixedInput<
// src: tCrA_load, dst: tCrA_mma // src: tCrA_load, dst: tCrA_mma
Utils::convert_A_kblock(tCrA_load, tCrA_mma, 0); Utils::convert_A_kblock(tCrA_load, tCrA_mma, 0);
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
// Unroll the K mode manually to set scale D to 1 // Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
for (int chunk_id = 0; chunk_id < NumChunksPerTileK; ++chunk_id) { for (int chunk_id = 0; chunk_id < NumChunksPerTileK; ++chunk_id) {
...@@ -1060,6 +1058,8 @@ struct CollectiveMmaArrayMixedInput< ...@@ -1060,6 +1058,8 @@ struct CollectiveMmaArrayMixedInput<
} }
} }
warpgroup_wait<0>();
CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
for (int chunk_id_ = 0; chunk_id_ < NumChunksPerTileK; ++chunk_id_) { for (int chunk_id_ = 0; chunk_id_ < NumChunksPerTileK; ++chunk_id_) {
warpgroup_fence_operand(intermediate_array[chunk_id_]); warpgroup_fence_operand(intermediate_array[chunk_id_]);
...@@ -1114,7 +1114,6 @@ struct CollectiveMmaArrayMixedInput< ...@@ -1114,7 +1114,6 @@ struct CollectiveMmaArrayMixedInput<
1, 1,
smem_pipe_read.index()); smem_pipe_read.index());
warpgroup_wait<K_WAIT_MAX>();
Utils::convert_A_kblock(tCrA_load, tCrA_mma, 0); Utils::convert_A_kblock(tCrA_load, tCrA_mma, 0);
} }
} }
...@@ -1148,8 +1147,6 @@ struct CollectiveMmaArrayMixedInput< ...@@ -1148,8 +1147,6 @@ struct CollectiveMmaArrayMixedInput<
tiled_mma.accumulate_ = GMMA::ScaleOut::One; tiled_mma.accumulate_ = GMMA::ScaleOut::One;
warpgroup_commit_batch(); warpgroup_commit_batch();
warpgroup_wait<K_WAIT_MAX>(); // We have K_BLOCK_MAX - 1 GMMA instructions pending for this stage, so we can
// release prior barrier
if (k_block == K_BLOCK_MAX - 1) { if (k_block == K_BLOCK_MAX - 1) {
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
++smem_pipe_release; ++smem_pipe_release;
...@@ -1162,6 +1159,8 @@ struct CollectiveMmaArrayMixedInput< ...@@ -1162,6 +1159,8 @@ struct CollectiveMmaArrayMixedInput<
if (k_block == K_BLOCK_MAX - 1) { if (k_block == K_BLOCK_MAX - 1) {
// The last k_block // The last k_block
warpgroup_wait<0>();
CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
for (int chunk_id_ = 0; chunk_id_ < NumChunksPerTileK; ++chunk_id_) { for (int chunk_id_ = 0; chunk_id_ < NumChunksPerTileK; ++chunk_id_) {
warpgroup_fence_operand(intermediate_array[chunk_id_]); warpgroup_fence_operand(intermediate_array[chunk_id_]);
...@@ -1241,7 +1240,6 @@ struct CollectiveMmaArrayMixedInput< ...@@ -1241,7 +1240,6 @@ struct CollectiveMmaArrayMixedInput<
tiled_mma.accumulate_ = GMMA::ScaleOut::One; tiled_mma.accumulate_ = GMMA::ScaleOut::One;
warpgroup_commit_batch(); warpgroup_commit_batch();
warpgroup_wait<K_WAIT_MAX>();
if (k_block == K_BLOCK_MAX - 1) { if (k_block == K_BLOCK_MAX - 1) {
// release prior barrier // release prior barrier
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
...@@ -1264,6 +1262,8 @@ struct CollectiveMmaArrayMixedInput< ...@@ -1264,6 +1262,8 @@ struct CollectiveMmaArrayMixedInput<
if ((k_block + 1) % NumMMAsPerChunk == 0) { if ((k_block + 1) % NumMMAsPerChunk == 0) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
warpgroup_wait<0>();
warpgroup_fence_operand(intermediate); warpgroup_fence_operand(intermediate);
// Apply the group-wise scaling // Apply the group-wise scaling
...@@ -1296,7 +1296,7 @@ struct CollectiveMmaArrayMixedInput< ...@@ -1296,7 +1296,7 @@ struct CollectiveMmaArrayMixedInput<
smem_pipe_release.advance(k_tile_count); smem_pipe_release.advance(k_tile_count);
// Wait on all GMMAs to complete // Wait on all GMMAs to complete
warpgroup_wait<0>(); // warpgroup_wait<0>();
for (int count = 0; count < prologue_mma_count; ++count) { for (int count = 0; count < prologue_mma_count; ++count) {
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
......
...@@ -157,8 +157,8 @@ def _per_tensor_quant_fp8( ...@@ -157,8 +157,8 @@ def _per_tensor_quant_fp8(
reason="cutlass_w4a8_moe_mm is only supported on sm90", reason="cutlass_w4a8_moe_mm is only supported on sm90",
) )
@pytest.mark.parametrize("batch_size", [2, 4, 8, 16, 32]) @pytest.mark.parametrize("batch_size", [2, 4, 8, 16, 32])
@pytest.mark.parametrize("k", [512, 1024, 2048, 4096, 7168]) @pytest.mark.parametrize("k", [256, 512, 1024, 2048, 4096, 7168])
@pytest.mark.parametrize("n", [256, 512, 1024, 2048]) @pytest.mark.parametrize("n", [256, 512, 1024, 2048, 7168])
@pytest.mark.parametrize("num_experts", [2, 4, 6, 8]) @pytest.mark.parametrize("num_experts", [2, 4, 6, 8])
def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts): def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts):
torch.manual_seed(0) torch.manual_seed(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment