Commit de6a70f7 authored by Jing Zhang's avatar Jing Zhang
Browse files

add ds

parent 1d11426a
...@@ -33,9 +33,9 @@ using CShuffleDataType = F16; ...@@ -33,9 +33,9 @@ using CShuffleDataType = F16;
using DsDataType = ck::Tuple<>; using DsDataType = ck::Tuple<>;
using EDataType = F16; using EDataType = F16;
using ALayout = Row; using ALayout = Row;
using BLayout = Col; using BLayout = Col;
using ELayout = Row; using ELayout = Row;
using AElementOp = PassThrough; using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
...@@ -63,9 +63,9 @@ int main(int argc, char* argv[]) ...@@ -63,9 +63,9 @@ int main(int argc, char* argv[])
int init_method = 1; int init_method = 1;
bool time_kernel = false; bool time_kernel = false;
const int M = 256; const int M = 256 * (rand() % 16 + 1);
const int N = 128; const int N = 128 * (rand() % 16 + 1);
const int K = 64; const int K = 64 * (rand() % 16 + 1);
const int stride_A = K; const int stride_A = K;
const int stride_B = K; const int stride_B = K;
...@@ -112,12 +112,12 @@ int main(int argc, char* argv[]) ...@@ -112,12 +112,12 @@ int main(int argc, char* argv[])
Tensor<ADataType> a_g_m_k(f_host_tensor_descriptor(batch_count, M, K, stride_A, ALayout{})); Tensor<ADataType> a_g_m_k(f_host_tensor_descriptor(batch_count, M, K, stride_A, ALayout{}));
Tensor<BDataType> b_g_k_n(f_host_tensor_descriptor(batch_count, K, N, stride_B, BLayout{})); Tensor<BDataType> b_g_k_n(f_host_tensor_descriptor(batch_count, K, N, stride_B, BLayout{}));
Tensor<EDataType> c_g_m_n_device_result( Tensor<EDataType> e_g_m_n_device_result(
f_host_tensor_descriptor(batch_count, M, N, stride_C, ELayout{})); f_host_tensor_descriptor(batch_count, M, N, stride_C, ELayout{}));
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl; std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl;
std::cout << "c_g_m_n: " << c_g_m_n_device_result.mDesc << std::endl; std::cout << "e_g_m_n: " << e_g_m_n_device_result.mDesc << std::endl;
switch(init_method) switch(init_method)
{ {
...@@ -134,35 +134,38 @@ int main(int argc, char* argv[]) ...@@ -134,35 +134,38 @@ int main(int argc, char* argv[])
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace()); DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(EDataType) * c_g_m_n_device_result.mDesc.GetElementSpace()); DeviceMem c_device_buf(sizeof(EDataType) * e_g_m_n_device_result.mDesc.GetElementSpace());
a_device_buf.ToDevice(a_g_m_k.mData.data()); a_device_buf.ToDevice(a_g_m_k.mData.data());
b_device_buf.ToDevice(b_g_k_n.mData.data()); b_device_buf.ToDevice(b_g_k_n.mData.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto c_element_op = CDEElementOp{}; auto cde_element_op = CDEElementOp{};
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
// do GEMM // do GEMM
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), auto argument = gemm.MakeArgument(a_device_buf.GetDeviceBuffer(),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()), b_device_buf.GetDeviceBuffer(),
static_cast<EDataType*>(c_device_buf.GetDeviceBuffer()), {},
c_device_buf.GetDeviceBuffer(),
M, M,
N, N,
K, K,
stride_A, stride_A,
stride_B, stride_B,
{},
stride_C, stride_C,
batch_stride_A, batch_stride_A,
batch_stride_B, batch_stride_B,
{},
batch_stride_C, batch_stride_C,
batch_count, batch_count,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op); cde_element_op);
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
...@@ -189,32 +192,21 @@ int main(int argc, char* argv[]) ...@@ -189,32 +192,21 @@ int main(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
c_device_buf.FromDevice(c_g_m_n_device_result.mData.data()); c_device_buf.FromDevice(e_g_m_n_device_result.mData.data());
auto ref_batched_gemm = ReferenceBatchedGemmInstance{}; auto ref_batched_gemm = ReferenceBatchedGemmInstance{};
auto ref_invoker = ref_batched_gemm.MakeInvoker(); auto ref_invoker = ref_batched_gemm.MakeInvoker();
Tensor<EDataType> c_g_m_n_host_result = HostTensorDescriptor( Tensor<EDataType> e_g_m_n_host_result(
std::vector<std::size_t>({batch_count, M, N}), std::vector<std::size_t>({M * N, N, 1})); f_host_tensor_descriptor(batch_count, M, N, stride_C, ELayout{}));
auto ref_argument = ref_batched_gemm.MakeArgument( auto ref_argument = ref_batched_gemm.MakeArgument(
a_g_m_k, b_g_k_n, c_g_m_n_host_result, a_element_op, b_element_op, c_element_op); a_g_m_k, b_g_k_n, e_g_m_n_host_result, a_element_op, b_element_op, cde_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
//for(int b = 0; b < batch_count; b++)
//{
//for(int m = 0; m < M; m++)
//{
//for(int n = 0; n < N; n++)
//{
//c_g_m_n_host_result(b, m, n) = c_g_m_n_host_result(b, m, n);
//}
//}
//}
pass = ck::utils::check_err( pass = ck::utils::check_err(
c_g_m_n_host_result.mData, c_g_m_n_device_result.mData, "Error: Incorrect results c"); e_g_m_n_host_result.mData, e_g_m_n_device_result.mData, "Error: Incorrect results c");
} }
return pass ? 0 : 1; return pass ? 0 : 1;
......
...@@ -29,16 +29,18 @@ struct DeviceBatchedGemmMultiD : public BaseOperator ...@@ -29,16 +29,18 @@ struct DeviceBatchedGemmMultiD : public BaseOperator
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a, MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_c, void* p_c,
ck::index_t M, ck::index_t M,
ck::index_t N, ck::index_t N,
ck::index_t K, ck::index_t K,
ck::index_t StrideA, ck::index_t StrideA,
ck::index_t StrideB, ck::index_t StrideB,
//std::array<ck::index_t, NumDTensor> StrideDs, std::array<ck::index_t, NumDTensor> StrideDs,
ck::index_t StrideE, ck::index_t StrideE,
ck::index_t BatchStrideA, ck::index_t BatchStrideA,
ck::index_t BatchStrideB, ck::index_t BatchStrideB,
std::array<ck::index_t, NumDTensor> BatchStrideDs,
ck::index_t BatchStrideE, ck::index_t BatchStrideE,
ck::index_t Batch, ck::index_t Batch,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
...@@ -58,16 +60,17 @@ template <typename ALayout, ...@@ -58,16 +60,17 @@ template <typename ALayout,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation> typename CDEElementwiseOperation>
using DeviceBatchedGemmMultiDPtr = std::unique_ptr<DeviceBatchedGemmMultiD<ALayout, using DeviceBatchedGemmMultiDPtr =
BLayout, std::unique_ptr<DeviceBatchedGemmMultiD<ALayout,
CLayout, BLayout,
ADataType, CLayout,
BDataType, ADataType,
DsDataType, BDataType,
EDataType, DsDataType,
AElementwiseOperation, EDataType,
BElementwiseOperation, AElementwiseOperation,
CDEElementwiseOperation>>; BElementwiseOperation,
CDEElementwiseOperation>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment