Commit b2df7018 authored by letaoqin's avatar letaoqin
Browse files

fix ComputeBasePtrOfStridedBatch init bug

parent 4b0a5069
...@@ -181,40 +181,40 @@ int run(int argc, char* argv[]) ...@@ -181,40 +181,40 @@ int run(int argc, char* argv[])
// do GEMM // do GEMM
// TODO ANT: replace array with vector? // TODO ANT: replace array with vector?
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument( auto argument =
static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()), static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()), static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(nullptr), static_cast<ZDataType*>(nullptr),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()), static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
static_cast<DDataType*>(d_device_buf.GetDeviceBuffer()), // static_cast<DDataType*>(d_device_buf.GetDeviceBuffer()), //
nullptr, nullptr,
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b0_gs_ns_ks_lengths, b0_gs_ns_ks_lengths,
b0_gs_ns_ks_strides, b0_gs_ns_ks_strides,
b1_gs_os_ns_lengths, b1_gs_os_ns_lengths,
b1_gs_os_ns_strides, b1_gs_os_ns_strides,
c_gs_ms_os_lengths, c_gs_ms_os_lengths,
c_gs_ms_os_strides, c_gs_ms_os_strides,
z_gs_ms_ns_lengths, z_gs_ms_ns_lengths,
z_gs_ms_ns_strides, z_gs_ms_ns_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
d_gs_ms_ns_lengths, // acc0_biases_gs_ms_ns_lengths d_gs_ms_ns_lengths, // acc0_biases_gs_ms_ns_lengths
d_gs_ms_ns_strides, // acc0_biases_gs_ms_ns_strides d_gs_ms_ns_strides, // acc0_biases_gs_ms_ns_strides
{}, // std::vector<ck::index_t> {}, // std::vector<ck::index_t>
{}, // std::vector<ck::index_t> {}, // std::vector<ck::index_t>
a_element_op, a_element_op,
b0_element_op, b0_element_op,
acc0_element_op, acc0_element_op,
b1_element_op, b1_element_op,
c_element_op, c_element_op,
p_drop, // dropout ratio p_drop, // dropout ratio
{seed, offset}); // dropout random seed and offset, offset should be at least the number of {seed, offset}); // dropout random seed and offset, offset should be at
// elements on a thread // least the number of elements on a thread
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
...@@ -229,8 +229,10 @@ int run(int argc, char* argv[]) ...@@ -229,8 +229,10 @@ int run(int argc, char* argv[])
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount; std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O + sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O +
sizeof(DDataType) * M * N * std::is_void<DDataType>::value?1:0) * sizeof(DDataType) * M * N * std::is_void<DDataType>::value
? 0
: 1) *
BatchCount; BatchCount;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
...@@ -243,38 +245,38 @@ int run(int argc, char* argv[]) ...@@ -243,38 +245,38 @@ int run(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
// run for storing z tensor // run for storing z tensor
argument = gemm.MakeArgument( argument =
static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()), static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()), static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(z_device_buf.GetDeviceBuffer()), static_cast<ZDataType*>(z_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()), static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
static_cast<DDataType*>(d_device_buf.GetDeviceBuffer()), static_cast<DDataType*>(d_device_buf.GetDeviceBuffer()),
nullptr, nullptr,
a_gs_ms_ks_lengths, a_gs_ms_ks_lengths,
a_gs_ms_ks_strides, a_gs_ms_ks_strides,
b0_gs_ns_ks_lengths, b0_gs_ns_ks_lengths,
b0_gs_ns_ks_strides, b0_gs_ns_ks_strides,
b1_gs_os_ns_lengths, b1_gs_os_ns_lengths,
b1_gs_os_ns_strides, b1_gs_os_ns_strides,
c_gs_ms_os_lengths, c_gs_ms_os_lengths,
c_gs_ms_os_strides, c_gs_ms_os_strides,
z_gs_ms_ns_lengths, z_gs_ms_ns_lengths,
z_gs_ms_ns_strides, z_gs_ms_ns_strides,
lse_gs_ms_lengths, lse_gs_ms_lengths,
d_gs_ms_ns_lengths, d_gs_ms_ns_lengths,
d_gs_ms_ns_strides, d_gs_ms_ns_strides,
{}, {},
{}, {},
a_element_op, a_element_op,
b0_element_op, b0_element_op,
acc0_element_op, acc0_element_op,
b1_element_op, b1_element_op,
c_element_op, c_element_op,
p_drop, // dropout ratio p_drop, // dropout ratio
{seed, offset}); // dropout random seed and offset, offset should be at least the number {seed, offset}); // dropout random seed and offset, offset should be
// of elements on a thread // at least the number of elements on a thread
c_device_buf.SetZero(); c_device_buf.SetZero();
lse_device_buf.SetZero(); lse_device_buf.SetZero();
invoker.Run(argument, StreamConfig{nullptr, false}); invoker.Run(argument, StreamConfig{nullptr, false});
......
...@@ -697,6 +697,15 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -697,6 +697,15 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
d0_grid_desc_g_m_n_ = Transform::MakeCGridDescriptor_G_M_N( d0_grid_desc_g_m_n_ = Transform::MakeCGridDescriptor_G_M_N(
acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_strides); acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_strides);
compute_base_ptr_of_batch_ = ComputeBasePtrOfStridedBatch(
a_grid_desc_g_m_k_,
b_grid_desc_g_n_k_,
d0_grid_desc_g_m_n_,
b1_grid_desc_g_n_k_,
c_grid_desc_g_m_n_,
z_grid_desc_g_m_n_,
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize()));
d0_n_length_stride_.push_back(acc0_biases_gs_ms_ns_lengths[NumDimG + NumDimM]); d0_n_length_stride_.push_back(acc0_biases_gs_ms_ns_lengths[NumDimG + NumDimM]);
d0_n_length_stride_.push_back(acc0_biases_gs_ms_ns_strides[NumDimG + NumDimM]); d0_n_length_stride_.push_back(acc0_biases_gs_ms_ns_strides[NumDimG + NumDimM]);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment