Commit fa94a220 authored by letaoqin's avatar letaoqin
Browse files

deviece add d0 grad

parent b7b7e153
...@@ -25,7 +25,7 @@ Kernel outputs: ...@@ -25,7 +25,7 @@ Kernel outputs:
#define PRINT_HOST 0 #define PRINT_HOST 0
#define USING_MASK 0 #define USING_MASK 0
#define DIM 64 // DIM should be a multiple of 8. #define DIM 128 // DIM should be a multiple of 8.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -494,6 +494,7 @@ int run(int argc, char* argv[]) ...@@ -494,6 +494,7 @@ int run(int argc, char* argv[])
DeviceMem kgrad_device_buf(sizeof(OutputDataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize()); DeviceMem kgrad_device_buf(sizeof(OutputDataType) * k_gs_ns_ks.mDesc.GetElementSpaceSize());
DeviceMem vgrad_device_buf(sizeof(OutputDataType) * v_gs_os_ns.mDesc.GetElementSpaceSize()); DeviceMem vgrad_device_buf(sizeof(OutputDataType) * v_gs_os_ns.mDesc.GetElementSpaceSize());
DeviceMem ygrad_device_buf(sizeof(InputDataType) * y_gs_ms_os.mDesc.GetElementSpaceSize()); DeviceMem ygrad_device_buf(sizeof(InputDataType) * y_gs_ms_os.mDesc.GetElementSpaceSize());
DeviceMem dgrad_device_buf(sizeof(Acc0BiasDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize());
q_device_buf.ToDevice(q_gs_ms_ks.mData.data()); q_device_buf.ToDevice(q_gs_ms_ks.mData.data());
k_device_buf.ToDevice(k_gs_ns_ks.mData.data()); k_device_buf.ToDevice(k_gs_ns_ks.mData.data());
...@@ -518,6 +519,8 @@ int run(int argc, char* argv[]) ...@@ -518,6 +519,8 @@ int run(int argc, char* argv[])
static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
static_cast<Acc0BiasDataType*>(d_device_buf.GetDeviceBuffer()), // p_acc0_bias; static_cast<Acc0BiasDataType*>(d_device_buf.GetDeviceBuffer()), // p_acc0_bias;
nullptr, // p_acc1_bias; nullptr, // p_acc1_bias;
static_cast<Acc0BiasDataType*>(dgrad_device_buf.GetDeviceBuffer()),
nullptr,
q_gs_ms_ks_lengths, q_gs_ms_ks_lengths,
q_gs_ms_ks_strides, q_gs_ms_ks_strides,
k_gs_ns_ks_lengths, k_gs_ns_ks_lengths,
...@@ -563,6 +566,8 @@ int run(int argc, char* argv[]) ...@@ -563,6 +566,8 @@ int run(int argc, char* argv[])
static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()), static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
static_cast<Acc0BiasDataType*>(d_device_buf.GetDeviceBuffer()), // p_acc0_bias; static_cast<Acc0BiasDataType*>(d_device_buf.GetDeviceBuffer()), // p_acc0_bias;
nullptr, // p_acc1_bias; nullptr, // p_acc1_bias;
static_cast<Acc0BiasDataType*>(dgrad_device_buf.GetDeviceBuffer()),
nullptr,
q_gs_ms_ks_lengths, q_gs_ms_ks_lengths,
q_gs_ms_ks_strides, q_gs_ms_ks_strides,
k_gs_ns_ks_lengths, k_gs_ns_ks_lengths,
......
...@@ -65,6 +65,7 @@ __global__ void ...@@ -65,6 +65,7 @@ __global__ void
const InputDataType* __restrict__ p_ygrad_grid, const InputDataType* __restrict__ p_ygrad_grid,
OutputDataType* __restrict__ p_qgrad_grid, OutputDataType* __restrict__ p_qgrad_grid,
OutputDataType* __restrict__ p_kgrad_grid, OutputDataType* __restrict__ p_kgrad_grid,
D0DataType* __restrict__ p_d0grad_grid,
OutputDataType* __restrict__ p_vgrad_grid, OutputDataType* __restrict__ p_vgrad_grid,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
...@@ -120,13 +121,21 @@ __global__ void ...@@ -120,13 +121,21 @@ __global__ void
const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded; const index_t z_random_matrix_offset = g_idx * raw_m_padded * raw_n_padded;
const D0DataType* tmp_p_d0_grid = nullptr; const D0DataType* tmp_p_d0_grid = nullptr;
D0DataType* tmp_p_d0grad_grid = nullptr;
if constexpr(!is_same<D0DataType, void>::value) if constexpr(!is_same<D0DataType, void>::value)
{ {
const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx))); static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx)));
tmp_p_d0_grid = p_d0_grid + d0_batch_offset; if(p_d0_grid != nullptr)
{
tmp_p_d0_grid = p_d0_grid + d0_batch_offset;
}
if(p_d0grad_grid != nullptr)
{
tmp_p_d0grad_grid = p_d0grad_grid + d0_batch_offset;
}
} }
if constexpr(Deterministic) if constexpr(Deterministic)
{ {
for(index_t i = 0; i < nblock; i++) for(index_t i = 0; i < nblock; i++)
...@@ -142,6 +151,7 @@ __global__ void ...@@ -142,6 +151,7 @@ __global__ void
p_ygrad_grid + c_batch_offset, p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset, p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset, p_kgrad_grid + b_batch_offset,
tmp_p_d0grad_grid,
p_vgrad_grid + b1_batch_offset, p_vgrad_grid + b1_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
...@@ -179,6 +189,7 @@ __global__ void ...@@ -179,6 +189,7 @@ __global__ void
p_ygrad_grid + c_batch_offset, p_ygrad_grid + c_batch_offset,
p_qgrad_grid + a_batch_offset, p_qgrad_grid + a_batch_offset,
p_kgrad_grid + b_batch_offset, p_kgrad_grid + b_batch_offset,
tmp_p_d0grad_grid,
p_vgrad_grid + b1_batch_offset, p_vgrad_grid + b1_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
...@@ -213,6 +224,7 @@ __global__ void ...@@ -213,6 +224,7 @@ __global__ void
ignore = p_ygrad_grid; ignore = p_ygrad_grid;
ignore = p_qgrad_grid; ignore = p_qgrad_grid;
ignore = p_kgrad_grid; ignore = p_kgrad_grid;
ignore = p_d0grad_grid;
ignore = p_vgrad_grid; ignore = p_vgrad_grid;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
...@@ -771,6 +783,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -771,6 +783,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
OutputDataType* p_vgrad_grid, OutputDataType* p_vgrad_grid,
const D0DataType* p_acc0_bias, const D0DataType* p_acc0_bias,
const D1DataType* p_acc1_bias, const D1DataType* p_acc1_bias,
D0DataType* p_d0grad_grid,
D1DataType* p_d1grad_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -806,6 +820,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -806,6 +820,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_qgrad_grid_{p_qgrad_grid}, p_qgrad_grid_{p_qgrad_grid},
p_kgrad_grid_{p_kgrad_grid}, p_kgrad_grid_{p_kgrad_grid},
p_vgrad_grid_{p_vgrad_grid}, p_vgrad_grid_{p_vgrad_grid},
p_d0grad_grid_{p_d0grad_grid},
a_grid_desc_ak0_m_ak1_{ a_grid_desc_ak0_m_ak1_{
DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)},
b_grid_desc_bk0_n_bk1_{ b_grid_desc_bk0_n_bk1_{
...@@ -855,6 +870,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -855,6 +870,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{ {
// TODO: implement bias addition // TODO: implement bias addition
ignore = p_acc1_bias; ignore = p_acc1_bias;
ignore = p_d1grad_grid;
ignore = acc1_bias_gs_ms_gemm1ns_lengths; ignore = acc1_bias_gs_ms_gemm1ns_lengths;
ignore = acc1_bias_gs_ms_gemm1ns_strides; ignore = acc1_bias_gs_ms_gemm1ns_strides;
...@@ -939,6 +955,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -939,6 +955,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
OutputDataType* p_qgrad_grid_; OutputDataType* p_qgrad_grid_;
OutputDataType* p_kgrad_grid_; OutputDataType* p_kgrad_grid_;
OutputDataType* p_vgrad_grid_; OutputDataType* p_vgrad_grid_;
D0DataType* p_d0grad_grid_;
// tensor descriptor // tensor descriptor
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
...@@ -1066,6 +1083,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1066,6 +1083,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
arg.p_ygrad_grid_, arg.p_ygrad_grid_,
arg.p_qgrad_grid_, arg.p_qgrad_grid_,
arg.p_kgrad_grid_, arg.p_kgrad_grid_,
arg.p_d0grad_grid_,
arg.p_vgrad_grid_, arg.p_vgrad_grid_,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
...@@ -1233,6 +1251,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1233,6 +1251,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
OutputDataType* p_vgrad_grid, OutputDataType* p_vgrad_grid,
const D0DataType* p_acc0_bias, const D0DataType* p_acc0_bias,
const D1DataType* p_acc1_bias, const D1DataType* p_acc1_bias,
D0DataType* p_d0grad_grid,
D1DataType* p_d1grad_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -1270,6 +1290,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1270,6 +1290,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_vgrad_grid, p_vgrad_grid,
p_acc0_bias, p_acc0_bias,
p_acc1_bias, p_acc1_bias,
p_d0grad_grid,
p_d1grad_grid,
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b_gs_ns_ks_lengths, b_gs_ns_ks_lengths,
...@@ -1311,6 +1333,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1311,6 +1333,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
void* p_vgrad_grid, void* p_vgrad_grid,
const void* p_acc0_bias, const void* p_acc0_bias,
const void* p_acc1_bias, const void* p_acc1_bias,
D0DataType* p_d0grad_grid,
D1DataType* p_d1grad_grid,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -1349,6 +1373,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1349,6 +1373,8 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static_cast<OutputDataType*>(p_vgrad_grid), static_cast<OutputDataType*>(p_vgrad_grid),
static_cast<const D0DataType*>(p_acc0_bias), // cast in struct Argument static_cast<const D0DataType*>(p_acc0_bias), // cast in struct Argument
static_cast<const D1DataType*>(p_acc1_bias), // cast in struct Argument static_cast<const D1DataType*>(p_acc1_bias), // cast in struct Argument
static_cast<const D0DataType*>(p_d0grad_grid),
static_cast<const D1DataType*>(p_d1grad_grid),
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b_gs_ns_ks_lengths, b_gs_ns_ks_lengths,
......
...@@ -1478,9 +1478,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1478,9 +1478,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static constexpr auto d0_block_space_size_aligned = math::integer_least_multiple( static constexpr auto d0_block_space_size_aligned = math::integer_least_multiple(
D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize(), max_lds_align); D0Loader::d0_block_write_desc_m0_n0_m1_m2_n1_m3.GetElementSpaceSize(), max_lds_align);
static constexpr auto d0_block_space_offset = k_block_space_size_aligned.value * static constexpr auto d0_block_space_offset =
sizeof(GemmDataType) / k_block_space_size_aligned.value * sizeof(GemmDataType) /
D0Loader::template TypeTransform<D0DataType>::Size; D0Loader::template TypeTransform<D0DataType>::Size;
// LDS allocation for C shuffle in LDS // LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
...@@ -1537,6 +1537,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1537,6 +1537,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const InputDataType* __restrict__ p_ygrad_grid, const InputDataType* __restrict__ p_ygrad_grid,
OutputDataType* __restrict__ p_qgrad_grid, OutputDataType* __restrict__ p_qgrad_grid,
OutputDataType* __restrict__ p_kgrad_grid, OutputDataType* __restrict__ p_kgrad_grid,
D0DataType* __restrict__ p_d0grad_grid,
OutputDataType* __restrict__ p_vgrad_grid, OutputDataType* __restrict__ p_vgrad_grid,
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
...@@ -1562,6 +1563,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1562,6 +1563,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const index_t raw_n_padded, const index_t raw_n_padded,
const index_t block_idx_n) const index_t block_idx_n)
{ {
ignore = p_d0grad_grid;
const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop); const FloatGemmAcc p_dropout = type_convert<FloatGemmAcc>(1.0f - p_drop);
const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout); const FloatGemmAcc rp_dropout = type_convert<FloatGemmAcc>(1.0f / p_dropout);
const ushort p_dropout_in_16bits = const ushort p_dropout_in_16bits =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment