Commit 075abf92 authored by danyao12's avatar danyao12
Browse files

bwd template rename

parent c56f28b0
...@@ -36,8 +36,8 @@ Kernel outputs: ...@@ -36,8 +36,8 @@ Kernel outputs:
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" #include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp" #include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
...@@ -97,7 +97,7 @@ static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpeciali ...@@ -97,7 +97,7 @@ static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpeciali
// If 64 < Headdim/K/O <= 128, ues prototype2 2nd template. // If 64 < Headdim/K/O <= 128, ues prototype2 2nd template.
#if(RANGE_HDKO == 0) #if(RANGE_HDKO == 0)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -165,7 +165,7 @@ using DeviceGemmInstance = ...@@ -165,7 +165,7 @@ using DeviceGemmInstance =
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
#elif(RANGE_HDKO == 1) #elif(RANGE_HDKO == 1)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -233,7 +233,7 @@ using DeviceGemmInstance = ...@@ -233,7 +233,7 @@ using DeviceGemmInstance =
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
// using DeviceGemmInstance = // using DeviceGemmInstance =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle< // ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
// NumDimG, // NumDimG,
// NumDimM, // NumDimM,
// NumDimN, // NumDimN,
...@@ -301,7 +301,7 @@ using DeviceGemmInstance = ...@@ -301,7 +301,7 @@ using DeviceGemmInstance =
// MaskingSpec>; // MaskingSpecialization // MaskingSpec>; // MaskingSpecialization
#elif(RANGE_HDKO == 2) #elif(RANGE_HDKO == 2)
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
......
...@@ -43,8 +43,8 @@ Kernel outputs: ...@@ -43,8 +43,8 @@ Kernel outputs:
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" #include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp" #include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
...@@ -176,7 +176,7 @@ using DeviceGemmInstanceFWD = ...@@ -176,7 +176,7 @@ using DeviceGemmInstanceFWD =
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
using DeviceGemmInstanceBWD = using DeviceGemmInstanceBWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -314,7 +314,7 @@ using DeviceGemmInstanceFWD = ...@@ -314,7 +314,7 @@ using DeviceGemmInstanceFWD =
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
using DeviceGemmInstanceBWD = using DeviceGemmInstanceBWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -382,7 +382,7 @@ using DeviceGemmInstanceBWD = ...@@ -382,7 +382,7 @@ using DeviceGemmInstanceBWD =
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
// using DeviceGemmInstanceBWD = // using DeviceGemmInstanceBWD =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle< // ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
// NumDimG, // NumDimG,
// NumDimM, // NumDimM,
// NumDimN, // NumDimN,
...@@ -520,7 +520,7 @@ using DeviceGemmInstanceFWD = ...@@ -520,7 +520,7 @@ using DeviceGemmInstanceFWD =
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
using DeviceGemmInstanceBWD = using DeviceGemmInstanceBWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle< ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp" #include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp" #include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -50,10 +50,9 @@ template <typename GridwiseGemm, ...@@ -50,10 +50,9 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop> bool HasMainKBlockLoop>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
// __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1)
#endif #endif
kernel_batched_multihead_attention_backward_xdl_cshuffle_pt1( kernel_batched_multihead_attention_backward_xdl_cshuffle_v1(
const DataType* __restrict__ p_a_grid, const DataType* __restrict__ p_a_grid,
const DataType* __restrict__ p_b_grid, const DataType* __restrict__ p_b_grid,
ZDataType* __restrict__ p_z_grid, ZDataType* __restrict__ p_z_grid,
...@@ -233,7 +232,7 @@ template <index_t NumDimG, ...@@ -233,7 +232,7 @@ template <index_t NumDimG,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
: public BaseOperator // TODO inherit atten bwd op once API stablizes : public BaseOperator // TODO inherit atten bwd op once API stablizes
{ {
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
...@@ -255,7 +254,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -255,7 +254,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
static constexpr index_t NumDimGemm1K = NumDimN; static constexpr index_t NumDimGemm1K = NumDimN;
#endif #endif
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1; using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -597,7 +596,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -597,7 +596,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
}; };
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
DataType, // TODO: distinguish A/B datatype DataType, // TODO: distinguish A/B datatype
GemmDataType, GemmDataType,
GemmAccDataType, GemmAccDataType,
...@@ -900,7 +899,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -900,7 +899,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
float ave_time = 0; float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_pt1< const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_v1<
GridwiseGemm, GridwiseGemm,
DataType, DataType,
ZDataType, ZDataType,
...@@ -1231,7 +1230,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -1231,7 +1230,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1" str << "DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
...@@ -231,7 +231,7 @@ template <index_t NumDimG, ...@@ -231,7 +231,7 @@ template <index_t NumDimG,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
: public BaseOperator // TODO inherit atten bwd op once API stablizes : public BaseOperator // TODO inherit atten bwd op once API stablizes
{ {
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
...@@ -253,7 +253,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -253,7 +253,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
static constexpr index_t NumDimGemm1K = NumDimN; static constexpr index_t NumDimGemm1K = NumDimN;
#endif #endif
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle; using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -1230,7 +1230,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -1230,7 +1230,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle" str << "DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
...@@ -85,7 +85,7 @@ template <typename DataType, ...@@ -85,7 +85,7 @@ template <typename DataType,
bool PadN, bool PadN,
bool MaskOutUpperTriangle, bool MaskOutUpperTriangle,
PipelineVersion PipelineVer = PipelineVersion::v1> PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1 struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
{ {
static_assert(LoopSched == LoopScheduler::Default, static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported"); "Non-default loop scheduler is currently not supported");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment