Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
13fe6e95
Commit
13fe6e95
authored
Dec 18, 2024
by
Aleksander Dudek
Browse files
Merge branch 'develop' into ck_tile_gemmkernel_reuse
parents
b85e1128
1c1b3363
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
71 additions
and
159 deletions
+71
-159
example/ck_tile/03_gemm/run_gemm_example.inc
example/ck_tile/03_gemm/run_gemm_example.inc
+27
-2
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
+30
-3
include/ck_tile/core/container/meta_data_buffer.hpp
include/ck_tile/core/container/meta_data_buffer.hpp
+3
-3
include/ck_tile/host/reference/reference_gemm.hpp
include/ck_tile/host/reference/reference_gemm.hpp
+11
-151
No files found.
example/ck_tile/03_gemm/run_gemm_example.inc
View file @
13fe6e95
...
@@ -161,14 +161,39 @@ int run_gemm_example_with_layouts(int argc,
...
@@ -161,14 +161,39 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_gpu_ref
.
SetZero
();
c_m_n_gpu_ref
.
SetZero
();
c_m_n_gpu_buf_ref
.
SetZero
();
c_m_n_gpu_buf_ref
.
SetZero
();
ADataType
*
d_A
;
BDataType
*
d_B
;
CDataType
*
d_C
;
ck_tile
::
hip_check_error
(
hipMalloc
(
&
d_A
,
M
*
K
*
sizeof
(
ADataType
)));
ck_tile
::
hip_check_error
(
hipMalloc
(
&
d_B
,
N
*
K
*
sizeof
(
BDataType
)));
ck_tile
::
hip_check_error
(
hipMalloc
(
&
d_C
,
M
*
N
*
sizeof
(
CDataType
)));
ck_tile
::
hip_check_error
(
hipMemcpy
(
d_A
,
a_m_k_dev_buf
.
GetDeviceBuffer
(),
M
*
K
*
sizeof
(
ADataType
),
hipMemcpyHostToDevice
));
ck_tile
::
hip_check_error
(
hipMemcpy
(
d_B
,
b_k_n_dev_buf
.
GetDeviceBuffer
(),
N
*
K
*
sizeof
(
BDataType
),
hipMemcpyHostToDevice
));
ck_tile
::
reference_gemm_gpu
<
ADataType
,
ck_tile
::
reference_gemm_gpu
<
ADataType
,
BDataType
,
BDataType
,
AccDataType
,
AccDataType
,
CDataType
,
CDataType
,
ALayout
,
ALayout
,
BLayout
,
BLayout
,
CLayout
>
(
CLayout
>
(
d_A
,
d_B
,
d_C
,
M
,
N
,
K
,
stride_A
,
stride_B
,
stride_C
);
a_m_k_dev_buf
,
b_k_n_dev_buf
,
c_m_n_gpu_buf_ref
,
M
,
N
,
K
,
stride_A
,
stride_B
,
stride_C
);
ck_tile
::
hip_check_error
(
hipMemcpy
(
c_m_n_gpu_buf_ref
.
GetDeviceBuffer
(),
d_C
,
M
*
N
*
sizeof
(
CDataType
),
hipMemcpyDeviceToHost
));
ck_tile
::
hip_check_error
(
hipFree
(
d_A
));
ck_tile
::
hip_check_error
(
hipFree
(
d_B
));
ck_tile
::
hip_check_error
(
hipFree
(
d_C
));
c_m_n_gpu_buf_ref
.
FromDevice
(
c_m_n_gpu_ref
.
data
());
c_m_n_gpu_buf_ref
.
FromDevice
(
c_m_n_gpu_ref
.
data
());
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_gpu_ref
);
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_gpu_ref
);
...
...
example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc
View file @
13fe6e95
...
@@ -188,15 +188,33 @@ int run_batched_gemm_example_with_layouts(int argc,
...
@@ -188,15 +188,33 @@ int run_batched_gemm_example_with_layouts(int argc,
c_m_n_gpu_ref
.
SetZero
();
c_m_n_gpu_ref
.
SetZero
();
c_m_n_gpu_buf_ref
.
SetZero
();
c_m_n_gpu_buf_ref
.
SetZero
();
ADataType
*
d_A
;
BDataType
*
d_B
;
CDataType
*
d_C
;
ck_tile
::
hip_check_error
(
hipMalloc
(
&
d_A
,
batch_count
*
M
*
K
*
sizeof
(
ADataType
)));
ck_tile
::
hip_check_error
(
hipMalloc
(
&
d_B
,
batch_count
*
N
*
K
*
sizeof
(
BDataType
)));
ck_tile
::
hip_check_error
(
hipMalloc
(
&
d_C
,
batch_count
*
M
*
N
*
sizeof
(
CDataType
)));
ck_tile
::
hip_check_error
(
hipMemcpy
(
d_A
,
a_m_k_dev_buf
.
GetDeviceBuffer
(),
batch_count
*
M
*
K
*
sizeof
(
ADataType
),
hipMemcpyHostToDevice
));
ck_tile
::
hip_check_error
(
hipMemcpy
(
d_B
,
b_k_n_dev_buf
.
GetDeviceBuffer
(),
batch_count
*
N
*
K
*
sizeof
(
BDataType
),
hipMemcpyHostToDevice
));
ck_tile
::
reference_batched_gemm_gpu
<
ADataType
,
ck_tile
::
reference_batched_gemm_gpu
<
ADataType
,
BDataType
,
BDataType
,
AccDataType
,
AccDataType
,
CDataType
,
CDataType
,
ALayout
,
ALayout
,
BLayout
,
BLayout
,
CLayout
>
(
a_m_k_dev_buf
,
CLayout
>
(
d_A
,
b_k_n_dev_buf
,
d_B
,
c_m_n_gpu_buf_ref
,
d_C
,
M
,
M
,
N
,
N
,
K
,
K
,
...
@@ -208,6 +226,15 @@ int run_batched_gemm_example_with_layouts(int argc,
...
@@ -208,6 +226,15 @@ int run_batched_gemm_example_with_layouts(int argc,
batch_stride_C
,
batch_stride_C
,
batch_count
);
batch_count
);
ck_tile
::
hip_check_error
(
hipMemcpy
(
c_m_n_gpu_buf_ref
.
GetDeviceBuffer
(),
d_C
,
batch_count
*
M
*
N
*
sizeof
(
CDataType
),
hipMemcpyDeviceToHost
));
ck_tile
::
hip_check_error
(
hipFree
(
d_A
));
ck_tile
::
hip_check_error
(
hipFree
(
d_B
));
ck_tile
::
hip_check_error
(
hipFree
(
d_C
));
c_m_n_gpu_buf_ref
.
FromDevice
(
c_m_n_gpu_ref
.
data
());
c_m_n_gpu_buf_ref
.
FromDevice
(
c_m_n_gpu_ref
.
data
());
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_gpu_ref
);
pass
=
ck_tile
::
check_err
(
c_m_n_dev_result
,
c_m_n_gpu_ref
);
...
...
include/ck_tile/core/container/meta_data_buffer.hpp
View file @
13fe6e95
...
@@ -30,7 +30,7 @@ struct meta_data_buffer
...
@@ -30,7 +30,7 @@ struct meta_data_buffer
{
{
constexpr
index_t
size
=
sizeof
(
T
);
constexpr
index_t
size
=
sizeof
(
T
);
auto
tmp
=
bit_cast
<
array
<
std
::
byte
,
size
>>
(
data
);
auto
tmp
=
ck_tile
::
bit_cast
<
array
<
std
::
byte
,
size
>>
(
data
);
for
(
int
i
=
0
;
i
<
size
;
i
++
)
for
(
int
i
=
0
;
i
<
size
;
i
++
)
{
{
...
@@ -66,7 +66,7 @@ struct meta_data_buffer
...
@@ -66,7 +66,7 @@ struct meta_data_buffer
pos
++
;
pos
++
;
}
}
data
=
bit_cast
<
T
>
(
tmp
);
data
=
ck_tile
::
bit_cast
<
T
>
(
tmp
);
}
}
return
data
;
return
data
;
...
@@ -86,7 +86,7 @@ struct meta_data_buffer
...
@@ -86,7 +86,7 @@ struct meta_data_buffer
pos
++
;
pos
++
;
}
}
auto
data
=
bit_cast
<
T
>
(
tmp
);
auto
data
=
ck_tile
::
bit_cast
<
T
>
(
tmp
);
return
data
;
return
data
;
}
}
...
...
include/ck_tile/host/reference/reference_gemm.hpp
View file @
13fe6e95
...
@@ -97,9 +97,9 @@ template <typename ADataType,
...
@@ -97,9 +97,9 @@ template <typename ADataType,
typename
LayoutA
,
typename
LayoutA
,
typename
LayoutB
,
typename
LayoutB
,
typename
LayoutC
>
typename
LayoutC
>
void
reference_gemm_gpu
(
DeviceMem
&
a_device
,
void
reference_gemm_gpu
(
ADataType
*
a_ptr
,
DeviceMem
&
b_device
,
BDataType
*
b_ptr
,
DeviceMem
&
c_device
,
CDataType
*
c_ptr
,
index_t
M
,
index_t
M
,
index_t
N
,
index_t
N
,
index_t
K
,
index_t
K
,
...
@@ -107,79 +107,13 @@ void reference_gemm_gpu(DeviceMem& a_device,
...
@@ -107,79 +107,13 @@ void reference_gemm_gpu(DeviceMem& a_device,
index_t
stride_b
,
index_t
stride_b
,
index_t
stride_c
)
index_t
stride_c
)
{
{
ADataType
*
d_A
;
BDataType
*
d_B
;
CDataType
*
d_C
;
hipError_t
errA
=
hipMalloc
(
&
d_A
,
M
*
K
*
sizeof
(
ADataType
));
hipError_t
errB
=
hipMalloc
(
&
d_B
,
N
*
K
*
sizeof
(
BDataType
));
hipError_t
errC
=
hipMalloc
(
&
d_C
,
M
*
N
*
sizeof
(
CDataType
));
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for A: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
return
;
// Early exit on error
}
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for B: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
return
;
// Early exit on error
}
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for C: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
return
;
// Early exit on error
}
errA
=
hipMemcpy
(
d_A
,
a_device
.
GetDeviceBuffer
(),
M
*
K
*
sizeof
(
ADataType
),
hipMemcpyHostToDevice
);
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying A to device: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
}
errB
=
hipMemcpy
(
d_B
,
b_device
.
GetDeviceBuffer
(),
N
*
K
*
sizeof
(
BDataType
),
hipMemcpyHostToDevice
);
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying B to device: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
}
int
totalElements
=
M
*
N
;
int
totalElements
=
M
*
N
;
int
numThreadsPerBlock
=
256
;
// Common choice for threads per block
int
numThreadsPerBlock
=
256
;
// Common choice for threads per block
int
numBlocks
=
(
totalElements
+
numThreadsPerBlock
-
1
)
/
numThreadsPerBlock
;
int
numBlocks
=
(
totalElements
+
numThreadsPerBlock
-
1
)
/
numThreadsPerBlock
;
naive_gemm_kernel
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
LayoutA
,
LayoutB
,
LayoutC
>
naive_gemm_kernel
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
LayoutA
,
LayoutB
,
LayoutC
>
<<<
numBlocks
,
numThreadsPerBlock
>>>
(
d_A
,
d_B
,
d_C
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
<<<
numBlocks
,
numThreadsPerBlock
>>>
(
errC
=
hipMemcpy
(
a_ptr
,
b_ptr
,
c_ptr
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
c_device
.
GetDeviceBuffer
(),
d_C
,
M
*
N
*
sizeof
(
CDataType
),
hipMemcpyDeviceToHost
);
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying C to device: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
}
errA
=
hipFree
(
d_A
);
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the A memory: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
}
errB
=
hipFree
(
d_B
);
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the B memory: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
}
errC
=
hipFree
(
d_C
);
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the C memory: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
}
return
;
return
;
}
}
...
@@ -191,9 +125,9 @@ template <typename ADataType,
...
@@ -191,9 +125,9 @@ template <typename ADataType,
typename
LayoutA
,
typename
LayoutA
,
typename
LayoutB
,
typename
LayoutB
,
typename
LayoutC
>
typename
LayoutC
>
void
reference_batched_gemm_gpu
(
DeviceMem
&
a_device
,
void
reference_batched_gemm_gpu
(
ADataType
*
a_ptr
,
DeviceMem
&
b_device
,
BDataType
*
b_ptr
,
DeviceMem
&
c_device
,
CDataType
*
c_ptr
,
index_t
M
,
index_t
M
,
index_t
N
,
index_t
N
,
index_t
K
,
index_t
K
,
...
@@ -205,94 +139,20 @@ void reference_batched_gemm_gpu(DeviceMem& a_device,
...
@@ -205,94 +139,20 @@ void reference_batched_gemm_gpu(DeviceMem& a_device,
index_t
batch_stride_C
,
index_t
batch_stride_C
,
index_t
batch_count
)
index_t
batch_count
)
{
{
ADataType
*
d_A
;
BDataType
*
d_B
;
CDataType
*
d_C
;
hipError_t
errA
=
hipMalloc
(
&
d_A
,
batch_count
*
M
*
K
*
sizeof
(
ADataType
));
hipError_t
errB
=
hipMalloc
(
&
d_B
,
batch_count
*
N
*
K
*
sizeof
(
BDataType
));
hipError_t
errC
=
hipMalloc
(
&
d_C
,
batch_count
*
M
*
N
*
sizeof
(
CDataType
));
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for A: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
return
;
// Early exit on error
}
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for B: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
return
;
// Early exit on error
}
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for C: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
return
;
// Early exit on error
}
errA
=
hipMemcpy
(
d_A
,
a_device
.
GetDeviceBuffer
(),
batch_count
*
M
*
K
*
sizeof
(
ADataType
),
hipMemcpyHostToDevice
);
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying A to device: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
}
errB
=
hipMemcpy
(
d_B
,
b_device
.
GetDeviceBuffer
(),
batch_count
*
N
*
K
*
sizeof
(
BDataType
),
hipMemcpyHostToDevice
);
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying B to device: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
}
int
totalElements
=
M
*
N
;
int
totalElements
=
M
*
N
;
int
numThreadsPerBlock
=
256
;
// Common choice for threads per block
int
numThreadsPerBlock
=
256
;
// Common choice for threads per block
int
numBlocks
=
(
totalElements
+
numThreadsPerBlock
-
1
)
/
numThreadsPerBlock
;
int
numBlocks
=
(
totalElements
+
numThreadsPerBlock
-
1
)
/
numThreadsPerBlock
;
for
(
index_t
batch_id
=
0
;
batch_id
<
batch_count
;
++
batch_id
)
for
(
index_t
batch_id
=
0
;
batch_id
<
batch_count
;
++
batch_id
)
{
{
ADataType
*
d_ATemp
=
d_A
+
batch_id
*
batch_stride_A
;
ADataType
*
d_ATemp
=
a_ptr
+
batch_id
*
batch_stride_A
;
BDataType
*
d_BTemp
=
d_B
+
batch_id
*
batch_stride_B
;
BDataType
*
d_BTemp
=
b_ptr
+
batch_id
*
batch_stride_B
;
CDataType
*
d_CTemp
=
d_C
+
batch_id
*
batch_stride_C
;
CDataType
*
d_CTemp
=
c_ptr
+
batch_id
*
batch_stride_C
;
naive_gemm_kernel
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
LayoutA
,
LayoutB
,
LayoutC
>
naive_gemm_kernel
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
LayoutA
,
LayoutB
,
LayoutC
>
<<<
numBlocks
,
numThreadsPerBlock
>>>
(
<<<
numBlocks
,
numThreadsPerBlock
>>>
(
d_ATemp
,
d_BTemp
,
d_CTemp
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
d_ATemp
,
d_BTemp
,
d_CTemp
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
}
}
errC
=
hipMemcpy
(
c_device
.
GetDeviceBuffer
(),
d_C
,
batch_count
*
M
*
N
*
sizeof
(
CDataType
),
hipMemcpyDeviceToHost
);
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying C to device: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
}
errA
=
hipFree
(
d_A
);
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the A memory: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
}
errB
=
hipFree
(
d_B
);
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the B memory: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
}
errC
=
hipFree
(
d_C
);
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the C memory: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
}
return
;
return
;
}
}
}
// namespace ck_tile
}
// namespace ck_tile
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment