Commit 075abf92 authored by danyao12's avatar danyao12
Browse files

bwd template rename

parent c56f28b0
......@@ -36,8 +36,8 @@ Kernel outputs:
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
......@@ -97,7 +97,7 @@ static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpeciali
// If 64 < Headdim/K/O <= 128, ues prototype2 2nd template.
#if(RANGE_HDKO == 0)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1<
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
NumDimG,
NumDimM,
NumDimN,
......@@ -165,7 +165,7 @@ using DeviceGemmInstance =
MaskingSpec>; // MaskingSpecialization
#elif(RANGE_HDKO == 1)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1<
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
NumDimG,
NumDimM,
NumDimN,
......@@ -233,7 +233,7 @@ using DeviceGemmInstance =
MaskingSpec>; // MaskingSpecialization
// using DeviceGemmInstance =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle<
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
// NumDimG,
// NumDimM,
// NumDimN,
......@@ -301,7 +301,7 @@ using DeviceGemmInstance =
// MaskingSpec>; // MaskingSpecialization
#elif(RANGE_HDKO == 2)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle<
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
NumDimG,
NumDimM,
NumDimN,
......
......@@ -43,8 +43,8 @@ Kernel outputs:
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
......@@ -176,7 +176,7 @@ using DeviceGemmInstanceFWD =
MaskingSpec>; // MaskingSpecialization
using DeviceGemmInstanceBWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1<
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
NumDimG,
NumDimM,
NumDimN,
......@@ -314,7 +314,7 @@ using DeviceGemmInstanceFWD =
MaskingSpec>; // MaskingSpecialization
using DeviceGemmInstanceBWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1<
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
NumDimG,
NumDimM,
NumDimN,
......@@ -382,7 +382,7 @@ using DeviceGemmInstanceBWD =
MaskingSpec>; // MaskingSpecialization
// using DeviceGemmInstanceBWD =
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle<
// ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
// NumDimG,
// NumDimM,
// NumDimN,
......@@ -520,7 +520,7 @@ using DeviceGemmInstanceFWD =
MaskingSpec>; // MaskingSpecialization
using DeviceGemmInstanceBWD =
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle<
ck::tensor_operation::device::DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2<
NumDimG,
NumDimM,
NumDimN,
......
......@@ -16,7 +16,7 @@
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
......@@ -50,10 +50,9 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
// __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, /*CK_MIN_BLOCK_PER_CU*/ 1)
#endif
kernel_batched_multihead_attention_backward_xdl_cshuffle_pt1(
kernel_batched_multihead_attention_backward_xdl_cshuffle_v1(
const DataType* __restrict__ p_a_grid,
const DataType* __restrict__ p_b_grid,
ZDataType* __restrict__ p_z_grid,
......@@ -233,7 +232,7 @@ template <index_t NumDimG,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
MaskingSpecialization MaskingSpec,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
: public BaseOperator // TODO inherit atten bwd op once API stablizes
{
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
......@@ -255,7 +254,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1;
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -597,7 +596,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
};
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1<
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
DataType, // TODO: distinguish A/B datatype
GemmDataType,
GemmAccDataType,
......@@ -900,7 +899,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_pt1<
const auto kernel = kernel_batched_multihead_attention_backward_xdl_cshuffle_v1<
GridwiseGemm,
DataType,
ZDataType,
......@@ -1231,7 +1230,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
auto str = std::stringstream();
// clang-format off
str << "DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1"
str << "DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
......
......@@ -231,7 +231,7 @@ template <index_t NumDimG,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
MaskingSpecialization MaskingSpec,
LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
: public BaseOperator // TODO inherit atten bwd op once API stablizes
{
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
......@@ -253,7 +253,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle;
using DeviceOp = DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -1230,7 +1230,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
auto str = std::stringstream();
// clang-format off
str << "DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle"
str << "DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
......
......@@ -85,7 +85,7 @@ template <typename DataType,
bool PadN,
bool MaskOutUpperTriangle,
PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1
struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
{
static_assert(LoopSched == LoopScheduler::Default,
"Non-default loop scheduler is currently not supported");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment