Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
56863b9a
Commit
56863b9a
authored
Aug 16, 2023
by
Jing Zhang
Browse files
add fp8 support
parents
54df59bf
d4c84256
Changes
250
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
237 additions
and
296 deletions
+237
-296
profiler/include/profiler/profile_pool3d_fwd_impl.hpp
profiler/include/profiler/profile_pool3d_fwd_impl.hpp
+16
-7
profiler/src/CMakeLists.txt
profiler/src/CMakeLists.txt
+27
-24
profiler/src/profile_avg_pool2d_fwd.cpp
profiler/src/profile_avg_pool2d_fwd.cpp
+0
-141
profiler/src/profile_gemm.cpp
profiler/src/profile_gemm.cpp
+7
-1
profiler/src/profile_grouped_gemm.cpp
profiler/src/profile_grouped_gemm.cpp
+2
-2
profiler/src/profile_max_pool3d_fwd.cpp
profiler/src/profile_max_pool3d_fwd.cpp
+69
-48
script/cmake-ck-dev.sh
script/cmake-ck-dev.sh
+2
-1
test/CMakeLists.txt
test/CMakeLists.txt
+1
-1
test/batched_gemm/CMakeLists.txt
test/batched_gemm/CMakeLists.txt
+20
-15
test/batched_gemm_gemm/CMakeLists.txt
test/batched_gemm_gemm/CMakeLists.txt
+7
-5
test/batched_gemm_multi_d/CMakeLists.txt
test/batched_gemm_multi_d/CMakeLists.txt
+2
-3
test/batched_gemm_reduce/CMakeLists.txt
test/batched_gemm_reduce/CMakeLists.txt
+6
-4
test/batched_gemm_softmax_gemm/CMakeLists.txt
test/batched_gemm_softmax_gemm/CMakeLists.txt
+7
-5
test/batched_gemm_softmax_gemm_permute/CMakeLists.txt
test/batched_gemm_softmax_gemm_permute/CMakeLists.txt
+19
-15
test/elementwise_normalization/CMakeLists.txt
test/elementwise_normalization/CMakeLists.txt
+6
-7
test/gemm_layernorm/CMakeLists.txt
test/gemm_layernorm/CMakeLists.txt
+2
-0
test/gemm_reduce/CMakeLists.txt
test/gemm_reduce/CMakeLists.txt
+5
-3
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
...uped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
+12
-0
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface.cpp
...d_bwd_weight/test_grouped_convnd_bwd_weight_interface.cpp
+12
-13
test/grouped_convnd_fwd/grouped_convnd_fwd.cpp
test/grouped_convnd_fwd/grouped_convnd_fwd.cpp
+15
-1
No files found.
profiler/include/profiler/profile_pool3d_fwd_impl.hpp
View file @
56863b9a
...
...
@@ -21,6 +21,8 @@ template <typename InDataType,
typename
OutDataType
,
typename
ComputeDataType
,
typename
IndexDataType
,
typename
InLayout
,
typename
OutLayout
,
ck
::
ReduceTensorOp
ReduceOpId
,
bool
PropagateNan
,
bool
OutputIndex
>
...
...
@@ -31,6 +33,7 @@ bool profile_pool3d_fwd_impl(int do_verification,
std
::
vector
<
index_t
>
in_length
,
// NCDHW
std
::
vector
<
index_t
>
window_spatial_lengths
,
std
::
vector
<
index_t
>
window_strides
,
std
::
vector
<
index_t
>
window_dilations
,
std
::
vector
<
index_t
>
input_left_pads
,
std
::
vector
<
index_t
>
input_right_pads
)
{
...
...
@@ -38,8 +41,8 @@ bool profile_pool3d_fwd_impl(int do_verification,
constexpr
index_t
WindowRank
=
3
;
if
(
in_length
.
size
()
!=
InOutRank
||
window_spatial_lengths
.
size
()
!=
WindowRank
||
window_strides
.
size
()
!=
WindowRank
||
in
put_left_pad
s
.
size
()
!=
WindowRank
||
input_right_pads
.
size
()
!=
WindowRank
)
window_strides
.
size
()
!=
WindowRank
||
w
in
dow_dilation
s
.
size
()
!=
WindowRank
||
input_left_pads
.
size
()
!=
WindowRank
||
input_right_pads
.
size
()
!=
WindowRank
)
return
false
;
std
::
vector
<
index_t
>
out_length
(
InOutRank
);
...
...
@@ -53,11 +56,13 @@ bool profile_pool3d_fwd_impl(int do_verification,
// Calculate Do, Ho, Wo
for
(
int
i
=
2
;
i
<
InOutRank
;
++
i
)
{
auto
pad1
=
input_left_pads
[
i
-
2
];
auto
pad2
=
input_right_pads
[
i
-
2
];
auto
windows_size
=
window_spatial_lengths
[
i
-
2
];
auto
windows_stride
=
window_strides
[
i
-
2
];
out_length
[
i
]
=
(
in_length
[
i
]
+
pad1
+
pad2
-
windows_size
)
/
windows_stride
+
1
;
auto
pad1
=
input_left_pads
[
i
-
2
];
auto
pad2
=
input_right_pads
[
i
-
2
];
auto
windows_size
=
window_spatial_lengths
[
i
-
2
];
auto
windows_stride
=
window_strides
[
i
-
2
];
auto
windows_dilation
=
window_dilations
[
i
-
2
];
auto
eff
=
(
windows_size
-
1
)
*
windows_dilation
+
1
;
out_length
[
i
]
=
(
in_length
[
i
]
+
pad1
+
pad2
-
eff
)
/
windows_stride
+
1
;
}
int
Di
=
in_length
[
2
];
...
...
@@ -104,6 +109,8 @@ bool profile_pool3d_fwd_impl(int do_verification,
InDataType
,
OutDataType
,
IndexDataType
,
InLayout
,
OutLayout
,
ReduceOpId
,
OutputIndex
>
;
...
...
@@ -136,6 +143,7 @@ bool profile_pool3d_fwd_impl(int do_verification,
out_indices_n_c_do_ho_wo_host
,
window_spatial_lengths
,
window_strides
,
window_dilations
,
input_left_pads
,
input_right_pads
);
auto
ref_invoker
=
ref
.
MakeInvoker
();
...
...
@@ -157,6 +165,7 @@ bool profile_pool3d_fwd_impl(int do_verification,
{
Do
*
C
*
Ho
*
Wo
,
1
,
C
*
Ho
*
Wo
,
Wo
*
C
,
C
},
{
Do
*
C
*
Ho
*
Wo
,
1
,
C
*
Ho
*
Wo
,
Wo
*
C
,
C
},
window_strides
,
window_dilations
,
input_left_pads
,
input_right_pads
,
{
2
,
3
,
4
});
...
...
profiler/src/CMakeLists.txt
View file @
56863b9a
...
...
@@ -3,20 +3,11 @@ set(PROFILER_SOURCES
profiler.cpp
profile_gemm.cpp
profile_gemm_splitk.cpp
profile_gemm_streamk.cpp
profile_gemm_bilinear.cpp
profile_gemm_bias_add_reduce.cpp
profile_gemm_add_add_fastgelu.cpp
profile_gemm_add_multiply.cpp
profile_gemm_add_fastgelu.cpp
profile_gemm_add_relu_add_layernorm.cpp
profile_gemm_fastgelu.cpp
profile_gemm_reduce.cpp
profile_batched_gemm.cpp
profile_batched_gemm_gemm.cpp
profile_batched_gemm_add_relu_gemm_add.cpp
profile_batched_gemm_reduce.cpp
profile_grouped_gemm.cpp
profile_conv_fwd.cpp
profile_conv_fwd_bias_relu.cpp
profile_conv_fwd_bias_relu_add.cpp
...
...
@@ -26,13 +17,11 @@ set(PROFILER_SOURCES
profile_reduce.cpp
profile_groupnorm.cpp
profile_layernorm.cpp
profile_avg_pool2d_fwd.cpp
profile_max_pool3d_fwd.cpp
profile_softmax.cpp
profile_batchnorm_fwd.cpp
profile_batchnorm_bwd.cpp
profile_batchnorm_infer.cpp
profile_grouped_gemm_fastgelu.cpp
profile_contraction_bilinear.cpp
profile_contraction_scale.cpp
profile_grouped_conv_bwd_data.cpp
...
...
@@ -40,6 +29,18 @@ set(PROFILER_SOURCES
if
(
DL_KERNELS
)
list
(
APPEND PROFILER_SOURCES profile_batched_gemm_multi_d.cpp
)
endif
()
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
list
(
APPEND PROFILER_SOURCES profile_batched_gemm_gemm.cpp
)
list
(
APPEND PROFILER_SOURCES profile_gemm_fastgelu.cpp
)
list
(
APPEND PROFILER_SOURCES profile_gemm_streamk.cpp
)
list
(
APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp
)
list
(
APPEND PROFILER_SOURCES profile_gemm_add_fastgelu.cpp
)
list
(
APPEND PROFILER_SOURCES profile_gemm_add_add_fastgelu.cpp
)
list
(
APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp
)
list
(
APPEND PROFILER_SOURCES profile_batched_gemm_add_relu_gemm_add.cpp
)
list
(
APPEND PROFILER_SOURCES profile_grouped_gemm.cpp
)
list
(
APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp
)
endif
()
set
(
PROFILER_EXECUTABLE ckProfiler
)
...
...
@@ -49,20 +50,11 @@ target_compile_options(${PROFILER_EXECUTABLE} PRIVATE -Wno-global-constructors)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE utility
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_splitk_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_streamk_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_bilinear_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_add_add_fastgelu_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_add_multiply_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_add_fastgelu_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_fastgelu_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_add_relu_add_layernorm_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_reduce_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_bias_add_reduce_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_batched_gemm_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_batched_gemm_gemm_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_batched_gemm_add_relu_gemm_add_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_batched_gemm_reduce_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_gemm_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_conv2d_fwd_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv1d_fwd_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv2d_fwd_instance
)
...
...
@@ -79,13 +71,24 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_instan
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_softmax_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_reduce_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_batchnorm_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_gemm_fastgelu_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_contraction_bilinear_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_contraction_scale_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_pool_fwd_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_pool3d_fwd_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv2d_bwd_data_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv3d_bwd_data_instance
)
if
(
DL_KERNELS
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_batched_gemm_multi_d_instance
)
endif
()
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv2d_bwd_data_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv3d_bwd_data_instance
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_add_fastgelu_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_add_relu_add_layernorm_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_bilinear_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_add_add_fastgelu_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_streamk_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_fastgelu_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_batched_gemm_gemm_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_batched_gemm_add_relu_gemm_add_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_gemm_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_gemm_fastgelu_instance
)
endif
()
rocm_install
(
TARGETS
${
PROFILER_EXECUTABLE
}
COMPONENT profiler
)
profiler/src/profile_avg_pool2d_fwd.cpp
deleted
100644 → 0
View file @
54df59bf
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <vector>
#include <unordered_map>
#include "profiler/data_type_enum.hpp"
#include "profiler/profile_pool2d_fwd_impl.hpp"
#include "profiler_operation_registry.hpp"
using
ck
::
index_t
;
struct
avgPoolFwdArgParser
{
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
int
>>
long_opts
=
{
{
"length"
,
{}},
{
"wsize"
,
{}},
{
"wstride"
,
{}},
{
"pad1"
,
{}},
{
"pad2"
,
{}}};
bool
parse_opt
(
int
argc
,
char
*
argv
[],
const
std
::
string
&
key
,
int
i
)
{
if
(
std
::
string
(
"--"
)
+
key
==
argv
[
i
])
{
int
pos
=
i
;
while
(
++
i
<
argc
&&
argv
[
i
][
0
]
!=
'-'
)
{}
int
end
=
i
;
for
(
int
j
=
pos
+
1
;
j
<
end
;
j
++
)
{
long_opts
[
key
].
push_back
(
std
::
stoi
(
argv
[
j
]));
}
return
true
;
}
return
false
;
}
void
operator
()(
int
argc
,
char
*
argv
[])
{
for
(
auto
&
kv
:
long_opts
)
{
for
(
int
i
=
1
;
i
<
argc
;
i
++
)
{
if
(
parse_opt
(
argc
,
argv
,
kv
.
first
,
i
))
break
;
}
}
}
};
void
print_help_avg_pool2d_fwd
()
{
std
::
cout
<<
"arg1: data type (0: fp16; 1: fp32)
\n
"
<<
"arg2: verification (0: no; 1: yes)
\n
"
<<
"arg3: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
<<
"arg4: print tensor value (0: no; 1: yes)
\n
"
<<
"arg5: time kernel (0=no, 1=yes)
\n
"
<<
"--length: input tensor length for NDHW(e.g, --length 2 32 30 30)
\n
"
<<
"--wsize: window size for YX (e.g, --wsize 2 2)
\n
"
<<
"--wstride: window stride for HW (e.g, --wstride 2 2)
\n
"
<<
"--pad1: left side of padding in HW (e.g, --pad1 1 1)
\n
"
<<
"--pad2: right side of padding in HW (e.g, --pad2 1 1)
\n
"
<<
"eg: ckProfiler avg_pool2d_fwd 0 1 2 0 1 0 --length 2 32 30 30 --wsize 2 2 "
"--wstride 2 2 --pad1 1 1 --pad2 1 1"
<<
std
::
endl
;
}
int
profile_avg_pool2d_fwd
(
int
argc
,
char
*
argv
[])
{
ck
::
DataTypeEnum
data_type
=
ck
::
DataTypeEnum
::
Half
;
bool
do_verification
=
true
;
int
init_method
=
0
;
bool
do_log
=
false
;
bool
time_kernel
=
true
;
std
::
vector
<
index_t
>
in_length
=
{
2
,
32
,
30
,
30
};
std
::
vector
<
index_t
>
wsize
=
{
2
,
2
};
std
::
vector
<
index_t
>
wstride
=
{
2
,
2
};
std
::
vector
<
index_t
>
pad1
=
{
1
,
1
};
std
::
vector
<
index_t
>
pad2
=
{
1
,
1
};
if
(
argc
!=
2
&&
argc
!=
25
)
{
print_help_avg_pool2d_fwd
();
return
0
;
}
else
if
(
argc
==
25
)
{
data_type
=
static_cast
<
ck
::
DataTypeEnum
>
(
std
::
stoi
(
argv
[
2
]));
do_verification
=
std
::
stoi
(
argv
[
3
]);
init_method
=
std
::
stoi
(
argv
[
4
]);
do_log
=
std
::
stoi
(
argv
[
5
]);
time_kernel
=
std
::
stoi
(
argv
[
6
]);
// parse the long options
avgPoolFwdArgParser
arg_parser
;
arg_parser
(
argc
,
argv
);
in_length
=
arg_parser
.
long_opts
[
"length"
];
wsize
=
arg_parser
.
long_opts
[
"wsize"
];
wstride
=
arg_parser
.
long_opts
[
"wstride"
];
pad1
=
arg_parser
.
long_opts
[
"pad1"
];
pad2
=
arg_parser
.
long_opts
[
"pad2"
];
}
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
I32
=
int32_t
;
constexpr
auto
ReduceOpId
=
ck
::
ReduceTensorOp
::
AVG
;
if
(
data_type
==
ck
::
DataTypeEnum
::
Half
)
{
ck
::
profiler
::
profile_pool2d_fwd_impl
<
F16
,
F16
,
F32
,
I32
,
ReduceOpId
,
false
,
false
>
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
in_length
,
wsize
,
wstride
,
pad1
,
pad2
);
}
else
if
(
data_type
==
ck
::
DataTypeEnum
::
Float
)
{
ck
::
profiler
::
profile_pool2d_fwd_impl
<
F32
,
F32
,
F32
,
I32
,
ReduceOpId
,
false
,
false
>
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
in_length
,
wsize
,
wstride
,
pad1
,
pad2
);
}
else
{
throw
std
::
runtime_error
(
"not implemented yet"
);
}
return
0
;
}
REGISTER_PROFILER_OPERATION
(
"avg_pool2d_fwd"
,
"avg_pool2d fwd"
,
profile_avg_pool2d_fwd
);
profiler/src/profile_gemm.cpp
View file @
56863b9a
...
...
@@ -121,7 +121,10 @@ int profile_gemm(int argc, char* argv[])
return
pass
?
0
:
1
;
};
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
if
(
false
)
;
#ifdef __fp32__
else
if
(
data_type
==
GemmDataType
::
F32_F32_F32
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
return
profile
(
Row
{},
Row
{},
Row
{},
F32
{},
F32
{},
F32
{},
F32
{});
}
...
...
@@ -137,6 +140,8 @@ int profile_gemm(int argc, char* argv[])
{
return
profile
(
Col
{},
Col
{},
Row
{},
F32
{},
F32
{},
F32
{},
F32
{});
}
#endif
#ifdef __fp16__
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
return
profile
(
Row
{},
Row
{},
Row
{},
F16
{},
F16
{},
F32
{},
F16
{});
...
...
@@ -153,6 +158,7 @@ int profile_gemm(int argc, char* argv[])
{
return
profile
(
Col
{},
Col
{},
Row
{},
F16
{},
F16
{},
F32
{},
F16
{});
}
#endif
#ifdef __bf16__
else
if
(
data_type
==
GemmDataType
::
BF16_BF16_BF16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
...
...
profiler/src/profile_grouped_gemm.cpp
View file @
56863b9a
...
...
@@ -88,7 +88,7 @@ int profile_grouped_gemm(int argc, char* argv[])
const
auto
StrideBs
=
argToIntArray
(
argv
[
12
]);
const
auto
StrideCs
=
argToIntArray
(
argv
[
13
]);
const
int
kbatch
=
argc
==
15
?
std
::
stoi
(
argv
[
14
])
:
1
;
#ifdef __fp16__
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
ck
::
profiler
::
profile_grouped_gemm_impl
<
ck
::
half_t
,
...
...
@@ -173,7 +173,7 @@ int profile_grouped_gemm(int argc, char* argv[])
{
throw
std
::
runtime_error
(
"wrong! this GEMM data_type & layout is not implemented"
);
}
#endif
return
0
;
}
...
...
profiler/src/profile_max_pool3d_fwd.cpp
View file @
56863b9a
...
...
@@ -13,8 +13,12 @@ using ck::index_t;
struct
maxPoolFwdArgParser
{
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
int
>>
long_opts
=
{
{
"length"
,
{}},
{
"wsize"
,
{}},
{
"wstride"
,
{}},
{
"pad1"
,
{}},
{
"pad2"
,
{}}};
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
int
>>
long_opts
=
{{
"length"
,
{}},
{
"wsize"
,
{}},
{
"wstride"
,
{}},
{
"wdilation"
,
{}},
{
"pad1"
,
{}},
{
"pad2"
,
{}}};
bool
parse_opt
(
int
argc
,
char
*
argv
[],
const
std
::
string
&
key
,
int
i
)
{
...
...
@@ -56,10 +60,11 @@ void print_help_max_pool3d_fwd()
<<
"--length: input tensor length for NCDHW(e.g, --length 2 32 30 30 30)
\n
"
<<
"--wsize: window size for ZYX (e.g, --wsize 2 2 2)
\n
"
<<
"--wstride: window stride for DHW (e.g, --wstride 2 2 2)
\n
"
<<
"--wdilation: window dilation for DHW (e.g, --wdilation 1 1 1)
\n
"
<<
"--pad1: left side of padding in DHW (e.g, --pad1 1 1 1)
\n
"
<<
"--pad2: right side of padding in DHW (e.g, --pad2 1 1 1)
\n
"
<<
"eg: ckProfiler max_pool3d_fwd 0 1 2 0 1 0 --length 2 32 30 30 30 --wsize 2 2 2 "
"--wstride 2 2 2 --pad1 1 1 1 --pad2 1 1 1"
"--wstride 2 2 2
--wdilation 1 1 1
--pad1 1 1 1 --pad2 1 1 1"
<<
std
::
endl
;
}
...
...
@@ -75,15 +80,16 @@ int profile_max_pool3d_fwd(int argc, char* argv[])
std
::
vector
<
index_t
>
in_length
=
{
2
,
32
,
30
,
30
,
30
};
std
::
vector
<
index_t
>
wsize
=
{
2
,
2
,
2
};
std
::
vector
<
index_t
>
wstride
=
{
2
,
2
,
2
};
std
::
vector
<
index_t
>
wdilation
=
{
1
,
1
,
1
};
std
::
vector
<
index_t
>
pad1
=
{
1
,
1
,
1
};
std
::
vector
<
index_t
>
pad2
=
{
1
,
1
,
1
};
if
(
argc
!=
2
&&
argc
!=
3
0
)
if
(
argc
!=
2
&&
argc
!=
3
4
)
{
print_help_max_pool3d_fwd
();
return
0
;
}
else
if
(
argc
==
3
0
)
else
if
(
argc
==
3
4
)
{
data_type
=
static_cast
<
ck
::
DataTypeEnum
>
(
std
::
stoi
(
argv
[
2
]));
do_verification
=
std
::
stoi
(
argv
[
3
]);
...
...
@@ -98,64 +104,79 @@ int profile_max_pool3d_fwd(int argc, char* argv[])
in_length
=
arg_parser
.
long_opts
[
"length"
];
wsize
=
arg_parser
.
long_opts
[
"wsize"
];
wstride
=
arg_parser
.
long_opts
[
"wstride"
];
wdilation
=
arg_parser
.
long_opts
[
"wdilation"
];
pad1
=
arg_parser
.
long_opts
[
"pad1"
];
pad2
=
arg_parser
.
long_opts
[
"pad2"
];
}
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
I32
=
int32_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
I32
=
int32_t
;
using
NDHWC
=
ck
::
tensor_layout
::
convolution
::
NDHWC
;
#if 1
constexpr
auto
ReduceOpId
=
ck
::
ReduceTensorOp
::
MAX
;
#else
constexpr
auto
ReduceOpId
=
ck
::
ReduceTensorOp
::
AVG
;
#endif
if
(
data_type
==
ck
::
DataTypeEnum
::
Half
)
{
if
(
return_index
)
ck
::
profiler
::
profile_pool3d_fwd_impl
<
F16
,
F16
,
F16
,
I32
,
ReduceOpId
,
false
,
true
>
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
in_length
,
wsize
,
wstride
,
pad1
,
pad2
);
ck
::
profiler
::
profile_pool3d_fwd_impl
<
F16
,
F16
,
F16
,
I32
,
NDHWC
,
NDHWC
,
ReduceOpId
,
false
,
true
>
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
in_length
,
wsize
,
wstride
,
wdilation
,
pad1
,
pad2
);
else
ck
::
profiler
::
profile_pool3d_fwd_impl
<
F16
,
F16
,
F16
,
I32
,
ReduceOpId
,
false
,
false
>
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
in_length
,
wsize
,
wstride
,
pad1
,
pad2
);
ck
::
profiler
::
profile_pool3d_fwd_impl
<
F16
,
F16
,
F16
,
I32
,
NDHWC
,
NDHWC
,
ReduceOpId
,
false
,
false
>
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
in_length
,
wsize
,
wstride
,
wdilation
,
pad1
,
pad2
);
}
else
if
(
data_type
==
ck
::
DataTypeEnum
::
Float
)
{
if
(
return_index
)
ck
::
profiler
::
profile_pool3d_fwd_impl
<
F32
,
F32
,
F32
,
I32
,
ReduceOpId
,
false
,
true
>
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
in_length
,
wsize
,
wstride
,
pad1
,
pad2
);
ck
::
profiler
::
profile_pool3d_fwd_impl
<
F32
,
F32
,
F32
,
I32
,
NDHWC
,
NDHWC
,
ReduceOpId
,
false
,
true
>
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
in_length
,
wsize
,
wstride
,
wdilation
,
pad1
,
pad2
);
else
ck
::
profiler
::
profile_pool3d_fwd_impl
<
F32
,
F32
,
F32
,
I32
,
ReduceOpId
,
false
,
false
>
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
in_length
,
wsize
,
wstride
,
pad1
,
pad2
);
ck
::
profiler
::
profile_pool3d_fwd_impl
<
F32
,
F32
,
F32
,
I32
,
NDHWC
,
NDHWC
,
ReduceOpId
,
false
,
false
>
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
in_length
,
wsize
,
wstride
,
wdilation
,
pad1
,
pad2
);
}
else
{
...
...
script/cmake-ck-dev.sh
View file @
56863b9a
...
...
@@ -12,7 +12,8 @@ cmake
-save-temps=
$PWD
"
\
-D
CMAKE_BUILD_TYPE
=
Release
\
-D
BUILD_DEV
=
ON
\
-D
GPU_TARGETS
=
"gfx908;gfx90a;gfx940"
\
-D
DTYPES
=
"fp32;fp16;bf16;fp8"
\
-D
GPU_TARGETS
=
"gfx90a"
\
-D
CMAKE_VERBOSE_MAKEFILE:BOOL
=
ON
\
-D
USE_BITINT_EXTENSION_INT4
=
OFF
\
${
MY_PROJECT_SOURCE
}
test/CMakeLists.txt
View file @
56863b9a
...
...
@@ -60,6 +60,6 @@ add_subdirectory(contraction)
add_subdirectory
(
pool_fwd
)
add_subdirectory
(
batched_gemm_multi_d
)
add_subdirectory
(
grouped_convnd_bwd_data
)
if
(
GPU_TARGETS MATCHES
"gfx11
00
"
)
if
(
GPU_TARGETS MATCHES
"gfx11"
)
add_subdirectory
(
wmma_op
)
endif
()
test/batched_gemm/CMakeLists.txt
View file @
56863b9a
...
...
@@ -2,21 +2,26 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_test_executable
(
test_batched_gemm_fp16 batched_gemm_fp16.cpp
)
target_link_libraries
(
test_batched_gemm_fp16 PRIVATE utility
)
target_link_libraries
(
test_batched_gemm_fp16 PRIVATE device_batched_gemm_instance
)
add_test_executable
(
test_batched_gemm_fp32 batched_gemm_fp32.cpp
)
target_link_libraries
(
test_batched_gemm_fp32 PRIVATE utility
)
target_link_libraries
(
test_batched_gemm_fp32 PRIVATE device_batched_gemm_instance
)
add_test_executable
(
test_batched_gemm_bf16 batched_gemm_bf16.cpp
)
target_link_libraries
(
test_batched_gemm_bf16 PRIVATE utility
)
target_link_libraries
(
test_batched_gemm_bf16 PRIVATE device_batched_gemm_instance
)
add_test_executable
(
test_batched_gemm_int8 batched_gemm_int8.cpp
)
target_link_libraries
(
test_batched_gemm_int8 PRIVATE utility
)
target_link_libraries
(
test_batched_gemm_int8 PRIVATE device_batched_gemm_instance
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_test_executable
(
test_batched_gemm_fp16 batched_gemm_fp16.cpp
)
target_link_libraries
(
test_batched_gemm_fp16 PRIVATE utility
)
target_link_libraries
(
test_batched_gemm_fp16 PRIVATE device_batched_gemm_instance
)
endif
()
if
(
DTYPES MATCHES
"fp32"
OR NOT DEFINED DTYPES
)
add_test_executable
(
test_batched_gemm_fp32 batched_gemm_fp32.cpp
)
target_link_libraries
(
test_batched_gemm_fp32 PRIVATE utility
)
target_link_libraries
(
test_batched_gemm_fp32 PRIVATE device_batched_gemm_instance
)
endif
()
if
(
DTYPES MATCHES
"bf16"
OR NOT DEFINED DTYPES
)
add_test_executable
(
test_batched_gemm_bf16 batched_gemm_bf16.cpp
)
target_link_libraries
(
test_batched_gemm_bf16 PRIVATE utility
)
target_link_libraries
(
test_batched_gemm_bf16 PRIVATE device_batched_gemm_instance
)
endif
()
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
add_test_executable
(
test_batched_gemm_int8 batched_gemm_int8.cpp
)
target_link_libraries
(
test_batched_gemm_int8 PRIVATE utility
)
target_link_libraries
(
test_batched_gemm_int8 PRIVATE device_batched_gemm_instance
)
endif
()
set
(
target 1
)
endif
()
endforeach
()
\ No newline at end of file
test/batched_gemm_gemm/CMakeLists.txt
View file @
56863b9a
...
...
@@ -2,10 +2,12 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_custom_target
(
test_batched_gemm_gemm
)
add_gtest_executable
(
test_batched_gemm_gemm_fp16 test_batched_gemm_gemm_fp16.cpp
)
target_link_libraries
(
test_batched_gemm_gemm_fp16 PRIVATE utility device_batched_gemm_gemm_instance
)
add_dependencies
(
test_batched_gemm_gemm test_batched_gemm_gemm_fp16
)
set
(
target 1
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_custom_target
(
test_batched_gemm_gemm
)
add_gtest_executable
(
test_batched_gemm_gemm_fp16 test_batched_gemm_gemm_fp16.cpp
)
target_link_libraries
(
test_batched_gemm_gemm_fp16 PRIVATE utility device_batched_gemm_gemm_instance
)
add_dependencies
(
test_batched_gemm_gemm test_batched_gemm_gemm_fp16
)
set
(
target 1
)
endif
()
endif
()
endforeach
()
\ No newline at end of file
test/batched_gemm_multi_d/CMakeLists.txt
View file @
56863b9a
# TODO: Enable for gfx90a after complier fix
if
(
DL_KERNELS
)
add_gtest_executable
(
test_batched_gemm_multi_d test_batched_gemm_multi_d.cpp
)
target_link_libraries
(
test_batched_gemm_multi_d PRIVATE utility device_batched_gemm_multi_d_instance
)
add_gtest_executable
(
test_batched_gemm_multi_d test_batched_gemm_multi_d.cpp
)
target_link_libraries
(
test_batched_gemm_multi_d PRIVATE utility device_batched_gemm_multi_d_instance
)
endif
()
test/batched_gemm_reduce/CMakeLists.txt
View file @
56863b9a
...
...
@@ -2,9 +2,11 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_test_executable
(
test_batched_gemm_reduce_fp16 batched_gemm_reduce_fp16.cpp
)
target_link_libraries
(
test_batched_gemm_reduce_fp16 PRIVATE utility
)
target_link_libraries
(
test_batched_gemm_reduce_fp16 PRIVATE device_batched_gemm_reduce_instance
)
set
(
target 1
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_test_executable
(
test_batched_gemm_reduce_fp16 batched_gemm_reduce_fp16.cpp
)
target_link_libraries
(
test_batched_gemm_reduce_fp16 PRIVATE utility
)
target_link_libraries
(
test_batched_gemm_reduce_fp16 PRIVATE device_batched_gemm_reduce_instance
)
set
(
target 1
)
endif
()
endif
()
endforeach
()
test/batched_gemm_softmax_gemm/CMakeLists.txt
View file @
56863b9a
...
...
@@ -2,10 +2,12 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_custom_target
(
test_batched_gemm_softmax_gemm
)
add_gtest_executable
(
test_batched_gemm_softmax_gemm_fp16 test_batched_gemm_softmax_gemm_fp16.cpp
)
target_link_libraries
(
test_batched_gemm_softmax_gemm_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_instance
)
add_dependencies
(
test_batched_gemm_softmax_gemm test_batched_gemm_softmax_gemm_fp16
)
set
(
target 1
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_custom_target
(
test_batched_gemm_softmax_gemm
)
add_gtest_executable
(
test_batched_gemm_softmax_gemm_fp16 test_batched_gemm_softmax_gemm_fp16.cpp
)
target_link_libraries
(
test_batched_gemm_softmax_gemm_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_instance
)
add_dependencies
(
test_batched_gemm_softmax_gemm test_batched_gemm_softmax_gemm_fp16
)
set
(
target 1
)
endif
()
endif
()
endforeach
()
\ No newline at end of file
test/batched_gemm_softmax_gemm_permute/CMakeLists.txt
View file @
56863b9a
...
...
@@ -2,21 +2,25 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_custom_target
(
test_batched_gemm_softmax_gemm_permute
)
add_gtest_executable
(
test_batched_gemm_softmax_gemm_permute_fp16 test_batched_gemm_softmax_gemm_permute_fp16.cpp
)
add_gtest_executable
(
test_batched_gemm_softmax_gemm_permute_bf16 test_batched_gemm_softmax_gemm_permute_bf16.cpp
)
target_link_libraries
(
test_batched_gemm_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance
)
target_link_libraries
(
test_batched_gemm_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance
)
add_dependencies
(
test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_fp16
)
add_dependencies
(
test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_bf16
)
add_gtest_executable
(
test_batched_gemm_bias_softmax_gemm_permute_fp16 test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp
)
add_gtest_executable
(
test_batched_gemm_bias_softmax_gemm_permute_bf16 test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp
)
target_link_libraries
(
test_batched_gemm_bias_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance
)
target_link_libraries
(
test_batched_gemm_bias_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance
)
add_dependencies
(
test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_fp16
)
add_dependencies
(
test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_bf16
)
if
(
DTYPES MATCHES
"fp16"
OR DTYPES MATCHES
"bf16"
OR NOT DEFINED DTYPES
)
add_custom_target
(
test_batched_gemm_softmax_gemm_permute
)
endif
()
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_gtest_executable
(
test_batched_gemm_softmax_gemm_permute_fp16 test_batched_gemm_softmax_gemm_permute_fp16.cpp
)
add_gtest_executable
(
test_batched_gemm_bias_softmax_gemm_permute_fp16 test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp
)
target_link_libraries
(
test_batched_gemm_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance
)
target_link_libraries
(
test_batched_gemm_bias_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance
)
add_dependencies
(
test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_fp16
)
add_dependencies
(
test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_fp16
)
endif
()
if
(
DTYPES MATCHES
"bf16"
OR NOT DEFINED DTYPES
)
add_gtest_executable
(
test_batched_gemm_softmax_gemm_permute_bf16 test_batched_gemm_softmax_gemm_permute_bf16.cpp
)
add_gtest_executable
(
test_batched_gemm_bias_softmax_gemm_permute_bf16 test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp
)
target_link_libraries
(
test_batched_gemm_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance
)
target_link_libraries
(
test_batched_gemm_bias_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance
)
add_dependencies
(
test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_bf16
)
add_dependencies
(
test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_bf16
)
endif
()
set
(
target 1
)
endif
()
endforeach
()
\ No newline at end of file
test/elementwise_normalization/CMakeLists.txt
View file @
56863b9a
add_custom_target
(
test_elementwise_normalization
)
add_gtest_executable
(
test_elementwise_layernorm_fp16 test_elementwise_layernorm_fp16.cpp
)
target_link_libraries
(
test_elementwise_layernorm_fp16 PRIVATE utility device_elementwise_normalization_instance
)
add_dependencies
(
test_elementwise_normalization test_elementwise_layernorm_fp16
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_custom_target
(
test_elementwise_normalization
)
add_gtest_executable
(
test_elementwise_layernorm_fp16 test_elementwise_layernorm_fp16.cpp
)
target_link_libraries
(
test_elementwise_layernorm_fp16 PRIVATE utility device_elementwise_normalization_instance
)
add_dependencies
(
test_elementwise_normalization test_elementwise_layernorm_fp16
)
endif
()
\ No newline at end of file
test/gemm_layernorm/CMakeLists.txt
View file @
56863b9a
...
...
@@ -2,10 +2,12 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_custom_target
(
test_gemm_layernorm
)
add_gtest_executable
(
test_gemm_add_relu_add_layernorm_fp16 test_gemm_add_relu_add_layernorm_fp16.cpp
)
target_link_libraries
(
test_gemm_add_relu_add_layernorm_fp16 PRIVATE utility device_gemm_add_relu_add_layernorm_instance
)
add_dependencies
(
test_gemm_layernorm test_gemm_add_relu_add_layernorm_fp16
)
set
(
target 1
)
endif
()
endif
()
endforeach
()
test/gemm_reduce/CMakeLists.txt
View file @
56863b9a
add_test_executable
(
test_gemm_reduce_fp16 gemm_reduce_fp16.cpp
)
target_link_libraries
(
test_gemm_reduce_fp16 PRIVATE utility
)
target_link_libraries
(
test_gemm_reduce_fp16 PRIVATE device_gemm_reduce_instance
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_test_executable
(
test_gemm_reduce_fp16 gemm_reduce_fp16.cpp
)
target_link_libraries
(
test_gemm_reduce_fp16 PRIVATE utility
)
target_link_libraries
(
test_gemm_reduce_fp16 PRIVATE device_gemm_reduce_instance
)
endif
()
\ No newline at end of file
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
View file @
56863b9a
...
...
@@ -100,6 +100,9 @@ TYPED_TEST(TestGroupedConvndBwdWeight1d, Test1D)
this
->
conv_params
.
push_back
({
1
,
2
,
128
,
128
,
256
,
{
1
},
{
14
},
{
2
},
{
1
},
{
0
},
{
0
}});
this
->
conv_params
.
push_back
({
1
,
2
,
32
,
128
,
256
,
{
3
},
{
28
},
{
1
},
{
1
},
{
1
},
{
1
}});
this
->
conv_params
.
push_back
({
1
,
2
,
128
,
128
,
256
,
{
1
},
{
3
},
{
1
},
{
1
},
{
0
},
{
0
}});
this
->
conv_params
.
push_back
({
1
,
1
,
1
,
1
,
32
,
{
3
},
{
32
},
{
1
},
{
1
},
{
1
},
{
1
}});
this
->
conv_params
.
push_back
({
1
,
1
,
1
,
64
,
3
,
{
3
},
{
32
},
{
1
},
{
1
},
{
1
},
{
1
}});
this
->
conv_params
.
push_back
({
1
,
1
,
1
,
1
,
1
,
{
3
},
{
32
},
{
1
},
{
1
},
{
1
},
{
1
}});
this
->
Run
();
}
...
...
@@ -112,6 +115,9 @@ TYPED_TEST(TestGroupedConvndBwdWeight2d, Test2D)
{
2
,
2
,
4
,
128
,
256
,
{
3
,
3
},
{
14
,
14
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
conv_params
.
push_back
(
{
2
,
2
,
128
,
128
,
256
,
{
1
,
1
},
{
3
,
3
},
{
1
,
1
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
this
->
conv_params
.
push_back
({
2
,
1
,
1
,
1
,
32
,
{
3
,
3
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
conv_params
.
push_back
({
2
,
1
,
1
,
64
,
3
,
{
3
,
3
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
conv_params
.
push_back
({
2
,
1
,
1
,
1
,
1
,
{
3
,
3
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
Run
();
}
...
...
@@ -124,5 +130,11 @@ TYPED_TEST(TestGroupedConvndBwdWeight3d, Test3D)
{
3
,
2
,
2
,
128
,
256
,
{
3
,
3
,
3
},
{
14
,
14
,
3
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
this
->
conv_params
.
push_back
(
{
3
,
2
,
32
,
128
,
256
,
{
1
,
1
,
1
},
{
3
,
3
,
3
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}});
this
->
conv_params
.
push_back
(
{
3
,
1
,
1
,
1
,
32
,
{
3
,
3
,
3
},
{
32
,
32
,
32
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
this
->
conv_params
.
push_back
(
{
3
,
1
,
1
,
64
,
3
,
{
3
,
3
,
3
},
{
32
,
32
,
32
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
this
->
conv_params
.
push_back
(
{
3
,
1
,
1
,
1
,
1
,
{
3
,
3
,
3
},
{
32
,
32
,
32
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
this
->
Run
();
}
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface.cpp
View file @
56863b9a
...
...
@@ -70,10 +70,11 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
ck
::
utils
::
conv
::
make_output_host_tensor_descriptor_g_n_k_wos_packed
<
OutLayout
>
(
conv_param
);
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_
spatial_
lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_
spatial_
lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_
spatial_
lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
input_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
filter_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
output_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
input_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
weights_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
output_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
{};
...
...
@@ -82,10 +83,11 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
auto
range_copy
=
[](
const
auto
&
from
,
auto
to
)
{
std
::
copy
(
begin
(
from
),
end
(
from
),
to
);
};
range_copy
(
conv_param
.
input_spatial_lengths_
,
begin
(
input_spatial_lengths
));
range_copy
(
conv_param
.
filter_spatial_lengths_
,
begin
(
filter_spatial_lengths
));
range_copy
(
conv_param
.
output_spatial_lengths_
,
begin
(
output_spatial_lengths
));
range_copy
(
in_g_n_c_wis_desc
.
GetLengths
(),
begin
(
input_lengths
));
range_copy
(
in_g_n_c_wis_desc
.
GetStrides
(),
begin
(
input_strides
));
range_copy
(
wei_g_k_c_xs_desc
.
GetLengths
(),
begin
(
filter_lengths
));
range_copy
(
wei_g_k_c_xs_desc
.
GetStrides
(),
begin
(
weights_strides
));
range_copy
(
out_g_n_k_wos_desc
.
GetLengths
(),
begin
(
output_lengths
));
range_copy
(
out_g_n_k_wos_desc
.
GetStrides
(),
begin
(
output_strides
));
range_copy
(
conv_param
.
conv_filter_strides_
,
begin
(
conv_filter_strides
));
range_copy
(
conv_param
.
conv_filter_dilations_
,
begin
(
conv_filter_dilations
));
...
...
@@ -97,14 +99,11 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
auto
argument
=
conv
.
MakeArgument
(
nullptr
,
nullptr
,
nullptr
,
conv_param
.
G_
,
conv_param
.
N_
,
conv_param
.
K_
,
conv_param
.
C_
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
input_lengths
,
input_strides
,
filter_lengths
,
weights_strides
,
output_lengths
,
output_strides
,
conv_filter_strides
,
conv_filter_dilations
,
...
...
test/grouped_convnd_fwd/grouped_convnd_fwd.cpp
View file @
56863b9a
...
...
@@ -22,6 +22,8 @@ TEST_F(TestGroupedConvNdFwd, GroupedConv1dFwdGNWC)
conv_params
.
push_back
({
1
,
2
,
128
,
128
,
256
,
{
1
},
{
14
},
{
2
},
{
1
},
{
0
},
{
0
}});
conv_params
.
push_back
({
1
,
2
,
128
,
128
,
256
,
{
3
},
{
28
},
{
1
},
{
1
},
{
1
},
{
1
}});
conv_params
.
push_back
({
1
,
2
,
128
,
128
,
256
,
{
1
},
{
3
},
{
1
},
{
1
},
{
0
},
{
0
}});
conv_params
.
push_back
({
1
,
1
,
1
,
1
,
32
,
{
3
},
{
32
},
{
1
},
{
1
},
{
1
},
{
1
}});
conv_params
.
push_back
({
1
,
1
,
1
,
64
,
3
,
{
3
},
{
32
},
{
1
},
{
1
},
{
1
},
{
1
}});
for
(
auto
&
param
:
conv_params
)
{
...
...
@@ -96,6 +98,9 @@ TEST_F(TestGroupedConvNdFwd, GroupedConv2dFwdGNHWC)
conv_params
.
push_back
({
2
,
2
,
128
,
128
,
256
,
{
1
,
1
},
{
7
,
7
},
{
2
,
2
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
conv_params
.
push_back
({
2
,
2
,
128
,
128
,
256
,
{
3
,
3
},
{
14
,
14
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
conv_params
.
push_back
({
2
,
2
,
128
,
128
,
256
,
{
1
,
1
},
{
3
,
3
},
{
1
,
1
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
conv_params
.
push_back
({
2
,
1
,
1
,
1
,
32
,
{
3
,
3
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
conv_params
.
push_back
({
2
,
1
,
1
,
64
,
3
,
{
3
,
3
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
conv_params
.
push_back
({
2
,
1
,
1
,
1
,
1
,
{
3
,
3
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
for
(
auto
&
param
:
conv_params
)
{
...
...
@@ -173,6 +178,12 @@ TEST_F(TestGroupedConvNdFwd, GroupedConv3dFwdGNDHWC)
{
3
,
2
,
128
,
128
,
256
,
{
3
,
3
,
3
},
{
14
,
14
,
3
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
conv_params
.
push_back
(
{
3
,
2
,
128
,
128
,
256
,
{
1
,
1
,
1
},
{
3
,
3
,
3
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}});
conv_params
.
push_back
(
{
3
,
1
,
1
,
1
,
32
,
{
3
,
3
,
3
},
{
32
,
32
,
32
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
this
->
conv_params
.
push_back
(
{
3
,
1
,
1
,
64
,
3
,
{
3
,
3
,
3
},
{
32
,
32
,
32
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
conv_params
.
push_back
(
{
3
,
1
,
1
,
1
,
1
,
{
3
,
3
,
3
},
{
32
,
32
,
32
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
for
(
auto
&
param
:
conv_params
)
{
...
...
@@ -247,6 +258,9 @@ TEST_F(TestGroupedConvNdFwd, GroupedConv2dFwdNHWGC)
conv_params
.
push_back
({
2
,
2
,
128
,
128
,
256
,
{
1
,
1
},
{
7
,
7
},
{
2
,
2
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
conv_params
.
push_back
({
2
,
2
,
128
,
128
,
256
,
{
3
,
3
},
{
14
,
14
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
conv_params
.
push_back
({
2
,
2
,
128
,
128
,
256
,
{
1
,
1
},
{
3
,
3
},
{
1
,
1
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
conv_params
.
push_back
({
2
,
1
,
1
,
1
,
32
,
{
3
,
3
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
conv_params
.
push_back
({
2
,
1
,
1
,
64
,
3
,
{
3
,
3
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
conv_params
.
push_back
({
2
,
1
,
1
,
1
,
1
,
{
3
,
3
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
for
(
auto
&
param
:
conv_params
)
{
...
...
@@ -255,7 +269,7 @@ TEST_F(TestGroupedConvNdFwd, GroupedConv2dFwdNHWGC)
// fp16
pass
=
ck
::
profiler
::
profile_grouped_conv_fwd_impl
<
2
,
ck
::
tensor_layout
::
convolution
::
NHWGC
,
ck
::
tensor_layout
::
convolution
::
KYX
G
C
,
ck
::
tensor_layout
::
convolution
::
G
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWGK
,
ck
::
half_t
,
ck
::
half_t
,
...
...
Prev
1
…
8
9
10
11
12
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment