Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
522b7aee
Commit
522b7aee
authored
Jan 30, 2024
by
Adam Osewski
Browse files
Merge remote-tracking branch 'origin/develop' into aosewski/ggemm_multi_d2
parents
ff936fd6
84832fc4
Changes
130
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1066 additions
and
36 deletions
+1066
-36
example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int4.cpp
...wd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int4.cpp
+2
-3
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp
+2
-3
example/35_splitK_gemm/run_splitK_gemm_example.inc
example/35_splitK_gemm/run_splitK_gemm_example.inc
+1
-1
example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp
example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp
+1
-1
example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int4.cpp
..._grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int4.cpp
+2
-3
example/44_elementwise_permute/elementwise_permute.cpp
example/44_elementwise_permute/elementwise_permute.cpp
+3
-0
example/44_elementwise_permute/elementwise_permute_3d.cpp
example/44_elementwise_permute/elementwise_permute_3d.cpp
+9
-6
example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp
...le/44_elementwise_permute/elementwise_permute_4D_fp16.cpp
+3
-0
example/44_elementwise_permute/elementwise_permute_4D_fp16_2d.cpp
...44_elementwise_permute/elementwise_permute_4D_fp16_2d.cpp
+3
-0
example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp
...4_elementwise_permute/elementwise_permute_4D_fp16_col.cpp
+3
-0
example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp
...4_elementwise_permute/elementwise_permute_4D_fp16_row.cpp
+3
-0
example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp
...4_elementwise_permute/elementwise_permute_4D_fp32_col.cpp
+3
-0
example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp
...4_elementwise_permute/elementwise_permute_4D_fp32_row.cpp
+3
-0
example/48_pool3d_fwd/pool3d_fwd_common.hpp
example/48_pool3d_fwd/pool3d_fwd_common.hpp
+4
-0
example/51_avgpool3d_bwd/avgpool3d_bwd_common.hpp
example/51_avgpool3d_bwd/avgpool3d_bwd_common.hpp
+4
-0
include/ck/ck.hpp
include/ck/ck.hpp
+2
-2
include/ck/host_utility/hip_check_error.hpp
include/ck/host_utility/hip_check_error.hpp
+15
-13
include/ck/host_utility/kernel_launch.hpp
include/ck/host_utility/kernel_launch.hpp
+2
-2
include/ck/stream_config.hpp
include/ck/stream_config.hpp
+2
-2
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp
...or_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp
+999
-0
No files found.
example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_xdl_int4.cpp
View file @
522b7aee
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#error Should compile this file with ck::int4_t support
#endif
#include "common.hpp"
#include "common.hpp"
...
@@ -29,3 +27,4 @@ using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd;
...
@@ -29,3 +27,4 @@ using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd;
#include "run_grouped_conv_fwd_bias_relu_add_example.inc"
#include "run_grouped_conv_fwd_bias_relu_add_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_grouped_conv_fwd_bias_relu_add_example
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_grouped_conv_fwd_bias_relu_add_example
(
argc
,
argv
);
}
#endif
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_int4.cpp
View file @
522b7aee
...
@@ -9,9 +9,7 @@ Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o
...
@@ -9,9 +9,7 @@ Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o
Gemm1
Gemm1
*/
*/
#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#error Should compile this file with ck::int4_t support
#endif
#include <iostream>
#include <iostream>
#include <numeric>
#include <numeric>
...
@@ -144,3 +142,4 @@ static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
...
@@ -144,3 +142,4 @@ static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
#endif
#endif
int
main
(
int
argc
,
char
*
argv
[])
{
return
run_batched_gemm_gemm_example
(
argc
,
argv
)
?
0
:
1
;
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
run_batched_gemm_gemm_example
(
argc
,
argv
)
?
0
:
1
;
}
#endif
example/35_splitK_gemm/run_splitK_gemm_example.inc
View file @
522b7aee
...
@@ -157,7 +157,7 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
...
@@ -157,7 +157,7 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
if
(
config
.
time_kernel
)
if
(
config
.
time_kernel
)
{
{
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
});
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
,
1
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
std
::
size_t
num_btype
=
...
...
example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp
View file @
522b7aee
...
@@ -42,7 +42,7 @@ using AElementOp = PassThrough;
...
@@ -42,7 +42,7 @@ using AElementOp = PassThrough;
using
BElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
KPadding
;
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmXdlSplitKCShuffle
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmXdlSplitKCShuffle
// clang-format off
// clang-format off
...
...
example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_int4.cpp
View file @
522b7aee
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#error Should compile this file with ck::int4_t support
#endif
#include <cstdlib>
#include <cstdlib>
#include <iostream>
#include <iostream>
...
@@ -120,3 +118,4 @@ static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
...
@@ -120,3 +118,4 @@ static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
#endif
#endif
int
main
(
int
argc
,
char
*
argv
[])
{
return
run_grouped_conv_conv_fwd_example
(
argc
,
argv
)
?
0
:
1
;
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
run_grouped_conv_conv_fwd_example
(
argc
,
argv
)
?
0
:
1
;
}
#endif
example/44_elementwise_permute/elementwise_permute.cpp
View file @
522b7aee
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <iostream>
#include <cstdlib>
#include <cstdlib>
...
...
example/44_elementwise_permute/elementwise_permute_3d.cpp
View file @
522b7aee
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <iostream>
#include <cstdlib>
#include <cstdlib>
...
@@ -14,8 +17,8 @@
...
@@ -14,8 +17,8 @@
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
F32
=
float
;
using
ADataType
=
F
16
;
using
ADataType
=
F
32
;
using
BDataType
=
F
16
;
using
BDataType
=
F
32
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
DeviceElementwisePermuteInstance
=
using
DeviceElementwisePermuteInstance
=
...
@@ -25,10 +28,10 @@ using DeviceElementwisePermuteInstance =
...
@@ -25,10 +28,10 @@ using DeviceElementwisePermuteInstance =
2
,
// NumDim_m, {N, C}
2
,
// NumDim_m, {N, C}
2
,
// NumDim_n, {H, W}
2
,
// NumDim_n, {H, W}
1
,
// NumDim_k, {D}
1
,
// NumDim_k, {D}
8
,
// MPerThread
4
,
// MPerThread
8
,
// NPerThread
4
,
// NPerThread
8
,
// KPerThread
4
,
// KPerThread
ck
::
Sequence
<
8
>
,
// InScalarPerVectorSeq
ck
::
Sequence
<
4
>
,
// InScalarPerVectorSeq
ck
::
Sequence
<
4
>>
;
// OutScalarPerVectorSeq
ck
::
Sequence
<
4
>>
;
// OutScalarPerVectorSeq
template
<
typename
HostTensorA
,
typename
HostTensorB
,
typename
Functor
>
template
<
typename
HostTensorA
,
typename
HostTensorB
,
typename
Functor
>
...
...
example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp
View file @
522b7aee
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <iostream>
#include <cstdlib>
#include <cstdlib>
...
...
example/44_elementwise_permute/elementwise_permute_4D_fp16_2d.cpp
View file @
522b7aee
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <iostream>
#include <cstdlib>
#include <cstdlib>
...
...
example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp
View file @
522b7aee
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <iostream>
#include <cstdlib>
#include <cstdlib>
#include <random>
#include <random>
...
...
example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp
View file @
522b7aee
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <iostream>
#include <cstdlib>
#include <cstdlib>
...
...
example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp
View file @
522b7aee
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <iostream>
#include <cstdlib>
#include <cstdlib>
...
...
example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp
View file @
522b7aee
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <iostream>
#include <cstdlib>
#include <cstdlib>
...
...
example/48_pool3d_fwd/pool3d_fwd_common.hpp
View file @
522b7aee
...
@@ -32,6 +32,8 @@ std::vector<ck::index_t> f_tensor_strides_ncdhw(ck::index_t N_,
...
@@ -32,6 +32,8 @@ std::vector<ck::index_t> f_tensor_strides_ncdhw(ck::index_t N_,
return
{
C_
*
D
*
H
*
W
,
D
*
H
*
W
,
H
*
W
,
W
,
1
_uz
};
return
{
C_
*
D
*
H
*
W
,
D
*
H
*
W
,
H
*
W
,
W
,
1
_uz
};
else
if
constexpr
(
ck
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
convolution
::
NDHWC
>::
value
)
else
if
constexpr
(
ck
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
convolution
::
NDHWC
>::
value
)
return
{
D
*
C_
*
H
*
W
,
1
_uz
,
C_
*
H
*
W
,
W
*
C_
,
C_
};
return
{
D
*
C_
*
H
*
W
,
1
_uz
,
C_
*
H
*
W
,
W
*
C_
,
C_
};
throw
std
::
runtime_error
(
"Pool3d_fwd: problem with layout. "
);
return
{
0
,
0
,
0
,
0
,
0
};
};
};
template
<
typename
TensorLayout
>
template
<
typename
TensorLayout
>
...
@@ -53,6 +55,8 @@ HostTensorDescriptor f_host_tensor_descriptor(std::size_t N_,
...
@@ -53,6 +55,8 @@ HostTensorDescriptor f_host_tensor_descriptor(std::size_t N_,
return
HostTensorDescriptor
({
N_
,
C_
,
D
,
H
,
W
},
return
HostTensorDescriptor
({
N_
,
C_
,
D
,
H
,
W
},
{
D
*
C_
*
H
*
W
,
1
_uz
,
C_
*
H
*
W
,
W
*
C_
,
C_
});
{
D
*
C_
*
H
*
W
,
1
_uz
,
C_
*
H
*
W
,
W
*
C_
,
C_
});
}
}
throw
std
::
runtime_error
(
"Pool3d_fwd: problem with layout. "
);
return
HostTensorDescriptor
({
0
,
0
,
0
,
0
,
0
},
{
0
,
0
,
0
,
0
,
0
});
};
};
template
<
typename
DevicePoolFwdInstance
,
template
<
typename
DevicePoolFwdInstance
,
...
...
example/51_avgpool3d_bwd/avgpool3d_bwd_common.hpp
View file @
522b7aee
...
@@ -26,6 +26,8 @@ std::vector<ck::index_t> f_tensor_strides_ncdhw(ck::index_t N_,
...
@@ -26,6 +26,8 @@ std::vector<ck::index_t> f_tensor_strides_ncdhw(ck::index_t N_,
return
{
C_
*
D
*
H
*
W
,
D
*
H
*
W
,
H
*
W
,
W
,
1
_uz
};
return
{
C_
*
D
*
H
*
W
,
D
*
H
*
W
,
H
*
W
,
W
,
1
_uz
};
else
if
constexpr
(
ck
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
convolution
::
NDHWC
>::
value
)
else
if
constexpr
(
ck
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
convolution
::
NDHWC
>::
value
)
return
{
D
*
C_
*
H
*
W
,
1
_uz
,
C_
*
H
*
W
,
W
*
C_
,
C_
};
return
{
D
*
C_
*
H
*
W
,
1
_uz
,
C_
*
H
*
W
,
W
*
C_
,
C_
};
throw
std
::
runtime_error
(
"Avgpool3d_bwd: problem with layout. "
);
return
{
0
,
0
,
0
,
0
,
0
};
};
};
template
<
typename
TensorLayout
>
template
<
typename
TensorLayout
>
...
@@ -47,6 +49,8 @@ HostTensorDescriptor f_host_tensor_descriptor(std::size_t N_,
...
@@ -47,6 +49,8 @@ HostTensorDescriptor f_host_tensor_descriptor(std::size_t N_,
return
HostTensorDescriptor
({
N_
,
C_
,
D
,
H
,
W
},
return
HostTensorDescriptor
({
N_
,
C_
,
D
,
H
,
W
},
{
D
*
C_
*
H
*
W
,
1
_uz
,
C_
*
H
*
W
,
W
*
C_
,
C_
});
{
D
*
C_
*
H
*
W
,
1
_uz
,
C_
*
H
*
W
,
W
*
C_
,
C_
});
}
}
throw
std
::
runtime_error
(
"Avgpool3d_bwd: problem with layout. "
);
return
HostTensorDescriptor
({
0
,
0
,
0
,
0
,
0
},
{
0
,
0
,
0
,
0
,
0
});
};
};
template
<
typename
DevicePoolBwdInstance
,
template
<
typename
DevicePoolBwdInstance
,
...
...
include/ck/ck.hpp
View file @
522b7aee
...
@@ -213,12 +213,12 @@
...
@@ -213,12 +213,12 @@
#define CK_WORKAROUND_SWDEV_388832 1
#define CK_WORKAROUND_SWDEV_388832 1
// flag to enable (1) or disable (0) the debugging output in some kernels
// flag to enable (1) or disable (0) the debugging output in some kernels
#define DEBUG_LOG
1
#define DEBUG_LOG
0
// denorm test fix, required to work around dissue
// denorm test fix, required to work around dissue
#ifndef CK_WORKAROUND_DENORM_FIX
#ifndef CK_WORKAROUND_DENORM_FIX
#define CK_WORKAROUND_DENORM_FIX 0
#define CK_WORKAROUND_DENORM_FIX 0
#el
if
#el
se
// enable only on MI200
// enable only on MI200
#define CK_WORKAROUND_DENORM_FIX = CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__)
#define CK_WORKAROUND_DENORM_FIX = CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__)
#endif // CK_WORKAROUND_DENORM_FIX
#endif // CK_WORKAROUND_DENORM_FIX
...
...
include/ck/host_utility/hip_check_error.hpp
View file @
522b7aee
...
@@ -12,21 +12,23 @@ inline void hip_check_error(hipError_t x)
...
@@ -12,21 +12,23 @@ inline void hip_check_error(hipError_t x)
if
(
x
!=
hipSuccess
)
if
(
x
!=
hipSuccess
)
{
{
std
::
ostringstream
ss
;
std
::
ostringstream
ss
;
ss
<<
"HIP runtime error: "
<<
hipGetErrorString
(
x
)
<<
". "
<<
__FILE__
<<
": "
<<
__LINE__
ss
<<
"HIP runtime error: "
<<
hipGetErrorString
(
x
)
<<
". "
<<
"in function: "
<<
__func__
;
<<
"hip_check_error.hpp"
<<
": "
<<
__LINE__
<<
"in function: "
<<
__func__
;
throw
std
::
runtime_error
(
ss
.
str
());
throw
std
::
runtime_error
(
ss
.
str
());
}
}
}
}
#define HIP_CHECK_ERROR(retval_or_funcall) \
#define HIP_CHECK_ERROR(retval_or_funcall) \
do \
do \
{ \
{ \
hipError_t _tmpVal = retval_or_funcall; \
hipError_t _tmpVal = retval_or_funcall; \
if(_tmpVal != hipSuccess) \
if(_tmpVal != hipSuccess) \
{ \
{ \
std::ostringstream ostr; \
std::ostringstream ostr; \
ostr << "HIP Function Failed (" << __FILE__ << "," << __LINE__ << ") " \
ostr << "HIP Function Failed (" \
<< hipGetErrorString(_tmpVal); \
<< "hip_check_error.hpp" \
throw std::runtime_error(ostr.str()); \
<< "," << __LINE__ << ") " << hipGetErrorString(_tmpVal); \
} \
throw std::runtime_error(ostr.str()); \
} \
} while(0)
} while(0)
include/ck/host_utility/kernel_launch.hpp
View file @
522b7aee
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -30,7 +30,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
...
@@ -30,7 +30,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
block_dim
.
y
,
block_dim
.
y
,
block_dim
.
z
);
block_dim
.
z
);
printf
(
"Warm up
1
time
\n
"
);
printf
(
"Warm up
%d
time
s
\n
"
,
stream_config
.
cold_niters_
);
#endif
#endif
// warm up
// warm up
for
(
int
i
=
0
;
i
<
stream_config
.
cold_niters_
;
++
i
)
for
(
int
i
=
0
;
i
<
stream_config
.
cold_niters_
;
++
i
)
...
...
include/ck/stream_config.hpp
View file @
522b7aee
...
@@ -11,6 +11,6 @@ struct StreamConfig
...
@@ -11,6 +11,6 @@ struct StreamConfig
hipStream_t
stream_id_
=
nullptr
;
hipStream_t
stream_id_
=
nullptr
;
bool
time_kernel_
=
false
;
bool
time_kernel_
=
false
;
int
log_level_
=
0
;
int
log_level_
=
0
;
int
cold_niters_
=
1
;
int
cold_niters_
=
5
;
int
nrepeat_
=
1
0
;
int
nrepeat_
=
5
0
;
};
};
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp
0 → 100644
View file @
522b7aee
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment