Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
393882bc
Commit
393882bc
authored
Mar 29, 2023
by
Tri Dao
Browse files
[LayerNorm] Implement LN with parallel residual, support dim 8k
parent
009a3e71
Changes
46
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1132 additions
and
8 deletions
+1132
-8
csrc/layer_norm/ln_parallel_bwd_8192.cu
csrc/layer_norm/ln_parallel_bwd_8192.cu
+15
-0
csrc/layer_norm/ln_parallel_fwd_1024.cu
csrc/layer_norm/ln_parallel_fwd_1024.cu
+15
-0
csrc/layer_norm/ln_parallel_fwd_1280.cu
csrc/layer_norm/ln_parallel_fwd_1280.cu
+15
-0
csrc/layer_norm/ln_parallel_fwd_1536.cu
csrc/layer_norm/ln_parallel_fwd_1536.cu
+15
-0
csrc/layer_norm/ln_parallel_fwd_2048.cu
csrc/layer_norm/ln_parallel_fwd_2048.cu
+15
-0
csrc/layer_norm/ln_parallel_fwd_256.cu
csrc/layer_norm/ln_parallel_fwd_256.cu
+15
-0
csrc/layer_norm/ln_parallel_fwd_2560.cu
csrc/layer_norm/ln_parallel_fwd_2560.cu
+15
-0
csrc/layer_norm/ln_parallel_fwd_3072.cu
csrc/layer_norm/ln_parallel_fwd_3072.cu
+15
-0
csrc/layer_norm/ln_parallel_fwd_4096.cu
csrc/layer_norm/ln_parallel_fwd_4096.cu
+15
-0
csrc/layer_norm/ln_parallel_fwd_512.cu
csrc/layer_norm/ln_parallel_fwd_512.cu
+15
-0
csrc/layer_norm/ln_parallel_fwd_5120.cu
csrc/layer_norm/ln_parallel_fwd_5120.cu
+15
-0
csrc/layer_norm/ln_parallel_fwd_6144.cu
csrc/layer_norm/ln_parallel_fwd_6144.cu
+15
-0
csrc/layer_norm/ln_parallel_fwd_7168.cu
csrc/layer_norm/ln_parallel_fwd_7168.cu
+15
-0
csrc/layer_norm/ln_parallel_fwd_768.cu
csrc/layer_norm/ln_parallel_fwd_768.cu
+15
-0
csrc/layer_norm/ln_parallel_fwd_8192.cu
csrc/layer_norm/ln_parallel_fwd_8192.cu
+15
-0
csrc/layer_norm/ln_parallel_residual_bwd_kernels.cuh
csrc/layer_norm/ln_parallel_residual_bwd_kernels.cuh
+540
-0
csrc/layer_norm/ln_parallel_residual_fwd_kernels.cuh
csrc/layer_norm/ln_parallel_residual_fwd_kernels.cuh
+281
-0
csrc/layer_norm/ln_utils.cuh
csrc/layer_norm/ln_utils.cuh
+33
-0
csrc/layer_norm/setup.py
csrc/layer_norm/setup.py
+32
-0
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+21
-8
No files found.
csrc/layer_norm/ln_parallel_bwd_8192.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_PARALLEL_BWD_LAUNCHER
(
8192
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
8192
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
8192
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
8192
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
8192
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
8192
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
8192
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
8192
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
8192
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
8192
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
\ No newline at end of file
csrc/layer_norm/ln_parallel_fwd_1024.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_PARALLEL_FWD_LAUNCHER
(
1024
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1024
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1024
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1024
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1024
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1024
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1024
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1024
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1024
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1024
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
csrc/layer_norm/ln_parallel_fwd_1280.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_PARALLEL_FWD_LAUNCHER
(
1280
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1280
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1280
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1280
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1280
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1280
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1280
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1280
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1280
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1280
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
csrc/layer_norm/ln_parallel_fwd_1536.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_PARALLEL_FWD_LAUNCHER
(
1536
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1536
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1536
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1536
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1536
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1536
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1536
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1536
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1536
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1536
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
csrc/layer_norm/ln_parallel_fwd_2048.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_PARALLEL_FWD_LAUNCHER
(
2048
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
2048
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
2048
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
2048
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
2048
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
2048
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
2048
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
2048
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
2048
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
2048
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
csrc/layer_norm/ln_parallel_fwd_256.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_PARALLEL_FWD_LAUNCHER
(
256
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
256
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
256
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
256
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
256
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
256
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
256
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
256
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
256
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
256
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
\ No newline at end of file
csrc/layer_norm/ln_parallel_fwd_2560.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_PARALLEL_FWD_LAUNCHER
(
2560
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
2560
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
2560
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
2560
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
2560
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
2560
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
2560
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
2560
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
2560
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
2560
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
csrc/layer_norm/ln_parallel_fwd_3072.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_PARALLEL_FWD_LAUNCHER
(
3072
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
3072
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
3072
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
3072
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
3072
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
3072
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
3072
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
3072
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
3072
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
3072
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
csrc/layer_norm/ln_parallel_fwd_4096.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_PARALLEL_FWD_LAUNCHER
(
4096
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
4096
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
4096
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
4096
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
4096
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
4096
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
4096
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
4096
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
4096
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
4096
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
csrc/layer_norm/ln_parallel_fwd_512.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_PARALLEL_FWD_LAUNCHER
(
512
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
512
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
512
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
512
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
512
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
512
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
512
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
512
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
512
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
512
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
csrc/layer_norm/ln_parallel_fwd_5120.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_PARALLEL_FWD_LAUNCHER
(
5120
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
5120
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
5120
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
5120
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
5120
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
5120
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
5120
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
5120
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
5120
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
5120
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
csrc/layer_norm/ln_parallel_fwd_6144.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_PARALLEL_FWD_LAUNCHER
(
6144
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
6144
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
6144
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
6144
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
6144
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
6144
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
6144
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
6144
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
6144
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
6144
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
8
,
16
);
csrc/layer_norm/ln_parallel_fwd_7168.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_PARALLEL_FWD_LAUNCHER
(
7168
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
7168
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
7168
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
7168
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
7168
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
7168
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
7168
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
7168
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
7168
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
7168
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
csrc/layer_norm/ln_parallel_fwd_768.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_PARALLEL_FWD_LAUNCHER
(
768
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
768
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
768
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
768
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
768
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
768
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
768
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
768
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
768
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
768
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
csrc/layer_norm/ln_parallel_fwd_8192.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_PARALLEL_FWD_LAUNCHER
(
8192
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
8192
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
8192
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
8192
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
8192
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
8192
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
8192
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
8192
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
8192
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
8192
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
8
,
16
);
csrc/layer_norm/ln_parallel_residual_bwd_kernels.cuh
0 → 100644
View file @
393882bc
#pragma once
#include "ln.h"
#include "ln_utils.cuh"
#include "ln_kernel_traits.h"
#include "static_switch.h"
#include "ln_bwd_kernels.cuh"
namespace
layer_norm
{
template
<
typename
Ktraits
,
bool
Is_dropout
,
bool
Tied_norm
,
bool
Is_even_cols
>
__global__
__launch_bounds__
(
Ktraits
::
THREADS_PER_CTA
)
void
ln_parallel_residual_bwd_kernel
(
layer_norm
::
BwdParams
params
)
{
enum
{
ROWS_PER_CTA
=
Ktraits
::
ROWS_PER_CTA
};
enum
{
WARPS_M
=
Ktraits
::
WARPS_M
};
enum
{
WARPS_N
=
Ktraits
::
WARPS_N
};
enum
{
THREADS_PER_ROW
=
Ktraits
::
THREADS_PER_ROW
};
enum
{
COLS
=
Ktraits
::
COLS
};
enum
{
BYTES_PER_ROW
=
Ktraits
::
BYTES_PER_ROW
};
enum
{
LDGS
=
Ktraits
::
LDGS
};
enum
{
NUM_ELTS
=
Ktraits
::
ELTS_PER_LDG
};
enum
{
THREADS_PER_WARP
=
Ktraits
::
THREADS_PER_WARP
};
enum
{
CTAS_PER_ROW
=
Ktraits
::
CTAS_PER_ROW
};
using
input_t
=
typename
Ktraits
::
input_t
;
using
compute_t
=
typename
Ktraits
::
compute_t
;
using
index_t
=
typename
Ktraits
::
index_t
;
using
mask_t
=
typename
Ktraits
::
mask_t
;
using
Ivec
=
typename
Ktraits
::
Ivec
;
using
Rvec
=
typename
Ktraits
::
Rvec
;
using
Ovec
=
typename
Ktraits
::
Ovec
;
using
Wvec
=
typename
Ktraits
::
Wvec
;
using
Cvec
=
typename
Ktraits
::
Cvec
;
using
Mvec
=
typename
Ktraits
::
Mvec
;
using
Reducer
=
typename
Ktraits
::
Reducer
;
using
reduce_t
=
typename
Reducer
::
Type
;
extern
__shared__
char
smem_
[];
const
bool
has_residual
=
params
.
dresidual
!=
nullptr
;
const
bool
has_x1
=
params
.
dx1
!=
nullptr
;
const
bool
prenorm
=
params
.
dx
!=
nullptr
;
const
index_t
tidx
=
threadIdx
.
x
;
const
index_t
bidn
=
blockIdx
.
x
%
CTAS_PER_ROW
;
const
index_t
bidm
=
blockIdx
.
x
/
CTAS_PER_ROW
;
const
index_t
lane
=
tidx
%
THREADS_PER_WARP
;
const
index_t
warp
=
tidx
/
THREADS_PER_WARP
;
const
index_t
warp_m
=
warp
/
Ktraits
::
WARPS_N
;
const
index_t
warp_n
=
warp
%
Ktraits
::
WARPS_N
;
const
index_t
tid_r
=
warp_n
*
THREADS_PER_WARP
+
lane
;
const
index_t
r
=
bidm
*
Ktraits
::
ROWS_PER_CTA
+
warp_m
;
const
index_t
c
=
bidn
*
THREADS_PER_ROW
+
warp_n
*
THREADS_PER_WARP
+
lane
;
static_assert
(
COLS
==
THREADS_PER_ROW
*
LDGS
*
NUM_ELTS
*
CTAS_PER_ROW
);
Cvec
dz0y_sum
[
LDGS
];
Cvec
dz0_sum
[
LDGS
];
Cvec
dz1y_sum
[
LDGS
];
Cvec
dz1_sum
[
LDGS
];
memset
(
dz0y_sum
,
0
,
sizeof
(
dz0y_sum
));
memset
(
dz0_sum
,
0
,
sizeof
(
dz0_sum
));
if
(
!
Tied_norm
)
{
memset
(
dz1y_sum
,
0
,
sizeof
(
dz1y_sum
));
memset
(
dz1_sum
,
0
,
sizeof
(
dz1_sum
));
}
compute_t
*
smem_wgrad
=
reinterpret_cast
<
compute_t
*>
(
smem_
);
char
*
smem_dgrad
=
smem_
+
Ktraits
::
SMEM_BYTES_WGRAD
;
Reducer
reducer
(
params
,
bidm
,
bidn
,
warp_m
,
warp_n
,
lane
,
smem_dgrad
);
Sum
<
reduce_t
>
sum
;
const
index_t
num_valid_ldgs
=
((
params
.
cols
/
Ktraits
::
ELTS_PER_LDG
)
-
1
-
c
+
Ktraits
::
VEC_COLS_PER_LDG
)
/
Ktraits
::
VEC_COLS_PER_LDG
;
Wvec
gamma0
[
LDGS
];
Wvec
gamma1
[
LDGS
];
index_t
idx
=
c
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
if
(
Is_even_cols
||
(
it
<
num_valid_ldgs
))
{
gamma0
[
it
].
load_from
(
params
.
gamma
,
idx
);
if
(
!
Tied_norm
)
{
gamma1
[
it
].
load_from
(
params
.
gamma1
,
idx
);
}
idx
+=
Ktraits
::
VEC_COLS_PER_LDG
;
}
}
// TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the
// last blocks with syncthreads!
// grid stride over rows
#pragma unroll 1
for
(
int
row
=
r
;
row
<
params
.
rows
;
row
+=
params
.
ctas_per_col
*
ROWS_PER_CTA
)
{
const
compute_t
mu_r
=
static_cast
<
const
compute_t
*>
(
params
.
mu
)[
row
];
const
compute_t
rs_r
=
static_cast
<
const
compute_t
*>
(
params
.
rs
)[
row
];
Mvec
dmask0
[
LDGS
],
dmask1
[
LDGS
];
Rvec
dx
[
LDGS
];
compute_t
dy
[
LDGS
*
NUM_ELTS
];
compute_t
y
[
LDGS
*
NUM_ELTS
];
compute_t
mdy_local
=
0.
f
;
compute_t
mdyy_local
=
0.
f
;
index_t
idx
=
row
*
params
.
cols
/
Ktraits
::
ELTS_PER_LDG
+
c
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
if
(
Is_even_cols
||
(
it
<
num_valid_ldgs
))
{
Rvec
x
;
Ovec
dz0
,
dz1
;
dz0
.
load_from
(
params
.
dz
,
idx
);
if
(
!
Tied_norm
)
{
dz1
.
load_from
(
params
.
dz1
,
idx
);
}
if
(
prenorm
)
{
dx
[
it
].
load_from
(
params
.
dx
,
idx
);
}
x
.
load_from
(
params
.
x
,
idx
);
if
(
Is_dropout
)
{
dmask0
[
it
].
load_from
(
params
.
dmask
,
idx
);
if
(
has_x1
)
{
dmask1
[
it
].
load_from
(
params
.
dmask1
,
idx
);
}
}
idx
+=
Ktraits
::
VEC_COLS_PER_LDG
;
#pragma unroll
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
compute_t
x_tmp
=
x
.
data
.
elt
[
jt
];
compute_t
y_tmp
=
rs_r
*
(
x_tmp
-
(
!
params
.
is_rms_norm
?
mu_r
:
0.
f
));
compute_t
dy_tmp
=
compute_t
(
gamma0
[
it
].
data
.
elt
[
jt
])
*
compute_t
(
dz0
.
data
.
elt
[
jt
]);
if
(
!
Tied_norm
)
{
dy_tmp
+=
compute_t
(
gamma1
[
it
].
data
.
elt
[
jt
])
*
compute_t
(
dz1
.
data
.
elt
[
jt
]);
}
compute_t
dz0_tmp
=
dz0
.
data
.
elt
[
jt
];
compute_t
dz1_tmp
;
if
(
!
Tied_norm
)
{
dz1_tmp
=
dz1
.
data
.
elt
[
jt
];
}
mdy_local
+=
dy_tmp
;
mdyy_local
+=
dy_tmp
*
y_tmp
;
dy
[
it
*
NUM_ELTS
+
jt
]
=
dy_tmp
;
y
[
it
*
NUM_ELTS
+
jt
]
=
y_tmp
;
dz0y_sum
[
it
].
data
.
elt
[
jt
]
+=
dz0_tmp
*
y_tmp
;
dz0_sum
[
it
].
data
.
elt
[
jt
]
+=
dz0_tmp
;
if
(
!
Tied_norm
)
{
dz1y_sum
[
it
].
data
.
elt
[
jt
]
+=
dz1_tmp
*
y_tmp
;
dz1_sum
[
it
].
data
.
elt
[
jt
]
+=
dz1_tmp
;
}
}
}
}
reduce_t
result
=
reducer
.
allreduce
({
mdy_local
,
mdyy_local
},
sum
);
mdy_local
=
layer_norm
::
Get
<
0
>::
of
<
reduce_t
,
compute_t
>
(
result
)
*
params
.
inverse_cols
;
mdyy_local
=
layer_norm
::
Get
<
1
>::
of
<
reduce_t
,
compute_t
>
(
result
)
*
params
.
inverse_cols
;
idx
=
row
*
params
.
cols
/
Ktraits
::
ELTS_PER_LDG
+
c
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
if
(
Is_even_cols
||
(
it
<
num_valid_ldgs
))
{
Ivec
dx0
,
dx1
;
Rvec
dresidual
;
#pragma unroll
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
compute_t
dx_tmp_res
;
compute_t
dy_tmp
=
dy
[
it
*
NUM_ELTS
+
jt
];
compute_t
y_tmp
=
y
[
it
*
NUM_ELTS
+
jt
];
compute_t
dx_tmp
=
rs_r
*
(
dy_tmp
-
(
mdyy_local
*
y_tmp
+
(
!
params
.
is_rms_norm
?
mdy_local
:
0.
f
)));
dx_tmp_res
=
prenorm
?
dx_tmp
+
compute_t
(
dx
[
it
].
data
.
elt
[
jt
])
:
dx_tmp
;
if
(
has_residual
)
{
dresidual
.
data
.
elt
[
jt
]
=
dx_tmp_res
;
}
if
(
Is_dropout
)
{
dx0
.
data
.
elt
[
jt
]
=
dmask0
[
it
].
data
.
elt
[
jt
]
?
dx_tmp_res
*
params
.
dropout_scale
:
0.
f
;
if
(
has_x1
)
{
dx1
.
data
.
elt
[
jt
]
=
dmask1
[
it
].
data
.
elt
[
jt
]
?
dx_tmp_res
*
params
.
dropout_scale
:
0.
f
;
}
}
else
{
dx0
.
data
.
elt
[
jt
]
=
dx_tmp_res
;
if
(
has_x1
)
{
dx1
.
data
.
elt
[
jt
]
=
dx_tmp_res
;
}
}
}
if
(
has_residual
)
{
dresidual
.
store_to
(
params
.
dresidual
,
idx
);
}
dx0
.
store_to
(
params
.
dx0
,
idx
);
if
(
has_x1
)
{
dx1
.
store_to
(
params
.
dx1
,
idx
);
}
idx
+=
Ktraits
::
VEC_COLS_PER_LDG
;
}
}
}
// end: grid stride loop
if
(
WARPS_M
==
1
)
{
idx
=
r
*
params
.
cols
/
Ktraits
::
ELTS_PER_LDG
+
c
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
if
(
Is_even_cols
||
(
it
<
num_valid_ldgs
))
{
dz0_sum
[
it
].
store_to
(
params
.
dbeta_part
,
idx
);
dz0y_sum
[
it
].
store_to
(
params
.
dgamma_part
,
idx
);
if
(
!
Tied_norm
)
{
dz1_sum
[
it
].
store_to
(
params
.
dbeta1_part
,
idx
);
dz1y_sum
[
it
].
store_to
(
params
.
dgamma1_part
,
idx
);
}
idx
+=
Ktraits
::
VEC_COLS_PER_LDG
;
}
}
}
else
{
static_assert
(
WARPS_M
==
1
||
Ktraits
::
CTAS_PER_ROW
==
1
,
"Multiple rows per CTA not supported for Multi-CTA."
);
// Finalize reduction of part dgamma and dbeta for this CTA
// by reducing over the rows held across the WARPS_M warps
// Assumption: blockSize divides hidden size.
enum
{
NUM_RES
=
COLS
/
Ktraits
::
THREADS_PER_CTA
};
static_assert
(
NUM_RES
*
Ktraits
::
THREADS_PER_CTA
==
COLS
,
""
);
idx
=
warp_m
*
Ktraits
::
VEC_COLS
+
tid_r
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
dz0_sum
[
it
].
store_to
(
smem_wgrad
,
idx
);
idx
+=
THREADS_PER_ROW
;
}
__syncthreads
();
compute_t
cta_dz0_sum
[
NUM_RES
];
memset
(
cta_dz0_sum
,
0
,
sizeof
(
compute_t
)
*
NUM_RES
);
for
(
int
it
=
0
;
it
<
ROWS_PER_CTA
;
it
++
)
{
for
(
int
jt
=
0
;
jt
<
NUM_RES
;
jt
++
)
{
cta_dz0_sum
[
jt
]
+=
smem_wgrad
[
it
*
COLS
+
tidx
+
jt
*
Ktraits
::
THREADS_PER_CTA
];
}
}
__syncthreads
();
idx
=
warp_m
*
Ktraits
::
VEC_COLS
+
tid_r
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
dz0y_sum
[
it
].
store_to
(
smem_wgrad
,
idx
);
idx
+=
THREADS_PER_ROW
;
}
__syncthreads
();
compute_t
cta_dz0y_sum
[
NUM_RES
];
memset
(
cta_dz0y_sum
,
0
,
sizeof
(
compute_t
)
*
NUM_RES
);
for
(
int
it
=
0
;
it
<
ROWS_PER_CTA
;
it
++
)
{
for
(
int
jt
=
0
;
jt
<
NUM_RES
;
jt
++
)
{
cta_dz0y_sum
[
jt
]
+=
smem_wgrad
[
it
*
COLS
+
tidx
+
jt
*
Ktraits
::
THREADS_PER_CTA
];
}
}
compute_t
cta_dz1_sum
[
NUM_RES
],
cta_dz1y_sum
[
NUM_RES
];
if
(
!
Tied_norm
)
{
__syncthreads
();
idx
=
warp_m
*
Ktraits
::
VEC_COLS
+
tid_r
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
dz1_sum
[
it
].
store_to
(
smem_wgrad
,
idx
);
idx
+=
THREADS_PER_ROW
;
}
__syncthreads
();
memset
(
cta_dz1_sum
,
0
,
sizeof
(
compute_t
)
*
NUM_RES
);
for
(
int
it
=
0
;
it
<
ROWS_PER_CTA
;
it
++
)
{
for
(
int
jt
=
0
;
jt
<
NUM_RES
;
jt
++
)
{
cta_dz1_sum
[
jt
]
+=
smem_wgrad
[
it
*
COLS
+
tidx
+
jt
*
Ktraits
::
THREADS_PER_CTA
];
}
}
__syncthreads
();
idx
=
warp_m
*
Ktraits
::
VEC_COLS
+
tid_r
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
dz1y_sum
[
it
].
store_to
(
smem_wgrad
,
idx
);
idx
+=
THREADS_PER_ROW
;
}
__syncthreads
();
memset
(
cta_dz1y_sum
,
0
,
sizeof
(
compute_t
)
*
NUM_RES
);
for
(
int
it
=
0
;
it
<
ROWS_PER_CTA
;
it
++
)
{
for
(
int
jt
=
0
;
jt
<
NUM_RES
;
jt
++
)
{
cta_dz1y_sum
[
jt
]
+=
smem_wgrad
[
it
*
COLS
+
tidx
+
jt
*
Ktraits
::
THREADS_PER_CTA
];
}
}
}
const
index_t
num_valid_writes
=
(
params
.
cols
-
1
-
tidx
+
Ktraits
::
THREADS_PER_CTA
)
/
Ktraits
::
THREADS_PER_CTA
;
compute_t
*
dgamma0_part
=
static_cast
<
compute_t
*>
(
params
.
dgamma_part
)
+
bidm
*
params
.
cols
+
tidx
;
compute_t
*
dbeta0_part
=
static_cast
<
compute_t
*>
(
params
.
dbeta_part
)
+
bidm
*
params
.
cols
+
tidx
;
compute_t
*
dgamma1_part
=
!
Tied_norm
?
static_cast
<
compute_t
*>
(
params
.
dgamma1_part
)
+
bidm
*
params
.
cols
+
tidx
:
nullptr
;
compute_t
*
dbeta1_part
=
!
Tied_norm
?
static_cast
<
compute_t
*>
(
params
.
dbeta1_part
)
+
bidm
*
params
.
cols
+
tidx
:
nullptr
;
for
(
int
jt
=
0
;
jt
<
NUM_RES
;
jt
++
)
{
if
(
Is_even_cols
||
(
jt
<
num_valid_writes
))
{
*
dgamma0_part
=
cta_dz0y_sum
[
jt
];
dgamma0_part
+=
Ktraits
::
THREADS_PER_CTA
;
*
dbeta0_part
=
cta_dz0_sum
[
jt
];
dbeta0_part
+=
Ktraits
::
THREADS_PER_CTA
;
if
(
!
Tied_norm
)
{
*
dgamma1_part
=
cta_dz1y_sum
[
jt
];
dgamma1_part
+=
Ktraits
::
THREADS_PER_CTA
;
*
dbeta1_part
=
cta_dz1_sum
[
jt
];
dbeta1_part
+=
Ktraits
::
THREADS_PER_CTA
;
}
}
}
}
}
template
<
typename
Kernel_traits
,
bool
Is_even_cols
>
__global__
__launch_bounds__
(
Kernel_traits
::
THREADS_PER_CTA
)
void
ln_parallel_residual_bwd_finalize_kernel
(
BwdParams
params
)
{
using
compute_t
=
typename
Kernel_traits
::
compute_t
;
using
weight_t
=
typename
Kernel_traits
::
weight_t
;
using
index_t
=
typename
Kernel_traits
::
index_t
;
using
Reducer
=
typename
Kernel_traits
::
Reducer
;
using
reduce_t
=
typename
Reducer
::
Type
;
Sum
<
reduce_t
>
sum
;
enum
{
NUM_ELT
=
Kernel_traits
::
ELTS_PER_LDG
};
enum
{
THREADS_PER_WARP
=
Kernel_traits
::
THREADS_PER_WARP
};
// Multiplying by 2 since we have both gamma0 and gamma1
__shared__
char
smem_
[
2
*
Kernel_traits
::
SMEM_BYTES_PER_CTA
];
constexpr
uint32_t
bidm
=
0
;
const
uint32_t
bidn
=
blockIdx
.
x
;
const
uint32_t
tidx
=
threadIdx
.
x
;
const
uint32_t
warp
=
tidx
/
THREADS_PER_WARP
;
const
uint32_t
lane
=
tidx
%
THREADS_PER_WARP
;
Reducer
reducer
(
params
,
bidm
,
bidn
,
0
,
0
,
lane
,
smem_
);
const
uint32_t
c
=
bidn
*
THREADS_PER_WARP
+
lane
;
const
uint32_t
c_out
=
bidn
*
THREADS_PER_WARP
/
2
+
lane
;
constexpr
uint32_t
COL_STRIDE
=
Kernel_traits
::
CTAS
*
THREADS_PER_WARP
;
for
(
uint32_t
col
=
c
,
col_out
=
c_out
;
col
<
Kernel_traits
::
COLS
;
col
+=
COL_STRIDE
,
col_out
+=
COL_STRIDE
/
2
)
{
// Each thread sums over NUM_ELT columns.
Vec
<
compute_t
,
NUM_ELT
>
dbeta0_local
,
dgamma0_local
,
dbeta1_local
,
dgamma1_local
;
memset
(
&
dgamma0_local
,
0
,
sizeof
(
dgamma0_local
));
memset
(
&
dbeta0_local
,
0
,
sizeof
(
dbeta0_local
));
memset
(
&
dgamma1_local
,
0
,
sizeof
(
dgamma1_local
));
memset
(
&
dbeta1_local
,
0
,
sizeof
(
dbeta1_local
));
if
(
Is_even_cols
||
col
<
params
.
cols
)
{
for
(
uint32_t
row
=
warp
;
row
<
params
.
ctas_per_col
;
row
+=
Kernel_traits
::
ROWS_PER_CTA
)
{
index_t
idx
=
row
*
params
.
cols
+
col
;
Vec
<
compute_t
,
NUM_ELT
>
dbeta0_part
,
dgamma0_part
,
dbeta1_part
,
dgamma1_part
;
dbeta0_part
.
load_from
(
params
.
dbeta_part
,
idx
);
dgamma0_part
.
load_from
(
params
.
dgamma_part
,
idx
);
dbeta1_part
.
load_from
(
params
.
dbeta1_part
,
idx
);
dgamma1_part
.
load_from
(
params
.
dgamma1_part
,
idx
);
#pragma unroll
for
(
int
it
=
0
;
it
<
NUM_ELT
;
it
++
)
{
dgamma0_local
.
data
.
elt
[
it
]
+=
dgamma0_part
.
data
.
elt
[
it
];
dbeta0_local
.
data
.
elt
[
it
]
+=
dbeta0_part
.
data
.
elt
[
it
];
dgamma1_local
.
data
.
elt
[
it
]
+=
dgamma1_part
.
data
.
elt
[
it
];
dbeta1_local
.
data
.
elt
[
it
]
+=
dbeta1_part
.
data
.
elt
[
it
];
}
}
}
void
*
smem_gamma0
=
smem_
;
void
*
smem_beta0
=
&
smem_
[
Kernel_traits
::
SMEM_BYTES_TRANSPOSE
];
void
*
smem_gamma1
=
&
smem_
[
2
*
Kernel_traits
::
SMEM_BYTES_TRANSPOSE
];
void
*
smem_beta1
=
&
smem_
[
3
*
Kernel_traits
::
SMEM_BYTES_TRANSPOSE
];
const
int
write_row
=
warp
;
const
int
write_col
=
lane
^
write_row
;
const
int
write_idx
=
write_row
*
THREADS_PER_WARP
+
write_col
;
dgamma0_local
.
store_to
(
smem_gamma0
,
write_idx
);
dbeta0_local
.
store_to
(
smem_beta0
,
write_idx
);
dgamma1_local
.
store_to
(
smem_gamma1
,
write_idx
);
dbeta1_local
.
store_to
(
smem_beta1
,
write_idx
);
__syncthreads
();
// It would be probably safe to reuse the first row of smem_beta0 and smem_gamma0
void
*
smem_gamma0_out
=
&
smem_
[
4
*
Kernel_traits
::
SMEM_BYTES_TRANSPOSE
];
void
*
smem_beta0_out
=
&
smem_
[
4
*
Kernel_traits
::
SMEM_BYTES_TRANSPOSE
+
Kernel_traits
::
SMEM_BYTES_OUTPUT
];
void
*
smem_gamma1_out
=
&
smem_
[
4
*
Kernel_traits
::
SMEM_BYTES_TRANSPOSE
+
2
*
Kernel_traits
::
SMEM_BYTES_OUTPUT
];
void
*
smem_beta1_out
=
&
smem_
[
4
*
Kernel_traits
::
SMEM_BYTES_TRANSPOSE
+
3
*
Kernel_traits
::
SMEM_BYTES_OUTPUT
];
// More than one iter iff ROWS_PER_CTA < 32.
for
(
int
w
=
warp
;
w
<
THREADS_PER_WARP
;
w
+=
Kernel_traits
::
ROWS_PER_CTA
)
{
const
int
read_row
=
lane
;
const
int
read_col
=
w
^
read_row
;
const
int
read_idx
=
read_row
*
THREADS_PER_WARP
+
read_col
;
memset
(
&
dbeta0_local
,
0
,
sizeof
(
dbeta0_local
));
memset
(
&
dgamma0_local
,
0
,
sizeof
(
dgamma0_local
));
memset
(
&
dbeta1_local
,
0
,
sizeof
(
dbeta1_local
));
memset
(
&
dgamma1_local
,
0
,
sizeof
(
dgamma1_local
));
// Load beta and gamma transposed
if
(
read_row
<
Kernel_traits
::
ROWS_PER_CTA
){
dbeta0_local
.
load_from
(
smem_beta0
,
read_idx
);
dgamma0_local
.
load_from
(
smem_gamma0
,
read_idx
);
dbeta1_local
.
load_from
(
smem_beta1
,
read_idx
);
dgamma1_local
.
load_from
(
smem_gamma1
,
read_idx
);
}
// Call reducer on the loaded value(s) and convert.
#pragma unroll
for
(
int
it
=
0
;
it
<
NUM_ELT
;
it
++
)
{
compute_t
b0_i
=
dbeta0_local
.
data
.
elt
[
it
];
compute_t
g0_i
=
dgamma0_local
.
data
.
elt
[
it
];
compute_t
b1_i
=
dbeta1_local
.
data
.
elt
[
it
];
compute_t
g1_i
=
dgamma1_local
.
data
.
elt
[
it
];
b0_i
=
reducer
.
allreduce
(
b0_i
,
sum
);
g0_i
=
reducer
.
allreduce
(
g0_i
,
sum
);
b1_i
=
reducer
.
allreduce
(
b1_i
,
sum
);
g1_i
=
reducer
.
allreduce
(
g1_i
,
sum
);
dgamma0_local
.
data
.
elt
[
it
]
=
g0_i
;
dbeta0_local
.
data
.
elt
[
it
]
=
b0_i
;
dgamma1_local
.
data
.
elt
[
it
]
=
g1_i
;
dbeta1_local
.
data
.
elt
[
it
]
=
b1_i
;
}
// Leader stores the result at the current column.
if
(
lane
==
0
){
dgamma0_local
.
store_to
(
smem_gamma0_out
,
w
);
dbeta0_local
.
store_to
(
smem_beta0_out
,
w
);
dgamma1_local
.
store_to
(
smem_gamma1_out
,
w
);
dbeta1_local
.
store_to
(
smem_beta1_out
,
w
);
}
}
// All writes done.
__syncthreads
();
// Pack and store: 2-wide stores with half the threads.
if
(
Is_even_cols
||
col_out
*
2
<
params
.
cols
)
{
if
(
warp
==
Kernel_traits
::
ROWS_PER_CTA
-
1
&&
lane
<
THREADS_PER_WARP
/
2
)
{
using
src_t
=
typename
TypeToVec2
<
compute_t
>::
Type
;
using
dst_t
=
typename
TypeToVec2
<
weight_t
>::
Type
;
Vec
<
src_t
,
NUM_ELT
>
dbeta0_vec2
,
dgamma0_vec2
,
dbeta1_vec2
,
dgamma1_vec2
;
Vec
<
dst_t
,
NUM_ELT
>
dbeta0_out2
,
dgamma0_out2
,
dbeta1_out2
,
dgamma1_out2
;
dgamma0_vec2
.
load_from
(
smem_gamma0_out
,
lane
);
dbeta0_vec2
.
load_from
(
smem_beta0_out
,
lane
);
dgamma1_vec2
.
load_from
(
smem_gamma1_out
,
lane
);
dbeta1_vec2
.
load_from
(
smem_beta1_out
,
lane
);
#pragma unroll
for
(
int
it
=
0
;
it
<
NUM_ELT
;
it
++
)
{
dgamma0_out2
.
data
.
elt
[
it
]
=
Converter
<
src_t
,
dst_t
>::
convert
(
dgamma0_vec2
.
data
.
elt
[
it
]);
dbeta0_out2
.
data
.
elt
[
it
]
=
Converter
<
src_t
,
dst_t
>::
convert
(
dbeta0_vec2
.
data
.
elt
[
it
]);
dgamma1_out2
.
data
.
elt
[
it
]
=
Converter
<
src_t
,
dst_t
>::
convert
(
dgamma1_vec2
.
data
.
elt
[
it
]);
dbeta1_out2
.
data
.
elt
[
it
]
=
Converter
<
src_t
,
dst_t
>::
convert
(
dbeta1_vec2
.
data
.
elt
[
it
]);
}
dgamma0_out2
.
store_to
(
params
.
dgamma
,
col_out
);
dbeta0_out2
.
store_to
(
params
.
dbeta
,
col_out
);
dgamma1_out2
.
store_to
(
params
.
dgamma1
,
col_out
);
dbeta1_out2
.
store_to
(
params
.
dbeta1
,
col_out
);
}
}
}
}
}
// namespace layer_norm
using
namespace
layer_norm
;
template
<
typename
weight_t
,
typename
input_t
,
typename
residual_t
,
typename
output_t
,
typename
compute_t
,
typename
index_t
,
int
HIDDEN_SIZE
,
int
CTAS_PER_ROW
,
int
WARPS_M
,
int
WARPS_N
,
int
BYTES_PER_LDG_MAIN
,
int
BYTES_PER_LDG_FINAL
>
void
launch_parallel_residual_
(
LaunchParams
<
BwdParams
>
&
launch_params
,
const
bool
configure_params
){
using
Kernel_traits
=
Kernel_traits
<
weight_t
,
input_t
,
residual_t
,
output_t
,
compute_t
,
index_t
,
HIDDEN_SIZE
,
CTAS_PER_ROW
,
WARPS_M
,
WARPS_N
,
BYTES_PER_LDG_MAIN
>
;
bool
is_dropout
=
launch_params
.
params
.
dropout_keep_p
<
1.
f
;
bool
tied_norm
=
launch_params
.
params
.
gamma1
==
nullptr
;
bool
is_even_cols
=
launch_params
.
params
.
cols
==
HIDDEN_SIZE
;
BOOL_SWITCH
(
is_dropout
,
IsDropoutConst
,
[
&
]
{
BOOL_SWITCH
(
tied_norm
,
TiedNormConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_cols
,
IsEvenColsConst
,
[
&
]
{
auto
kernel
=
&
ln_parallel_residual_bwd_kernel
<
Kernel_traits
,
IsDropoutConst
,
TiedNormConst
,
IsEvenColsConst
>
;
if
(
configure_params
)
{
int
ctas_per_sm
;
CHECK_CUDA
(
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES
));
launch_params
.
params
.
ctas_per_col
=
launch_params
.
props
->
multiProcessorCount
*
ctas_per_sm
/
Kernel_traits
::
CTAS_PER_ROW
;
launch_params
.
barrier_size
=
0
;
launch_params
.
workspace_bytes
=
0
;
if
(
Kernel_traits
::
CTAS_PER_ROW
>
1
)
{
launch_params
.
barrier_size
=
2
*
launch_params
.
params
.
ctas_per_col
;
launch_params
.
workspace_bytes
=
launch_params
.
params
.
ctas_per_col
*
Kernel_traits
::
WARPS_M
*
Kernel_traits
::
CTAS_PER_ROW
*
sizeof
(
typename
Kernel_traits
::
reduce_t
)
*
2
;
}
return
;
}
if
(
Kernel_traits
::
SMEM_BYTES
>=
48
*
1024
)
{
CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
Kernel_traits
::
SMEM_BYTES
));
}
auto
stream
=
launch_params
.
stream
;
auto
ctas_per_col
=
launch_params
.
params
.
ctas_per_col
;
if
(
Kernel_traits
::
CTAS_PER_ROW
==
1
)
{
kernel
<<<
ctas_per_col
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES
,
stream
>>>
(
launch_params
.
params
);
}
else
{
dim3
grid
(
Kernel_traits
::
CTAS_PER_ROW
*
ctas_per_col
);
dim3
block
(
Kernel_traits
::
THREADS_PER_CTA
);
void
*
params_
=
(
void
*
)
&
launch_params
.
params
;
cudaLaunchCooperativeKernel
((
void
*
)
kernel
,
grid
,
block
,
(
void
**
)
&
params_
,
Kernel_traits
::
SMEM_BYTES
,
stream
);
}
using
Kernel_traits_f
=
layer_norm
::
Kernel_traits_finalize
<
HIDDEN_SIZE
,
weight_t
,
input_t
,
residual_t
,
output_t
,
compute_t
,
index_t
,
/*HasColscaleConst=*/
false
,
32
*
32
,
// THREADS_PER_CTA
BYTES_PER_LDG_FINAL
>
;
auto
kernel_f
=
!
TiedNormConst
?
&
layer_norm
::
ln_parallel_residual_bwd_finalize_kernel
<
Kernel_traits_f
,
IsEvenColsConst
>
:
&
layer_norm
::
ln_bwd_finalize_kernel
<
Kernel_traits_f
,
/*HasColscaleConst=*/
false
,
IsEvenColsConst
>
;
kernel_f
<<<
Kernel_traits_f
::
CTAS
,
Kernel_traits_f
::
THREADS_PER_CTA
,
0
,
stream
>>>
(
launch_params
.
params
);
});
});
});
}
csrc/layer_norm/ln_parallel_residual_fwd_kernels.cuh
0 → 100644
View file @
393882bc
#pragma once
#ifdef OLD_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h>
#else
#include <ATen/cuda/CUDAGeneratorImpl.h>
#endif
#include <ATen/cuda/detail/UnpackRaw.cuh> // For at::cuda::philox::unpack
#include <curand_kernel.h>
#include "ln.h"
#include "ln_utils.cuh"
#include "ln_kernel_traits.h"
#include "static_switch.h"
namespace
layer_norm
{
template
<
typename
Ktraits
,
bool
Is_dropout
,
bool
Tied_norm
,
bool
Is_even_cols
>
__global__
__launch_bounds__
(
Ktraits
::
THREADS_PER_CTA
)
void
ln_parallel_residual_fwd_kernel
(
FwdParams
params
)
{
enum
{
ROWS_PER_CTA
=
Ktraits
::
ROWS_PER_CTA
};
enum
{
WARPS_N
=
Ktraits
::
WARPS_N
};
enum
{
WARPS_M
=
Ktraits
::
WARPS_M
};
enum
{
THREADS_PER_ROW
=
Ktraits
::
THREADS_PER_ROW
};
enum
{
VEC_COLS_PER_LDG
=
Ktraits
::
VEC_COLS_PER_LDG
};
enum
{
BYTES_PER_ROW
=
Ktraits
::
BYTES_PER_ROW
};
enum
{
LDGS
=
Ktraits
::
LDGS
};
enum
{
NUM_ELTS
=
Ktraits
::
NUM_ELTS
};
enum
{
CTAS_PER_ROW
=
Ktraits
::
CTAS_PER_ROW
};
using
input_t
=
typename
Ktraits
::
input_t
;
using
residual_t
=
typename
Ktraits
::
residual_t
;
using
output_t
=
typename
Ktraits
::
output_t
;
using
index_t
=
typename
Ktraits
::
index_t
;
using
compute_t
=
typename
Ktraits
::
compute_t
;
using
mask_t
=
typename
Ktraits
::
mask_t
;
using
Ivec
=
typename
Ktraits
::
Ivec
;
using
Rvec
=
typename
Ktraits
::
Rvec
;
using
Ovec
=
typename
Ktraits
::
Ovec
;
using
Wvec
=
typename
Ktraits
::
Wvec
;
using
Cvec
=
typename
Ktraits
::
Cvec
;
using
Mvec
=
typename
Ktraits
::
Mvec
;
using
Stats
=
typename
Ktraits
::
Stats
;
using
stats_t
=
typename
Stats
::
stats_t
;
const
bool
has_residual
=
params
.
residual
!=
nullptr
;
const
bool
has_x1
=
params
.
x1
!=
nullptr
;
const
bool
save_x
=
has_residual
||
has_x1
||
Is_dropout
||
!
(
std
::
is_same
<
input_t
,
residual_t
>::
value
);
extern
__shared__
char
smem_
[];
const
index_t
tidx
=
threadIdx
.
x
;
const
index_t
bidn
=
blockIdx
.
x
%
CTAS_PER_ROW
;
const
index_t
bidm
=
blockIdx
.
x
/
CTAS_PER_ROW
;
const
index_t
lane
=
tidx
%
THREADS_PER_WARP
;
const
index_t
warp
=
tidx
/
THREADS_PER_WARP
;
const
index_t
warp_m
=
warp
/
WARPS_N
;
const
index_t
warp_n
=
warp
%
WARPS_N
;
const
index_t
r
=
bidm
*
ROWS_PER_CTA
+
warp_m
;
const
index_t
c
=
bidn
*
THREADS_PER_ROW
+
warp_n
*
THREADS_PER_WARP
+
lane
;
Stats
stats
(
params
,
bidm
,
bidn
,
warp_m
,
warp_n
,
lane
,
smem_
);
compute_t
*
mu_ptr
=
static_cast
<
compute_t
*>
(
params
.
mu
);
compute_t
*
rs_ptr
=
static_cast
<
compute_t
*>
(
params
.
rs
);
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/Dropout.cu
curandStatePhilox4_32_10_t
state
;
if
(
Is_dropout
)
{
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
const
index_t
tidx_global
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
curand_init
(
std
::
get
<
0
>
(
seeds
),
tidx_global
,
std
::
get
<
1
>
(
seeds
),
&
state
);
}
const
index_t
num_valid_ldgs
=
((
params
.
cols
/
Ktraits
::
ELTS_PER_LDG
)
-
1
-
c
+
VEC_COLS_PER_LDG
)
/
VEC_COLS_PER_LDG
;
Wvec
gamma0
[
LDGS
];
Wvec
beta0
[
LDGS
];
Wvec
gamma1
[
LDGS
];
Wvec
beta1
[
LDGS
];
index_t
idx
=
c
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
if
(
Is_even_cols
||
(
it
<
num_valid_ldgs
))
{
gamma0
[
it
].
load_from
(
params
.
gamma
,
idx
);
if
(
params
.
beta
!=
nullptr
)
{
beta0
[
it
].
load_from
(
params
.
beta
,
idx
);
}
else
{
beta0
[
it
].
zero_
();
}
if
(
!
Tied_norm
)
{
gamma1
[
it
].
load_from
(
params
.
gamma1
,
idx
);
if
(
params
.
beta1
!=
nullptr
)
{
beta1
[
it
].
load_from
(
params
.
beta1
,
idx
);
}
else
{
beta1
[
it
].
zero_
();
}
}
idx
+=
VEC_COLS_PER_LDG
;
}
}
for
(
int
row
=
r
;
row
<
params
.
rows
;
row
+=
params
.
ctas_per_col
*
ROWS_PER_CTA
)
{
index_t
idx
=
row
*
params
.
cols
/
Ktraits
::
ELTS_PER_LDG
+
c
;
compute_t
xf
[
LDGS
*
NUM_ELTS
];
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
if
(
Is_even_cols
||
(
it
<
num_valid_ldgs
))
{
Ivec
x0
;
Ivec
x1
;
Rvec
residual
;
Rvec
x
;
Mvec
dmask0
;
Mvec
dmask1
;
x0
.
load_from
(
params
.
x0
,
idx
);
if
(
has_x1
)
{
x1
.
load_from
(
params
.
x1
,
idx
);
}
if
(
has_residual
)
{
residual
.
load_from
(
params
.
residual
,
idx
);
}
#pragma unroll
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
// TD [2022-04-22]: We're memory bound, not compute bound, so we don't need to use
// the more efficient curand_uniform4.
compute_t
x_ij
;
mask_t
keep0
=
!
Is_dropout
?
true
:
curand_uniform
(
&
state
)
<=
params
.
dropout_keep_p
;
if
(
Is_dropout
)
{
dmask0
.
data
.
elt
[
jt
]
=
keep0
;
}
compute_t
x0_ij
=
compute_t
(
x0
.
data
.
elt
[
jt
]);
x0_ij
=
keep0
?
(
Is_dropout
?
x0_ij
*
params
.
dropout_scale
:
x0_ij
)
:
0.0
f
;
if
(
has_x1
)
{
mask_t
keep1
=
!
Is_dropout
?
true
:
curand_uniform
(
&
state
)
<=
params
.
dropout_keep_p
;
if
(
Is_dropout
)
{
dmask1
.
data
.
elt
[
jt
]
=
keep1
;
}
compute_t
x1_ij
=
compute_t
(
x1
.
data
.
elt
[
jt
]);
x1_ij
=
keep1
?
(
Is_dropout
?
x1_ij
*
params
.
dropout_scale
:
x1_ij
)
:
0.0
f
;
x_ij
=
has_residual
?
x0_ij
+
x1_ij
+
compute_t
(
residual
.
data
.
elt
[
jt
])
:
x0_ij
+
x1_ij
;
}
else
{
x_ij
=
has_residual
?
x0_ij
+
compute_t
(
residual
.
data
.
elt
[
jt
])
:
x0_ij
;
}
if
(
save_x
)
{
x
.
data
.
elt
[
jt
]
=
x_ij
;
}
xf
[
it
*
NUM_ELTS
+
jt
]
=
x_ij
;
}
if
(
save_x
)
{
x
.
store_to
(
params
.
x
,
idx
);
}
if
(
Is_dropout
)
{
dmask0
.
store_to
(
params
.
dmask
,
idx
);
if
(
has_x1
)
{
dmask1
.
store_to
(
params
.
dmask1
,
idx
);
}
}
idx
+=
VEC_COLS_PER_LDG
;
}
}
static_assert
(
CTAS_PER_ROW
==
1
,
"Don't support multiple CTAs per row for now"
);
const
index_t
num_vecs
=
params
.
cols
/
Ktraits
::
ELTS_PER_LDG
;
const
index_t
num_full_ldgs
=
num_vecs
/
Ktraits
::
VEC_COLS_PER_LDG
;
const
index_t
remaining_vecs
=
num_vecs
%
Ktraits
::
VEC_COLS_PER_LDG
;
auto
valid_elts_in_warp_fn
=
[
num_full_ldgs
,
remaining_vecs
]
(
int
warp_n
)
->
int
{
// Need to convert to int, otherwise the subtraction will wrap around.
const
index_t
valid_partial_vecs_in_warp
=
std
::
min
(
std
::
max
(
int
(
remaining_vecs
)
-
int
(
warp_n
*
THREADS_PER_WARP
),
int
(
0
)),
int
(
THREADS_PER_WARP
));
return
(
num_full_ldgs
*
THREADS_PER_WARP
+
valid_partial_vecs_in_warp
)
*
NUM_ELTS
;
};
stats_t
s
=
stats
.
template
compute
<
Is_even_cols
>(
xf
,
params
.
inverse_cols
,
valid_elts_in_warp_fn
,
num_valid_ldgs
*
NUM_ELTS
);
compute_t
mu
=
layer_norm
::
Get
<
0
>::
of
<
stats_t
,
compute_t
>
(
s
);
compute_t
m2
=
layer_norm
::
Get
<
1
>::
of
<
stats_t
,
compute_t
>
(
s
);
if
(
bidn
==
0
&&
warp_n
==
0
&&
lane
==
0
)
{
mu_ptr
[
row
]
=
mu
;
}
compute_t
rs
=
rsqrtf
(
m2
*
params
.
inverse_cols
+
params
.
epsilon
+
(
!
params
.
is_rms_norm
?
0.
f
:
mu
*
mu
));
if
(
bidn
==
0
&&
warp_n
==
0
&&
lane
==
0
)
{
rs_ptr
[
row
]
=
rs
;
}
idx
=
row
*
params
.
cols
/
Ktraits
::
ELTS_PER_LDG
+
c
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
if
(
Is_even_cols
||
(
it
<
num_valid_ldgs
))
{
Ovec
z0
;
Ovec
z1
;
#pragma unroll
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
compute_t
y_ij
=
compute_t
(
rs
*
(
xf
[
it
*
NUM_ELTS
+
jt
]
-
(
!
params
.
is_rms_norm
?
mu
:
0.
f
)));
compute_t
g0_ij
=
gamma0
[
it
].
data
.
elt
[
jt
];
compute_t
b0_ij
=
beta0
[
it
].
data
.
elt
[
jt
];
z0
.
data
.
elt
[
jt
]
=
output_t
(
g0_ij
*
y_ij
+
b0_ij
);
if
(
!
Tied_norm
)
{
compute_t
g1_ij
=
gamma1
[
it
].
data
.
elt
[
jt
];
compute_t
b1_ij
=
beta1
[
it
].
data
.
elt
[
jt
];
z1
.
data
.
elt
[
jt
]
=
output_t
(
g1_ij
*
y_ij
+
b1_ij
);
}
}
z0
.
store_to
(
params
.
z
,
idx
);
if
(
!
Tied_norm
)
{
z1
.
store_to
(
params
.
z1
,
idx
);
}
idx
+=
VEC_COLS_PER_LDG
;
}
}
}
}
}
// namespace layer_norm
using
namespace
layer_norm
;
template
<
typename
weight_t
,
typename
input_t
,
typename
residual_t
,
typename
output_t
,
typename
compute_t
,
typename
index_t
,
int
HIDDEN_SIZE
,
int
CTAS_PER_ROW
,
int
WARPS_M
,
int
WARPS_N
,
int
BYTES_PER_LDG
>
void
launch_parallel_residual_
(
LaunchParams
<
FwdParams
>
&
launch_params
,
const
bool
configure_params
){
using
Kernel_traits
=
Kernel_traits
<
weight_t
,
input_t
,
residual_t
,
output_t
,
compute_t
,
index_t
,
HIDDEN_SIZE
,
CTAS_PER_ROW
,
WARPS_M
,
WARPS_N
,
BYTES_PER_LDG
>
;
bool
is_even_cols
=
launch_params
.
params
.
cols
==
HIDDEN_SIZE
;
bool
tied_norm
=
launch_params
.
params
.
gamma1
==
nullptr
;
BOOL_SWITCH
(
launch_params
.
params
.
dropout_keep_p
<
1.
f
,
IsDropoutConst
,
[
&
]
{
BOOL_SWITCH
(
tied_norm
,
TiedNormConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_cols
,
IsEvenColsConst
,
[
&
]
{
auto
kernel
=
&
ln_parallel_residual_fwd_kernel
<
Kernel_traits
,
IsDropoutConst
,
TiedNormConst
,
IsEvenColsConst
>
;
if
(
configure_params
)
{
int
ctas_per_sm
;
CHECK_CUDA
(
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES_FWD
));
launch_params
.
params
.
ctas_per_col
=
launch_params
.
props
->
multiProcessorCount
*
ctas_per_sm
/
Kernel_traits
::
CTAS_PER_ROW
;
const
size_t
rows_per_loop
=
launch_params
.
params
.
ctas_per_col
*
Kernel_traits
::
ROWS_PER_CTA
;
launch_params
.
elts_per_thread
=
(
launch_params
.
params
.
rows
+
rows_per_loop
-
1
)
/
rows_per_loop
*
Kernel_traits
::
LDGS
*
Kernel_traits
::
NUM_ELTS
;
launch_params
.
barrier_size
=
0
;
launch_params
.
workspace_bytes
=
0
;
if
(
Kernel_traits
::
CTAS_PER_ROW
>
1
)
{
launch_params
.
barrier_size
=
2
*
launch_params
.
params
.
ctas_per_col
;
launch_params
.
workspace_bytes
=
launch_params
.
params
.
ctas_per_col
*
Kernel_traits
::
WARPS_M
*
Kernel_traits
::
CTAS_PER_ROW
*
sizeof
(
typename
Kernel_traits
::
Stats
::
stats_t
)
*
2
;
}
return
;
}
if
(
Kernel_traits
::
SMEM_BYTES_FWD
>=
48
*
1024
)
{
CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
Kernel_traits
::
SMEM_BYTES_FWD
));
}
auto
stream
=
launch_params
.
stream
;
auto
ctas_per_col
=
launch_params
.
params
.
ctas_per_col
;
if
(
Kernel_traits
::
CTAS_PER_ROW
==
1
)
{
kernel
<<<
ctas_per_col
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES_FWD
,
stream
>>>
(
launch_params
.
params
);
}
else
{
dim3
grid
(
Kernel_traits
::
CTAS_PER_ROW
*
ctas_per_col
);
dim3
block
(
Kernel_traits
::
THREADS_PER_CTA
);
void
*
params_
=
(
void
*
)
&
launch_params
.
params
;
cudaLaunchCooperativeKernel
((
void
*
)
kernel
,
grid
,
block
,
(
void
**
)
&
params_
,
Kernel_traits
::
SMEM_BYTES_FWD
,
stream
);
}
});
});
});
}
csrc/layer_norm/ln_utils.cuh
View file @
393882bc
...
...
@@ -64,6 +64,39 @@ inline void check_cuda_(cudaError_t status, const char *file, int line) {
////////////////////////////////////////////////////////////////////////////////////////////////////
#define REGISTER_PARALLEL_FWD_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG) \
void ln_parallel_residual_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams<FwdParams> &launch_params, \
const bool configure_params) { \
launch_parallel_residual_<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG>( \
launch_params, configure_params); \
} \
static FwdParallelRegistrar<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, HIDDEN_SIZE> reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \
ln_parallel_residual_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE)
////////////////////////////////////////////////////////////////////////////////////////////////////
#define REGISTER_PARALLEL_BWD_LAUNCHER( \
HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \
void ln_parallel_residual_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams<BwdParams> &launch_params, \
const bool configure_params) { \
launch_parallel_residual_<WTYPE, \
ITYPE, \
RTYPE, \
OTYPE, \
CTYPE, \
uint32_t, \
HIDDEN_SIZE, \
CTAS_PER_ROW, \
WARPS_M, \
WARPS_N, \
BYTES_PER_LDG, \
BYTES_PER_LDG_FINALIZE>(launch_params, configure_params); \
} \
static BwdParallelRegistrar<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, HIDDEN_SIZE> reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \
ln_parallel_residual_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE)
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float2
operator
+
(
const
float2
&
a
,
const
float2
&
b
){
return
{
a
.
x
+
b
.
x
,
a
.
y
+
b
.
y
};
}
...
...
csrc/layer_norm/setup.py
View file @
393882bc
...
...
@@ -139,6 +139,38 @@ ext_modules.append(
"ln_bwd_5120.cu"
,
"ln_fwd_6144.cu"
,
"ln_bwd_6144.cu"
,
"ln_fwd_7168.cu"
,
"ln_bwd_7168.cu"
,
"ln_fwd_8192.cu"
,
"ln_bwd_8192.cu"
,
"ln_parallel_fwd_256.cu"
,
"ln_parallel_bwd_256.cu"
,
"ln_parallel_fwd_512.cu"
,
"ln_parallel_bwd_512.cu"
,
"ln_parallel_fwd_768.cu"
,
"ln_parallel_bwd_768.cu"
,
"ln_parallel_fwd_1024.cu"
,
"ln_parallel_bwd_1024.cu"
,
"ln_parallel_fwd_1280.cu"
,
"ln_parallel_bwd_1280.cu"
,
"ln_parallel_fwd_1536.cu"
,
"ln_parallel_bwd_1536.cu"
,
"ln_parallel_fwd_2048.cu"
,
"ln_parallel_bwd_2048.cu"
,
"ln_parallel_fwd_2560.cu"
,
"ln_parallel_bwd_2560.cu"
,
"ln_parallel_fwd_3072.cu"
,
"ln_parallel_bwd_3072.cu"
,
"ln_parallel_fwd_4096.cu"
,
"ln_parallel_bwd_4096.cu"
,
"ln_parallel_fwd_5120.cu"
,
"ln_parallel_bwd_5120.cu"
,
"ln_parallel_fwd_6144.cu"
,
"ln_parallel_bwd_6144.cu"
,
"ln_parallel_fwd_7168.cu"
,
"ln_parallel_bwd_7168.cu"
,
"ln_parallel_fwd_8192.cu"
,
"ln_parallel_bwd_8192.cu"
,
],
extra_compile_args
=
{
"cxx"
:
[
"-O3"
]
+
generator_flag
,
...
...
flash_attn/models/gpt.py
View file @
393882bc
...
...
@@ -37,6 +37,11 @@ try:
except
ImportError
:
dropout_add_layer_norm
=
None
try
:
from
flash_attn.ops.layer_norm
import
dropout_add_layer_norm_parallel_residual
except
ImportError
:
dropout_add_layer_norm_parallel_residual
=
None
try
:
from
flash_attn.ops.triton.mlp
import
FusedDenseSqreluDense
,
sqrelu_fwd
except
ImportError
:
...
...
@@ -282,8 +287,10 @@ class GPTModel(GPTPreTrainedModel):
for
i
in
range
(
config
.
num_hidden_layers
)])
self
.
fused_dropout_add_ln
=
getattr
(
config
,
'fused_dropout_add_ln'
,
False
)
if
self
.
fused_dropout_add_ln
and
dropout_add_layer_norm
is
None
:
raise
ImportError
(
'dropout_add_layer_norm is not installed'
)
if
self
.
fused_dropout_add_ln
:
if
((
not
self
.
parallel_block
and
dropout_add_layer_norm
is
None
)
or
(
self
.
parallel_block
and
dropout_add_layer_norm_parallel_residual
is
None
)):
raise
ImportError
(
'dropout_layer_norm is not installed'
)
if
self
.
prenorm
:
self
.
drop_f
=
nn
.
Dropout
(
config
.
resid_pdrop
)
self
.
ln_f
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
,
...
...
@@ -340,13 +347,19 @@ class GPTModel(GPTPreTrainedModel):
if
residual
is
not
None
else
dropped
+
dropped2
)
hidden_states
=
self
.
ln_f
(
residual
.
to
(
dtype
=
self
.
ln_f
.
weight
.
dtype
))
else
:
assert
not
self
.
parallel_block
# Set prenorm=False here since we don't need the residual
hidden_states
=
dropout_add_layer_norm
(
hidden_states
,
residual
,
self
.
ln_f
.
weight
,
self
.
ln_f
.
bias
,
self
.
drop_f
.
p
if
self
.
training
else
0.0
,
self
.
ln_f
.
eps
,
prenorm
=
False
,
residual_in_fp32
=
self
.
residual_in_fp32
)
if
not
self
.
parallel_block
:
hidden_states
=
dropout_add_layer_norm
(
hidden_states
,
residual
,
self
.
ln_f
.
weight
,
self
.
ln_f
.
bias
,
self
.
drop_f
.
p
if
self
.
training
else
0.0
,
self
.
ln_f
.
eps
,
prenorm
=
False
,
residual_in_fp32
=
self
.
residual_in_fp32
)
else
:
hidden_states
,
_
=
dropout_add_layer_norm_parallel_residual
(
hidden_states
,
hidden_states2
,
residual
,
self
.
ln_f
.
weight
,
self
.
ln_f
.
bias
,
None
,
None
,
self
.
drop_f
.
p
if
self
.
training
else
0.0
,
self
.
ln_f
.
eps
,
prenorm
=
False
,
residual_in_fp32
=
self
.
residual_in_fp32
)
return
hidden_states
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment