Commit 8ef97116 authored by fsx950223's avatar fsx950223
Browse files

update

parent d9579dc8
...@@ -25,7 +25,7 @@ Kernel outputs: ...@@ -25,7 +25,7 @@ Kernel outputs:
#define PRINT_HOST 0 #define PRINT_HOST 0
#define USING_MASK 0 #define USING_MASK 0
#define VERSION 1 #define DIM 64
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -62,6 +62,7 @@ using QKVElementOp = PassThrough; ...@@ -62,6 +62,7 @@ using QKVElementOp = PassThrough;
using YElementOp = PassThrough; using YElementOp = PassThrough;
using DataType = F16; using DataType = F16;
using GemmDataType = F16;
using AccDataType = F32; using AccDataType = F32;
using ShuffleDataType = F32; using ShuffleDataType = F32;
using LSEDataType = F32; using LSEDataType = F32;
...@@ -89,7 +90,7 @@ static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpeciali ...@@ -89,7 +90,7 @@ static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpeciali
static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default;
#if VERSION == 1 #if DIM >=128
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2< using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2<
NumDimG, NumDimG,
NumDimM, NumDimM,
...@@ -156,7 +157,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA ...@@ -156,7 +157,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
#elif VERSION == 2 #elif DIM >64
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2< using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2<
NumDimG, NumDimG,
NumDimM, NumDimM,
...@@ -223,7 +224,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA ...@@ -223,7 +224,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
#elif VERSION == 3 #elif DIM >32
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1< using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1<
NumDimG, NumDimG,
NumDimM, NumDimM,
...@@ -231,6 +232,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA ...@@ -231,6 +232,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA
NumDimK, NumDimK,
NumDimO, NumDimO,
DataType, DataType,
GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
Acc0BiasDataType, Acc0BiasDataType,
...@@ -297,6 +299,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA ...@@ -297,6 +299,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA
NumDimK, NumDimK,
NumDimO, NumDimO,
DataType, DataType,
GemmDataType,
ZDataType, ZDataType,
LSEDataType, LSEDataType,
Acc0BiasDataType, Acc0BiasDataType,
......
...@@ -151,6 +151,7 @@ template <index_t NumDimG, ...@@ -151,6 +151,7 @@ template <index_t NumDimG,
index_t NumDimK, index_t NumDimK,
index_t NumDimO, // NumDimGemm1N index_t NumDimO, // NumDimGemm1N
typename DataType, typename DataType,
typename GemmDataType,
typename ZDataType, typename ZDataType,
typename LSEDataType, typename LSEDataType,
typename Acc0BiasDataType, typename Acc0BiasDataType,
...@@ -535,9 +536,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -535,9 +536,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1<
DataType, // TODO: distinguish A/B datatype DataType, // TODO: distinguish A/B datatype
LSEDataType, GemmDataType,
GemmAccDataType, GemmAccDataType,
CShuffleDataType, CShuffleDataType,
LSEDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
AccElementwiseOperation, AccElementwiseOperation,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment