Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ce87bcc7
Commit
ce87bcc7
authored
Sep 08, 2023
by
Bartlomiej Kocot
Browse files
tmp
parent
c8a8385f
Changes
123
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
48 additions
and
47 deletions
+48
-47
example/17_convnd_bwd_data/convnd_bwd_data_common.hpp
example/17_convnd_bwd_data/convnd_bwd_data_common.hpp
+2
-1
example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp
...e/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp
+6
-6
example/19_binary_elementwise/broadcast_add_2d_amn_bn.cpp
example/19_binary_elementwise/broadcast_add_2d_amn_bn.cpp
+2
-2
example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp
example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp
+2
-3
example/19_binary_elementwise/elementwise_add_1d.cpp
example/19_binary_elementwise/elementwise_add_1d.cpp
+2
-3
example/19_binary_elementwise/elementwise_add_4d.cpp
example/19_binary_elementwise/elementwise_add_4d.cpp
+2
-3
example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc
...d_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc
+1
-1
example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp
...layernorm/gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp
+2
-2
example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp
...yernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp
+2
-3
example/21_gemm_layernorm/gemm_layernorm_xdl_naive_fp16.cpp
example/21_gemm_layernorm/gemm_layernorm_xdl_naive_fp16.cpp
+2
-2
example/21_gemm_layernorm/gemm_xdl_layernorm_naive_single_kernel_fp16.cpp
...layernorm/gemm_xdl_layernorm_naive_single_kernel_fp16.cpp
+3
-3
example/22_cgemm/cgemm_xdl_common.hpp
example/22_cgemm/cgemm_xdl_common.hpp
+5
-5
example/23_softmax/softmax_blockwise.cpp
example/23_softmax/softmax_blockwise.cpp
+2
-3
example/24_batched_gemm/run_batched_gemm_example.inc
example/24_batched_gemm/run_batched_gemm_example.inc
+3
-4
example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp
..._bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp
+2
-1
example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp
..._bias_e_permute/gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp
+2
-1
example/26_contraction/contraction_bilinear_xdl_fp32.cpp
example/26_contraction/contraction_bilinear_xdl_fp32.cpp
+2
-1
example/26_contraction/contraction_bilinear_xdl_fp64.cpp
example/26_contraction/contraction_bilinear_xdl_fp64.cpp
+2
-1
example/26_contraction/contraction_scale_xdl_fp32.cpp
example/26_contraction/contraction_scale_xdl_fp32.cpp
+2
-1
example/26_contraction/contraction_scale_xdl_fp64.cpp
example/26_contraction/contraction_scale_xdl_fp64.cpp
+2
-1
No files found.
example/17_convnd_bwd_data/convnd_bwd_data_common.hpp
View file @
ce87bcc7
...
@@ -146,7 +146,8 @@ int run_conv_bwd_data(bool do_verification,
...
@@ -146,7 +146,8 @@ int run_conv_bwd_data(bool do_verification,
in_device_buf
.
FromDevice
(
in_device
.
mData
.
data
());
in_device_buf
.
FromDevice
(
in_device
.
mData
.
data
());
return
ck
::
utils
::
check_err
(
in_device
,
in_host
)
?
0
:
1
;
validator
.
check_err
(
in_device
,
in_host
);
return
!
validator
.
is_success
()
}
}
return
0
;
return
0
;
...
...
example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp
View file @
ce87bcc7
...
@@ -293,19 +293,19 @@ int main(int argc, char* argv[])
...
@@ -293,19 +293,19 @@ int main(int argc, char* argv[])
}
}
}
}
pass
=
ck
::
utils
::
check_err
(
validator
.
check_err
(
c_g_m_n_host_result
,
c_g_m_n_device_result
,
"Error: Incorrect results c"
)
&&
c_g_m_n_host_result
,
c_g_m_n_device_result
,
"Error: Incorrect results c"
)
;
ck
::
utils
::
check_err
(
d0_g_m_device_result
,
validator
.
check_err
(
d0_g_m_device_result
,
d0_g_m_host_result
,
d0_g_m_host_result
,
"Error: Incorrect results! D0"
,
"Error: Incorrect results! D0"
,
1e-4
,
1e-4
,
1e-5
)
&&
1e-5
)
;
ck
::
utils
::
check_err
(
d1_g_m_device_result
,
validator
.
check_err
(
d1_g_m_device_result
,
d1_g_m_host_result
,
d1_g_m_host_result
,
"Error: Incorrect results! D1"
,
"Error: Incorrect results! D1"
,
1e-3
,
1e-3
,
1e-5
);
1e-5
);
}
}
return
pass
?
0
:
1
;
return
!
validator
.
is_success
()
;
}
}
example/19_binary_elementwise/broadcast_add_2d_amn_bn.cpp
View file @
ce87bcc7
...
@@ -129,8 +129,8 @@ int main()
...
@@ -129,8 +129,8 @@ int main()
host_broadcast2D
<
Tensor
<
ABDataType
>
,
Tensor
<
ABDataType
>
,
Tensor
<
CDataType
>
,
Add
,
0
>
(
host_broadcast2D
<
Tensor
<
ABDataType
>
,
Tensor
<
ABDataType
>
,
Tensor
<
CDataType
>
,
Add
,
0
>
(
host_c_m_n
,
a_m_n
,
b_n
,
M
,
N
,
Add
{});
host_c_m_n
,
a_m_n
,
b_n
,
M
,
N
,
Add
{});
pass
&=
ck
::
utils
::
check_err
(
c_m_n
,
host_c_m_n
,
"Error: Incorrect results c"
,
1e-3
,
1e-3
);
validator
.
check_err
(
c_m_n
,
host_c_m_n
,
"Error: Incorrect results c"
,
1e-3
,
1e-3
);
}
}
return
pass
?
0
:
1
;
return
!
validator
.
is_success
()
;
}
}
example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp
View file @
ce87bcc7
...
@@ -112,9 +112,8 @@ int main()
...
@@ -112,9 +112,8 @@ int main()
host_broadcast3D_am_bmnk
<
Tensor
<
ABDataType
>
,
Tensor
<
ABDataType
>
,
Tensor
<
CDataType
>
,
Add
>
(
host_broadcast3D_am_bmnk
<
Tensor
<
ABDataType
>
,
Tensor
<
ABDataType
>
,
Tensor
<
CDataType
>
,
Add
>
(
host_c_m_n_k
,
a_m
,
b_m_n_k
,
mnk
,
Add
{});
host_c_m_n_k
,
a_m
,
b_m_n_k
,
mnk
,
Add
{});
pass
&=
validator
.
check_err
(
c_m_n_k
,
host_c_m_n_k
,
"Error: Incorrect results c"
,
1e-3
,
1e-3
);
ck
::
utils
::
check_err
(
c_m_n_k
,
host_c_m_n_k
,
"Error: Incorrect results c"
,
1e-3
,
1e-3
);
}
}
return
pass
?
0
:
1
;
return
!
validator
.
is_success
()
;
}
}
example/19_binary_elementwise/elementwise_add_1d.cpp
View file @
ce87bcc7
...
@@ -95,7 +95,6 @@ int main()
...
@@ -95,7 +95,6 @@ int main()
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms"
<<
std
::
endl
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms"
<<
std
::
endl
;
bool
pass
=
true
;
if
(
do_verification
)
if
(
do_verification
)
{
{
c_m_device_buf
.
FromDevice
(
c_m
.
mData
.
data
());
c_m_device_buf
.
FromDevice
(
c_m
.
mData
.
data
());
...
@@ -104,8 +103,8 @@ int main()
...
@@ -104,8 +103,8 @@ int main()
host_elementwise1D
<
Tensor
<
ABDataType
>
,
Tensor
<
ABDataType
>
,
Tensor
<
CDataType
>
,
Add
>
(
host_elementwise1D
<
Tensor
<
ABDataType
>
,
Tensor
<
ABDataType
>
,
Tensor
<
CDataType
>
,
Add
>
(
host_c_m
,
a_m
,
b_m
,
M
,
Add
{});
host_c_m
,
a_m
,
b_m
,
M
,
Add
{});
pass
&=
ck
::
utils
::
check_err
(
c_m
,
host_c_m
,
"Error: Incorrect results c"
,
1e-3
,
1e-3
);
validator
.
check_err
(
c_m
,
host_c_m
,
"Error: Incorrect results c"
,
1e-3
,
1e-3
);
}
}
return
pass
?
0
:
1
;
return
!
validator
.
is_success
()
;
}
}
example/19_binary_elementwise/elementwise_add_4d.cpp
View file @
ce87bcc7
...
@@ -104,7 +104,6 @@ int main()
...
@@ -104,7 +104,6 @@ int main()
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms"
<<
std
::
endl
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms"
<<
std
::
endl
;
bool
pass
=
true
;
if
(
do_verification
)
if
(
do_verification
)
{
{
c_device_buf
.
FromDevice
(
c
.
mData
.
data
());
c_device_buf
.
FromDevice
(
c
.
mData
.
data
());
...
@@ -113,8 +112,8 @@ int main()
...
@@ -113,8 +112,8 @@ int main()
host_elementwise4D
<
Tensor
<
ABDataType
>
,
Tensor
<
ABDataType
>
,
Tensor
<
CDataType
>
,
Add
>
(
host_elementwise4D
<
Tensor
<
ABDataType
>
,
Tensor
<
ABDataType
>
,
Tensor
<
CDataType
>
,
Add
>
(
host_c
,
a
,
b
,
nchw
,
Add
{});
host_c
,
a
,
b
,
nchw
,
Add
{});
pass
&=
ck
::
utils
::
check_err
(
c
,
host_c
,
"Error: Incorrect results c"
,
1e-3
,
1e-3
);
validator
.
check_err
(
c
,
host_c
,
"Error: Incorrect results c"
,
1e-3
,
1e-3
);
}
}
return
pass
?
0
:
1
;
return
!
validator
.
is_success
()
;
}
}
example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc
View file @
ce87bcc7
...
@@ -157,7 +157,7 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
...
@@ -157,7 +157,7 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
wei_device_buf
.
FromDevice
(
wei_device_result
.
mData
.
data
());
wei_device_buf
.
FromDevice
(
wei_device_result
.
mData
.
data
());
return
ck
::
utils
::
check_err
(
wei_device_result
.
mData
,
wei_host_result
.
mData
);
return
validator
.
check_err
(
wei_device_result
.
mData
,
wei_host_result
.
mData
);
}
}
return
true
;
return
true
;
...
...
example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp
View file @
ce87bcc7
...
@@ -371,7 +371,7 @@ int main()
...
@@ -371,7 +371,7 @@ int main()
N
);
N
);
layerNorm_device_buf
.
FromDevice
(
layerNorm_m_n
.
mData
.
data
());
layerNorm_device_buf
.
FromDevice
(
layerNorm_m_n
.
mData
.
data
());
pass
&=
ck
::
utils
::
check_err
(
layerNorm_m_n
,
validator
.
check_err
(
layerNorm_m_n
,
host_layerNorm_m_n
,
host_layerNorm_m_n
,
"Error: Incorrect results layerNorm_m_n"
,
"Error: Incorrect results layerNorm_m_n"
,
1e-2
,
1e-2
,
...
@@ -401,5 +401,5 @@ int main()
...
@@ -401,5 +401,5 @@ int main()
gemm_reduce_mean_reduce_square_mean_ave_time
,
normalize_ave_time
,
M
,
N
,
K
);
gemm_reduce_mean_reduce_square_mean_ave_time
,
normalize_ave_time
,
M
,
N
,
K
);
}
}
return
pass
?
0
:
1
;
return
!
validator
.
is_success
()
;
}
}
example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp
View file @
ce87bcc7
...
@@ -255,9 +255,8 @@ int main()
...
@@ -255,9 +255,8 @@ int main()
epsilon
);
epsilon
);
h_device_buf
.
FromDevice
(
h_m_n
.
mData
.
data
());
h_device_buf
.
FromDevice
(
h_m_n
.
mData
.
data
());
pass
&=
validator
.
check_err
(
h_m_n
,
h_m_n_host
,
"Error: Incorrect results h_m_n"
,
1e-2
,
1e-2
);
ck
::
utils
::
check_err
(
h_m_n
,
h_m_n_host
,
"Error: Incorrect results h_m_n"
,
1e-2
,
1e-2
);
}
}
return
pass
?
0
:
1
;
return
!
validator
.
is_success
()
;
}
}
example/21_gemm_layernorm/gemm_layernorm_xdl_naive_fp16.cpp
View file @
ce87bcc7
...
@@ -345,7 +345,7 @@ int main()
...
@@ -345,7 +345,7 @@ int main()
N
);
N
);
layerNorm_device_buf
.
FromDevice
(
layerNorm_m_n
.
mData
.
data
());
layerNorm_device_buf
.
FromDevice
(
layerNorm_m_n
.
mData
.
data
());
pass
&=
ck
::
utils
::
check_err
(
validator
.
check_err
(
layerNorm_m_n
,
host_layerNorm_m_n
,
"Error: Incorrect results d1"
,
1e-3
,
1e-3
);
layerNorm_m_n
,
host_layerNorm_m_n
,
"Error: Incorrect results d1"
,
1e-3
,
1e-3
);
}
}
...
@@ -370,5 +370,5 @@ int main()
...
@@ -370,5 +370,5 @@ int main()
gemm_reduce_mean_reduce_square_mean_ave_time
,
normalize_ave_time
,
M
,
N
,
K
);
gemm_reduce_mean_reduce_square_mean_ave_time
,
normalize_ave_time
,
M
,
N
,
K
);
}
}
return
pass
?
0
:
1
;
return
!
validator
.
is_success
()
;
}
}
example/21_gemm_layernorm/gemm_xdl_layernorm_naive_single_kernel_fp16.cpp
View file @
ce87bcc7
...
@@ -274,14 +274,14 @@ int main(int argc, char* argv[])
...
@@ -274,14 +274,14 @@ int main(int argc, char* argv[])
if
constexpr
(
std
::
is_same
<
CShuffleDataType
,
F32
>::
value
)
if
constexpr
(
std
::
is_same
<
CShuffleDataType
,
F32
>::
value
)
{
{
pass
&=
ck
::
utils
::
check_err
(
validator
.
check_err
(
c_m_n_device_result
,
c_m_n_host_result
,
"Error: Incorrect results c"
);
c_m_n_device_result
,
c_m_n_host_result
,
"Error: Incorrect results c"
);
}
}
else
if
constexpr
(
std
::
is_same
<
CShuffleDataType
,
F16
>::
value
)
else
if
constexpr
(
std
::
is_same
<
CShuffleDataType
,
F16
>::
value
)
{
{
pass
&=
ck
::
utils
::
check_err
(
validator
.
check_err
(
c_m_n_device_result
,
c_m_n_host_result
,
"Error: Incorrect results c"
,
1e-2
,
1e-2
);
c_m_n_device_result
,
c_m_n_host_result
,
"Error: Incorrect results c"
,
1e-2
,
1e-2
);
}
}
}
}
return
pass
?
0
:
1
;
return
!
validator
.
is_success
()
;
}
}
example/22_cgemm/cgemm_xdl_common.hpp
View file @
ce87bcc7
...
@@ -220,12 +220,12 @@ bool run_cgemm_xdl(ck::index_t M,
...
@@ -220,12 +220,12 @@ bool run_cgemm_xdl(ck::index_t M,
const
Tensor
<
CDataType
>
c_m_n_real_device_result_converted
(
c_m_n_real_device_result
);
const
Tensor
<
CDataType
>
c_m_n_real_device_result_converted
(
c_m_n_real_device_result
);
const
Tensor
<
CDataType
>
c_m_n_imag_device_result_converted
(
c_m_n_imag_device_result
);
const
Tensor
<
CDataType
>
c_m_n_imag_device_result_converted
(
c_m_n_imag_device_result
);
result
=
ck
::
utils
::
check_err
(
c_m_n_real_device_result_converted
,
validator
.
check_err
(
c_m_n_real_device_result_converted
,
c_m_n_real_host_result
,
c_m_n_real_host_result
,
"Verification error: incorrect results in real part!"
,
"Verification error: incorrect results in real part!"
,
1e-2
f
,
1e-2
f
,
1e-1
f
);
1e-1
f
);
result
=
result
&&
ck
::
utils
::
check_err
(
validator
.
check_err
(
c_m_n_imag_device_result_converted
,
c_m_n_imag_device_result_converted
,
c_m_n_imag_host_result
,
c_m_n_imag_host_result
,
"Verification error: incorrect results in imaginary part!"
,
"Verification error: incorrect results in imaginary part!"
,
...
@@ -235,12 +235,12 @@ bool run_cgemm_xdl(ck::index_t M,
...
@@ -235,12 +235,12 @@ bool run_cgemm_xdl(ck::index_t M,
else
else
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
{
{
result
=
ck
::
utils
::
check_err
(
c_m_n_real_device_result
,
validator
.
check_err
(
c_m_n_real_device_result
,
c_m_n_real_host_result
,
c_m_n_real_host_result
,
"Verification error: incorrect results in real part!"
,
"Verification error: incorrect results in real part!"
,
1e-2
f
,
1e-2
f
,
1e-1
f
);
1e-1
f
);
result
=
result
&&
ck
::
utils
::
check_err
(
validator
.
check_err
(
c_m_n_imag_device_result
,
c_m_n_imag_device_result
,
c_m_n_imag_host_result
,
c_m_n_imag_host_result
,
"Verification error: incorrect results in imaginary part!"
,
"Verification error: incorrect results in imaginary part!"
,
...
@@ -248,7 +248,7 @@ bool run_cgemm_xdl(ck::index_t M,
...
@@ -248,7 +248,7 @@ bool run_cgemm_xdl(ck::index_t M,
1e-1
f
);
1e-1
f
);
}
}
return
result
;
return
validator
.
is_success
()
;
}
}
return
true
;
return
true
;
}
}
example/23_softmax/softmax_blockwise.cpp
View file @
ce87bcc7
...
@@ -240,13 +240,12 @@ int main(int argc, char* argv[])
...
@@ -240,13 +240,12 @@ int main(int argc, char* argv[])
auto
invoker_ptr
=
device_instance
.
MakeInvokerPointer
();
auto
invoker_ptr
=
device_instance
.
MakeInvokerPointer
();
bool
pass
=
true
;
if
(
args
.
do_verification
)
if
(
args
.
do_verification
)
{
{
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
});
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
});
out_dev
.
FromDevice
(
out
.
mData
.
data
());
out_dev
.
FromDevice
(
out
.
mData
.
data
());
// LogRangeAsType<float>(std::cout << "tensor out: " , out.mData, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "tensor out: " , out.mData, ",") << std::endl;
pass
=
pass
&&
ck
::
utils
::
check_err
(
out
,
out_ref
);
validator
.
check_err
(
out
,
out_ref
);
};
};
float
avg_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
args
.
time_kernel
});
float
avg_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
args
.
time_kernel
});
...
@@ -260,5 +259,5 @@ int main(int argc, char* argv[])
...
@@ -260,5 +259,5 @@ int main(int argc, char* argv[])
std
::
cout
<<
"Perf: "
<<
avg_time
<<
" ms, "
<<
gb_per_sec
<<
" GB/s, "
<<
instance_name
std
::
cout
<<
"Perf: "
<<
avg_time
<<
" ms, "
<<
gb_per_sec
<<
" GB/s, "
<<
instance_name
<<
std
::
endl
;
<<
std
::
endl
;
return
(
pass
?
0
:
1
);
return
!
validator
.
is_success
(
);
}
}
example/24_batched_gemm/run_batched_gemm_example.inc
View file @
ce87bcc7
...
@@ -146,7 +146,6 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
...
@@ -146,7 +146,6 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
}
}
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
bool
pass
=
true
;
if
(
config
.
do_verification
)
if
(
config
.
do_verification
)
{
{
...
@@ -174,10 +173,10 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
...
@@ -174,10 +173,10 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
#ifdef BUILD_INT4_EXAMPLE
#ifdef BUILD_INT4_EXAMPLE
const
Tensor
<
EDataType
>
e_device_result_converted
(
e_g_m_n_device_result
);
const
Tensor
<
EDataType
>
e_device_result_converted
(
e_g_m_n_device_result
);
pass
&=
ck
::
utils
::
check_err
(
e_device_result_converted
,
e_g_m_n_host_result
);
validator
.
check_err
(
e_device_result_converted
,
e_g_m_n_host_result
);
#else
#else
pass
=
ck
::
utils
::
check_err
(
validator
.
check_err
(
e_g_m_n_device_result
,
e_g_m_n_host_result
,
"Error: Incorrect results c"
);
e_g_m_n_device_result
,
e_g_m_n_host_result
,
"Error: Incorrect results c"
);
#endif
#endif
}
}
...
@@ -197,7 +196,7 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
...
@@ -197,7 +196,7 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
<<
" GB/s, "
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
<<
" GB/s, "
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
}
}
return
pass
?
0
:
1
;
return
!
validator
.
is_success
()
;
}
}
bool
run_batched_gemm_example
(
int
argc
,
char
*
argv
[])
bool
run_batched_gemm_example
(
int
argc
,
char
*
argv
[])
...
...
example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp
View file @
ce87bcc7
...
@@ -390,7 +390,8 @@ int main(int argc, char* argv[])
...
@@ -390,7 +390,8 @@ int main(int argc, char* argv[])
}
}
}
}
return
ck
::
utils
::
check_err
(
e_gs_ms_ns_device_result
,
e_gs_ms_ns_host_result
)
?
0
:
1
;
validator
.
check_err
(
e_gs_ms_ns_device_result
,
e_gs_ms_ns_host_result
);
return
!
validator
.
is_success
();
}
}
return
0
;
return
0
;
...
...
example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp
View file @
ce87bcc7
...
@@ -391,7 +391,8 @@ int main(int argc, char* argv[])
...
@@ -391,7 +391,8 @@ int main(int argc, char* argv[])
}
}
}
}
return
ck
::
utils
::
check_err
(
e_gs_ms_ns_device_result
,
e_gs_ms_ns_host_result
)
?
0
:
1
;
validator
.
check_err
(
e_gs_ms_ns_device_result
,
e_gs_ms_ns_host_result
);
return
!
validator
.
is_success
();
}
}
return
0
;
return
0
;
...
...
example/26_contraction/contraction_bilinear_xdl_fp32.cpp
View file @
ce87bcc7
...
@@ -286,7 +286,8 @@ int main(int argc, char* argv[])
...
@@ -286,7 +286,8 @@ int main(int argc, char* argv[])
}
}
}
}
return
ck
::
utils
::
check_err
(
e_ms_ns_device_result
,
e_ms_ns_host_result
)
?
0
:
1
;
validator
.
check_err
(
e_ms_ns_device_result
,
e_ms_ns_host_result
);
return
!
validator
.
is_success
();
}
}
return
0
;
return
0
;
...
...
example/26_contraction/contraction_bilinear_xdl_fp64.cpp
View file @
ce87bcc7
...
@@ -286,7 +286,8 @@ int main(int argc, char* argv[])
...
@@ -286,7 +286,8 @@ int main(int argc, char* argv[])
}
}
}
}
return
ck
::
utils
::
check_err
(
e_ms_ns_device_result
,
e_ms_ns_host_result
)
?
0
:
1
;
validator
.
check_err
(
e_ms_ns_device_result
,
e_ms_ns_host_result
);
return
!
validator
.
is_success
();
}
}
return
0
;
return
0
;
...
...
example/26_contraction/contraction_scale_xdl_fp32.cpp
View file @
ce87bcc7
...
@@ -269,7 +269,8 @@ int main(int argc, char* argv[])
...
@@ -269,7 +269,8 @@ int main(int argc, char* argv[])
}
}
}
}
return
ck
::
utils
::
check_err
(
e_ms_ns_device_result
,
e_ms_ns_host_result
)
?
0
:
1
;
validator
.
check_err
(
e_ms_ns_device_result
,
e_ms_ns_host_result
);
return
!
validator
.
is_success
();
}
}
return
0
;
return
0
;
...
...
example/26_contraction/contraction_scale_xdl_fp64.cpp
View file @
ce87bcc7
...
@@ -269,7 +269,8 @@ int main(int argc, char* argv[])
...
@@ -269,7 +269,8 @@ int main(int argc, char* argv[])
}
}
}
}
return
ck
::
utils
::
check_err
(
e_ms_ns_device_result
,
e_ms_ns_host_result
)
?
0
:
1
;
validator
.
check_err
(
e_ms_ns_device_result
,
e_ms_ns_host_result
);
return
!
validator
.
is_success
();
}
}
return
0
;
return
0
;
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment