Commit e6e8edef authored by rocking's avatar rocking
Browse files

Add groupnorm + swish instances

parent d651dc85
......@@ -15,7 +15,8 @@ namespace instance {
using F16 = ck::half_t;
using F32 = float;
using Pass = ck::tensor_operation::element_wise::PassThrough;
using Pass = ck::tensor_operation::element_wise::PassThrough;
using Swish = ck::tensor_operation::element_wise::Swish;
template <typename OutElementwise, index_t Rank, index_t Reduce>
// clang-format off
......@@ -64,6 +65,13 @@ void add_device_normalization_rank_5_3_f16_instances(
add_device_operation_instances(instances, device_normalization_f16_instances<Pass, 5, 3>{});
}
void add_device_normalization_rank_5_3_swish_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F32, F16, Swish, 5, 3>>>&
instances)
{
add_device_operation_instances(instances, device_normalization_f16_instances<Swish, 5, 3>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
......
......@@ -15,7 +15,7 @@ namespace instance {
using F32 = float;
using Pass = ck::tensor_operation::element_wise::PassThrough;
using Swish = ck::tensor_operation::element_wise::Swish;
template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_layernorm_f32_instances = std::tuple<
// clang-format off
......@@ -63,6 +63,13 @@ void add_device_normalization_rank_5_3_f32_instances(
add_device_operation_instances(instances, device_layernorm_f32_instances<Pass, 5, 3>{});
}
void add_device_normalization_rank_5_3_swish_f32_instances(
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, Swish, 5, 3>>>&
instances)
{
add_device_operation_instances(instances, device_layernorm_f32_instances<Swish, 5, 3>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment