Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
c70aacd3
Commit
c70aacd3
authored
Jul 31, 2024
by
Jing Zhang
Browse files
format
parent
f9b8a5d0
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
154 additions
and
176 deletions
+154
-176
example/01_gemm/run_gemm_example_v2.inc
example/01_gemm/run_gemm_example_v2.inc
+2
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp
...operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp
+137
-154
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
+7
-16
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+8
-5
No files found.
example/01_gemm/run_gemm_example_v2.inc
View file @
c70aacd3
...
@@ -272,7 +272,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -272,7 +272,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
if
(
config
.
time_kernel
)
if
(
config
.
time_kernel
)
{
{
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
,
0
,
20
,
50
,
true
,
50
});
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
,
0
,
20
,
50
,
true
,
50
});
std
::
size_t
flop
=
2_
uz
*
M
*
N
*
K
;
std
::
size_t
flop
=
2_
uz
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
std
::
size_t
num_btype
=
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp
View file @
c70aacd3
...
@@ -168,14 +168,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
...
@@ -168,14 +168,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
// rotating mem
// rotating mem
rotating_mem
.
Next
();
rotating_mem
.
Next
();
// clear c mem
// clear c mem
{
if
(
arg_
.
KBatch
>
1
)
if
(
arg_
.
KBatch
>
1
)
hipGetErrorString
(
hipMemsetAsync
(
arg_
.
p_c_grid
,
hipGetErrorString
(
0
,
hipMemsetAsync
(
arg_
.
p_c_grid
,
arg_
.
M
*
arg_
.
N
*
sizeof
(
CDataType
),
0
,
stream_config
.
stream_id_
));
arg_
.
M
*
arg_
.
N
*
sizeof
(
CDataType
),
stream_config
.
stream_id_
));
}
};
};
ave_time
=
ck
::
utility
::
launch_and_time_kernel_with_preprocess
<
false
>
(
ave_time
=
ck
::
utility
::
launch_and_time_kernel_with_preprocess
<
false
>
(
...
@@ -189,13 +186,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
...
@@ -189,13 +186,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
}
}
else
else
{
{
{
if
(
arg
.
KBatch
>
1
)
if
(
arg
.
KBatch
>
1
)
hipGetErrorString
(
hipMemsetAsync
(
arg
.
p_c_grid
,
hipGetErrorString
(
hipMemsetAsync
(
arg
.
p_c_grid
,
0
,
0
,
arg
.
M
*
arg
.
N
*
sizeof
(
CDataType
),
arg
.
M
*
arg
.
N
*
sizeof
(
CDataType
),
stream_config
.
stream_id_
));
stream_config
.
stream_id_
));
}
ave_time
=
launch_and_time_kernel
(
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
arg
);
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
arg
);
...
@@ -213,14 +208,12 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
...
@@ -213,14 +208,12 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
{
{
if
(
arg
.
KBatch
>
1
)
if
(
arg
.
KBatch
>
1
)
{
{
{
const
auto
kernel
=
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
>
;
minimum_occupancy
>
;
Run
(
kernel
);
Run
(
kernel
);
}
}
}
else
else
{
{
...
@@ -237,117 +230,113 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
...
@@ -237,117 +230,113 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
{
{
if
(
arg
.
KBatch
>
1
)
if
(
arg
.
KBatch
>
1
)
{
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
One
)
{
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
One
)
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
One
>
;
Run
(
kernel
);
}
else
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Full
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Full
>
;
Run
(
kernel
);
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
2
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Two
)
{
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
GridwiseGemm
,
true
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
minimum_occupancy
,
TailNumber
::
One
>
;
TailNumber
::
Two
>
;
Run
(
kernel
);
Run
(
kernel
);
}
}
else
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
}
TailNumber
::
Full
)
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
3
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Three
)
{
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
GridwiseGemm
,
true
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
minimum_occupancy
,
TailNumber
::
Full
>
;
TailNumber
::
Three
>
;
Run
(
kernel
);
Run
(
kernel
);
}
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
2
)
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
4
)
{
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Two
)
TailNumber
::
Four
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Two
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
3
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Three
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Three
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
4
)
{
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
TailNumber
::
Four
)
GridwiseGemm
,
{
true
,
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
InMemoryDataOperationEnum
::
AtomicAdd
,
GridwiseGemm
,
minimum_occupancy
,
true
,
TailNumber
::
Four
>
;
InMemoryDataOperationEnum
::
AtomicAdd
,
Run
(
kernel
);
minimum_occupancy
,
TailNumber
::
Four
>
;
Run
(
kernel
);
}
}
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
5
)
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
5
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Five
)
{
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
TailNumber
::
Five
)
GridwiseGemm
,
{
true
,
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
InMemoryDataOperationEnum
::
AtomicAdd
,
GridwiseGemm
,
minimum_occupancy
,
true
,
TailNumber
::
Five
>
;
InMemoryDataOperationEnum
::
AtomicAdd
,
Run
(
kernel
);
minimum_occupancy
,
TailNumber
::
Five
>
;
Run
(
kernel
);
}
}
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
6
)
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
6
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Six
)
{
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
TailNumber
::
Six
)
GridwiseGemm
,
{
true
,
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
InMemoryDataOperationEnum
::
AtomicAdd
,
GridwiseGemm
,
minimum_occupancy
,
true
,
TailNumber
::
Six
>
;
InMemoryDataOperationEnum
::
AtomicAdd
,
Run
(
kernel
);
minimum_occupancy
,
TailNumber
::
Six
>
;
Run
(
kernel
);
}
}
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
7
)
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
7
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Seven
)
{
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
TailNumber
::
Seven
)
GridwiseGemm
,
{
true
,
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
InMemoryDataOperationEnum
::
AtomicAdd
,
GridwiseGemm
,
minimum_occupancy
,
true
,
TailNumber
::
Seven
>
;
InMemoryDataOperationEnum
::
AtomicAdd
,
Run
(
kernel
);
minimum_occupancy
,
TailNumber
::
Seven
>
;
Run
(
kernel
);
}
}
}
}
}
}
}
...
@@ -469,27 +458,25 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
...
@@ -469,27 +458,25 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
{
{
if
(
arg
.
KBatch
>
1
)
if
(
arg
.
KBatch
>
1
)
{
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_2lds
<
{
GridwiseGemm
,
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_2lds
<
true
,
GridwiseGemm
,
InMemoryDataOperationEnum
::
AtomicAdd
,
true
,
minimum_occupancy
,
InMemoryDataOperationEnum
::
AtomicAdd
,
TailNumber
::
Odd
>
;
minimum_occupancy
,
Run
(
kernel
);
TailNumber
::
Odd
>
;
}
Run
(
kernel
);
else
}
{
else
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_2lds
<
{
GridwiseGemm
,
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_2lds
<
true
,
GridwiseGemm
,
InMemoryDataOperationEnum
::
AtomicAdd
,
true
,
minimum_occupancy
,
InMemoryDataOperationEnum
::
AtomicAdd
,
TailNumber
::
Even
>
;
minimum_occupancy
,
Run
(
kernel
);
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
}
}
}
else
else
...
@@ -520,27 +507,25 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
...
@@ -520,27 +507,25 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
{
{
if
(
arg
.
KBatch
>
1
)
if
(
arg
.
KBatch
>
1
)
{
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
const
auto
kernel
=
{
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
true
,
GridwiseGemm
,
InMemoryDataOperationEnum
::
AtomicAdd
,
true
,
minimum_occupancy
,
InMemoryDataOperationEnum
::
AtomicAdd
,
TailNumber
::
Odd
>
;
minimum_occupancy
,
Run
(
kernel
);
TailNumber
::
Odd
>
;
}
Run
(
kernel
);
else
}
{
else
const
auto
kernel
=
{
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
true
,
GridwiseGemm
,
InMemoryDataOperationEnum
::
AtomicAdd
,
true
,
minimum_occupancy
,
InMemoryDataOperationEnum
::
AtomicAdd
,
TailNumber
::
Even
>
;
minimum_occupancy
,
Run
(
kernel
);
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
}
}
}
else
else
...
@@ -576,14 +561,12 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
...
@@ -576,14 +561,12 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
if
(
arg
.
KBatch
>
1
)
if
(
arg
.
KBatch
>
1
)
{
{
{
const
auto
kernel
=
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
false
,
false
,
InMemoryDataOperationEnum
::
AtomicAdd
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
>
;
minimum_occupancy
>
;
Run
(
kernel
);
Run
(
kernel
);
}
}
}
else
else
{
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
View file @
c70aacd3
...
@@ -29,7 +29,7 @@ template <typename GridwiseGemm,
...
@@ -29,7 +29,7 @@ template <typename GridwiseGemm,
TailNumber
TailNum
=
TailNumber
::
Full
>
TailNumber
TailNum
=
TailNumber
::
Full
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
MinimumOccupancy
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
MinimumOccupancy
)
#endif
#endif
// __attribute__((amdgpu_waves_per_eu(1, 1)))
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3
(
typename
GridwiseGemm
::
Argument
karg
)
kernel_gemm_xdl_cshuffle_v3
(
typename
GridwiseGemm
::
Argument
karg
)
...
@@ -57,7 +57,7 @@ template <typename GridwiseGemm,
...
@@ -57,7 +57,7 @@ template <typename GridwiseGemm,
TailNumber
TailNum
=
TailNumber
::
Full
>
TailNumber
TailNum
=
TailNumber
::
Full
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
MinimumOccupancy
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
MinimumOccupancy
)
#endif
#endif
// __attribute__((amdgpu_waves_per_eu(1, 1)))
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3_2lds
(
typename
GridwiseGemm
::
Argument
karg
)
kernel_gemm_xdl_cshuffle_v3_2lds
(
typename
GridwiseGemm
::
Argument
karg
)
...
@@ -485,20 +485,11 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -485,20 +485,11 @@ struct GridwiseGemm_xdl_cshuffle_v3
__host__
void
Print
()
const
__host__
void
Print
()
const
{
{
std
::
cout
<<
"problem {"
std
::
cout
<<
"problem {"
<<
"M:"
<<
M
<<
", "
<<
"N:"
<<
N
<<
", "
<<
"K:"
<<
K
<<
", "
<<
"M:"
<<
M
<<
", "
<<
"SA:"
<<
StrideA
<<
", "
<<
"SB:"
<<
StrideB
<<
", "
<<
"SC:"
<<
StrideC
<<
"N:"
<<
N
<<
", "
<<
", "
<<
"MP:"
<<
MPadded
<<
", "
<<
"NP:"
<<
NPadded
<<
", "
<<
"K:"
<<
K
<<
", "
<<
"KRead:"
<<
KRead
<<
", "
<<
"KP:"
<<
KPadded
<<
", "
<<
"AK0:"
<<
AK0
<<
"SA:"
<<
StrideA
<<
", "
<<
", "
<<
"BK0:"
<<
BK0
<<
", "
<<
"MBlock: "
<<
MBlock
<<
", "
<<
"SB:"
<<
StrideB
<<
", "
<<
"SC:"
<<
StrideC
<<
", "
<<
"MP:"
<<
MPadded
<<
", "
<<
"NP:"
<<
NPadded
<<
", "
<<
"KRead:"
<<
KRead
<<
", "
<<
"KP:"
<<
KPadded
<<
", "
<<
"AK0:"
<<
AK0
<<
", "
<<
"BK0:"
<<
BK0
<<
", "
<<
"MBlock: "
<<
MBlock
<<
", "
<<
"NBlock: "
<<
NBlock
<<
"}"
<<
std
::
endl
;
<<
"NBlock: "
<<
NBlock
<<
"}"
<<
std
::
endl
;
}
}
...
...
include/ck/utility/amd_buffer_addressing.hpp
View file @
c70aacd3
...
@@ -571,7 +571,8 @@ __device__ void amd_global_atomic_add_impl(const typename vector_type<T, N>::typ
...
@@ -571,7 +571,8 @@ __device__ void amd_global_atomic_add_impl(const typename vector_type<T, N>::typ
static_assert
(
N
%
2
==
0
,
""
);
static_assert
(
N
%
2
==
0
,
""
);
vector_type
<
half_t
,
N
>
tmp
{
src_thread_data
};
vector_type
<
half_t
,
N
>
tmp
{
src_thread_data
};
static_for
<
0
,
N
/
2
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
N
/
2
,
1
>
{}([
&
](
auto
i
)
{
__builtin_amdgcn_global_atomic_fadd_v2f16
(
bit_cast
<
half2_t
*>
(
addr
)
+
i
,
tmp
.
template
AsType
<
half2_t
>()[
i
]);
__builtin_amdgcn_global_atomic_fadd_v2f16
(
bit_cast
<
half2_t
*>
(
addr
)
+
i
,
tmp
.
template
AsType
<
half2_t
>()[
i
]);
});
});
}
}
else
if
constexpr
(
is_same
<
T
,
bhalf_t
>::
value
)
else
if
constexpr
(
is_same
<
T
,
bhalf_t
>::
value
)
...
@@ -579,7 +580,8 @@ __device__ void amd_global_atomic_add_impl(const typename vector_type<T, N>::typ
...
@@ -579,7 +580,8 @@ __device__ void amd_global_atomic_add_impl(const typename vector_type<T, N>::typ
static_assert
(
N
%
2
==
0
,
""
);
static_assert
(
N
%
2
==
0
,
""
);
vector_type
<
bhalf_t
,
N
>
tmp
{
src_thread_data
};
vector_type
<
bhalf_t
,
N
>
tmp
{
src_thread_data
};
static_for
<
0
,
N
/
2
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
N
/
2
,
1
>
{}([
&
](
auto
i
)
{
__builtin_amdgcn_global_atomic_fadd_v2bf16
(
bit_cast
<
bhalf2_t
*>
(
addr
)
+
i
,
tmp
.
template
AsType
<
bhalf2_t
>()[
i
]);
__builtin_amdgcn_global_atomic_fadd_v2bf16
(
bit_cast
<
bhalf2_t
*>
(
addr
)
+
i
,
tmp
.
template
AsType
<
bhalf2_t
>()[
i
]);
});
});
}
}
}
}
...
@@ -939,9 +941,10 @@ amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thr
...
@@ -939,9 +941,10 @@ amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thr
{
{
ignore
=
dst_wave_buffer_resource
;
ignore
=
dst_wave_buffer_resource
;
ignore
=
dst_thread_addr_offset
;
ignore
=
dst_thread_addr_offset
;
//amd_buffer_atomic_add_impl<scalar_t, vector_size>(
// amd_buffer_atomic_add_impl<scalar_t, vector_size>(
//src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
// src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
amd_global_atomic_add_impl
<
scalar_t
,
vector_size
>
(
src_thread_data
,
p_dst_wave
+
dst_thread_element_offset
);
amd_global_atomic_add_impl
<
scalar_t
,
vector_size
>
(
src_thread_data
,
p_dst_wave
+
dst_thread_element_offset
);
}
}
#endif
#endif
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment