Unverified Commit 8a1b7ee2 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Fix detection of 3 in 3hd/h3d layouts (#1187)



* fix detection of 3 in 3hd/h3d layouts
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* error out when invalid layout group is provided
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent c4a5cb85
......@@ -95,9 +95,21 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
auto qkv_sizes = QKV.sizes().vec();
std::vector<size_t> qkv_shape{qkv_sizes.begin(), qkv_sizes.end()};
std::vector<size_t> q_shape;
for (auto i : qkv_shape) {
if (i != 3) {
q_shape.push_back(i);
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
int loc_3 = 0;
switch (layout_group) {
case NVTE_3HD:
loc_3 = qkv_sizes.size() - 3;
break;
case NVTE_H3D:
loc_3 = qkv_sizes.size() - 2;
break;
default:
NVTE_ERROR("Invalid QKV layout group.");
}
for (auto it = qkv_shape.begin(); it != qkv_shape.end(); ++it) {
if (it - qkv_shape.begin() != loc_3) {
q_shape.push_back(*it);
}
}
std::vector<int64_t> o_shape{q_shape.begin(), q_shape.end()};
......@@ -252,9 +264,21 @@ std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
auto qkv_sizes = QKV.sizes().vec();
std::vector<size_t> qkv_shape{qkv_sizes.begin(), qkv_sizes.end()};
std::vector<size_t> q_shape;
for (auto i : qkv_shape) {
if (i != 3) {
q_shape.push_back(i);
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
int loc_3 = 0;
switch (layout_group) {
case NVTE_3HD:
loc_3 = qkv_sizes.size() - 3;
break;
case NVTE_H3D:
loc_3 = qkv_sizes.size() - 2;
break;
default:
NVTE_ERROR("Invalid QKV layout group.");
}
for (auto it = qkv_shape.begin(); it != qkv_shape.end(); ++it) {
if (it - qkv_shape.begin() != loc_3) {
q_shape.push_back(*it);
}
}
auto h = q_shape[q_shape.size() - 2];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment