"git@developer.sourcefind.cn:change/sglang.git" did not exist on "2d3ae4e1258791a04a28279044359c08c16af99e"
Commit 592b0649 authored by letaoqin's avatar letaoqin
Browse files

grouped bwd add bias grad

parent 0dba17c3
...@@ -603,7 +603,7 @@ int run(int argc, char* argv[]) ...@@ -603,7 +603,7 @@ int run(int argc, char* argv[])
(sizeof(InputDataType) * M * K + sizeof(InputDataType) * K * N + (sizeof(InputDataType) * M * K + sizeof(InputDataType) * K * N +
sizeof(InputDataType) * N * O + sizeof(InputDataType) * M * O * size_t(2) + sizeof(InputDataType) * N * O + sizeof(InputDataType) * M * O * size_t(2) +
sizeof(OutputDataType) * M * K + sizeof(OutputDataType) * K * N + sizeof(OutputDataType) * M * K + sizeof(OutputDataType) * K * N +
sizeof(OutputDataType) * N * O + sizeof(Acc0BiasDataType) * M * N) * sizeof(OutputDataType) * N * O + sizeof(Acc0BiasDataType) * M * N * size_t(2)) *
BatchCount + BatchCount +
sizeof(LSEDataType) * M * BatchCount; sizeof(LSEDataType) * M * BatchCount;
......
...@@ -24,7 +24,7 @@ Kernel outputs: ...@@ -24,7 +24,7 @@ Kernel outputs:
*/ */
#define USING_MASK 0 #define USING_MASK 0
#define DIM 128 // DIM should be a multiple of 8. #define DIM 64 // DIM should be a multiple of 8.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
...@@ -333,6 +333,7 @@ int run(int argc, char* argv[]) ...@@ -333,6 +333,7 @@ int run(int argc, char* argv[])
std::vector<const void*> p_lse; std::vector<const void*> p_lse;
std::vector<void*> p_qgrad; std::vector<void*> p_qgrad;
std::vector<void*> p_kgrad; std::vector<void*> p_kgrad;
std::vector<void*> p_d0grad;
std::vector<void*> p_vgrad; std::vector<void*> p_vgrad;
std::vector<const void*> p_ygrad; std::vector<const void*> p_ygrad;
...@@ -356,6 +357,7 @@ int run(int argc, char* argv[]) ...@@ -356,6 +357,7 @@ int run(int argc, char* argv[])
std::vector<Tensor<LSEDataType>> lse_tensors; std::vector<Tensor<LSEDataType>> lse_tensors;
std::vector<Tensor<OutputDataType>> qgrad_tensors; std::vector<Tensor<OutputDataType>> qgrad_tensors;
std::vector<Tensor<OutputDataType>> kgrad_tensors; std::vector<Tensor<OutputDataType>> kgrad_tensors;
std::vector<Tensor<Acc0BiasDataType>> d0grad_tensors;
std::vector<Tensor<OutputDataType>> vgrad_tensors; std::vector<Tensor<OutputDataType>> vgrad_tensors;
std::vector<Tensor<InputDataType>> ygrad_tensors; std::vector<Tensor<InputDataType>> ygrad_tensors;
...@@ -369,6 +371,7 @@ int run(int argc, char* argv[]) ...@@ -369,6 +371,7 @@ int run(int argc, char* argv[])
std::vector<DeviceMemPtr> qgrad_tensors_device; std::vector<DeviceMemPtr> qgrad_tensors_device;
std::vector<DeviceMemPtr> ygrad_tensors_device; std::vector<DeviceMemPtr> ygrad_tensors_device;
std::vector<DeviceMemPtr> kgrad_tensors_device; std::vector<DeviceMemPtr> kgrad_tensors_device;
std::vector<DeviceMemPtr> d0grad_tensors_device;
std::vector<DeviceMemPtr> vgrad_tensors_device; std::vector<DeviceMemPtr> vgrad_tensors_device;
std::size_t group_count = 10; std::size_t group_count = 10;
std::size_t flop = 0, num_byte = 0; std::size_t flop = 0, num_byte = 0;
...@@ -445,12 +448,13 @@ int run(int argc, char* argv[]) ...@@ -445,12 +448,13 @@ int run(int argc, char* argv[])
int BatchCount = G0 * G1; int BatchCount = G0 * G1;
flop += (size_t(3) * M * N * K + size_t(2) * M * N * O) * 2 * BatchCount; flop += (size_t(3) * M * N * K + size_t(2) * M * N * O) * 2 * BatchCount;
// Q/K/V/Y, dQ/dK/dV/dY, LSE // Q/K/V/Y, dQ/dK/dV/dY, LSE
num_byte += (sizeof(InputDataType) * M * K + sizeof(InputDataType) * K * N + num_byte +=
sizeof(InputDataType) * N * O + sizeof(InputDataType) * M * O * size_t(2) + (sizeof(InputDataType) * M * K + sizeof(InputDataType) * K * N +
sizeof(OutputDataType) * M * K + sizeof(OutputDataType) * K * N + sizeof(InputDataType) * N * O + sizeof(InputDataType) * M * O * size_t(2) +
sizeof(OutputDataType) * N * O + sizeof(Acc0BiasDataType) * M * N) * sizeof(OutputDataType) * M * K + sizeof(OutputDataType) * K * N +
BatchCount + sizeof(OutputDataType) * N * O + sizeof(Acc0BiasDataType) * M * N * size_t(2)) *
sizeof(LSEDataType) * M * BatchCount; BatchCount +
sizeof(LSEDataType) * M * BatchCount;
Tensor<InputDataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); Tensor<InputDataType> q_gs_ms_ks(q_gs_ms_ks_lengths, q_gs_ms_ks_strides);
Tensor<InputDataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides); Tensor<InputDataType> k_gs_ns_ks(k_gs_ns_ks_lengths, k_gs_ns_ks_strides);
...@@ -600,6 +604,8 @@ int run(int argc, char* argv[]) ...@@ -600,6 +604,8 @@ int run(int argc, char* argv[])
std::make_unique<DeviceMem>(sizeof(OutputDataType) * q_gs_ms_ks.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(OutputDataType) * q_gs_ms_ks.GetElementSpaceSize()));
kgrad_tensors_device.emplace_back( kgrad_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(OutputDataType) * k_gs_ns_ks.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(OutputDataType) * k_gs_ns_ks.GetElementSpaceSize()));
d0grad_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(Acc0BiasDataType) * d0_gs_ms_ns.GetElementSpaceSize()));
vgrad_tensors_device.emplace_back( vgrad_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(OutputDataType) * v_gs_os_ns.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(OutputDataType) * v_gs_os_ns.GetElementSpaceSize()));
ygrad_tensors_device.emplace_back( ygrad_tensors_device.emplace_back(
...@@ -619,6 +625,7 @@ int run(int argc, char* argv[]) ...@@ -619,6 +625,7 @@ int run(int argc, char* argv[])
p_y.push_back(y_tensors_device.back()->GetDeviceBuffer()); p_y.push_back(y_tensors_device.back()->GetDeviceBuffer());
p_lse.push_back(lse_tensors_device.back()->GetDeviceBuffer()); p_lse.push_back(lse_tensors_device.back()->GetDeviceBuffer());
p_kgrad.push_back(kgrad_tensors_device.back()->GetDeviceBuffer()); p_kgrad.push_back(kgrad_tensors_device.back()->GetDeviceBuffer());
p_d0grad.push_back(d0grad_tensors_device.back()->GetDeviceBuffer());
p_vgrad.push_back(vgrad_tensors_device.back()->GetDeviceBuffer()); p_vgrad.push_back(vgrad_tensors_device.back()->GetDeviceBuffer());
p_ygrad.push_back(ygrad_tensors_device.back()->GetDeviceBuffer()); p_ygrad.push_back(ygrad_tensors_device.back()->GetDeviceBuffer());
p_qgrad.push_back(qgrad_tensors_device.back()->GetDeviceBuffer()); p_qgrad.push_back(qgrad_tensors_device.back()->GetDeviceBuffer());
...@@ -636,6 +643,8 @@ int run(int argc, char* argv[]) ...@@ -636,6 +643,8 @@ int run(int argc, char* argv[])
p_vgrad, p_vgrad,
p_d0, p_d0,
{}, {},
p_d0grad,
{},
problem_descs, problem_descs,
QKVElementOp{}, QKVElementOp{},
QKVElementOp{}, QKVElementOp{},
...@@ -682,6 +691,8 @@ int run(int argc, char* argv[]) ...@@ -682,6 +691,8 @@ int run(int argc, char* argv[])
p_vgrad, p_vgrad,
p_d0, p_d0,
{}, {},
p_d0grad,
{},
problem_descs, problem_descs,
QKVElementOp{}, QKVElementOp{},
QKVElementOp{}, QKVElementOp{},
...@@ -732,6 +743,7 @@ int run(int argc, char* argv[]) ...@@ -732,6 +743,7 @@ int run(int argc, char* argv[])
lse_tensors_device[i]->ToDevice(lse_tensors[i].data()); lse_tensors_device[i]->ToDevice(lse_tensors[i].data());
qgrad_tensors_device[i]->SetZero(); qgrad_tensors_device[i]->SetZero();
kgrad_tensors_device[i]->SetZero(); kgrad_tensors_device[i]->SetZero();
d0grad_tensors_device[i]->SetZero();
vgrad_tensors_device[i]->SetZero(); vgrad_tensors_device[i]->SetZero();
} }
...@@ -804,6 +816,8 @@ int run(int argc, char* argv[]) ...@@ -804,6 +816,8 @@ int run(int argc, char* argv[])
q_tensors[i].GetStrides()); q_tensors[i].GetStrides());
Tensor<OutputDataType> kgrad_gs_ns_ks_host_result(k_tensors[i].GetLengths(), Tensor<OutputDataType> kgrad_gs_ns_ks_host_result(k_tensors[i].GetLengths(),
k_tensors[i].GetStrides()); k_tensors[i].GetStrides());
Tensor<Acc0BiasDataType> d0grad_gs_ms_ns_host_result(d0_tensors[i].GetLengths(),
d0_tensors[i].GetStrides());
Tensor<OutputDataType> vgrad_gs_os_ns_host_result(v_tensors[i].GetLengths(), Tensor<OutputDataType> vgrad_gs_os_ns_host_result(v_tensors[i].GetLengths(),
v_tensors[i].GetStrides()); v_tensors[i].GetStrides());
...@@ -811,11 +825,14 @@ int run(int argc, char* argv[]) ...@@ -811,11 +825,14 @@ int run(int argc, char* argv[])
q_tensors[i].GetStrides()); q_tensors[i].GetStrides());
Tensor<OutputDataType> kgrad_gs_ns_ks_device_result(k_tensors[i].GetLengths(), Tensor<OutputDataType> kgrad_gs_ns_ks_device_result(k_tensors[i].GetLengths(),
k_tensors[i].GetStrides()); k_tensors[i].GetStrides());
Tensor<Acc0BiasDataType> d0grad_gs_ms_ns_device_result(d0_tensors[i].GetLengths(),
d0_tensors[i].GetStrides());
Tensor<OutputDataType> vgrad_gs_os_ns_device_result(v_tensors[i].GetLengths(), Tensor<OutputDataType> vgrad_gs_os_ns_device_result(v_tensors[i].GetLengths(),
v_tensors[i].GetStrides()); v_tensors[i].GetStrides());
qgrad_tensors_device[i]->FromDevice(qgrad_gs_ms_ks_device_result.data()); qgrad_tensors_device[i]->FromDevice(qgrad_gs_ms_ks_device_result.data());
kgrad_tensors_device[i]->FromDevice(kgrad_gs_ns_ks_device_result.data()); kgrad_tensors_device[i]->FromDevice(kgrad_gs_ns_ks_device_result.data());
d0grad_tensors_device[i]->FromDevice(d0grad_gs_ms_ns_device_result.data());
vgrad_tensors_device[i]->FromDevice(vgrad_gs_os_ns_device_result.data()); vgrad_tensors_device[i]->FromDevice(vgrad_gs_os_ns_device_result.data());
// permute // permute
qgrad_gs_ms_ks_host_result.ForEach([&](auto& self, auto idx) { qgrad_gs_ms_ks_host_result.ForEach([&](auto& self, auto idx) {
...@@ -834,6 +851,14 @@ int run(int argc, char* argv[]) ...@@ -834,6 +851,14 @@ int run(int argc, char* argv[])
self(idx) = kgrad_g_n_k(g, idx[2], idx[3]); self(idx) = kgrad_g_n_k(g, idx[2], idx[3]);
}); });
d0grad_gs_ms_ns_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0];
const size_t& g1 = idx[1];
const size_t g = g0 * G1 + g1;
self(idx) = sgrad_g_m_n(g, idx[2], idx[3]);
});
vgrad_gs_os_ns_host_result.ForEach([&](auto& self, auto idx) { vgrad_gs_os_ns_host_result.ForEach([&](auto& self, auto idx) {
const size_t& g0 = idx[0]; const size_t& g0 = idx[0];
const size_t& g1 = idx[1]; const size_t& g1 = idx[1];
...@@ -861,6 +886,12 @@ int run(int argc, char* argv[]) ...@@ -861,6 +886,12 @@ int run(int argc, char* argv[])
"error", "error",
1e-2, 1e-2,
1e-2); 1e-2);
std::cout << "Checking d0grad:\n";
pass &= ck::utils::check_err(d0grad_gs_ms_ns_device_result.mData,
d0grad_gs_ms_ns_host_result.mData,
"error",
1e-2,
1e-2);
} }
} }
......
...@@ -103,13 +103,17 @@ __global__ void ...@@ -103,13 +103,17 @@ __global__ void
: arg_ptr[group_id].p_z_grid_ + z_batch_offset); : arg_ptr[group_id].p_z_grid_ + z_batch_offset);
const D0DataType* tmp_p_d0_grid = nullptr; const D0DataType* tmp_p_d0_grid = nullptr;
D0DataType* tmp_p_d0grad_grid = nullptr;
if constexpr(!is_same<D0DataType, void>::value) if constexpr(!is_same<D0DataType, void>::value)
{ {
const long_index_t d0_batch_offset = const long_index_t d0_batch_offset =
__builtin_amdgcn_readfirstlane(static_cast<long_index_t>( __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetD0BasePtr(g_idx))); arg_ptr[group_id].compute_base_ptr_of_batch_.GetD0BasePtr(g_idx)));
tmp_p_d0_grid = arg_ptr[group_id].p_d0_grid_ + d0_batch_offset; if(arg_ptr[group_id].p_d0_grid_ != nullptr)
tmp_p_d0_grid = arg_ptr[group_id].p_d0_grid_ + d0_batch_offset;
if(arg_ptr[group_id].p_d0grad_grid_)
tmp_p_d0grad_grid = arg_ptr[group_id].p_d0grad_grid_ + d0_batch_offset;
} }
if constexpr(Deterministic) if constexpr(Deterministic)
{ {
...@@ -126,6 +130,7 @@ __global__ void ...@@ -126,6 +130,7 @@ __global__ void
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset, arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset, arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset,
arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset, arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset,
tmp_p_d0grad_grid,
arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset, arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
...@@ -164,6 +169,7 @@ __global__ void ...@@ -164,6 +169,7 @@ __global__ void
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset, arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset, arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset,
arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset, arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset,
tmp_p_d0grad_grid,
arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset, arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
...@@ -696,6 +702,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -696,6 +702,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const InputDataType* p_ygrad_grid_; const InputDataType* p_ygrad_grid_;
OutputDataType* p_qgrad_grid_; OutputDataType* p_qgrad_grid_;
OutputDataType* p_kgrad_grid_; OutputDataType* p_kgrad_grid_;
D0DataType* p_d0grad_grid_;
OutputDataType* p_vgrad_grid_; OutputDataType* p_vgrad_grid_;
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
...@@ -760,6 +767,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -760,6 +767,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std::vector<void*>& p_Vgrads, std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias_vec, const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec, const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<void*>& p_d0grads,
const std::vector<void*>& p_d1grads,
const std::vector<ProblemDesc>& problem_desc_vec, const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -792,7 +801,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -792,7 +801,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
group_count_ == ck::type_convert<ck::index_t>(p_LSEs.size()) && group_count_ == ck::type_convert<ck::index_t>(p_LSEs.size()) &&
(group_count_ == ck::type_convert<ck::index_t>(p_acc0_bias_vec.size()) || (group_count_ == ck::type_convert<ck::index_t>(p_acc0_bias_vec.size()) ||
ck::type_convert<ck::index_t>(p_acc0_bias_vec.size() == 0)) && ck::type_convert<ck::index_t>(p_acc0_bias_vec.size() == 0)) &&
0 == p_acc1_bias_vec.size())) 0 == p_acc1_bias_vec.size() &&
(group_count_ == ck::type_convert<ck::index_t>(p_d0grads.size()) ||
ck::type_convert<ck::index_t>(p_d0grads.size() == 0)) &&
0 == p_d1grads.size()))
{ {
throw std::runtime_error("wrong! group_count_ != p_As/b/b1/c.size"); throw std::runtime_error("wrong! group_count_ != p_As/b/b1/c.size");
} }
...@@ -816,7 +828,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -816,7 +828,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const auto p_ygrad_grid = static_cast<const InputDataType*>(p_Ygrads[i]); const auto p_ygrad_grid = static_cast<const InputDataType*>(p_Ygrads[i]);
auto p_qgrad_grid = static_cast<OutputDataType*>(p_Qgrads[i]); auto p_qgrad_grid = static_cast<OutputDataType*>(p_Qgrads[i]);
auto p_kgrad_grid = static_cast<OutputDataType*>(p_Kgrads[i]); auto p_kgrad_grid = static_cast<OutputDataType*>(p_Kgrads[i]);
auto p_vgrad_grid = static_cast<OutputDataType*>(p_Vgrads[i]); auto p_d0grad_grid =
(ck::type_convert<ck::index_t>(p_d0grads.size()) == group_count_)
? static_cast<D0DataType*>(p_d0grads[i])
: nullptr;
auto p_vgrad_grid = static_cast<OutputDataType*>(p_Vgrads[i]);
const auto& problem_desc = problem_desc_vec[i]; const auto& problem_desc = problem_desc_vec[i];
...@@ -925,6 +941,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -925,6 +941,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_ygrad_grid, p_ygrad_grid,
p_qgrad_grid, p_qgrad_grid,
p_kgrad_grid, p_kgrad_grid,
p_d0grad_grid,
p_vgrad_grid, p_vgrad_grid,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
...@@ -1214,6 +1231,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1214,6 +1231,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std::vector<void*>& p_Vgrads, std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias_vec, const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec, const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<void*>& p_d0grads,
const std::vector<void*>& p_d1grads,
const std::vector<ProblemDesc>& problem_desc_vec, const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -1235,6 +1254,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1235,6 +1254,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_Vgrads, p_Vgrads,
p_acc0_bias_vec, p_acc0_bias_vec,
p_acc1_bias_vec, p_acc1_bias_vec,
p_d0grads,
p_d1grads,
problem_desc_vec, problem_desc_vec,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -1262,6 +1283,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1262,6 +1283,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std::vector<void*>& p_Vgrads, std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias_vec, const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec, const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<void*>& p_d0grads,
const std::vector<void*>& p_d1grads,
const std::vector<ProblemDesc>& problem_desc_vec, const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -1283,6 +1306,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -1283,6 +1306,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_Vgrads, p_Vgrads,
p_acc0_bias_vec, // cast in struct Argument p_acc0_bias_vec, // cast in struct Argument
p_acc1_bias_vec, // cast in struct Argument p_acc1_bias_vec, // cast in struct Argument
p_d0grads,
p_d1grads,
problem_desc_vec, problem_desc_vec,
a_element_op, a_element_op,
b_element_op, b_element_op,
......
...@@ -102,13 +102,16 @@ __global__ void ...@@ -102,13 +102,16 @@ __global__ void
(arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr (arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
: arg_ptr[group_id].p_z_grid_ + z_batch_offset); : arg_ptr[group_id].p_z_grid_ + z_batch_offset);
const D0DataType* tmp_p_d0_grid = nullptr; const D0DataType* tmp_p_d0_grid = nullptr;
D0DataType* tmp_p_d0grad_grid = nullptr;
if constexpr(!is_same<D0DataType, void>::value) if constexpr(!is_same<D0DataType, void>::value)
{ {
const long_index_t d0_batch_offset = const long_index_t d0_batch_offset =
__builtin_amdgcn_readfirstlane(static_cast<long_index_t>( __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetD0BasePtr(g_idx))); arg_ptr[group_id].compute_base_ptr_of_batch_.GetD0BasePtr(g_idx)));
if(arg_ptr[group_id].p_d0_grid_ != nullptr)
tmp_p_d0_grid = arg_ptr[group_id].p_d0_grid_ + d0_batch_offset; tmp_p_d0_grid = arg_ptr[group_id].p_d0_grid_ + d0_batch_offset;
if(arg_ptr[group_id].p_d0grad_grid_)
tmp_p_d0grad_grid = arg_ptr[group_id].p_d0grad_grid_ + d0_batch_offset;
} }
if constexpr(Deterministic) if constexpr(Deterministic)
...@@ -126,6 +129,7 @@ __global__ void ...@@ -126,6 +129,7 @@ __global__ void
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset, arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset, arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset,
arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset, arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset,
tmp_p_d0grad_grid,
arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset, arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
...@@ -164,6 +168,7 @@ __global__ void ...@@ -164,6 +168,7 @@ __global__ void
arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset, arg_ptr[group_id].p_ygrad_grid_ + c_batch_offset,
arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset, arg_ptr[group_id].p_qgrad_grid_ + a_batch_offset,
arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset, arg_ptr[group_id].p_kgrad_grid_ + b_batch_offset,
tmp_p_d0grad_grid,
arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset, arg_ptr[group_id].p_vgrad_grid_ + b1_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
...@@ -767,6 +772,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -767,6 +772,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const InputDataType* p_ygrad_grid_; const InputDataType* p_ygrad_grid_;
OutputDataType* p_qgrad_grid_; OutputDataType* p_qgrad_grid_;
OutputDataType* p_kgrad_grid_; OutputDataType* p_kgrad_grid_;
D0DataType* p_d0grad_grid_;
OutputDataType* p_vgrad_grid_; OutputDataType* p_vgrad_grid_;
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
...@@ -831,6 +837,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -831,6 +837,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std::vector<void*>& p_Vgrads, std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias_vec, const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec, const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<void*>& p_d0grads,
const std::vector<void*>& p_d1grads,
const std::vector<ProblemDesc>& problem_desc_vec, const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -863,7 +871,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -863,7 +871,10 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
group_count_ == ck::type_convert<ck::index_t>(p_LSEs.size()) && group_count_ == ck::type_convert<ck::index_t>(p_LSEs.size()) &&
(group_count_ == ck::type_convert<ck::index_t>(p_acc0_bias_vec.size()) || (group_count_ == ck::type_convert<ck::index_t>(p_acc0_bias_vec.size()) ||
ck::type_convert<ck::index_t>(p_acc0_bias_vec.size() == 0)) && ck::type_convert<ck::index_t>(p_acc0_bias_vec.size() == 0)) &&
0 == p_acc1_bias_vec.size())) 0 == p_acc1_bias_vec.size() &&
(group_count_ == ck::type_convert<ck::index_t>(p_d0grads.size()) ||
ck::type_convert<ck::index_t>(p_d0grads.size() == 0)) &&
0 == p_d1grads.size()))
{ {
throw std::runtime_error("wrong! group_count_ != p_As/b/b1/c.size"); throw std::runtime_error("wrong! group_count_ != p_As/b/b1/c.size");
} }
...@@ -887,7 +898,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -887,7 +898,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
const auto p_ygrad_grid = static_cast<const InputDataType*>(p_Ygrads[i]); const auto p_ygrad_grid = static_cast<const InputDataType*>(p_Ygrads[i]);
auto p_qgrad_grid = static_cast<OutputDataType*>(p_Qgrads[i]); auto p_qgrad_grid = static_cast<OutputDataType*>(p_Qgrads[i]);
auto p_kgrad_grid = static_cast<OutputDataType*>(p_Kgrads[i]); auto p_kgrad_grid = static_cast<OutputDataType*>(p_Kgrads[i]);
auto p_vgrad_grid = static_cast<OutputDataType*>(p_Vgrads[i]); auto p_d0grad_grid =
(ck::type_convert<ck::index_t>(p_d0grads.size()) == group_count_)
? static_cast<D0DataType*>(p_d0grads[i])
: nullptr;
auto p_vgrad_grid = static_cast<OutputDataType*>(p_Vgrads[i]);
const auto& problem_desc = problem_desc_vec[i]; const auto& problem_desc = problem_desc_vec[i];
...@@ -996,6 +1011,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -996,6 +1011,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_ygrad_grid, p_ygrad_grid,
p_qgrad_grid, p_qgrad_grid,
p_kgrad_grid, p_kgrad_grid,
p_d0grad_grid,
p_vgrad_grid, p_vgrad_grid,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
...@@ -1290,6 +1306,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1290,6 +1306,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std::vector<void*>& p_Vgrads, std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias_vec, const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec, const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<void*>& p_d0grads,
const std::vector<void*>& p_d1grads,
const std::vector<ProblemDesc>& problem_desc_vec, const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -1311,6 +1329,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1311,6 +1329,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_Vgrads, p_Vgrads,
p_acc0_bias_vec, p_acc0_bias_vec,
p_acc1_bias_vec, p_acc1_bias_vec,
p_d0grads,
p_d1grads,
problem_desc_vec, problem_desc_vec,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -1338,6 +1358,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1338,6 +1358,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
std::vector<void*>& p_Vgrads, std::vector<void*>& p_Vgrads,
const std::vector<const void*>& p_acc0_bias_vec, const std::vector<const void*>& p_acc0_bias_vec,
const std::vector<const void*>& p_acc1_bias_vec, const std::vector<const void*>& p_acc1_bias_vec,
const std::vector<void*>& p_d0grads,
const std::vector<void*>& p_d1grads,
const std::vector<ProblemDesc>& problem_desc_vec, const std::vector<ProblemDesc>& problem_desc_vec,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -1359,6 +1381,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -1359,6 +1381,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
p_Vgrads, p_Vgrads,
p_acc0_bias_vec, // cast in struct Argument p_acc0_bias_vec, // cast in struct Argument
p_acc1_bias_vec, // cast in struct Argument p_acc1_bias_vec, // cast in struct Argument
p_d0grads,
p_d1grads,
problem_desc_vec, problem_desc_vec,
a_element_op, a_element_op,
b_element_op, b_element_op,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment