Commit 499dfe39 authored by letaoqin's avatar letaoqin
Browse files

change name to NumD0Tensor

parent 381a7317
...@@ -312,13 +312,13 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -312,13 +312,13 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0,
"Number of dimension must be greater than 0"); "Number of dimension must be greater than 0");
static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size(); static constexpr index_t NumD0Tensor = Acc0BiasDataType::Size();
static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size(); static constexpr index_t NumD1Tensor = Acc1BiasDataType::Size();
// TODO ANT: implement bias combination // TODO ANT: implement bias combination
static_assert(NumAcc0Bias <= 1, "Acc0 Bias addition is max support one bias"); static_assert(NumD0Tensor <= 1, "Acc0 Bias addition is max support one bias");
static_assert(NumAcc1Bias == 0, "Acc1 Bias addition is unimplemented"); static_assert(NumD1Tensor == 0, "Acc1 Bias addition is unimplemented");
static_assert(NumAcc1Bias == 0 static_assert(NumD1Tensor == 0
? true ? true
: std::is_same_v<ADataType, ck::tuple_element_t<0, Acc0BiasDataType>>); : std::is_same_v<ADataType, ck::tuple_element_t<0, Acc0BiasDataType>>);
using DDataType = ADataType; using DDataType = ADataType;
...@@ -580,8 +580,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -580,8 +580,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
CDataType* p_c_grid, CDataType* p_c_grid,
ZDataType* p_z_grid, ZDataType* p_z_grid,
LSEDataType* p_lse_grid, LSEDataType* p_lse_grid,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const std::array<void*, NumD0Tensor> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases, const std::array<void*, NumD1Tensor> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -593,11 +593,11 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -593,11 +593,11 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
const std::vector<index_t>& z_gs_ms_ns_lengths, const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides, const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, const std::array<std::vector<ck::index_t>, NumD0Tensor> acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, const std::array<std::vector<ck::index_t>, NumD0Tensor> acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::array<std::vector<ck::index_t>, NumD1Tensor>
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::array<std::vector<ck::index_t>, NumD1Tensor>
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -610,7 +610,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -610,7 +610,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_b1_grid_{p_b1_grid}, p_b1_grid_{p_b1_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
p_d_grid_{NumAcc0Bias == 0 ? nullptr p_d_grid_{NumD0Tensor == 0 ? nullptr
: static_cast<const DDataType*>(p_acc0_biases[0])}, : static_cast<const DDataType*>(p_acc0_biases[0])},
p_z_grid_{p_z_grid}, p_z_grid_{p_z_grid},
p_lse_grid_{p_lse_grid}, p_lse_grid_{p_lse_grid},
...@@ -622,7 +622,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -622,7 +622,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)}, b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
c_grid_desc_m_n_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths, c_grid_desc_m_n_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)}, c_gs_ms_gemm1ns_strides)},
d_grid_desc_m_n_{NumAcc0Bias == 0 d_grid_desc_m_n_{NumD0Tensor == 0
? DGridDesc_M_N{} ? DGridDesc_M_N{}
: MakeZGridDescriptor_M_N(acc0_biases_gs_ms_ns_lengths[0], : MakeZGridDescriptor_M_N(acc0_biases_gs_ms_ns_lengths[0],
acc0_biases_gs_ms_ns_strides[0])}, acc0_biases_gs_ms_ns_strides[0])},
...@@ -636,7 +636,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -636,7 +636,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)}, b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)},
c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths, c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths,
c_gs_ms_gemm1ns_strides)}, c_gs_ms_gemm1ns_strides)},
d_grid_desc_g_m_n_{NumAcc0Bias == 0 ? DGridDesc_G_M_N{} d_grid_desc_g_m_n_{NumD0Tensor == 0 ? DGridDesc_G_M_N{}
: Transform::MakeCGridDescriptor_G_M_N( : Transform::MakeCGridDescriptor_G_M_N(
acc0_biases_gs_ms_ns_lengths[0], acc0_biases_gs_ms_ns_lengths[0],
acc0_biases_gs_ms_ns_strides[0])}, acc0_biases_gs_ms_ns_strides[0])},
...@@ -1058,8 +1058,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -1058,8 +1058,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
CDataType* p_c, CDataType* p_c,
ZDataType* p_z, ZDataType* p_z,
LSEDataType* p_lse, LSEDataType* p_lse,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const std::array<void*, NumD0Tensor> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases, const std::array<void*, NumD1Tensor> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -1071,11 +1071,11 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -1071,11 +1071,11 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
const std::vector<index_t>& z_gs_ms_ns_lengths, const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides, const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, const std::array<std::vector<ck::index_t>, NumD0Tensor> acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, const std::array<std::vector<ck::index_t>, NumD0Tensor> acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::array<std::vector<ck::index_t>, NumD1Tensor>
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::array<std::vector<ck::index_t>, NumD1Tensor>
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
...@@ -1128,8 +1128,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -1128,8 +1128,8 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
void* p_c, void* p_c,
void* p_z, void* p_z,
void* p_lse, void* p_lse,
const std::array<void*, NumAcc0Bias> p_acc0_biases, const std::array<void*, NumD0Tensor> p_acc0_biases,
const std::array<void*, NumAcc1Bias> p_acc1_biases, const std::array<void*, NumD1Tensor> p_acc1_biases,
const std::vector<index_t>& a_gs_ms_ks_lengths, const std::vector<index_t>& a_gs_ms_ks_lengths,
const std::vector<index_t>& a_gs_ms_ks_strides, const std::vector<index_t>& a_gs_ms_ks_strides,
const std::vector<index_t>& b_gs_ns_ks_lengths, const std::vector<index_t>& b_gs_ns_ks_lengths,
...@@ -1141,11 +1141,11 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2 ...@@ -1141,11 +1141,11 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2R2
const std::vector<index_t>& z_gs_ms_ns_lengths, const std::vector<index_t>& z_gs_ms_ns_lengths,
const std::vector<index_t>& z_gs_ms_ns_strides, const std::vector<index_t>& z_gs_ms_ns_strides,
const std::vector<index_t>& lse_gs_ms_lengths, const std::vector<index_t>& lse_gs_ms_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, const std::array<std::vector<ck::index_t>, NumD0Tensor> acc0_biases_gs_ms_ns_lengths,
const std::array<std::vector<ck::index_t>, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, const std::array<std::vector<ck::index_t>, NumD0Tensor> acc0_biases_gs_ms_ns_strides,
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::array<std::vector<ck::index_t>, NumD1Tensor>
acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths
const std::array<std::vector<ck::index_t>, NumAcc1Bias> const std::array<std::vector<ck::index_t>, NumD1Tensor>
acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment