Commit f01a06c4 authored by danyao12's avatar danyao12
Browse files

mi300 test decoder

parent 1128cd3a
...@@ -24,7 +24,7 @@ Kernel outputs: ...@@ -24,7 +24,7 @@ Kernel outputs:
*/ */
#define PRINT_HOST 0 #define PRINT_HOST 0
#define USING_MASK 0 #define USING_MASK 1
#define DIM 64 // DIM should be a multiple of 8. #define DIM 64 // DIM should be a multiple of 8.
#include <iostream> #include <iostream>
......
...@@ -24,7 +24,7 @@ Kernel outputs: ...@@ -24,7 +24,7 @@ Kernel outputs:
*/ */
#define PRINT_HOST 0 #define PRINT_HOST 0
#define USING_MASK 0 #define USING_MASK 1
#define DIM 128 // DIM should be a multiple of 8. #define DIM 128 // DIM should be a multiple of 8.
#include <iostream> #include <iostream>
......
...@@ -24,7 +24,7 @@ Kernel outputs: ...@@ -24,7 +24,7 @@ Kernel outputs:
*/ */
#define PRINT_HOST 0 #define PRINT_HOST 0
#define USING_MASK 0 #define USING_MASK 1
#define DIM 64 // DIM should be a multiple of 8. #define DIM 64 // DIM should be a multiple of 8.
#include <iostream> #include <iostream>
......
...@@ -31,7 +31,7 @@ Kernel outputs: ...@@ -31,7 +31,7 @@ Kernel outputs:
*/ */
#define PRINT_HOST 0 #define PRINT_HOST 0
#define USING_MASK 0 #define USING_MASK 1
#define DIM 128 // DIM should be a multiple of 8. #define DIM 128 // DIM should be a multiple of 8.
#include <iostream> #include <iostream>
......
...@@ -31,7 +31,7 @@ Kernel outputs: ...@@ -31,7 +31,7 @@ Kernel outputs:
*/ */
#define PRINT_HOST 0 #define PRINT_HOST 0
#define USING_MASK 0 #define USING_MASK 1
#define DIM 128 // DIM should be a multiple of 8. #define DIM 128 // DIM should be a multiple of 8.
#include <iostream> #include <iostream>
......
...@@ -23,7 +23,7 @@ Kernel outputs: ...@@ -23,7 +23,7 @@ Kernel outputs:
*/ */
#define USING_MASK 0 #define USING_MASK 1
#define DIM 64 // DIM should be a multiple of 8. #define DIM 64 // DIM should be a multiple of 8.
#include <iostream> #include <iostream>
......
...@@ -23,7 +23,7 @@ Kernel outputs: ...@@ -23,7 +23,7 @@ Kernel outputs:
*/ */
#define USING_MASK 0 #define USING_MASK 1
#define DIM 128 // DIM should be a multiple of 8. #define DIM 128 // DIM should be a multiple of 8.
#include <iostream> #include <iostream>
......
...@@ -30,7 +30,7 @@ Kernel outputs: ...@@ -30,7 +30,7 @@ Kernel outputs:
*/ */
#define USING_MASK 0 #define USING_MASK 1
#define DIM 64 // DIM should be a multiple of 8. #define DIM 64 // DIM should be a multiple of 8.
#include <iostream> #include <iostream>
......
...@@ -30,7 +30,7 @@ Kernel outputs: ...@@ -30,7 +30,7 @@ Kernel outputs:
*/ */
#define USING_MASK 0 #define USING_MASK 1
#define DIM 64 // DIM should be a multiple of 8. #define DIM 64 // DIM should be a multiple of 8.
#include <iostream> #include <iostream>
......
...@@ -1867,11 +1867,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1 ...@@ -1867,11 +1867,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
{ {
auto n_block_data_idx_on_grid = auto n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(gemm1_k_block_outer_index * NPerBlock); __builtin_amdgcn_readfirstlane(gemm1_k_block_outer_index * NPerBlock);
if(c0_matrix_mask.IsTileSkippable( // if(c0_matrix_mask.IsTileSkippable(
m_block_data_idx_on_grid, n_block_data_idx_on_grid, MPerBlock, NPerBlock)) // m_block_data_idx_on_grid, n_block_data_idx_on_grid, MPerBlock, NPerBlock))
{ // {
continue; // continue;
} // }
// gemm dP // gemm dP
// dP = dY * V^T // dP = dY * V^T
......
...@@ -1775,11 +1775,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2 ...@@ -1775,11 +1775,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
{ {
auto n_block_data_idx_on_grid = auto n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(gemm1_k_block_outer_index * NPerBlock); __builtin_amdgcn_readfirstlane(gemm1_k_block_outer_index * NPerBlock);
if(c0_matrix_mask.IsTileSkippable( // if(c0_matrix_mask.IsTileSkippable(
m_block_data_idx_on_grid, n_block_data_idx_on_grid, MPerBlock, NPerBlock)) // m_block_data_idx_on_grid, n_block_data_idx_on_grid, MPerBlock, NPerBlock))
{ // {
continue; // continue;
} // }
// S = Q * K^T // S = Q * K^T
gemm0_gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>( gemm0_gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(
q_grid_desc_k0_m_k1, q_grid_desc_k0_m_k1,
......
...@@ -1798,11 +1798,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1798,11 +1798,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
{ {
auto m_block_data_idx_on_grid = auto m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(gemm0_m_block_outer_index * MPerBlock); __builtin_amdgcn_readfirstlane(gemm0_m_block_outer_index * MPerBlock);
if(c0_matrix_mask.IsTileSkippable( // if(c0_matrix_mask.IsTileSkippable(
m_block_data_idx_on_grid, n_block_data_idx_on_grid, MPerBlock, NPerBlock)) // m_block_data_idx_on_grid, n_block_data_idx_on_grid, MPerBlock, NPerBlock))
{ // {
continue; // continue;
} // }
// load ygrad // load ygrad
gemm_tile_ygrad_blockwise_copy.Run(ygrad_grid_desc_o0_m_o1, gemm_tile_ygrad_blockwise_copy.Run(ygrad_grid_desc_o0_m_o1,
......
...@@ -1721,11 +1721,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1721,11 +1721,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{ {
auto m_block_data_idx_on_grid = auto m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(gemm0_m_block_outer_index * MPerBlock); __builtin_amdgcn_readfirstlane(gemm0_m_block_outer_index * MPerBlock);
if(c0_matrix_mask.IsTileSkippable( // if(c0_matrix_mask.IsTileSkippable(
m_block_data_idx_on_grid, n_block_data_idx_on_grid, MPerBlock, NPerBlock)) // m_block_data_idx_on_grid, n_block_data_idx_on_grid, MPerBlock, NPerBlock))
{ // {
continue; // continue;
} // }
// //
// calculate Y dot dY // calculate Y dot dY
......
...@@ -926,11 +926,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V1 ...@@ -926,11 +926,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V1
{ {
auto n_block_data_idx_on_grid = auto n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(gemm1_k_block_outer_index * NPerBlock); __builtin_amdgcn_readfirstlane(gemm1_k_block_outer_index * NPerBlock);
if(c0_matrix_mask.IsTileSkippable( // if(c0_matrix_mask.IsTileSkippable(
m_block_data_idx_on_grid, n_block_data_idx_on_grid, MPerBlock, NPerBlock)) // m_block_data_idx_on_grid, n_block_data_idx_on_grid, MPerBlock, NPerBlock))
{ // {
continue; // continue;
} // }
// gemm0 // gemm0
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1, gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1,
......
...@@ -1080,11 +1080,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -1080,11 +1080,11 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
{ {
auto n_block_data_idx_on_grid = auto n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(gemm1_k_block_outer_index * NPerBlock); __builtin_amdgcn_readfirstlane(gemm1_k_block_outer_index * NPerBlock);
if(c0_matrix_mask.IsTileSkippable( // if(c0_matrix_mask.IsTileSkippable(
m_block_data_idx_on_grid, n_block_data_idx_on_grid, MPerBlock, NPerBlock)) // m_block_data_idx_on_grid, n_block_data_idx_on_grid, MPerBlock, NPerBlock))
{ // {
continue; // continue;
} // }
// gemm0 // gemm0
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1, gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment