Commit 8ef97116 authored by fsx950223's avatar fsx950223
Browse files

update

parent d9579dc8
......@@ -25,7 +25,7 @@ Kernel outputs:
#define PRINT_HOST 0
#define USING_MASK 0
#define VERSION 1
#define DIM 64
#include <iostream>
#include <numeric>
......@@ -62,6 +62,7 @@ using QKVElementOp = PassThrough;
using YElementOp = PassThrough;
using DataType = F16;
using GemmDataType = F16;
using AccDataType = F32;
using ShuffleDataType = F32;
using LSEDataType = F32;
......@@ -89,7 +90,7 @@ static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpeciali
static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default;
#if VERSION == 1
#if DIM >=128
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2<
NumDimG,
NumDimM,
......@@ -156,7 +157,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
#elif VERSION == 2
#elif DIM >64
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2<
NumDimG,
NumDimM,
......@@ -223,7 +224,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
#elif VERSION == 3
#elif DIM >32
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1<
NumDimG,
NumDimM,
......@@ -231,6 +232,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA
NumDimK,
NumDimO,
DataType,
GemmDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
......@@ -297,6 +299,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA
NumDimK,
NumDimO,
DataType,
GemmDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
......
......@@ -151,6 +151,7 @@ template <index_t NumDimG,
index_t NumDimK,
index_t NumDimO, // NumDimGemm1N
typename DataType,
typename GemmDataType,
typename ZDataType,
typename LSEDataType,
typename Acc0BiasDataType,
......@@ -535,9 +536,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1
// GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1<
DataType, // TODO: distinguish A/B datatype
LSEDataType,
GemmDataType,
GemmAccDataType,
CShuffleDataType,
LSEDataType,
AElementwiseOperation,
BElementwiseOperation,
AccElementwiseOperation,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment