Commit 29448ffd authored by Harisankar Sadasivan's avatar Harisankar Sadasivan
Browse files

merge from develop and revisison for pr#881

parents 9223a5e2 8f84a012
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
...@@ -172,18 +172,19 @@ int main() ...@@ -172,18 +172,19 @@ int main()
BLayout, BLayout,
CLayout>(); CLayout>();
const auto normalize_ptrs =
ck::tensor_operation::device::instance::get_device_normalize_from_mean_meansquare_instances<
CDataType,
ReduceDataType,
ReduceDataType,
GammaDataType,
BetaDataType,
LayerNormOutDataType>();
std::cout << "found " << gemm_reduce_ptrs.size() std::cout << "found " << gemm_reduce_ptrs.size()
<< " gemm_reduceMean_reduceSquareMean instances" << std::endl; << " gemm_reduceMean_reduceSquareMean instances" << std::endl;
using NormalizeDeviceOp = ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<CDataType, ReduceDataType, ReduceDataType, GammaDataType, BetaDataType>,
ck::Tuple<LayerNormOutDataType>,
ck::tensor_operation::element_wise::Normalize,
2>;
const auto normalize_ptrs =
ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
NormalizeDeviceOp>::GetInstances();
std::cout << "found " << normalize_ptrs.size() << " normalize instances" << std::endl; std::cout << "found " << normalize_ptrs.size() << " normalize instances" << std::endl;
auto f_matrix_space_size = auto f_matrix_space_size =
......
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
...@@ -100,6 +100,10 @@ int main(int argc, char* argv[]) ...@@ -100,6 +100,10 @@ int main(int argc, char* argv[])
if(op_ptr->IsSupportedArgument(argument_ptr.get())) if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{ {
size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
SimpleDeviceMem workspace(workspace_sz);
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer());
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
std::size_t num_byte = sizeof(XDataType) * M * N + sizeof(GammaDataType) * N + std::size_t num_byte = sizeof(XDataType) * M * N + sizeof(GammaDataType) * N +
...@@ -153,6 +157,10 @@ int main(int argc, char* argv[]) ...@@ -153,6 +157,10 @@ int main(int argc, char* argv[])
if(op_ptr->IsSupportedArgument(argument_ptr.get())) if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{ {
size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
SimpleDeviceMem workspace(workspace_sz);
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer());
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
} }
......
File mode changed from 100644 to 100755
...@@ -53,12 +53,35 @@ int main(int argc, char* argv[]) ...@@ -53,12 +53,35 @@ int main(int argc, char* argv[])
SimpleDeviceMem in(sizeof(InDataType) * num_elements); SimpleDeviceMem in(sizeof(InDataType) * num_elements);
SimpleDeviceMem out(sizeof(OutDataType) * num_elements); SimpleDeviceMem out(sizeof(OutDataType) * num_elements);
using DeviceOp = ck::tensor_operation::device:: using DeviceOp = ck::tensor_operation::device::DeviceSoftmax<InDataType,
DeviceSoftmax<InDataType, AccDataType, OutDataType, PassThrough, PassThrough, Rank>; AccDataType,
OutDataType,
PassThrough,
PassThrough,
Rank,
NumReduceDim>;
// get device op instances // get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances(); DeviceOp>::GetInstances();
auto& generic_op_ptr = op_ptrs[0];
auto generic_argument_ptr = generic_op_ptr->MakeArgumentPointer(in_lengths,
in_strides,
reduce_dims,
alpha,
beta,
in.GetDeviceBuffer(),
out.GetDeviceBuffer(),
PassThrough{},
PassThrough{});
if(!generic_op_ptr->IsSupportedArgument(generic_argument_ptr.get()))
{
throw std::runtime_error(
"The generic kernel instance should be able to support any input shapes");
};
std::cout << "found " << op_ptrs.size() << " instances" << std::endl; std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
std::string best_op_name; std::string best_op_name;
...@@ -74,11 +97,6 @@ int main(int argc, char* argv[]) ...@@ -74,11 +97,6 @@ int main(int argc, char* argv[])
{ {
auto& op_ptr = op_ptrs[i]; auto& op_ptr = op_ptrs[i];
if(op_ptr->GetRank() != Rank || op_ptr->GetNumReduceDim() != NumReduceDim)
{
continue;
}
auto argument_ptr = op_ptr->MakeArgumentPointer(in_lengths, auto argument_ptr = op_ptr->MakeArgumentPointer(in_lengths,
in_strides, in_strides,
reduce_dims, reduce_dims,
......
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment