Commit 5e1b0dd8 authored by Anthony Chang's avatar Anthony Chang
Browse files

IsSupportedArgument checks

parent c3920de4
...@@ -544,7 +544,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -544,7 +544,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
c_element_op_{c_element_op}, c_element_op_{c_element_op},
batch_count_(Batch), batch_count_(Batch),
compute_base_ptr_of_batch_{ compute_base_ptr_of_batch_{
BatchStrideA, BatchStrideB, BatchStrideB1, c_grid_desc_g_m_n_} BatchStrideA, BatchStrideB, BatchStrideB1, c_grid_desc_g_m_n_},
raw_lengths_m_n_k_o_{MRaw, NRaw, KRaw, Gemm1NRaw},
c_extent_lowest_{c_gs_ms_gemm1ns_lengths.back()},
c_stride_lowest_{c_gs_ms_gemm1ns_strides.back()}
{ {
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_, b_grid_desc_bk0_n_bk1_,
...@@ -578,6 +581,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -578,6 +581,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
index_t batch_count_; index_t batch_count_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
// For robust IsSupportedArgument() check
std::vector<index_t> raw_lengths_m_n_k_o_;
index_t c_extent_lowest_;
index_t c_stride_lowest_;
}; };
// Invoker // Invoker
...@@ -692,7 +700,35 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -692,7 +700,35 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
return false; return false;
} }
// TODO: Check A/B0/B1 length & stride and scalar per vector // Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds
const auto MRaw = arg.raw_lengths_m_n_k_o_[0];
const auto NRaw = arg.raw_lengths_m_n_k_o_[1];
const auto KRaw = arg.raw_lengths_m_n_k_o_[2];
const auto Gemm1NRaw = arg.raw_lengths_m_n_k_o_[3];
// Check scalar per vector requirement
const auto a_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, ALayout> ? KRaw : MRaw;
const auto b_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, BLayout> ? NRaw : KRaw;
const auto b1_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, B1Layout> ? Gemm1NRaw : NRaw;
const auto c_extent_lowest = arg.c_extent_lowest_;
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{
return false;
}
// Check vector store requirement; assumes last dimension in N to be contiguous
if(arg.c_stride_lowest_ != 1)
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
......
...@@ -459,7 +459,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -459,7 +459,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
b1_element_op_{b1_element_op}, b1_element_op_{b1_element_op},
c_element_op_{c_element_op}, c_element_op_{c_element_op},
batch_count_(Batch), batch_count_(Batch),
compute_base_ptr_of_batch_{BatchStrideA, BatchStrideB, BatchStrideB1, BatchStrideC} compute_base_ptr_of_batch_{BatchStrideA, BatchStrideB, BatchStrideB1, BatchStrideC},
raw_lengths_m_n_k_o_{MRaw, NRaw, KRaw, Gemm1NRaw}
{ {
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_, b_grid_desc_bk0_n_bk1_,
...@@ -492,6 +493,9 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -492,6 +493,9 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
CElementwiseOperation c_element_op_; CElementwiseOperation c_element_op_;
index_t batch_count_; index_t batch_count_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
// For robust IsSupportedArgument() check
std::vector<index_t> raw_lengths_m_n_k_o_;
}; };
// Invoker // Invoker
...@@ -595,6 +599,31 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -595,6 +599,31 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return false; return false;
} }
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds
const auto MRaw = arg.raw_lengths_m_n_k_o_[0];
const auto NRaw = arg.raw_lengths_m_n_k_o_[1];
const auto KRaw = arg.raw_lengths_m_n_k_o_[2];
const auto Gemm1NRaw = arg.raw_lengths_m_n_k_o_[3];
// Check scalar per vector requirement
const auto a_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, ALayout> ? KRaw : MRaw;
const auto b_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, BLayout> ? NRaw : KRaw;
const auto b1_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, B1Layout> ? Gemm1NRaw : NRaw;
const auto c_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, CLayout> ? Gemm1NRaw : MRaw;
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment