Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
393882bc
Commit
393882bc
authored
Mar 29, 2023
by
Tri Dao
Browse files
[LayerNorm] Implement LN with parallel residual, support dim 8k
parent
009a3e71
Changes
46
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1132 additions
and
8 deletions
+1132
-8
csrc/layer_norm/ln_parallel_bwd_8192.cu
csrc/layer_norm/ln_parallel_bwd_8192.cu
+15
-0
csrc/layer_norm/ln_parallel_fwd_1024.cu
csrc/layer_norm/ln_parallel_fwd_1024.cu
+15
-0
csrc/layer_norm/ln_parallel_fwd_1280.cu
csrc/layer_norm/ln_parallel_fwd_1280.cu
+15
-0
csrc/layer_norm/ln_parallel_fwd_1536.cu
csrc/layer_norm/ln_parallel_fwd_1536.cu
+15
-0
csrc/layer_norm/ln_parallel_fwd_2048.cu
csrc/layer_norm/ln_parallel_fwd_2048.cu
+15
-0
csrc/layer_norm/ln_parallel_fwd_256.cu
csrc/layer_norm/ln_parallel_fwd_256.cu
+15
-0
csrc/layer_norm/ln_parallel_fwd_2560.cu
csrc/layer_norm/ln_parallel_fwd_2560.cu
+15
-0
csrc/layer_norm/ln_parallel_fwd_3072.cu
csrc/layer_norm/ln_parallel_fwd_3072.cu
+15
-0
csrc/layer_norm/ln_parallel_fwd_4096.cu
csrc/layer_norm/ln_parallel_fwd_4096.cu
+15
-0
csrc/layer_norm/ln_parallel_fwd_512.cu
csrc/layer_norm/ln_parallel_fwd_512.cu
+15
-0
csrc/layer_norm/ln_parallel_fwd_5120.cu
csrc/layer_norm/ln_parallel_fwd_5120.cu
+15
-0
csrc/layer_norm/ln_parallel_fwd_6144.cu
csrc/layer_norm/ln_parallel_fwd_6144.cu
+15
-0
csrc/layer_norm/ln_parallel_fwd_7168.cu
csrc/layer_norm/ln_parallel_fwd_7168.cu
+15
-0
csrc/layer_norm/ln_parallel_fwd_768.cu
csrc/layer_norm/ln_parallel_fwd_768.cu
+15
-0
csrc/layer_norm/ln_parallel_fwd_8192.cu
csrc/layer_norm/ln_parallel_fwd_8192.cu
+15
-0
csrc/layer_norm/ln_parallel_residual_bwd_kernels.cuh
csrc/layer_norm/ln_parallel_residual_bwd_kernels.cuh
+540
-0
csrc/layer_norm/ln_parallel_residual_fwd_kernels.cuh
csrc/layer_norm/ln_parallel_residual_fwd_kernels.cuh
+281
-0
csrc/layer_norm/ln_utils.cuh
csrc/layer_norm/ln_utils.cuh
+33
-0
csrc/layer_norm/setup.py
csrc/layer_norm/setup.py
+32
-0
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+21
-8
No files found.
csrc/layer_norm/ln_parallel_bwd_8192.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_bwd_kernels.cuh"
// Create backward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL
REGISTER_PARALLEL_BWD_LAUNCHER
(
8192
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
8192
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
8192
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
8192
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
8192
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
8192
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
8192
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
8192
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
8192
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
REGISTER_PARALLEL_BWD_LAUNCHER
(
8192
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
8
,
16
,
4
);
\ No newline at end of file
csrc/layer_norm/ln_parallel_fwd_1024.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_PARALLEL_FWD_LAUNCHER
(
1024
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1024
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1024
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1024
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1024
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1024
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1024
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1024
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1024
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1024
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
csrc/layer_norm/ln_parallel_fwd_1280.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_PARALLEL_FWD_LAUNCHER
(
1280
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1280
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1280
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1280
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1280
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1280
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1280
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1280
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1280
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1280
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
csrc/layer_norm/ln_parallel_fwd_1536.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_PARALLEL_FWD_LAUNCHER
(
1536
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1536
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1536
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1536
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1536
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1536
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1536
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1536
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1536
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
1536
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
csrc/layer_norm/ln_parallel_fwd_2048.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_PARALLEL_FWD_LAUNCHER
(
2048
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
2048
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
2048
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
2048
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
2048
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
2048
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
2048
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
2048
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
2048
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
2048
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
csrc/layer_norm/ln_parallel_fwd_256.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_PARALLEL_FWD_LAUNCHER
(
256
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
256
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
256
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
256
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
256
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
256
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
256
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
256
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
256
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
256
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
\ No newline at end of file
csrc/layer_norm/ln_parallel_fwd_2560.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_PARALLEL_FWD_LAUNCHER
(
2560
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
2560
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
2560
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
2560
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
2560
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
2560
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
2560
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
2560
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
2560
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
2560
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
csrc/layer_norm/ln_parallel_fwd_3072.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_PARALLEL_FWD_LAUNCHER
(
3072
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
3072
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
3072
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
3072
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
3072
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
3072
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
3072
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
3072
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
3072
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
3072
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
csrc/layer_norm/ln_parallel_fwd_4096.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_PARALLEL_FWD_LAUNCHER
(
4096
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
4096
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
4096
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
4096
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
4096
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
4096
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
4096
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
4096
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
4096
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
4096
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
csrc/layer_norm/ln_parallel_fwd_512.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_PARALLEL_FWD_LAUNCHER
(
512
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
512
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
512
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
512
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
512
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
512
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
512
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
512
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
512
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
512
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
csrc/layer_norm/ln_parallel_fwd_5120.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_PARALLEL_FWD_LAUNCHER
(
5120
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
5120
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
5120
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
5120
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
5120
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
5120
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
5120
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
5120
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
5120
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
5120
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
csrc/layer_norm/ln_parallel_fwd_6144.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_PARALLEL_FWD_LAUNCHER
(
6144
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
6144
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
6144
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
6144
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
6144
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
6144
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
6144
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
6144
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
6144
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
6144
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
8
,
16
);
csrc/layer_norm/ln_parallel_fwd_7168.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_PARALLEL_FWD_LAUNCHER
(
7168
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
7168
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
7168
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
7168
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
7168
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
7168
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
7168
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
7168
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
7168
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
4
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
7168
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
4
,
16
);
csrc/layer_norm/ln_parallel_fwd_768.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_PARALLEL_FWD_LAUNCHER
(
768
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
768
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
768
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
768
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
768
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
768
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
768
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
768
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
768
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
4
,
1
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
768
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
4
,
1
,
16
);
csrc/layer_norm/ln_parallel_fwd_8192.cu
0 → 100644
View file @
393882bc
#include "ln_parallel_residual_fwd_kernels.cuh"
// Create forward launch function and register. Macro signature:
// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG
REGISTER_PARALLEL_FWD_LAUNCHER
(
8192
,
fp32
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
8192
,
fp16
,
fp32
,
fp32
,
fp32
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
8192
,
fp32
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
8192
,
fp16
,
fp16
,
fp32
,
fp16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
8192
,
fp32
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
8192
,
fp32
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
8192
,
bf16
,
bf16
,
fp32
,
bf16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
8192
,
fp32
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
8192
,
fp16
,
fp16
,
fp16
,
fp16
,
fp32
,
1
,
1
,
8
,
16
);
REGISTER_PARALLEL_FWD_LAUNCHER
(
8192
,
bf16
,
bf16
,
bf16
,
bf16
,
fp32
,
1
,
1
,
8
,
16
);
csrc/layer_norm/ln_parallel_residual_bwd_kernels.cuh
0 → 100644
View file @
393882bc
This diff is collapsed.
Click to expand it.
csrc/layer_norm/ln_parallel_residual_fwd_kernels.cuh
0 → 100644
View file @
393882bc
#pragma once
#ifdef OLD_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h>
#else
#include <ATen/cuda/CUDAGeneratorImpl.h>
#endif
#include <ATen/cuda/detail/UnpackRaw.cuh> // For at::cuda::philox::unpack
#include <curand_kernel.h>
#include "ln.h"
#include "ln_utils.cuh"
#include "ln_kernel_traits.h"
#include "static_switch.h"
namespace
layer_norm
{
template
<
typename
Ktraits
,
bool
Is_dropout
,
bool
Tied_norm
,
bool
Is_even_cols
>
__global__
__launch_bounds__
(
Ktraits
::
THREADS_PER_CTA
)
void
ln_parallel_residual_fwd_kernel
(
FwdParams
params
)
{
enum
{
ROWS_PER_CTA
=
Ktraits
::
ROWS_PER_CTA
};
enum
{
WARPS_N
=
Ktraits
::
WARPS_N
};
enum
{
WARPS_M
=
Ktraits
::
WARPS_M
};
enum
{
THREADS_PER_ROW
=
Ktraits
::
THREADS_PER_ROW
};
enum
{
VEC_COLS_PER_LDG
=
Ktraits
::
VEC_COLS_PER_LDG
};
enum
{
BYTES_PER_ROW
=
Ktraits
::
BYTES_PER_ROW
};
enum
{
LDGS
=
Ktraits
::
LDGS
};
enum
{
NUM_ELTS
=
Ktraits
::
NUM_ELTS
};
enum
{
CTAS_PER_ROW
=
Ktraits
::
CTAS_PER_ROW
};
using
input_t
=
typename
Ktraits
::
input_t
;
using
residual_t
=
typename
Ktraits
::
residual_t
;
using
output_t
=
typename
Ktraits
::
output_t
;
using
index_t
=
typename
Ktraits
::
index_t
;
using
compute_t
=
typename
Ktraits
::
compute_t
;
using
mask_t
=
typename
Ktraits
::
mask_t
;
using
Ivec
=
typename
Ktraits
::
Ivec
;
using
Rvec
=
typename
Ktraits
::
Rvec
;
using
Ovec
=
typename
Ktraits
::
Ovec
;
using
Wvec
=
typename
Ktraits
::
Wvec
;
using
Cvec
=
typename
Ktraits
::
Cvec
;
using
Mvec
=
typename
Ktraits
::
Mvec
;
using
Stats
=
typename
Ktraits
::
Stats
;
using
stats_t
=
typename
Stats
::
stats_t
;
const
bool
has_residual
=
params
.
residual
!=
nullptr
;
const
bool
has_x1
=
params
.
x1
!=
nullptr
;
const
bool
save_x
=
has_residual
||
has_x1
||
Is_dropout
||
!
(
std
::
is_same
<
input_t
,
residual_t
>::
value
);
extern
__shared__
char
smem_
[];
const
index_t
tidx
=
threadIdx
.
x
;
const
index_t
bidn
=
blockIdx
.
x
%
CTAS_PER_ROW
;
const
index_t
bidm
=
blockIdx
.
x
/
CTAS_PER_ROW
;
const
index_t
lane
=
tidx
%
THREADS_PER_WARP
;
const
index_t
warp
=
tidx
/
THREADS_PER_WARP
;
const
index_t
warp_m
=
warp
/
WARPS_N
;
const
index_t
warp_n
=
warp
%
WARPS_N
;
const
index_t
r
=
bidm
*
ROWS_PER_CTA
+
warp_m
;
const
index_t
c
=
bidn
*
THREADS_PER_ROW
+
warp_n
*
THREADS_PER_WARP
+
lane
;
Stats
stats
(
params
,
bidm
,
bidn
,
warp_m
,
warp_n
,
lane
,
smem_
);
compute_t
*
mu_ptr
=
static_cast
<
compute_t
*>
(
params
.
mu
);
compute_t
*
rs_ptr
=
static_cast
<
compute_t
*>
(
params
.
rs
);
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/Dropout.cu
curandStatePhilox4_32_10_t
state
;
if
(
Is_dropout
)
{
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
const
index_t
tidx_global
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
curand_init
(
std
::
get
<
0
>
(
seeds
),
tidx_global
,
std
::
get
<
1
>
(
seeds
),
&
state
);
}
const
index_t
num_valid_ldgs
=
((
params
.
cols
/
Ktraits
::
ELTS_PER_LDG
)
-
1
-
c
+
VEC_COLS_PER_LDG
)
/
VEC_COLS_PER_LDG
;
Wvec
gamma0
[
LDGS
];
Wvec
beta0
[
LDGS
];
Wvec
gamma1
[
LDGS
];
Wvec
beta1
[
LDGS
];
index_t
idx
=
c
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
if
(
Is_even_cols
||
(
it
<
num_valid_ldgs
))
{
gamma0
[
it
].
load_from
(
params
.
gamma
,
idx
);
if
(
params
.
beta
!=
nullptr
)
{
beta0
[
it
].
load_from
(
params
.
beta
,
idx
);
}
else
{
beta0
[
it
].
zero_
();
}
if
(
!
Tied_norm
)
{
gamma1
[
it
].
load_from
(
params
.
gamma1
,
idx
);
if
(
params
.
beta1
!=
nullptr
)
{
beta1
[
it
].
load_from
(
params
.
beta1
,
idx
);
}
else
{
beta1
[
it
].
zero_
();
}
}
idx
+=
VEC_COLS_PER_LDG
;
}
}
for
(
int
row
=
r
;
row
<
params
.
rows
;
row
+=
params
.
ctas_per_col
*
ROWS_PER_CTA
)
{
index_t
idx
=
row
*
params
.
cols
/
Ktraits
::
ELTS_PER_LDG
+
c
;
compute_t
xf
[
LDGS
*
NUM_ELTS
];
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
if
(
Is_even_cols
||
(
it
<
num_valid_ldgs
))
{
Ivec
x0
;
Ivec
x1
;
Rvec
residual
;
Rvec
x
;
Mvec
dmask0
;
Mvec
dmask1
;
x0
.
load_from
(
params
.
x0
,
idx
);
if
(
has_x1
)
{
x1
.
load_from
(
params
.
x1
,
idx
);
}
if
(
has_residual
)
{
residual
.
load_from
(
params
.
residual
,
idx
);
}
#pragma unroll
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
// TD [2022-04-22]: We're memory bound, not compute bound, so we don't need to use
// the more efficient curand_uniform4.
compute_t
x_ij
;
mask_t
keep0
=
!
Is_dropout
?
true
:
curand_uniform
(
&
state
)
<=
params
.
dropout_keep_p
;
if
(
Is_dropout
)
{
dmask0
.
data
.
elt
[
jt
]
=
keep0
;
}
compute_t
x0_ij
=
compute_t
(
x0
.
data
.
elt
[
jt
]);
x0_ij
=
keep0
?
(
Is_dropout
?
x0_ij
*
params
.
dropout_scale
:
x0_ij
)
:
0.0
f
;
if
(
has_x1
)
{
mask_t
keep1
=
!
Is_dropout
?
true
:
curand_uniform
(
&
state
)
<=
params
.
dropout_keep_p
;
if
(
Is_dropout
)
{
dmask1
.
data
.
elt
[
jt
]
=
keep1
;
}
compute_t
x1_ij
=
compute_t
(
x1
.
data
.
elt
[
jt
]);
x1_ij
=
keep1
?
(
Is_dropout
?
x1_ij
*
params
.
dropout_scale
:
x1_ij
)
:
0.0
f
;
x_ij
=
has_residual
?
x0_ij
+
x1_ij
+
compute_t
(
residual
.
data
.
elt
[
jt
])
:
x0_ij
+
x1_ij
;
}
else
{
x_ij
=
has_residual
?
x0_ij
+
compute_t
(
residual
.
data
.
elt
[
jt
])
:
x0_ij
;
}
if
(
save_x
)
{
x
.
data
.
elt
[
jt
]
=
x_ij
;
}
xf
[
it
*
NUM_ELTS
+
jt
]
=
x_ij
;
}
if
(
save_x
)
{
x
.
store_to
(
params
.
x
,
idx
);
}
if
(
Is_dropout
)
{
dmask0
.
store_to
(
params
.
dmask
,
idx
);
if
(
has_x1
)
{
dmask1
.
store_to
(
params
.
dmask1
,
idx
);
}
}
idx
+=
VEC_COLS_PER_LDG
;
}
}
static_assert
(
CTAS_PER_ROW
==
1
,
"Don't support multiple CTAs per row for now"
);
const
index_t
num_vecs
=
params
.
cols
/
Ktraits
::
ELTS_PER_LDG
;
const
index_t
num_full_ldgs
=
num_vecs
/
Ktraits
::
VEC_COLS_PER_LDG
;
const
index_t
remaining_vecs
=
num_vecs
%
Ktraits
::
VEC_COLS_PER_LDG
;
auto
valid_elts_in_warp_fn
=
[
num_full_ldgs
,
remaining_vecs
]
(
int
warp_n
)
->
int
{
// Need to convert to int, otherwise the subtraction will wrap around.
const
index_t
valid_partial_vecs_in_warp
=
std
::
min
(
std
::
max
(
int
(
remaining_vecs
)
-
int
(
warp_n
*
THREADS_PER_WARP
),
int
(
0
)),
int
(
THREADS_PER_WARP
));
return
(
num_full_ldgs
*
THREADS_PER_WARP
+
valid_partial_vecs_in_warp
)
*
NUM_ELTS
;
};
stats_t
s
=
stats
.
template
compute
<
Is_even_cols
>(
xf
,
params
.
inverse_cols
,
valid_elts_in_warp_fn
,
num_valid_ldgs
*
NUM_ELTS
);
compute_t
mu
=
layer_norm
::
Get
<
0
>::
of
<
stats_t
,
compute_t
>
(
s
);
compute_t
m2
=
layer_norm
::
Get
<
1
>::
of
<
stats_t
,
compute_t
>
(
s
);
if
(
bidn
==
0
&&
warp_n
==
0
&&
lane
==
0
)
{
mu_ptr
[
row
]
=
mu
;
}
compute_t
rs
=
rsqrtf
(
m2
*
params
.
inverse_cols
+
params
.
epsilon
+
(
!
params
.
is_rms_norm
?
0.
f
:
mu
*
mu
));
if
(
bidn
==
0
&&
warp_n
==
0
&&
lane
==
0
)
{
rs_ptr
[
row
]
=
rs
;
}
idx
=
row
*
params
.
cols
/
Ktraits
::
ELTS_PER_LDG
+
c
;
#pragma unroll
for
(
int
it
=
0
;
it
<
LDGS
;
it
++
)
{
if
(
Is_even_cols
||
(
it
<
num_valid_ldgs
))
{
Ovec
z0
;
Ovec
z1
;
#pragma unroll
for
(
int
jt
=
0
;
jt
<
NUM_ELTS
;
jt
++
)
{
compute_t
y_ij
=
compute_t
(
rs
*
(
xf
[
it
*
NUM_ELTS
+
jt
]
-
(
!
params
.
is_rms_norm
?
mu
:
0.
f
)));
compute_t
g0_ij
=
gamma0
[
it
].
data
.
elt
[
jt
];
compute_t
b0_ij
=
beta0
[
it
].
data
.
elt
[
jt
];
z0
.
data
.
elt
[
jt
]
=
output_t
(
g0_ij
*
y_ij
+
b0_ij
);
if
(
!
Tied_norm
)
{
compute_t
g1_ij
=
gamma1
[
it
].
data
.
elt
[
jt
];
compute_t
b1_ij
=
beta1
[
it
].
data
.
elt
[
jt
];
z1
.
data
.
elt
[
jt
]
=
output_t
(
g1_ij
*
y_ij
+
b1_ij
);
}
}
z0
.
store_to
(
params
.
z
,
idx
);
if
(
!
Tied_norm
)
{
z1
.
store_to
(
params
.
z1
,
idx
);
}
idx
+=
VEC_COLS_PER_LDG
;
}
}
}
}
}
// namespace layer_norm
using
namespace
layer_norm
;
template
<
typename
weight_t
,
typename
input_t
,
typename
residual_t
,
typename
output_t
,
typename
compute_t
,
typename
index_t
,
int
HIDDEN_SIZE
,
int
CTAS_PER_ROW
,
int
WARPS_M
,
int
WARPS_N
,
int
BYTES_PER_LDG
>
void
launch_parallel_residual_
(
LaunchParams
<
FwdParams
>
&
launch_params
,
const
bool
configure_params
){
using
Kernel_traits
=
Kernel_traits
<
weight_t
,
input_t
,
residual_t
,
output_t
,
compute_t
,
index_t
,
HIDDEN_SIZE
,
CTAS_PER_ROW
,
WARPS_M
,
WARPS_N
,
BYTES_PER_LDG
>
;
bool
is_even_cols
=
launch_params
.
params
.
cols
==
HIDDEN_SIZE
;
bool
tied_norm
=
launch_params
.
params
.
gamma1
==
nullptr
;
BOOL_SWITCH
(
launch_params
.
params
.
dropout_keep_p
<
1.
f
,
IsDropoutConst
,
[
&
]
{
BOOL_SWITCH
(
tied_norm
,
TiedNormConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_cols
,
IsEvenColsConst
,
[
&
]
{
auto
kernel
=
&
ln_parallel_residual_fwd_kernel
<
Kernel_traits
,
IsDropoutConst
,
TiedNormConst
,
IsEvenColsConst
>
;
if
(
configure_params
)
{
int
ctas_per_sm
;
CHECK_CUDA
(
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
ctas_per_sm
,
kernel
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES_FWD
));
launch_params
.
params
.
ctas_per_col
=
launch_params
.
props
->
multiProcessorCount
*
ctas_per_sm
/
Kernel_traits
::
CTAS_PER_ROW
;
const
size_t
rows_per_loop
=
launch_params
.
params
.
ctas_per_col
*
Kernel_traits
::
ROWS_PER_CTA
;
launch_params
.
elts_per_thread
=
(
launch_params
.
params
.
rows
+
rows_per_loop
-
1
)
/
rows_per_loop
*
Kernel_traits
::
LDGS
*
Kernel_traits
::
NUM_ELTS
;
launch_params
.
barrier_size
=
0
;
launch_params
.
workspace_bytes
=
0
;
if
(
Kernel_traits
::
CTAS_PER_ROW
>
1
)
{
launch_params
.
barrier_size
=
2
*
launch_params
.
params
.
ctas_per_col
;
launch_params
.
workspace_bytes
=
launch_params
.
params
.
ctas_per_col
*
Kernel_traits
::
WARPS_M
*
Kernel_traits
::
CTAS_PER_ROW
*
sizeof
(
typename
Kernel_traits
::
Stats
::
stats_t
)
*
2
;
}
return
;
}
if
(
Kernel_traits
::
SMEM_BYTES_FWD
>=
48
*
1024
)
{
CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
Kernel_traits
::
SMEM_BYTES_FWD
));
}
auto
stream
=
launch_params
.
stream
;
auto
ctas_per_col
=
launch_params
.
params
.
ctas_per_col
;
if
(
Kernel_traits
::
CTAS_PER_ROW
==
1
)
{
kernel
<<<
ctas_per_col
,
Kernel_traits
::
THREADS_PER_CTA
,
Kernel_traits
::
SMEM_BYTES_FWD
,
stream
>>>
(
launch_params
.
params
);
}
else
{
dim3
grid
(
Kernel_traits
::
CTAS_PER_ROW
*
ctas_per_col
);
dim3
block
(
Kernel_traits
::
THREADS_PER_CTA
);
void
*
params_
=
(
void
*
)
&
launch_params
.
params
;
cudaLaunchCooperativeKernel
((
void
*
)
kernel
,
grid
,
block
,
(
void
**
)
&
params_
,
Kernel_traits
::
SMEM_BYTES_FWD
,
stream
);
}
});
});
});
}
csrc/layer_norm/ln_utils.cuh
View file @
393882bc
...
...
@@ -64,6 +64,39 @@ inline void check_cuda_(cudaError_t status, const char *file, int line) {
////////////////////////////////////////////////////////////////////////////////////////////////////
#define REGISTER_PARALLEL_FWD_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG) \
void ln_parallel_residual_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams<FwdParams> &launch_params, \
const bool configure_params) { \
launch_parallel_residual_<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, uint32_t, HIDDEN_SIZE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG>( \
launch_params, configure_params); \
} \
static FwdParallelRegistrar<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, HIDDEN_SIZE> reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \
ln_parallel_residual_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE)
////////////////////////////////////////////////////////////////////////////////////////////////////
#define REGISTER_PARALLEL_BWD_LAUNCHER( \
HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \
void ln_parallel_residual_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams<BwdParams> &launch_params, \
const bool configure_params) { \
launch_parallel_residual_<WTYPE, \
ITYPE, \
RTYPE, \
OTYPE, \
CTYPE, \
uint32_t, \
HIDDEN_SIZE, \
CTAS_PER_ROW, \
WARPS_M, \
WARPS_N, \
BYTES_PER_LDG, \
BYTES_PER_LDG_FINALIZE>(launch_params, configure_params); \
} \
static BwdParallelRegistrar<WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, HIDDEN_SIZE> reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \
ln_parallel_residual_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE)
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float2
operator
+
(
const
float2
&
a
,
const
float2
&
b
){
return
{
a
.
x
+
b
.
x
,
a
.
y
+
b
.
y
};
}
...
...
csrc/layer_norm/setup.py
View file @
393882bc
...
...
@@ -139,6 +139,38 @@ ext_modules.append(
"ln_bwd_5120.cu"
,
"ln_fwd_6144.cu"
,
"ln_bwd_6144.cu"
,
"ln_fwd_7168.cu"
,
"ln_bwd_7168.cu"
,
"ln_fwd_8192.cu"
,
"ln_bwd_8192.cu"
,
"ln_parallel_fwd_256.cu"
,
"ln_parallel_bwd_256.cu"
,
"ln_parallel_fwd_512.cu"
,
"ln_parallel_bwd_512.cu"
,
"ln_parallel_fwd_768.cu"
,
"ln_parallel_bwd_768.cu"
,
"ln_parallel_fwd_1024.cu"
,
"ln_parallel_bwd_1024.cu"
,
"ln_parallel_fwd_1280.cu"
,
"ln_parallel_bwd_1280.cu"
,
"ln_parallel_fwd_1536.cu"
,
"ln_parallel_bwd_1536.cu"
,
"ln_parallel_fwd_2048.cu"
,
"ln_parallel_bwd_2048.cu"
,
"ln_parallel_fwd_2560.cu"
,
"ln_parallel_bwd_2560.cu"
,
"ln_parallel_fwd_3072.cu"
,
"ln_parallel_bwd_3072.cu"
,
"ln_parallel_fwd_4096.cu"
,
"ln_parallel_bwd_4096.cu"
,
"ln_parallel_fwd_5120.cu"
,
"ln_parallel_bwd_5120.cu"
,
"ln_parallel_fwd_6144.cu"
,
"ln_parallel_bwd_6144.cu"
,
"ln_parallel_fwd_7168.cu"
,
"ln_parallel_bwd_7168.cu"
,
"ln_parallel_fwd_8192.cu"
,
"ln_parallel_bwd_8192.cu"
,
],
extra_compile_args
=
{
"cxx"
:
[
"-O3"
]
+
generator_flag
,
...
...
flash_attn/models/gpt.py
View file @
393882bc
...
...
@@ -37,6 +37,11 @@ try:
except
ImportError
:
dropout_add_layer_norm
=
None
try
:
from
flash_attn.ops.layer_norm
import
dropout_add_layer_norm_parallel_residual
except
ImportError
:
dropout_add_layer_norm_parallel_residual
=
None
try
:
from
flash_attn.ops.triton.mlp
import
FusedDenseSqreluDense
,
sqrelu_fwd
except
ImportError
:
...
...
@@ -282,8 +287,10 @@ class GPTModel(GPTPreTrainedModel):
for
i
in
range
(
config
.
num_hidden_layers
)])
self
.
fused_dropout_add_ln
=
getattr
(
config
,
'fused_dropout_add_ln'
,
False
)
if
self
.
fused_dropout_add_ln
and
dropout_add_layer_norm
is
None
:
raise
ImportError
(
'dropout_add_layer_norm is not installed'
)
if
self
.
fused_dropout_add_ln
:
if
((
not
self
.
parallel_block
and
dropout_add_layer_norm
is
None
)
or
(
self
.
parallel_block
and
dropout_add_layer_norm_parallel_residual
is
None
)):
raise
ImportError
(
'dropout_layer_norm is not installed'
)
if
self
.
prenorm
:
self
.
drop_f
=
nn
.
Dropout
(
config
.
resid_pdrop
)
self
.
ln_f
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
,
...
...
@@ -340,13 +347,19 @@ class GPTModel(GPTPreTrainedModel):
if
residual
is
not
None
else
dropped
+
dropped2
)
hidden_states
=
self
.
ln_f
(
residual
.
to
(
dtype
=
self
.
ln_f
.
weight
.
dtype
))
else
:
assert
not
self
.
parallel_block
# Set prenorm=False here since we don't need the residual
hidden_states
=
dropout_add_layer_norm
(
hidden_states
,
residual
,
self
.
ln_f
.
weight
,
self
.
ln_f
.
bias
,
self
.
drop_f
.
p
if
self
.
training
else
0.0
,
self
.
ln_f
.
eps
,
prenorm
=
False
,
residual_in_fp32
=
self
.
residual_in_fp32
)
if
not
self
.
parallel_block
:
hidden_states
=
dropout_add_layer_norm
(
hidden_states
,
residual
,
self
.
ln_f
.
weight
,
self
.
ln_f
.
bias
,
self
.
drop_f
.
p
if
self
.
training
else
0.0
,
self
.
ln_f
.
eps
,
prenorm
=
False
,
residual_in_fp32
=
self
.
residual_in_fp32
)
else
:
hidden_states
,
_
=
dropout_add_layer_norm_parallel_residual
(
hidden_states
,
hidden_states2
,
residual
,
self
.
ln_f
.
weight
,
self
.
ln_f
.
bias
,
None
,
None
,
self
.
drop_f
.
p
if
self
.
training
else
0.0
,
self
.
ln_f
.
eps
,
prenorm
=
False
,
residual_in_fp32
=
self
.
residual_in_fp32
)
return
hidden_states
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment