Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
f885c131
Commit
f885c131
authored
Aug 09, 2021
by
Chao Liu
Browse files
tidy
parent
80120f0a
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
90 additions
and
147 deletions
+90
-147
composable_kernel/include/utility/amd_address_space.hpp
composable_kernel/include/utility/amd_address_space.hpp
+6
-0
composable_kernel/include/utility/print.hpp
composable_kernel/include/utility/print.hpp
+0
-48
host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp
...ion_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp
+0
-3
host/driver_offline/include/driver_dynamic_contraction_dlops_v1r2.hpp
...offline/include/driver_dynamic_contraction_dlops_v1r2.hpp
+0
-4
host/driver_offline/include/driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp
...ution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp
+0
-4
host/driver_offline/include/driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp
...orward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp
+0
-4
host/driver_offline/include/driver_dynamic_gemm_dlops_v1r2.hpp
...driver_offline/include/driver_dynamic_gemm_dlops_v1r2.hpp
+24
-24
host/driver_offline/include/driver_dynamic_gemm_dlops_v1r3.hpp
...driver_offline/include/driver_dynamic_gemm_dlops_v1r3.hpp
+32
-24
host/driver_offline/include/driver_dynamic_gemm_xdlops_v2r3.hpp
...river_offline/include/driver_dynamic_gemm_xdlops_v2r3.hpp
+13
-15
host/driver_offline/src/conv_fwd_driver_offline.cpp
host/driver_offline/src/conv_fwd_driver_offline.cpp
+8
-8
host/host_tensor/include/device.hpp
host/host_tensor/include/device.hpp
+7
-13
No files found.
composable_kernel/include/utility/amd_address_space.hpp
View file @
f885c131
...
...
@@ -20,6 +20,12 @@ __device__ T* cast_pointer_to_generic_address_space(T CONSTANT* p)
return
(
T
*
)
p
;
}
template
<
typename
T
>
__host__
__device__
T
CONSTANT
*
cast_pointer_to_constant_address_space
(
T
*
p
)
{
return
(
T
CONSTANT
*
)
p
;
}
}
// namespace ck
#endif
composable_kernel/include/utility/print.hpp
View file @
f885c131
...
...
@@ -11,59 +11,11 @@ namespace ck {
template
<
typename
T
>
__host__
__device__
void
print_array
(
const
char
*
s
,
T
a
)
{
using
data_type
=
decltype
(
a
.
At
(
Number
<
0
>
{}));
constexpr
index_t
nsize
=
a
.
Size
();
#if 0
if constexpr(is_same<data_type, uint32_t>{})
{
printf("%s size %u, {", s, nsize);
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%u, ", uint32_t{a[i]}); });
printf("}\n");
}
else if constexpr(is_same<data_type, int32_t>{})
{
printf("%s size %d, {", s, nsize);
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", int32_t{a[i]}); });
printf("}\n");
}
else if constexpr(is_same<data_type, bool>{})
{
printf("%s size %d, {", s, nsize);
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", bool{a[i]}); });
printf("}\n");
}
#else
printf
(
"%s size %d, {"
,
s
,
nsize
);
static_for
<
0
,
nsize
,
1
>
{}([
&
a
](
auto
i
)
constexpr
{
printf
(
"%d, "
,
int32_t
{
a
[
i
]});
});
printf
(
"}
\n
"
);
#endif
}
template
<
typename
T
>
__host__
__device__
void
print_array_v2
(
const
char
*
s
,
T
a
)
{
using
data_type
=
decltype
(
a
.
At
(
Number
<
0
>
{}));
constexpr
index_t
nsize
=
a
.
Size
();
#if 0
if constexpr(is_same<data_type, uint32_t>{})
{
printf("%s size %u, {", s, nsize);
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("[%u] %u, ", i.value, a[i]); });
printf("}\n");
}
else if constexpr(is_same<data_type, int32_t>{})
{
printf("%s size %d, {", s, nsize);
static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("[%d] %d, ", i.value, a[i]); });
printf("}\n");
}
#else
printf
(
"%s size %d, {"
,
s
,
nsize
);
static_for
<
0
,
nsize
,
1
>
{}([
&
a
](
auto
i
)
constexpr
{
printf
(
"[%d] %d, "
,
i
.
value
,
a
[
i
]);
});
printf
(
"}
\n
"
);
#endif
}
}
// namespace ck
...
...
host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp
View file @
f885c131
...
...
@@ -257,9 +257,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw
const
auto
K
=
out_n_ho_wo_k_lengths
[
I3
];
const
auto
C
=
wei_k_y_x_c_lengths
[
I3
];
const
auto
Hi
=
in_n_hi_wi_c_lengths
[
I1
];
const
auto
Wi
=
in_n_hi_wi_c_lengths
[
I2
];
const
auto
Ho
=
out_n_ho_wo_k_lengths
[
I1
];
const
auto
Wo
=
out_n_ho_wo_k_lengths
[
I2
];
...
...
host/driver_offline/include/driver_dynamic_contraction_dlops_v1r2.hpp
View file @
f885c131
...
...
@@ -194,7 +194,6 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
...
...
@@ -221,7 +220,6 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
...
...
@@ -248,7 +246,6 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
...
...
@@ -275,7 +272,6 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
...
...
host/driver_offline/include/driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp
View file @
f885c131
...
...
@@ -244,7 +244,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_e_k_global_desc
,
p_wei_global
,
in_e_n_ho_wo_global_desc
,
...
...
@@ -270,7 +269,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_e_k_global_desc
,
p_wei_global
,
in_e_n_ho_wo_global_desc
,
...
...
@@ -296,7 +294,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_e_k_global_desc
,
p_wei_global
,
in_e_n_ho_wo_global_desc
,
...
...
@@ -322,7 +319,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_e_k_global_desc
,
p_wei_global
,
in_e_n_ho_wo_global_desc
,
...
...
host/driver_offline/include/driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp
View file @
f885c131
...
...
@@ -257,7 +257,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_e_k_global_desc
,
p_wei_global
,
in_e_n_ho_wo_global_desc
,
...
...
@@ -284,7 +283,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_e_k_global_desc
,
p_wei_global
,
in_e_n_ho_wo_global_desc
,
...
...
@@ -311,7 +309,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_e_k_global_desc
,
p_wei_global
,
in_e_n_ho_wo_global_desc
,
...
...
@@ -338,7 +335,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
wei_e_k_global_desc
,
p_wei_global
,
in_e_n_ho_wo_global_desc
,
...
...
host/driver_offline/include/driver_dynamic_gemm_dlops_v1r2.hpp
View file @
f885c131
...
...
@@ -189,7 +189,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
...
...
@@ -216,7 +215,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
...
...
@@ -243,7 +241,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
...
...
@@ -270,7 +267,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
...
...
@@ -315,14 +311,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
(
void
CONSTANT
*
)
a_k_m0_m1_grid_desc_dev_buf
.
GetDeviceBuffer
(),
(
void
CONSTANT
*
)
b_k_n0_n1_grid_desc_dev_buf
.
GetDeviceBuffer
(),
(
void
CONSTANT
*
)
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf
.
GetDeviceBuffer
(),
(
void
CONSTANT
*
)
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
());
cast_pointer_to_constant_address_space
(
a_k_m0_m1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
b_k_n0_n1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
...
...
@@ -343,14 +340,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
(
void
CONSTANT
*
)
a_k_m0_m1_grid_desc_dev_buf
.
GetDeviceBuffer
(),
(
void
CONSTANT
*
)
b_k_n0_n1_grid_desc_dev_buf
.
GetDeviceBuffer
(),
(
void
CONSTANT
*
)
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf
.
GetDeviceBuffer
(),
(
void
CONSTANT
*
)
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
());
cast_pointer_to_constant_address_space
(
a_k_m0_m1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
b_k_n0_n1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
...
...
@@ -371,14 +369,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
(
void
CONSTANT
*
)
a_k_m0_m1_grid_desc_dev_buf
.
GetDeviceBuffer
(),
(
void
CONSTANT
*
)
b_k_n0_n1_grid_desc_dev_buf
.
GetDeviceBuffer
(),
(
void
CONSTANT
*
)
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf
.
GetDeviceBuffer
(),
(
void
CONSTANT
*
)
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
());
cast_pointer_to_constant_address_space
(
a_k_m0_m1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
b_k_n0_n1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
}
else
{
...
...
@@ -399,14 +398,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
(
void
CONSTANT
*
)
a_k_m0_m1_grid_desc_dev_buf
.
GetDeviceBuffer
(),
(
void
CONSTANT
*
)
b_k_n0_n1_grid_desc_dev_buf
.
GetDeviceBuffer
(),
(
void
CONSTANT
*
)
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf
.
GetDeviceBuffer
(),
(
void
CONSTANT
*
)
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
());
cast_pointer_to_constant_address_space
(
a_k_m0_m1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
b_k_n0_n1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
}
return
ave_time
;
...
...
host/driver_offline/include/driver_dynamic_gemm_dlops_v1r3.hpp
View file @
f885c131
...
...
@@ -185,7 +185,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
...
...
@@ -212,7 +211,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
...
...
@@ -239,7 +237,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
...
...
@@ -266,7 +263,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
...
...
@@ -311,14 +307,17 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
(
void
CONSTANT
*
)
a_k0_m0_m1_k1_grid_desc_dev_buf
.
GetDeviceBuffer
(),
(
void
CONSTANT
*
)
b_k0_n0_n1_k1_grid_desc_dev_buf
.
GetDeviceBuffer
(),
(
void
CONSTANT
*
)
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf
.
GetDeviceBuffer
(),
(
void
CONSTANT
*
)
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
());
cast_pointer_to_constant_address_space
(
a_k0_m0_m1_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
b_k0_n0_n1_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
}
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
...
...
@@ -339,14 +338,17 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
(
void
CONSTANT
*
)
a_k0_m0_m1_k1_grid_desc_dev_buf
.
GetDeviceBuffer
(),
(
void
CONSTANT
*
)
b_k0_n0_n1_k1_grid_desc_dev_buf
.
GetDeviceBuffer
(),
(
void
CONSTANT
*
)
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf
.
GetDeviceBuffer
(),
(
void
CONSTANT
*
)
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
());
cast_pointer_to_constant_address_space
(
a_k0_m0_m1_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
b_k0_n0_n1_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
}
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
...
...
@@ -367,14 +369,17 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
(
void
CONSTANT
*
)
a_k0_m0_m1_k1_grid_desc_dev_buf
.
GetDeviceBuffer
(),
(
void
CONSTANT
*
)
b_k0_n0_n1_k1_grid_desc_dev_buf
.
GetDeviceBuffer
(),
(
void
CONSTANT
*
)
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf
.
GetDeviceBuffer
(),
(
void
CONSTANT
*
)
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
());
cast_pointer_to_constant_address_space
(
a_k0_m0_m1_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
b_k0_n0_n1_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
}
else
{
...
...
@@ -395,14 +400,17 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
(
void
CONSTANT
*
)
a_k0_m0_m1_k1_grid_desc_dev_buf
.
GetDeviceBuffer
(),
(
void
CONSTANT
*
)
b_k0_n0_n1_k1_grid_desc_dev_buf
.
GetDeviceBuffer
(),
(
void
CONSTANT
*
)
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf
.
GetDeviceBuffer
(),
(
void
CONSTANT
*
)
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
());
cast_pointer_to_constant_address_space
(
a_k0_m0_m1_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
b_k0_n0_n1_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
}
return
ave_time
;
...
...
host/driver_offline/include/driver_dynamic_gemm_xdlops_v2r3.hpp
View file @
f885c131
...
...
@@ -153,7 +153,6 @@ __host__ float driver_dynamic_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
...
...
@@ -173,20 +172,19 @@ __host__ float driver_dynamic_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
c_m0_m1_m2_n_grid_desc_dev_buf
.
ToDevice
(
&
c_m0_m1_m2_n_grid_desc
);
c_block_cluster_adaptor_dev_buf
.
ToDevice
(
&
c_block_cluster_adaptor
);
float
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
(
void
CONSTANT
*
)
a_k0_m_k1_grid_desc_dev_buf
.
GetDeviceBuffer
(),
(
void
CONSTANT
*
)
b_k0_n_k1_grid_desc_dev_buf
.
GetDeviceBuffer
(),
(
void
CONSTANT
*
)
c_m0_m1_m2_n_grid_desc_dev_buf
.
GetDeviceBuffer
(),
(
void
CONSTANT
*
)
c_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
());
float
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
cast_pointer_to_constant_address_space
(
a_k0_m_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
b_k0_n_k1_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_m0_m1_m2_n_grid_desc_dev_buf
.
GetDeviceBuffer
()),
cast_pointer_to_constant_address_space
(
c_block_cluster_adaptor_dev_buf
.
GetDeviceBuffer
()));
#endif
return
ave_time
;
}
...
...
host/driver_offline/src/conv_fwd_driver_offline.cpp
View file @
f885c131
...
...
@@ -142,10 +142,8 @@ int main(int argc, char* argv[])
std
::
vector
<
std
::
size_t
>
in_lengths_host
(
4
),
wei_lengths_host
(
4
),
out_lengths_host
(
4
);
switch
(
layout
)
if
(
layout
==
ConvTensorLayout
::
NCHW
)
{
case
ConvTensorLayout
::
NCHW
:
// NCHW
in_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
N
);
in_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
C
);
in_lengths_host
[
2
]
=
static_cast
<
std
::
size_t
>
(
Hi
);
...
...
@@ -158,9 +156,9 @@ int main(int argc, char* argv[])
out_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
K
);
out_lengths_host
[
2
]
=
static_cast
<
std
::
size_t
>
(
Ho
);
out_lengths_host
[
3
]
=
static_cast
<
std
::
size_t
>
(
Wo
);
break
;
case
ConvTensorLayout
::
NHWC
:
// NHWC
}
else
if
(
layout
==
ConvTensorLayout
::
NHWC
)
{
in_lengths_host
[
0
]
=
static_cast
<
std
::
size_t
>
(
N
);
in_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
Hi
);
in_lengths_host
[
2
]
=
static_cast
<
std
::
size_t
>
(
Wi
);
...
...
@@ -173,8 +171,10 @@ int main(int argc, char* argv[])
out_lengths_host
[
1
]
=
static_cast
<
std
::
size_t
>
(
Ho
);
out_lengths_host
[
2
]
=
static_cast
<
std
::
size_t
>
(
Wo
);
out_lengths_host
[
3
]
=
static_cast
<
std
::
size_t
>
(
K
);
break
;
default:
throw
std
::
runtime_error
(
"wrong! not implemented"
);
}
else
{
std
::
runtime_error
(
"wrong! not implemented"
);
}
Tensor
<
in_data_t
>
in
(
in_lengths_host
);
...
...
host/host_tensor/include/device.hpp
View file @
f885c131
...
...
@@ -34,24 +34,16 @@ struct KernelTimer
using
device_stream_t
=
hipStream_t
;
template
<
typename
...
Args
,
typename
F
>
void
launch_kernel
(
F
kernel
,
dim3
grid_dim
,
dim3
block_dim
,
std
::
size_t
lds_byte
,
hipStream_t
stream_id
,
Args
...
args
)
void
launch_kernel
(
F
kernel
,
dim3
grid_dim
,
dim3
block_dim
,
std
::
size_t
lds_byte
,
Args
...
args
)
{
hipStream_t
stream_id
=
nullptr
;
hipLaunchKernelGGL
(
kernel
,
grid_dim
,
block_dim
,
lds_byte
,
stream_id
,
args
...);
}
template
<
typename
...
Args
,
typename
F
>
float
launch_and_time_kernel
(
F
kernel
,
int
nrepeat
,
dim3
grid_dim
,
dim3
block_dim
,
std
::
size_t
lds_byte
,
hipStream_t
stream_id
,
Args
...
args
)
float
launch_and_time_kernel
(
F
kernel
,
int
nrepeat
,
dim3
grid_dim
,
dim3
block_dim
,
std
::
size_t
lds_byte
,
Args
...
args
)
{
KernelTimer
timer
;
...
...
@@ -66,6 +58,8 @@ float launch_and_time_kernel(F kernel,
printf
(
"Warm up
\n
"
);
hipStream_t
stream_id
=
nullptr
;
// warm up
hipLaunchKernelGGL
(
kernel
,
grid_dim
,
block_dim
,
lds_byte
,
stream_id
,
args
...);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment