Commit bd8a8a1d authored by Jing Zhang's avatar Jing Zhang
Browse files

changed example

parent b7bc3c2b
...@@ -41,16 +41,18 @@ using BLayout = Col; ...@@ -41,16 +41,18 @@ using BLayout = Col;
using DLayout = Row; using DLayout = Row;
using ELayout = Row; using ELayout = Row;
struct MultiATest struct Add
{ {
template <typename A, typename A0, typename A1>
__host__ __device__ constexpr void operator()(A& a, const A0& a0, const A1& a1) const;
template <>
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()(ck::half2_t& a, const ck::half2_t& a0, const ck::half2_t& a1) const operator()(ck::half2_t& a, const ck::half2_t& a0, const ck::half2_t& a1) const
{ {
a = (a0 + a1) / 2; a = a0 + a1;
}
__host__ __device__ constexpr void
operator()(ck::half_t& a, const ck::half_t& a0, const ck::half_t& a1) const
{
a = a0 + a1;
} }
static constexpr ck::index_t vec_len = 2; static constexpr ck::index_t vec_len = 2;
...@@ -74,7 +76,7 @@ struct AlphaBetaAdd ...@@ -74,7 +76,7 @@ struct AlphaBetaAdd
float beta_; float beta_;
}; };
using AElementOp = MultiATest; using AElementOp = Add;
using BElementOp = PassThrough; using BElementOp = PassThrough;
using CDEElementOp = AlphaBetaAdd; using CDEElementOp = AlphaBetaAdd;
...@@ -205,13 +207,15 @@ int main(int argc, char* argv[]) ...@@ -205,13 +207,15 @@ int main(int argc, char* argv[])
} }
}; };
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor<ADataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<ADataType> a1_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<DDataType> d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{})); Tensor<DDataType> d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl;
std::cout << "a1_m_k: " << a1_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; std::cout << "d_m_n: " << d_m_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
...@@ -220,22 +224,26 @@ int main(int argc, char* argv[]) ...@@ -220,22 +224,26 @@ int main(int argc, char* argv[])
{ {
case 0: break; case 0: break;
case 1: case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}); a0_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
a1_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}); b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_m_n.GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5}); d_m_n.GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
break; break;
default: default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a0_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
a1_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_m_n.GenerateTensorValue(GeneratorTensor_3<DDataType>{-0.5, 0.5}); d_m_n.GenerateTensorValue(GeneratorTensor_3<DDataType>{-0.5, 0.5});
} }
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); DeviceMem a0_device_buf(sizeof(ADataType) * a0_m_k.mDesc.GetElementSpaceSize());
DeviceMem a1_device_buf(sizeof(ADataType) * a1_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize()); DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data()); a0_device_buf.ToDevice(a0_m_k.mData.data());
a1_device_buf.ToDevice(a1_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data());
d_device_buf.ToDevice(d_m_n.mData.data()); d_device_buf.ToDevice(d_m_n.mData.data());
e_device_buf.ToDevice(e_m_n_device_result.mData.data()); e_device_buf.ToDevice(e_m_n_device_result.mData.data());
...@@ -247,21 +255,22 @@ int main(int argc, char* argv[]) ...@@ -247,21 +255,22 @@ int main(int argc, char* argv[])
// do GEMM // do GEMM
auto device_op = DeviceOpInstance{}; auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker(); auto invoker = device_op.MakeInvoker();
auto argument = device_op.MakeArgument( auto argument =
std::array<const void*, 2>{a_device_buf.GetDeviceBuffer(), a_device_buf.GetDeviceBuffer()}, device_op.MakeArgument(std::array<const void*, 2>{a0_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{b_device_buf.GetDeviceBuffer()}, a1_device_buf.GetDeviceBuffer()},
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()}, std::array<const void*, 1>{b_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(), std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
M, e_device_buf.GetDeviceBuffer(),
N, M,
K, N,
std::array<ck::index_t, 2>{StrideA, StrideA}, K,
std::array<ck::index_t, 1>{StrideB}, std::array<ck::index_t, 2>{StrideA, StrideA},
std::array<ck::index_t, 1>{StrideD}, std::array<ck::index_t, 1>{StrideB},
StrideE, std::array<ck::index_t, 1>{StrideD},
a_element_op, StrideE,
b_element_op, a_element_op,
cde_element_op); b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument)) if(!device_op.IsSupportedArgument(argument))
{ {
...@@ -289,6 +298,16 @@ int main(int argc, char* argv[]) ...@@ -289,6 +298,16 @@ int main(int argc, char* argv[])
{ {
Tensor<CShuffleDataType> c_m_n({M, N}); Tensor<CShuffleDataType> c_m_n({M, N});
Tensor<ADataType> a_m_k({M, K});
for(int m = 0; m < M; ++m)
{
for(int k = 0; k < K; ++k)
{
a_element_op(a_m_k(m, k), a0_m_k(m, k), a1_m_k(m, k));
}
}
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType, BDataType,
CShuffleDataType, CShuffleDataType,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment