Unverified Commit 53ea4713 authored by Qianfeng's avatar Qianfeng Committed by GitHub
Browse files

Batchnorm-forward and Batchnorm-infer Implemented using generic kernels (#320)

* Implement multiple-reduction in one kernel (kernels, device ops, examples)

* Add generic elementwise kernel and device interface

* Add generator for normal-distributed data initialization

* Add host refer implementation of batchnorm-forward and batchnorm-infer

* Add examples for implementing batchnorm-forward and batchnorm-infer using generic kernels

* Remove un-needed including in batchnorm example

* Renaming generic_elementwise to elementiwise in kernel and device classes/functions

* Change in gemm_layernorm examples to use DeviceElementwise instead of Device5AryElementwise

* Change in exampe 19_binary_elementwise to use DeviceElementwise instead of DeviceBinaryElementwise

* Change in device_cgemm_4gemm_xdl_cshuffle.hpp to use kernel_elementwise instead of kernel_binary_elementwise

* Add DeviceElementwiseBase and use it in device_normalize_instance.cpp

* Removing and renaming files

* Update to synchronize gemm_layernorm client example to the generic element-wise device op API

* Update to synchronize with the latest headers directory and HostTensorDescriptor interface renaming

* Merge two static member functions in device_elementwise.hpp

* Remove unary_elementwise_1d kernel and device
parent 5ee30459
...@@ -18,8 +18,11 @@ namespace device { ...@@ -18,8 +18,11 @@ namespace device {
namespace instance { namespace instance {
using Normalize = ck::tensor_operation::element_wise::Normalize; using Normalize = ck::tensor_operation::element_wise::Normalize;
using DeviceNormalizeFromMeanMeanSquarePtr = using DeviceNormalizeFromMeanMeanSquarePtr = ck::tensor_operation::device::DeviceElementwiseBasePtr<
ck::tensor_operation::device::DeviceElementwisePtr<5, 1, 2, Normalize>; Tuple<half_t, float, float, half_t, half_t>,
Tuple<half_t>,
Normalize,
2>;
void add_device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances( void add_device_normalize_from_mean_squaremean_f16_f32_f32_f16_f16_instances(
std::vector<DeviceNormalizeFromMeanMeanSquarePtr>& instances); std::vector<DeviceNormalizeFromMeanMeanSquarePtr>& instances);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment