Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
3c171550
"tests/pipelines/vscode:/vscode.git/clone" did not exist on "ca783a0f1f4ce8b0a16e6b96a8890edc47489e3a"
Commit
3c171550
authored
Oct 29, 2024
by
Aleksander Dudek
Browse files
Batched gemm - messy validation check
parent
71eea17c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
66 additions
and
28 deletions
+66
-28
example/ck_tile/05_batched_gemm/run_batched_gemm_example.inc
example/ck_tile/05_batched_gemm/run_batched_gemm_example.inc
+17
-4
include/ck_tile/host/reference/reference_gemm.hpp
include/ck_tile/host/reference/reference_gemm.hpp
+35
-17
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
+14
-7
No files found.
example/ck_tile/05_batched_gemm/run_batched_gemm_example.inc
View file @
3c171550
...
@@ -96,11 +96,13 @@ int run_batched_gemm_example(int argc, char* argv[])
...
@@ -96,11 +96,13 @@ int run_batched_gemm_example(int argc, char* argv[])
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
{
{
return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz});
return ck_tile::HostTensorDescriptor({static_cast<size_t>(16), row, col},
{row * col, stride, 1_uz});
}
}
else
else
{
{
return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride});
return ck_tile::HostTensorDescriptor({static_cast<size_t>(16), row, col},
{row * col, 1_uz, stride});
}
}
};
};
...
@@ -194,8 +196,19 @@ int run_batched_gemm_example(int argc, char* argv[])
...
@@ -194,8 +196,19 @@ int run_batched_gemm_example(int argc, char* argv[])
CDataType,
CDataType,
ALayout,
ALayout,
BLayout,
BLayout,
CLayout>(
CLayout>(a_m_k_dev_buf,
a_m_k_dev_buf, b_k_n_dev_buf, c_m_n_gpu_buf_ref, M, N, K, stride_A, stride_B, stride_C);
b_k_n_dev_buf,
c_m_n_gpu_buf_ref,
M,
N,
K,
stride_A,
stride_B,
stride_C,
batch_stride_A,
batch_stride_B,
batch_stride_C,
batch_count);
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref);
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref);
...
...
include/ck_tile/host/reference/reference_gemm.hpp
View file @
3c171550
...
@@ -29,22 +29,22 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
...
@@ -29,22 +29,22 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
const
std
::
size_t
N
=
b_k_n
.
get_length
(
1
);
const
std
::
size_t
N
=
b_k_n
.
get_length
(
1
);
const
std
::
size_t
K
=
a_m_k
.
get_length
(
1
);
const
std
::
size_t
K
=
a_m_k
.
get_length
(
1
);
auto
f_mn
=
[
&
](
auto
m
,
auto
n
)
{
auto
f_mn
=
[
&
](
auto
m
,
auto
n
,
auto
b
)
{
AccDataType
v_acc
=
0
;
AccDataType
v_acc
=
0
;
for
(
std
::
size_t
k
=
0
;
k
<
K
;
++
k
)
for
(
std
::
size_t
k
=
0
;
k
<
K
;
++
k
)
{
{
ADataType
v_a
=
a_element_op
(
a_m_k
(
m
,
k
));
ADataType
v_a
=
a_element_op
(
a_m_k
(
b
,
m
,
k
));
BDataType
v_b
=
b_element_op
(
b_k_n
(
k
,
n
));
BDataType
v_b
=
b_element_op
(
b_k_n
(
b
,
k
,
n
));
v_acc
+=
v_acc
+=
ck_tile
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck_tile
::
type_convert
<
AccDataType
>
(
v_b
);
ck_tile
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck_tile
::
type_convert
<
AccDataType
>
(
v_b
);
}
}
c_m_n
(
m
,
n
)
=
ck_tile
::
type_convert
<
CDataType
>
(
acc_element_op
(
v_acc
));
c_m_n
(
b
,
m
,
n
)
=
ck_tile
::
type_convert
<
CDataType
>
(
acc_element_op
(
v_acc
));
};
};
make_ParallelTensorFunctor
(
f_mn
,
M
,
N
)(
std
::
thread
::
hardware_concurrency
());
make_ParallelTensorFunctor
(
f_mn
,
M
,
N
,
16
)(
std
::
thread
::
hardware_concurrency
());
}
}
template
<
typename
ADataType
,
template
<
typename
ADataType
,
...
@@ -105,16 +105,20 @@ void reference_gemm_gpu(DeviceMem& a_device,
...
@@ -105,16 +105,20 @@ void reference_gemm_gpu(DeviceMem& a_device,
index_t
K
,
index_t
K
,
index_t
stride_a
,
index_t
stride_a
,
index_t
stride_b
,
index_t
stride_b
,
index_t
stride_c
)
index_t
stride_c
,
index_t
batch_stride_A
,
index_t
batch_stride_B
,
index_t
batch_stride_C
,
index_t
batch_count
)
{
{
ADataType
*
d_A
;
ADataType
*
d_A
;
BDataType
*
d_B
;
BDataType
*
d_B
;
CDataType
*
d_C
;
CDataType
*
d_C
;
hipError_t
errA
=
hipMalloc
(
&
d_A
,
M
*
K
*
sizeof
(
ADataType
));
hipError_t
errA
=
hipMalloc
(
&
d_A
,
batch_count
*
M
*
K
*
sizeof
(
ADataType
));
hipError_t
errB
=
hipMalloc
(
&
d_B
,
N
*
K
*
sizeof
(
BDataType
));
hipError_t
errB
=
hipMalloc
(
&
d_B
,
batch_count
*
N
*
K
*
sizeof
(
BDataType
));
hipError_t
errC
=
hipMalloc
(
&
d_C
,
M
*
N
*
sizeof
(
CDataType
));
hipError_t
errC
=
hipMalloc
(
&
d_C
,
batch_count
*
M
*
N
*
sizeof
(
CDataType
));
if
(
errA
!=
hipSuccess
)
if
(
errA
!=
hipSuccess
)
{
{
std
::
cerr
<<
"Error allocating device memory for A: "
<<
hipGetErrorString
(
errA
)
std
::
cerr
<<
"Error allocating device memory for A: "
<<
hipGetErrorString
(
errA
)
...
@@ -136,15 +140,19 @@ void reference_gemm_gpu(DeviceMem& a_device,
...
@@ -136,15 +140,19 @@ void reference_gemm_gpu(DeviceMem& a_device,
return
;
// Early exit on error
return
;
// Early exit on error
}
}
errA
=
hipMemcpy
(
errA
=
hipMemcpy
(
d_A
,
d_A
,
a_device
.
GetDeviceBuffer
(),
M
*
K
*
sizeof
(
ADataType
),
hipMemcpyHostToDevice
);
a_device
.
GetDeviceBuffer
(),
batch_count
*
M
*
K
*
sizeof
(
ADataType
),
hipMemcpyHostToDevice
);
if
(
errA
!=
hipSuccess
)
if
(
errA
!=
hipSuccess
)
{
{
std
::
cerr
<<
"Error copying A to device: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
std
::
cerr
<<
"Error copying A to device: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
}
}
errB
=
hipMemcpy
(
errB
=
hipMemcpy
(
d_B
,
d_B
,
b_device
.
GetDeviceBuffer
(),
N
*
K
*
sizeof
(
BDataType
),
hipMemcpyHostToDevice
);
b_device
.
GetDeviceBuffer
(),
batch_count
*
N
*
K
*
sizeof
(
BDataType
),
hipMemcpyHostToDevice
);
if
(
errB
!=
hipSuccess
)
if
(
errB
!=
hipSuccess
)
{
{
std
::
cerr
<<
"Error copying B to device: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
std
::
cerr
<<
"Error copying B to device: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
...
@@ -154,10 +162,20 @@ void reference_gemm_gpu(DeviceMem& a_device,
...
@@ -154,10 +162,20 @@ void reference_gemm_gpu(DeviceMem& a_device,
int
numThreadsPerBlock
=
256
;
// Common choice for threads per block
int
numThreadsPerBlock
=
256
;
// Common choice for threads per block
int
numBlocks
=
(
totalElements
+
numThreadsPerBlock
-
1
)
/
numThreadsPerBlock
;
int
numBlocks
=
(
totalElements
+
numThreadsPerBlock
-
1
)
/
numThreadsPerBlock
;
naive_gemm_kernel
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
LayoutA
,
LayoutB
,
LayoutC
>
for
(
int
i
=
0
;
i
<
batch_count
;
++
i
)
<<<
numBlocks
,
numThreadsPerBlock
>>>
(
d_A
,
d_B
,
d_C
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
{
errC
=
hipMemcpy
(
ADataType
*
d_ATemp
=
d_A
+
i
*
batch_stride_A
;
c_device
.
GetDeviceBuffer
(),
d_C
,
M
*
N
*
sizeof
(
CDataType
),
hipMemcpyDeviceToHost
);
BDataType
*
d_BTemp
=
d_B
+
i
*
batch_stride_B
;
CDataType
*
d_CTemp
=
d_C
+
i
*
batch_stride_C
;
naive_gemm_kernel
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
LayoutA
,
LayoutB
,
LayoutC
>
<<<
numBlocks
,
numThreadsPerBlock
>>>
(
d_ATemp
,
d_BTemp
,
d_CTemp
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
}
errC
=
hipMemcpy
(
c_device
.
GetDeviceBuffer
(),
d_C
,
batch_count
*
M
*
N
*
sizeof
(
CDataType
),
hipMemcpyDeviceToHost
);
if
(
errC
!=
hipSuccess
)
if
(
errC
!=
hipSuccess
)
{
{
std
::
cerr
<<
"Error copying C to device: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
std
::
cerr
<<
"Error copying C to device: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
...
...
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
View file @
3c171550
...
@@ -89,13 +89,20 @@ struct BatchedGemmKernel
...
@@ -89,13 +89,20 @@ struct BatchedGemmKernel
CK_TILE_DEVICE
void
operator
()(
BatchedGemmCommonKargs
kargs
)
const
CK_TILE_DEVICE
void
operator
()(
BatchedGemmCommonKargs
kargs
)
const
{
{
const
auto
[
i_m
,
i_n
]
=
TilePartitioner
{}();
const
auto
[
i_m
,
i_n
]
=
TilePartitioner
{}();
//
const auto i_k = blockIdx.z;
const
auto
i_k
=
blockIdx
.
z
;
// options
// options
const
ADataType
*
a_start
=
static_cast
<
const
ADataType
*>
(
const
ADataType
*
a_start
=
static_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
)
+
kargs
.
a_ptr
);
//+
__builtin_amdgcn_readfirstlane(i_k * kargs.batch_stride_A);
__builtin_amdgcn_readfirstlane
(
i_k
*
kargs
.
batch_stride_A
);
const
BDataType
*
b_start
=
static_cast
<
const
BDataType
*>
(
const
BDataType
*
b_start
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
)
+
kargs
.
b_ptr
);
//+
__builtin_amdgcn_readfirstlane(i_k * kargs.batch_stride_B);
__builtin_amdgcn_readfirstlane
(
i_k
*
kargs
.
batch_stride_B
);
// Convert pointers to tensor views
// Convert pointers to tensor views
// if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0)
// {
// printf("__builtin_amdgcn_readfirstlane(i_k * kargs.batch_stride_A): %d\n",
// __builtin_amdgcn_readfirstlane(i_k * kargs.batch_stride_A));
// printf("__builtin_amdgcn_readfirstlane(i_k * kargs.batch_stride_B): %d\n",
// __builtin_amdgcn_readfirstlane(i_k * kargs.batch_stride_B));
// }
auto
a_tensor_view
=
[
&
]()
{
auto
a_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
...
@@ -172,8 +179,8 @@ struct BatchedGemmKernel
...
@@ -172,8 +179,8 @@ struct BatchedGemmKernel
auto
c_block_tile
=
auto
c_block_tile
=
GemmPipeline
{}.
template
operator
()(
a_block_window
,
b_block_window
,
num_loop
,
smem_ptr
);
GemmPipeline
{}.
template
operator
()(
a_block_window
,
b_block_window
,
num_loop
,
smem_ptr
);
CDataType
*
c_start
=
static_cast
<
CDataType
*>
(
CDataType
*
c_start
=
static_cast
<
CDataType
*>
(
kargs
.
c_ptr
)
+
kargs
.
c_ptr
);
//; +
__builtin_amdgcn_readfirstlane(i_k * kargs.batch_stride_C);
__builtin_amdgcn_readfirstlane
(
i_k
*
kargs
.
batch_stride_C
);
auto
c_tensor_view
=
[
&
]()
{
auto
c_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment