Commit 522d8b2f authored by letaoqin's avatar letaoqin
Browse files

bias grad update to light version

parent 9dc3e49b
...@@ -518,8 +518,10 @@ int run(int argc, char* argv[]) ...@@ -518,8 +518,10 @@ int run(int argc, char* argv[])
static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases; nullptr, // p_acc0_bias;
{}, // std::array<void*, 1> p_acc1_biases; nullptr, // p_acc1_bias;
nullptr,
nullptr,
q_gs_ms_ks_lengths, q_gs_ms_ks_lengths,
q_gs_ms_ks_strides, q_gs_ms_ks_strides,
k_gs_ns_ks_lengths, k_gs_ns_ks_lengths,
...@@ -564,8 +566,10 @@ int run(int argc, char* argv[]) ...@@ -564,8 +566,10 @@ int run(int argc, char* argv[])
static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases; nullptr, // p_acc0_bias;
{}, // std::array<void*, 1> p_acc1_biases; nullptr, // p_acc1_bias;
nullptr,
nullptr,
q_gs_ms_ks_lengths, q_gs_ms_ks_lengths,
q_gs_ms_ks_strides, q_gs_ms_ks_strides,
k_gs_ns_ks_lengths, k_gs_ns_ks_lengths,
......
...@@ -24,7 +24,7 @@ Kernel outputs: ...@@ -24,7 +24,7 @@ Kernel outputs:
*/ */
#define USING_MASK 0 #define USING_MASK 0
#define DIM 128 // DIM should be a multiple of 8. #define DIM 32 // DIM should be a multiple of 8.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -616,6 +616,8 @@ int run(int argc, char* argv[]) ...@@ -616,6 +616,8 @@ int run(int argc, char* argv[])
p_vgrad, p_vgrad,
{}, // std::array<void*, 1> p_acc0_biases; {}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases; {}, // std::array<void*, 1> p_acc1_biases;
{},
{},
problem_descs, problem_descs,
QKVElementOp{}, QKVElementOp{},
QKVElementOp{}, QKVElementOp{},
...@@ -663,6 +665,8 @@ int run(int argc, char* argv[]) ...@@ -663,6 +665,8 @@ int run(int argc, char* argv[])
p_vgrad, p_vgrad,
{}, // std::array<void*, 1> p_acc0_biases; {}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases; {}, // std::array<void*, 1> p_acc1_biases;
{},
{},
problem_descs, problem_descs,
QKVElementOp{}, QKVElementOp{},
QKVElementOp{}, QKVElementOp{},
......
...@@ -123,6 +123,7 @@ __global__ void ...@@ -123,6 +123,7 @@ __global__ void
const InputDataType* __restrict__ p_ygrad_grid, const InputDataType* __restrict__ p_ygrad_grid,
OutputDataType* __restrict__ p_qgrad_grid, OutputDataType* __restrict__ p_qgrad_grid,
OutputDataType* __restrict__ p_kgrad_grid, OutputDataType* __restrict__ p_kgrad_grid,
D0DataType* __restrict__ p_d0grad_grid,
OutputDataType* __restrict__ p_vgrad_grid, OutputDataType* __restrict__ p_vgrad_grid,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
...@@ -176,11 +177,19 @@ __global__ void ...@@ -176,11 +177,19 @@ __global__ void
const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded; const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded;
const D0DataType* tmp_p_d0_grid = nullptr; const D0DataType* tmp_p_d0_grid = nullptr;
D0DataType* tmp_p_d0grad_grid = nullptr;
if constexpr(!is_same<D0DataType, void>::value) if constexpr(!is_same<D0DataType, void>::value)
{ {
const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx)));
tmp_p_d0_grid = p_d0_grid + d0_batch_offset; if(p_d0_grid != nullptr)
{
tmp_p_d0_grid = p_d0_grid + d0_batch_offset;
}
if(p_d0grad_grid != nullptr)
{
tmp_p_d0grad_grid = p_d0grad_grid + d0_batch_offset;
}
} }
if constexpr(Deterministic) if constexpr(Deterministic)
{ {
...@@ -197,6 +206,7 @@ __global__ void ...@@ -197,6 +206,7 @@ __global__ void
p_ygrad_grid + c_batch_offset, p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset, p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset, p_kgrad_grid + b_batch_offset,
tmp_p_d0grad_grid,
p_vgrad_grid + b1_batch_offset, p_vgrad_grid + b1_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
...@@ -233,6 +243,7 @@ __global__ void ...@@ -233,6 +243,7 @@ __global__ void
p_ygrad_grid + c_batch_offset, p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset, p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset, p_kgrad_grid + b_batch_offset,
tmp_p_d0grad_grid,
p_vgrad_grid + b1_batch_offset, p_vgrad_grid + b1_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
...@@ -266,6 +277,7 @@ __global__ void ...@@ -266,6 +277,7 @@ __global__ void
ignore = p_ygrad_grid; ignore = p_ygrad_grid;
ignore = p_qgrad_grid; ignore = p_qgrad_grid;
ignore = p_kgrad_grid; ignore = p_kgrad_grid;
ignore = p_d0grad_grid;
ignore = p_vgrad_grid; ignore = p_vgrad_grid;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
...@@ -858,6 +870,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -858,6 +870,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
OutputDataType* p_vgrad_grid, OutputDataType* p_vgrad_grid,
const D0DataType* p_acc0_bias, const D0DataType* p_acc0_bias,
const D1DataType* p_acc1_bias, const D1DataType* p_acc1_bias,
D0DataType* p_d0grad_grid,
D1DataType* p_d1grad_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -894,6 +908,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -894,6 +908,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
p_qgrad_grid_{p_qgrad_grid}, p_qgrad_grid_{p_qgrad_grid},
p_kgrad_grid_{p_kgrad_grid}, p_kgrad_grid_{p_kgrad_grid},
p_vgrad_grid_{p_vgrad_grid}, p_vgrad_grid_{p_vgrad_grid},
p_d0grad_grid_{p_d0grad_grid},
a_grid_desc_ak0_m_ak1_{ a_grid_desc_ak0_m_ak1_{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b_grid_desc_bk0_n_bk1_{ b_grid_desc_bk0_n_bk1_{
...@@ -948,10 +963,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -948,10 +963,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
p_drop_{p_drop} p_drop_{p_drop}
{ {
// TODO: implement bias addition // TODO: implement bias addition
ignore = p_acc0_bias; ignore = p_d1grad_grid;
ignore = p_acc1_bias; ignore = p_acc1_bias;
ignore = acc0_bias_gs_ms_ns_lengths;
ignore = acc0_bias_gs_ms_ns_strides;
ignore = acc1_bias_gs_ms_gemm1ns_lengths; ignore = acc1_bias_gs_ms_gemm1ns_lengths;
ignore = acc1_bias_gs_ms_gemm1ns_strides; ignore = acc1_bias_gs_ms_gemm1ns_strides;
...@@ -1030,6 +1043,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1030,6 +1043,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
OutputDataType* p_qgrad_grid_; OutputDataType* p_qgrad_grid_;
OutputDataType* p_kgrad_grid_; OutputDataType* p_kgrad_grid_;
OutputDataType* p_vgrad_grid_; OutputDataType* p_vgrad_grid_;
D0DataType* p_d0grad_grid_;
// tensor descriptor // tensor descriptor
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
...@@ -1191,6 +1205,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1191,6 +1205,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
arg.p_ygrad_grid_, arg.p_ygrad_grid_,
arg.p_qgrad_grid_, arg.p_qgrad_grid_,
arg.p_kgrad_grid_, arg.p_kgrad_grid_,
arg.p_d0grad_grid_,
arg.p_vgrad_grid_, arg.p_vgrad_grid_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
...@@ -1342,6 +1357,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1342,6 +1357,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
OutputDataType* p_vgrad_grid, OutputDataType* p_vgrad_grid,
const D0DataType* p_acc0_bias, const D0DataType* p_acc0_bias,
const D1DataType* p_acc1_bias, const D1DataType* p_acc1_bias,
D0DataType* p_d0grad_grid,
D1DataType* p_d1grad_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -1380,6 +1397,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1380,6 +1397,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
p_vgrad_grid, p_vgrad_grid,
p_acc0_bias, p_acc0_bias,
p_acc1_bias, p_acc1_bias,
p_d0grad_grid,
p_d1grad_grid,
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b_gs_ns_ks_lengths, b_gs_ns_ks_lengths,
...@@ -1422,6 +1441,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1422,6 +1441,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
void* p_vgrad_grid, void* p_vgrad_grid,
const D0DataType* p_acc0_bias, const D0DataType* p_acc0_bias,
const D1DataType* p_acc1_bias, const D1DataType* p_acc1_bias,
void* p_d0grad_grid,
void* p_d1grad_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -1461,6 +1482,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1461,6 +1482,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static_cast<OutputDataType*>(p_vgrad_grid), static_cast<OutputDataType*>(p_vgrad_grid),
static_cast<const D0DataType*>(p_acc0_bias), // cast in struct Argument static_cast<const D0DataType*>(p_acc0_bias), // cast in struct Argument
static_cast<const D1DataType*>(p_acc1_bias), // cast in struct Argument static_cast<const D1DataType*>(p_acc1_bias), // cast in struct Argument
static_cast<D0DataType*>(p_d0grad_grid),
static_cast<D1DataType*>(p_d1grad_grid),
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b_gs_ns_ks_lengths, b_gs_ns_ks_lengths,
......
...@@ -123,6 +123,7 @@ __global__ void ...@@ -123,6 +123,7 @@ __global__ void
const InputDataType* __restrict__ p_ygrad_grid, const InputDataType* __restrict__ p_ygrad_grid,
OutputDataType* __restrict__ p_qgrad_grid, OutputDataType* __restrict__ p_qgrad_grid,
OutputDataType* __restrict__ p_kgrad_grid, OutputDataType* __restrict__ p_kgrad_grid,
D0DataType* __restrict__ p_d0grad_grid,
OutputDataType* __restrict__ p_vgrad_grid, OutputDataType* __restrict__ p_vgrad_grid,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
...@@ -176,11 +177,19 @@ __global__ void ...@@ -176,11 +177,19 @@ __global__ void
const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded; const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded;
const D0DataType* tmp_p_d0_grid = nullptr; const D0DataType* tmp_p_d0_grid = nullptr;
D0DataType* tmp_p_d0grad_grid = nullptr;
if constexpr(!is_same<D0DataType, void>::value) if constexpr(!is_same<D0DataType, void>::value)
{ {
const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx)));
tmp_p_d0_grid = p_d0_grid + d0_batch_offset; if(p_d0_grid != nullptr)
{
tmp_p_d0_grid = p_d0_grid + d0_batch_offset;
}
if(p_d0grad_grid != nullptr)
{
tmp_p_d0grad_grid = p_d0grad_grid + d0_batch_offset;
}
} }
if constexpr(Deterministic) if constexpr(Deterministic)
...@@ -198,6 +207,7 @@ __global__ void ...@@ -198,6 +207,7 @@ __global__ void
p_ygrad_grid + c_batch_offset, p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset, p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset, p_kgrad_grid + b_batch_offset,
tmp_p_d0grad_grid,
p_vgrad_grid + b1_batch_offset, p_vgrad_grid + b1_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
...@@ -234,6 +244,7 @@ __global__ void ...@@ -234,6 +244,7 @@ __global__ void
p_ygrad_grid + c_batch_offset, p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset, p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset, p_kgrad_grid + b_batch_offset,
tmp_p_d0grad_grid,
p_vgrad_grid + b1_batch_offset, p_vgrad_grid + b1_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
...@@ -267,6 +278,7 @@ __global__ void ...@@ -267,6 +278,7 @@ __global__ void
ignore = p_ygrad_grid; ignore = p_ygrad_grid;
ignore = p_qgrad_grid; ignore = p_qgrad_grid;
ignore = p_kgrad_grid; ignore = p_kgrad_grid;
ignore = p_d0grad_grid;
ignore = p_vgrad_grid; ignore = p_vgrad_grid;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
...@@ -874,6 +886,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -874,6 +886,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
OutputDataType* p_vgrad_grid, OutputDataType* p_vgrad_grid,
const D0DataType* p_acc0_bias, const D0DataType* p_acc0_bias,
const D1DataType* p_acc1_bias, const D1DataType* p_acc1_bias,
D0DataType* p_d0grad_grid,
D1DataType* p_d1grad_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -910,6 +924,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -910,6 +924,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
p_qgrad_grid_{p_qgrad_grid}, p_qgrad_grid_{p_qgrad_grid},
p_kgrad_grid_{p_kgrad_grid}, p_kgrad_grid_{p_kgrad_grid},
p_vgrad_grid_{p_vgrad_grid}, p_vgrad_grid_{p_vgrad_grid},
p_d0grad_grid_{p_d0grad_grid},
a_grid_desc_ak0_m_ak1_{ a_grid_desc_ak0_m_ak1_{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b_grid_desc_bk0_n_bk1_{ b_grid_desc_bk0_n_bk1_{
...@@ -964,6 +979,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -964,6 +979,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
{ {
// TODO: implement bias addition // TODO: implement bias addition
ignore = p_acc1_bias; ignore = p_acc1_bias;
ignore = p_d1grad_grid;
ignore = acc1_bias_gs_ms_gemm1ns_lengths; ignore = acc1_bias_gs_ms_gemm1ns_lengths;
ignore = acc1_bias_gs_ms_gemm1ns_strides; ignore = acc1_bias_gs_ms_gemm1ns_strides;
...@@ -1042,6 +1058,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1042,6 +1058,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
OutputDataType* p_qgrad_grid_; OutputDataType* p_qgrad_grid_;
OutputDataType* p_kgrad_grid_; OutputDataType* p_kgrad_grid_;
OutputDataType* p_vgrad_grid_; OutputDataType* p_vgrad_grid_;
D0DataType* p_d0grad_grid_;
// tensor descriptor // tensor descriptor
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
...@@ -1207,6 +1224,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1207,6 +1224,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
arg.p_ygrad_grid_, arg.p_ygrad_grid_,
arg.p_qgrad_grid_, arg.p_qgrad_grid_,
arg.p_kgrad_grid_, arg.p_kgrad_grid_,
arg.p_d0grad_grid_,
arg.p_vgrad_grid_, arg.p_vgrad_grid_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
...@@ -1374,6 +1392,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1374,6 +1392,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
OutputDataType* p_vgrad_grid, OutputDataType* p_vgrad_grid,
const D0DataType* p_acc0_bias, const D0DataType* p_acc0_bias,
const D1DataType* p_acc1_bias, const D1DataType* p_acc1_bias,
D0DataType* p_d0grad_grid,
D1DataType* p_d1grad_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -1412,6 +1432,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1412,6 +1432,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
p_vgrad_grid, p_vgrad_grid,
p_acc0_bias, p_acc0_bias,
p_acc1_bias, p_acc1_bias,
p_d0grad_grid,
p_d1grad_grid,
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b_gs_ns_ks_lengths, b_gs_ns_ks_lengths,
...@@ -1454,6 +1476,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1454,6 +1476,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
void* p_vgrad_grid, void* p_vgrad_grid,
const void* p_acc0_bias, const void* p_acc0_bias,
const void* p_acc1_bias, const void* p_acc1_bias,
void* p_d0grad_grid,
void* p_d1grad_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -1493,6 +1517,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1493,6 +1517,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static_cast<OutputDataType*>(p_vgrad_grid), static_cast<OutputDataType*>(p_vgrad_grid),
static_cast<const D0DataType*>(p_acc0_bias), // cast in struct Argument static_cast<const D0DataType*>(p_acc0_bias), // cast in struct Argument
static_cast<const D1DataType*>(p_acc1_bias), // cast in struct Argument static_cast<const D1DataType*>(p_acc1_bias), // cast in struct Argument
static_cast<D0DataType*>(p_d0grad_grid),
static_cast<D1DataType*>(p_d1grad_grid),
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b_gs_ns_ks_lengths, b_gs_ns_ks_lengths,
......
...@@ -1333,8 +1333,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1333,8 +1333,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
void* p_vgrad_grid, void* p_vgrad_grid,
const void* p_acc0_bias, const void* p_acc0_bias,
const void* p_acc1_bias, const void* p_acc1_bias,
D0DataType* p_d0grad_grid, void* p_d0grad_grid,
D1DataType* p_d1grad_grid, void* p_d1grad_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -1373,8 +1373,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1373,8 +1373,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static_cast<OutputDataType*>(p_vgrad_grid), static_cast<OutputDataType*>(p_vgrad_grid),
static_cast<const D0DataType*>(p_acc0_bias), // cast in struct Argument static_cast<const D0DataType*>(p_acc0_bias), // cast in struct Argument
static_cast<const D1DataType*>(p_acc1_bias), // cast in struct Argument static_cast<const D1DataType*>(p_acc1_bias), // cast in struct Argument
static_cast<const D0DataType*>(p_d0grad_grid), static_cast<D0DataType*>(p_d0grad_grid),
static_cast<const D1DataType*>(p_d1grad_grid), static_cast<D1DataType*>(p_d1grad_grid),
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b_gs_ns_ks_lengths, b_gs_ns_ks_lengths,
......
...@@ -162,13 +162,16 @@ __global__ void ...@@ -162,13 +162,16 @@ __global__ void
: arg_ptr[group_id].p_z_grid_ + z_batch_offset); : arg_ptr[group_id].p_z_grid_ + z_batch_offset);
const D0DataType* tmp_p_d0_grid = nullptr; const D0DataType* tmp_p_d0_grid = nullptr;
D0DataType* tmp_p_d0grad_grid = nullptr;
if constexpr(!is_same<D0DataType, void>::value) if constexpr(!is_same<D0DataType, void>::value)
{ {
const long_index_t d0_batch_offset = const long_index_t d0_batch_offset =
__builtin_amdgcn_readfirstlane(static_cast<long_index_t>( __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetD0BasePtr(g_idx))); arg_ptr[group_id].compute_base_ptr_of_batch_.GetD0BasePtr(g_idx)));
if(arg_ptr[group_id].p_d0_grid_ != nullptr)
tmp_p_d0_grid = arg_ptr[group_id].p_d0_grid_ + d0_batch_offset; tmp_p_d0_grid = arg_ptr[group_id].p_d0_grid_ + d0_batch_offset;
if(arg_ptr[group_id].p_d0grad_grid_)
tmp_p_d0grad_grid = arg_ptr[group_id].p_d0grad_grid_ + d0_batch_offset;
} }
if constexpr(Deterministic) if constexpr(Deterministic)
{ {
...@@ -185,6 +188,7 @@ __global__ void ...@@ -185,6 +188,7 @@ __global__ void
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset, arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset, arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset,
arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset, arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset,
tmp_p_d0grad_grid,
arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset, arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
...@@ -222,6 +226,7 @@ __global__ void ...@@ -222,6 +226,7 @@ __global__ void
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset, arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset, arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset,
arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset, arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset,
tmp_p_d0grad_grid,
arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset, arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
...@@ -806,6 +811,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -806,6 +811,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const InputDataType* p_ygrad_grid_; const InputDataType* p_ygrad_grid_;
OutputDataType* p_qgrad_grid_; OutputDataType* p_qgrad_grid_;
OutputDataType* p_kgrad_grid_; OutputDataType* p_kgrad_grid_;
D0DataType* p_d0grad_grid_;
OutputDataType* p_vgrad_grid_; OutputDataType* p_vgrad_grid_;
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
...@@ -878,6 +884,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -878,6 +884,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
std::vector<void*>& p_Vgrads, std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias_vec, const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec, const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<void*>& p_d0grads,
const std::vector<void*>& p_d1grads,
const std::vector<ProblemDesc>& problem_desc_vec, const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -911,7 +919,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -911,7 +919,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
group_count_ == ck::type_convert<ck::index_t>(p_Ds.size()) && group_count_ == ck::type_convert<ck::index_t>(p_Ds.size()) &&
(group_count_ == ck::type_convert<ck::index_t>(p_acc0_bias_vec.size()) || (group_count_ == ck::type_convert<ck::index_t>(p_acc0_bias_vec.size()) ||
ck::type_convert<ck::index_t>(p_acc0_bias_vec.size() == 0)) && ck::type_convert<ck::index_t>(p_acc0_bias_vec.size() == 0)) &&
0 == p_acc1_bias_vec.size())) 0 == p_acc1_bias_vec.size() &&
(group_count_ == ck::type_convert<ck::index_t>(p_d0grads.size()) ||
ck::type_convert<ck::index_t>(p_d0grads.size() == 0)) &&
0 == p_d1grads.size()))
{ {
throw std::runtime_error("wrong! group_count_ != p_As/b/b1/c.size"); throw std::runtime_error("wrong! group_count_ != p_As/b/b1/c.size");
} }
...@@ -937,7 +948,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -937,7 +948,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const auto p_ygrad_grid = static_cast<const InputDataType*>(p_Ygrads[i]); const auto p_ygrad_grid = static_cast<const InputDataType*>(p_Ygrads[i]);
auto p_qgrad_grid = static_cast<OutputDataType*>(p_Qgrads[i]); auto p_qgrad_grid = static_cast<OutputDataType*>(p_Qgrads[i]);
auto p_kgrad_grid = static_cast<OutputDataType*>(p_Kgrads[i]); auto p_kgrad_grid = static_cast<OutputDataType*>(p_Kgrads[i]);
auto p_vgrad_grid = static_cast<OutputDataType*>(p_Vgrads[i]); auto p_d0grad_grid =
(ck::type_convert<ck::index_t>(p_d0grads.size()) == group_count_)
? static_cast<D0DataType*>(p_d0grads[i])
: nullptr;
auto p_vgrad_grid = static_cast<OutputDataType*>(p_Vgrads[i]);
const auto& problem_desc = problem_desc_vec[i]; const auto& problem_desc = problem_desc_vec[i];
...@@ -1054,6 +1069,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1054,6 +1069,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
p_ygrad_grid, p_ygrad_grid,
p_qgrad_grid, p_qgrad_grid,
p_kgrad_grid, p_kgrad_grid,
p_d0grad_grid,
p_vgrad_grid, p_vgrad_grid,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
...@@ -1370,6 +1386,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1370,6 +1386,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
std::vector<void*>& p_Vgrads, std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias_vec, const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec, const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<void*>& p_d0grads,
const std::vector<void*>& p_d1grads,
const std::vector<ProblemDesc>& problem_desc_vec, const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -1392,6 +1410,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1392,6 +1410,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
p_Vgrads, p_Vgrads,
p_acc0_bias_vec, p_acc0_bias_vec,
p_acc1_bias_vec, p_acc1_bias_vec,
p_d0grads,
p_d1grads,
problem_desc_vec, problem_desc_vec,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -1420,6 +1440,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1420,6 +1440,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
std::vector<void*>& p_Vgrads, std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias_vec, const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec, const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<void*>& p_d0grads,
const std::vector<void*>& p_d1grads,
const std::vector<ProblemDesc>& problem_desc_vec, const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -1442,6 +1464,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1442,6 +1464,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
p_Vgrads, p_Vgrads,
p_acc0_bias_vec, // cast in struct Argument p_acc0_bias_vec, // cast in struct Argument
p_acc1_bias_vec, // cast in struct Argument p_acc1_bias_vec, // cast in struct Argument
p_d0grads,
p_d1grads,
problem_desc_vec, problem_desc_vec,
a_element_op, a_element_op,
b_element_op, b_element_op,
......
...@@ -160,13 +160,17 @@ __global__ void ...@@ -160,13 +160,17 @@ __global__ void
(arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr (arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset); : arg_ptr[group_id].p_z_grid_ + z_batch_offset);
const D0DataType* tmp_p_d0_grid = nullptr; const D0DataType* tmp_p_d0_grid = nullptr;
D0DataType* tmp_p_d0grad_grid = nullptr;
if constexpr(!is_same<D0DataType, void>::value) if constexpr(!is_same<D0DataType, void>::value)
{ {
const long_index_t d0_batch_offset = const long_index_t d0_batch_offset =
__builtin_amdgcn_readfirstlane(static_cast<long_index_t>( __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetD0BasePtr(g_idx))); arg_ptr[group_id].compute_base_ptr_of_batch_.GetD0BasePtr(g_idx)));
tmp_p_d0_grid = arg_ptr[group_id].p_d0_grid_ + d0_batch_offset; if(arg_ptr[group_id].p_d0_grid_ != nullptr)
tmp_p_d0_grid = arg_ptr[group_id].p_d0_grid_ + d0_batch_offset;
if(arg_ptr[group_id].p_d0grad_grid_)
tmp_p_d0grad_grid = arg_ptr[group_id].p_d0grad_grid_ + d0_batch_offset;
} }
if constexpr(Deterministic) if constexpr(Deterministic)
...@@ -184,6 +188,7 @@ __global__ void ...@@ -184,6 +188,7 @@ __global__ void
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset, arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset, arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset,
arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset, arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset,
tmp_p_d0grad_grid,
arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset, arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
...@@ -221,6 +226,7 @@ __global__ void ...@@ -221,6 +226,7 @@ __global__ void
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset, arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset, arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset,
arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset, arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset,
tmp_p_d0grad_grid,
arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset, arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
...@@ -876,6 +882,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -876,6 +882,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const InputDataType* p_ygrad_grid_; const InputDataType* p_ygrad_grid_;
OutputDataType* p_qgrad_grid_; OutputDataType* p_qgrad_grid_;
OutputDataType* p_kgrad_grid_; OutputDataType* p_kgrad_grid_;
D0DataType* p_d0grad_grid_;
OutputDataType* p_vgrad_grid_; OutputDataType* p_vgrad_grid_;
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
...@@ -948,6 +955,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -948,6 +955,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
std::vector<void*>& p_Vgrads, std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias_vec, const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec, const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<void*>& p_d0grads,
const std::vector<void*>& p_d1grads,
const std::vector<ProblemDesc>& problem_desc_vec, const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -981,7 +990,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -981,7 +990,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
group_count_ == ck::type_convert<ck::index_t>(p_Ds.size()) && group_count_ == ck::type_convert<ck::index_t>(p_Ds.size()) &&
(group_count_ == ck::type_convert<ck::index_t>(p_acc0_bias_vec.size()) || (group_count_ == ck::type_convert<ck::index_t>(p_acc0_bias_vec.size()) ||
ck::type_convert<ck::index_t>(p_acc0_bias_vec.size() == 0)) && ck::type_convert<ck::index_t>(p_acc0_bias_vec.size() == 0)) &&
0 == p_acc1_bias_vec.size())) 0 == p_acc1_bias_vec.size() &&
(group_count_ == ck::type_convert<ck::index_t>(p_d0grads.size()) ||
ck::type_convert<ck::index_t>(p_d0grads.size() == 0)) &&
0 == p_d1grads.size()))
{ {
throw std::runtime_error("wrong! group_count_ != p_As/b/b1/c.size"); throw std::runtime_error("wrong! group_count_ != p_As/b/b1/c.size");
} }
...@@ -1007,7 +1019,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1007,7 +1019,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const auto p_ygrad_grid = static_cast<const InputDataType*>(p_Ygrads[i]); const auto p_ygrad_grid = static_cast<const InputDataType*>(p_Ygrads[i]);
auto p_qgrad_grid = static_cast<OutputDataType*>(p_Qgrads[i]); auto p_qgrad_grid = static_cast<OutputDataType*>(p_Qgrads[i]);
auto p_kgrad_grid = static_cast<OutputDataType*>(p_Kgrads[i]); auto p_kgrad_grid = static_cast<OutputDataType*>(p_Kgrads[i]);
auto p_vgrad_grid = static_cast<OutputDataType*>(p_Vgrads[i]); auto p_d0grad_grid =
(ck::type_convert<ck::index_t>(p_d0grads.size()) == group_count_)
? static_cast<D0DataType*>(p_d0grads[i])
: nullptr;
auto p_vgrad_grid = static_cast<OutputDataType*>(p_Vgrads[i]);
const auto& problem_desc = problem_desc_vec[i]; const auto& problem_desc = problem_desc_vec[i];
...@@ -1124,6 +1140,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1124,6 +1140,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
p_ygrad_grid, p_ygrad_grid,
p_qgrad_grid, p_qgrad_grid,
p_kgrad_grid, p_kgrad_grid,
p_d0grad_grid,
p_vgrad_grid, p_vgrad_grid,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
...@@ -1445,6 +1462,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1445,6 +1462,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
std::vector<void*>& p_Vgrads, std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias_vec, const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec, const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<void*>& p_d0grads,
const std::vector<void*>& p_d1grads,
const std::vector<ProblemDesc>& problem_desc_vec, const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -1467,6 +1486,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1467,6 +1486,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
p_Vgrads, p_Vgrads,
p_acc0_bias_vec, p_acc0_bias_vec,
p_acc1_bias_vec, p_acc1_bias_vec,
p_d0grads,
p_d1grads,
problem_desc_vec, problem_desc_vec,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -1495,6 +1516,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1495,6 +1516,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
std::vector<void*>& p_Vgrads, std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias_vec, const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec, const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<void*>& p_d0grads,
const std::vector<void*>& p_d1grads,
const std::vector<ProblemDesc>& problem_desc_vec, const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -1517,6 +1540,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1517,6 +1540,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
p_Vgrads, p_Vgrads,
p_acc0_bias_vec, // cast in struct Argument p_acc0_bias_vec, // cast in struct Argument
p_acc1_bias_vec, // cast in struct Argument p_acc1_bias_vec, // cast in struct Argument
p_d0grads,
p_d1grads,
problem_desc_vec, problem_desc_vec,
a_element_op, a_element_op,
b_element_op, b_element_op,
......
...@@ -1215,7 +1215,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1215,7 +1215,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
using D0GridDescriptor_M0_N0_M1_M2_N1_M3 = using D0GridDescriptor_M0_N0_M1_M2_N1_M3 =
remove_cvref_t<decltype(MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(D0GridDesc_M_N{}))>; remove_cvref_t<decltype(MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(D0GridDesc_M_N{}))>;
struct D0Loader struct D0Operator
{ {
template <typename DataType> template <typename DataType>
struct TypeTransform struct TypeTransform
...@@ -1235,13 +1235,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1235,13 +1235,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static_assert(MPerXdl <= KPerBlock); static_assert(MPerXdl <= KPerBlock);
static_assert(D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock, static_assert(D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock,
"D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock"); "D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock");
__host__ __device__ static constexpr auto GetD0BlockWriteDescriptor_M0_N0_M1_M2_N1_M3() __host__ __device__ static constexpr auto GetD0BlockGlobalDescriptor_M0_N0_M1_M2_N1_M3()
{ {
// B1 matrix in LDS memory, dst of blockwise copy // B1 matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor_packed( return make_naive_tensor_descriptor_packed(
make_tuple(I1, I1, I1, D0M1, Number<NPerBlock>{}, D0M2)); make_tuple(I1, I1, I1, D0M1, Number<NPerBlock>{}, D0M2));
} }
__host__ __device__ static constexpr auto GetD0BlockReadDescriptor_N0_N1_M0_M1_M2() __host__ __device__ static constexpr auto GetD0BlockVgprDescriptor_N0_N1_M0_M1_M2()
{ {
constexpr auto d0_raw_m0_n_m1 = constexpr auto d0_raw_m0_n_m1 =
make_naive_tensor_descriptor_packed(make_tuple(D0M1, Number<NPerBlock>{}, D0M2)); make_naive_tensor_descriptor_packed(make_tuple(D0M1, Number<NPerBlock>{}, D0M2));
...@@ -1256,15 +1256,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1256,15 +1256,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
make_tuple(Sequence<2, 3>{}, Sequence<0, 1>{}, Sequence<4>{})); make_tuple(Sequence<2, 3>{}, Sequence<0, 1>{}, Sequence<4>{}));
return d0_n0_n1_m0_m1_m2; return d0_n0_n1_m0_m1_m2;
} }
static constexpr auto d0_block_write_desc_m0_n0_m1_m2_n1_m3 = static constexpr auto d0_block_global_desc_m0_n0_m1_m2_n1_m3 =
GetD0BlockWriteDescriptor_M0_N0_M1_M2_N1_M3(); GetD0BlockGlobalDescriptor_M0_N0_M1_M2_N1_M3();
static constexpr auto d0_block_read_desc_n0_n1_m0_m1_m2 = static constexpr auto d0_block_vgpr_desc_n0_n1_m0_m1_m2 =
GetD0BlockReadDescriptor_N0_N1_M0_M1_M2(); GetD0BlockVgprDescriptor_N0_N1_M0_M1_M2();
static constexpr auto d0_thread_desc_ = static constexpr auto d0_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I4, I1, D0M2)); make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I4, I1, D0M2));
using D0BlockwiseCopy = ThreadGroupTensorSliceTransfer_v4r1< using D0BlockwiseCopyGlobalToLds = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock, ThisThreadBlock,
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
...@@ -1275,18 +1275,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1275,18 +1275,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
1, 1,
BlockSize / NThreadClusterLengths, BlockSize / NThreadClusterLengths,
NThreadClusterLengths, NThreadClusterLengths,
1>, // ThreadClusterLengths 1>, // ThreadClusterLengths
Sequence<0, 1, 2, 3, 5, 4>, // ThreadClusterArrangeOrder Sequence<0, 1, 2, 3, 5, 4>, // ThreadClusterArrangeOrder
typename TypeTransform<D0DataType>::Type, // SrcData typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData typename TypeTransform<D0DataType>::Type, // DstData
D0GridDescriptor_M0_N0_M1_M2_N1_M3, // SrcDesc D0GridDescriptor_M0_N0_M1_M2_N1_M3, // SrcDesc
decltype(d0_block_write_desc_m0_n0_m1_m2_n1_m3), // DstDesc decltype(d0_block_global_desc_m0_n0_m1_m2_n1_m3), // DstDesc
Sequence<0, 1, 2, 3, 5, 4>, // SrcDimAccessOrder Sequence<0, 1, 2, 3, 5, 4>, // SrcDimAccessOrder
Sequence<0, 1, 2, 4, 3, 5>, // DstDimAccessOrder Sequence<0, 1, 2, 4, 3, 5>, // DstDimAccessOrder
4, // SrcVectorDim 4, // SrcVectorDim
5, // DstVectorDim 5, // DstVectorDim
D0BlockTransferSrcScalarPerVector, // SrcScalarPerVector D0BlockTransferSrcScalarPerVector, // SrcScalarPerVector
4, // DstScalarPerVector 4, // DstScalarPerVector
1, 1,
1, 1,
true, true,
...@@ -1296,13 +1296,56 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1296,13 +1296,56 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
using D0ThreadWiseCopy = using D0ThreadWiseCopy =
ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_read_desc_n0_n1_m0_m1_m2), // SrcDesc decltype(d0_block_vgpr_desc_n0_n1_m0_m1_m2), // SrcDesc
decltype(d0_thread_desc_), // DstDesc decltype(d0_thread_desc_), // DstDesc
Sequence<1, 1, 4, 1, 4>, // SliceLengths Sequence<1, 1, 4, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // DimAccessOrder Sequence<0, 1, 2, 3, 4>, // DimAccessOrder
4, // SrcVectorDim 4, // SrcVectorDim
2, // SrcScalarPerVector 2, // SrcScalarPerVector
2>; 2>;
using D0ThreadCopyVgprToLds = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
typename TypeTransform<D0DataType>::Type,
decltype(d0_thread_desc_),
decltype(d0_block_vgpr_desc_n0_n1_m0_m1_m2),
tensor_operation::element_wise::Scale, // CElementwiseOperation
Sequence<1, 1, 4, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // AccessOrder
4, // VectorDim
4, // ScalarPerVector
InMemoryDataOperationEnum::Set, // GlobalMemoryDataOperation
1, // DstScalarStrideInVector
true>;
using D0BlockwiseCopyLdsToGlobal = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<I1, I1, I1, D0M1, NPerBlock, D0M2>, // BlockSliceLengths
Sequence<1,
1,
1,
BlockSize / NThreadClusterLengths,
NThreadClusterLengths,
1>, // ThreadClusterLengths
Sequence<0, 1, 2, 3, 5, 4>, // ThreadClusterArrangeOrder
typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_global_desc_m0_n0_m1_m2_n1_m3), // SrcDesc
D0GridDescriptor_M0_N0_M1_M2_N1_M3, // DstDesc
Sequence<0, 1, 2, 4, 3, 5>, // SrcDimAccessOrder
Sequence<0, 1, 2, 3, 5, 4>, // DstDimAccessOrder
5, // SrcVectorDim
4, // DstVectorDim
4, // SrcScalarPerVector
D0BlockTransferSrcScalarPerVector, // DstScalarPerVector
1,
1,
true,
true, // DstResetCoord
1>;
}; };
struct SharedMemTrait struct SharedMemTrait
...@@ -1337,11 +1380,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1337,11 +1380,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
q_block_space_size_aligned.value; q_block_space_size_aligned.value;
static constexpr auto d0_block_space_size_aligned = math::integer_least_multiple( static constexpr auto d0_block_space_size_aligned = math::integer_least_multiple(
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize(), max_lds_align); D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize(),
max_lds_align);
static constexpr auto d0_block_space_offset = static constexpr auto d0_block_space_offset =
(k_block_space_size_aligned.value + ygrad_block_space_size_aligned.value + (k_block_space_size_aligned.value + ygrad_block_space_size_aligned.value +
q_block_space_size_aligned.value) * q_block_space_size_aligned.value) *
sizeof(GemmDataType) / D0Loader::template TypeTransform<D0DataType>::Size; sizeof(GemmDataType) / D0Operator::template TypeTransform<D0DataType>::Size;
// LDS allocation for C shuffle in LDS // LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
...@@ -1358,7 +1402,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1358,7 +1402,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
sizeof(GemmDataType); sizeof(GemmDataType);
const index_t d0_bytes_end = const index_t d0_bytes_end =
(SharedMemTrait::d0_block_space_offset + SharedMemTrait::d0_block_space_size_aligned) * (SharedMemTrait::d0_block_space_offset + SharedMemTrait::d0_block_space_size_aligned) *
D0Loader::template TypeTransform<D0DataType>::Size0; D0Operator::template TypeTransform<D0DataType>::Size0;
const index_t c_block_bytes_end = const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle); SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
...@@ -1381,6 +1425,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1381,6 +1425,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const InputDataType* __restrict__ p_ygrad_grid, const InputDataType* __restrict__ p_ygrad_grid,
OutputDataType* __restrict__ p_qgrad_grid, OutputDataType* __restrict__ p_qgrad_grid,
OutputDataType* __restrict__ p_kgrad_grid, OutputDataType* __restrict__ p_kgrad_grid,
D0DataType* __restrict__ p_d0grad_grid,
OutputDataType* __restrict__ p_vgrad_grid, OutputDataType* __restrict__ p_vgrad_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
...@@ -1848,17 +1893,30 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1848,17 +1893,30 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// gemm0 M loop // gemm0 M loop
index_t gemm0_m_block_outer_index = num_gemm0_m_block_outer_loop - 1; index_t gemm0_m_block_outer_index = num_gemm0_m_block_outer_loop - 1;
// D0 // D0
auto d0_block_copy_global_to_lds = typename D0Loader::D0BlockwiseCopy( auto d0_block_copy_global_to_lds = typename D0Operator::D0BlockwiseCopyGlobalToLds(
d0_grid_desc_m0_n0_m1_m2_n1_m3, d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0), make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{}, tensor_operation::element_wise::PassThrough{},
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3, D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, 0, 0, 0, 0, 0), make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
auto d0_thread_copy_lds_to_vgpr = typename D0Loader::D0ThreadWiseCopy( auto d0_thread_copy_lds_to_vgpr = typename D0Operator::D0ThreadWiseCopy(
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0)); make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0));
auto d0grad_thread_copy_vgpr_to_lds = typename D0Operator::D0ThreadCopyVgprToLds(
D0Operator::d0_block_vgpr_desc_n0_n1_m0_m1_m2,
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0),
tensor_operation::element_wise::Scale{rp_dropout});
auto d0_block_copy_lds_to_global = typename D0Operator::D0BlockwiseCopyLdsToGlobal(
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{},
d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{});
if constexpr(Deterministic) if constexpr(Deterministic)
{ {
block_sync_lds(); block_sync_lds();
...@@ -1994,49 +2052,54 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -1994,49 +2052,54 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// add bias // add bias
if constexpr(!is_same<D0DataType, void>::value) if constexpr(!is_same<D0DataType, void>::value)
{ {
static constexpr auto& c_thread_desc = s_blockwise_gemm.GetCThreadDesc(); if(p_d0_grid != nullptr)
{
const auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( static constexpr auto& c_thread_desc = s_blockwise_gemm.GetCThreadDesc();
p_d0_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
const auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( p_d0_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>( D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
D0Loader::d0_thread_desc_.GetElementSpaceSize());
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
static_for<0, D0M0, 1>{}([&](auto mr) { D0Operator::d0_thread_desc_.GetElementSpaceSize());
// load data to lds
d0_block_copy_global_to_lds.RunRead(d0_grid_desc_m0_n0_m1_m2_n1_m3, static_for<0, D0M0, 1>{}([&](auto mr) {
d0_grid_buf); // load data to lds
d0_block_copy_global_to_lds.RunRead(d0_grid_desc_m0_n0_m1_m2_n1_m3,
d0_block_copy_global_to_lds.MoveSrcSliceWindow( d0_grid_buf);
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0));
d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_block_copy_global_to_lds.RunWrite( d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0));
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3, d0_block_buf);
block_sync_lds(); d0_block_copy_global_to_lds.RunWrite(
// read data form lds D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3, d0_block_buf);
d0_thread_copy_lds_to_vgpr.Run(D0Loader::d0_block_read_desc_n0_n1_m0_m1_m2, block_sync_lds();
make_tuple(I0, I0, I0, I0, I0), // read data form lds
d0_block_buf, d0_thread_copy_lds_to_vgpr.Run(
D0Loader::d0_thread_desc_, D0Operator::d0_block_vgpr_desc_n0_n1_m0_m1_m2,
make_tuple(I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0),
d0_thread_buf); d0_block_buf,
D0Operator::d0_thread_desc_,
// bias add make_tuple(I0, I0, I0, I0, I0),
static_for<0, d0_thread_buf.Size(), 1>{}([&](auto i) { d0_thread_buf);
constexpr index_t c_offset =
c_thread_desc.CalculateOffset(make_tuple(mr, I0, i)); // bias add
static_for<0, d0_thread_buf.Size(), 1>{}([&](auto i) {
s_slash_p_thread_buf(Number<c_offset>{}) += constexpr index_t c_offset =
ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]); c_thread_desc.CalculateOffset(make_tuple(mr, I0, i));
s_slash_p_thread_buf(Number<c_offset>{}) +=
ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]);
});
}); });
});
d0_block_copy_global_to_lds.MoveSrcSliceWindow( d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(-1, 0, -D0M0.value, 0, 0, 0)); d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(-1, 0, -D0M0.value, 0, 0, 0));
}
} }
// P_i: = softmax(scalar * S_i:) // P_i: = softmax(scalar * S_i:)
...@@ -2127,6 +2190,45 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1 ...@@ -2127,6 +2190,45 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
: y_dot_ygrad_thread_buf[Number<m>{}]); : y_dot_ygrad_thread_buf[Number<m>{}]);
}); });
// output bias grad
if constexpr(!is_same<D0DataType, void>::value)
{
if(p_d0grad_grid != nullptr)
{
auto d0grad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0grad_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0grad_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
static_for<0, D0M0, 1>{}([&](auto mr) {
d0grad_thread_copy_vgpr_to_lds.Run(
D0Operator::d0_thread_desc_,
make_tuple(mr, I0, I0, I0, I0),
sgrad_thread_buf,
D0Operator::d0_block_vgpr_desc_n0_n1_m0_m1_m2,
d0grad_block_buf);
block_sync_lds();
// write data from lds to global
d0_block_copy_lds_to_global.Run(
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3,
d0grad_block_buf,
d0_grid_desc_m0_n0_m1_m2_n1_m3,
d0grad_grid_buf,
I0);
d0_block_copy_lds_to_global.MoveDstSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0));
});
d0_block_copy_lds_to_global.MoveDstSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(-1, 0, -D0M0.value, 0, 0, 0));
}
}
// gemm dV // gemm dV
// dV = P_drop^T * dY // dV = P_drop^T * dY
{ {
......
...@@ -1294,7 +1294,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1294,7 +1294,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
using D0GridDescriptor_M0_N0_M1_M2_N1_M3 = using D0GridDescriptor_M0_N0_M1_M2_N1_M3 =
remove_cvref_t<decltype(MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(D0GridDesc_M_N{}))>; remove_cvref_t<decltype(MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(D0GridDesc_M_N{}))>;
struct D0Loader struct D0Operator
{ {
template <typename DataType> template <typename DataType>
struct TypeTransform struct TypeTransform
...@@ -1314,13 +1314,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1314,13 +1314,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static_assert(NPerXdl == 32); static_assert(NPerXdl == 32);
static_assert(D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock, static_assert(D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock,
"D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock"); "D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock");
__host__ __device__ static constexpr auto GetD0BlockWriteDescriptor_M0_N0_M1_M2_N1_M3() __host__ __device__ static constexpr auto GetD0BlockGlobalDescriptor_M0_N0_M1_M2_N1_M3()
{ {
// B1 matrix in LDS memory, dst of blockwise copy // B1 matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor_packed( return make_naive_tensor_descriptor_packed(
make_tuple(I1, I1, I1, D0M1, Number<NPerBlock>{}, D0M2)); make_tuple(I1, I1, I1, D0M1, Number<NPerBlock>{}, D0M2));
} }
__host__ __device__ static constexpr auto GetD0BlockReadDescriptor_N0_N1_M0_M1_M2() __host__ __device__ static constexpr auto GetD0BlockVgprDescriptor_N0_N1_M0_M1_M2()
{ {
constexpr auto d0_raw_m0_n_m1 = constexpr auto d0_raw_m0_n_m1 =
make_naive_tensor_descriptor_packed(make_tuple(D0M1, Number<NPerBlock>{}, D0M2)); make_naive_tensor_descriptor_packed(make_tuple(D0M1, Number<NPerBlock>{}, D0M2));
...@@ -1335,15 +1335,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1335,15 +1335,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
make_tuple(Sequence<2, 3>{}, Sequence<0, 1>{}, Sequence<4>{})); make_tuple(Sequence<2, 3>{}, Sequence<0, 1>{}, Sequence<4>{}));
return d0_n0_n1_m0_m1_m2; return d0_n0_n1_m0_m1_m2;
} }
static constexpr auto d0_block_write_desc_m0_n0_m1_m2_n1_m3 = static constexpr auto d0_block_global_desc_m0_n0_m1_m2_n1_m3 =
GetD0BlockWriteDescriptor_M0_N0_M1_M2_N1_M3(); GetD0BlockGlobalDescriptor_M0_N0_M1_M2_N1_M3();
static constexpr auto d0_block_read_desc_n0_n1_m0_m1_m2 = static constexpr auto d0_block_vgpr_desc_n0_n1_m0_m1_m2 =
GetD0BlockReadDescriptor_N0_N1_M0_M1_M2(); GetD0BlockVgprDescriptor_N0_N1_M0_M1_M2();
static constexpr auto d0_thread_desc_ = static constexpr auto d0_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I4, I1, D0M2)); make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I4, I1, D0M2));
using D0BlockwiseCopy = ThreadGroupTensorSliceTransfer_v4r1< using D0BlockwiseCopyGlobalToLds = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock, ThisThreadBlock,
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough, tensor_operation::element_wise::PassThrough,
...@@ -1354,34 +1354,77 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1354,34 +1354,77 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
1, 1,
BlockSize / NThreadClusterLengths, BlockSize / NThreadClusterLengths,
NThreadClusterLengths, NThreadClusterLengths,
1>, // ThreadClusterLengths 1>, // ThreadClusterLengths
Sequence<0, 1, 2, 3, 5, 4>, // ThreadClusterArrangeOrder Sequence<0, 1, 2, 3, 5, 4>, // ThreadClusterArrangeOrder
typename TypeTransform<D0DataType>::Type, // SrcData typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData typename TypeTransform<D0DataType>::Type, // DstData
D0GridDescriptor_M0_N0_M1_M2_N1_M3, // SrcDesc D0GridDescriptor_M0_N0_M1_M2_N1_M3, // SrcDesc
decltype(d0_block_write_desc_m0_n0_m1_m2_n1_m3), // DstDesc decltype(d0_block_global_desc_m0_n0_m1_m2_n1_m3), // DstDesc
Sequence<0, 1, 2, 3, 5, 4>, // SrcDimAccessOrder Sequence<0, 1, 2, 3, 5, 4>, // SrcDimAccessOrder
Sequence<0, 1, 2, 4, 3, 5>, // DstDimAccessOrder Sequence<0, 1, 2, 4, 3, 5>, // DstDimAccessOrder
4, // SrcVectorDim 4, // SrcVectorDim
5, // DstVectorDim 5, // DstVectorDim
D0BlockTransferSrcScalarPerVector, // SrcScalarPerVector D0BlockTransferSrcScalarPerVector, // SrcScalarPerVector
4, // DstScalarPerVector 4, // DstScalarPerVector
1, 1,
1, 1,
true, true,
true, // DstResetCoord true, // DstResetCoord
1>; 1>;
using D0ThreadWiseCopy = using D0ThreadwiseCopyLdsToVgpr =
ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_read_desc_n0_n1_m0_m1_m2), // SrcDesc decltype(d0_block_vgpr_desc_n0_n1_m0_m1_m2), // SrcDesc
decltype(d0_thread_desc_), // DstDesc decltype(d0_thread_desc_), // DstDesc
Sequence<1, 1, 4, 1, 4>, // SliceLengths Sequence<1, 1, 4, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // DimAccessOrder Sequence<0, 1, 2, 3, 4>, // DimAccessOrder
4, // SrcVectorDim 4, // SrcVectorDim
2, // SrcScalarPerVector 2, // SrcScalarPerVector
2>; 2>;
using D0ThreadCopyVgprToLds = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
typename TypeTransform<D0DataType>::Type,
decltype(d0_thread_desc_),
decltype(d0_block_vgpr_desc_n0_n1_m0_m1_m2),
tensor_operation::element_wise::Scale, // CElementwiseOperation
Sequence<1, 1, 4, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // AccessOrder
4, // VectorDim
4, // ScalarPerVector
InMemoryDataOperationEnum::Set, // GlobalMemoryDataOperation
1, // DstScalarStrideInVector
true>;
using D0BlockwiseCopyLdsToGlobal = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<I1, I1, I1, D0M1, NPerBlock, D0M2>, // BlockSliceLengths
Sequence<1,
1,
1,
BlockSize / NThreadClusterLengths,
NThreadClusterLengths,
1>, // ThreadClusterLengths
Sequence<0, 1, 2, 3, 5, 4>, // ThreadClusterArrangeOrder
typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_global_desc_m0_n0_m1_m2_n1_m3), // SrcDesc
D0GridDescriptor_M0_N0_M1_M2_N1_M3, // DstDesc
Sequence<0, 1, 2, 4, 3, 5>, // SrcDimAccessOrder
Sequence<0, 1, 2, 3, 5, 4>, // DstDimAccessOrder
5, // SrcVectorDim
4, // DstVectorDim
4, // SrcScalarPerVector
D0BlockTransferSrcScalarPerVector, // DstScalarPerVector
1,
1,
true,
true, // DstResetCoord
1>;
}; };
struct SharedMemTrait struct SharedMemTrait
...@@ -1416,10 +1459,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1416,10 +1459,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
math::integer_least_multiple(BlockSize, max_lds_align); math::integer_least_multiple(BlockSize, max_lds_align);
static constexpr auto d0_block_space_size_aligned = math::integer_least_multiple( static constexpr auto d0_block_space_size_aligned = math::integer_least_multiple(
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize(), max_lds_align); D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize(),
max_lds_align);
static constexpr auto d0_block_space_offset = static constexpr auto d0_block_space_offset =
k_block_space_size_aligned.value * sizeof(GemmDataType) / k_block_space_size_aligned.value * sizeof(GemmDataType) /
D0Loader::template TypeTransform<D0DataType>::Size; D0Operator::template TypeTransform<D0DataType>::Size;
// LDS allocation for C shuffle in LDS // LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
...@@ -1444,7 +1488,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1444,7 +1488,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
sizeof(GemmDataType); sizeof(GemmDataType);
const index_t d0_bytes_end = const index_t d0_bytes_end =
(SharedMemTrait::d0_block_space_offset + SharedMemTrait::d0_block_space_size_aligned) * (SharedMemTrait::d0_block_space_offset + SharedMemTrait::d0_block_space_size_aligned) *
D0Loader::template TypeTransform<D0DataType>::Size0; D0Operator::template TypeTransform<D0DataType>::Size0;
const index_t c_block_bytes_end = const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle); SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
...@@ -1472,6 +1516,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1472,6 +1516,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const InputDataType* __restrict__ p_ygrad_grid, const InputDataType* __restrict__ p_ygrad_grid,
OutputDataType* __restrict__ p_qgrad_grid, OutputDataType* __restrict__ p_qgrad_grid,
OutputDataType* __restrict__ p_kgrad_grid, OutputDataType* __restrict__ p_kgrad_grid,
D0DataType* __restrict__ p_d0grad_grid,
OutputDataType* __restrict__ p_vgrad_grid, OutputDataType* __restrict__ p_vgrad_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
...@@ -1969,17 +2014,30 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -1969,17 +2014,30 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
index_t gemm0_m_block_outer_index = num_gemm0_m_block_outer_loop - 1; index_t gemm0_m_block_outer_index = num_gemm0_m_block_outer_loop - 1;
// D0 // D0
auto d0_block_copy_global_to_lds = typename D0Loader::D0BlockwiseCopy( auto d0_block_copy_global_to_lds = typename D0Operator::D0BlockwiseCopyGlobalToLds(
d0_grid_desc_m0_n0_m1_m2_n1_m3, d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0), make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{}, tensor_operation::element_wise::PassThrough{},
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3, D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, 0, 0, 0, 0, 0), make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
auto d0_thread_copy_lds_to_vgpr = typename D0Loader::D0ThreadWiseCopy( auto d0_thread_copy_lds_to_vgpr = typename D0Operator::D0ThreadwiseCopyLdsToVgpr(
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0)); make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0));
auto d0grad_thread_copy_vgpr_to_lds = typename D0Operator::D0ThreadCopyVgprToLds(
D0Operator::d0_block_vgpr_desc_n0_n1_m0_m1_m2,
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0),
tensor_operation::element_wise::Scale{rp_dropout});
auto d0_block_copy_lds_to_global = typename D0Operator::D0BlockwiseCopyLdsToGlobal(
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{},
d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{});
if constexpr(Deterministic) if constexpr(Deterministic)
{ {
block_sync_lds(); block_sync_lds();
...@@ -2145,50 +2203,55 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -2145,50 +2203,55 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// add bias // add bias
if constexpr(!is_same<D0DataType, void>::value) if constexpr(!is_same<D0DataType, void>::value)
{ {
static constexpr auto& c_thread_desc = s_blockwise_gemm.GetCThreadDesc(); if(p_d0_grid != nullptr)
{
static constexpr auto& c_thread_desc = s_blockwise_gemm.GetCThreadDesc();
const auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); p_d0_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset, static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize()); D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>( auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
D0Loader::d0_thread_desc_.GetElementSpaceSize()); D0Operator::d0_thread_desc_.GetElementSpaceSize());
ignore = d0_thread_buf; ignore = d0_thread_buf;
static_for<0, D0M0, 1>{}([&](auto mr) { static_for<0, D0M0, 1>{}([&](auto mr) {
// load data to lds // load data to lds
d0_block_copy_global_to_lds.RunRead(d0_grid_desc_m0_n0_m1_m2_n1_m3, d0_block_copy_global_to_lds.RunRead(d0_grid_desc_m0_n0_m1_m2_n1_m3,
d0_grid_buf); d0_grid_buf);
d0_block_copy_global_to_lds.MoveSrcSliceWindow( d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0)); d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0));
d0_block_copy_global_to_lds.RunWrite( d0_block_copy_global_to_lds.RunWrite(
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3, d0_block_buf); D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3, d0_block_buf);
block_sync_lds(); block_sync_lds();
// read data form lds // read data form lds
d0_thread_copy_lds_to_vgpr.Run(D0Loader::d0_block_read_desc_n0_n1_m0_m1_m2, d0_thread_copy_lds_to_vgpr.Run(
make_tuple(I0, I0, I0, I0, I0), D0Operator::d0_block_vgpr_desc_n0_n1_m0_m1_m2,
d0_block_buf, make_tuple(I0, I0, I0, I0, I0),
D0Loader::d0_thread_desc_, d0_block_buf,
make_tuple(I0, I0, I0, I0, I0), D0Operator::d0_thread_desc_,
d0_thread_buf); make_tuple(I0, I0, I0, I0, I0),
d0_thread_buf);
// bias add
static_for<0, d0_thread_buf.Size(), 1>{}([&](auto i) { // bias add
constexpr index_t c_offset = static_for<0, d0_thread_buf.Size(), 1>{}([&](auto i) {
c_thread_desc.CalculateOffset(make_tuple(mr, I0, i)); constexpr index_t c_offset =
c_thread_desc.CalculateOffset(make_tuple(mr, I0, i));
s_slash_p_thread_buf(Number<c_offset>{}) +=
ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]); s_slash_p_thread_buf(Number<c_offset>{}) +=
ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]);
});
}); });
});
d0_block_copy_global_to_lds.MoveSrcSliceWindow( d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(-1, 0, -D0M0.value, 0, 0, 0)); d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(-1, 0, -D0M0.value, 0, 0, 0));
}
} }
// P_i: = softmax(scalar * S_i:) // P_i: = softmax(scalar * S_i:)
...@@ -2395,6 +2458,46 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2 ...@@ -2395,6 +2458,46 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
: y_dot_ygrad_thread_buf[Number<m>{}]); : y_dot_ygrad_thread_buf[Number<m>{}]);
}); });
// output bias grad
if constexpr(!is_same<D0DataType, void>::value)
{
if(p_d0grad_grid != nullptr)
{
auto d0grad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0grad_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0grad_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
static_for<0, D0M0, 1>{}([&](auto mr) {
d0grad_thread_copy_vgpr_to_lds.Run(
D0Operator::d0_thread_desc_,
make_tuple(mr, I0, I0, I0, I0),
sgrad_thread_buf,
D0Operator::d0_block_vgpr_desc_n0_n1_m0_m1_m2,
d0grad_block_buf);
block_sync_lds();
// write data from lds to global
d0_block_copy_lds_to_global.Run(
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3,
d0grad_block_buf,
d0_grid_desc_m0_n0_m1_m2_n1_m3,
d0grad_grid_buf,
I0);
d0_block_copy_lds_to_global.MoveDstSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0));
});
d0_block_copy_lds_to_global.MoveDstSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(-1, 0, -D0M0.value, 0, 0, 0));
}
}
SubThreadBlock<BlockSize> gemm2_a_copy_subgroup(s_blockwise_gemm.GetWaveIdx()[I0], SubThreadBlock<BlockSize> gemm2_a_copy_subgroup(s_blockwise_gemm.GetWaveIdx()[I0],
s_blockwise_gemm.GetWaveIdx()[I1]); s_blockwise_gemm.GetWaveIdx()[I1]);
......
...@@ -2151,7 +2151,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -2151,7 +2151,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0), make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{}); tensor_operation::element_wise::PassThrough{});
ignore = d0grad_thread_copy_vgpr_to_lds;
if constexpr(Deterministic) if constexpr(Deterministic)
{ {
block_sync_lds(); block_sync_lds();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment