"...torch/git@developer.sourcefind.cn:OpenDAS/fairscale.git" did not exist on "861b5ce2bdd697b17b0ce759a5f5f126cd3a8915"
Commit 67e10a6a authored by letaoqin's avatar letaoqin
Browse files

v1 group finished

parent 8efd67d8
...@@ -24,7 +24,7 @@ Kernel outputs: ...@@ -24,7 +24,7 @@ Kernel outputs:
*/ */
#define USING_MASK 0 #define USING_MASK 0
#define DIM 128 // DIM should be a multiple of 8. #define DIM 64 // DIM should be a multiple of 8.
#include <iostream> #include <iostream>
#include <numeric> #include <numeric>
......
...@@ -134,6 +134,7 @@ __global__ void ...@@ -134,6 +134,7 @@ __global__ void
c_element_op, c_element_op,
arg_ptr[group_id].a_grid_desc_ak0_m_ak1_, arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].d0_grid_desc_m0_n0_m1_m2_n1_m3_,
arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_, arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].y_grid_desc_mblock_mperblock_oblock_operblock_, arg_ptr[group_id].y_grid_desc_mblock_mperblock_oblock_operblock_,
...@@ -171,6 +172,7 @@ __global__ void ...@@ -171,6 +172,7 @@ __global__ void
c_element_op, c_element_op,
arg_ptr[group_id].a_grid_desc_ak0_m_ak1_, arg_ptr[group_id].a_grid_desc_ak0_m_ak1_,
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].d0_grid_desc_m0_n0_m1_m2_n1_m3_,
arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_, arg_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].y_grid_desc_mblock_mperblock_oblock_operblock_, arg_ptr[group_id].y_grid_desc_mblock_mperblock_oblock_operblock_,
...@@ -297,11 +299,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -297,11 +299,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
std::vector<index_t> lse_gs_ms_lengths; std::vector<index_t> lse_gs_ms_lengths;
std::vector<index_t> lse_gs_ms_strides; std::vector<index_t> lse_gs_ms_strides;
std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_lengths; std::vector<index_t> acc0_biases_gs_ms_ns_lengths;
std::vector<std::vector<index_t>> acc0_biases_gs_ms_ns_strides; std::vector<index_t> acc0_biases_gs_ms_ns_strides;
std::vector<std::vector<index_t>> acc1_biases_gs_ms_os_lengths; std::vector<index_t> acc1_biases_gs_ms_os_lengths;
std::vector<std::vector<index_t>> acc1_biases_gs_ms_os_strides; std::vector<index_t> acc1_biases_gs_ms_os_strides;
}; };
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -495,10 +497,22 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -495,10 +497,22 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
} }
} }
// D in Gemm0 C position // D in Gemm0 C position
static auto MakeDGridDescriptor_M_N(const std::vector<index_t>& d_gs_ms_ns_lengths_vec, static auto
const std::vector<index_t>& d_gs_ms_ns_strides_vec) MakeD0GridDescriptor_M_N(const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides)
{ {
return Transform::MakeCGridDescriptor_M_N(d_gs_ms_ns_lengths_vec, d_gs_ms_ns_strides_vec);
return Transform::MakeCGridDescriptor_M_N(acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides);
}
static auto
MakeD0GridDescriptor_G_M_N(const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_lengths,
const std::vector<ck::index_t>& acc0_biases_gs_ms_ns_strides)
{
return Transform::MakeCGridDescriptor_G_M_N(acc0_biases_gs_ms_ns_lengths,
acc0_biases_gs_ms_ns_strides);
} }
using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})); using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {}));
...@@ -508,12 +522,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -508,12 +522,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1)); using LSEGridDesc_M = decltype(MakeLSEGridDescriptor_M(1));
using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})); using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {}));
using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})); using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {}));
using D0GridDesc_G_M_N = decltype(MakeD0GridDescriptor_G_M_N({}, {}));
using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})); using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {}));
using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); using ZGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {}));
using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {})); using KGridDesc_N_K = decltype(Transform::MakeB0GridDescriptor_N_K({}, {}));
using D0GridDesc_M_N = decltype(MakeDGridDescriptor_M_N({}, {})); using D0GridDesc_M_N = decltype(MakeD0GridDescriptor_M_N({}, {}));
using YGradGridDesc_O0_M_O1 = decltype(MakeYGradGridDescriptor_O0_M_O1({}, {})); using YGradGridDesc_O0_M_O1 = decltype(MakeYGradGridDescriptor_O0_M_O1({}, {}));
using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {})); using ZGridDesc_M_N = decltype(MakeZGridDescriptor_M_N({}, {}));
...@@ -538,12 +553,14 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -538,12 +553,14 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
{ {
ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k, ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k,
const BGridDesc_G_N_K& b_grid_desc_g_n_k, const BGridDesc_G_N_K& b_grid_desc_g_n_k,
const D0GridDesc_G_M_N& d0_grid_desc_g_m_n,
const ZGridDesc_G_M_N& z_grid_desc_g_m_n, const ZGridDesc_G_M_N& z_grid_desc_g_m_n,
const B1GridDesc_G_N_K& b1_grid_desc_g_n_k, const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
const CGridDesc_G_M_N& c_grid_desc_g_m_n, const CGridDesc_G_M_N& c_grid_desc_g_m_n,
index_t batch_stride_lse) index_t batch_stride_lse)
: a_grid_desc_g_m_k_(a_grid_desc_g_m_k), : a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
b_grid_desc_g_n_k_(b_grid_desc_g_n_k), b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
d0_grid_desc_g_m_n_(d0_grid_desc_g_m_n),
z_grid_desc_g_m_n_(z_grid_desc_g_m_n), z_grid_desc_g_m_n_(z_grid_desc_g_m_n),
b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k), b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
c_grid_desc_g_m_n_(c_grid_desc_g_m_n), c_grid_desc_g_m_n_(c_grid_desc_g_m_n),
...@@ -561,6 +578,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -561,6 +578,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0));
} }
__host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx) const
{
return d0_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetZBasePtr(index_t g_idx) const __host__ __device__ constexpr long_index_t GetZBasePtr(index_t g_idx) const
{ {
return z_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0)); return z_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
...@@ -584,6 +606,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -584,6 +606,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
private: private:
AGridDesc_G_M_K a_grid_desc_g_m_k_; AGridDesc_G_M_K a_grid_desc_g_m_k_;
BGridDesc_G_N_K b_grid_desc_g_n_k_; BGridDesc_G_N_K b_grid_desc_g_n_k_;
D0GridDesc_G_M_N d0_grid_desc_g_m_n_;
ZGridDesc_G_M_N z_grid_desc_g_m_n_; ZGridDesc_G_M_N z_grid_desc_g_m_n_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_; B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
...@@ -663,6 +686,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -663,6 +686,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// pointers // pointers
const InputDataType* p_a_grid_; const InputDataType* p_a_grid_;
const InputDataType* p_b_grid_; const InputDataType* p_b_grid_;
const D0DataType* p_d0_grid_;
ZDataType* p_z_grid_; ZDataType* p_z_grid_;
const InputDataType* p_b1_grid_; const InputDataType* p_b1_grid_;
const InputDataType* p_c_grid_; const InputDataType* p_c_grid_;
...@@ -675,6 +699,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -675,6 +699,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
// tensor descriptors for block/thread-wise copy // tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_M2_N1_M3 d0_grid_desc_m0_n0_m1_m2_n1_m3_;
ZGridDesc_M_N z_grid_desc_m_n_; ZGridDesc_M_N z_grid_desc_m_n_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
YGridDesc_M_O y_grid_desc_m_o_; YGridDesc_M_O y_grid_desc_m_o_;
...@@ -714,6 +739,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -714,6 +739,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
index_t batch_count_; index_t batch_count_;
// raw data
std::vector<ck::index_t> d0_n_length_stride_;
}; };
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -759,24 +787,26 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -759,24 +787,26 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
group_count_ == ck::type_convert<ck::index_t>(p_Qgrads.size()) && group_count_ == ck::type_convert<ck::index_t>(p_Qgrads.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_Kgrads.size()) && group_count_ == ck::type_convert<ck::index_t>(p_Kgrads.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_Vgrads.size()) && group_count_ == ck::type_convert<ck::index_t>(p_Vgrads.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_LSEs.size()))) group_count_ == ck::type_convert<ck::index_t>(p_LSEs.size()) &&
(group_count_ == ck::type_convert<ck::index_t>(p_acc0_biases.size()) ||
ck::type_convert<ck::index_t>(p_acc0_biases.size() == 0)) &&
0 == p_acc1_biases.size()))
{ {
throw std::runtime_error("wrong! group_count_ != p_As/b/b1/c.size"); throw std::runtime_error("wrong! group_count_ != p_As/b/b1/c.size");
} }
if(!(p_acc0_biases.size() == p_acc1_biases.size()))
{
throw std::runtime_error("wrong! acc0_bias_vec.size != acc1_bias_vec.size");
}
grid_size_ = 0; grid_size_ = 0;
index_t z_random_matrix_offset = 0; index_t z_random_matrix_offset = 0;
for(index_t i = 0; i < group_count_; i++) for(index_t i = 0; i < group_count_; i++)
{ {
const auto p_a_grid = static_cast<const InputDataType*>(p_As[i]); const auto p_a_grid = static_cast<const InputDataType*>(p_As[i]);
const auto p_b_grid = static_cast<const InputDataType*>(p_Bs[i]); const auto p_b_grid = static_cast<const InputDataType*>(p_Bs[i]);
const auto p_d0_grid =
(ck::type_convert<ck::index_t>(p_acc0_biases.size()) == group_count_)
? static_cast<const D0DataType*>(p_acc0_biases[i])
: nullptr;
auto p_z_grid = static_cast<ZDataType*>(p_Zs[i]); auto p_z_grid = static_cast<ZDataType*>(p_Zs[i]);
const auto p_b1_grid = static_cast<const InputDataType*>(p_B1s[i]); const auto p_b1_grid = static_cast<const InputDataType*>(p_B1s[i]);
const auto p_c_grid = static_cast<const InputDataType*>(p_Cs[i]); const auto p_c_grid = static_cast<const InputDataType*>(p_Cs[i]);
...@@ -792,6 +822,23 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -792,6 +822,23 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides); problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides);
const auto b_grid_desc_bk0_n_bk1 = DeviceOp::MakeBGridDescriptor_BK0_N_BK1( const auto b_grid_desc_bk0_n_bk1 = DeviceOp::MakeBGridDescriptor_BK0_N_BK1(
problem_desc.b_gs_ns_ks_lengths, problem_desc.b_gs_ns_ks_strides); problem_desc.b_gs_ns_ks_lengths, problem_desc.b_gs_ns_ks_strides);
std::vector<index_t> tmp_d0_gs_ms_ns_lengths;
std::vector<index_t> tmp_d0_gs_ms_ns_strides;
if constexpr(!is_same<D0DataType, void>::value)
{
tmp_d0_gs_ms_ns_lengths = problem_desc.acc0_biases_gs_ms_ns_lengths;
tmp_d0_gs_ms_ns_strides = problem_desc.acc0_biases_gs_ms_ns_strides;
}
else
{
tmp_d0_gs_ms_ns_lengths = {1, 1, 1, 1};
tmp_d0_gs_ms_ns_strides = {0, 0, 0, 0};
}
const D0GridDesc_M_N d0_grid_desc_m_n{DeviceOp::MakeD0GridDescriptor_M_N(
tmp_d0_gs_ms_ns_lengths, tmp_d0_gs_ms_ns_strides)};
const auto d0_grid_desc_m0_n0_m1_m2_n1_m3 =
GridwiseGemm::MakeD0GridDescriptor_M0_N0_M1_M2_N1_M3(d0_grid_desc_m_n);
const auto z_grid_desc_m_n = DeviceOp::MakeZGridDescriptor_M_N( const auto z_grid_desc_m_n = DeviceOp::MakeZGridDescriptor_M_N(
problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides); problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides);
const auto b1_grid_desc_bk0_n_bk1 = DeviceOp::MakeVGridDescriptor_O0_N_O1( const auto b1_grid_desc_bk0_n_bk1 = DeviceOp::MakeVGridDescriptor_O0_N_O1(
...@@ -811,6 +858,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -811,6 +858,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides); problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides);
const auto b_grid_desc_g_n_k = Transform::MakeB0GridDescriptor_G_N_K( const auto b_grid_desc_g_n_k = Transform::MakeB0GridDescriptor_G_N_K(
problem_desc.b_gs_ns_ks_lengths, problem_desc.b_gs_ns_ks_strides); problem_desc.b_gs_ns_ks_lengths, problem_desc.b_gs_ns_ks_strides);
const auto d0_grid_desc_g_m_n = DeviceOp::MakeD0GridDescriptor_G_M_N(
tmp_d0_gs_ms_ns_lengths, tmp_d0_gs_ms_ns_strides);
const auto z_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N( const auto z_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N(
problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides); problem_desc.z_gs_ms_ns_lengths, problem_desc.z_gs_ms_ns_strides);
const auto b1_grid_desc_g_n_k = Transform::MakeB1GridDescriptor_G_N_K( const auto b1_grid_desc_g_n_k = Transform::MakeB1GridDescriptor_G_N_K(
...@@ -847,6 +896,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -847,6 +896,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
const auto compute_base_ptr_of_batch = ComputeBasePtrOfStridedBatch( const auto compute_base_ptr_of_batch = ComputeBasePtrOfStridedBatch(
a_grid_desc_g_m_k, a_grid_desc_g_m_k,
b_grid_desc_g_n_k, b_grid_desc_g_n_k,
d0_grid_desc_g_m_n,
z_grid_desc_g_m_n, z_grid_desc_g_m_n,
b1_grid_desc_g_n_k, b1_grid_desc_g_n_k,
c_grid_desc_g_m_n, c_grid_desc_g_m_n,
...@@ -865,6 +915,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -865,6 +915,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
group_kernel_args_.push_back({p_a_grid, group_kernel_args_.push_back({p_a_grid,
p_b_grid, p_b_grid,
p_d0_grid,
p_z_grid, p_z_grid,
p_b1_grid, p_b1_grid,
p_c_grid, p_c_grid,
...@@ -875,6 +926,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -875,6 +926,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
p_vgrad_grid, p_vgrad_grid,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
d0_grid_desc_m0_n0_m1_m2_n1_m3,
z_grid_desc_m_n, z_grid_desc_m_n,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
y_grid_desc_m_o, y_grid_desc_m_o,
...@@ -896,6 +948,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -896,6 +948,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
z_random_matrix_offset = z_random_matrix_offset =
z_random_matrix_offset + raw_m_padded * raw_n_padded * batch_count; z_random_matrix_offset + raw_m_padded * raw_n_padded * batch_count;
// for check
std::vector<ck::index_t> d0_n_length_stride;
d0_n_length_stride.push_back(tmp_d0_gs_ms_ns_lengths[NumDimG + NumDimM]);
d0_n_length_stride.push_back(tmp_d0_gs_ms_ns_strides[NumDimG + NumDimM]);
group_device_args_.push_back( group_device_args_.push_back(
{{problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1], {{problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1],
problem_desc.b_gs_ns_ks_lengths[NumDimG + NumDimN - 1], problem_desc.b_gs_ns_ks_lengths[NumDimG + NumDimN - 1],
...@@ -910,7 +967,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1 ...@@ -910,7 +967,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
{problem_desc.c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1], {problem_desc.c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1],
problem_desc.c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]}, problem_desc.c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]},
c_grid_desc_g_m_n, c_grid_desc_g_m_n,
batch_count}); batch_count,
d0_n_length_stride});
} }
// TODO: implement bias addition // TODO: implement bias addition
// ignore = p_acc0_biases; // ignore = p_acc0_biases;
......
...@@ -708,7 +708,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -708,7 +708,6 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_M2_N1_M3 d0_grid_desc_m0_n0_m1_m2_n1_m3_; typename GridwiseGemm::D0GridDescriptor_M0_N0_M1_M2_N1_M3 d0_grid_desc_m0_n0_m1_m2_n1_m3_;
;
ZGridDesc_M_N z_grid_desc_m_n_; ZGridDesc_M_N z_grid_desc_m_n_;
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
YGridDesc_M_O y_grid_desc_m_o_; YGridDesc_M_O y_grid_desc_m_o_;
...@@ -835,9 +834,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2 ...@@ -835,9 +834,9 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides); problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides);
const auto b_grid_desc_bk0_n_bk1 = DeviceOp::MakeBGridDescriptor_BK0_N_BK1( const auto b_grid_desc_bk0_n_bk1 = DeviceOp::MakeBGridDescriptor_BK0_N_BK1(
problem_desc.b_gs_ns_ks_lengths, problem_desc.b_gs_ns_ks_strides); problem_desc.b_gs_ns_ks_lengths, problem_desc.b_gs_ns_ks_strides);
std::vector<index_t> tmp_d0_gs_ms_ns_lengths; std::vector<index_t> tmp_d0_gs_ms_ns_lengths;
std::vector<index_t> tmp_d0_gs_ms_ns_strides; std::vector<index_t> tmp_d0_gs_ms_ns_strides;
if constexpr(!is_same<D0DataType, void>::value) if constexpr(!is_same<D0DataType, void>::value)
{ {
tmp_d0_gs_ms_ns_lengths = problem_desc.acc0_biases_gs_ms_ns_lengths; tmp_d0_gs_ms_ns_lengths = problem_desc.acc0_biases_gs_ms_ns_lengths;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment