Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7402fcbe
"...composable_kernel_rocm.git" did not exist on "50530c17d60deddd9452e4e479764cb63ab12d46"
Commit
7402fcbe
authored
Mar 08, 2023
by
ltqin
Browse files
add client example for gemm_bias_gemm
parent
22e7a408
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
657 additions
and
27 deletions
+657
-27
client_example/08_fused_attention/CMakeLists.txt
client_example/08_fused_attention/CMakeLists.txt
+6
-0
client_example/08_fused_attention/fused_attention_bias_mask.cpp
..._example/08_fused_attention/fused_attention_bias_mask.cpp
+233
-0
client_example/08_fused_attention/fused_attention_mask.cpp
client_example/08_fused_attention/fused_attention_mask.cpp
+226
-0
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute_general.hpp
...nstance/gpu/batched_gemm_softmax_gemm_permute_general.hpp
+164
-0
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt
...ance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt
+1
-1
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instance.cpp
...ax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instance.cpp
+27
-26
No files found.
client_example/08_fused_attention/CMakeLists.txt
View file @
7402fcbe
...
@@ -3,3 +3,9 @@ target_link_libraries(client_fused_attention PRIVATE composable_kernel::device_o
...
@@ -3,3 +3,9 @@ target_link_libraries(client_fused_attention PRIVATE composable_kernel::device_o
add_executable
(
client_fused_attention_bias fused_attention_bias.cpp
)
add_executable
(
client_fused_attention_bias fused_attention_bias.cpp
)
target_link_libraries
(
client_fused_attention_bias PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_fused_attention_bias PRIVATE composable_kernel::device_operations
)
add_executable
(
client_fused_attention_mask fused_attention_mask.cpp
)
target_link_libraries
(
client_fused_attention_mask PRIVATE composable_kernel::device_operations
)
add_executable
(
client_fused_attention_bias_mask fused_attention_bias_mask.cpp
)
target_link_libraries
(
client_fused_attention_bias_mask PRIVATE composable_kernel::device_operations
)
client_example/08_fused_attention/fused_attention_bias_mask.cpp
0 → 100644
View file @
7402fcbe
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <vector>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute_general.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
B0ElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Acc0ElementOp
=
ck
::
tensor_operation
::
element_wise
::
ScaleBiasMask
;
using
B1ElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
constexpr
static
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
using
ADataType
=
ck
::
half_t
;
using
B0DataType
=
ck
::
half_t
;
using
B1DataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
using
D00DataType
=
ck
::
half_t
;
using
D01DataType
=
int32_t
;
using
AccDataType
=
float
;
struct
SimpleDeviceMem
{
SimpleDeviceMem
()
=
delete
;
SimpleDeviceMem
(
std
::
size_t
mem_size
)
:
p_mem_
{}
{
(
void
)
hipMalloc
(
static_cast
<
void
**>
(
&
p_mem_
),
mem_size
);
}
void
*
GetDeviceBuffer
()
{
return
p_mem_
;
}
~
SimpleDeviceMem
()
{
(
void
)
hipFree
(
p_mem_
);
}
void
*
p_mem_
;
};
int
main
(
int
argc
,
char
*
argv
[])
{
int
G0
=
48
;
int
G1
=
16
;
int
M
=
1024
;
int
N
=
1024
;
int
K
=
64
;
int
O
=
64
;
// A layout [G0, M, G1, K]
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_strides
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
};
// B0 layout [G0, N, G1, K]
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_strides
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
};
// B1 layout [G0, N, G1, O]
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_strides
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
};
// C layout [G0, M, G1, O]
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_strides
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
};
// D00 layout [G0, M, G1, N]
std
::
vector
<
ck
::
index_t
>
d00_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
d00_gs_ms_ns_strides
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
};
// D01 layout [G0, M, G1, N]
std
::
vector
<
ck
::
index_t
>
d01_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
d01_gs_ms_ns_strides
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
};
SimpleDeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
G0
*
G1
*
M
*
K
);
SimpleDeviceMem
b0_device_buf
(
sizeof
(
B0DataType
)
*
G0
*
G1
*
N
*
K
);
SimpleDeviceMem
d00_device_buf
(
sizeof
(
D00DataType
)
*
G0
*
G1
*
M
*
N
);
SimpleDeviceMem
d01_device_buf
(
sizeof
(
D01DataType
)
*
G0
*
G1
*
M
*
N
);
SimpleDeviceMem
b1_device_buf
(
sizeof
(
B1DataType
)
*
G0
*
G1
*
O
*
N
);
SimpleDeviceMem
c_device_buf
(
sizeof
(
CDataType
)
*
G0
*
G1
*
M
*
O
);
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
ck
::
Tuple
<
D00DataType
,
D01DataType
>
,
ck
::
Tuple
<>
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
MaskingSpec
>
;
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
std
::
string
best_op_name
;
int
best_op_id
=
-
1
;
float
best_ave_time
=
0
;
float
best_tflops
=
0
;
float
best_gb_per_sec
=
0
;
// profile device op instances
std
::
cout
<<
"Run all instances and do timing"
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<
op_ptrs
.
size
();
++
i
)
{
auto
&
op_ptr
=
op_ptrs
[
i
];
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
a_device_buf
.
GetDeviceBuffer
(),
b0_device_buf
.
GetDeviceBuffer
(),
b1_device_buf
.
GetDeviceBuffer
(),
c_device_buf
.
GetDeviceBuffer
(),
std
::
array
<
void
*
,
2
>
{
d00_device_buf
.
GetDeviceBuffer
(),
d01_device_buf
.
GetDeviceBuffer
()},
// p_acc0_biases
{},
// p_acc1_biases
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
,
b1_gs_os_ns_lengths
,
b1_gs_os_ns_strides
,
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
2
>
{
d00_gs_ms_ns_lengths
,
d01_gs_ms_ns_lengths
},
// acc0_biases_gs_ms_ns_lengths
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
2
>
{
d01_gs_ms_ns_strides
,
d01_gs_ms_ns_strides
},
// acc0_biases_gs_ms_ns_strides
{},
// acc1_biases_gs_ms_os_lengths
{},
// acc1_biases_gs_ms_os_strides
AElementOp
{},
B0ElementOp
{},
Acc0ElementOp
{
1
/
sqrtf
(
K
),
0.1
},
B1ElementOp
{},
CElementOp
{});
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
true
});
std
::
size_t
flop
=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
)
*
G0
*
G1
;
std
::
size_t
num_btype
=
(
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
+
sizeof
(
B1DataType
)
*
N
*
O
+
sizeof
(
CDataType
)
*
M
*
O
+
sizeof
(
D00DataType
)
*
M
*
N
*
2
)
*
G0
*
G1
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
std
::
endl
;
if
(
tflops
>
best_tflops
)
{
best_op_id
=
i
;
best_op_name
=
op_name
;
best_tflops
=
tflops
;
best_ave_time
=
ave_time
;
best_gb_per_sec
=
gb_per_sec
;
}
}
else
{
std
::
cout
<<
op_name
<<
" does not support this problem"
<<
std
::
endl
;
}
}
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
// run the best instance
{
auto
&
op_ptr
=
op_ptrs
[
best_op_id
];
std
::
cout
<<
"Run the best instance without timing: "
<<
op_ptr
->
GetTypeString
()
<<
std
::
endl
;
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
a_device_buf
.
GetDeviceBuffer
(),
b0_device_buf
.
GetDeviceBuffer
(),
b1_device_buf
.
GetDeviceBuffer
(),
c_device_buf
.
GetDeviceBuffer
(),
std
::
array
<
void
*
,
2
>
{
d00_device_buf
.
GetDeviceBuffer
(),
d01_device_buf
.
GetDeviceBuffer
()},
// p_acc0_biases
{},
// p_acc1_biases
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
,
b1_gs_os_ns_lengths
,
b1_gs_os_ns_strides
,
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
2
>
{
d00_gs_ms_ns_lengths
,
d01_gs_ms_ns_lengths
},
// acc0_biases_gs_ms_ns_lengths
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
2
>
{
d01_gs_ms_ns_strides
,
d01_gs_ms_ns_strides
},
// acc0_biases_gs_ms_ns_strides
{},
// acc1_biases_gs_ms_os_lengths
{},
// acc1_biases_gs_ms_os_strides
AElementOp
{},
B0ElementOp
{},
Acc0ElementOp
{
1
/
sqrtf
(
K
),
0.1
},
B1ElementOp
{},
CElementOp
{});
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
});
}
std
::
cout
<<
"Done"
<<
std
::
endl
;
}
return
0
;
}
client_example/08_fused_attention/fused_attention_mask.cpp
0 → 100644
View file @
7402fcbe
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <vector>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute_general.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
B0ElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Acc0ElementOp
=
ck
::
tensor_operation
::
element_wise
::
ScaleMask
;
using
B1ElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
constexpr
static
auto
MaskingSpec
=
ck
::
tensor_operation
::
device
::
MaskingSpecialization
::
MaskDisabled
;
using
ADataType
=
ck
::
half_t
;
using
B0DataType
=
ck
::
half_t
;
using
B1DataType
=
ck
::
half_t
;
using
CDataType
=
ck
::
half_t
;
using
D0DataType
=
int32_t
;
using
AccDataType
=
float
;
struct
SimpleDeviceMem
{
SimpleDeviceMem
()
=
delete
;
SimpleDeviceMem
(
std
::
size_t
mem_size
)
:
p_mem_
{}
{
(
void
)
hipMalloc
(
static_cast
<
void
**>
(
&
p_mem_
),
mem_size
);
}
void
*
GetDeviceBuffer
()
{
return
p_mem_
;
}
~
SimpleDeviceMem
()
{
(
void
)
hipFree
(
p_mem_
);
}
void
*
p_mem_
;
};
int
main
(
int
argc
,
char
*
argv
[])
{
int
G0
=
48
;
int
G1
=
16
;
int
M
=
1024
;
int
N
=
1024
;
int
K
=
64
;
int
O
=
64
;
// A layout [G0, M, G1, K]
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M
,
K
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_strides
{
M
*
G1
*
K
,
K
,
G1
*
K
,
1
};
// B0 layout [G0, N, G1, K]
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_lengths
{
G0
,
G1
,
N
,
K
};
std
::
vector
<
ck
::
index_t
>
b0_gs_ns_ks_strides
{
N
*
G1
*
K
,
K
,
G1
*
K
,
1
};
// B1 layout [G0, N, G1, O]
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_lengths
{
G0
,
G1
,
O
,
N
};
std
::
vector
<
ck
::
index_t
>
b1_gs_os_ns_strides
{
N
*
G1
*
O
,
O
,
1
,
G1
*
O
};
// C layout [G0, M, G1, O]
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_lengths
{
G0
,
G1
,
M
,
O
};
std
::
vector
<
ck
::
index_t
>
c_gs_ms_os_strides
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
};
// D layout [G0, M, G1, N]
std
::
vector
<
ck
::
index_t
>
d0_gs_ms_ns_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
d0_gs_ms_ns_strides
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
};
SimpleDeviceMem
a_device_buf
(
sizeof
(
ADataType
)
*
G0
*
G1
*
M
*
K
);
SimpleDeviceMem
b0_device_buf
(
sizeof
(
B0DataType
)
*
G0
*
G1
*
N
*
K
);
SimpleDeviceMem
d0_device_buf
(
sizeof
(
D0DataType
)
*
G0
*
G1
*
M
*
N
);
SimpleDeviceMem
b1_device_buf
(
sizeof
(
B1DataType
)
*
G0
*
G1
*
O
*
N
);
SimpleDeviceMem
c_device_buf
(
sizeof
(
CDataType
)
*
G0
*
G1
*
M
*
O
);
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
ck
::
Tuple
<
D0DataType
>
,
ck
::
Tuple
<>
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
MaskingSpec
>
;
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
std
::
string
best_op_name
;
int
best_op_id
=
-
1
;
float
best_ave_time
=
0
;
float
best_tflops
=
0
;
float
best_gb_per_sec
=
0
;
// profile device op instances
std
::
cout
<<
"Run all instances and do timing"
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<
op_ptrs
.
size
();
++
i
)
{
auto
&
op_ptr
=
op_ptrs
[
i
];
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
a_device_buf
.
GetDeviceBuffer
(),
b0_device_buf
.
GetDeviceBuffer
(),
b1_device_buf
.
GetDeviceBuffer
(),
c_device_buf
.
GetDeviceBuffer
(),
std
::
array
<
void
*
,
1
>
{
d0_device_buf
.
GetDeviceBuffer
()},
// p_acc0_biases
{},
// p_acc1_biases
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
,
b1_gs_os_ns_lengths
,
b1_gs_os_ns_strides
,
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d0_gs_ms_ns_lengths
},
// acc0_biases_gs_ms_ns_lengths
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d0_gs_ms_ns_strides
},
// acc0_biases_gs_ms_ns_strides
{},
// acc1_biases_gs_ms_os_lengths
{},
// acc1_biases_gs_ms_os_strides
AElementOp
{},
B0ElementOp
{},
Acc0ElementOp
{
1
/
sqrtf
(
K
),
0.1
},
B1ElementOp
{},
CElementOp
{});
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
true
});
std
::
size_t
flop
=
(
size_t
(
M
)
*
N
*
K
*
2
+
size_t
(
M
)
*
N
*
O
*
2
)
*
G0
*
G1
;
std
::
size_t
num_btype
=
(
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
B0DataType
)
*
K
*
N
+
sizeof
(
B1DataType
)
*
N
*
O
+
sizeof
(
CDataType
)
*
M
*
O
+
sizeof
(
D0DataType
)
*
M
*
N
)
*
G0
*
G1
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
std
::
endl
;
if
(
tflops
>
best_tflops
)
{
best_op_id
=
i
;
best_op_name
=
op_name
;
best_tflops
=
tflops
;
best_ave_time
=
ave_time
;
best_gb_per_sec
=
gb_per_sec
;
}
}
else
{
std
::
cout
<<
op_name
<<
" does not support this problem"
<<
std
::
endl
;
}
}
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
// run the best instance
{
auto
&
op_ptr
=
op_ptrs
[
best_op_id
];
std
::
cout
<<
"Run the best instance without timing: "
<<
op_ptr
->
GetTypeString
()
<<
std
::
endl
;
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
a_device_buf
.
GetDeviceBuffer
(),
b0_device_buf
.
GetDeviceBuffer
(),
b1_device_buf
.
GetDeviceBuffer
(),
c_device_buf
.
GetDeviceBuffer
(),
std
::
array
<
void
*
,
1
>
{
d0_device_buf
.
GetDeviceBuffer
()},
// p_acc0_biases
{},
// p_acc1_biases
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
,
b1_gs_os_ns_lengths
,
b1_gs_os_ns_strides
,
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
,
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d0_gs_ms_ns_lengths
},
// acc0_biases_gs_ms_ns_lengths
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d0_gs_ms_ns_strides
},
// acc0_biases_gs_ms_ns_strides
{},
// acc1_biases_gs_ms_os_lengths
{},
// acc1_biases_gs_ms_os_strides
AElementOp
{},
B0ElementOp
{},
Acc0ElementOp
{
1
/
sqrtf
(
K
),
0.1
},
B1ElementOp
{},
CElementOp
{});
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
});
}
std
::
cout
<<
"Done"
<<
std
::
endl
;
}
return
0
;
}
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute_general.hpp
0 → 100644
View file @
7402fcbe
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<
int32_t
>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
element_wise
::
ScaleMask
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
instances
);
void
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<
int32_t
>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
element_wise
::
ScaleMask
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskDisabled
>>>&
instances
);
void
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<
F16
,
int32_t
>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
element_wise
::
ScaleBiasMask
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
instances
);
void
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<
F16
,
int32_t
>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
element_wise
::
ScaleBiasMask
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskDisabled
>>>&
instances
);
template
<
index_t
NumDimG
,
index_t
NumDimM
,
index_t
NumDimN
,
index_t
NumDimK
,
index_t
NumDimO
,
typename
ADataType
,
typename
B0DataType
,
typename
B1DataType
,
typename
CDataType
,
typename
Acc0BiasDataType
,
typename
Acc1BiasDataType
,
typename
AElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
C0DEElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
C1DEElementwiseOperation
,
MaskingSpecialization
MaskingSpec
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AElementwiseOperation
,
B0ElementwiseOperation
,
C0DEElementwiseOperation
,
B1ElementwiseOperation
,
C1DEElementwiseOperation
,
MaskingSpec
>>
{
using
DeviceOp
=
DeviceBatchedGemmSoftmaxGemmPermute
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AElementwiseOperation
,
B0ElementwiseOperation
,
C0DEElementwiseOperation
,
B1ElementwiseOperation
,
C1DEElementwiseOperation
,
MaskingSpec
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
op_ptrs
);
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt
View file @
7402fcbe
...
@@ -3,6 +3,6 @@ add_instance_library(device_batched_gemm_softmax_gemm_permute_instance
...
@@ -3,6 +3,6 @@ add_instance_library(device_batched_gemm_softmax_gemm_permute_instance
device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_
bias
_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_
multiple_d
_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instance.cpp
)
)
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_
bias
_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instance.cpp
→
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_
multiple_d
_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instance.cpp
View file @
7402fcbe
...
@@ -24,8 +24,8 @@ template <ck::index_t... Is>
...
@@ -24,8 +24,8 @@ template <ck::index_t... Is>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Scale
Add
=
ck
::
tensor_operation
::
element_wise
::
Scale
Add
;
using
Scale
Mask
=
ck
::
tensor_operation
::
element_wise
::
Scale
Mask
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
Scale
BiasMask
=
ck
::
tensor_operation
::
element_wise
::
Scale
BiasMask
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmPadded
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
static
constexpr
auto
GemmPadded
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
...
@@ -68,8 +68,8 @@ using device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo
...
@@ -68,8 +68,8 @@ using device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo
// clang-format on
// clang-format on
>
;
>
;
// f16
PassThrough
masking
// f16
ScaleMask
masking
void
add_device_batched_gemm_
bias_masking
_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
void
add_device_batched_gemm_
mutiple_d
_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
...
@@ -80,11 +80,11 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_gmk_
...
@@ -80,11 +80,11 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_gmk_
F16
,
F16
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<
int32_t
>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ScaleMask
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
...
@@ -100,13 +100,13 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_gmk_
...
@@ -100,13 +100,13 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_gmk_
1
,
1
,
F16
,
F16
,
F32
,
F32
,
ck
::
Tuple
<>
,
ck
::
Tuple
<
int32_t
>
,
PassThrough
,
ScaleMask
,
MaskingSpecialization
::
MaskOutUpperTriangle
>
{});
MaskingSpecialization
::
MaskOutUpperTriangle
>
{});
}
}
// f16
PassThrough
disable masking
// f16
ScaleMask
disable masking
void
add_device_batched_gemm_
bias
_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
void
add_device_batched_gemm_
mutiple_d
_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
...
@@ -117,11 +117,11 @@ void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_
...
@@ -117,11 +117,11 @@ void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_
F16
,
F16
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<
int32_t
>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
ScaleMask
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskDisabled
>>>&
MaskingSpecialization
::
MaskDisabled
>>>&
...
@@ -137,13 +137,13 @@ void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_
...
@@ -137,13 +137,13 @@ void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_
1
,
1
,
F16
,
F16
,
F32
,
F32
,
ck
::
Tuple
<>
,
ck
::
Tuple
<
int32_t
>
,
PassThrough
,
ScaleMask
,
MaskingSpecialization
::
MaskDisabled
>
{});
MaskingSpecialization
::
MaskDisabled
>
{});
}
}
// f16
// f16
ScaleBiasMask masking
void
add_device_batched_gemm_
bias_masking
_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
void
add_device_batched_gemm_
mutiple_d
_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
...
@@ -154,11 +154,11 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_gmk_
...
@@ -154,11 +154,11 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_gmk_
F16
,
F16
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<
F16
,
int32_t
>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
Scale
,
Scale
BiasMask
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
...
@@ -174,12 +174,13 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_gmk_
...
@@ -174,12 +174,13 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_gmk_
1
,
1
,
F16
,
F16
,
F32
,
F32
,
ck
::
Tuple
<>
,
ck
::
Tuple
<
F16
,
int32_t
>
,
Scale
,
Scale
BiasMask
,
MaskingSpecialization
::
MaskOutUpperTriangle
>
{});
MaskingSpecialization
::
MaskOutUpperTriangle
>
{});
}
}
void
add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
// f16 ScaleBiasMask disable masking
void
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
...
@@ -190,11 +191,11 @@ void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_
...
@@ -190,11 +191,11 @@ void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_
F16
,
F16
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<
F16
,
int32_t
>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
Scale
,
Scale
BiasMask
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskDisabled
>>>&
MaskingSpecialization
::
MaskDisabled
>>>&
...
@@ -210,8 +211,8 @@ void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_
...
@@ -210,8 +211,8 @@ void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_
1
,
1
,
F16
,
F16
,
F32
,
F32
,
ck
::
Tuple
<>
,
ck
::
Tuple
<
F16
,
int32_t
>
,
Scale
,
Scale
BiasMask
,
MaskingSpecialization
::
MaskDisabled
>
{});
MaskingSpecialization
::
MaskDisabled
>
{});
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment