Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
3317bfe2
Commit
3317bfe2
authored
Apr 17, 2020
by
Jing Zhang
Browse files
format
parent
2b8e3ece
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
95 additions
and
112 deletions
+95
-112
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw.hpp
...olution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw.hpp
+0
-1
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_fp16_bfp16.hpp
...lude/tensor_operation/gridwise_gemm_xdlops_fp16_bfp16.hpp
+8
-25
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
+1
-1
driver/include/device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw.hpp
...olution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw.hpp
+76
-75
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+10
-10
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw.hpp
View file @
3317bfe2
...
@@ -139,7 +139,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp16_nchw_kcyx_nkhw
...
@@ -139,7 +139,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp16_nchw_kcyx_nkhw
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
constexpr
auto
out_gemmm_gemmn_global_desc
=
constexpr
auto
out_gemmm_gemmn_global_desc
=
transform_tensor_descriptor
(
out_n_k_ho_wo_global_desc
,
transform_tensor_descriptor
(
out_n_k_ho_wo_global_desc
,
make_tuple
(
PassThrough
<
K
>
{},
Merge
<
Sequence
<
N
,
Ho
,
Wo
>>
{}),
make_tuple
(
PassThrough
<
K
>
{},
Merge
<
Sequence
<
N
,
Ho
,
Wo
>>
{}),
...
...
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_fp16_bfp16.hpp
View file @
3317bfe2
...
@@ -219,12 +219,8 @@ struct GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1
...
@@ -219,12 +219,8 @@ struct GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1
// 2D indexes are computed with vectorized value in mind (e.g. float, half2, half4),
// 2D indexes are computed with vectorized value in mind (e.g. float, half2, half4),
// we recast datatype from a single half to 4 packed half/2 packed bfloat16
// we recast datatype from a single half to 4 packed half/2 packed bfloat16
// respectively.
// respectively.
auto
p_a_block_vec
=
auto
p_a_block_vec
=
reinterpret_cast
<
const
half4_t
*>
(
p_a_block_now
);
reinterpret_cast
<
const
half4_t
*>
(
auto
p_b_block_vec
=
reinterpret_cast
<
const
half4_t
*>
(
p_b_block_now
);
p_a_block_now
);
auto
p_b_block_vec
=
reinterpret_cast
<
const
half4_t
*>
(
p_b_block_now
);
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
p_c_thread
);
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
p_c_thread
);
// LDS double buffer: store next data to LDS
// LDS double buffer: store next data to LDS
...
@@ -252,12 +248,8 @@ struct GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1
...
@@ -252,12 +248,8 @@ struct GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1
b_blockwise_copy
.
RunLoadThreadBuffer
(
p_b_global
,
p_b_thread_buffer
);
b_blockwise_copy
.
RunLoadThreadBuffer
(
p_b_global
,
p_b_thread_buffer
);
// LDS double buffer: GEMM on 2nd-last data
// LDS double buffer: GEMM on 2nd-last data
auto
p_a_block_vec
=
auto
p_a_block_vec
=
reinterpret_cast
<
const
half4_t
*>
(
p_a_block_double
);
reinterpret_cast
<
const
half4_t
*>
(
auto
p_b_block_vec
=
reinterpret_cast
<
const
half4_t
*>
(
p_b_block_double
);
p_a_block_double
);
auto
p_b_block_vec
=
reinterpret_cast
<
const
half4_t
*>
(
p_b_block_double
);
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
p_c_thread
);
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
p_c_thread
);
// LDS double buffer: store last data to LDS
// LDS double buffer: store last data to LDS
...
@@ -269,12 +261,8 @@ struct GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1
...
@@ -269,12 +261,8 @@ struct GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1
__syncthreads
();
__syncthreads
();
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
p_a_block_vec
=
p_a_block_vec
=
reinterpret_cast
<
const
half4_t
*>
(
p_a_block_double
+
a_block_space
);
reinterpret_cast
<
const
half4_t
*>
(
p_b_block_vec
=
reinterpret_cast
<
const
half4_t
*>
(
p_b_block_double
+
b_block_space
);
p_a_block_double
+
a_block_space
);
p_b_block_vec
=
reinterpret_cast
<
const
half4_t
*>
(
p_b_block_double
+
b_block_space
);
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
p_c_thread
);
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
p_c_thread
);
}
}
else
// if has 1 iteration left
else
// if has 1 iteration left
...
@@ -282,12 +270,8 @@ struct GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1
...
@@ -282,12 +270,8 @@ struct GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1
__syncthreads
();
__syncthreads
();
// LDS double buffer: GEMM on last data
// LDS double buffer: GEMM on last data
auto
p_a_block_vec
=
auto
p_a_block_vec
=
reinterpret_cast
<
const
half4_t
*>
(
p_a_block_double
);
reinterpret_cast
<
const
half4_t
*>
(
auto
p_b_block_vec
=
reinterpret_cast
<
const
half4_t
*>
(
p_b_block_double
);
p_a_block_double
);
auto
p_b_block_vec
=
reinterpret_cast
<
const
half4_t
*>
(
p_b_block_double
);
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
p_c_thread
);
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
p_c_thread
);
}
}
}
}
...
@@ -348,7 +332,6 @@ struct GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1
...
@@ -348,7 +332,6 @@ struct GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1
}
}
}
}
};
};
}
}
#endif
#endif
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
View file @
3317bfe2
...
@@ -810,7 +810,7 @@ struct XdlopsGemm_t
...
@@ -810,7 +810,7 @@ struct XdlopsGemm_t
(
mfma_type
.
group_size
*
mfma_type
.
num_input_blks
);
(
mfma_type
.
group_size
*
mfma_type
.
num_input_blks
);
index_t
bindex
=
blk_td
;
index_t
bindex
=
blk_td
;
p_c_thread
[
m
+
c_off
]
+=
inner_product_with_conversion
<
FloatC
>
{}(
p_c_thread
[
m
+
c_off
]
+=
inner_product_with_conversion
<
FloatC
>
{}(
p_a_wave
[
aindex
+
a_off
],
p_b_wave
[
bindex
+
b_off
]);
p_a_wave
[
aindex
+
a_off
],
p_b_wave
[
bindex
+
b_off
]);
}
}
}
}
}
}
...
...
driver/include/device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw.hpp
View file @
3317bfe2
...
@@ -13,16 +13,16 @@ template <class T,
...
@@ -13,16 +13,16 @@ template <class T,
class
InLeftPads
,
class
InLeftPads
,
class
InRightPads
>
class
InRightPads
>
void
device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw
(
InDesc
,
void
device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw
(
InDesc
,
const
Tensor
<
T
>&
in_nchw
,
const
Tensor
<
T
>&
in_nchw
,
WeiDesc
,
WeiDesc
,
const
Tensor
<
T
>&
wei_kcyx
,
const
Tensor
<
T
>&
wei_kcyx
,
OutDesc
,
OutDesc
,
Tensor
<
T
>&
out_nkhw
,
Tensor
<
T
>&
out_nkhw
,
ConvStrides
,
ConvStrides
,
ConvDilations
,
ConvDilations
,
InLeftPads
,
InLeftPads
,
InRightPads
,
InRightPads
,
ck
::
index_t
nrepeat
)
ck
::
index_t
nrepeat
)
{
{
using
namespace
ck
;
using
namespace
ck
;
...
@@ -60,7 +60,7 @@ void device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw(InDesc,
...
@@ -60,7 +60,7 @@ void device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmKPACK
=
4
;
constexpr
index_t
GemmKPACK
=
4
;
constexpr
index_t
GemmMPerWave
=
64
;
constexpr
index_t
GemmMPerWave
=
64
;
constexpr
index_t
GemmNPerWave
=
64
;
constexpr
index_t
GemmNPerWave
=
64
;
...
@@ -76,7 +76,7 @@ void device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw(InDesc,
...
@@ -76,7 +76,7 @@ void device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw(InDesc,
using
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN_GemmKPACK
=
Sequence
<
1
,
2
,
4
>
;
using
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN_GemmKPACK
=
Sequence
<
1
,
2
,
4
>
;
using
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN_GemmKPACK
=
Sequence
<
4
,
32
,
1
>
;
using
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN_GemmKPACK
=
Sequence
<
4
,
32
,
1
>
;
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
1
;
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
1
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmKPACK
=
1
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmKPACK
=
1
;
constexpr
index_t
GemmM
=
K
;
constexpr
index_t
GemmM
=
K
;
...
@@ -87,51 +87,52 @@ void device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw(InDesc,
...
@@ -87,51 +87,52 @@ void device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw(InDesc,
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
constexpr
auto
gridwise_conv
=
GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp16_nchw_kcyx_nkhw
<
constexpr
auto
gridwise_conv
=
GridSize
,
GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp16_nchw_kcyx_nkhw
<
BlockSize
,
GridSize
,
half
,
BlockSize
,
float
,
half
,
decltype
(
in_nchw_desc
),
float
,
decltype
(
wei_kcyx_desc
),
decltype
(
in_nchw_desc
),
decltype
(
out_nkhw_desc
),
decltype
(
wei_kcyx_desc
),
ConvStrides
,
decltype
(
out_nkhw_desc
),
ConvDilations
,
ConvStrides
,
InLeftPads
,
ConvDilations
,
InRightPads
,
InLeftPads
,
GemmMPerBlock
,
InRightPads
,
GemmNPerBlock
,
GemmMPerBlock
,
GemmKPerBlock
,
GemmNPerBlock
,
GemmKPACK
,
GemmKPerBlock
,
GemmMPerWave
,
GemmKPACK
,
GemmNPerWave
,
GemmMPerWave
,
ThreadGemmDataPerReadM
,
GemmNPerWave
,
ThreadGemmDataPerReadN
,
ThreadGemmDataPerReadM
,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM_GemmKPACK
,
ThreadGemmDataPerReadN
,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM_GemmKPACK
,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM_GemmKPACK
,
GemmABlockCopySrcDataPerRead_GemmKPACK
,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM_GemmKPACK
,
GemmABlockCopyDstDataPerWrite_GemmKPACK
,
GemmABlockCopySrcDataPerRead_GemmKPACK
,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN_GemmKPACK
,
GemmABlockCopyDstDataPerWrite_GemmKPACK
,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN_GemmKPACK
,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN_GemmKPACK
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN_GemmKPACK
,
GemmBBlockCopyDstDataPerWrite_GemmKPACK
>
{};
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmKPACK
>
{};
for
(
index_t
i
=
0
;
i
<
10
;
++
i
)
for
(
index_t
i
=
0
;
i
<
10
;
++
i
)
{
{
float
time
=
float
time
=
launch_and_time_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
launch_and_time_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
GridSize
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
0
,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
printf
(
"Elapsed time : %f ms, %f TFlop/s
\n
"
,
printf
(
"Elapsed time : %f ms, %f TFlop/s
\n
"
,
time
,
time
,
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
time
);
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
time
);
}
}
// warm up
// warm up
...
@@ -139,14 +140,14 @@ void device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw(InDesc,
...
@@ -139,14 +140,14 @@ void device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw(InDesc,
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
{
launch_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
launch_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
GridSize
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
0
,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
}
}
printf
(
"Start running %d times...
\n
"
,
nrepeat
);
printf
(
"Start running %d times...
\n
"
,
nrepeat
);
...
@@ -156,25 +157,25 @@ void device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw(InDesc,
...
@@ -156,25 +157,25 @@ void device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw(InDesc,
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
{
launch_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
launch_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
GridSize
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
0
,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
}
}
cudaDeviceSynchronize
();
cudaDeviceSynchronize
();
auto
end
=
std
::
chrono
::
steady_clock
::
now
();
auto
end
=
std
::
chrono
::
steady_clock
::
now
();
float
ave_time
=
std
::
chrono
::
duration
<
float
,
std
::
milli
>
(
end
-
start
).
count
()
/
nrepeat
;
float
ave_time
=
std
::
chrono
::
duration
<
float
,
std
::
milli
>
(
end
-
start
).
count
()
/
nrepeat
;
printf
(
"Average elapsed time : %f ms, %f TFlop/s
\n
"
,
printf
(
"Average elapsed time : %f ms, %f TFlop/s
\n
"
,
ave_time
,
ave_time
,
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
);
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
);
out_nkhw_device_buf
.
FromDevice
(
out_nkhw
.
mData
.
data
());
out_nkhw_device_buf
.
FromDevice
(
out_nkhw
.
mData
.
data
());
}
}
driver/src/conv_driver.cpp
View file @
3317bfe2
...
@@ -618,16 +618,16 @@ int main(int argc, char* argv[])
...
@@ -618,16 +618,16 @@ int main(int argc, char* argv[])
nrepeat
);
nrepeat
);
#elif 1
#elif 1
device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw
(
in_nchw_desc
,
device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw
(
in_nchw_desc
,
in_nchw
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx_desc
,
wei_kcyx
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_desc
,
out_nkhw_device
,
out_nkhw_device
,
ConvStrides
{},
ConvStrides
{},
ConvDilations
{},
ConvDilations
{},
LeftPads
{},
LeftPads
{},
RightPads
{},
RightPads
{},
nrepeat
);
nrepeat
);
#endif
#endif
if
(
do_verification
)
if
(
do_verification
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment