Commit d670c5a6 authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed

parent 963bc7a3
...@@ -3,3 +3,6 @@ add_executable(client_conv3d_fwd_fp32 conv3d_fwd_fp32.cpp) ...@@ -3,3 +3,6 @@ add_executable(client_conv3d_fwd_fp32 conv3d_fwd_fp32.cpp)
target_link_libraries(client_conv3d_fwd_fp16 PRIVATE composable_kernel::device_operations) target_link_libraries(client_conv3d_fwd_fp16 PRIVATE composable_kernel::device_operations)
target_link_libraries(client_conv3d_fwd_fp32 PRIVATE composable_kernel::device_operations) target_link_libraries(client_conv3d_fwd_fp32 PRIVATE composable_kernel::device_operations)
add_executable(client_conv3d_fwd_fp16_comp_fp8 conv3d_fwd_fp16_comp_fp8.cpp)
target_link_libraries(client_conv3d_fwd_fp16_comp_fp8 PRIVATE composable_kernel::device_operations)
...@@ -37,7 +37,9 @@ int main() ...@@ -37,7 +37,9 @@ int main()
OutDataType, OutDataType,
InLayout, InLayout,
WeiLayout, WeiLayout,
OutLayout>( OutLayout,
3,
ck::f8_t>(
{N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K}) {N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K})
? EXIT_SUCCESS ? EXIT_SUCCESS
: EXIT_FAILURE; : EXIT_FAILURE;
......
...@@ -378,7 +378,7 @@ template <ck::index_t NumDimSpatial, ...@@ -378,7 +378,7 @@ template <ck::index_t NumDimSpatial,
typename InDataType, typename InDataType,
typename WeiDataType, typename WeiDataType,
typename OutDataType, typename OutDataType,
typename ComputeType = InDataType> typename ComputeType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD< struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD<
NumDimSpatial, NumDimSpatial,
InLayout, InLayout,
...@@ -521,14 +521,15 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe ...@@ -521,14 +521,15 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs);
} }
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> && else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t && is_same_v<ComputeType, half_t>>) is_same_v<OutDataType, half_t> && is_same_v<ComputeType, half_t>)
{ {
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs);
} }
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> && else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t && is_same_v<ComputeType, ck::f8_t>>) is_same_v<OutDataType, half_t> && is_same_v<ComputeType, ck::f8_t>)
{ {
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_f8_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_f8_instances(
op_ptrs);
} }
else if constexpr(is_same_v<InDataType, ck::bhalf_t> && else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<WeiDataType, ck::bhalf_t> &&
......
...@@ -24,21 +24,24 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_f8_instance ...@@ -24,21 +24,24 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_f8_instance
PassThrough, PassThrough,
F8>>>& instances) F8>>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_f16_comp_f8_instances<3, device_grouped_conv_fwd_xdl_f16_comp_f8_instances<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
Empty_Tuple, Empty_Tuple,
NDHWGK, NDHWGK,
ConvFwdDefault>{}); ConvFwdDefault>{});
add_device_operation_instances(instances, add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_f16_comp_f8_instances<3, device_grouped_conv_fwd_xdl_f16_comp_f8_instances<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
Empty_Tuple, Empty_Tuple,
NDHWGK, NDHWGK,
ConvFwd1x1P0>{}); ConvFwd1x1P0>{});
add_device_operation_instances(instances, add_device_operation_instances(
instances,
device_grouped_conv_fwd_xdl_f16_comp_f8_instances<3, device_grouped_conv_fwd_xdl_f16_comp_f8_instances<3,
NDHWGC, NDHWGC,
GKZYXC, GKZYXC,
......
...@@ -6,14 +6,13 @@ rm -rf CMakeFiles ...@@ -6,14 +6,13 @@ rm -rf CMakeFiles
MY_PROJECT_SOURCE=$1 MY_PROJECT_SOURCE=$1
cmake \ cmake \
-D INSTANCES_ONLY=ON \
-D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_PREFIX_PATH=/opt/rocm \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_CXX_FLAGS="-std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker \ -D CMAKE_CXX_FLAGS="-std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker \
-save-temps=$PWD" \ -save-temps=$PWD" \
-D CMAKE_BUILD_TYPE=Release \ -D CMAKE_BUILD_TYPE=Release \
-D BUILD_DEV=ON \ -D BUILD_DEV=ON \
-D GPU_TARGETS="gfx908" \ -D GPU_TARGETS="gfx908;gfx90a;gfx940" \
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
-D USE_BITINT_EXTENSION_INT4=OFF \ -D USE_BITINT_EXTENSION_INT4=OFF \
${MY_PROJECT_SOURCE} ${MY_PROJECT_SOURCE}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment