Commit 86699d01 authored by Jing Zhang's avatar Jing Zhang
Browse files

add instances for fp32 output

parent 0c3cfcf8
add_executable(client_grouped_gemm_fixed_nk_bias_fp16 grouped_gemm_fixed_nk_bias_fp16.cpp) add_executable(client_grouped_gemm_fixed_nk_bias_fp16 grouped_gemm_fixed_nk_bias_fp16.cpp)
target_link_libraries(client_grouped_gemm_fixed_nk_bias_fp16 PRIVATE composable_kernel::device_operations) target_link_libraries(client_grouped_gemm_fixed_nk_bias_fp16 PRIVATE composable_kernel::device_operations)
add_executable(client_grouped_gemm_fixed_nk_bias_fp16_fp32_out grouped_gemm_fixed_nk_bias_fp16_fp32_out.cpp)
target_link_libraries(client_grouped_gemm_fixed_nk_bias_fp16_fp32_out PRIVATE composable_kernel::device_operations)
...@@ -139,6 +139,13 @@ struct AddBias ...@@ -139,6 +139,13 @@ struct AddBias
{ {
e = c + d0; e = c + d0;
} }
template <>
__host__ __device__ void
operator()<float, float, float>(float& e, const float& c, const float& d0) const
{
e = c + d0;
}
}; };
struct UnaryConvert struct UnaryConvert
......
...@@ -16,6 +16,7 @@ namespace tensor_operation { ...@@ -16,6 +16,7 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
//fp16_output
void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances( void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row, std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Row, Row,
...@@ -42,33 +43,36 @@ void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances( ...@@ -42,33 +43,36 @@ void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances(
PassThrough, PassThrough,
AddBias>>>& instances); AddBias>>>& instances);
#if 0 //fp32_output
void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_km_kn_mn_instances( void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Col, std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Row, Row,
Row_Tuple, Row_Tuple,
Row, Row,
F16, F16,
F16, F16,
F32_Tuple, F32_Tuple,
F16, F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
AddBias>>>& instances); AddBias>>>& instances);
void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_km_nk_mn_instances( void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Col, std::vector<std::unique_ptr<DeviceGroupedGemmFixedNK<Row,
Col, Col,
Row_Tuple, Row_Tuple,
Row, Row,
F16, F16,
F16, F16,
F32_Tuple, F32_Tuple,
F16, F32,
PassThrough, PassThrough,
PassThrough, PassThrough,
AddBias>>>& instances); AddBias>>>& instances);
#endif
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
...@@ -105,6 +109,7 @@ struct DeviceOperationInstanceFactory< ...@@ -105,6 +109,7 @@ struct DeviceOperationInstanceFactory<
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
//fp16_output
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> && if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<EDataType, half_t>) is_same_v<EDataType, half_t>)
{ {
...@@ -118,15 +123,21 @@ struct DeviceOperationInstanceFactory< ...@@ -118,15 +123,21 @@ struct DeviceOperationInstanceFactory<
{ {
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances(op_ptrs); add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
} }
if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> && }
//fp32_output
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<EDataType, float>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>) is_same_v<ELayout, Row>)
{ {
// add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_km_kn_mn_instances(op_ptrs); add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_instances(op_ptrs);
} }
if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> && if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>) is_same_v<ELayout, Row>)
{ {
// add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_km_nk_mn_instances(op_ptrs); add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instances(op_ptrs);
} }
} }
return op_ptrs; return op_ptrs;
......
add_instance_library(device_grouped_gemm_bias_instance add_instance_library(device_grouped_gemm_bias_instance
device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instance.cpp device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instance.cpp device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instance.cpp
#device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_km_kn_mn_instance.cpp
#device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_km_nk_mn_instance.cpp device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instance.cpp
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment