"docs/source/api/vscode:/vscode.git/clone" did not exist on "e7457b377d395e5402910eae8540c2bad7613ebf"
Commit e6e8edef authored by rocking's avatar rocking
Browse files

Add groupnorm + swish instances

parent d651dc85
...@@ -16,6 +16,7 @@ using F16 = ck::half_t; ...@@ -16,6 +16,7 @@ using F16 = ck::half_t;
using F32 = float; using F32 = float;
using Pass = ck::tensor_operation::element_wise::PassThrough; using Pass = ck::tensor_operation::element_wise::PassThrough;
using Swish = ck::tensor_operation::element_wise::Swish;
template <typename OutElementwise, index_t Rank, index_t Reduce> template <typename OutElementwise, index_t Rank, index_t Reduce>
// clang-format off // clang-format off
...@@ -64,6 +65,13 @@ void add_device_normalization_rank_5_3_f16_instances( ...@@ -64,6 +65,13 @@ void add_device_normalization_rank_5_3_f16_instances(
add_device_operation_instances(instances, device_normalization_f16_instances<Pass, 5, 3>{}); add_device_operation_instances(instances, device_normalization_f16_instances<Pass, 5, 3>{});
} }
void add_device_normalization_rank_5_3_swish_f16_instances(
std::vector<std::unique_ptr<DeviceNormalization<F16, F16, F16, F32, F16, Swish, 5, 3>>>&
instances)
{
add_device_operation_instances(instances, device_normalization_f16_instances<Swish, 5, 3>{});
}
} // namespace instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -15,7 +15,7 @@ namespace instance { ...@@ -15,7 +15,7 @@ namespace instance {
using F32 = float; using F32 = float;
using Pass = ck::tensor_operation::element_wise::PassThrough; using Pass = ck::tensor_operation::element_wise::PassThrough;
using Swish = ck::tensor_operation::element_wise::Swish;
template <typename OutElementwise, index_t Rank, index_t Reduce> template <typename OutElementwise, index_t Rank, index_t Reduce>
using device_layernorm_f32_instances = std::tuple< using device_layernorm_f32_instances = std::tuple<
// clang-format off // clang-format off
...@@ -63,6 +63,13 @@ void add_device_normalization_rank_5_3_f32_instances( ...@@ -63,6 +63,13 @@ void add_device_normalization_rank_5_3_f32_instances(
add_device_operation_instances(instances, device_layernorm_f32_instances<Pass, 5, 3>{}); add_device_operation_instances(instances, device_layernorm_f32_instances<Pass, 5, 3>{});
} }
void add_device_normalization_rank_5_3_swish_f32_instances(
std::vector<std::unique_ptr<DeviceNormalization<F32, F32, F32, F32, F32, Swish, 5, 3>>>&
instances)
{
add_device_operation_instances(instances, device_layernorm_f32_instances<Swish, 5, 3>{});
}
} // namespace instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment