Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
965b7ba4
Unverified
Commit
965b7ba4
authored
Dec 02, 2024
by
Illia Silin
Committed by
GitHub
Dec 02, 2024
Browse files
Merge pull request #229 from ROCm/promote_ocp_fp8
Promote ocp fp8
parents
5dff1b14
62e3c582
Changes
63
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
63 additions
and
43 deletions
+63
-43
CMakeLists.txt
CMakeLists.txt
+10
-1
client_example/24_grouped_conv_activation/CMakeLists.txt
client_example/24_grouped_conv_activation/CMakeLists.txt
+2
-2
client_example/CMakeLists.txt
client_example/CMakeLists.txt
+9
-1
example/01_gemm/CMakeLists.txt
example/01_gemm/CMakeLists.txt
+1
-1
example/01_gemm/common.hpp
example/01_gemm/common.hpp
+1
-1
example/01_gemm/run_gemm_example.inc
example/01_gemm/run_gemm_example.inc
+2
-2
example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp
..._grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp
+4
-4
example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp
example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp
+4
-4
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp
...e/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp
+4
-4
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp
+3
-3
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp
...le/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp
+3
-3
example/15_grouped_gemm/run_grouped_gemm_example.inc
example/15_grouped_gemm/run_grouped_gemm_example.inc
+5
-2
example/18_batched_gemm_reduce/CMakeLists.txt
example/18_batched_gemm_reduce/CMakeLists.txt
+1
-1
example/21_gemm_layernorm/gemm_xdl_layernorm_naive_single_kernel_fp16.cpp
...layernorm/gemm_xdl_layernorm_naive_single_kernel_fp16.cpp
+3
-3
example/31_batched_gemm_gemm/CMakeLists.txt
example/31_batched_gemm_gemm/CMakeLists.txt
+1
-1
example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc
...le/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc
+2
-2
example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc
...cale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc
+2
-2
example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc
...tmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc
+2
-2
example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc
...gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc
+2
-2
example/32_batched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc
...ched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc
+2
-2
No files found.
CMakeLists.txt
View file @
965b7ba4
...
...
@@ -185,13 +185,22 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx9")
add_definitions
(
-DCK_USE_XDL
)
endif
()
if
(
SUPPORTED_GPU_TARGETS MATCHES
"gfx94"
)
message
(
"Enabling FP8 gemms
i
n
ckProfiler
"
)
message
(
"Enabling FP8 gemms
o
n
native architectures
"
)
add_definitions
(
-DCK_USE_GFX94
)
endif
()
if
(
SUPPORTED_GPU_TARGETS MATCHES
"gfx11"
OR SUPPORTED_GPU_TARGETS MATCHES
"gfx12"
)
message
(
"Enabling WMMA instances"
)
add_definitions
(
-DCK_USE_WMMA
)
endif
()
if
(
SUPPORTED_GPU_TARGETS MATCHES
"gfx12"
)
add_definitions
(
-DCK_USE_OCP_FP8
)
set
(
CK_USE_OCP_FP8
"ON"
)
endif
()
if
(
SUPPORTED_GPU_TARGETS MATCHES
"gfx90a"
OR SUPPORTED_GPU_TARGETS MATCHES
"gfx94"
)
add_definitions
(
-DCK_USE_FNUZ_FP8
)
set
(
CK_USE_FNUZ_FP8
"ON"
)
endif
()
option
(
CK_USE_FP8_ON_UNSUPPORTED_ARCH
"Enable FP8 GEMM instances on older architectures"
OFF
)
if
(
CK_USE_FP8_ON_UNSUPPORTED_ARCH
AND
(
SUPPORTED_GPU_TARGETS MATCHES
"gfx90a"
OR SUPPORTED_GPU_TARGETS MATCHES
"gfx908"
))
add_definitions
(
-DCK_USE_FP8_ON_UNSUPPORTED_ARCH
)
...
...
client_example/24_grouped_conv_activation/CMakeLists.txt
View file @
965b7ba4
...
...
@@ -54,7 +54,7 @@ target_link_libraries(client_conv3d_fwd_convscale_relu_amax_fp8
PRIVATE composable_kernel::device_conv_operations
composable_kernel::device_other_operations
composable_kernel::device_reduction_operations
utility
)
composable_kernel::
utility
)
# Fwd convscale + AMAX
add_executable
(
client_conv3d_fwd_convscale_amax_fp8
grouped_convnd_fwd_convscale_reduce/conv3d_fwd_convscale_amax_fp8.cpp
)
...
...
@@ -62,7 +62,7 @@ target_link_libraries(client_conv3d_fwd_convscale_amax_fp8
PRIVATE composable_kernel::device_conv_operations
composable_kernel::device_other_operations
composable_kernel::device_reduction_operations
utility
)
composable_kernel::
utility
)
# Fwd convscale
add_executable
(
client_conv3d_fwd_convscale_fp8
grouped_convnd_fwd_convscale/conv3d_fwd_convscale_fp8.cpp
)
...
...
client_example/CMakeLists.txt
View file @
965b7ba4
...
...
@@ -56,13 +56,21 @@ if (GPU_TARGETS)
add_definitions
(
-DCK_USE_WMMA
)
set
(
CK_USE_WMMA
"ON"
)
endif
()
if
(
GPU_TARGETS MATCHES
"gfx12"
)
add_definitions
(
-DCK_USE_OCP_FP8
)
set
(
CK_USE_OCP_FP8
"ON"
)
endif
()
if
(
GPU_TARGETS MATCHES
"gfx90a"
OR GPU_TARGETS MATCHES
"gfx94"
)
add_definitions
(
-DCK_USE_FNUZ_FP8
)
set
(
CK_USE_FNUZ_FP8
"ON"
)
endif
()
else
()
add_definitions
(
-DCK_USE_WMMA -DCK_USE_XDL
)
set
(
CK_USE_XDL
"ON"
)
set
(
CK_USE_WMMA
"ON"
)
endif
()
find_package
(
composable_kernel COMPONENTS device_other_operations device_gemm_operations device_conv_operations device_reduction_operations
)
find_package
(
composable_kernel COMPONENTS device_other_operations device_gemm_operations device_conv_operations device_reduction_operations
utility
)
if
(
GPU_TARGETS MATCHES
"gfx9"
)
find_package
(
composable_kernel COMPONENTS device_contraction_operations
)
endif
()
...
...
example/01_gemm/CMakeLists.txt
View file @
965b7ba4
...
...
@@ -58,7 +58,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp64)
add_example_executable
(
example_gemm_xdl_streamk gemm_xdl_streamk.cpp
)
list
(
APPEND gpu_list gfx90a gfx940 gfx941 gfx942
gfx950
)
list
(
APPEND gpu_list gfx90a gfx940 gfx941 gfx942
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
...
...
example/01_gemm/common.hpp
View file @
965b7ba4
...
...
@@ -76,7 +76,7 @@ struct ProblemSizeSplitK final
struct
ExecutionConfig
final
{
// 0 - no verification, 1 - CPU, 2 - GPU, 3 - CPU + GPU
int
do_verification
=
3
;
int
do_verification
=
1
;
int
init_method
=
2
;
bool
time_kernel
=
false
;
};
...
...
example/01_gemm/run_gemm_example.inc
View file @
965b7ba4
...
...
@@ -143,8 +143,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
switch
(
config
.
init_method
)
{
case
0
:
ck
::
utils
::
FillConstant
<
ADataType
>
{
static_cas
t
<
ADataType
>
(
1.
f
)}(
a_m_k
);
ck
::
utils
::
FillConstant
<
BDataType
>
{
static_cas
t
<
BDataType
>
(
1.
f
)}(
b_k_n
);
ck
::
utils
::
FillConstant
<
ADataType
>
{
ck
::
type_conver
t
<
ADataType
>
(
1.
f
)}(
a_m_k
);
ck
::
utils
::
FillConstant
<
BDataType
>
{
ck
::
type_conver
t
<
BDataType
>
(
1.
f
)}(
b_k_n
);
break
;
case
1
:
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_m_k
);
...
...
example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp
View file @
965b7ba4
...
...
@@ -186,15 +186,15 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
for
(
int
j
=
0
;
j
<
NumDMatrices
;
++
j
)
{
d_tensors
[
i
][
j
].
GenerateTensorValue
(
GeneratorTensor_3
<
A
DataType
>
{
0.0
,
1.0
});
d_tensors
[
i
][
j
].
GenerateTensorValue
(
GeneratorTensor_3
<
D
DataType
>
{
0.0
,
1.0
});
}
break
;
default:
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
0
>
{});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
ADataType
,
0
>
{});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
BDataType
,
1
>
{});
for
(
int
j
=
0
;
j
<
NumDMatrices
;
++
j
)
{
d_tensors
[
i
][
j
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
0
>
{});
d_tensors
[
i
][
j
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
DDataType
,
0
>
{});
}
}
}
...
...
example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp
View file @
965b7ba4
...
...
@@ -190,15 +190,15 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
for
(
int
j
=
0
;
j
<
NumDs
;
++
j
)
{
d_tensors
[
i
][
j
].
GenerateTensorValue
(
GeneratorTensor_3
<
A
DataType
>
{
0.0
,
1.0
});
d_tensors
[
i
][
j
].
GenerateTensorValue
(
GeneratorTensor_3
<
D
DataType
>
{
0.0
,
1.0
});
}
break
;
default:
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
0
>
{});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
ADataType
,
0
>
{});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
BDataType
,
1
>
{});
for
(
int
j
=
0
;
j
<
NumDs
;
++
j
)
{
d_tensors
[
i
][
j
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
0
>
{});
d_tensors
[
i
][
j
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
DDataType
,
0
>
{});
}
}
}
...
...
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp
View file @
965b7ba4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
...
...
@@ -167,11 +167,11 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
break
;
default:
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
0
>
{});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
ADataType
,
0
>
{});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
BDataType
,
1
>
{});
}
d0_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
d0_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
D0DataType
,
1
>
{});
}
using
GroupedGemmKernelArgument
=
ck
::
tensor_operation
::
device
::
GroupedGemmKernelArgument
<
1
>
;
...
...
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp
View file @
965b7ba4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
...
...
@@ -157,8 +157,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
break
;
default:
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
0
>
{});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
ADataType
,
0
>
{});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
BDataType
,
1
>
{});
}
}
...
...
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp
View file @
965b7ba4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
...
...
@@ -158,8 +158,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
break
;
default:
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
0
>
{});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
ADataType
,
0
>
{});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
BDataType
,
1
>
{});
}
}
...
...
example/15_grouped_gemm/run_grouped_gemm_example.inc
View file @
965b7ba4
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
struct
ProblemSize
final
...
...
@@ -124,8 +127,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
b_tensors
[
i
]
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
break
;
default
:
a_tensors
[
i
]
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
0
>
{});
b_tensors
[
i
]
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
a_tensors
[
i
]
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
ADataType
,
0
>
{});
b_tensors
[
i
]
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
BDataType
,
1
>
{});
}
}
...
...
example/18_batched_gemm_reduce/CMakeLists.txt
View file @
965b7ba4
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
gfx950
)
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
...
...
example/21_gemm_layernorm/gemm_xdl_layernorm_naive_single_kernel_fp16.cpp
View file @
965b7ba4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
...
...
@@ -175,8 +175,8 @@ int main(int argc, char* argv[])
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
break
;
default:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
0
>
{});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
ADataType
,
0
>
{});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
BDataType
,
1
>
{});
}
c0_n_bias
.
GenerateTensorValue
(
GeneratorTensor_2
<
C0DataType
>
{
-
5
,
5
});
...
...
example/31_batched_gemm_gemm/CMakeLists.txt
View file @
965b7ba4
...
...
@@ -5,6 +5,6 @@ if(USE_BITINT_EXTENSION_INT4)
add_example_executable
(
example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp
)
endif
(
USE_BITINT_EXTENSION_INT4
)
if
(
NOT GPU_TARGETS MATCHES
"gfx94"
AND NOT GPU_TARGETS MATCHES
"gfx95"
AND NOT GPU_TARGETS MATCHES
"gfx1"
)
if
(
NOT GPU_TARGETS MATCHES
"gfx94"
AND NOT GPU_TARGETS MATCHES
"gfx1"
)
add_example_executable
(
example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp
)
endif
()
example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc
View file @
965b7ba4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -150,7 +150,7 @@ bool run_batched_gemm_gemm_example(int argc, char* argv[])
break
;
default
:
a_g_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1
});
b0_g_k_n
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
b0_g_k_n
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
B0DataType
,
1
>
{});
b1_g_n_o
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
}
...
...
example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc
View file @
965b7ba4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
int
run
(
int
argc
,
char
*
argv
[])
{
...
...
@@ -157,7 +157,7 @@ int run(int argc, char* argv[])
break
;
default
:
a_g_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1
});
b0_g_k_n
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
b0_g_k_n
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
B0DataType
,
1
>
{});
b1_g_n_o
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
}
...
...
example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc
View file @
965b7ba4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
int
run
(
int
argc
,
char
*
argv
[])
{
...
...
@@ -118,7 +118,7 @@ int run(int argc, char* argv[])
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
break
;
default
:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
ADataType
,
2
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B0DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
}
...
...
example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc
View file @
965b7ba4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
int
run
(
int
argc
,
char
*
argv
[])
{
...
...
@@ -153,7 +153,7 @@ int run(int argc, char* argv[])
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
2
,
2
});
break
;
default
:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
ADataType
,
2
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B0DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
}
...
...
example/32_batched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc
View file @
965b7ba4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
int
run
(
int
argc
,
char
*
argv
[])
{
...
...
@@ -178,7 +178,7 @@ int run(int argc, char* argv[])
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
2
,
2
});
break
;
default
:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
ADataType
,
2
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B0DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
}
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment