Commit 522d8b2f authored by letaoqin's avatar letaoqin
Browse files

bias grad update to light version

parent 9dc3e49b
......@@ -518,8 +518,10 @@ int run(int argc, char* argv[])
static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
nullptr, // p_acc0_bias;
nullptr, // p_acc1_bias;
nullptr,
nullptr,
q_gs_ms_ks_lengths,
q_gs_ms_ks_strides,
k_gs_ns_ks_lengths,
......@@ -564,8 +566,10 @@ int run(int argc, char* argv[])
static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
nullptr, // p_acc0_bias;
nullptr, // p_acc1_bias;
nullptr,
nullptr,
q_gs_ms_ks_lengths,
q_gs_ms_ks_strides,
k_gs_ns_ks_lengths,
......
......@@ -24,7 +24,7 @@ Kernel outputs:
*/
#define USING_MASK 0
#define DIM 128 // DIM should be a multiple of 8.
#define DIM 32 // DIM should be a multiple of 8.
#include <iostream>
#include <numeric>
......@@ -616,6 +616,8 @@ int run(int argc, char* argv[])
p_vgrad,
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
{},
{},
problem_descs,
QKVElementOp{},
QKVElementOp{},
......@@ -663,6 +665,8 @@ int run(int argc, char* argv[])
p_vgrad,
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
{},
{},
problem_descs,
QKVElementOp{},
QKVElementOp{},
......
......@@ -123,6 +123,7 @@ __global__ void
const InputDataType* __restrict__ p_ygrad_grid,
OutputDataType* __restrict__ p_qgrad_grid,
OutputDataType* __restrict__ p_kgrad_grid,
D0DataType* __restrict__ p_d0grad_grid,
OutputDataType* __restrict__ p_vgrad_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
......@@ -176,11 +177,19 @@ __global__ void
const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded;
const D0DataType* tmp_p_d0_grid = nullptr;
D0DataType* tmp_p_d0grad_grid = nullptr;
if constexpr(!is_same<D0DataType, void>::value)
{
const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx)));
tmp_p_d0_grid = p_d0_grid + d0_batch_offset;
if(p_d0_grid != nullptr)
{
tmp_p_d0_grid = p_d0_grid + d0_batch_offset;
}
if(p_d0grad_grid != nullptr)
{
tmp_p_d0grad_grid = p_d0grad_grid + d0_batch_offset;
}
}
if constexpr(Deterministic)
{
......@@ -197,6 +206,7 @@ __global__ void
p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset,
tmp_p_d0grad_grid,
p_vgrad_grid + b1_batch_offset,
p_shared,
a_element_op,
......@@ -233,6 +243,7 @@ __global__ void
p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset,
tmp_p_d0grad_grid,
p_vgrad_grid + b1_batch_offset,
p_shared,
a_element_op,
......@@ -266,6 +277,7 @@ __global__ void
ignore = p_ygrad_grid;
ignore = p_qgrad_grid;
ignore = p_kgrad_grid;
ignore = p_d0grad_grid;
ignore = p_vgrad_grid;
ignore = a_element_op;
ignore = b_element_op;
......@@ -858,6 +870,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
OutputDataType* p_vgrad_grid,
const D0DataType* p_acc0_bias,
const D1DataType* p_acc1_bias,
D0DataType* p_d0grad_grid,
D1DataType* p_d1grad_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
......@@ -894,6 +908,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
p_qgrad_grid_{p_qgrad_grid},
p_kgrad_grid_{p_kgrad_grid},
p_vgrad_grid_{p_vgrad_grid},
p_d0grad_grid_{p_d0grad_grid},
a_grid_desc_ak0_m_ak1_{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b_grid_desc_bk0_n_bk1_{
......@@ -948,10 +963,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
p_drop_{p_drop}
{
// TODO: implement bias addition
ignore = p_acc0_bias;
ignore = p_d1grad_grid;
ignore = p_acc1_bias;
ignore = acc0_bias_gs_ms_ns_lengths;
ignore = acc0_bias_gs_ms_ns_strides;
ignore = acc1_bias_gs_ms_gemm1ns_lengths;
ignore = acc1_bias_gs_ms_gemm1ns_strides;
......@@ -1030,6 +1043,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
OutputDataType* p_qgrad_grid_;
OutputDataType* p_kgrad_grid_;
OutputDataType* p_vgrad_grid_;
D0DataType* p_d0grad_grid_;
// tensor descriptor
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
......@@ -1191,6 +1205,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
arg.p_ygrad_grid_,
arg.p_qgrad_grid_,
arg.p_kgrad_grid_,
arg.p_d0grad_grid_,
arg.p_vgrad_grid_,
arg.a_element_op_,
arg.b_element_op_,
......@@ -1342,6 +1357,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
OutputDataType* p_vgrad_grid,
const D0DataType* p_acc0_bias,
const D1DataType* p_acc1_bias,
D0DataType* p_d0grad_grid,
D1DataType* p_d1grad_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
......@@ -1380,6 +1397,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
p_vgrad_grid,
p_acc0_bias,
p_acc1_bias,
p_d0grad_grid,
p_d1grad_grid,
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
......@@ -1422,6 +1441,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
void* p_vgrad_grid,
const D0DataType* p_acc0_bias,
const D1DataType* p_acc1_bias,
void* p_d0grad_grid,
void* p_d1grad_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
......@@ -1461,6 +1482,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static_cast<OutputDataType*>(p_vgrad_grid),
static_cast<const D0DataType*>(p_acc0_bias), // cast in struct Argument
static_cast<const D1DataType*>(p_acc1_bias), // cast in struct Argument
static_cast<D0DataType*>(p_d0grad_grid),
static_cast<D1DataType*>(p_d1grad_grid),
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
......
......@@ -123,6 +123,7 @@ __global__ void
const InputDataType* __restrict__ p_ygrad_grid,
OutputDataType* __restrict__ p_qgrad_grid,
OutputDataType* __restrict__ p_kgrad_grid,
D0DataType* __restrict__ p_d0grad_grid,
OutputDataType* __restrict__ p_vgrad_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
......@@ -176,11 +177,19 @@ __global__ void
const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded;
const D0DataType* tmp_p_d0_grid = nullptr;
D0DataType* tmp_p_d0grad_grid = nullptr;
if constexpr(!is_same<D0DataType, void>::value)
{
const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx)));
tmp_p_d0_grid = p_d0_grid + d0_batch_offset;
if(p_d0_grid != nullptr)
{
tmp_p_d0_grid = p_d0_grid + d0_batch_offset;
}
if(p_d0grad_grid != nullptr)
{
tmp_p_d0grad_grid = p_d0grad_grid + d0_batch_offset;
}
}
if constexpr(Deterministic)
......@@ -198,6 +207,7 @@ __global__ void
p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset,
tmp_p_d0grad_grid,
p_vgrad_grid + b1_batch_offset,
p_shared,
a_element_op,
......@@ -234,6 +244,7 @@ __global__ void
p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset,
tmp_p_d0grad_grid,
p_vgrad_grid + b1_batch_offset,
p_shared,
a_element_op,
......@@ -267,6 +278,7 @@ __global__ void
ignore = p_ygrad_grid;
ignore = p_qgrad_grid;
ignore = p_kgrad_grid;
ignore = p_d0grad_grid;
ignore = p_vgrad_grid;
ignore = a_element_op;
ignore = b_element_op;
......@@ -874,6 +886,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
OutputDataType* p_vgrad_grid,
const D0DataType* p_acc0_bias,
const D1DataType* p_acc1_bias,
D0DataType* p_d0grad_grid,
D1DataType* p_d1grad_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
......@@ -910,6 +924,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
p_qgrad_grid_{p_qgrad_grid},
p_kgrad_grid_{p_kgrad_grid},
p_vgrad_grid_{p_vgrad_grid},
p_d0grad_grid_{p_d0grad_grid},
a_grid_desc_ak0_m_ak1_{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b_grid_desc_bk0_n_bk1_{
......@@ -964,6 +979,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
{
// TODO: implement bias addition
ignore = p_acc1_bias;
ignore = p_d1grad_grid;
ignore = acc1_bias_gs_ms_gemm1ns_lengths;
ignore = acc1_bias_gs_ms_gemm1ns_strides;
......@@ -1042,6 +1058,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
OutputDataType* p_qgrad_grid_;
OutputDataType* p_kgrad_grid_;
OutputDataType* p_vgrad_grid_;
D0DataType* p_d0grad_grid_;
// tensor descriptor
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
......@@ -1207,6 +1224,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
arg.p_ygrad_grid_,
arg.p_qgrad_grid_,
arg.p_kgrad_grid_,
arg.p_d0grad_grid_,
arg.p_vgrad_grid_,
arg.a_element_op_,
arg.b_element_op_,
......@@ -1374,6 +1392,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
OutputDataType* p_vgrad_grid,
const D0DataType* p_acc0_bias,
const D1DataType* p_acc1_bias,
D0DataType* p_d0grad_grid,
D1DataType* p_d1grad_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
......@@ -1412,6 +1432,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
p_vgrad_grid,
p_acc0_bias,
p_acc1_bias,
p_d0grad_grid,
p_d1grad_grid,
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
......@@ -1454,6 +1476,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
void* p_vgrad_grid,
const void* p_acc0_bias,
const void* p_acc1_bias,
void* p_d0grad_grid,
void* p_d1grad_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
......@@ -1493,6 +1517,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static_cast<OutputDataType*>(p_vgrad_grid),
static_cast<const D0DataType*>(p_acc0_bias), // cast in struct Argument
static_cast<const D1DataType*>(p_acc1_bias), // cast in struct Argument
static_cast<D0DataType*>(p_d0grad_grid),
static_cast<D1DataType*>(p_d1grad_grid),
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
......
......@@ -1333,8 +1333,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
void* p_vgrad_grid,
const void* p_acc0_bias,
const void* p_acc1_bias,
D0DataType* p_d0grad_grid,
D1DataType* p_d1grad_grid,
void* p_d0grad_grid,
void* p_d1grad_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths,
......@@ -1373,8 +1373,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static_cast<OutputDataType*>(p_vgrad_grid),
static_cast<const D0DataType*>(p_acc0_bias), // cast in struct Argument
static_cast<const D1DataType*>(p_acc1_bias), // cast in struct Argument
static_cast<const D0DataType*>(p_d0grad_grid),
static_cast<const D1DataType*>(p_d1grad_grid),
static_cast<D0DataType*>(p_d0grad_grid),
static_cast<D1DataType*>(p_d1grad_grid),
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
......
......@@ -162,13 +162,16 @@ __global__ void
: arg_ptr[group_id].p_z_grid_ + z_batch_offset);
const D0DataType* tmp_p_d0_grid = nullptr;
D0DataType* tmp_p_d0grad_grid = nullptr;
if constexpr(!is_same<D0DataType, void>::value)
{
const long_index_t d0_batch_offset =
__builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetD0BasePtr(g_idx)));
tmp_p_d0_grid = arg_ptr[group_id].p_d0_grid_ + d0_batch_offset;
if(arg_ptr[group_id].p_d0_grid_ != nullptr)
tmp_p_d0_grid = arg_ptr[group_id].p_d0_grid_ + d0_batch_offset;
if(arg_ptr[group_id].p_d0grad_grid_)
tmp_p_d0grad_grid = arg_ptr[group_id].p_d0grad_grid_ + d0_batch_offset;
}
if constexpr(Deterministic)
{
......@@ -185,6 +188,7 @@ __global__ void
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset,
arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset,
tmp_p_d0grad_grid,
arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset,
p_shared,
a_element_op,
......@@ -222,6 +226,7 @@ __global__ void
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset,
arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset,
tmp_p_d0grad_grid,
arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset,
p_shared,
a_element_op,
......@@ -806,6 +811,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const InputDataType* p_ygrad_grid_;
OutputDataType* p_qgrad_grid_;
OutputDataType* p_kgrad_grid_;
D0DataType* p_d0grad_grid_;
OutputDataType* p_vgrad_grid_;
// tensor descriptors for block/thread-wise copy
......@@ -878,6 +884,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<void*>& p_d0grads,
const std::vector<void*>& p_d1grads,
const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
......@@ -911,7 +919,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
group_count_ == ck::type_convert<ck::index_t>(p_Ds.size()) &&
(group_count_ == ck::type_convert<ck::index_t>(p_acc0_bias_vec.size()) ||
ck::type_convert<ck::index_t>(p_acc0_bias_vec.size() == 0)) &&
0 == p_acc1_bias_vec.size()))
0 == p_acc1_bias_vec.size() &&
(group_count_ == ck::type_convert<ck::index_t>(p_d0grads.size()) ||
ck::type_convert<ck::index_t>(p_d0grads.size() == 0)) &&
0 == p_d1grads.size()))
{
throw std::runtime_error("wrong! group_count_ != p_As/b/b1/c.size");
}
......@@ -937,7 +948,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const auto p_ygrad_grid = static_cast<const InputDataType*>(p_Ygrads[i]);
auto p_qgrad_grid = static_cast<OutputDataType*>(p_Qgrads[i]);
auto p_kgrad_grid = static_cast<OutputDataType*>(p_Kgrads[i]);
auto p_vgrad_grid = static_cast<OutputDataType*>(p_Vgrads[i]);
auto p_d0grad_grid =
(ck::type_convert<ck::index_t>(p_d0grads.size()) == group_count_)
? static_cast<D0DataType*>(p_d0grads[i])
: nullptr;
auto p_vgrad_grid = static_cast<OutputDataType*>(p_Vgrads[i]);
const auto& problem_desc = problem_desc_vec[i];
......@@ -1054,6 +1069,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
p_ygrad_grid,
p_qgrad_grid,
p_kgrad_grid,
p_d0grad_grid,
p_vgrad_grid,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
......@@ -1370,6 +1386,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<void*>& p_d0grads,
const std::vector<void*>& p_d1grads,
const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
......@@ -1392,6 +1410,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
p_Vgrads,
p_acc0_bias_vec,
p_acc1_bias_vec,
p_d0grads,
p_d1grads,
problem_desc_vec,
a_element_op,
b_element_op,
......@@ -1420,6 +1440,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<void*>& p_d0grads,
const std::vector<void*>& p_d1grads,
const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
......@@ -1442,6 +1464,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
p_Vgrads,
p_acc0_bias_vec, // cast in struct Argument
p_acc1_bias_vec, // cast in struct Argument
p_d0grads,
p_d1grads,
problem_desc_vec,
a_element_op,
b_element_op,
......
......@@ -160,13 +160,17 @@ __global__ void
(arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset);
const D0DataType* tmp_p_d0_grid = nullptr;
D0DataType* tmp_p_d0grad_grid = nullptr;
if constexpr(!is_same<D0DataType, void>::value)
{
const long_index_t d0_batch_offset =
__builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetD0BasePtr(g_idx)));
tmp_p_d0_grid = arg_ptr[group_id].p_d0_grid_ + d0_batch_offset;
if(arg_ptr[group_id].p_d0_grid_ != nullptr)
tmp_p_d0_grid = arg_ptr[group_id].p_d0_grid_ + d0_batch_offset;
if(arg_ptr[group_id].p_d0grad_grid_)
tmp_p_d0grad_grid = arg_ptr[group_id].p_d0grad_grid_ + d0_batch_offset;
}
if constexpr(Deterministic)
......@@ -184,6 +188,7 @@ __global__ void
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset,
arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset,
tmp_p_d0grad_grid,
arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset,
p_shared,
a_element_op,
......@@ -221,6 +226,7 @@ __global__ void
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset,
arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset,
tmp_p_d0grad_grid,
arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset,
p_shared,
a_element_op,
......@@ -876,6 +882,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const InputDataType* p_ygrad_grid_;
OutputDataType* p_qgrad_grid_;
OutputDataType* p_kgrad_grid_;
D0DataType* p_d0grad_grid_;
OutputDataType* p_vgrad_grid_;
// tensor descriptors for block/thread-wise copy
......@@ -948,6 +955,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<void*>& p_d0grads,
const std::vector<void*>& p_d1grads,
const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
......@@ -981,7 +990,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
group_count_ == ck::type_convert<ck::index_t>(p_Ds.size()) &&
(group_count_ == ck::type_convert<ck::index_t>(p_acc0_bias_vec.size()) ||
ck::type_convert<ck::index_t>(p_acc0_bias_vec.size() == 0)) &&
0 == p_acc1_bias_vec.size()))
0 == p_acc1_bias_vec.size() &&
(group_count_ == ck::type_convert<ck::index_t>(p_d0grads.size()) ||
ck::type_convert<ck::index_t>(p_d0grads.size() == 0)) &&
0 == p_d1grads.size()))
{
throw std::runtime_error("wrong! group_count_ != p_As/b/b1/c.size");
}
......@@ -1007,7 +1019,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const auto p_ygrad_grid = static_cast<const InputDataType*>(p_Ygrads[i]);
auto p_qgrad_grid = static_cast<OutputDataType*>(p_Qgrads[i]);
auto p_kgrad_grid = static_cast<OutputDataType*>(p_Kgrads[i]);
auto p_vgrad_grid = static_cast<OutputDataType*>(p_Vgrads[i]);
auto p_d0grad_grid =
(ck::type_convert<ck::index_t>(p_d0grads.size()) == group_count_)
? static_cast<D0DataType*>(p_d0grads[i])
: nullptr;
auto p_vgrad_grid = static_cast<OutputDataType*>(p_Vgrads[i]);
const auto& problem_desc = problem_desc_vec[i];
......@@ -1124,6 +1140,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
p_ygrad_grid,
p_qgrad_grid,
p_kgrad_grid,
p_d0grad_grid,
p_vgrad_grid,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
......@@ -1445,6 +1462,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<void*>& p_d0grads,
const std::vector<void*>& p_d1grads,
const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
......@@ -1467,6 +1486,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
p_Vgrads,
p_acc0_bias_vec,
p_acc1_bias_vec,
p_d0grads,
p_d1grads,
problem_desc_vec,
a_element_op,
b_element_op,
......@@ -1495,6 +1516,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<void*>& p_d0grads,
const std::vector<void*>& p_d1grads,
const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
......@@ -1517,6 +1540,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
p_Vgrads,
p_acc0_bias_vec, // cast in struct Argument
p_acc1_bias_vec, // cast in struct Argument
p_d0grads,
p_d1grads,
problem_desc_vec,
a_element_op,
b_element_op,
......
......@@ -1215,7 +1215,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
using D0GridDescriptor_M0_N0_M1_M2_N1_M3 =
remove_cvref_t<decltype(MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(D0GridDesc_M_N{}))>;
struct D0Loader
struct D0Operator
{
template <typename DataType>
struct TypeTransform
......@@ -1235,13 +1235,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
static_assert(MPerXdl <= KPerBlock);
static_assert(D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock,
"D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock");
__host__ __device__ static constexpr auto GetD0BlockWriteDescriptor_M0_N0_M1_M2_N1_M3()
__host__ __device__ static constexpr auto GetD0BlockGlobalDescriptor_M0_N0_M1_M2_N1_M3()
{
// B1 matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor_packed(
make_tuple(I1, I1, I1, D0M1, Number<NPerBlock>{}, D0M2));
}
__host__ __device__ static constexpr auto GetD0BlockReadDescriptor_N0_N1_M0_M1_M2()
__host__ __device__ static constexpr auto GetD0BlockVgprDescriptor_N0_N1_M0_M1_M2()
{
constexpr auto d0_raw_m0_n_m1 =
make_naive_tensor_descriptor_packed(make_tuple(D0M1, Number<NPerBlock>{}, D0M2));
......@@ -1256,15 +1256,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
make_tuple(Sequence<2, 3>{}, Sequence<0, 1>{}, Sequence<4>{}));
return d0_n0_n1_m0_m1_m2;
}
static constexpr auto d0_block_write_desc_m0_n0_m1_m2_n1_m3 =
GetD0BlockWriteDescriptor_M0_N0_M1_M2_N1_M3();
static constexpr auto d0_block_read_desc_n0_n1_m0_m1_m2 =
GetD0BlockReadDescriptor_N0_N1_M0_M1_M2();
static constexpr auto d0_block_global_desc_m0_n0_m1_m2_n1_m3 =
GetD0BlockGlobalDescriptor_M0_N0_M1_M2_N1_M3();
static constexpr auto d0_block_vgpr_desc_n0_n1_m0_m1_m2 =
GetD0BlockVgprDescriptor_N0_N1_M0_M1_M2();
static constexpr auto d0_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I4, I1, D0M2));
using D0BlockwiseCopy = ThreadGroupTensorSliceTransfer_v4r1<
using D0BlockwiseCopyGlobalToLds = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
......@@ -1275,18 +1275,18 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
1,
BlockSize / NThreadClusterLengths,
NThreadClusterLengths,
1>, // ThreadClusterLengths
Sequence<0, 1, 2, 3, 5, 4>, // ThreadClusterArrangeOrder
typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
D0GridDescriptor_M0_N0_M1_M2_N1_M3, // SrcDesc
decltype(d0_block_write_desc_m0_n0_m1_m2_n1_m3), // DstDesc
Sequence<0, 1, 2, 3, 5, 4>, // SrcDimAccessOrder
Sequence<0, 1, 2, 4, 3, 5>, // DstDimAccessOrder
4, // SrcVectorDim
5, // DstVectorDim
D0BlockTransferSrcScalarPerVector, // SrcScalarPerVector
4, // DstScalarPerVector
1>, // ThreadClusterLengths
Sequence<0, 1, 2, 3, 5, 4>, // ThreadClusterArrangeOrder
typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
D0GridDescriptor_M0_N0_M1_M2_N1_M3, // SrcDesc
decltype(d0_block_global_desc_m0_n0_m1_m2_n1_m3), // DstDesc
Sequence<0, 1, 2, 3, 5, 4>, // SrcDimAccessOrder
Sequence<0, 1, 2, 4, 3, 5>, // DstDimAccessOrder
4, // SrcVectorDim
5, // DstVectorDim
D0BlockTransferSrcScalarPerVector, // SrcScalarPerVector
4, // DstScalarPerVector
1,
1,
true,
......@@ -1296,13 +1296,56 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
using D0ThreadWiseCopy =
ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_read_desc_n0_n1_m0_m1_m2), // SrcDesc
decltype(d0_block_vgpr_desc_n0_n1_m0_m1_m2), // SrcDesc
decltype(d0_thread_desc_), // DstDesc
Sequence<1, 1, 4, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // DimAccessOrder
4, // SrcVectorDim
2, // SrcScalarPerVector
2>;
using D0ThreadCopyVgprToLds = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
typename TypeTransform<D0DataType>::Type,
decltype(d0_thread_desc_),
decltype(d0_block_vgpr_desc_n0_n1_m0_m1_m2),
tensor_operation::element_wise::Scale, // CElementwiseOperation
Sequence<1, 1, 4, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // AccessOrder
4, // VectorDim
4, // ScalarPerVector
InMemoryDataOperationEnum::Set, // GlobalMemoryDataOperation
1, // DstScalarStrideInVector
true>;
using D0BlockwiseCopyLdsToGlobal = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<I1, I1, I1, D0M1, NPerBlock, D0M2>, // BlockSliceLengths
Sequence<1,
1,
1,
BlockSize / NThreadClusterLengths,
NThreadClusterLengths,
1>, // ThreadClusterLengths
Sequence<0, 1, 2, 3, 5, 4>, // ThreadClusterArrangeOrder
typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_global_desc_m0_n0_m1_m2_n1_m3), // SrcDesc
D0GridDescriptor_M0_N0_M1_M2_N1_M3, // DstDesc
Sequence<0, 1, 2, 4, 3, 5>, // SrcDimAccessOrder
Sequence<0, 1, 2, 3, 5, 4>, // DstDimAccessOrder
5, // SrcVectorDim
4, // DstVectorDim
4, // SrcScalarPerVector
D0BlockTransferSrcScalarPerVector, // DstScalarPerVector
1,
1,
true,
true, // DstResetCoord
1>;
};
struct SharedMemTrait
......@@ -1337,11 +1380,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
q_block_space_size_aligned.value;
static constexpr auto d0_block_space_size_aligned = math::integer_least_multiple(
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize(), max_lds_align);
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize(),
max_lds_align);
static constexpr auto d0_block_space_offset =
(k_block_space_size_aligned.value + ygrad_block_space_size_aligned.value +
q_block_space_size_aligned.value) *
sizeof(GemmDataType) / D0Loader::template TypeTransform<D0DataType>::Size;
sizeof(GemmDataType) / D0Operator::template TypeTransform<D0DataType>::Size;
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
......@@ -1358,7 +1402,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
sizeof(GemmDataType);
const index_t d0_bytes_end =
(SharedMemTrait::d0_block_space_offset + SharedMemTrait::d0_block_space_size_aligned) *
D0Loader::template TypeTransform<D0DataType>::Size0;
D0Operator::template TypeTransform<D0DataType>::Size0;
const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
......@@ -1381,6 +1425,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
const InputDataType* __restrict__ p_ygrad_grid,
OutputDataType* __restrict__ p_qgrad_grid,
OutputDataType* __restrict__ p_kgrad_grid,
D0DataType* __restrict__ p_d0grad_grid,
OutputDataType* __restrict__ p_vgrad_grid,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
......@@ -1848,17 +1893,30 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// gemm0 M loop
index_t gemm0_m_block_outer_index = num_gemm0_m_block_outer_loop - 1;
// D0
auto d0_block_copy_global_to_lds = typename D0Loader::D0BlockwiseCopy(
auto d0_block_copy_global_to_lds = typename D0Operator::D0BlockwiseCopyGlobalToLds(
d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{},
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3,
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{});
auto d0_thread_copy_lds_to_vgpr = typename D0Loader::D0ThreadWiseCopy(
auto d0_thread_copy_lds_to_vgpr = typename D0Operator::D0ThreadWiseCopy(
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0));
auto d0grad_thread_copy_vgpr_to_lds = typename D0Operator::D0ThreadCopyVgprToLds(
D0Operator::d0_block_vgpr_desc_n0_n1_m0_m1_m2,
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0),
tensor_operation::element_wise::Scale{rp_dropout});
auto d0_block_copy_lds_to_global = typename D0Operator::D0BlockwiseCopyLdsToGlobal(
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{},
d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{});
if constexpr(Deterministic)
{
block_sync_lds();
......@@ -1994,49 +2052,54 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
// add bias
if constexpr(!is_same<D0DataType, void>::value)
{
static constexpr auto& c_thread_desc = s_blockwise_gemm.GetCThreadDesc();
const auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
D0Loader::d0_thread_desc_.GetElementSpaceSize());
static_for<0, D0M0, 1>{}([&](auto mr) {
// load data to lds
d0_block_copy_global_to_lds.RunRead(d0_grid_desc_m0_n0_m1_m2_n1_m3,
d0_grid_buf);
d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0));
d0_block_copy_global_to_lds.RunWrite(
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3, d0_block_buf);
block_sync_lds();
// read data form lds
d0_thread_copy_lds_to_vgpr.Run(D0Loader::d0_block_read_desc_n0_n1_m0_m1_m2,
make_tuple(I0, I0, I0, I0, I0),
d0_block_buf,
D0Loader::d0_thread_desc_,
make_tuple(I0, I0, I0, I0, I0),
d0_thread_buf);
// bias add
static_for<0, d0_thread_buf.Size(), 1>{}([&](auto i) {
constexpr index_t c_offset =
c_thread_desc.CalculateOffset(make_tuple(mr, I0, i));
s_slash_p_thread_buf(Number<c_offset>{}) +=
ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]);
if(p_d0_grid != nullptr)
{
static constexpr auto& c_thread_desc = s_blockwise_gemm.GetCThreadDesc();
const auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
D0Operator::d0_thread_desc_.GetElementSpaceSize());
static_for<0, D0M0, 1>{}([&](auto mr) {
// load data to lds
d0_block_copy_global_to_lds.RunRead(d0_grid_desc_m0_n0_m1_m2_n1_m3,
d0_grid_buf);
d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0));
d0_block_copy_global_to_lds.RunWrite(
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3, d0_block_buf);
block_sync_lds();
// read data form lds
d0_thread_copy_lds_to_vgpr.Run(
D0Operator::d0_block_vgpr_desc_n0_n1_m0_m1_m2,
make_tuple(I0, I0, I0, I0, I0),
d0_block_buf,
D0Operator::d0_thread_desc_,
make_tuple(I0, I0, I0, I0, I0),
d0_thread_buf);
// bias add
static_for<0, d0_thread_buf.Size(), 1>{}([&](auto i) {
constexpr index_t c_offset =
c_thread_desc.CalculateOffset(make_tuple(mr, I0, i));
s_slash_p_thread_buf(Number<c_offset>{}) +=
ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]);
});
});
});
d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(-1, 0, -D0M0.value, 0, 0, 0));
d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(-1, 0, -D0M0.value, 0, 0, 0));
}
}
// P_i: = softmax(scalar * S_i:)
......@@ -2127,6 +2190,45 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V1
: y_dot_ygrad_thread_buf[Number<m>{}]);
});
// output bias grad
if constexpr(!is_same<D0DataType, void>::value)
{
if(p_d0grad_grid != nullptr)
{
auto d0grad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0grad_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0grad_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
static_for<0, D0M0, 1>{}([&](auto mr) {
d0grad_thread_copy_vgpr_to_lds.Run(
D0Operator::d0_thread_desc_,
make_tuple(mr, I0, I0, I0, I0),
sgrad_thread_buf,
D0Operator::d0_block_vgpr_desc_n0_n1_m0_m1_m2,
d0grad_block_buf);
block_sync_lds();
// write data from lds to global
d0_block_copy_lds_to_global.Run(
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3,
d0grad_block_buf,
d0_grid_desc_m0_n0_m1_m2_n1_m3,
d0grad_grid_buf,
I0);
d0_block_copy_lds_to_global.MoveDstSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0));
});
d0_block_copy_lds_to_global.MoveDstSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(-1, 0, -D0M0.value, 0, 0, 0));
}
}
// gemm dV
// dV = P_drop^T * dY
{
......
......@@ -1294,7 +1294,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
using D0GridDescriptor_M0_N0_M1_M2_N1_M3 =
remove_cvref_t<decltype(MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(D0GridDesc_M_N{}))>;
struct D0Loader
struct D0Operator
{
template <typename DataType>
struct TypeTransform
......@@ -1314,13 +1314,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
static_assert(NPerXdl == 32);
static_assert(D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock,
"D0BlockTransferSrcScalarPerVector * NThreadClusterLengths <= NPerBlock");
__host__ __device__ static constexpr auto GetD0BlockWriteDescriptor_M0_N0_M1_M2_N1_M3()
__host__ __device__ static constexpr auto GetD0BlockGlobalDescriptor_M0_N0_M1_M2_N1_M3()
{
// B1 matrix in LDS memory, dst of blockwise copy
return make_naive_tensor_descriptor_packed(
make_tuple(I1, I1, I1, D0M1, Number<NPerBlock>{}, D0M2));
}
__host__ __device__ static constexpr auto GetD0BlockReadDescriptor_N0_N1_M0_M1_M2()
__host__ __device__ static constexpr auto GetD0BlockVgprDescriptor_N0_N1_M0_M1_M2()
{
constexpr auto d0_raw_m0_n_m1 =
make_naive_tensor_descriptor_packed(make_tuple(D0M1, Number<NPerBlock>{}, D0M2));
......@@ -1335,15 +1335,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
make_tuple(Sequence<2, 3>{}, Sequence<0, 1>{}, Sequence<4>{}));
return d0_n0_n1_m0_m1_m2;
}
static constexpr auto d0_block_write_desc_m0_n0_m1_m2_n1_m3 =
GetD0BlockWriteDescriptor_M0_N0_M1_M2_N1_M3();
static constexpr auto d0_block_read_desc_n0_n1_m0_m1_m2 =
GetD0BlockReadDescriptor_N0_N1_M0_M1_M2();
static constexpr auto d0_block_global_desc_m0_n0_m1_m2_n1_m3 =
GetD0BlockGlobalDescriptor_M0_N0_M1_M2_N1_M3();
static constexpr auto d0_block_vgpr_desc_n0_n1_m0_m1_m2 =
GetD0BlockVgprDescriptor_N0_N1_M0_M1_M2();
static constexpr auto d0_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I4, I1, D0M2));
using D0BlockwiseCopy = ThreadGroupTensorSliceTransfer_v4r1<
using D0BlockwiseCopyGlobalToLds = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
......@@ -1354,34 +1354,77 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
1,
BlockSize / NThreadClusterLengths,
NThreadClusterLengths,
1>, // ThreadClusterLengths
Sequence<0, 1, 2, 3, 5, 4>, // ThreadClusterArrangeOrder
typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
D0GridDescriptor_M0_N0_M1_M2_N1_M3, // SrcDesc
decltype(d0_block_write_desc_m0_n0_m1_m2_n1_m3), // DstDesc
Sequence<0, 1, 2, 3, 5, 4>, // SrcDimAccessOrder
Sequence<0, 1, 2, 4, 3, 5>, // DstDimAccessOrder
4, // SrcVectorDim
5, // DstVectorDim
D0BlockTransferSrcScalarPerVector, // SrcScalarPerVector
4, // DstScalarPerVector
1>, // ThreadClusterLengths
Sequence<0, 1, 2, 3, 5, 4>, // ThreadClusterArrangeOrder
typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
D0GridDescriptor_M0_N0_M1_M2_N1_M3, // SrcDesc
decltype(d0_block_global_desc_m0_n0_m1_m2_n1_m3), // DstDesc
Sequence<0, 1, 2, 3, 5, 4>, // SrcDimAccessOrder
Sequence<0, 1, 2, 4, 3, 5>, // DstDimAccessOrder
4, // SrcVectorDim
5, // DstVectorDim
D0BlockTransferSrcScalarPerVector, // SrcScalarPerVector
4, // DstScalarPerVector
1,
1,
true,
true, // DstResetCoord
1>;
using D0ThreadWiseCopy =
using D0ThreadwiseCopyLdsToVgpr =
ThreadwiseTensorSliceTransfer_v4<typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_read_desc_n0_n1_m0_m1_m2), // SrcDesc
decltype(d0_block_vgpr_desc_n0_n1_m0_m1_m2), // SrcDesc
decltype(d0_thread_desc_), // DstDesc
Sequence<1, 1, 4, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // DimAccessOrder
4, // SrcVectorDim
2, // SrcScalarPerVector
2>;
using D0ThreadCopyVgprToLds = ThreadwiseTensorSliceTransfer_v1r3<
FloatGemmAcc,
typename TypeTransform<D0DataType>::Type,
decltype(d0_thread_desc_),
decltype(d0_block_vgpr_desc_n0_n1_m0_m1_m2),
tensor_operation::element_wise::Scale, // CElementwiseOperation
Sequence<1, 1, 4, 1, 4>, // SliceLengths
Sequence<0, 1, 2, 3, 4>, // AccessOrder
4, // VectorDim
4, // ScalarPerVector
InMemoryDataOperationEnum::Set, // GlobalMemoryDataOperation
1, // DstScalarStrideInVector
true>;
using D0BlockwiseCopyLdsToGlobal = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
tensor_operation::element_wise::PassThrough,
tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<I1, I1, I1, D0M1, NPerBlock, D0M2>, // BlockSliceLengths
Sequence<1,
1,
1,
BlockSize / NThreadClusterLengths,
NThreadClusterLengths,
1>, // ThreadClusterLengths
Sequence<0, 1, 2, 3, 5, 4>, // ThreadClusterArrangeOrder
typename TypeTransform<D0DataType>::Type, // SrcData
typename TypeTransform<D0DataType>::Type, // DstData
decltype(d0_block_global_desc_m0_n0_m1_m2_n1_m3), // SrcDesc
D0GridDescriptor_M0_N0_M1_M2_N1_M3, // DstDesc
Sequence<0, 1, 2, 4, 3, 5>, // SrcDimAccessOrder
Sequence<0, 1, 2, 3, 5, 4>, // DstDimAccessOrder
5, // SrcVectorDim
4, // DstVectorDim
4, // SrcScalarPerVector
D0BlockTransferSrcScalarPerVector, // DstScalarPerVector
1,
1,
true,
true, // DstResetCoord
1>;
};
struct SharedMemTrait
......@@ -1416,10 +1459,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
math::integer_least_multiple(BlockSize, max_lds_align);
static constexpr auto d0_block_space_size_aligned = math::integer_least_multiple(
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize(), max_lds_align);
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize(),
max_lds_align);
static constexpr auto d0_block_space_offset =
k_block_space_size_aligned.value * sizeof(GemmDataType) /
D0Loader::template TypeTransform<D0DataType>::Size;
D0Operator::template TypeTransform<D0DataType>::Size;
// LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
......@@ -1444,7 +1488,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
sizeof(GemmDataType);
const index_t d0_bytes_end =
(SharedMemTrait::d0_block_space_offset + SharedMemTrait::d0_block_space_size_aligned) *
D0Loader::template TypeTransform<D0DataType>::Size0;
D0Operator::template TypeTransform<D0DataType>::Size0;
const index_t c_block_bytes_end =
SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
......@@ -1472,6 +1516,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
const InputDataType* __restrict__ p_ygrad_grid,
OutputDataType* __restrict__ p_qgrad_grid,
OutputDataType* __restrict__ p_kgrad_grid,
D0DataType* __restrict__ p_d0grad_grid,
OutputDataType* __restrict__ p_vgrad_grid,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
......@@ -1969,17 +2014,30 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
index_t gemm0_m_block_outer_index = num_gemm0_m_block_outer_loop - 1;
// D0
auto d0_block_copy_global_to_lds = typename D0Loader::D0BlockwiseCopy(
auto d0_block_copy_global_to_lds = typename D0Operator::D0BlockwiseCopyGlobalToLds(
d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{},
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3,
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{});
auto d0_thread_copy_lds_to_vgpr = typename D0Loader::D0ThreadWiseCopy(
auto d0_thread_copy_lds_to_vgpr = typename D0Operator::D0ThreadwiseCopyLdsToVgpr(
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0));
auto d0grad_thread_copy_vgpr_to_lds = typename D0Operator::D0ThreadCopyVgprToLds(
D0Operator::d0_block_vgpr_desc_n0_n1_m0_m1_m2,
make_tuple(wave_id[I1], wave_m_n_id[I1], 0, wave_m_n_id[I0], 0),
tensor_operation::element_wise::Scale{rp_dropout});
auto d0_block_copy_lds_to_global = typename D0Operator::D0BlockwiseCopyLdsToGlobal(
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(0, 0, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{},
d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{});
if constexpr(Deterministic)
{
block_sync_lds();
......@@ -2145,50 +2203,55 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
// add bias
if constexpr(!is_same<D0DataType, void>::value)
{
static constexpr auto& c_thread_desc = s_blockwise_gemm.GetCThreadDesc();
if(p_d0_grid != nullptr)
{
static constexpr auto& c_thread_desc = s_blockwise_gemm.GetCThreadDesc();
const auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
const auto d0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
D0Loader::d0_thread_desc_.GetElementSpaceSize());
ignore = d0_thread_buf;
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, D0DataType>(
D0Operator::d0_thread_desc_.GetElementSpaceSize());
ignore = d0_thread_buf;
static_for<0, D0M0, 1>{}([&](auto mr) {
// load data to lds
d0_block_copy_global_to_lds.RunRead(d0_grid_desc_m0_n0_m1_m2_n1_m3,
d0_grid_buf);
static_for<0, D0M0, 1>{}([&](auto mr) {
// load data to lds
d0_block_copy_global_to_lds.RunRead(d0_grid_desc_m0_n0_m1_m2_n1_m3,
d0_grid_buf);
d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0));
d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0));
d0_block_copy_global_to_lds.RunWrite(
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3, d0_block_buf);
block_sync_lds();
// read data form lds
d0_thread_copy_lds_to_vgpr.Run(D0Loader::d0_block_read_desc_n0_n1_m0_m1_m2,
make_tuple(I0, I0, I0, I0, I0),
d0_block_buf,
D0Loader::d0_thread_desc_,
make_tuple(I0, I0, I0, I0, I0),
d0_thread_buf);
// bias add
static_for<0, d0_thread_buf.Size(), 1>{}([&](auto i) {
constexpr index_t c_offset =
c_thread_desc.CalculateOffset(make_tuple(mr, I0, i));
s_slash_p_thread_buf(Number<c_offset>{}) +=
ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]);
d0_block_copy_global_to_lds.RunWrite(
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3, d0_block_buf);
block_sync_lds();
// read data form lds
d0_thread_copy_lds_to_vgpr.Run(
D0Operator::d0_block_vgpr_desc_n0_n1_m0_m1_m2,
make_tuple(I0, I0, I0, I0, I0),
d0_block_buf,
D0Operator::d0_thread_desc_,
make_tuple(I0, I0, I0, I0, I0),
d0_thread_buf);
// bias add
static_for<0, d0_thread_buf.Size(), 1>{}([&](auto i) {
constexpr index_t c_offset =
c_thread_desc.CalculateOffset(make_tuple(mr, I0, i));
s_slash_p_thread_buf(Number<c_offset>{}) +=
ck::type_convert<FloatGemmAcc>(d0_thread_buf[i]);
});
});
});
d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(-1, 0, -D0M0.value, 0, 0, 0));
d0_block_copy_global_to_lds.MoveSrcSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(-1, 0, -D0M0.value, 0, 0, 0));
}
}
// P_i: = softmax(scalar * S_i:)
......@@ -2395,6 +2458,46 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_Light_V2
: y_dot_ygrad_thread_buf[Number<m>{}]);
});
// output bias grad
if constexpr(!is_same<D0DataType, void>::value)
{
if(p_d0grad_grid != nullptr)
{
auto d0grad_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_d0grad_grid, d0_grid_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
auto d0grad_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<D0DataType*>(p_shared) + SharedMemTrait::d0_block_space_offset,
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize());
static_for<0, D0M0, 1>{}([&](auto mr) {
d0grad_thread_copy_vgpr_to_lds.Run(
D0Operator::d0_thread_desc_,
make_tuple(mr, I0, I0, I0, I0),
sgrad_thread_buf,
D0Operator::d0_block_vgpr_desc_n0_n1_m0_m1_m2,
d0grad_block_buf);
block_sync_lds();
// write data from lds to global
d0_block_copy_lds_to_global.Run(
D0Operator::d0_block_global_desc_m0_n0_m1_m2_n1_m3,
d0grad_block_buf,
d0_grid_desc_m0_n0_m1_m2_n1_m3,
d0grad_grid_buf,
I0);
d0_block_copy_lds_to_global.MoveDstSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3, make_multi_index(0, 0, 1, 0, 0, 0));
});
d0_block_copy_lds_to_global.MoveDstSliceWindow(
d0_grid_desc_m0_n0_m1_m2_n1_m3,
make_multi_index(-1, 0, -D0M0.value, 0, 0, 0));
}
}
SubThreadBlock<BlockSize> gemm2_a_copy_subgroup(s_blockwise_gemm.GetWaveIdx()[I0],
s_blockwise_gemm.GetWaveIdx()[I1]);
......
......@@ -2151,7 +2151,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
make_multi_index(gemm0_m_block_outer_index, block_work_idx_n, 0, 0, 0, 0),
tensor_operation::element_wise::PassThrough{});
ignore = d0grad_thread_copy_vgpr_to_lds;
if constexpr(Deterministic)
{
block_sync_lds();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment