Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
e92395d9
Commit
e92395d9
authored
Dec 27, 2024
by
coderfeli
Browse files
Merge remote-tracking branch 'origin/cka8w8_devtimer' into update_cka8w8_uc
parents
842d910e
7efafa11
Changes
81
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
829 additions
and
308 deletions
+829
-308
example/ck_tile/16_batched_gemm/batched_gemm.hpp
example/ck_tile/16_batched_gemm/batched_gemm.hpp
+1
-5
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
+31
-4
example/ck_tile/17_grouped_gemm/grouped_gemm.hpp
example/ck_tile/17_grouped_gemm/grouped_gemm.hpp
+13
-7
example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc
example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc
+21
-13
include/ck/config.h.in
include/ck/config.h.in
+16
-0
include/ck/host_utility/flush_cache.hpp
include/ck/host_utility/flush_cache.hpp
+77
-90
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp
+6
-0
include/ck/tensor_operation/gpu/device/device_base.hpp
include/ck/tensor_operation/gpu/device/device_base.hpp
+34
-0
include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp
...nsor_operation/gpu/device/device_batched_gemm_multi_d.hpp
+2
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp
...e/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp
+29
-16
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp
...operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp
+1
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
...ration/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
+8
-8
include/ck/utility/amd_ck_fp8.hpp
include/ck/utility/amd_ck_fp8.hpp
+18
-6
include/ck_tile/core/arch/amd_buffer_addressing.hpp
include/ck_tile/core/arch/amd_buffer_addressing.hpp
+2
-2
include/ck_tile/core/container/meta_data_buffer.hpp
include/ck_tile/core/container/meta_data_buffer.hpp
+3
-3
include/ck_tile/core/tensor/static_distributed_tensor.hpp
include/ck_tile/core/tensor/static_distributed_tensor.hpp
+1
-0
include/ck_tile/host/arg_parser.hpp
include/ck_tile/host/arg_parser.hpp
+44
-2
include/ck_tile/host/reference/reference_gemm.hpp
include/ck_tile/host/reference/reference_gemm.hpp
+11
-151
include/ck_tile/ops/flatmm.hpp
include/ck_tile/ops/flatmm.hpp
+1
-0
include/ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp
.../flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp
+510
-0
No files found.
example/ck_tile/16_batched_gemm/batched_gemm.hpp
View file @
e92395d9
...
@@ -29,10 +29,6 @@ using BDataType = Types::BDataType;
...
@@ -29,10 +29,6 @@ using BDataType = Types::BDataType;
using
AccDataType
=
Types
::
AccDataType
;
using
AccDataType
=
Types
::
AccDataType
;
using
CDataType
=
Types
::
CDataType
;
using
CDataType
=
Types
::
CDataType
;
struct
batched_gemm_kargs
:
public
ck_tile
::
BatchedGemmHostArgs
{
};
auto
create_args
(
int
argc
,
char
*
argv
[])
auto
create_args
(
int
argc
,
char
*
argv
[])
{
{
ck_tile
::
ArgParser
arg_parser
;
ck_tile
::
ArgParser
arg_parser
;
...
@@ -60,4 +56,4 @@ auto create_args(int argc, char* argv[])
...
@@ -60,4 +56,4 @@ auto create_args(int argc, char* argv[])
}
}
// host API
// host API
float
batched_gemm
(
b
atched
_g
emm
_ka
rgs
args
,
const
ck_tile
::
stream_config
&
s
);
float
batched_gemm
(
const
ck_tile
::
B
atched
G
emm
HostA
rgs
&
args
,
const
ck_tile
::
stream_config
&
s
);
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
View file @
e92395d9
...
@@ -20,7 +20,7 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
...
@@ -20,7 +20,7 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
int
n_warmup
,
int
n_warmup
,
int
n_repeat
)
int
n_repeat
)
{
{
b
atched
_g
emm
_ka
rgs
args
;
ck_tile
::
B
atched
G
emm
HostA
rgs
args
;
args
.
a_ptr
=
a_m_k_dev_buf
.
GetDeviceBuffer
();
args
.
a_ptr
=
a_m_k_dev_buf
.
GetDeviceBuffer
();
args
.
b_ptr
=
b_k_n_dev_buf
.
GetDeviceBuffer
();
args
.
b_ptr
=
b_k_n_dev_buf
.
GetDeviceBuffer
();
args
.
c_ptr
=
c_m_n_dev_buf
.
GetDeviceBuffer
();
args
.
c_ptr
=
c_m_n_dev_buf
.
GetDeviceBuffer
();
...
@@ -188,15 +188,33 @@ int run_batched_gemm_example_with_layouts(int argc,
...
@@ -188,15 +188,33 @@ int run_batched_gemm_example_with_layouts(int argc,
c_m_n_gpu_ref
.
SetZero
();
c_m_n_gpu_ref
.
SetZero
();
c_m_n_gpu_buf_ref
.
SetZero
();
c_m_n_gpu_buf_ref
.
SetZero
();
ADataType
*
d_A
;
BDataType
*
d_B
;
CDataType
*
d_C
;
ck_tile
::
hip_check_error
(
hipMalloc
(
&
d_A
,
batch_count
*
M
*
K
*
sizeof
(
ADataType
)));
ck_tile
::
hip_check_error
(
hipMalloc
(
&
d_B
,
batch_count
*
N
*
K
*
sizeof
(
BDataType
)));
ck_tile
::
hip_check_error
(
hipMalloc
(
&
d_C
,
batch_count
*
M
*
N
*
sizeof
(
CDataType
)));
ck_tile
::
hip_check_error
(
hipMemcpy
(
d_A
,
a_m_k_dev_buf
.
GetDeviceBuffer
(),
batch_count
*
M
*
K
*
sizeof
(
ADataType
),
hipMemcpyHostToDevice
));
ck_tile
::
hip_check_error
(
hipMemcpy
(
d_B
,
b_k_n_dev_buf
.
GetDeviceBuffer
(),
batch_count
*
N
*
K
*
sizeof
(
BDataType
),
hipMemcpyHostToDevice
));
ck_tile
::
reference_batched_gemm_gpu
<
ADataType
,
ck_tile
::
reference_batched_gemm_gpu
<
ADataType
,
BDataType
,
BDataType
,
AccDataType
,
AccDataType
,
CDataType
,
CDataType
,
ALayout
,
ALayout
,
BLayout
,
BLayout
,
CLayout
>
(
a_m_k_dev_buf
,
CLayout
>
(
d_A
,
b_k_n_dev_buf
,
d_B
,
c_m_n_gpu_buf_ref
,
d_C
,
M
,
M
,
N
,
N
,
K
,
K
,
...
@@ -208,6 +226,15 @@ int run_batched_gemm_example_with_layouts(int argc,
...
@@ -208,6 +226,15 @@ int run_batched_gemm_example_with_layouts(int argc,
batch_stride_C
,
batch_stride_C
,
batch_count
);
batch_count
);
ck_tile
::
hip_check_error
(
hipMemcpy
(
c_m_n_gpu_buf_ref
.
GetDeviceBuffer
(),
d_C
,
batch_count
*
M
*
N
*
sizeof
(
CDataType
),
hipMemcpyDeviceToHost
));
ck_tile
::
hip_check_error
(
hipFree
(
d_A
));
ck_tile
::
hip_check_error
(
hipFree
(
d_B
));
ck_tile
::
hip_check_error
(
hipFree
(
d_C
));
c_m_n_gpu_buf_ref
.
FromDevice
(
c_m_n_gpu_ref
.
data
());
c_m_n_gpu_buf_ref
.
FromDevice
(
c_m_n_gpu_ref
.
data
());
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_gpu_ref
);
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_gpu_ref
);
...
...
example/ck_tile/17_grouped_gemm/grouped_gemm.hpp
View file @
e92395d9
...
@@ -34,13 +34,19 @@ using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs;
...
@@ -34,13 +34,19 @@ using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs;
auto
create_args
(
int
argc
,
char
*
argv
[])
auto
create_args
(
int
argc
,
char
*
argv
[])
{
{
ck_tile
::
ArgParser
arg_parser
;
ck_tile
::
ArgParser
arg_parser
;
arg_parser
.
insert
(
"a_layout"
,
"R"
,
"A tensor data layout - Row by default"
)
arg_parser
.
insert
(
"Ms"
,
""
,
"M dimensions - empty by default."
)
.
insert
(
"b_layout"
,
"R"
,
"B tensor data layout - Row by default"
)
.
insert
(
"Ns"
,
""
,
"N dimensions - empty by default."
)
.
insert
(
"c_layout"
,
"R"
,
"C tensor data layout - Row by default"
)
.
insert
(
"Ks"
,
""
,
"K dimensions - empty by default."
)
.
insert
(
"validate"
,
"1"
,
"0. No validation, 1. Validation on CPU"
)
.
insert
(
"stride_As"
,
""
,
"Tensor A strides - it is empty by default."
)
.
insert
(
"warmup"
,
"10"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"stride_Bs"
,
""
,
"Tensor B strides - it is empty by default."
)
.
insert
(
"repeat"
,
"100"
,
"number of iterations to benchmark the kernel"
)
.
insert
(
"stride_Cs"
,
""
,
"Tensor C strides - it is empty by default."
)
.
insert
(
"group_count"
,
"16"
,
"group count"
);
.
insert
(
"a_layout"
,
"R"
,
"A tensor data layout - Row by default."
)
.
insert
(
"b_layout"
,
"R"
,
"B tensor data layout - Row by default."
)
.
insert
(
"c_layout"
,
"R"
,
"C tensor data layout - Row by default."
)
.
insert
(
"validate"
,
"1"
,
"0. No validation, 1. Validation on CPU."
)
.
insert
(
"warmup"
,
"10"
,
"number of iterations before benchmark the kernel."
)
.
insert
(
"repeat"
,
"100"
,
"number of iterations to benchmark the kernel."
)
.
insert
(
"group_count"
,
"16"
,
"group count."
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
return
std
::
make_tuple
(
result
,
arg_parser
);
return
std
::
make_tuple
(
result
,
arg_parser
);
...
...
example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc
View file @
e92395d9
...
@@ -53,17 +53,24 @@ int run_grouped_gemm_example_with_layouts(int argc,
...
@@ -53,17 +53,24 @@ int run_grouped_gemm_example_with_layouts(int argc,
return
-
1
;
return
-
1
;
};
};
auto
valid_input_data
=
[
&
](
int
group_count
,
const
auto
&...
args
)
{
return
!
(
args
.
empty
()
||
...
)
&&
group_count
==
(
args
.
size
()
==
...
);
};
const
int
group_count
=
arg_parser
.
get_int
(
"group_count"
);
const
int
group_count
=
arg_parser
.
get_int
(
"group_count"
);
const
int
repeat
=
arg_parser
.
get_int
(
"repeat"
);
const
int
repeat
=
arg_parser
.
get_int
(
"repeat"
);
const
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
const
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
std
::
vector
<
ck_tile
::
index_t
>
Ms
;
std
::
vector
<
ck_tile
::
index_t
>
Ms
=
arg_parser
.
get_int_vec
(
"Ms"
)
;
std
::
vector
<
ck_tile
::
index_t
>
Ns
;
std
::
vector
<
ck_tile
::
index_t
>
Ns
=
arg_parser
.
get_int_vec
(
"Ns"
)
;
std
::
vector
<
ck_tile
::
index_t
>
Ks
;
std
::
vector
<
ck_tile
::
index_t
>
Ks
=
arg_parser
.
get_int_vec
(
"Ks"
)
;
std
::
vector
<
ck_tile
::
index_t
>
stride_As
;
std
::
vector
<
ck_tile
::
index_t
>
stride_As
=
arg_parser
.
get_int_vec
(
"stride_As"
)
;
std
::
vector
<
ck_tile
::
index_t
>
stride_Bs
;
std
::
vector
<
ck_tile
::
index_t
>
stride_Bs
=
arg_parser
.
get_int_vec
(
"stride_Bs"
)
;
std
::
vector
<
ck_tile
::
index_t
>
stride_Cs
;
std
::
vector
<
ck_tile
::
index_t
>
stride_Cs
=
arg_parser
.
get_int_vec
(
"stride_Cs"
)
;
if
(
!
valid_input_data
(
group_count
,
Ms
,
Ns
,
Ks
,
stride_As
,
stride_Bs
,
stride_Cs
))
{
std
::
cout
<<
"Please check the input data. Default values will be used."
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<
group_count
;
i
++
)
for
(
int
i
=
0
;
i
<
group_count
;
i
++
)
{
{
Ms
.
push_back
(
256
+
256
*
i
);
Ms
.
push_back
(
256
+
256
*
i
);
...
@@ -74,6 +81,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
...
@@ -74,6 +81,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
stride_Bs
.
push_back
(
Ks
[
i
]);
stride_Bs
.
push_back
(
Ks
[
i
]);
stride_Cs
.
push_back
(
Ns
[
i
]);
stride_Cs
.
push_back
(
Ns
[
i
]);
}
}
}
std
::
vector
<
ck_tile
::
HostTensor
<
ADataType
>>
a_m_k_tensors
;
std
::
vector
<
ck_tile
::
HostTensor
<
ADataType
>>
a_m_k_tensors
;
std
::
vector
<
ck_tile
::
HostTensor
<
BDataType
>>
b_k_n_tensors
;
std
::
vector
<
ck_tile
::
HostTensor
<
BDataType
>>
b_k_n_tensors
;
...
...
include/ck/config.h.in
View file @
e92395d9
...
@@ -111,6 +111,22 @@
...
@@ -111,6 +111,22 @@
#cmakedefine CK_USE_WMMA @CK_USE_WMMA@
#cmakedefine CK_USE_WMMA @CK_USE_WMMA@
#endif
#endif
#ifndef CK_USE_GFX94
#cmakedefine CK_USE_GFX94 @CK_USE_GFX94@
#endif
#ifndef CK_USE_OCP_FP8
#cmakedefine CK_USE_OCP_FP8 @CK_USE_OCP_FP8@
#endif
#ifndef CK_USE_FNUZ_FP8
#cmakedefine CK_USE_FNUZ_FP8 @CK_USE_FNUZ_FP8@
#endif
#ifndef CK_USE_FP8_ON_UNSUPPORTED_ARCH
#cmakedefine CK_USE_FP8_ON_UNSUPPORTED_ARCH @CK_USE_FP8_ON_UNSUPPORTED_ARCH@
#endif
// clang-format on
// clang-format on
#endif // CK_CONFIG_H_IN
#endif // CK_CONFIG_H_IN
include/ck/host_utility/flush_cache.hpp
View file @
e92395d9
...
@@ -4,7 +4,6 @@
...
@@ -4,7 +4,6 @@
#pragma once
#pragma once
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#include <hip/hip_ext.h>
#include <set>
#include <set>
#include <vector>
#include <vector>
...
@@ -43,8 +42,8 @@ struct RotatingMemWrapperMultiD
...
@@ -43,8 +42,8 @@ struct RotatingMemWrapperMultiD
{
{
{
{
void
*
pADeviceBuf
;
void
*
pADeviceBuf
;
HIP_CHECK_ERROR
(
hipMalloc
(
static_cast
<
void
**>
(
&
pADeviceBuf
),
size_a_
));
hip_check_error
(
hipMalloc
(
static_cast
<
void
**>
(
&
pADeviceBuf
),
size_a_
));
HIP_CHECK_ERROR
(
hipMemcpy
(
static_cast
<
void
*>
(
pADeviceBuf
),
hip_check_error
(
hipMemcpy
(
static_cast
<
void
*>
(
pADeviceBuf
),
const_cast
<
void
*>
(
p_a_grids
[
0
]),
const_cast
<
void
*>
(
p_a_grids
[
0
]),
size_a_
,
size_a_
,
hipMemcpyDeviceToDevice
));
hipMemcpyDeviceToDevice
));
...
@@ -53,8 +52,8 @@ struct RotatingMemWrapperMultiD
...
@@ -53,8 +52,8 @@ struct RotatingMemWrapperMultiD
{
{
void
*
pBDeviceBuf
;
void
*
pBDeviceBuf
;
HIP_CHECK_ERROR
(
hipMalloc
(
static_cast
<
void
**>
(
&
pBDeviceBuf
),
size_b_
));
hip_check_error
(
hipMalloc
(
static_cast
<
void
**>
(
&
pBDeviceBuf
),
size_b_
));
HIP_CHECK_ERROR
(
hipMemcpy
(
static_cast
<
void
*>
(
pBDeviceBuf
),
hip_check_error
(
hipMemcpy
(
static_cast
<
void
*>
(
pBDeviceBuf
),
const_cast
<
void
*>
(
p_b_grids
[
0
]),
const_cast
<
void
*>
(
p_b_grids
[
0
]),
size_b_
,
size_b_
,
hipMemcpyDeviceToDevice
));
hipMemcpyDeviceToDevice
));
...
@@ -66,8 +65,8 @@ struct RotatingMemWrapperMultiD
...
@@ -66,8 +65,8 @@ struct RotatingMemWrapperMultiD
DsGridPointer
ds_buffer
;
DsGridPointer
ds_buffer
;
static_for
<
0
,
NumDs
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
NumDs
,
1
>
{}([
&
](
auto
j
)
{
void
*
pDDeviceBuf
;
void
*
pDDeviceBuf
;
HIP_CHECK_ERROR
(
hipMalloc
(
static_cast
<
void
**>
(
&
pDDeviceBuf
),
size_ds_
[
j
]));
hip_check_error
(
hipMalloc
(
static_cast
<
void
**>
(
&
pDDeviceBuf
),
size_ds_
[
j
]));
HIP_CHECK_ERROR
(
hipMemcpy
(
static_cast
<
void
*>
(
pDDeviceBuf
),
hip_check_error
(
hipMemcpy
(
static_cast
<
void
*>
(
pDDeviceBuf
),
static_cast
<
const
void
*>
(
p_ds_grids
[
0
][
j
]),
static_cast
<
const
void
*>
(
p_ds_grids
[
0
][
j
]),
size_ds_
[
j
],
size_ds_
[
j
],
hipMemcpyDeviceToDevice
));
hipMemcpyDeviceToDevice
));
...
@@ -94,10 +93,8 @@ struct RotatingMemWrapperMultiD
...
@@ -94,10 +93,8 @@ struct RotatingMemWrapperMultiD
}
}
void
Print
()
void
Print
()
{
{
std
::
cout
<<
"RotatingMemWrapperMultiD: { size_a: "
<<
size_a
<<
", size_b: "
<<
size_b
;
std
::
cout
<<
"RotatingMemWrapperMultiD: { size_a: "
<<
size_a
<<
", size_b: "
<<
size_b
static_for
<
0
,
NumDs
,
1
>
{}(
<<
", rotating_count: "
<<
rotating_count
<<
"}"
<<
std
::
endl
;
[
&
](
auto
j
)
{
std
::
cout
<<
", size_d"
<<
j
.
value
<<
": "
<<
size_ds
[
j
];
});
std
::
cout
<<
", rotating_count: "
<<
rotating_count
<<
"}"
<<
std
::
endl
;
}
}
~
RotatingMemWrapperMultiD
()
~
RotatingMemWrapperMultiD
()
{
{
...
@@ -111,35 +108,13 @@ struct RotatingMemWrapperMultiD
...
@@ -111,35 +108,13 @@ struct RotatingMemWrapperMultiD
// free device mem
// free device mem
for
(
size_t
i
=
1
;
i
<
rotating_count
;
i
++
)
for
(
size_t
i
=
1
;
i
<
rotating_count
;
i
++
)
{
{
try
hip_check_error
(
hipFree
(
const_cast
<
void
*>
(
p_a_grids
[
i
])));
{
hip_check_error
(
hipFree
(
const_cast
<
void
*>
(
p_b_grids
[
i
])));
HIP_CHECK_ERROR
(
hipFree
(
const_cast
<
void
*>
(
p_a_grids
[
i
])));
}
catch
(
std
::
runtime_error
&
re
)
{
std
::
cerr
<<
re
.
what
()
<<
std
::
endl
;
}
try
{
HIP_CHECK_ERROR
(
hipFree
(
const_cast
<
void
*>
(
p_b_grids
[
i
])));
}
catch
(
std
::
runtime_error
&
re
)
{
std
::
cerr
<<
re
.
what
()
<<
std
::
endl
;
}
static_for
<
0
,
NumDs
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
NumDs
,
1
>
{}([
&
](
auto
j
)
{
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
j
.
value
,
DsDataType
>>
;
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
j
.
value
,
DsDataType
>>
;
try
hip_check_error
(
{
HIP_CHECK_ERROR
(
hipFree
(
static_cast
<
void
*>
(
const_cast
<
DDataType
*>
(
p_ds_grids
[
i
][
j
]))));
hipFree
(
static_cast
<
void
*>
(
const_cast
<
DDataType
*>
(
p_ds_grids
[
i
][
j
]))));
}
catch
(
std
::
runtime_error
&
re
)
{
std
::
cerr
<<
re
.
what
()
<<
std
::
endl
;
}
});
});
}
}
}
}
...
@@ -176,8 +151,8 @@ struct RotatingMemWrapper
...
@@ -176,8 +151,8 @@ struct RotatingMemWrapper
{
{
{
{
void
*
pADeviceBuf
;
void
*
pADeviceBuf
;
HIP_CHECK_ERROR
(
hipMalloc
(
static_cast
<
void
**>
(
&
pADeviceBuf
),
size_a_
));
hip_check_error
(
hipMalloc
(
static_cast
<
void
**>
(
&
pADeviceBuf
),
size_a_
));
HIP_CHECK_ERROR
(
hipMemcpy
(
static_cast
<
void
*>
(
pADeviceBuf
),
hip_check_error
(
hipMemcpy
(
static_cast
<
void
*>
(
pADeviceBuf
),
const_cast
<
void
*>
(
p_a_grids
[
0
]),
const_cast
<
void
*>
(
p_a_grids
[
0
]),
size_a_
,
size_a_
,
hipMemcpyDeviceToDevice
));
hipMemcpyDeviceToDevice
));
...
@@ -186,8 +161,8 @@ struct RotatingMemWrapper
...
@@ -186,8 +161,8 @@ struct RotatingMemWrapper
{
{
void
*
pBDeviceBuf
;
void
*
pBDeviceBuf
;
HIP_CHECK_ERROR
(
hipMalloc
(
static_cast
<
void
**>
(
&
pBDeviceBuf
),
size_b_
));
hip_check_error
(
hipMalloc
(
static_cast
<
void
**>
(
&
pBDeviceBuf
),
size_b_
));
HIP_CHECK_ERROR
(
hipMemcpy
(
static_cast
<
void
*>
(
pBDeviceBuf
),
hip_check_error
(
hipMemcpy
(
static_cast
<
void
*>
(
pBDeviceBuf
),
const_cast
<
void
*>
(
p_b_grids
[
0
]),
const_cast
<
void
*>
(
p_b_grids
[
0
]),
size_b_
,
size_b_
,
hipMemcpyDeviceToDevice
));
hipMemcpyDeviceToDevice
));
...
@@ -221,23 +196,8 @@ struct RotatingMemWrapper
...
@@ -221,23 +196,8 @@ struct RotatingMemWrapper
// free device mem
// free device mem
for
(
size_t
i
=
1
;
i
<
rotating_count
;
i
++
)
for
(
size_t
i
=
1
;
i
<
rotating_count
;
i
++
)
{
{
try
hip_check_error
(
hipFree
(
const_cast
<
void
*>
(
p_a_grids
[
i
])));
{
hip_check_error
(
hipFree
(
const_cast
<
void
*>
(
p_b_grids
[
i
])));
HIP_CHECK_ERROR
(
hipFree
(
const_cast
<
void
*>
(
p_a_grids
[
i
])));
}
catch
(
std
::
runtime_error
&
re
)
{
std
::
cerr
<<
re
.
what
()
<<
std
::
endl
;
}
try
{
HIP_CHECK_ERROR
(
hipFree
(
const_cast
<
void
*>
(
p_b_grids
[
i
])));
}
catch
(
std
::
runtime_error
&
re
)
{
std
::
cerr
<<
re
.
what
()
<<
std
::
endl
;
}
}
}
}
}
}
}
...
@@ -255,20 +215,25 @@ struct RotatingMemWrapper
...
@@ -255,20 +215,25 @@ struct RotatingMemWrapper
inline
void
flush_icache
()
inline
void
flush_icache
()
{
{
hipDeviceProp_t
deviceProps
;
hipDeviceProp_t
deviceProps
;
HIP_CHECK_ERROR
(
hipGetDeviceProperties
(
&
deviceProps
,
0
));
hip_check_error
(
hipGetDeviceProperties
(
&
deviceProps
,
0
));
int32_t
gpu_block3
=
deviceProps
.
multiProcessorCount
*
60
;
int32_t
gpu_block3
=
deviceProps
.
multiProcessorCount
*
60
;
ck
::
flush_icache
<<<
dim3
(
gpu_block3
),
dim3
(
64
),
0
,
nullptr
>>>
();
ck
::
flush_icache
<<<
dim3
(
gpu_block3
),
dim3
(
64
),
0
,
nullptr
>>>
();
HIP_CHECK_ERROR
(
hipGetLastError
());
hip_check_error
(
hipGetLastError
());
}
}
// if TimePrePress == false, return time does not include preprocess's time
// if TimePrePress == false, return time does not include preprocess's time
template
<
bool
TimePreprocess
,
typename
...
Args
,
typename
F
,
typename
PreProcessFunc
>
template
<
bool
TimePreprocess
,
typename
GemmArgs
,
typename
...
Args
,
typename
F
,
typename
PreProcessFunc
>
float
launch_and_time_kernel_with_preprocess
(
const
StreamConfig
&
stream_config
,
float
launch_and_time_kernel_with_preprocess
(
const
StreamConfig
&
stream_config
,
PreProcessFunc
preprocess
,
PreProcessFunc
preprocess
,
F
kernel
,
F
kernel
,
dim3
grid_dim
,
dim3
grid_dim
,
dim3
block_dim
,
dim3
block_dim
,
std
::
size_t
lds_byte
,
std
::
size_t
lds_byte
,
GemmArgs
&
gemm_args
,
Args
...
args
)
Args
...
args
)
{
{
#if CK_TIME_KERNEL
#if CK_TIME_KERNEL
...
@@ -291,8 +256,8 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
...
@@ -291,8 +256,8 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
// warm up
// warm up
for
(
int
i
=
0
;
i
<
stream_config
.
cold_niters_
;
++
i
)
for
(
int
i
=
0
;
i
<
stream_config
.
cold_niters_
;
++
i
)
{
{
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
gemm_args
,
args
...);
HIP_CHECK_ERROR
(
hipGetLastError
());
hip_check_error
(
hipGetLastError
());
}
}
const
int
nrepeat
=
stream_config
.
nrepeat_
;
const
int
nrepeat
=
stream_config
.
nrepeat_
;
...
@@ -312,36 +277,54 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
...
@@ -312,36 +277,54 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
#endif
#endif
hipEvent_t
start
,
stop
;
hipEvent_t
start
,
stop
;
HIP_CHECK_ERROR
(
hipEventCreate
(
&
start
));
hip_check_error
(
hipEventCreate
(
&
start
));
HIP_CHECK_ERROR
(
hipEventCreate
(
&
stop
));
hip_check_error
(
hipEventCreate
(
&
stop
));
hip_check_error
(
hipDeviceSynchronize
());
hip_check_error
(
hipEventRecord
(
start
,
stream_config
.
stream_id_
));
for
(
int
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
int
i
=
0
;
i
<
nrepeat
;
++
i
)
{
if
constexpr
(
!
TimePreprocess
)
{
{
preprocess
();
preprocess
();
}
// hipEvent_t start, stop;
// hip_check_error(hipEventCreate(&start));
// hip_check_error(hipEventCreate(&stop));
// hip_check_error(hipDeviceSynchronize());
// hip_check_error(hipEventRecord(start, stream_config.stream_id_));
// calculate preprocess time
if
constexpr
(
TimePreprocess
)
{
preprocess
();
}
// run real kernel
// run real kernel
hipExtLaunchKernelGGL
(
kernel
,
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
gemm_args
,
args
...);
grid_dim
,
hip_check_error
(
hipGetLastError
());
block_dim
,
lds_byte
,
stream_config
.
stream_id_
,
start
,
stop
,
0
,
args
...);
HIP_CHECK_ERROR
(
hipGetLastError
());
// end real kernel
// end real kernel
HIP_CHECK_ERROR
(
hipEventRecord
(
stop
,
stream_config
.
stream_id_
));
// hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
HIP_CHECK_ERROR
(
hipEventSynchronize
(
stop
));
// hip_check_error(hipEventSynchronize(stop));
// float cur_time = 0;
// hip_check_error(hipEventElapsedTime(&cur_time, start, stop));
// #if MEDIAN
// times.insert(cur_time);
// #else
// total_time += cur_time;
// #endif
float
cur_time
=
0
;
if
(
ck
::
EnvIsEnabled
(
CK_ENV
(
CK_LOGGING
)))
HIP_CHECK_ERROR
(
hipEventElapsedTime
(
&
cur_time
,
start
,
stop
));
{
#if MEDIAN
// std::cout << "i: " << i << " cur_time: " << cur_time << std::endl;
times
.
insert
(
cur_time
);
#else
printf
(
"gemm_args.p_a_grid: %p, gemm_args.p_b_grid:%p
\n
"
,
total_time
+=
cur_time
;
static_cast
<
const
void
*>
(
gemm_args
.
p_a_grid
),
#endif
static_cast
<
const
void
*>
(
gemm_args
.
p_b_grid
));
}
}
}
hip_check_error
(
hipEventRecord
(
stop
,
stream_config
.
stream_id_
));
hip_check_error
(
hipEventRecord
(
stop
,
stream_config
.
stream_id_
));
hip_check_error
(
hipEventSynchronize
(
stop
));
hip_check_error
(
hipEventSynchronize
(
stop
));
...
@@ -367,20 +350,24 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
...
@@ -367,20 +350,24 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
return
(
*
mid
+
*
mid_next
)
/
2
;
return
(
*
mid
+
*
mid_next
)
/
2
;
}
}
#else
#else
return
total_time
/
nrepeat
;
// return total_time / nrepeat;
hipDeviceProp_t
deviceProps
;
hip_check_error
(
hipGetDeviceProperties
(
&
deviceProps
,
0
));
float
preprocess_offset
=
deviceProps
.
multiProcessorCount
==
80
?
0.005
:
0.01
;
return
(
total_time
-
preprocess_offset
*
nrepeat
)
/
nrepeat
;
#endif
#endif
}
}
else
else
{
{
preprocess
();
preprocess
();
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
gemm_args
,
args
...);
HIP_CHECK_ERROR
(
hipGetLastError
());
hip_check_error
(
hipGetLastError
());
return
0
;
return
0
;
}
}
#else
#else
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
gemm_args
,
args
...);
HIP_CHECK_ERROR
(
hipGetLastError
());
hip_check_error
(
hipGetLastError
());
return
0
;
return
0
;
#endif
#endif
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp
View file @
e92395d9
...
@@ -477,6 +477,9 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
...
@@ -477,6 +477,9 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
b_thread_buf_tail
);
b_thread_buf_tail
);
});
});
});
});
HotLoopScheduler
();
__builtin_amdgcn_sched_barrier
(
0
);
}
}
}
}
...
@@ -692,6 +695,9 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
...
@@ -692,6 +695,9 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
});
});
});
});
HotLoopScheduler
();
__builtin_amdgcn_sched_barrier
(
0
);
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
...
...
include/ck/tensor_operation/gpu/device/device_base.hpp
View file @
e92395d9
...
@@ -5,6 +5,8 @@
...
@@ -5,6 +5,8 @@
#include <string>
#include <string>
#include <sstream>
#include <sstream>
#include <regex>
#include <optional>
#include "ck/stream_config.hpp"
#include "ck/stream_config.hpp"
...
@@ -12,6 +14,34 @@ namespace ck {
...
@@ -12,6 +14,34 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
#define GET_OBJECT_NAME_IMLP \
std::optional<std::string> GetObjectName() const override \
{ \
std::string str = __PRETTY_FUNCTION__; \
static std::regex obj_name_expr{"<std::string> (.*)::GetObjectName"}; \
std::smatch match; \
if(!std::regex_search(str, match, obj_name_expr)) \
{ \
return str; \
} \
return std::string(match[1]) + ';'; \
}
#define GET_TEMPLATE_INFO_IMPL \
std::optional<std::string> GetTemplateInfo() const override \
{ \
std::string str = __PRETTY_FUNCTION__; \
static std::regex template_expr{"\\[(.*)\\]"}; \
std::smatch match; \
if(!std::regex_search(str, match, template_expr)) \
{ \
return std::nullopt; \
} \
return std::string(match[1]); \
}
#define REGISTER_EXTRA_PRINTING_METHODS GET_OBJECT_NAME_IMLP GET_TEMPLATE_INFO_IMPL
struct
BaseArgument
struct
BaseArgument
{
{
BaseArgument
()
=
default
;
BaseArgument
()
=
default
;
...
@@ -48,6 +78,10 @@ struct BaseOperator
...
@@ -48,6 +78,10 @@ struct BaseOperator
virtual
std
::
string
GetTypeIdName
()
const
{
return
typeid
(
*
this
).
name
();
}
virtual
std
::
string
GetTypeIdName
()
const
{
return
typeid
(
*
this
).
name
();
}
virtual
std
::
optional
<
std
::
string
>
GetObjectName
()
const
{
return
std
::
nullopt
;
}
virtual
std
::
optional
<
std
::
string
>
GetTemplateInfo
()
const
{
return
std
::
nullopt
;
}
virtual
std
::
string
GetTypeIdHashCode
()
const
virtual
std
::
string
GetTypeIdHashCode
()
const
{
{
std
::
ostringstream
oss
;
std
::
ostringstream
oss
;
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp
View file @
e92395d9
...
@@ -89,7 +89,8 @@ struct DeviceBatchedGemmV2MultiD : public BaseOperator
...
@@ -89,7 +89,8 @@ struct DeviceBatchedGemmV2MultiD : public BaseOperator
index_t
BatchStrideE
,
index_t
BatchStrideE
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
=
0
;
CDEElementwiseOperation
cde_element_op
,
index_t
KBatch
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
};
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp
View file @
e92395d9
...
@@ -41,12 +41,15 @@ __global__ void
...
@@ -41,12 +41,15 @@ __global__ void
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
g_idx
=
blockIdx
.
z
%
karg
.
Batch
;
const
index_t
g_idx
=
blockIdx
.
z
%
karg
.
Batch
;
const
index_t
k_idx
=
blockIdx
.
z
/
karg
.
Batch
;
const
auto
a_batch_offset
=
karg
.
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
);
const
auto
a_batch_offset
=
karg
.
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
);
const
auto
b_batch_offset
=
karg
.
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
);
const
auto
b_batch_offset
=
karg
.
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
);
const
auto
ds_batch_offset
=
karg
.
compute_ptr_offset_of_batch
.
GetDsPtrOffset
(
g_idx
);
const
auto
ds_batch_offset
=
karg
.
compute_ptr_offset_of_batch
.
GetDsPtrOffset
(
g_idx
);
const
auto
c_batch_offset
=
karg
.
compute_ptr_offset_of_batch
.
GetCPtrOffset
(
g_idx
);
const
auto
c_batch_offset
=
karg
.
compute_ptr_offset_of_batch
.
GetCPtrOffset
(
g_idx
);
auto
splitk_batch_offset
=
typename
GridwiseGemm
::
SplitKBatchOffset
(
karg
,
k_idx
);
// populate pointer, desc for Ds
// populate pointer, desc for Ds
static_for
<
0
,
GridwiseGemm
::
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
GridwiseGemm
::
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
// D pointer
// D pointer
...
@@ -54,8 +57,8 @@ __global__ void
...
@@ -54,8 +57,8 @@ __global__ void
});
});
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
karg
.
p_a_grid
+
a_batch_offset
,
karg
.
p_a_grid
+
a_batch_offset
+
splitk_batch_offset
.
a_k_split_offset
,
karg
.
p_b_grid
+
b_batch_offset
,
karg
.
p_b_grid
+
b_batch_offset
+
splitk_batch_offset
.
b_k_split_offset
,
karg
.
p_ds_grid
,
karg
.
p_ds_grid
,
karg
.
p_c_grid
+
c_batch_offset
,
karg
.
p_c_grid
+
c_batch_offset
,
p_shared
,
p_shared
,
...
@@ -87,12 +90,15 @@ __global__ void
...
@@ -87,12 +90,15 @@ __global__ void
__shared__
char
p_shared_1
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared_1
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
g_idx
=
blockIdx
.
z
%
karg
.
Batch
;
const
index_t
g_idx
=
blockIdx
.
z
%
karg
.
Batch
;
const
index_t
k_idx
=
blockIdx
.
z
/
karg
.
Batch
;
const
auto
a_batch_offset
=
karg
.
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
);
const
auto
a_batch_offset
=
karg
.
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
);
const
auto
b_batch_offset
=
karg
.
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
);
const
auto
b_batch_offset
=
karg
.
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
);
const
auto
ds_batch_offset
=
karg
.
compute_ptr_offset_of_batch
.
GetDsPtrOffset
(
g_idx
);
const
auto
ds_batch_offset
=
karg
.
compute_ptr_offset_of_batch
.
GetDsPtrOffset
(
g_idx
);
const
auto
c_batch_offset
=
karg
.
compute_ptr_offset_of_batch
.
GetCPtrOffset
(
g_idx
);
const
auto
c_batch_offset
=
karg
.
compute_ptr_offset_of_batch
.
GetCPtrOffset
(
g_idx
);
auto
splitk_batch_offset
=
typename
GridwiseGemm
::
SplitKBatchOffset
(
karg
,
k_idx
);
// populate pointer, desc for Ds
// populate pointer, desc for Ds
static_for
<
0
,
GridwiseGemm
::
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
GridwiseGemm
::
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
// D pointer
// D pointer
...
@@ -100,8 +106,8 @@ __global__ void
...
@@ -100,8 +106,8 @@ __global__ void
});
});
GridwiseGemm
::
template
Run_2Lds
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
GridwiseGemm
::
template
Run_2Lds
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
karg
.
p_a_grid
+
a_batch_offset
,
karg
.
p_a_grid
+
a_batch_offset
+
splitk_batch_offset
.
a_k_split_offset
,
karg
.
p_b_grid
+
b_batch_offset
,
karg
.
p_b_grid
+
b_batch_offset
+
splitk_batch_offset
.
b_k_split_offset
,
karg
.
p_ds_grid
,
karg
.
p_ds_grid
,
karg
.
p_c_grid
+
c_batch_offset
,
karg
.
p_c_grid
+
c_batch_offset
,
p_shared_0
,
p_shared_0
,
...
@@ -303,7 +309,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
...
@@ -303,7 +309,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
index_t
Batch_
,
index_t
Batch_
,
AElementwiseOperation
a_element_op_
,
AElementwiseOperation
a_element_op_
,
BElementwiseOperation
b_element_op_
,
BElementwiseOperation
b_element_op_
,
CElementwiseOperation
c_element_op_
)
CElementwiseOperation
c_element_op_
,
index_t
KBatch_
)
:
GridwiseGemm
::
Argument
{
p_a_grid_
,
:
GridwiseGemm
::
Argument
{
p_a_grid_
,
p_b_grid_
,
p_b_grid_
,
p_ds_grid_
,
p_ds_grid_
,
...
@@ -315,7 +322,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
...
@@ -315,7 +322,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
StrideB_
,
StrideB_
,
StrideDs_
,
StrideDs_
,
StrideE_
,
StrideE_
,
1
,
KBatch_
,
a_element_op_
,
a_element_op_
,
b_element_op_
,
b_element_op_
,
c_element_op_
},
c_element_op_
},
...
@@ -336,13 +343,14 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
...
@@ -336,13 +343,14 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
arg
.
Print
();
arg
.
Print
();
}
}
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
)
||
arg
.
KBatch
>
1
)
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
))
{
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
}
}
index_t
gdx
,
gdy
,
gdz
;
index_t
gdx
,
gdy
,
gdz
;
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
M
,
arg
.
N
,
arg
.
Batch
);
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
M
,
arg
.
N
,
arg
.
Batch
*
arg
.
KBatch
);
float
ave_time
=
0
;
float
ave_time
=
0
;
...
@@ -387,9 +395,10 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
...
@@ -387,9 +395,10 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
rotating_mem
.
Next
();
rotating_mem
.
Next
();
// clear c mem
// clear c mem
if
(
arg_
.
KBatch
>
1
)
if
(
arg_
.
KBatch
>
1
)
hipGetErrorString
(
hipMemsetAsync
(
arg_
.
p_c_grid
,
hipGetErrorString
(
hipMemsetAsync
(
arg_
.
p_c_grid
,
0
,
0
,
arg_
.
M
*
arg_
.
N
*
sizeof
(
CDataType
),
arg
.
Batch
*
arg_
.
M
*
arg_
.
N
*
sizeof
(
CDataType
),
stream_config
.
stream_id_
));
stream_config
.
stream_id_
));
};
};
...
@@ -889,7 +898,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
...
@@ -889,7 +898,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
index_t
BatchStrideE
,
index_t
BatchStrideE
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
CElementwiseOperation
c_element_op
,
index_t
KBatch
=
1
)
{
{
return
Argument
{
static_cast
<
const
ADataType
*>
(
p_a
),
return
Argument
{
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
const
BDataType
*>
(
p_b
),
...
@@ -909,7 +919,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
...
@@ -909,7 +919,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
Batch
,
Batch
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
};
c_element_op
,
KBatch
};
}
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
@@ -934,7 +945,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
...
@@ -934,7 +945,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
index_t
BatchStrideE
,
index_t
BatchStrideE
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
override
CElementwiseOperation
c_element_op
,
index_t
KBatch
=
1
)
override
{
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
const
BDataType
*>
(
p_b
),
...
@@ -954,7 +966,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
...
@@ -954,7 +966,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
Batch
,
Batch
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
);
c_element_op
,
KBatch
);
}
}
// polymorphic
// polymorphic
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp
View file @
e92395d9
...
@@ -741,6 +741,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
...
@@ -741,6 +741,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
return
str
.
str
();
return
str
.
str
();
}
}
REGISTER_EXTRA_PRINTING_METHODS
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
View file @
e92395d9
...
@@ -41,7 +41,7 @@ __global__ void
...
@@ -41,7 +41,7 @@ __global__ void
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
auto
splitk_batch_offset
=
typename
GridwiseGemm
::
SplitKBatchOffset
(
karg
);
auto
splitk_batch_offset
=
typename
GridwiseGemm
::
SplitKBatchOffset
(
karg
,
blockIdx
.
z
);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
karg
.
p_a_grid
+
splitk_batch_offset
.
a_k_split_offset
,
karg
.
p_a_grid
+
splitk_batch_offset
.
a_k_split_offset
,
...
@@ -76,7 +76,7 @@ __global__ void
...
@@ -76,7 +76,7 @@ __global__ void
__shared__
char
p_shared_0
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared_0
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared_1
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared_1
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
auto
splitk_batch_offset
=
typename
GridwiseGemm
::
SplitKBatchOffset
(
karg
);
auto
splitk_batch_offset
=
typename
GridwiseGemm
::
SplitKBatchOffset
(
karg
,
blockIdx
.
z
);
GridwiseGemm
::
template
Run_2Lds
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
GridwiseGemm
::
template
Run_2Lds
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
karg
.
p_a_grid
+
splitk_batch_offset
.
a_k_split_offset
,
karg
.
p_a_grid
+
splitk_batch_offset
.
a_k_split_offset
,
...
@@ -639,27 +639,27 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
...
@@ -639,27 +639,27 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
struct
SplitKBatchOffset
struct
SplitKBatchOffset
{
{
__device__
SplitKBatchOffset
(
Argument
&
karg
)
__device__
SplitKBatchOffset
(
Argument
&
karg
,
index_t
k_id
)
{
{
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
{
{
a_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
;
a_k_split_offset
=
k_id
*
karg
.
KRead
;
}
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
{
a_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
*
karg
.
StrideA
;
a_k_split_offset
=
k_id
*
karg
.
KRead
*
karg
.
StrideA
;
}
}
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>
)
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>
)
{
{
b_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
*
karg
.
StrideB
;
b_k_split_offset
=
k_id
*
karg
.
KRead
*
karg
.
StrideB
;
}
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>
)
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>
)
{
{
b_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
;
b_k_split_offset
=
k_id
*
karg
.
KRead
;
}
}
if
(
blockIdx
.
z
<
static_cast
<
uint32_t
>
(
karg
.
KBatch
-
1
)
)
if
(
k_id
<
karg
.
KBatch
-
1
)
{
{
karg
.
K
=
karg
.
KRead
;
karg
.
K
=
karg
.
KRead
;
}
}
...
...
include/ck/utility/amd_ck_fp8.hpp
View file @
e92395d9
...
@@ -18,6 +18,20 @@
...
@@ -18,6 +18,20 @@
#define CK_USE_OCP_FP8 0
#define CK_USE_OCP_FP8 0
#endif
#endif
namespace
{
// https://en.cppreference.com/w/cpp/types/conditional
template
<
bool
B
,
class
T
,
class
F
>
struct
conditional
{
using
type
=
T
;
};
template
<
class
T
,
class
F
>
struct
conditional
<
false
,
T
,
F
>
{
using
type
=
F
;
};
}
// namespace
namespace
ck
{
namespace
ck
{
using
f8_fnuz_t
=
_BitInt
(
8
);
using
f8_fnuz_t
=
_BitInt
(
8
);
...
@@ -191,11 +205,10 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x)
...
@@ -191,11 +205,10 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x)
}
}
}
}
typename
__hip_internal
::
conditional
<
typename
conditional
<
sizeof
(
T
)
==
2
,
sizeof
(
T
)
==
2
,
unsigned
short
int
,
unsigned
short
int
,
typename
__hip_internal
::
conditional
<
sizeof
(
T
)
==
4
,
unsigned
int
,
unsigned
long
long
>::
typename
conditional
<
sizeof
(
T
)
==
4
,
unsigned
int
,
unsigned
long
long
>::
type
>::
type
retval
;
type
>::
type
retval
;
if
constexpr
(
we
==
5
&&
is_half
&&
!
is_fnuz
)
if
constexpr
(
we
==
5
&&
is_half
&&
!
is_fnuz
)
{
{
...
@@ -538,11 +551,10 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
...
@@ -538,11 +551,10 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
constexpr
int
mfmt
=
(
sizeof
(
T
)
==
8
)
?
52
:
((
sizeof
(
T
)
==
4
)
?
23
:
10
);
constexpr
int
mfmt
=
(
sizeof
(
T
)
==
8
)
?
52
:
((
sizeof
(
T
)
==
4
)
?
23
:
10
);
using
T_bitwise
=
typename
__hip_internal
::
conditional
<
using
T_bitwise
=
typename
conditional
<
sizeof
(
T
)
==
2
,
sizeof
(
T
)
==
2
,
unsigned
short
int
,
unsigned
short
int
,
typename
__hip_internal
::
conditional
<
sizeof
(
T
)
==
4
,
unsigned
int
,
unsigned
long
long
>::
typename
conditional
<
sizeof
(
T
)
==
4
,
unsigned
int
,
unsigned
long
long
>::
type
>::
type
;
type
>::
type
;
T_bitwise
x_bitwise
=
bit_cast
<
T_bitwise
>
(
_x
);
T_bitwise
x_bitwise
=
bit_cast
<
T_bitwise
>
(
_x
);
unsigned
long
long
x
{
x_bitwise
};
unsigned
long
long
x
{
x_bitwise
};
...
...
include/ck_tile/core/arch/amd_buffer_addressing.hpp
View file @
e92395d9
...
@@ -1303,8 +1303,8 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
...
@@ -1303,8 +1303,8 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
static_assert
(
static_assert
(
(
std
::
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
std
::
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
std
::
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
fp16_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
fp16_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
std
::
is_same
<
T
,
bf16_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
bf16_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
std
::
is_same
<
T
,
int32_t
>::
value
&&
(
std
::
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
fp8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
fp8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
...
...
include/ck_tile/core/container/meta_data_buffer.hpp
View file @
e92395d9
...
@@ -30,7 +30,7 @@ struct meta_data_buffer
...
@@ -30,7 +30,7 @@ struct meta_data_buffer
{
{
constexpr
index_t
size
=
sizeof
(
T
);
constexpr
index_t
size
=
sizeof
(
T
);
auto
tmp
=
bit_cast
<
array
<
std
::
byte
,
size
>>
(
data
);
auto
tmp
=
ck_tile
::
bit_cast
<
array
<
std
::
byte
,
size
>>
(
data
);
for
(
int
i
=
0
;
i
<
size
;
i
++
)
for
(
int
i
=
0
;
i
<
size
;
i
++
)
{
{
...
@@ -66,7 +66,7 @@ struct meta_data_buffer
...
@@ -66,7 +66,7 @@ struct meta_data_buffer
pos
++
;
pos
++
;
}
}
data
=
bit_cast
<
T
>
(
tmp
);
data
=
ck_tile
::
bit_cast
<
T
>
(
tmp
);
}
}
return
data
;
return
data
;
...
@@ -86,7 +86,7 @@ struct meta_data_buffer
...
@@ -86,7 +86,7 @@ struct meta_data_buffer
pos
++
;
pos
++
;
}
}
auto
data
=
bit_cast
<
T
>
(
tmp
);
auto
data
=
ck_tile
::
bit_cast
<
T
>
(
tmp
);
return
data
;
return
data
;
}
}
...
...
include/ck_tile/core/tensor/static_distributed_tensor.hpp
View file @
e92395d9
...
@@ -29,6 +29,7 @@ struct static_distributed_tensor
...
@@ -29,6 +29,7 @@ struct static_distributed_tensor
remove_cvref_t
<
decltype
(
StaticTileDistribution
{}.
get_ys_to_d_descriptor
())
>
;
remove_cvref_t
<
decltype
(
StaticTileDistribution
{}.
get_ys_to_d_descriptor
())
>
;
static
constexpr
index_t
kThreadElementSpaceSize
=
ThreadTensorDesc
{}.
get_element_space_size
();
static
constexpr
index_t
kThreadElementSpaceSize
=
ThreadTensorDesc
{}.
get_element_space_size
();
static_assert
(
0
<
kThreadElementSpaceSize
,
"Make sure tile distribution is valid"
);
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_dimension
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_dimension
()
{
{
...
...
include/ck_tile/host/arg_parser.hpp
View file @
e92395d9
...
@@ -15,11 +15,14 @@
...
@@ -15,11 +15,14 @@
namespace
ck_tile
{
namespace
ck_tile
{
/*
/*
* a host side utility, arg parser for
* a host side utility, arg parser for, either
* -[key0] = [value0, value1, value2]
* or
* -[key0]=[value0] -[key1]=[value1] ...
* -[key0]=[value0] -[key1]=[value1] ...
*/
*/
class
ArgParser
class
ArgParser
{
{
public:
public:
class
Arg
class
Arg
{
{
...
@@ -187,6 +190,45 @@ class ArgParser
...
@@ -187,6 +190,45 @@ class ArgParser
return
value
;
return
value
;
}
}
std
::
vector
<
std
::
string
>
get_string_vec
(
const
std
::
string
&
name
,
const
std
::
string
&
delimiter
=
","
)
const
{
if
(
get_str
(
name
).
empty
())
{
return
{};
}
std
::
string
s
=
get_str
(
name
);
std
::
vector
<
std
::
string
>
tokens
;
size_t
pos
=
0
;
std
::
string
token
;
while
((
pos
=
s
.
find
(
delimiter
))
!=
std
::
string
::
npos
)
{
token
=
s
.
substr
(
0
,
pos
);
tokens
.
push_back
(
token
);
s
.
erase
(
0
,
pos
+
delimiter
.
length
());
}
tokens
.
push_back
(
s
);
return
tokens
;
}
std
::
vector
<
int
>
get_int_vec
(
const
std
::
string
&
name
,
const
std
::
string
&
delimiter
=
","
)
const
{
if
(
get_str
(
name
).
empty
())
{
return
{};
}
const
std
::
vector
<
std
::
string
>
args
=
get_string_vec
(
name
,
delimiter
);
std
::
vector
<
int
>
tokens
;
tokens
.
reserve
(
static_cast
<
int
>
(
args
.
size
()));
for
(
const
std
::
string
&
token
:
args
)
{
int
value
=
atoi
(
token
.
c_str
());
tokens
.
push_back
(
value
);
}
return
tokens
;
}
private:
private:
std
::
unordered_map
<
std
::
string
,
Arg
>
input_map
;
std
::
unordered_map
<
std
::
string
,
Arg
>
input_map
;
std
::
vector
<
std
::
string
>
keys
;
std
::
vector
<
std
::
string
>
keys
;
...
...
include/ck_tile/host/reference/reference_gemm.hpp
View file @
e92395d9
...
@@ -97,9 +97,9 @@ template <typename ADataType,
...
@@ -97,9 +97,9 @@ template <typename ADataType,
typename
LayoutA
,
typename
LayoutA
,
typename
LayoutB
,
typename
LayoutB
,
typename
LayoutC
>
typename
LayoutC
>
void
reference_gemm_gpu
(
DeviceMem
&
a_device
,
void
reference_gemm_gpu
(
ADataType
*
a_ptr
,
DeviceMem
&
b_device
,
BDataType
*
b_ptr
,
DeviceMem
&
c_device
,
CDataType
*
c_ptr
,
index_t
M
,
index_t
M
,
index_t
N
,
index_t
N
,
index_t
K
,
index_t
K
,
...
@@ -107,79 +107,13 @@ void reference_gemm_gpu(DeviceMem& a_device,
...
@@ -107,79 +107,13 @@ void reference_gemm_gpu(DeviceMem& a_device,
index_t
stride_b
,
index_t
stride_b
,
index_t
stride_c
)
index_t
stride_c
)
{
{
ADataType
*
d_A
;
BDataType
*
d_B
;
CDataType
*
d_C
;
hipError_t
errA
=
hipMalloc
(
&
d_A
,
M
*
K
*
sizeof
(
ADataType
));
hipError_t
errB
=
hipMalloc
(
&
d_B
,
N
*
K
*
sizeof
(
BDataType
));
hipError_t
errC
=
hipMalloc
(
&
d_C
,
M
*
N
*
sizeof
(
CDataType
));
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for A: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
return
;
// Early exit on error
}
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for B: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
return
;
// Early exit on error
}
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for C: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
return
;
// Early exit on error
}
errA
=
hipMemcpy
(
d_A
,
a_device
.
GetDeviceBuffer
(),
M
*
K
*
sizeof
(
ADataType
),
hipMemcpyHostToDevice
);
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying A to device: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
}
errB
=
hipMemcpy
(
d_B
,
b_device
.
GetDeviceBuffer
(),
N
*
K
*
sizeof
(
BDataType
),
hipMemcpyHostToDevice
);
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying B to device: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
}
int
totalElements
=
M
*
N
;
int
totalElements
=
M
*
N
;
int
numThreadsPerBlock
=
256
;
// Common choice for threads per block
int
numThreadsPerBlock
=
256
;
// Common choice for threads per block
int
numBlocks
=
(
totalElements
+
numThreadsPerBlock
-
1
)
/
numThreadsPerBlock
;
int
numBlocks
=
(
totalElements
+
numThreadsPerBlock
-
1
)
/
numThreadsPerBlock
;
naive_gemm_kernel
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
LayoutA
,
LayoutB
,
LayoutC
>
naive_gemm_kernel
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
LayoutA
,
LayoutB
,
LayoutC
>
<<<
numBlocks
,
numThreadsPerBlock
>>>
(
d_A
,
d_B
,
d_C
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
<<<
numBlocks
,
numThreadsPerBlock
>>>
(
errC
=
hipMemcpy
(
a_ptr
,
b_ptr
,
c_ptr
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
c_device
.
GetDeviceBuffer
(),
d_C
,
M
*
N
*
sizeof
(
CDataType
),
hipMemcpyDeviceToHost
);
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying C to device: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
}
errA
=
hipFree
(
d_A
);
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the A memory: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
}
errB
=
hipFree
(
d_B
);
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the B memory: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
}
errC
=
hipFree
(
d_C
);
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the C memory: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
}
return
;
return
;
}
}
...
@@ -191,9 +125,9 @@ template <typename ADataType,
...
@@ -191,9 +125,9 @@ template <typename ADataType,
typename
LayoutA
,
typename
LayoutA
,
typename
LayoutB
,
typename
LayoutB
,
typename
LayoutC
>
typename
LayoutC
>
void
reference_batched_gemm_gpu
(
DeviceMem
&
a_device
,
void
reference_batched_gemm_gpu
(
ADataType
*
a_ptr
,
DeviceMem
&
b_device
,
BDataType
*
b_ptr
,
DeviceMem
&
c_device
,
CDataType
*
c_ptr
,
index_t
M
,
index_t
M
,
index_t
N
,
index_t
N
,
index_t
K
,
index_t
K
,
...
@@ -205,94 +139,20 @@ void reference_batched_gemm_gpu(DeviceMem& a_device,
...
@@ -205,94 +139,20 @@ void reference_batched_gemm_gpu(DeviceMem& a_device,
index_t
batch_stride_C
,
index_t
batch_stride_C
,
index_t
batch_count
)
index_t
batch_count
)
{
{
ADataType
*
d_A
;
BDataType
*
d_B
;
CDataType
*
d_C
;
hipError_t
errA
=
hipMalloc
(
&
d_A
,
batch_count
*
M
*
K
*
sizeof
(
ADataType
));
hipError_t
errB
=
hipMalloc
(
&
d_B
,
batch_count
*
N
*
K
*
sizeof
(
BDataType
));
hipError_t
errC
=
hipMalloc
(
&
d_C
,
batch_count
*
M
*
N
*
sizeof
(
CDataType
));
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for A: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
return
;
// Early exit on error
}
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for B: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
return
;
// Early exit on error
}
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for C: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
return
;
// Early exit on error
}
errA
=
hipMemcpy
(
d_A
,
a_device
.
GetDeviceBuffer
(),
batch_count
*
M
*
K
*
sizeof
(
ADataType
),
hipMemcpyHostToDevice
);
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying A to device: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
}
errB
=
hipMemcpy
(
d_B
,
b_device
.
GetDeviceBuffer
(),
batch_count
*
N
*
K
*
sizeof
(
BDataType
),
hipMemcpyHostToDevice
);
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying B to device: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
}
int
totalElements
=
M
*
N
;
int
totalElements
=
M
*
N
;
int
numThreadsPerBlock
=
256
;
// Common choice for threads per block
int
numThreadsPerBlock
=
256
;
// Common choice for threads per block
int
numBlocks
=
(
totalElements
+
numThreadsPerBlock
-
1
)
/
numThreadsPerBlock
;
int
numBlocks
=
(
totalElements
+
numThreadsPerBlock
-
1
)
/
numThreadsPerBlock
;
for
(
index_t
batch_id
=
0
;
batch_id
<
batch_count
;
++
batch_id
)
for
(
index_t
batch_id
=
0
;
batch_id
<
batch_count
;
++
batch_id
)
{
{
ADataType
*
d_ATemp
=
d_A
+
batch_id
*
batch_stride_A
;
ADataType
*
d_ATemp
=
a_ptr
+
batch_id
*
batch_stride_A
;
BDataType
*
d_BTemp
=
d_B
+
batch_id
*
batch_stride_B
;
BDataType
*
d_BTemp
=
b_ptr
+
batch_id
*
batch_stride_B
;
CDataType
*
d_CTemp
=
d_C
+
batch_id
*
batch_stride_C
;
CDataType
*
d_CTemp
=
c_ptr
+
batch_id
*
batch_stride_C
;
naive_gemm_kernel
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
LayoutA
,
LayoutB
,
LayoutC
>
naive_gemm_kernel
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
LayoutA
,
LayoutB
,
LayoutC
>
<<<
numBlocks
,
numThreadsPerBlock
>>>
(
<<<
numBlocks
,
numThreadsPerBlock
>>>
(
d_ATemp
,
d_BTemp
,
d_CTemp
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
d_ATemp
,
d_BTemp
,
d_CTemp
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
}
}
errC
=
hipMemcpy
(
c_device
.
GetDeviceBuffer
(),
d_C
,
batch_count
*
M
*
N
*
sizeof
(
CDataType
),
hipMemcpyDeviceToHost
);
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying C to device: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
}
errA
=
hipFree
(
d_A
);
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the A memory: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
}
errB
=
hipFree
(
d_B
);
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the B memory: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
}
errC
=
hipFree
(
d_C
);
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the C memory: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
}
return
;
return
;
}
}
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/flatmm.hpp
View file @
e92395d9
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include "ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp
0 → 100644
View file @
e92395d9
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment