Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
01e94729
Commit
01e94729
authored
Jan 28, 2021
by
Chao Liu
Browse files
fix control flow issue for padding case
parent
68ea43b1
Changes
10
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
369 additions
and
452 deletions
+369
-452
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
...convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+273
-1
composable_kernel/include/gridwise_operation_wrapper.hpp
composable_kernel/include/gridwise_operation_wrapper.hpp
+5
-1
composable_kernel/include/tensor_description/dynamic_multi_index_transform.hpp
...lude/tensor_description/dynamic_multi_index_transform.hpp
+11
-3
composable_kernel/include/tensor_description/dynamic_tensor_descriptor.hpp
.../include/tensor_description/dynamic_tensor_descriptor.hpp
+5
-0
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
...kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
+0
-30
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
...or_operation/threadwise_dynamic_tensor_slice_transfer.hpp
+29
-405
composable_kernel/include/utility/config.amd.hpp.in
composable_kernel/include/utility/config.amd.hpp.in
+1
-1
driver/include/device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
...convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+3
-3
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
...convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+37
-3
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+5
-5
No files found.
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
01e94729
...
@@ -188,7 +188,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
...
@@ -188,7 +188,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
0
,
0
,
GemmABlockTransferSrcScalarPerVector_GemmK
,
GemmABlockTransferSrcScalarPerVector_GemmK
,
GemmABlockTransferDstScalarPerVector_GemmM
,
GemmABlockTransferDstScalarPerVector_GemmM
,
tru
e
,
// move back src coordinate after threadwise copy
fals
e
,
//
don't
move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
,
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
...
@@ -623,5 +623,277 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
...
@@ -623,5 +623,277 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
}
}
};
};
template
<
index_t
BlockSize
,
typename
Float
,
typename
AccFloat
,
index_t
GemmMPerBlock
,
index_t
GemmNPerBlock
,
index_t
GemmKPerBlock
,
index_t
GemmMPerThread
,
index_t
GemmNPerThread
,
index_t
GemmKPerThread
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
typename
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
,
typename
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
,
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
,
index_t
GemmABlockTransferDstScalarPerVector_GemmM
,
typename
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
,
typename
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
,
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
,
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
,
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
>
struct
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
{
template
<
typename
...
Wei
,
typename
...
In
,
typename
...
Out
>
__host__
void
Run
(
const
DynamicTensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_global_desc
,
const
DynamicTensorDescriptor
<
In
...
>&
in_n_c_hi_wi_global_desc
,
const
DynamicTensorDescriptor
<
Out
...
>&
out_n_k_ho_wo_global_desc
,
const
MultiIndex
<
2
>
conv_strides
,
const
MultiIndex
<
2
>
conv_dilations
,
const
MultiIndex
<
2
>
in_left_pads
,
const
MultiIndex
<
2
>
in_right_pads
,
const
Float
*
__restrict__
p_wei_global
,
const
Float
*
__restrict__
p_in_global
,
Float
*
__restrict__
p_out_global
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
const
index_t
N
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I0
);
const
index_t
C
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I1
);
const
index_t
K
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I1
);
const
index_t
Hi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I2
);
const
index_t
Wi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I3
);
const
index_t
Ho
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I2
);
const
index_t
Wo
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I3
);
const
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLength
(
I2
);
const
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLength
(
I3
);
const
index_t
ConvStrideH
=
conv_strides
[
I0
];
const
index_t
ConvStrideW
=
conv_strides
[
I1
];
const
index_t
ConvDilationH
=
conv_dilations
[
I0
];
const
index_t
ConvDilationW
=
conv_dilations
[
I1
];
const
index_t
InLeftPadH
=
in_left_pads
[
I0
];
const
index_t
InLeftPadW
=
in_left_pads
[
I1
];
const
index_t
InRightPadH
=
in_right_pads
[
I0
];
const
index_t
InRightPadW
=
in_right_pads
[
I1
];
if
(
!
(
Y
==
1
&&
X
==
1
&&
ConvStrideH
==
1
&&
ConvStrideW
==
1
&&
ConvDilationH
==
1
&&
ConvDilationW
==
1
&&
InLeftPadH
==
0
&&
InLeftPadW
==
0
&&
InRightPadH
==
0
&&
InRightPadW
==
0
))
{
throw
std
::
runtime_error
(
"wrong! 1x1, stride 1, no padding"
);
}
// weight tensor
const
auto
wei_gemmk_gemmm_global_desc
=
transform_dynamic_tensor_descriptor
(
make_dynamic_naive_tensor_descriptor_packed
<
2
>
(
make_multi_index
(
K
,
C
)),
make_tuple
(
DynamicPassThrough
{
K
},
DynamicPassThrough
{
C
}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}));
// input tensor
const
auto
in_gemmk_gemmn_global_desc
=
transform_dynamic_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
make_tuple
(
DynamicPassThrough
{
C
},
DynamicMerge
<
3
>
{
make_multi_index
(
N
,
Ho
,
Wo
)}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// output tensor
const
auto
out_gemmm_gemmn_global_desc
=
transform_dynamic_tensor_descriptor
(
make_dynamic_naive_tensor_descriptor_packed
<
3
>
(
make_multi_index
(
N
,
K
,
Ho
*
Wo
)),
make_tuple
(
DynamicPassThrough
{
K
},
DynamicMerge
<
2
>
{
make_multi_index
(
N
,
Ho
*
Wo
)}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
index_t
GemmM
=
out_gemmm_gemmn_global_desc
.
GetLength
(
I0
);
const
index_t
GemmN
=
out_gemmm_gemmn_global_desc
.
GetLength
(
I1
);
const
index_t
GemmK
=
wei_gemmk_gemmm_global_desc
.
GetLength
(
I0
);
if
(
!
(
GemmM
%
GemmMPerBlock
==
0
&&
GemmN
%
GemmNPerBlock
==
0
&&
GemmK
%
GemmKPerBlock
==
0
))
{
throw
std
::
runtime_error
(
"wrong! GEMM size no divisible"
);
}
constexpr
index_t
GemmM1
=
GemmMPerThread
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
;
constexpr
index_t
GemmN1
=
GemmNPerThread
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
;
const
index_t
GemmM0
=
GemmM
/
GemmM1
;
const
index_t
GemmN0
=
GemmN
/
GemmN1
;
const
auto
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
=
transform_dynamic_tensor_descriptor
(
out_gemmm_gemmn_global_desc
,
make_tuple
(
DynamicUnMerge
<
2
>
{
make_multi_index
(
GemmM0
,
GemmM1
)},
DynamicUnMerge
<
2
>
{
make_multi_index
(
GemmN0
,
GemmN1
)}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
// GEMM
using
gridwise_gemm
=
GridwiseDynamicGemm_km_kn_mn_v1
<
BlockSize
,
Float
,
AccFloat
,
InMemoryDataOperation
::
Set
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThread
,
GemmNPerThread
,
GemmKPerThread
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
,
Sequence
<
1
,
0
>
,
Sequence
<
1
,
0
>
,
0
,
GemmABlockTransferSrcScalarPerVector_GemmK
,
GemmABlockTransferDstScalarPerVector_GemmM
,
false
,
// don't move back src coordinate after threadwise copy
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
,
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
GemmBBlockTransferSrcScalarPerVector_GemmN
,
GemmBBlockTransferDstScalarPerVector_GemmN
,
false
,
// don't move back src coordinate after threadwise copy, which will be fused with
// MoveSrcSliceWindow() to save addr computation
Sequence
<
2
,
3
,
0
,
1
>
,
3
,
GemmCThreadTransferDstScalarPerVector_GemmN1
>
;
const
index_t
GridSize
=
(
GemmM
/
GemmMPerBlock
)
*
(
GemmN
/
GemmNPerBlock
);
const
bool
has_main_k_block_loop
=
(
GemmK
+
GemmKPerBlock
)
/
(
2
*
GemmKPerBlock
)
>
1
;
const
bool
has_double_tail_k_block_loop
=
(
GemmK
/
GemmKPerBlock
)
%
2
==
0
;
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
),
const
Float
*
,
decltype
(
in_gemmk_gemmn_global_desc
),
const
Float
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
),
Float
*
,
integral_constant
<
bool
,
true
>
,
integral_constant
<
bool
,
true
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc
,
p_wei_global
,
in_gemmk_gemmn_global_desc
,
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
p_out_global
,
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
),
const
Float
*
,
decltype
(
in_gemmk_gemmn_global_desc
),
const
Float
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
),
Float
*
,
integral_constant
<
bool
,
true
>
,
integral_constant
<
bool
,
false
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc
,
p_wei_global
,
in_gemmk_gemmn_global_desc
,
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
p_out_global
,
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
false
>
{});
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
),
const
Float
*
,
decltype
(
in_gemmk_gemmn_global_desc
),
const
Float
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
),
Float
*
,
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
true
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc
,
p_wei_global
,
in_gemmk_gemmn_global_desc
,
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
p_out_global
,
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{});
}
else
{
const
auto
kernel
=
run_gridwise_operation
<
gridwise_gemm
,
decltype
(
wei_gemmk_gemmm_global_desc
),
const
Float
*
,
decltype
(
in_gemmk_gemmn_global_desc
),
const
Float
*
,
decltype
(
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
),
Float
*
,
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
false
>>
;
launch_kernel
(
kernel
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_gemmk_gemmm_global_desc
,
p_wei_global
,
in_gemmk_gemmn_global_desc
,
p_in_global
,
out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc
,
p_out_global
,
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{});
}
}
};
}
// namespace ck
}
// namespace ck
#endif
#endif
composable_kernel/include/gridwise_operation_wrapper.hpp
View file @
01e94729
...
@@ -2,7 +2,11 @@
...
@@ -2,7 +2,11 @@
#define CK_GRIDWISE_OPERATION_KERNEL_WRAPPER
#define CK_GRIDWISE_OPERATION_KERNEL_WRAPPER
template
<
typename
GridwiseOp
,
typename
...
Xs
>
template
<
typename
GridwiseOp
,
typename
...
Xs
>
__global__
void
run_gridwise_operation
(
Xs
...
xs
)
__global__
void
#if 1
__launch_bounds__
(
256
,
2
)
#endif
run_gridwise_operation
(
Xs
...
xs
)
{
{
GridwiseOp
{}.
Run
(
xs
...);
GridwiseOp
{}.
Run
(
xs
...);
}
}
...
...
composable_kernel/include/tensor_description/dynamic_multi_index_transform.hpp
View file @
01e94729
...
@@ -848,13 +848,18 @@ struct DynamicMerge
...
@@ -848,13 +848,18 @@ struct DynamicMerge
do_carry
=
idx_low_tmp
>=
low_lengths_
[
i
];
do_carry
=
idx_low_tmp
>=
low_lengths_
[
i
];
#if 0
#if 0
// TODO: use exec-mask inline asm
// TODO: use exec-mask inline asm
, which use 1 VALU
if(do_carry)
if(do_carry)
{
{
idx_diff_low(i) -= low_lengths_[i];
idx_diff_low(i) -= low_lengths_[i];
}
}
#else
#elif
1
// this use 2 VALU
idx_diff_low
(
i
)
=
do_carry
?
idx_diff_low
[
i
]
-
low_lengths_
[
i
]
:
idx_diff_low
[
i
];
idx_diff_low
(
i
)
=
do_carry
?
idx_diff_low
[
i
]
-
low_lengths_
[
i
]
:
idx_diff_low
[
i
];
#elif 1
// this use 2 VALU
index_t
idx_diff_low_tmp
=
idx_diff_low
[
i
]
-
low_lengths_
[
i
];
idx_diff_low
(
i
)
=
do_carry
?
idx_diff_low_tmp
:
idx_diff_low
[
i
];
#endif
#endif
idx_low
(
i
)
+=
idx_diff_low
[
i
];
idx_low
(
i
)
+=
idx_diff_low
[
i
];
...
@@ -885,8 +890,11 @@ struct DynamicMerge
...
@@ -885,8 +890,11 @@ struct DynamicMerge
{
{
idx_diff_low(i) += low_lengths_[i];
idx_diff_low(i) += low_lengths_[i];
}
}
#el
se
#el
if
1
idx_diff_low
(
i
)
=
do_borrow
?
idx_diff_low
[
i
]
+
low_lengths_
[
i
]
:
idx_diff_low
[
i
];
idx_diff_low
(
i
)
=
do_borrow
?
idx_diff_low
[
i
]
+
low_lengths_
[
i
]
:
idx_diff_low
[
i
];
#elif 1
index_t
idx_diff_low_tmp
=
idx_diff_low
[
i
]
+
low_lengths_
[
i
];
idx_diff_low
(
i
)
=
do_borrow
?
idx_diff_low_tmp
:
idx_diff_low
[
i
];
#endif
#endif
idx_low
(
i
)
+=
idx_diff_low
[
i
];
idx_low
(
i
)
+=
idx_diff_low
[
i
];
...
...
composable_kernel/include/tensor_description/dynamic_tensor_descriptor.hpp
View file @
01e94729
...
@@ -541,7 +541,12 @@ coordinate_has_valid_offset_assuming_visible_index_is_valid(const TensorDesc& te
...
@@ -541,7 +541,12 @@ coordinate_has_valid_offset_assuming_visible_index_is_valid(const TensorDesc& te
const
auto
idx_up
=
const
auto
idx_up
=
get_container_subset
(
idx_hidden
,
TensorDesc
::
GetUpperDimensionIdss
().
At
(
itran
));
get_container_subset
(
idx_hidden
,
TensorDesc
::
GetUpperDimensionIdss
().
At
(
itran
));
#if 0 // debug
// Comment: this implemenetation results in weird control flow in ISA
valid = valid && tran.IsValidUpperIndexMappedToValidLowerIndex(idx_up);
valid = valid && tran.IsValidUpperIndexMappedToValidLowerIndex(idx_up);
#else
valid
&=
tran
.
IsValidUpperIndexMappedToValidLowerIndex
(
idx_up
);
#endif
}
}
});
});
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
View file @
01e94729
...
@@ -322,7 +322,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
...
@@ -322,7 +322,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
}
while
(
k_block_data_begin
<
K
-
2
*
KPerBlock
);
}
while
(
k_block_data_begin
<
K
-
2
*
KPerBlock
);
}
}
#if 1
// LDS double buffer: tail
// LDS double buffer: tail
if
constexpr
(
HasDoubleTailKBlockLoop
)
// if has 2 iteration left
if
constexpr
(
HasDoubleTailKBlockLoop
)
// if has 2 iteration left
{
{
...
@@ -356,7 +355,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
...
@@ -356,7 +355,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
// LDS double buffer: GEMM on last data
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
p_a_block_double
,
p_b_block_double
,
p_c_thread
);
blockwise_gemm
.
Run
(
p_a_block_double
,
p_b_block_double
,
p_c_thread
);
}
}
#endif
// output: register to global memory
// output: register to global memory
{
{
...
@@ -385,33 +383,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
...
@@ -385,33 +383,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
const
index_t
n_thread_data_on_global
=
const
index_t
n_thread_data_on_global
=
n_block_data_on_global
+
c_thread_mtx_on_block
.
col
;
n_block_data_on_global
+
c_thread_mtx_on_block
.
col
;
#if 0
ThreadwiseDynamicTensorSliceTransfer_v1r2<
AccFloat,
Float,
decltype(c_m0_m1_n0_n1_thread_desc),
decltype(c_m0_m1_n0_n1_global_desc),
Sequence<MRepeat, MPerThread, NRepeat, NPerThread>,
CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
1,
CThreadTransferDstScalarPerVector,
AddressSpace::Vgpr,
AddressSpace::Global,
CGlobalMemoryDataOperation,
1,
1,
true,
true>(c_m0_m1_n0_n1_thread_desc,
make_multi_index(0, 0, 0, 0),
c_m0_m1_n0_n1_global_desc,
make_multi_index(m_thread_data_on_global / M1,
m_thread_data_on_global % M1,
n_thread_data_on_global / N1,
n_thread_data_on_global % N1))
.Run_hack(
c_m0_m1_n0_n1_thread_desc, p_c_thread, c_m0_m1_n0_n1_global_desc, p_c_global);
#else
ThreadwiseDynamicTensorSliceTransfer_v1r3
<
ThreadwiseDynamicTensorSliceTransfer_v1r3
<
AccFloat
,
AccFloat
,
Float
,
Float
,
...
@@ -432,7 +403,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
...
@@ -432,7 +403,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
n_thread_data_on_global
/
N1
,
n_thread_data_on_global
/
N1
,
n_thread_data_on_global
%
N1
))
n_thread_data_on_global
%
N1
))
.
Run_hack
(
p_c_thread
,
c_m0_m1_n0_n1_global_desc
,
p_c_global
);
.
Run_hack
(
p_c_thread
,
c_m0_m1_n0_n1_global_desc
,
p_c_global
);
#endif
}
}
}
}
...
...
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
View file @
01e94729
This diff is collapsed.
Click to expand it.
composable_kernel/include/utility/config.amd.hpp.in
View file @
01e94729
...
@@ -87,7 +87,7 @@
...
@@ -87,7 +87,7 @@
// thread-invariant, otherwise it's a bug
// thread-invariant, otherwise it's a bug
// TODO: separate index calculation into "compile-time", "global", "block", "wave", "thread"
// TODO: separate index calculation into "compile-time", "global", "block", "wave", "thread"
#ifndef CK_HACK_DYNAMIC_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
#ifndef CK_HACK_DYNAMIC_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
#define CK_HACK_DYNAMIC_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
0
#define CK_HACK_DYNAMIC_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
1
#endif
#endif
// workaround: put all workaround here
// workaround: put all workaround here
...
...
driver/include/device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
01e94729
...
@@ -187,7 +187,7 @@ void device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -187,7 +187,7 @@ void device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
1
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
1
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
#elif
1
#elif
0
// cdata = 64, BlockSize = 256, 128x128x8
// cdata = 64, BlockSize = 256, 128x128x8
// b threadwise copy 2x2
// b threadwise copy 2x2
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
...
@@ -221,7 +221,7 @@ void device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -221,7 +221,7 @@ void device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
1
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
1
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
#elif
1
#elif
0
// cdata = 64, BlockSize = 256, 128x128x8
// cdata = 64, BlockSize = 256, 128x128x8
// vector 4
// vector 4
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
...
@@ -323,7 +323,7 @@ void device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -323,7 +323,7 @@ void device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
4
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
4
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
4
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
4
;
#elif
0
#elif
1
// cdata = 64, BlockSize = 256, 128x128x16
// cdata = 64, BlockSize = 256, 128x128x16
// GemmBBlockCopySrcDataPerRead_GemmN = 4
// GemmBBlockCopySrcDataPerRead_GemmN = 4
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 4
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 4
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
01e94729
...
@@ -104,7 +104,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
...
@@ -104,7 +104,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
2
,
1
>
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
2
,
1
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
2
,
128
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
1
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
2
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
2
,
1
>
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
2
,
1
>
;
...
@@ -145,7 +145,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
...
@@ -145,7 +145,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
1
;
#elif
1
#elif
0
// cdata = 64, BlockSize = 256, 128x128x8
// cdata = 64, BlockSize = 256, 128x128x8
// b thread copy 2x2
// b thread copy 2x2
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
...
@@ -176,6 +176,38 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
...
@@ -176,6 +176,38 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
1
;
#elif 1
// cdata = 64, BlockSize = 256, 128x128x16
// GemmBBlockCopySrcDataPerRead_GemmN = 4
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 4
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
16
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
4
,
2
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
64
>
;
constexpr
index_t
GemmABlockTransferSrcScalarPerVector_GemmK
=
4
;
constexpr
index_t
GemmABlockTransferDstScalarPerVector_GemmM
=
1
;
using
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN
=
Sequence
<
2
,
4
>
;
using
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN
=
Sequence
<
8
,
32
>
;
constexpr
index_t
GemmBBlockTransferSrcScalarPerVector_GemmN
=
4
;
constexpr
index_t
GemmBBlockTransferDstScalarPerVector_GemmN
=
4
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
4
;
#endif
#endif
const
index_t
N
=
out_n_k_ho_wo_desc
.
GetLength
(
I0
);
const
index_t
N
=
out_n_k_ho_wo_desc
.
GetLength
(
I0
);
...
@@ -203,8 +235,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
...
@@ -203,8 +235,10 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
constexpr
auto
conv_driver
=
constexpr
auto
conv_driver
=
#if 1
#if 1
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
#el
se
#el
if 0
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
#elif 1
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
#endif
#endif
<
BlockSize
,
<
BlockSize
,
TDevice
,
TDevice
,
...
...
driver/src/conv_driver.cpp
View file @
01e94729
...
@@ -103,7 +103,7 @@ int main(int argc, char* argv[])
...
@@ -103,7 +103,7 @@ int main(int argc, char* argv[])
constexpr
index_t
C
=
96
;
constexpr
index_t
C
=
96
;
constexpr
index_t
HI
=
35
;
constexpr
index_t
HI
=
35
;
constexpr
index_t
WI
=
35
;
constexpr
index_t
WI
=
35
;
constexpr
index_t
K
=
96
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
constexpr
index_t
X
=
3
;
...
@@ -133,7 +133,7 @@ int main(int argc, char* argv[])
...
@@ -133,7 +133,7 @@ int main(int argc, char* argv[])
constexpr
index_t
C
=
128
;
constexpr
index_t
C
=
128
;
constexpr
index_t
HI
=
17
;
constexpr
index_t
HI
=
17
;
constexpr
index_t
WI
=
17
;
constexpr
index_t
WI
=
17
;
constexpr
index_t
K
=
128
;
constexpr
index_t
K
=
384
;
constexpr
index_t
Y
=
7
;
constexpr
index_t
Y
=
7
;
constexpr
index_t
X
=
1
;
constexpr
index_t
X
=
1
;
...
@@ -175,10 +175,10 @@ int main(int argc, char* argv[])
...
@@ -175,10 +175,10 @@ int main(int argc, char* argv[])
#elif 0
#elif 0
// 3x3, 147x147
// 3x3, 147x147
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
32
;
constexpr
index_t
C
=
128
;
constexpr
index_t
HI
=
147
;
constexpr
index_t
HI
=
147
;
constexpr
index_t
WI
=
147
;
constexpr
index_t
WI
=
147
;
constexpr
index_t
K
=
64
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
constexpr
index_t
X
=
3
;
...
@@ -457,7 +457,7 @@ int main(int argc, char* argv[])
...
@@ -457,7 +457,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif
0
#elif
1
// 1x1, 7x7
// 1x1, 7x7
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
512
;
constexpr
index_t
C
=
512
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment