Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ce87bcc7
"docs/source/en/api/pipelines/stable_diffusion/inpaint.md" did not exist on "86ecd4b795f865b5b615b8c54991c177bb3dbef5"
Commit
ce87bcc7
authored
Sep 08, 2023
by
Bartlomiej Kocot
Browse files
tmp
parent
c8a8385f
Changes
123
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
68 additions
and
62 deletions
+68
-62
example/27_layernorm/run_layernorm_example.inc
example/27_layernorm/run_layernorm_example.inc
+2
-2
example/28_grouped_gemm_bias_e_permute/grouped_gemm_bias_e_permute_xdl_fp16.cpp
...m_bias_e_permute/grouped_gemm_bias_e_permute_xdl_fp16.cpp
+2
-2
example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp
..._bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp
+2
-1
example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_xdl_fp16.cpp
...m_bias_e_permute/batched_gemm_bias_e_permute_xdl_fp16.cpp
+2
-1
example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc
...multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc
+3
-2
example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc
...ple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc
+3
-2
example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_example.inc
...uped_conv_fwd_multiple_d/run_grouped_conv_fwd_example.inc
+3
-2
example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc
...le/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc
+2
-1
example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc
...cale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc
+2
-1
example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc
...tmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc
+3
-4
example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc
...tmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc
+3
-5
example/33_multiple_reduce/dual_reduce_common.hpp
example/33_multiple_reduce/dual_reduce_common.hpp
+4
-4
example/34_batchnorm/batchnorm_backward_nhwc.cpp
example/34_batchnorm/batchnorm_backward_nhwc.cpp
+6
-6
example/34_batchnorm/batchnorm_forward_inferring_nhwc.cpp
example/34_batchnorm/batchnorm_forward_inferring_nhwc.cpp
+3
-3
example/34_batchnorm/batchnorm_forward_training_nhwc.cpp
example/34_batchnorm/batchnorm_forward_training_nhwc.cpp
+8
-8
example/34_batchnorm/batchnorm_forward_training_nhwc_obsolete.cpp
...34_batchnorm/batchnorm_forward_training_nhwc_obsolete.cpp
+7
-7
example/35_splitK_gemm/run_splitK_gemm_example.inc
example/35_splitK_gemm/run_splitK_gemm_example.inc
+4
-4
example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp
..._sparse_embedding/sparse_embedding3_forward_layernorm.cpp
+3
-3
example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp
..._gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp
+3
-2
example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_bias_relu_example.inc
...ultiple_d/run_grouped_conv_bwd_data_bias_relu_example.inc
+3
-2
No files found.
example/27_layernorm/run_layernorm_example.inc
View file @
ce87bcc7
...
@@ -89,8 +89,8 @@ int run_groupnorm_example()
...
@@ -89,8 +89,8 @@ int run_groupnorm_example()
ref_invoker
.
Run
(
ref_argument
);
ref_invoker
.
Run
(
ref_argument
);
y_dev
.
FromDevice
(
y
.
mData
.
data
());
y_dev
.
FromDevice
(
y
.
mData
.
data
());
pass
&=
ck
::
utils
::
check_err
(
y
,
host_y
,
"Error: Incorrect results"
,
1
e
-
3
,
1
e
-
3
);
validator
.
check_err
(
y
,
host_y
,
"Error: Incorrect results"
,
1
e
-
3
,
1
e
-
3
);
}
}
return
(
pass
?
0
:
1
);
return
!
validator
.
is_success
(
);
}
}
example/28_grouped_gemm_bias_e_permute/grouped_gemm_bias_e_permute_xdl_fp16.cpp
View file @
ce87bcc7
...
@@ -458,9 +458,9 @@ int main(int argc, char* argv[])
...
@@ -458,9 +458,9 @@ int main(int argc, char* argv[])
}
}
}
}
pass
&=
ck
::
utils
::
check_err
(
e_device_tensors
[
i
],
e_ms_ns_host_result
);
validator
.
check_err
(
e_device_tensors
[
i
],
e_ms_ns_host_result
);
}
}
}
}
return
pass
?
0
:
1
;
return
!
validator
.
is_success
()
;
}
}
example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp
View file @
ce87bcc7
...
@@ -424,7 +424,8 @@ int main(int argc, char* argv[])
...
@@ -424,7 +424,8 @@ int main(int argc, char* argv[])
}
}
}
}
return
ck
::
utils
::
check_err
(
e_gs_ms_ns_device_result
,
e_gs_ms_ns_host_result
)
?
0
:
1
;
validator
.
check_err
(
e_gs_ms_ns_device_result
,
e_gs_ms_ns_host_result
);
return
!
validator
.
is_success
();
}
}
return
0
;
return
0
;
...
...
example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_xdl_fp16.cpp
View file @
ce87bcc7
...
@@ -390,7 +390,8 @@ int main(int argc, char* argv[])
...
@@ -390,7 +390,8 @@ int main(int argc, char* argv[])
}
}
}
}
return
ck
::
utils
::
check_err
(
e_gs_ms_ns_device_result
,
e_gs_ms_ns_host_result
)
?
0
:
1
;
validator
.
check_err
(
e_gs_ms_ns_device_result
,
e_gs_ms_ns_host_result
);
return
!
validator
.
is_success
();
}
}
return
0
;
return
0
;
...
...
example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_example.inc
View file @
ce87bcc7
...
@@ -256,12 +256,13 @@ bool run_grouped_conv_fwd_bias_relu_add(const ExecutionConfig& config,
...
@@ -256,12 +256,13 @@ bool run_grouped_conv_fwd_bias_relu_add(const ExecutionConfig& config,
#ifdef BUILD_INT4_EXAMPLE
#ifdef BUILD_INT4_EXAMPLE
const
Tensor
<
OutUserDataType
>
out_device_converted
(
out_device
);
const
Tensor
<
OutUserDataType
>
out_device_converted
(
out_device
);
return
ck
::
utils
::
check_err
(
validator
.
check_err
(
out_device_converted
,
out_host
,
"Error: incorrect results!"
,
1
e
-
5
f
,
1
e
-
4
f
);
out_device_converted
,
out_host
,
"Error: incorrect results!"
,
1
e
-
5
f
,
1
e
-
4
f
);
#else
#else
return
ck
::
utils
::
check_err
(
validator
.
check_err
(
out_device
,
out_host
,
"Error: incorrect results!"
,
1
e
-
5
f
,
1
e
-
4
f
);
out_device
,
out_host
,
"Error: incorrect results!"
,
1
e
-
5
f
,
1
e
-
4
f
);
#endif
#endif
return
validator
.
is_success
();
}
}
return
true
;
return
true
;
...
...
example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc
View file @
ce87bcc7
...
@@ -254,12 +254,13 @@ bool run_grouped_conv_fwd_bias_relu_add(const ExecutionConfig& config,
...
@@ -254,12 +254,13 @@ bool run_grouped_conv_fwd_bias_relu_add(const ExecutionConfig& config,
#ifdef BUILD_INT4_EXAMPLE
#ifdef BUILD_INT4_EXAMPLE
const
Tensor
<
OutUserDataType
>
out_device_converted
(
out_device
);
const
Tensor
<
OutUserDataType
>
out_device_converted
(
out_device
);
return
ck
::
utils
::
check_err
(
validator
.
check_err
(
out_device_converted
,
out_host
,
"Error: incorrect results!"
,
1
e
-
5
f
,
1
e
-
4
f
);
out_device_converted
,
out_host
,
"Error: incorrect results!"
,
1
e
-
5
f
,
1
e
-
4
f
);
#else
#else
return
ck
::
utils
::
check_err
(
validator
.
check_err
(
out_device
,
out_host
,
"Error: incorrect results!"
,
1
e
-
5
f
,
1
e
-
4
f
);
out_device
,
out_host
,
"Error: incorrect results!"
,
1
e
-
5
f
,
1
e
-
4
f
);
#endif
#endif
return
validator
.
is_success
();
}
}
return
true
;
return
true
;
...
...
example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_example.inc
View file @
ce87bcc7
...
@@ -191,12 +191,13 @@ bool run_grouped_conv_fwd(const ExecutionConfig& config,
...
@@ -191,12 +191,13 @@ bool run_grouped_conv_fwd(const ExecutionConfig& config,
#ifdef BUILD_INT4_EXAMPLE
#ifdef BUILD_INT4_EXAMPLE
const
Tensor
<
OutUserDataType
>
out_device_converted
(
out_device
);
const
Tensor
<
OutUserDataType
>
out_device_converted
(
out_device
);
return
ck
::
utils
::
check_err
(
validator
.
check_err
(
out_device_converted
.
mData
,
out_host
.
mData
,
"Error: incorrect results!"
,
1
e
-
5
f
,
1
e
-
4
f
);
out_device_converted
.
mData
,
out_host
.
mData
,
"Error: incorrect results!"
,
1
e
-
5
f
,
1
e
-
4
f
);
#else
#else
return
ck
::
utils
::
check_err
(
validator
.
check_err
(
out_device
.
mData
,
out_host
.
mData
,
"Error: incorrect results!"
,
1
e
-
5
f
,
1
e
-
4
f
);
out_device
.
mData
,
out_host
.
mData
,
"Error: incorrect results!"
,
1
e
-
5
f
,
1
e
-
4
f
);
#endif
#endif
return
validator
.
is_success
();
}
}
return
true
;
return
true
;
...
...
example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc
View file @
ce87bcc7
...
@@ -270,7 +270,8 @@ bool run_batched_gemm_gemm_example(int argc, char* argv[])
...
@@ -270,7 +270,8 @@ bool run_batched_gemm_gemm_example(int argc, char* argv[])
c_g_m_o_device_buf
.
FromDevice
(
c_g_m_o_device_result
.
mData
.
data
());
c_g_m_o_device_buf
.
FromDevice
(
c_g_m_o_device_result
.
mData
.
data
());
#endif
#endif
return
ck
::
utils
::
check_err
(
c_g_m_o_device_result
,
c_g_m_o_host_result
);
validator
.
check_err
(
c_g_m_o_device_result
,
c_g_m_o_host_result
);
return
validator
.
is_success
();
}
}
return
true
;
return
true
;
...
...
example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc
View file @
ce87bcc7
...
@@ -254,7 +254,8 @@ int run(int argc, char* argv[])
...
@@ -254,7 +254,8 @@ int run(int argc, char* argv[])
ref_gemm1_invoker
.
Run
(
ref_gemm1_argument
);
ref_gemm1_invoker
.
Run
(
ref_gemm1_argument
);
return
ck
::
utils
::
check_err
(
c_g_m_o_device_result
.
mData
,
c_g_m_o_host_result
.
mData
)
?
0
:
1
;
validator
.
check_err
(
c_g_m_o_device_result
.
mData
,
c_g_m_o_host_result
.
mData
);
return
!
validator
.
is_success
();
}
}
return
0
;
return
0
;
...
...
example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc
View file @
ce87bcc7
...
@@ -265,13 +265,12 @@ int run(int argc, char* argv[])
...
@@ -265,13 +265,12 @@ int run(int argc, char* argv[])
atol
=
1
e
-
2
;
atol
=
1
e
-
2
;
}
}
return
ck
::
utils
::
check_err
(
c_gs_ms_os_device_result
.
mData
,
validator
.
check_err
(
c_gs_ms_os_device_result
.
mData
,
c_gs_ms_os_host_result
.
mData
,
c_gs_ms_os_host_result
.
mData
,
"Error: Incorrect results!"
,
"Error: Incorrect results!"
,
rtol
,
rtol
,
atol
)
atol
);
?
0
return
!
validator
.
is_success
();
:
1
;
}
}
return
0
;
return
0
;
...
...
example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc
View file @
ce87bcc7
...
@@ -223,7 +223,7 @@ int run(int argc, char* argv[])
...
@@ -223,7 +223,7 @@ int run(int argc, char* argv[])
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
bool
pass
=
true
;
ck
::
utils
::
CorrectnessValidator
validator
;
if
(
do_verification
)
if
(
do_verification
)
{
{
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
group_count
;
i
++
)
...
@@ -309,11 +309,9 @@ int run(int argc, char* argv[])
...
@@ -309,11 +309,9 @@ int run(int argc, char* argv[])
self
(
idx
)
=
c_g_m_o_host_result
(
g
,
idx
[
2
],
idx
[
3
]);
self
(
idx
)
=
c_g_m_o_host_result
(
g
,
idx
[
2
],
idx
[
3
]);
});
});
bool
pass_
=
validator
.
check_err
(
c_gs_ms_os_device_result
.
mData
,
c_gs_ms_os_host_result
.
mData
);
ck
::
utils
::
check_err
(
c_gs_ms_os_device_result
.
mData
,
c_gs_ms_os_host_result
.
mData
);
pass
&=
pass_
;
}
}
}
}
return
pass
?
0
:
1
;
return
!
validator
.
is_success
()
;
}
}
example/33_multiple_reduce/dual_reduce_common.hpp
View file @
ce87bcc7
...
@@ -300,15 +300,15 @@ int mean_meansquare_dual_reduce_test(size_t n,
...
@@ -300,15 +300,15 @@ int mean_meansquare_dual_reduce_test(size_t n,
std
::
cout
<<
"Perf: "
<<
avg_time
<<
" ms, "
<<
gb_per_sec
<<
" GB/s, "
<<
reduce_name
std
::
cout
<<
"Perf: "
<<
avg_time
<<
" ms, "
<<
gb_per_sec
<<
" GB/s, "
<<
reduce_name
<<
std
::
endl
;
<<
std
::
endl
;
bool
pass
=
true
;
ck
::
utils
::
CorrectnessValidator
validator
;
if
(
do_verification
)
if
(
do_verification
)
{
{
mean_dev
.
FromDevice
(
mean
.
mData
.
data
());
mean_dev
.
FromDevice
(
mean
.
mData
.
data
());
meansquare_dev
.
FromDevice
(
meansquare
.
mData
.
data
());
meansquare_dev
.
FromDevice
(
meansquare
.
mData
.
data
());
pass
=
pass
&&
ck
::
utils
::
check_err
(
mean
,
mean_ref
);
validator
.
check_err
(
mean
,
mean_ref
);
pass
=
pass
&&
ck
::
utils
::
check_err
(
meansquare
,
meansquare_ref
);
validator
.
check_err
(
meansquare
,
meansquare_ref
);
};
};
return
(
pass
?
0
:
1
);
return
!
validator
.
is_success
(
);
}
}
example/34_batchnorm/batchnorm_backward_nhwc.cpp
View file @
ce87bcc7
...
@@ -338,7 +338,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
...
@@ -338,7 +338,7 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
else
else
(
void
)
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
(
void
)
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
bool
pass
=
true
;
ck
::
utils
::
CorrectnessValidator
validator
;
if
(
do_verification
)
if
(
do_verification
)
{
{
...
@@ -394,20 +394,20 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
...
@@ -394,20 +394,20 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
dbias_dev
.
FromDevice
(
dbias
.
data
());
dbias_dev
.
FromDevice
(
dbias
.
data
());
// clang-format off
// clang-format off
pass
=
pass
&&
ck
::
utils
::
check_err
(
dbias
.
mData
,
dbias_ref
.
mData
,
"dBias result:"
,
2e-4
,
2e-4
);
validator
.
check_err
(
dbias
.
mData
,
dbias_ref
.
mData
,
"dBias result:"
,
2e-4
,
2e-4
);
pass
=
pass
&&
ck
::
utils
::
check_err
(
dscale
.
mData
,
dscale_ref
.
mData
,
"dScale result:"
,
2e-4
,
2e-4
);
validator
.
check_err
(
dscale
.
mData
,
dscale_ref
.
mData
,
"dScale result:"
,
2e-4
,
2e-4
);
pass
=
pass
&&
ck
::
utils
::
check_err
(
dx
.
mData
,
dx_ref
.
mData
,
"dx result:"
);
validator
.
check_err
(
dx
.
mData
,
dx_ref
.
mData
,
"dx result:"
);
// clang-format on
// clang-format on
};
};
return
(
pa
ss
);
return
validator
.
is_succe
ss
(
);
};
};
static
const
double
epsilon
=
std
::
numeric_limits
<
float
>::
epsilon
();
static
const
double
epsilon
=
std
::
numeric_limits
<
float
>::
epsilon
();
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
bool
pass
=
true
;
ck
::
utils
::
CorrectnessValidator
validator
;
if
(
argc
>
1
)
if
(
argc
>
1
)
{
{
...
...
example/34_batchnorm/batchnorm_forward_inferring_nhwc.cpp
View file @
ce87bcc7
...
@@ -308,17 +308,17 @@ bool bnorm_infer_nhwc_test(bool do_verification,
...
@@ -308,17 +308,17 @@ bool bnorm_infer_nhwc_test(bool do_verification,
(
void
)
invoker_ptr_ref
->
Run
(
argument_ptr_ref
.
get
());
(
void
)
invoker_ptr_ref
->
Run
(
argument_ptr_ref
.
get
());
y_dev
.
FromDevice
(
y
.
mData
.
data
());
y_dev
.
FromDevice
(
y
.
mData
.
data
());
pass
=
pass
&&
ck
::
utils
::
check_err
(
y
,
y_ref
);
validator
.
check_err
(
y
,
y_ref
);
};
};
return
(
pa
ss
);
return
validator
.
is_succe
ss
(
);
};
};
static
const
double
epsilon
=
std
::
numeric_limits
<
float
>::
epsilon
();
static
const
double
epsilon
=
std
::
numeric_limits
<
float
>::
epsilon
();
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
bool
pass
=
true
;
ck
::
utils
::
CorrectnessValidator
validator
;
if
(
argc
>
1
)
if
(
argc
>
1
)
{
{
...
...
example/34_batchnorm/batchnorm_forward_training_nhwc.cpp
View file @
ce87bcc7
...
@@ -362,7 +362,7 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
...
@@ -362,7 +362,7 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
else
else
(
void
)
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
(
void
)
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
bool
pass
=
true
;
ck
::
utils
::
CorrectnessValidator
validator
;
if
(
do_verification
)
if
(
do_verification
)
{
{
...
@@ -414,7 +414,7 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
...
@@ -414,7 +414,7 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
(
void
)
invoker_ptr_ref
->
Run
(
argument_ptr_ref
.
get
());
(
void
)
invoker_ptr_ref
->
Run
(
argument_ptr_ref
.
get
());
y_dev
.
FromDevice
(
y
.
mData
.
data
());
y_dev
.
FromDevice
(
y
.
mData
.
data
());
pass
=
pass
&&
ck
::
utils
::
check_err
(
y
,
y_ref
,
"Incorrect normalized output values"
);
validator
.
check_err
(
y
,
y_ref
,
"Incorrect normalized output values"
);
if
(
updateMovingAverage
)
if
(
updateMovingAverage
)
{
{
...
@@ -424,10 +424,10 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
...
@@ -424,10 +424,10 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
resultRunningMean_dev
.
FromDevice
(
resultRunningMean
.
mData
.
data
());
resultRunningMean_dev
.
FromDevice
(
resultRunningMean
.
mData
.
data
());
resultRunningVariance_dev
.
FromDevice
(
resultRunningVariance
.
mData
.
data
());
resultRunningVariance_dev
.
FromDevice
(
resultRunningVariance
.
mData
.
data
());
pass
=
pass
&&
ck
::
utils
::
check_err
(
resultRunningMean
,
validator
.
check_err
(
resultRunningMean
,
resultRunningMean_ref
,
resultRunningMean_ref
,
"Incorrect running mean values"
);
"Incorrect running mean values"
);
pass
=
pass
&&
ck
::
utils
::
check_err
(
resultRunningVariance
,
validator
.
check_err
(
resultRunningVariance
,
resultRunningVariance_ref
,
resultRunningVariance_ref
,
"Incorrect running variance values"
);
"Incorrect running variance values"
);
};
};
...
@@ -442,15 +442,15 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
...
@@ -442,15 +442,15 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
resultSaveMean_dev
.
FromDevice
(
resultSaveMean
.
mData
.
data
());
resultSaveMean_dev
.
FromDevice
(
resultSaveMean
.
mData
.
data
());
resultSaveInvVariance_dev
.
FromDevice
(
resultSaveInvVariance
.
mData
.
data
());
resultSaveInvVariance_dev
.
FromDevice
(
resultSaveInvVariance
.
mData
.
data
());
pass
=
pass
&&
ck
::
utils
::
check_err
(
validator
.
check_err
(
resultSaveMean
,
resultSaveMean_ref
,
"Incorrect saved mean values"
);
resultSaveMean
,
resultSaveMean_ref
,
"Incorrect saved mean values"
);
pass
=
pass
&&
ck
::
utils
::
check_err
(
resultSaveInvVariance
,
validator
.
check_err
(
resultSaveInvVariance
,
resultSaveInvVariance_ref
,
resultSaveInvVariance_ref
,
"Incorrect saved invvariance values"
);
"Incorrect saved invvariance values"
);
};
};
};
};
return
(
pa
ss
);
return
validator
.
is_succe
ss
(
);
};
};
const
double
epsilon
=
std
::
numeric_limits
<
float
>::
epsilon
();
const
double
epsilon
=
std
::
numeric_limits
<
float
>::
epsilon
();
...
@@ -584,7 +584,7 @@ int main(int argc, char* argv[])
...
@@ -584,7 +584,7 @@ int main(int argc, char* argv[])
averageFactor
,
averageFactor
,
epsilon
);
epsilon
);
pass
=
pass
&&
bnorm_fwd_nhwc_test
<
ck
::
half_t
,
float
,
false
>
(
true
,
bnorm_fwd_nhwc_test
<
ck
::
half_t
,
float
,
false
>
(
true
,
2
,
2
,
false
,
// don't time kernel
false
,
// don't time kernel
{
128
,
16
,
3
,
1024
},
{
128
,
16
,
3
,
1024
},
...
...
example/34_batchnorm/batchnorm_forward_training_nhwc_obsolete.cpp
View file @
ce87bcc7
...
@@ -362,7 +362,7 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
...
@@ -362,7 +362,7 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
else
else
(
void
)
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
(
void
)
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
bool
pass
=
true
;
ck
::
utils
::
CorrectnessValidator
validator
;
if
(
do_verification
)
if
(
do_verification
)
{
{
...
@@ -414,7 +414,7 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
...
@@ -414,7 +414,7 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
(
void
)
invoker_ptr_ref
->
Run
(
argument_ptr_ref
.
get
());
(
void
)
invoker_ptr_ref
->
Run
(
argument_ptr_ref
.
get
());
y_dev
.
FromDevice
(
y
.
mData
.
data
());
y_dev
.
FromDevice
(
y
.
mData
.
data
());
pass
=
pass
&&
ck
::
utils
::
check_err
(
y
,
y_ref
,
"Incorrect normalized output values"
);
validator
.
check_err
(
y
,
y_ref
,
"Incorrect normalized output values"
);
if
(
updateMovingAverage
)
if
(
updateMovingAverage
)
{
{
...
@@ -424,10 +424,10 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
...
@@ -424,10 +424,10 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
resultRunningMean_dev
.
FromDevice
(
resultRunningMean
.
mData
.
data
());
resultRunningMean_dev
.
FromDevice
(
resultRunningMean
.
mData
.
data
());
resultRunningVariance_dev
.
FromDevice
(
resultRunningVariance
.
mData
.
data
());
resultRunningVariance_dev
.
FromDevice
(
resultRunningVariance
.
mData
.
data
());
pass
=
pass
&&
ck
::
utils
::
check_err
(
resultRunningMean
,
validator
.
check_err
(
resultRunningMean
,
resultRunningMean_ref
,
resultRunningMean_ref
,
"Incorrect running mean values"
);
"Incorrect running mean values"
);
pass
=
pass
&&
ck
::
utils
::
check_err
(
resultRunningVariance
,
validator
.
check_err
(
resultRunningVariance
,
resultRunningVariance_ref
,
resultRunningVariance_ref
,
"Incorrect running variance values"
);
"Incorrect running variance values"
);
};
};
...
@@ -442,15 +442,15 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
...
@@ -442,15 +442,15 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
resultSaveMean_dev
.
FromDevice
(
resultSaveMean
.
mData
.
data
());
resultSaveMean_dev
.
FromDevice
(
resultSaveMean
.
mData
.
data
());
resultSaveInvVariance_dev
.
FromDevice
(
resultSaveInvVariance
.
mData
.
data
());
resultSaveInvVariance_dev
.
FromDevice
(
resultSaveInvVariance
.
mData
.
data
());
pass
=
pass
&&
ck
::
utils
::
check_err
(
validator
.
check_err
(
resultSaveMean
,
resultSaveMean_ref
,
"Incorrect saved mean values"
);
resultSaveMean
,
resultSaveMean_ref
,
"Incorrect saved mean values"
);
pass
=
pass
&&
ck
::
utils
::
check_err
(
resultSaveInvVariance
,
validator
.
check_err
(
resultSaveInvVariance
,
resultSaveInvVariance_ref
,
resultSaveInvVariance_ref
,
"Incorrect saved invvariance values"
);
"Incorrect saved invvariance values"
);
};
};
};
};
return
(
pa
ss
);
return
validator
.
is_succe
ss
(
);
};
};
const
double
epsilon
=
std
::
numeric_limits
<
float
>::
epsilon
();
const
double
epsilon
=
std
::
numeric_limits
<
float
>::
epsilon
();
...
...
example/35_splitK_gemm/run_splitK_gemm_example.inc
View file @
ce87bcc7
...
@@ -121,7 +121,7 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
...
@@ -121,7 +121,7 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
}
}
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
bool
pass
=
true
;
ck
::
utils
::
CorrectnessValidator
validator
;
if
(
config
.
do_verification
)
if
(
config
.
do_verification
)
{
{
...
@@ -146,12 +146,12 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
...
@@ -146,12 +146,12 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
if
(
std
::
is_same
<
CDataType
,
ck
::
half_t
>::
value
)
if
(
std
::
is_same
<
CDataType
,
ck
::
half_t
>::
value
)
{
{
pass
&=
ck
::
utils
::
check_err
(
validator
.
check_err
(
c_m_n_device_result
,
c_m_n_host_result
,
"fp16 incorrect result"
,
3
e
-
3
,
1
e
-
3
);
c_m_n_device_result
,
c_m_n_host_result
,
"fp16 incorrect result"
,
3
e
-
3
,
1
e
-
3
);
}
}
else
else
{
{
pass
&=
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_host_result
);
validator
.
check_err
(
c_m_n_device_result
,
c_m_n_host_result
);
}
}
}
}
...
@@ -169,7 +169,7 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
...
@@ -169,7 +169,7 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
<<
" GB/s, "
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
<<
" GB/s, "
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
}
}
return
pass
;
return
validator
.
is_success
()
;
}
}
bool
run_splitK_gemm_example
(
int
argc
,
char
*
argv
[])
bool
run_splitK_gemm_example
(
int
argc
,
char
*
argv
[])
...
...
example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp
View file @
ce87bcc7
...
@@ -155,7 +155,7 @@ int main()
...
@@ -155,7 +155,7 @@ int main()
auto
invoker_ptr
=
device_instance
.
MakeInvokerPointer
();
auto
invoker_ptr
=
device_instance
.
MakeInvokerPointer
();
float
time_ms
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
float
time_ms
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
bool
pass
=
true
;
ck
::
utils
::
CorrectnessValidator
validator
;
{
{
Tensor
<
OutType
>
out_from_dev
(
f_host_tensor_desc_2d
(
index_length
,
current_dim
));
Tensor
<
OutType
>
out_from_dev
(
f_host_tensor_desc_2d
(
index_length
,
current_dim
));
ReferenceInstance
ref
;
ReferenceInstance
ref
;
...
@@ -176,7 +176,7 @@ int main()
...
@@ -176,7 +176,7 @@ int main()
ref_invoker
.
Run
(
ref_argument
);
ref_invoker
.
Run
(
ref_argument
);
out_dev
.
FromDevice
(
out_from_dev
.
mData
.
data
());
out_dev
.
FromDevice
(
out_from_dev
.
mData
.
data
());
pass
&=
ck
::
utils
::
check_err
(
out_from_dev
,
out
,
"Error: Incorrect results"
,
1e-3
,
1e-3
);
validator
.
check_err
(
out_from_dev
,
out
,
"Error: Incorrect results"
,
1e-3
,
1e-3
);
}
}
double
total_read
=
current_dim
*
index_length
*
3
*
sizeof
(
EmbType
)
+
double
total_read
=
current_dim
*
index_length
*
3
*
sizeof
(
EmbType
)
+
...
@@ -186,7 +186,7 @@ int main()
...
@@ -186,7 +186,7 @@ int main()
double
gbps
=
(
total_read
+
total_write
)
/
time_ms
/
1e6
;
double
gbps
=
(
total_read
+
total_write
)
/
time_ms
/
1e6
;
std
::
cout
<<
", total bytes:"
<<
(
total_read
+
total_write
)
<<
", time:"
<<
time_ms
std
::
cout
<<
", total bytes:"
<<
(
total_read
+
total_write
)
<<
", time:"
<<
time_ms
<<
", gbps:"
<<
gbps
<<
", valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
endl
<<
", gbps:"
<<
gbps
<<
", valid:"
<<
(
validator
.
is_success
()
?
"y"
:
"n"
)
<<
std
::
endl
<<
std
::
flush
;
<<
std
::
flush
;
});
});
...
...
example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp
View file @
ce87bcc7
...
@@ -513,8 +513,9 @@ int main(int argc, char* argv[])
...
@@ -513,8 +513,9 @@ int main(int argc, char* argv[])
e1_g_m_o_host_result
.
ForEach
([
&
](
auto
&
,
auto
idx
)
{
e1_g_m_o_host_result
.
ForEach
([
&
](
auto
&
,
auto
idx
)
{
cde1_element_op
(
e1_g_m_o_host_result
(
idx
),
c1_g_m_o
(
idx
),
d1_g_m_o
(
idx
));
cde1_element_op
(
e1_g_m_o_host_result
(
idx
),
c1_g_m_o
(
idx
),
d1_g_m_o
(
idx
));
});
});
ck
::
utils
::
CorrectnessValidator
validator
;
return
ck
::
utils
::
check_err
(
e1_g_m_o_device_result
,
e1_g_m_o_host_result
)
?
0
:
1
;
validator
.
check_err
(
e1_g_m_o_device_result
,
e1_g_m_o_host_result
);
return
!
validator
.
is_success
();
}
}
return
0
;
return
0
;
...
...
example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_bias_relu_example.inc
View file @
ce87bcc7
...
@@ -156,8 +156,9 @@ bool run_conv_bwd_data_bias_relu(const ExecutionConfig& config,
...
@@ -156,8 +156,9 @@ bool run_conv_bwd_data_bias_relu(const ExecutionConfig& config,
[
&
](
auto
&
,
auto
idx
)
{
in_element_op
(
in_host
(
idx
),
c_host
(
idx
),
bias
(
idx
));
});
[
&
](
auto
&
,
auto
idx
)
{
in_element_op
(
in_host
(
idx
),
c_host
(
idx
),
bias
(
idx
));
});
in_device_buf
.
FromDevice
(
in_device
.
mData
.
data
());
in_device_buf
.
FromDevice
(
in_device
.
mData
.
data
());
ck
::
utils
::
CorrectnessValidator
validator
;
return
ck
::
utils
::
check_err
(
in_device
,
in_host
);
validator
.
check_err
(
in_device
,
in_host
);
return
validator
.
is_success
();
}
}
return
true
;
return
true
;
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment