Commit b2df7018 authored by letaoqin's avatar letaoqin
Browse files

fix ComputeBasePtrOfStridedBatch init bug

parent 4b0a5069
...@@ -183,8 +183,8 @@ int run(int argc, char* argv[]) ...@@ -183,8 +183,8 @@ int run(int argc, char* argv[])
// TODO ANT: replace array with vector? // TODO ANT: replace array with vector?
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument( auto argument =
static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()), static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()), static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
...@@ -213,8 +213,8 @@ int run(int argc, char* argv[]) ...@@ -213,8 +213,8 @@ int run(int argc, char* argv[])
b1_element_op, b1_element_op,
c_element_op, c_element_op,
p_drop, // dropout ratio p_drop, // dropout ratio
{seed, offset}); // dropout random seed and offset, offset should be at least the number of {seed, offset}); // dropout random seed and offset, offset should be at
// elements on a thread // least the number of elements on a thread
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
...@@ -230,7 +230,9 @@ int run(int argc, char* argv[]) ...@@ -230,7 +230,9 @@ int run(int argc, char* argv[])
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount; std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O + sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O +
sizeof(DDataType) * M * N * std::is_void<DDataType>::value?1:0) * sizeof(DDataType) * M * N * std::is_void<DDataType>::value
? 0
: 1) *
BatchCount; BatchCount;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
...@@ -243,8 +245,8 @@ int run(int argc, char* argv[]) ...@@ -243,8 +245,8 @@ int run(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
// run for storing z tensor // run for storing z tensor
argument = gemm.MakeArgument( argument =
static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()), static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()), static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
...@@ -273,8 +275,8 @@ int run(int argc, char* argv[]) ...@@ -273,8 +275,8 @@ int run(int argc, char* argv[])
b1_element_op, b1_element_op,
c_element_op, c_element_op,
p_drop, // dropout ratio p_drop, // dropout ratio
{seed, offset}); // dropout random seed and offset, offset should be at least the number {seed, offset}); // dropout random seed and offset, offset should be
// of elements on a thread // at least the number of elements on a thread
c_device_buf.SetZero(); c_device_buf.SetZero();
lse_device_buf.SetZero(); lse_device_buf.SetZero();
invoker.Run(argument, StreamConfig{nullptr, false}); invoker.Run(argument, StreamConfig{nullptr, false});
......
...@@ -697,6 +697,15 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2 ...@@ -697,6 +697,15 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
d0_grid_desc_g_m_n_ = Transform::MakeCGridDescriptor_G_M_N( d0_grid_desc_g_m_n_ = Transform::MakeCGridDescriptor_G_M_N(
acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_strides); acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_strides);
compute_base_ptr_of_batch_ = ComputeBasePtrOfStridedBatch(
a_grid_desc_g_m_k_,
b_grid_desc_g_n_k_,
d0_grid_desc_g_m_n_,
b1_grid_desc_g_n_k_,
c_grid_desc_g_m_n_,
z_grid_desc_g_m_n_,
type_convert<index_t>(lse_grid_desc_m_.GetElementSpaceSize()));
d0_n_length_stride_.push_back(acc0_biases_gs_ms_ns_lengths[NumDimG + NumDimM]); d0_n_length_stride_.push_back(acc0_biases_gs_ms_ns_lengths[NumDimG + NumDimM]);
d0_n_length_stride_.push_back(acc0_biases_gs_ms_ns_strides[NumDimG + NumDimM]); d0_n_length_stride_.push_back(acc0_biases_gs_ms_ns_strides[NumDimG + NumDimM]);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment