Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
1970d162
Commit
1970d162
authored
Jul 15, 2023
by
fsx950223
Browse files
Merge remote-tracking branch 'origin/attn-train-develop-qloop' into skip_dropout
parents
5a3904c7
9b4c780a
Changes
38
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
54 additions
and
39 deletions
+54
-39
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v1.cpp
..._softmax_gemm/batched_multihead_attention_backward_v1.cpp
+1
-1
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v2.cpp
..._softmax_gemm/batched_multihead_attention_backward_v2.cpp
+1
-1
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v2_phased.cpp
...x_gemm/batched_multihead_attention_backward_v2_phased.cpp
+1
-1
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_forward_v1.cpp
...e_softmax_gemm/batched_multihead_attention_forward_v1.cpp
+1
-1
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_forward_v2.cpp
...e_softmax_gemm/batched_multihead_attention_forward_v2.cpp
+1
-1
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v1.cpp
...ale_softmax_gemm/batched_multihead_attention_train_v1.cpp
+2
-2
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v2.cpp
...ale_softmax_gemm/batched_multihead_attention_train_v2.cpp
+2
-2
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v1.cpp
..._softmax_gemm/grouped_multihead_attention_backward_v1.cpp
+1
-1
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v2.cpp
..._softmax_gemm/grouped_multihead_attention_backward_v2.cpp
+1
-1
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward_v1.cpp
...e_softmax_gemm/grouped_multihead_attention_forward_v1.cpp
+1
-1
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward_v2.cpp
...e_softmax_gemm/grouped_multihead_attention_forward_v2.cpp
+1
-1
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train_v1.cpp
...ale_softmax_gemm/grouped_multihead_attention_train_v1.cpp
+2
-2
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
...ale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
+2
-2
example/32_batched_gemm_scale_softmax_gemm/run_batched_multihead_attention_forward.inc
..._softmax_gemm/run_batched_multihead_attention_forward.inc
+1
-1
example/32_batched_gemm_scale_softmax_gemm/run_grouped_multihead_attention_forward.inc
..._softmax_gemm/run_grouped_multihead_attention_forward.inc
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_kloop_v1.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_kloop_v1.hpp
+7
-4
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_kloop_v2.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_kloop_v2.hpp
+7
-4
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_phased_v1.hpp
...l/device_batched_mha_bwd_xdl_cshuffle_qloop_phased_v1.hpp
+7
-4
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
+7
-4
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
...ice/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
+7
-4
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v1.cpp
View file @
1970d162
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
/*
/*
Backprop for Gemm + Softmax + Gemm fused operation, where forward prop is defined as:
Backprop for Gemm + Softmax + Gemm fused operation, where forward prop is defined as:
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v2.cpp
View file @
1970d162
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
/*
/*
Backprop for Gemm + Softmax + Gemm fused operation, where forward prop is defined as:
Backprop for Gemm + Softmax + Gemm fused operation, where forward prop is defined as:
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_v2_phased.cpp
View file @
1970d162
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
/*
/*
Backprop for Gemm + Softmax + Gemm fused operation, where forward prop is defined as:
Backprop for Gemm + Softmax + Gemm fused operation, where forward prop is defined as:
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_forward_v1.cpp
View file @
1970d162
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
/*
/*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_forward_v2.cpp
View file @
1970d162
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
/*
/*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v1.cpp
View file @
1970d162
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
/*
/*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
...
@@ -1396,7 +1396,7 @@ int run(int argc, char* argv[])
...
@@ -1396,7 +1396,7 @@ int run(int argc, char* argv[])
}
}
std
::
cout
<<
"Checking z:
\n
"
;
std
::
cout
<<
"Checking z:
\n
"
;
pass
&=
ck
::
utils
::
check_err
(
z_fwd_gs_ms_ns
.
mData
,
z_bwd_gs_ms_ns
.
mData
,
1
);
pass
&=
ck
::
utils
::
check_
integer_
err
(
z_fwd_gs_ms_ns
.
mData
,
z_bwd_gs_ms_ns
.
mData
,
1
);
std
::
cout
<<
"Checking y:
\n
"
;
std
::
cout
<<
"Checking y:
\n
"
;
pass
&=
ck
::
utils
::
check_err
(
pass
&=
ck
::
utils
::
check_err
(
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v2.cpp
View file @
1970d162
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
/*
/*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
...
@@ -969,7 +969,7 @@ int run(int argc, char* argv[])
...
@@ -969,7 +969,7 @@ int run(int argc, char* argv[])
}
}
std
::
cout
<<
"Checking z:
\n
"
;
std
::
cout
<<
"Checking z:
\n
"
;
pass
&=
ck
::
utils
::
check_err
(
z_fwd_gs_ms_ns
.
mData
,
z_bwd_gs_ms_ns
.
mData
,
1
);
pass
&=
ck
::
utils
::
check_
integer_
err
(
z_fwd_gs_ms_ns
.
mData
,
z_bwd_gs_ms_ns
.
mData
,
1
);
std
::
cout
<<
"Checking y:
\n
"
;
std
::
cout
<<
"Checking y:
\n
"
;
pass
&=
ck
::
utils
::
check_err
(
pass
&=
ck
::
utils
::
check_err
(
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v1.cpp
View file @
1970d162
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
/*
/*
Backprop for Gemm + Softmax + Gemm fused operation, where forward prop is defined as:
Backprop for Gemm + Softmax + Gemm fused operation, where forward prop is defined as:
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_backward_v2.cpp
View file @
1970d162
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
/*
/*
Backprop for Gemm + Softmax + Gemm fused operation, where forward prop is defined as:
Backprop for Gemm + Softmax + Gemm fused operation, where forward prop is defined as:
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward_v1.cpp
View file @
1970d162
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
/*
/*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_forward_v2.cpp
View file @
1970d162
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
/*
/*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train_v1.cpp
View file @
1970d162
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
/*
/*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
...
@@ -1420,7 +1420,7 @@ int run(int argc, char* argv[])
...
@@ -1420,7 +1420,7 @@ int run(int argc, char* argv[])
}
}
std
::
cout
<<
"Checking z:
\n
"
;
std
::
cout
<<
"Checking z:
\n
"
;
pass
&=
ck
::
utils
::
check_err
(
z_fwd_tensors
[
i
].
mData
,
z_bwd_tensors
[
i
].
mData
,
1
);
pass
&=
ck
::
utils
::
check_
integer_
err
(
z_fwd_tensors
[
i
].
mData
,
z_bwd_tensors
[
i
].
mData
,
1
);
std
::
cout
<<
"Checking y:
\n
"
;
std
::
cout
<<
"Checking y:
\n
"
;
pass
&=
ck
::
utils
::
check_err
(
pass
&=
ck
::
utils
::
check_err
(
...
...
example/32_batched_gemm_scale_softmax_gemm/grouped_multihead_attention_train_v2.cpp
View file @
1970d162
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
/*
/*
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o
...
@@ -994,7 +994,7 @@ int run(int argc, char* argv[])
...
@@ -994,7 +994,7 @@ int run(int argc, char* argv[])
}
}
std
::
cout
<<
"Checking z:
\n
"
;
std
::
cout
<<
"Checking z:
\n
"
;
pass
&=
ck
::
utils
::
check_err
(
z_fwd_tensors
[
i
].
mData
,
z_bwd_tensors
[
i
].
mData
,
1
);
pass
&=
ck
::
utils
::
check_
integer_
err
(
z_fwd_tensors
[
i
].
mData
,
z_bwd_tensors
[
i
].
mData
,
1
);
std
::
cout
<<
"Checking y:
\n
"
;
std
::
cout
<<
"Checking y:
\n
"
;
pass
&=
ck
::
utils
::
check_err
(
pass
&=
ck
::
utils
::
check_err
(
...
...
example/32_batched_gemm_scale_softmax_gemm/run_batched_multihead_attention_forward.inc
View file @
1970d162
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
int
run
(
int
argc
,
char
*
argv
[])
int
run
(
int
argc
,
char
*
argv
[])
{
{
...
...
example/32_batched_gemm_scale_softmax_gemm/run_grouped_multihead_attention_forward.inc
View file @
1970d162
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
int
run
(
int
argc
,
char
*
argv
[])
int
run
(
int
argc
,
char
*
argv
[])
{
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_kloop_v1.hpp
View file @
1970d162
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -89,7 +89,8 @@ __global__ void
...
@@ -89,7 +89,8 @@ __global__ void
const
unsigned
long
long
seed
,
const
unsigned
long
long
seed
,
const
unsigned
long
long
offset
)
const
unsigned
long
long
offset
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
@@ -641,7 +642,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
...
@@ -641,7 +642,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
};
};
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
<
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_
Kloop_
Xdl_CShuffle_V1
<
InputDataType
,
// TODO: distinguish A/B datatype
InputDataType
,
// TODO: distinguish A/B datatype
OutputDataType
,
OutputDataType
,
ZDataType
,
ZDataType
,
...
@@ -1050,7 +1051,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
...
@@ -1050,7 +1051,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V1
arg.Print();
arg.Print();
#endif
#endif
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
))
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_kloop_v2.hpp
View file @
1970d162
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -88,7 +88,8 @@ __global__ void
...
@@ -88,7 +88,8 @@ __global__ void
const
unsigned
long
long
seed
,
const
unsigned
long
long
seed
,
const
unsigned
long
long
offset
)
const
unsigned
long
long
offset
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
@@ -640,7 +641,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
...
@@ -640,7 +641,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
};
};
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
<
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_
Kloop_
Xdl_CShuffle_V2
<
InputDataType
,
// TODO: distinguish A/B datatype
InputDataType
,
// TODO: distinguish A/B datatype
OutputDataType
,
OutputDataType
,
ZDataType
,
ZDataType
,
...
@@ -1051,7 +1052,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
...
@@ -1051,7 +1052,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
arg.Print();
arg.Print();
#endif
#endif
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
))
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_phased_v1.hpp
View file @
1970d162
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -87,7 +87,8 @@ __global__ void
...
@@ -87,7 +87,8 @@ __global__ void
const
unsigned
long
long
seed
,
const
unsigned
long
long
seed
,
const
unsigned
long
long
offset
)
const
unsigned
long
long
offset
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
@@ -637,7 +638,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Phased_Xdl_CShuffle_V1
...
@@ -637,7 +638,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Phased_Xdl_CShuffle_V1
};
};
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
<
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_
Qloop_
Xdl_CShuffle_V1
<
InputDataType
,
// TODO: distinguish A/B datatype
InputDataType
,
// TODO: distinguish A/B datatype
OutputDataType
,
OutputDataType
,
ZDataType
,
ZDataType
,
...
@@ -1044,7 +1045,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Phased_Xdl_CShuffle_V1
...
@@ -1044,7 +1045,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Phased_Xdl_CShuffle_V1
arg.Print();
arg.Print();
#endif
#endif
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
))
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v1.hpp
View file @
1970d162
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -90,7 +90,8 @@ __global__ void
...
@@ -90,7 +90,8 @@ __global__ void
const
index_t
raw_m_padded
,
const
index_t
raw_m_padded
,
const
index_t
raw_n_padded
)
const
index_t
raw_n_padded
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
@@ -628,7 +629,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -628,7 +629,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
};
};
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
<
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_
Qloop_
Xdl_CShuffle_V1
<
InputDataType
,
// TODO: distinguish A/B datatype
InputDataType
,
// TODO: distinguish A/B datatype
OutputDataType
,
OutputDataType
,
ZDataType
,
ZDataType
,
...
@@ -1024,7 +1025,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
...
@@ -1024,7 +1025,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
))
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_v2.hpp
View file @
1970d162
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -89,7 +89,8 @@ __global__ void
...
@@ -89,7 +89,8 @@ __global__ void
const
index_t
raw_m_padded
,
const
index_t
raw_m_padded
,
const
index_t
raw_n_padded
)
const
index_t
raw_n_padded
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
@@ -634,7 +635,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -634,7 +635,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
};
};
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
<
using
GridwiseGemm
=
GridwiseBatchedMultiheadAttentionBackward_
Qloop_
Xdl_CShuffle_V2
<
InputDataType
,
// TODO: distinguish A/B datatype
InputDataType
,
// TODO: distinguish A/B datatype
OutputDataType
,
OutputDataType
,
ZDataType
,
ZDataType
,
...
@@ -1055,7 +1056,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
...
@@ -1055,7 +1056,9 @@ struct DeviceBatchedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
))
{
{
return
false
;
return
false
;
}
}
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment