Commit 5e1b0dd8 authored by Anthony Chang's avatar Anthony Chang
Browse files

IsSupportedArgument checks

parent c3920de4
......@@ -544,7 +544,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
c_element_op_{c_element_op},
batch_count_(Batch),
compute_base_ptr_of_batch_{
BatchStrideA, BatchStrideB, BatchStrideB1, c_grid_desc_g_m_n_}
BatchStrideA, BatchStrideB, BatchStrideB1, c_grid_desc_g_m_n_},
raw_lengths_m_n_k_o_{MRaw, NRaw, KRaw, Gemm1NRaw},
c_extent_lowest_{c_gs_ms_gemm1ns_lengths.back()},
c_stride_lowest_{c_gs_ms_gemm1ns_strides.back()}
{
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_,
......@@ -578,6 +581,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
CElementwiseOperation c_element_op_;
index_t batch_count_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
// For robust IsSupportedArgument() check
std::vector<index_t> raw_lengths_m_n_k_o_;
index_t c_extent_lowest_;
index_t c_stride_lowest_;
};
// Invoker
......@@ -692,7 +700,35 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
return false;
}
// TODO: Check A/B0/B1 length & stride and scalar per vector
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds
const auto MRaw = arg.raw_lengths_m_n_k_o_[0];
const auto NRaw = arg.raw_lengths_m_n_k_o_[1];
const auto KRaw = arg.raw_lengths_m_n_k_o_[2];
const auto Gemm1NRaw = arg.raw_lengths_m_n_k_o_[3];
// Check scalar per vector requirement
const auto a_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, ALayout> ? KRaw : MRaw;
const auto b_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, BLayout> ? NRaw : KRaw;
const auto b1_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, B1Layout> ? Gemm1NRaw : NRaw;
const auto c_extent_lowest = arg.c_extent_lowest_;
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{
return false;
}
// Check vector store requirement; assumes last dimension in N to be contiguous
if(arg.c_stride_lowest_ != 1)
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
......
......@@ -459,7 +459,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
b1_element_op_{b1_element_op},
c_element_op_{c_element_op},
batch_count_(Batch),
compute_base_ptr_of_batch_{BatchStrideA, BatchStrideB, BatchStrideB1, BatchStrideC}
compute_base_ptr_of_batch_{BatchStrideA, BatchStrideB, BatchStrideB1, BatchStrideC},
raw_lengths_m_n_k_o_{MRaw, NRaw, KRaw, Gemm1NRaw}
{
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
b_grid_desc_bk0_n_bk1_,
......@@ -492,6 +493,9 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
CElementwiseOperation c_element_op_;
index_t batch_count_;
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
// For robust IsSupportedArgument() check
std::vector<index_t> raw_lengths_m_n_k_o_;
};
// Invoker
......@@ -595,6 +599,31 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
return false;
}
// Note: we need raw lengths since threadwise copy can not handle vector load when part of
// vector is out of bounds
const auto MRaw = arg.raw_lengths_m_n_k_o_[0];
const auto NRaw = arg.raw_lengths_m_n_k_o_[1];
const auto KRaw = arg.raw_lengths_m_n_k_o_[2];
const auto Gemm1NRaw = arg.raw_lengths_m_n_k_o_[3];
// Check scalar per vector requirement
const auto a_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, ALayout> ? KRaw : MRaw;
const auto b_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, BLayout> ? NRaw : KRaw;
const auto b1_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, B1Layout> ? Gemm1NRaw : NRaw;
const auto c_extent_lowest =
is_same_v<tensor_layout::gemm::RowMajor, CLayout> ? Gemm1NRaw : MRaw;
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 &&
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
{
return false;
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.b1_grid_desc_bk0_n_bk1_,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment