Commit 86699d01 authored by Jing Zhang's avatar Jing Zhang
Browse files

add instances for fp32 output

parent 0c3cfcf8
add_executable(client_grouped_gemm_fixed_nk_bias_fp16 grouped_gemm_fixed_nk_bias_fp16.cpp)
target_link_libraries(client_grouped_gemm_fixed_nk_bias_fp16 PRIVATE composable_kernel::device_operations)
add_executable(client_grouped_gemm_fixed_nk_bias_fp16_fp32_out grouped_gemm_fixed_nk_bias_fp16_fp32_out.cpp)
target_link_libraries(client_grouped_gemm_fixed_nk_bias_fp16_fp32_out PRIVATE composable_kernel::device_operations)
......@@ -139,6 +139,13 @@ struct AddBias
{
e = c + d0;
}
template <>
__host__ __device__ void
operator()<float, float, float>(float& e, const float& c, const float& d0) const
{
e = c + d0;
}
};
struct UnaryConvert
......
......@@ -16,6 +16,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
//fp16_output
void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Row,
......@@ -42,33 +43,36 @@ void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances(
PassThrough,
AddBias>>>& instances);
#if 0
void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Col,
//fp32_output
void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Row,
Row_Tuple,
Row,
F16,
F16,
F32_Tuple,
F16,
F32,
PassThrough,
PassThrough,
AddBias>>>& instances);
void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Col,
void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Col,
Row_Tuple,
Row,
F16,
F16,
F32_Tuple,
F16,
F32,
PassThrough,
PassThrough,
AddBias>>>& instances);
#endif
template <typename ALayout,
typename BLayout,
......@@ -105,6 +109,7 @@ struct DeviceOperationInstanceFactory<
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
//fp16_output
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<EDataType, half_t>)
{
......@@ -118,15 +123,21 @@ struct DeviceOperationInstanceFactory<
{
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
}
if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
}
//fp32_output
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<EDataType, float>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
// add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_km_kn_mn_instances(op_ptrs);
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_instances(op_ptrs);
}
if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
// add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_km_nk_mn_instances(op_ptrs);
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instances(op_ptrs);
}
}
return op_ptrs;
......
add_instance_library(device_grouped_gemm_bias_instance
device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instance.cpp
#device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_km_kn_mn_instance.cpp
#device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_km_nk_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instance.cpp
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment