Commit a39dd61f authored by danyao12's avatar danyao12
Browse files

refactor grouped bwd example and fix some bugs

parent 79f3caf8
...@@ -25,7 +25,7 @@ Kernel outputs: ...@@ -25,7 +25,7 @@ Kernel outputs:
#define PRINT_HOST 0 #define PRINT_HOST 0
#define USING_MASK 0 #define USING_MASK 0
#define RANGE_HDKO 2 // 0~2 #define RANGE_HDKO 1 // 0~2
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -523,7 +523,7 @@ int run(int argc, char* argv[]) ...@@ -523,7 +523,7 @@ int run(int argc, char* argv[])
G0 = std::stoi(argv[8]); G0 = std::stoi(argv[8]);
G1 = std::stoi(argv[9]); G1 = std::stoi(argv[9]);
alpha = std::stof(argv[10]); alpha = std::stof(argv[10]);
p_drop = std::stof(argv[11]); p_drop = std::stof(argv[11]);
input_permute = std::stoi(argv[12]); input_permute = std::stoi(argv[12]);
...@@ -540,9 +540,9 @@ int run(int argc, char* argv[]) ...@@ -540,9 +540,9 @@ int run(int argc, char* argv[])
exit(0); exit(0);
} }
float p_dropout = 1 - p_drop; float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0)); uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout; float rp_dropout = 1.0 / p_dropout;
std::cout << "do_verification: " << do_verification << std::endl; std::cout << "do_verification: " << do_verification << std::endl;
std::cout << "init_method: " << init_method << std::endl; std::cout << "init_method: " << init_method << std::endl;
...@@ -678,7 +678,6 @@ int run(int argc, char* argv[]) ...@@ -678,7 +678,6 @@ int run(int argc, char* argv[])
// = 0 // = 0
} }
// calculate y & log-sum-exp beforehand
Tensor<DataType> q_g_m_k({BatchCount, M, K}); Tensor<DataType> q_g_m_k({BatchCount, M, K});
Tensor<DataType> k_g_n_k({BatchCount, N, K}); Tensor<DataType> k_g_n_k({BatchCount, N, K});
Tensor<ZDataType> z_g_m_n({BatchCount, M, N}); Tensor<ZDataType> z_g_m_n({BatchCount, M, N});
......
...@@ -23,20 +23,20 @@ Kernel outputs: ...@@ -23,20 +23,20 @@ Kernel outputs:
*/ */
#define PRINT_HOST 0
#define USING_MASK 0 #define USING_MASK 0
#define DIM 64 #define DIM 64 // DIM should be a multiple of 8.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
#include <initializer_list> #include <initializer_list>
#include <cstdlib> #include <cstdlib>
#include <fstream>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" #include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle_v2.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_multihead_attention_backward_xdl_cshuffle_pt1.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
...@@ -47,13 +47,13 @@ Kernel outputs: ...@@ -47,13 +47,13 @@ Kernel outputs:
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_dropout.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_dropout.hpp"
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using BF16 = ck::bhalf_t;
using U16 = unsigned short; using F32 = float;
using U16 = unsigned short;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale; using Scale = ck::tensor_operation::element_wise::Scale;
...@@ -90,8 +90,13 @@ static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpeciali ...@@ -90,8 +90,13 @@ static constexpr auto TensorSpecK = ck::tensor_operation::device::TensorSpeciali
static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecV = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto TensorSpecY = ck::tensor_operation::device::TensorSpecialization::Default;
#if DIM >=128 // DIM should be a multiple of 8.
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2< // If DIM <= 32 , ues prototype1 1st template.
// If 32 < DIM <= 64 , ues prototype1 2nd template.
// If 64 < DIM <= 128, ues prototype2 2nd template.
#if(DIM <= 32)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -119,8 +124,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA ...@@ -119,8 +124,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA
256, 256,
128, // MPerBlock 128, // MPerBlock
128, // NPerBlock 128, // NPerBlock
64, // KPerBlock 32, // KPerBlock
128, // Gemm1NPerBlock 32, // Gemm1NPerBlock
32, // Gemm1KPerBlock 32, // Gemm1KPerBlock
8, // AK1 8, // AK1
8, // BK1 8, // BK1
...@@ -129,8 +134,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA ...@@ -129,8 +134,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA
32, // NPerXDL 32, // NPerXDL
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
4, // Gemm1NXdlPerWave 1, // Gemm1NXdlPerWave
2, // Gemm2NXdlPerWave 1, // Gemm2NXdlPerWave
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
S<1, 0, 2>, S<1, 0, 2>,
...@@ -153,12 +158,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA ...@@ -153,12 +158,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA
2, 2,
false, false,
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
4, // CShuffleNXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
#elif DIM >64 #elif(DIM <= 64)
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2< using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -188,7 +194,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA ...@@ -188,7 +194,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA
128, // NPerBlock 128, // NPerBlock
64, // KPerBlock 64, // KPerBlock
64, // Gemm1NPerBlock 64, // Gemm1NPerBlock
64, // Gemm1KPerBlock 32, // Gemm1KPerBlock
8, // AK1 8, // AK1
8, // BK1 8, // BK1
2, // B1K1 2, // B1K1
...@@ -216,7 +222,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA ...@@ -216,7 +222,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA
S<0, 2, 1>, S<0, 2, 1>,
S<0, 2, 1>, S<0, 2, 1>,
1, 1,
2, 4,
2, 2,
false, false,
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
...@@ -224,8 +230,77 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA ...@@ -224,8 +230,77 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
#elif DIM >32
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1< // using DeviceGemmInstance =
// ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2<
// NumDimG,
// NumDimM,
// NumDimN,
// NumDimK,
// NumDimO,
// DataType,
// GemmDataType,
// ZDataType,
// LSEDataType,
// Acc0BiasDataType,
// Acc1BiasDataType,
// AccDataType,
// ShuffleDataType,
// QKVElementOp,
// QKVElementOp,
// Scale,
// QKVElementOp,
// YElementOp,
// GemmSpec,
// TensorSpecQ,
// TensorSpecK,
// TensorSpecV,
// TensorSpecY,
// 1,
// 256,
// 128, // MPerBlock
// 128, // NPerBlock
// 64, // KPerBlock
// 64, // Gemm1NPerBlock
// 64, // Gemm1KPerBlock
// 8, // AK1
// 8, // BK1
// 2, // B1K1
// 32, // MPerXDL
// 32, // NPerXDL
// 1, // MXdlPerWave
// 4, // NXdlPerWave
// 2, // Gemm1NXdlPerWave
// 2, // Gemm2NXdlPerWave
// S<4, 64, 1>, // ABlockTransfer
// S<1, 0, 2>,
// S<1, 0, 2>,
// 2,
// 8,
// 8,
// true,
// S<4, 64, 1>, // BBlockTransfer
// S<1, 0, 2>,
// S<1, 0, 2>,
// 2,
// 8,
// 8,
// true,
// S<8, 32, 1>, // B1BlockTransfer
// S<0, 2, 1>,
// S<0, 2, 1>,
// 1,
// 2,
// 2,
// false,
// 1, // CShuffleMXdlPerWavePerShuffle
// 2, // CShuffleNXdlPerWavePerShuffle
// S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
// 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
// MaskingSpec>; // MaskingSpecialization
#elif(DIM <= 128)
using DeviceGemmInstance =
ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2<
NumDimG, NumDimG,
NumDimM, NumDimM,
NumDimN, NumDimN,
...@@ -254,7 +329,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA ...@@ -254,7 +329,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA
128, // MPerBlock 128, // MPerBlock
128, // NPerBlock 128, // NPerBlock
64, // KPerBlock 64, // KPerBlock
64, // Gemm1NPerBlock 128, // Gemm1NPerBlock
32, // Gemm1KPerBlock 32, // Gemm1KPerBlock
8, // AK1 8, // AK1
8, // BK1 8, // BK1
...@@ -263,7 +338,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA ...@@ -263,7 +338,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA
32, // NPerXDL 32, // NPerXDL
1, // MXdlPerWave 1, // MXdlPerWave
4, // NXdlPerWave 4, // NXdlPerWave
2, // Gemm1NXdlPerWave 4, // Gemm1NXdlPerWave
2, // Gemm2NXdlPerWave 2, // Gemm2NXdlPerWave
S<4, 64, 1>, // ABlockTransfer S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>, S<1, 0, 2>,
...@@ -287,78 +362,12 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA ...@@ -287,78 +362,12 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadA
2, 2,
false, false,
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle 4, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock 8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization MaskingSpec>; // MaskingSpecialization
#else
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDimO,
DataType,
GemmDataType,
ZDataType,
LSEDataType,
Acc0BiasDataType,
Acc1BiasDataType,
AccDataType,
ShuffleDataType,
QKVElementOp,
QKVElementOp,
Scale,
QKVElementOp,
YElementOp,
GemmSpec,
TensorSpecQ,
TensorSpecK,
TensorSpecV,
TensorSpecY,
1,
256,
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
32, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
1, // Gemm1NXdlPerWave
1, // Gemm2NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<8, 32, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec>; // MaskingSpecialization
#endif #endif
// Ref Gemm0: S = alpha * Q * K^T // Ref Gemm0: S = alpha * Q * K^T
// fp16 in, fp32 out // fp16 in, fp32 out
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<DataType, using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<DataType,
...@@ -393,7 +402,7 @@ using ReferenceGemmGradInstance = ck::tensor_operation::host::ReferenceBatchedGe ...@@ -393,7 +402,7 @@ using ReferenceGemmGradInstance = ck::tensor_operation::host::ReferenceBatchedGe
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale>; Scale>;
// Ref dropout
using ReferenceDropoutInstance = using ReferenceDropoutInstance =
ck::tensor_operation::host::ReferenceDropout<ushort, DataType, DataType>; ck::tensor_operation::host::ReferenceDropout<ushort, DataType, DataType>;
...@@ -428,14 +437,12 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k, ...@@ -428,14 +437,12 @@ void run_attention_fwd_host(const TensorQ& q_g_m_k,
ref_gemm0_invoker.Run(ref_gemm0_argument); ref_gemm0_invoker.Run(ref_gemm0_argument);
// masking // masking
#if USING_MASK
auto N = s_g_m_n.GetLengths()[2]; auto N = s_g_m_n.GetLengths()[2];
const auto mask = DeviceGemmInstance::C0MatrixMask(N); const auto mask = DeviceGemmInstance::C0MatrixMask(N);
s_g_m_n.ForEach([&](auto& self, auto idx) { s_g_m_n.ForEach([&](auto& self, auto idx) {
if(mask.IsMaskedElement(idx[1], idx[2])) if(mask.IsMaskedElement(idx[1], idx[2]))
self(idx) = -ck::NumericLimits<float>::Infinity(); self(idx) = -ck::NumericLimits<float>::Infinity();
}); });
#endif
// P = Softmax(S) // P = Softmax(S)
auto ref_softmax = ReferenceSoftmaxInstance{}; auto ref_softmax = ReferenceSoftmaxInstance{};
...@@ -470,16 +477,14 @@ int run(int argc, char* argv[]) ...@@ -470,16 +477,14 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o // y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O]) // y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3]) // y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
float alpha = 1.f / std::sqrt(DIM); float alpha = 1.f / std::sqrt(DIM);
float p_drop = 0.2; float p_drop = 0.2;
float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout;
const unsigned long long seed = 1;
const unsigned long long offset = 0;
bool input_permute = false; bool input_permute = false;
bool output_permute = true; bool output_permute = false;
const unsigned long long seed = 1;
const unsigned long long offset = 0;
if(argc == 1) if(argc == 1)
{ {
...@@ -497,7 +502,7 @@ int run(int argc, char* argv[]) ...@@ -497,7 +502,7 @@ int run(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
alpha = std::stof(argv[4]); p_drop = std::stof(argv[4]);
input_permute = std::stoi(argv[5]); input_permute = std::stoi(argv[5]);
output_permute = std::stoi(argv[6]); output_permute = std::stoi(argv[6]);
...@@ -513,6 +518,10 @@ int run(int argc, char* argv[]) ...@@ -513,6 +518,10 @@ int run(int argc, char* argv[])
exit(0); exit(0);
} }
float p_dropout = 1 - p_drop;
uint16_t p_dropout_in_16bits = uint16_t(std::floor(p_dropout * 65535.0));
float rp_dropout = 1.0 / p_dropout;
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
std::vector<DeviceGemmInstance::ProblemDesc> problem_descs; std::vector<DeviceGemmInstance::ProblemDesc> problem_descs;
...@@ -520,7 +529,8 @@ int run(int argc, char* argv[]) ...@@ -520,7 +529,8 @@ int run(int argc, char* argv[])
using DeviceMemPtr = std::unique_ptr<DeviceMem>; using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<const void*> p_q; std::vector<const void*> p_q;
std::vector<const void*> p_k; std::vector<const void*> p_k;
std::vector<void*> p_z; std::vector<void*> p_z; // for result verification
std::vector<void*> p_z_nullptr; // for time test
std::vector<const void*> p_v; std::vector<const void*> p_v;
std::vector<const void*> p_y; std::vector<const void*> p_y;
std::vector<const void*> p_lse; std::vector<const void*> p_lse;
...@@ -560,16 +570,16 @@ int run(int argc, char* argv[]) ...@@ -560,16 +570,16 @@ int run(int argc, char* argv[])
std::vector<DeviceMemPtr> ygrad_tensors_device; std::vector<DeviceMemPtr> ygrad_tensors_device;
std::vector<DeviceMemPtr> kgrad_tensors_device; std::vector<DeviceMemPtr> kgrad_tensors_device;
std::vector<DeviceMemPtr> vgrad_tensors_device; std::vector<DeviceMemPtr> vgrad_tensors_device;
std::size_t group_count = 3; std::size_t group_count = 10;
std::size_t flop = 0, num_byte = 0; std::size_t flop = 0, num_byte = 0;
for(std::size_t i = 0; i < group_count; i++) for(std::size_t i = 0; i < group_count; i++)
{ {
int M = 128 * (rand() % 4 + 1); int M = 128 * (rand() % 8) + (rand() % 128);
int N = 128 * (rand() % 4 + 1); int N = 128 * (rand() % 8) + (rand() % 128);
int K = DIM; int K = DIM;
int O = DIM; int O = DIM;
int G0 = rand() % 3 + 1; int G0 = rand() % 4 + 1;
int G1 = rand() % 2 + 1; int G1 = rand() % 4 + 1;
std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1, M, K}; std::vector<ck::index_t> q_gs_ms_ks_lengths{G0, G1, M, K};
std::vector<ck::index_t> q_gs_ms_ks_strides = std::vector<ck::index_t> q_gs_ms_ks_strides =
input_permute input_permute
...@@ -645,6 +655,7 @@ int run(int argc, char* argv[]) ...@@ -645,6 +655,7 @@ int run(int argc, char* argv[])
{ {
std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl; std::cout << "q_gs_ms_ks: " << q_gs_ms_ks.mDesc << std::endl;
std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl; std::cout << "k_gs_ns_ks: " << k_gs_ns_ks.mDesc << std::endl;
std::cout << "z_gs_ms_ns: " << z_gs_ms_ns.mDesc << std::endl;
std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl; std::cout << "v_gs_os_ns: " << v_gs_os_ns.mDesc << std::endl;
std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl; std::cout << "y_gs_ms_os: " << y_gs_ms_os.mDesc << std::endl;
std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl; std::cout << "lse_gs_ms_os: " << lse_gs_ms.mDesc << std::endl;
...@@ -727,33 +738,9 @@ int run(int argc, char* argv[]) ...@@ -727,33 +738,9 @@ int run(int argc, char* argv[])
k_gs_ns_ks.ForEach([&](auto& self, auto idx) { k_gs_ns_ks.ForEach([&](auto& self, auto idx) {
k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); k_g_n_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
}); });
z_gs_ms_ns.ForEach([&](auto& self, auto idx) {
z_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
v_gs_os_ns.ForEach([&](auto& self, auto idx) { v_gs_os_ns.ForEach([&](auto& self, auto idx) {
v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); v_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx);
}); });
lse_gs_ms.ForEach(
[&](auto& self, auto idx) { lse_g_m(idx[0] * G1 + idx[1], idx[2]) = self(idx); });
run_attention_fwd_host(q_g_m_k,
k_g_n_k,
v_g_n_o,
alpha,
s_g_m_n,
p_g_m_n,
y_g_m_o,
lse_g_m,
p_drop_g_m_n,
z_g_m_n,
p_dropout_in_16bits,
rp_dropout);
y_gs_ms_os.ForEach([&](auto& self, auto idx) {
self(idx) = y_g_m_o(idx[0] * G1 + idx[1], idx[2], idx[3]);
});
lse_gs_ms.ForEach(
[&](auto& self, auto idx) { self(idx) = lse_g_m(idx[0] * G1 + idx[1], idx[2]); });
q_g_m_ks.push_back(q_g_m_k); q_g_m_ks.push_back(q_g_m_k);
k_g_n_ks.push_back(k_g_n_k); k_g_n_ks.push_back(k_g_n_k);
...@@ -763,6 +750,7 @@ int run(int argc, char* argv[]) ...@@ -763,6 +750,7 @@ int run(int argc, char* argv[])
p_g_m_ns.push_back(p_g_m_n); p_g_m_ns.push_back(p_g_m_n);
y_g_m_os.push_back(y_g_m_o); y_g_m_os.push_back(y_g_m_o);
lse_g_ms.push_back(lse_g_m); lse_g_ms.push_back(lse_g_m);
p_drop_g_m_ns.push_back(p_drop_g_m_n);
q_tensors.push_back(q_gs_ms_ks); q_tensors.push_back(q_gs_ms_ks);
k_tensors.push_back(k_gs_ns_ks); k_tensors.push_back(k_gs_ns_ks);
v_tensors.push_back(v_gs_os_ns); v_tensors.push_back(v_gs_os_ns);
...@@ -770,7 +758,6 @@ int run(int argc, char* argv[]) ...@@ -770,7 +758,6 @@ int run(int argc, char* argv[])
z_tensors.push_back(z_gs_ms_ns); z_tensors.push_back(z_gs_ms_ns);
lse_tensors.push_back(lse_gs_ms); lse_tensors.push_back(lse_gs_ms);
ygrad_tensors.push_back(ygrad_gs_ms_os); ygrad_tensors.push_back(ygrad_gs_ms_os);
p_drop_g_m_ns.push_back(p_drop_g_m_n);
q_tensors_device.emplace_back( q_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(DataType) * q_gs_ms_ks.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(DataType) * q_gs_ms_ks.GetElementSpaceSize()));
k_tensors_device.emplace_back( k_tensors_device.emplace_back(
...@@ -795,15 +782,11 @@ int run(int argc, char* argv[]) ...@@ -795,15 +782,11 @@ int run(int argc, char* argv[])
k_tensors_device.back()->ToDevice(k_gs_ns_ks.data()); k_tensors_device.back()->ToDevice(k_gs_ns_ks.data());
z_tensors_device.back()->ToDevice(z_gs_ms_ns.data()); z_tensors_device.back()->ToDevice(z_gs_ms_ns.data());
v_tensors_device.back()->ToDevice(v_gs_os_ns.data()); v_tensors_device.back()->ToDevice(v_gs_os_ns.data());
y_tensors_device.back()->ToDevice(y_gs_ms_os.data());
lse_tensors_device.back()->ToDevice(lse_gs_ms.data());
qgrad_tensors_device.back()->SetZero();
kgrad_tensors_device.back()->SetZero();
vgrad_tensors_device.back()->SetZero();
ygrad_tensors_device.back()->ToDevice(ygrad_gs_ms_os.data()); ygrad_tensors_device.back()->ToDevice(ygrad_gs_ms_os.data());
p_q.push_back(q_tensors_device.back()->GetDeviceBuffer()); p_q.push_back(q_tensors_device.back()->GetDeviceBuffer());
p_k.push_back(k_tensors_device.back()->GetDeviceBuffer()); p_k.push_back(k_tensors_device.back()->GetDeviceBuffer());
p_z.push_back(z_tensors_device.back()->GetDeviceBuffer()); p_z.push_back(z_tensors_device.back()->GetDeviceBuffer());
p_z_nullptr.push_back(nullptr);
p_v.push_back(v_tensors_device.back()->GetDeviceBuffer()); p_v.push_back(v_tensors_device.back()->GetDeviceBuffer());
p_y.push_back(y_tensors_device.back()->GetDeviceBuffer()); p_y.push_back(y_tensors_device.back()->GetDeviceBuffer());
p_lse.push_back(lse_tensors_device.back()->GetDeviceBuffer()); p_lse.push_back(lse_tensors_device.back()->GetDeviceBuffer());
...@@ -815,7 +798,7 @@ int run(int argc, char* argv[]) ...@@ -815,7 +798,7 @@ int run(int argc, char* argv[])
auto argument = auto argument =
gemm.MakeArgument(p_q, gemm.MakeArgument(p_q,
p_k, p_k,
p_z, p_z_nullptr,
p_v, p_v,
p_y, p_y,
p_lse, p_lse,
...@@ -857,33 +840,7 @@ int run(int argc, char* argv[]) ...@@ -857,33 +840,7 @@ int run(int argc, char* argv[])
bool pass = true; bool pass = true;
if(do_verification) if(do_verification)
{ {
// get z matrix
for(int i = 0; i < group_count; i++)
{
int G1 = v_tensors[i].GetLengths()[1];
z_tensors_device[i]->FromDevice(z_g_m_ns[i].data());
run_attention_fwd_host(q_g_m_ks[i],
k_g_n_ks[i],
v_g_n_os[i],
alpha,
s_g_m_ns[i],
p_g_m_ns[i],
y_g_m_os[i],
lse_g_ms[i],
p_drop_g_m_ns[i],
z_g_m_ns[i],
p_dropout_in_16bits,
rp_dropout);
y_tensors[i].ForEach([&](auto& self, auto idx) {
self(idx) = y_g_m_os[i](idx[0] * G1 + idx[1], idx[2], idx[3]);
});
y_tensors_device[i]->ToDevice(y_tensors[i].data());
qgrad_tensors_device[i]->SetZero();
kgrad_tensors_device[i]->SetZero();
vgrad_tensors_device[i]->SetZero();
}
p_z = std::vector<void*>(p_z.size(), nullptr);
argument = argument =
gemm.MakeArgument(p_q, gemm.MakeArgument(p_q,
p_k, p_k,
...@@ -905,8 +862,8 @@ int run(int argc, char* argv[]) ...@@ -905,8 +862,8 @@ int run(int argc, char* argv[])
YElementOp{}, YElementOp{},
p_drop, p_drop,
std::tuple<unsigned long long, unsigned long long>(seed, offset)); std::tuple<unsigned long long, unsigned long long>(seed, offset));
DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument)); DeviceMem problem_desc_workspace_verify(gemm.GetWorkSpaceSize(&argument));
gemm.SetWorkSpacePointer(&argument, problem_desc_workspace.GetDeviceBuffer()); gemm.SetWorkSpacePointer(&argument, problem_desc_workspace_verify.GetDeviceBuffer());
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
...@@ -914,6 +871,43 @@ int run(int argc, char* argv[]) ...@@ -914,6 +871,43 @@ int run(int argc, char* argv[])
return 0; return 0;
} }
invoker.Run(argument, StreamConfig{nullptr, false}); invoker.Run(argument, StreamConfig{nullptr, false});
for(std::size_t i = 0; i < group_count; i++)
{
int G1 = v_tensors[i].GetLengths()[1];
// copy z matirx data form device
z_tensors_device[i]->FromDevice(z_tensors[i].mData.data());
z_tensors[i].ForEach([&](auto& self, auto idx) {
z_g_m_ns[i](idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx);
});
run_attention_fwd_host(q_g_m_ks[i],
k_g_n_ks[i],
v_g_n_os[i],
alpha,
s_g_m_ns[i],
p_g_m_ns[i],
y_g_m_os[i],
lse_g_ms[i],
p_drop_g_m_ns[i],
z_g_m_ns[i],
p_dropout_in_16bits,
rp_dropout);
y_tensors[i].ForEach([&](auto& self, auto idx) {
self(idx) = y_g_m_os[i](idx[0] * G1 + idx[1], idx[2], idx[3]);
});
y_tensors_device[i]->ToDevice(y_tensors[i].data());
lse_tensors[i].ForEach([&](auto& self, auto idx) {
self(idx) = lse_g_ms[i](idx[0] * G1 + idx[1], idx[2]);
});
lse_tensors_device[i]->ToDevice(lse_tensors[i].data());
qgrad_tensors_device[i]->SetZero();
kgrad_tensors_device[i]->SetZero();
vgrad_tensors_device[i]->SetZero();
}
invoker.Run(argument, StreamConfig{nullptr, false});
for(std::size_t i = 0; i < group_count; i++) for(std::size_t i = 0; i < group_count; i++)
{ {
......
...@@ -98,7 +98,7 @@ __global__ void ...@@ -98,7 +98,7 @@ __global__ void
unsigned short* z_matrix_ptr = unsigned short* z_matrix_ptr =
(arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr (arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset); : arg_ptr[group_id].p_z_grid_ + z_batch_offset);
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset, arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset, arg_ptr[group_id].p_b_grid_ + b_batch_offset,
...@@ -211,7 +211,7 @@ template <index_t NumDimG, ...@@ -211,7 +211,7 @@ template <index_t NumDimG,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
MaskingSpecialization MaskingSpec, MaskingSpecialization MaskingSpec,
LoopScheduler LoopSched = LoopScheduler::Default> LoopScheduler LoopSched = LoopScheduler::Default>
struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1 struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1
: public BaseOperator // TODO inherit atten bwd op once API stablizes : public BaseOperator // TODO inherit atten bwd op once API stablizes
{ {
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
...@@ -223,7 +223,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -223,7 +223,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1
// TODO: implement bias combination // TODO: implement bias combination
static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented");
using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1; using DeviceOp = DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1;
struct ProblemDesc struct ProblemDesc
{ {
std::vector<index_t> a_gs_ms_ks_lengths; std::vector<index_t> a_gs_ms_ks_lengths;
...@@ -448,7 +448,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -448,7 +448,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1
} }
} }
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {})); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
using B1GridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {})); using B1GridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {}));
...@@ -534,7 +533,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -534,7 +533,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1
}; };
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_PT1< using GridwiseGemm = GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1<
DataType, // TODO: distinguish A/B datatype DataType, // TODO: distinguish A/B datatype
GemmDataType, GemmDataType,
GemmAccDataType, GemmAccDataType,
...@@ -711,7 +710,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -711,7 +710,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1
} }
grid_size_ = 0; grid_size_ = 0;
for(std::size_t i = 0; i < group_count_; i++) for(index_t i = 0; i < group_count_; i++)
{ {
const auto p_a_grid = static_cast<const DataType*>(p_As[i]); const auto p_a_grid = static_cast<const DataType*>(p_As[i]);
const auto p_b_grid = static_cast<const DataType*>(p_Bs[i]); const auto p_b_grid = static_cast<const DataType*>(p_Bs[i]);
...@@ -895,7 +894,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -895,7 +894,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1
// for(std::size_t i = 0; i < arg.group_count_; i++) // for(std::size_t i = 0; i < arg.group_count_; i++)
// { // {
// const auto K = // const auto K =
// arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2); // arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0) *
// arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2);
// const bool y = GridwiseGemm::CalculateHasMainKBlockLoop(K); // const bool y = GridwiseGemm::CalculateHasMainKBlockLoop(K);
// all_has_main_k_block_loop &= y; // all_has_main_k_block_loop &= y;
// some_has_main_k_block_loop |= y; // some_has_main_k_block_loop |= y;
...@@ -976,7 +976,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -976,7 +976,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1
return false; return false;
} }
for(std::size_t i = 0; i < arg.group_count_; i++) for(index_t i = 0; i < arg.group_count_; i++)
{ {
// TODO: Check if tensor specialization & strides mismatch // TODO: Check if tensor specialization & strides mismatch
const auto& kernel_arg = arg.group_kernel_args_[i]; const auto& kernel_arg = arg.group_kernel_args_[i];
...@@ -986,7 +986,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -986,7 +986,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1
const index_t c_m = kernel_arg.y_grid_desc_m_o_.GetLength(I0); const index_t c_m = kernel_arg.y_grid_desc_m_o_.GetLength(I0);
const index_t c_gemm1n = kernel_arg.y_grid_desc_m_o_.GetLength(I1); const index_t c_gemm1n = kernel_arg.y_grid_desc_m_o_.GetLength(I1);
const index_t a_m = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); const index_t a_m = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) * kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2); const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I0) *
kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I2);
if(!(c_g == device_arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n)) if(!(c_g == device_arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n))
{ {
...@@ -1160,7 +1161,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1 ...@@ -1160,7 +1161,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_PT1" str << "DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V1"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp" #include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v2.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt2.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp" #include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -98,7 +98,7 @@ __global__ void ...@@ -98,7 +98,7 @@ __global__ void
unsigned short* z_matrix_ptr = unsigned short* z_matrix_ptr =
(arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr (arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset); : arg_ptr[group_id].p_z_grid_ + z_batch_offset);
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop>(
arg_ptr[group_id].p_a_grid_ + a_batch_offset, arg_ptr[group_id].p_a_grid_ + a_batch_offset,
arg_ptr[group_id].p_b_grid_ + b_batch_offset, arg_ptr[group_id].p_b_grid_ + b_batch_offset,
...@@ -703,7 +703,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -703,7 +703,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
} }
grid_size_ = 0; grid_size_ = 0;
for(std::size_t i = 0; i < group_count_; i++) for(index_t i = 0; i < group_count_; i++)
{ {
const auto p_a_grid = static_cast<const DataType*>(p_As[i]); const auto p_a_grid = static_cast<const DataType*>(p_As[i]);
const auto p_b_grid = static_cast<const DataType*>(p_Bs[i]); const auto p_b_grid = static_cast<const DataType*>(p_Bs[i]);
...@@ -884,10 +884,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -884,10 +884,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
bool all_has_main_k_block_loop = true; bool all_has_main_k_block_loop = true;
bool some_has_main_k_block_loop = false; bool some_has_main_k_block_loop = false;
for(std::size_t i = 0; i < arg.group_count_; i++) for(index_t i = 0; i < arg.group_count_; i++)
{ {
const auto K = const auto K = arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0) *
arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2); arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2);
const bool y = GridwiseGemm::CalculateHasMainKBlockLoop(K); const bool y = GridwiseGemm::CalculateHasMainKBlockLoop(K);
all_has_main_k_block_loop &= y; all_has_main_k_block_loop &= y;
some_has_main_k_block_loop |= y; some_has_main_k_block_loop |= y;
...@@ -968,7 +968,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2 ...@@ -968,7 +968,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle_V2
return false; return false;
} }
for(std::size_t i = 0; i < arg.group_count_; i++) for(index_t i = 0; i < arg.group_count_; i++)
{ {
// TODO: Check if tensor specialization & strides mismatch // TODO: Check if tensor specialization & strides mismatch
const auto& kernel_arg = arg.group_kernel_args_[i]; const auto& kernel_arg = arg.group_kernel_args_[i];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment