Commit 5729c23c authored by Chao Liu's avatar Chao Liu
Browse files

update example

parent b2bf7d93
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "ck/tensor_operation/gpu/device/device_layernorm_impl.hpp" #include "ck/tensor_operation/gpu/device/device_layernorm_impl.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" #include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_common_util.hpp" #include "ck/library/utility/host_common_util.hpp"
...@@ -67,22 +68,41 @@ using DeviceInstance = ...@@ -67,22 +68,41 @@ using DeviceInstance =
8, // BetaScalarPerVector 8, // BetaScalarPerVector
8>; // OutScalarPerVector 8>; // OutScalarPerVector
int main() int main(int argc, char* argv[])
{ {
ck::index_t N = 1; ck::index_t N = 128;
ck::index_t H = 16; ck::index_t H = 16;
ck::index_t W = 16; ck::index_t W = 16;
ck::index_t G = 32; ck::index_t G = 32;
ck::index_t C = 40; ck::index_t C = 40;
if(argc == 1)
{
// use default case
}
else if(argc == 6)
{
N = std::stoi(argv[1]);
H = std::stoi(argv[2]);
W = std::stoi(argv[3]);
G = std::stoi(argv[4]);
C = std::stoi(argv[5]);
}
else
{
std::cerr << "arg1 to 5: N, H, W, G, C" << std::endl;
return 1;
}
Tensor<XDataType> x({N, H, W, G, C}); Tensor<XDataType> x({N, H, W, G, C});
Tensor<YDataType> y({N, H, W, G, C}); Tensor<YDataType> y({N, H, W, G, C});
Tensor<GammaDataType> gamma({G, C}); Tensor<GammaDataType> gamma({G, C});
Tensor<BetaDataType> beta({G, C}); Tensor<BetaDataType> beta({G, C});
x.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0}); ck::utils::FillUniformDistribution<XDataType>{0.f, 1.f}(x.begin(), x.end());
gamma.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{0.0, 1.0}); ck::utils::FillUniformDistribution<GammaDataType>{0.f, 1.f}(gamma.begin(), gamma.end());
beta.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{0.0, 1.0}); ck::utils::FillUniformDistribution<BetaDataType>{0.f, 1.f}(beta.begin(), beta.end());
DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize()); DeviceMem x_dev(sizeof(XDataType) * x.mDesc.GetElementSpaceSize());
DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize());
...@@ -116,9 +136,8 @@ int main() ...@@ -116,9 +136,8 @@ int main()
return 1; return 1;
}; };
bool time_kernel = true;
auto invoker_ptr = device_instance.MakeInvokerPointer(); auto invoker_ptr = device_instance.MakeInvokerPointer();
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel, true}); float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true, true});
std::size_t num_btype = sizeof(XDataType) * N * H * W * G * C + std::size_t num_btype = sizeof(XDataType) * N * H * W * G * C +
sizeof(YDataType) * N * H * W * G * C + sizeof(GammaDataType) * G * C + sizeof(YDataType) * N * H * W * G * C + sizeof(GammaDataType) * G * C +
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment