"awq/modules/git@developer.sourcefind.cn:OpenDAS/autoawq.git" did not exist on "66adba4a3397dadd4c39d62c504902ed8d0a38dc"
Commit c75a3c17 authored by guangzlu's avatar guangzlu
Browse files

added z tensor for dropout storing

parent 6b041227
...@@ -79,6 +79,7 @@ template <index_t NumDimG, ...@@ -79,6 +79,7 @@ template <index_t NumDimG,
typename B0DataType, typename B0DataType,
typename B1DataType, typename B1DataType,
typename CDataType, typename CDataType,
typename ZDataType,
typename LSEDataType, typename LSEDataType,
typename Acc0BiasDataType, typename Acc0BiasDataType,
typename Acc1BiasDataType, typename Acc1BiasDataType,
...@@ -104,6 +105,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermuteTrain : public BaseOperator ...@@ -104,6 +105,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermuteTrain : public BaseOperator
std::vector<index_t> c_gs_ms_os_lengths; std::vector<index_t> c_gs_ms_os_lengths;
std::vector<index_t> c_gs_ms_os_strides; std::vector<index_t> c_gs_ms_os_strides;
std::vector<index_t> z_gs_ms_os_lengths;
std::vector<index_t> z_gs_ms_os_strides;
std::vector<index_t> lse_gs_ms_lengths; std::vector<index_t> lse_gs_ms_lengths;
std::vector<index_t> lse_gs_ms_strides; std::vector<index_t> lse_gs_ms_strides;
...@@ -119,6 +123,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermuteTrain : public BaseOperator ...@@ -119,6 +123,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermuteTrain : public BaseOperator
std::vector<const void*> p_b0_vec, std::vector<const void*> p_b0_vec,
std::vector<const void*> p_b1_vec, std::vector<const void*> p_b1_vec,
std::vector<void*> p_c_vec, std::vector<void*> p_c_vec,
std::vector<void*> p_z_vec,
std::vector<void*> p_lse_vec, std::vector<void*> p_lse_vec,
std::vector<std::vector<const void*>> p_acc0_biases_vec, std::vector<std::vector<const void*>> p_acc0_biases_vec,
std::vector<std::vector<const void*>> p_acc1_biases_vec, std::vector<std::vector<const void*>> p_acc1_biases_vec,
......
...@@ -92,6 +92,8 @@ __global__ void ...@@ -92,6 +92,8 @@ __global__ void
arg_ptr[group_id].compute_base_ptr_of_batch_.GetB1BasePtr(g_idx))); arg_ptr[group_id].compute_base_ptr_of_batch_.GetB1BasePtr(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetCBasePtr(g_idx))); static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetCBasePtr(g_idx)));
const long_index_t z_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(arg_ptr[group_id].compute_base_ptr_of_batch_.GetZBasePtr(g_idx)));
const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>( const long_index_t lse_batch_offset = __builtin_amdgcn_readfirstlane(static_cast<long_index_t>(
arg_ptr[group_id].compute_base_ptr_of_batch_.GetLSEBasePtr(g_idx))); arg_ptr[group_id].compute_base_ptr_of_batch_.GetLSEBasePtr(g_idx)));
...@@ -100,6 +102,7 @@ __global__ void ...@@ -100,6 +102,7 @@ __global__ void
arg_ptr[group_id].p_b_grid_ + b_batch_offset, arg_ptr[group_id].p_b_grid_ + b_batch_offset,
arg_ptr[group_id].p_b1_grid_ + b1_batch_offset, arg_ptr[group_id].p_b1_grid_ + b1_batch_offset,
arg_ptr[group_id].p_c_grid_ + c_batch_offset, arg_ptr[group_id].p_c_grid_ + c_batch_offset,
arg_ptr[group_id].p_z_grid_ + z_batch_offset,
arg_ptr[group_id].p_lse_grid_ + lse_batch_offset, arg_ptr[group_id].p_lse_grid_ + lse_batch_offset,
p_shared, p_shared,
a_element_op, a_element_op,
...@@ -111,6 +114,7 @@ __global__ void ...@@ -111,6 +114,7 @@ __global__ void
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_, arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg_ptr[group_id].z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, ////////
arg_ptr[group_id].lse_grid_desc_m_, arg_ptr[group_id].lse_grid_desc_m_,
arg_ptr[group_id].block_2_ctile_map_, arg_ptr[group_id].block_2_ctile_map_,
arg_ptr[group_id].c0_matrix_mask_, arg_ptr[group_id].c0_matrix_mask_,
...@@ -361,11 +365,13 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -361,11 +365,13 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
const BGridDesc_G_N_K& b_grid_desc_g_n_k, const BGridDesc_G_N_K& b_grid_desc_g_n_k,
const B1GridDesc_G_N_K& b1_grid_desc_g_n_k, const B1GridDesc_G_N_K& b1_grid_desc_g_n_k,
const CGridDesc_G_M_N& c_grid_desc_g_m_n, const CGridDesc_G_M_N& c_grid_desc_g_m_n,
const ZGridDesc_G_M_N& z_grid_desc_g_m_n,
index_t BatchStrideLSE) index_t BatchStrideLSE)
: a_grid_desc_g_m_k_(a_grid_desc_g_m_k), : a_grid_desc_g_m_k_(a_grid_desc_g_m_k),
b_grid_desc_g_n_k_(b_grid_desc_g_n_k), b_grid_desc_g_n_k_(b_grid_desc_g_n_k),
b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k), b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k),
c_grid_desc_g_m_n_(c_grid_desc_g_m_n), c_grid_desc_g_m_n_(c_grid_desc_g_m_n),
z_grid_desc_g_m_n_(z_grid_desc_g_m_n),
BatchStrideLSE_(BatchStrideLSE) BatchStrideLSE_(BatchStrideLSE)
{ {
} }
...@@ -390,6 +396,11 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -390,6 +396,11 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0)); return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
} }
__host__ __device__ constexpr long_index_t GetZBasePtr(index_t g_idx) const
{
return z_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0));
}
__host__ __device__ constexpr long_index_t GetLSEBasePtr(index_t g_idx) const __host__ __device__ constexpr long_index_t GetLSEBasePtr(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideLSE_); return g_idx * static_cast<long_index_t>(BatchStrideLSE_);
...@@ -400,6 +411,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -400,6 +411,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
BGridDesc_G_N_K b_grid_desc_g_n_k_; BGridDesc_G_N_K b_grid_desc_g_n_k_;
B1GridDesc_G_N_K b1_grid_desc_g_n_k_; B1GridDesc_G_N_K b1_grid_desc_g_n_k_;
CGridDesc_G_M_N c_grid_desc_g_m_n_; CGridDesc_G_M_N c_grid_desc_g_m_n_;
ZGridDesc_G_M_N& z_grid_desc_g_m_n_;
index_t BatchStrideLSE_; index_t BatchStrideLSE_;
}; };
...@@ -612,6 +624,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -612,6 +624,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
b_grid_desc_g_n_k, b_grid_desc_g_n_k,
b1_grid_desc_g_n_k, b1_grid_desc_g_n_k,
c_grid_desc_g_m_n, c_grid_desc_g_m_n,
z_grid_desc_g_m_n,
type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize())); type_convert<index_t>(lse_grid_desc_m.GetElementSpaceSize()));
// C0 mask // C0 mask
...@@ -635,11 +648,13 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -635,11 +648,13 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
p_b_grid, p_b_grid,
p_b1_grid, p_b1_grid,
p_c_grid, p_c_grid,
p_z_grid,
p_lse_grid, p_lse_grid,
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_desc_mblock_mperblock_nblock_nperblock,
z_grid_desc_g_m_n,
lse_grid_desc_m, lse_grid_desc_m,
block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n), block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n),
compute_base_ptr_of_batch, compute_base_ptr_of_batch,
...@@ -912,6 +927,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -912,6 +927,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
std::vector<const void*> p_b_vec, std::vector<const void*> p_b_vec,
std::vector<const void*> p_b1_vec, std::vector<const void*> p_b1_vec,
std::vector<void*> p_c_vec, std::vector<void*> p_c_vec,
std::vector<void*> p_z_vec,
std::vector<void*> p_lse_vec, std::vector<void*> p_lse_vec,
std::vector<std::vector<const void*>> p_acc0_biases_vec, std::vector<std::vector<const void*>> p_acc0_biases_vec,
std::vector<std::vector<const void*>> p_acc1_biases_vec, std::vector<std::vector<const void*>> p_acc1_biases_vec,
...@@ -928,6 +944,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -928,6 +944,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
p_b_vec, p_b_vec,
p_b1_vec, p_b1_vec,
p_c_vec, p_c_vec,
p_z_vec,
p_lse_vec, p_lse_vec,
p_acc0_biases_vec, p_acc0_biases_vec,
p_acc1_biases_vec, p_acc1_biases_vec,
...@@ -949,6 +966,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -949,6 +966,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
std::vector<const void*> p_b_vec, std::vector<const void*> p_b_vec,
std::vector<const void*> p_b1_vec, std::vector<const void*> p_b1_vec,
std::vector<void*> p_c_vec, std::vector<void*> p_c_vec,
std::vector<void*> p_z_vec,
std::vector<void*> p_lse_vec, std::vector<void*> p_lse_vec,
std::vector<std::vector<const void*>> p_acc0_biases_vec, std::vector<std::vector<const void*>> p_acc0_biases_vec,
std::vector<std::vector<const void*>> p_acc1_biases_vec, std::vector<std::vector<const void*>> p_acc1_biases_vec,
...@@ -965,6 +983,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle ...@@ -965,6 +983,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
p_b_vec, p_b_vec,
p_b1_vec, p_b1_vec,
p_c_vec, p_c_vec,
p_z_vec,
p_lse_vec, p_lse_vec,
p_acc0_biases_vec, p_acc0_biases_vec,
p_acc1_biases_vec, p_acc1_biases_vec,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment