Commit 2c61a639 authored by wangshaojie6's avatar wangshaojie6
Browse files

format

parent aac6c294
......@@ -408,7 +408,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
__host__ __device__ bool IsUpperTriangle(index_t m, index_t n) const { return n > m; }
__host__ __device__ bool IsNOutOfBound(/*index_t m, */ index_t n) const { return n >= NRaw_; }
__host__ __device__ bool IsNOutOfBound(/*index_t m, */ index_t n) const
{
return n >= NRaw_;
}
__host__ __device__ bool IsMaskedElement(index_t m, index_t n) const
{
......
......@@ -323,7 +323,10 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
__host__ __device__ bool IsUpperTriangle(index_t m, index_t n) const { return n > m; }
__host__ __device__ bool IsNOutOfBound(/*index_t m, */ index_t n) const { return n >= NRaw_; }
__host__ __device__ bool IsNOutOfBound(/*index_t m, */ index_t n) const
{
return n >= NRaw_;
}
__host__ __device__ bool IsMaskedElement(index_t m, index_t n) const
{
......
......@@ -420,7 +420,10 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
__host__ __device__ bool IsUpperTriangle(index_t m, index_t n) const { return n > m; }
__host__ __device__ bool IsNOutOfBound(/*index_t m, */ index_t n) const { return n >= NRaw_; }
__host__ __device__ bool IsNOutOfBound(/*index_t m, */ index_t n) const
{
return n >= NRaw_;
}
__host__ __device__ bool IsMaskedElement(index_t m, index_t n) const
{
......@@ -646,7 +649,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
problem_desc_vec[i].BatchStrideB1,
c_grid_desc_g_m_n);
// C0 mask
// C0 mask
const auto c0_matrix_mask = C0MatrixMask(problem_desc_vec[i].N);
grid_size_ += grid_size_grp;
......
......@@ -763,7 +763,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
auto gemm0_n_block_idx =
__builtin_amdgcn_readfirstlane(gemm1_k_block_outer_index * NPerBlock);
if(c0_matrix_mask.IsUpperTriangle(m_block_data_idx_on_grid, gemm0_n_block_idx) &&
c0_matrix_mask.IsUpperTriangle(m_block_data_idx_on_grid + MPerBlock - 1, gemm0_n_block_idx))
c0_matrix_mask.IsUpperTriangle(m_block_data_idx_on_grid + MPerBlock - 1,
gemm0_n_block_idx))
{
continue;
}
......
......@@ -30,8 +30,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmPadded =
ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
static constexpr auto GemmPadded = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
// c[g, m, n] = a[g, m, k] * b[g, n, k]
using device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances =
......
......@@ -26,8 +26,7 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmPadded =
ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
static constexpr auto GemmPadded = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
// c[g, m, n] = a[g, m, k] * b[g, n, k]
using device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment