Commit 2c61a639 authored by wangshaojie6's avatar wangshaojie6
Browse files

format

parent aac6c294
...@@ -408,7 +408,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -408,7 +408,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
__host__ __device__ bool IsUpperTriangle(index_t m, index_t n) const { return n > m; } __host__ __device__ bool IsUpperTriangle(index_t m, index_t n) const { return n > m; }
__host__ __device__ bool IsNOutOfBound(/*index_t m, */ index_t n) const { return n >= NRaw_; } __host__ __device__ bool IsNOutOfBound(/*index_t m, */ index_t n) const
{
return n >= NRaw_;
}
__host__ __device__ bool IsMaskedElement(index_t m, index_t n) const __host__ __device__ bool IsMaskedElement(index_t m, index_t n) const
{ {
......
...@@ -323,7 +323,10 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -323,7 +323,10 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
__host__ __device__ bool IsUpperTriangle(index_t m, index_t n) const { return n > m; } __host__ __device__ bool IsUpperTriangle(index_t m, index_t n) const { return n > m; }
__host__ __device__ bool IsNOutOfBound(/*index_t m, */ index_t n) const { return n >= NRaw_; } __host__ __device__ bool IsNOutOfBound(/*index_t m, */ index_t n) const
{
return n >= NRaw_;
}
__host__ __device__ bool IsMaskedElement(index_t m, index_t n) const __host__ __device__ bool IsMaskedElement(index_t m, index_t n) const
{ {
......
...@@ -420,7 +420,10 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -420,7 +420,10 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
__host__ __device__ bool IsUpperTriangle(index_t m, index_t n) const { return n > m; } __host__ __device__ bool IsUpperTriangle(index_t m, index_t n) const { return n > m; }
__host__ __device__ bool IsNOutOfBound(/*index_t m, */ index_t n) const { return n >= NRaw_; } __host__ __device__ bool IsNOutOfBound(/*index_t m, */ index_t n) const
{
return n >= NRaw_;
}
__host__ __device__ bool IsMaskedElement(index_t m, index_t n) const __host__ __device__ bool IsMaskedElement(index_t m, index_t n) const
{ {
...@@ -646,7 +649,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -646,7 +649,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
problem_desc_vec[i].BatchStrideB1, problem_desc_vec[i].BatchStrideB1,
c_grid_desc_g_m_n); c_grid_desc_g_m_n);
// C0 mask // C0 mask
const auto c0_matrix_mask = C0MatrixMask(problem_desc_vec[i].N); const auto c0_matrix_mask = C0MatrixMask(problem_desc_vec[i].N);
grid_size_ += grid_size_grp; grid_size_ += grid_size_grp;
......
...@@ -763,7 +763,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -763,7 +763,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
auto gemm0_n_block_idx = auto gemm0_n_block_idx =
__builtin_amdgcn_readfirstlane(gemm1_k_block_outer_index * NPerBlock); __builtin_amdgcn_readfirstlane(gemm1_k_block_outer_index * NPerBlock);
if(c0_matrix_mask.IsUpperTriangle(m_block_data_idx_on_grid, gemm0_n_block_idx) && if(c0_matrix_mask.IsUpperTriangle(m_block_data_idx_on_grid, gemm0_n_block_idx) &&
c0_matrix_mask.IsUpperTriangle(m_block_data_idx_on_grid + MPerBlock - 1, gemm0_n_block_idx)) c0_matrix_mask.IsUpperTriangle(m_block_data_idx_on_grid + MPerBlock - 1,
gemm0_n_block_idx))
{ {
continue; continue;
} }
......
...@@ -30,8 +30,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; ...@@ -30,8 +30,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale; using Scale = ck::tensor_operation::element_wise::Scale;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmPadded = static constexpr auto GemmPadded = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
// c[g, m, n] = a[g, m, k] * b[g, n, k] // c[g, m, n] = a[g, m, k] * b[g, n, k]
using device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances = using device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances =
......
...@@ -26,8 +26,7 @@ using S = ck::Sequence<Is...>; ...@@ -26,8 +26,7 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmPadded = static constexpr auto GemmPadded = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
// c[g, m, n] = a[g, m, k] * b[g, n, k] // c[g, m, n] = a[g, m, k] * b[g, n, k]
using device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances = using device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment