From c2b62b7ffe6aecc6dde4ecf90ebb7ee5e64db565 Mon Sep 17 00:00:00 2001 From: jinrong <3413506624@qq.com> Date: Thu, 13 Mar 2025 16:18:02 +0800 Subject: [PATCH] delete origin files --- Dockerfile | 7 - LICENSE | 11 - README.md | 69 - README_ORIGIN.md | 162 - apex/RNN/README.md | 1 - apex/RNN/RNNBackend.py | 365 -- apex/RNN/__init__.py | 3 - apex/RNN/cells.py | 84 - apex/RNN/models.py | 54 - apex/__init__.py | 57 - apex/_autocast_utils.py | 23 - apex/amp/README.md | 72 - apex/amp/__init__.py | 5 - apex/amp/__version__.py | 2 - apex/amp/_amp_state.py | 59 - apex/amp/_initialize.py | 267 -- apex/amp/_process_optimizer.py | 489 -- apex/amp/amp.py | 198 - apex/amp/compat.py | 47 - apex/amp/frontend.py | 509 -- apex/amp/handle.py | 281 -- apex/amp/lists/__init__.py | 0 apex/amp/lists/functional_overrides.py | 91 - apex/amp/lists/tensor_overrides.py | 67 - apex/amp/lists/torch_overrides.py | 136 - apex/amp/opt.py | 103 - apex/amp/rnn_compat.py | 53 - apex/amp/scaler.py | 226 - apex/amp/utils.py | 232 - apex/amp/wrap.py | 286 -- apex/contrib/__init__.py | 0 apex/contrib/bottleneck/__init__.py | 2 - apex/contrib/bottleneck/bottleneck.py | 749 --- .../bottleneck/bottleneck_module_test.py | 254 - apex/contrib/bottleneck/halo_exchangers.py | 171 - apex/contrib/bottleneck/test.py | 71 - apex/contrib/clip_grad/__init__.py | 1 - apex/contrib/clip_grad/clip_grad.py | 128 - apex/contrib/conv_bias_relu/__init__.py | 2 - apex/contrib/conv_bias_relu/conv_bias_relu.py | 81 - apex/contrib/csrc/bottleneck/bottleneck.cpp | 4073 ----------------- .../csrc/conv_bias_relu/conv_bias_relu.cpp | 1639 ------- apex/contrib/csrc/cudnn-frontend | 1 - apex/contrib/csrc/fmha/fmha_api.cpp | 361 -- apex/contrib/csrc/fmha/src/fmha.h | 163 - apex/contrib/csrc/fmha/src/fmha/gemm.h | 314 -- apex/contrib/csrc/fmha/src/fmha/gmem_tile.h | 456 -- .../csrc/fmha/src/fmha/kernel_traits.h | 97 - apex/contrib/csrc/fmha/src/fmha/mask.h | 81 - apex/contrib/csrc/fmha/src/fmha/smem_tile.h | 1286 ------ apex/contrib/csrc/fmha/src/fmha/softmax.h | 395 -- apex/contrib/csrc/fmha/src/fmha/utils.h | 1038 ----- .../src/fmha_dgrad_fp16_128_64_kernel.sm80.cu | 60 - .../src/fmha_dgrad_fp16_256_64_kernel.sm80.cu | 60 - .../src/fmha_dgrad_fp16_384_64_kernel.sm80.cu | 60 - .../src/fmha_dgrad_fp16_512_64_kernel.sm80.cu | 105 - .../fmha/src/fmha_dgrad_kernel_1xN_reload.h | 558 --- .../src/fmha_dgrad_kernel_1xN_reload_nl.h | 569 --- .../src/fmha_fprop_fp16_128_64_kernel.sm80.cu | 84 - .../src/fmha_fprop_fp16_256_64_kernel.sm80.cu | 84 - .../src/fmha_fprop_fp16_384_64_kernel.sm80.cu | 84 - .../src/fmha_fprop_fp16_512_64_kernel.sm80.cu | 137 - .../csrc/fmha/src/fmha_fprop_kernel_1xN.h | 531 --- apex/contrib/csrc/fmha/src/fmha_kernel.h | 179 - .../csrc/fmha/src/fmha_noloop_reduce.cu | 177 - apex/contrib/csrc/fmha/src/fmha_utils.h | 92 - .../csrc/focal_loss/focal_loss_cuda.cpp | 70 - .../csrc/focal_loss/focal_loss_cuda_kernel.cu | 267 -- apex/contrib/csrc/groupbn/batch_norm.cu | 342 -- apex/contrib/csrc/groupbn/batch_norm.h | 901 ---- .../csrc/groupbn/batch_norm_add_relu.cu | 353 -- .../csrc/groupbn/batch_norm_add_relu.h | 816 ---- apex/contrib/csrc/groupbn/cuda_utils.h | 28 - apex/contrib/csrc/groupbn/dnn.h | 26 - apex/contrib/csrc/groupbn/interface.cpp | 175 - apex/contrib/csrc/groupbn/ipc.cu | 129 - .../csrc/groupbn/nhwc_batch_norm_kernel.h | 3021 ------------ .../csrc/index_mul_2d/index_mul_2d_cuda.cpp | 139 - .../index_mul_2d/index_mul_2d_cuda_kernel.cu | 492 -- apex/contrib/csrc/layer_norm/ln.h | 210 - apex/contrib/csrc/layer_norm/ln_api.cpp | 246 - .../csrc/layer_norm/ln_bwd_kernels.cuh | 315 -- .../layer_norm/ln_bwd_semi_cuda_kernel.cu | 250 - .../csrc/layer_norm/ln_fwd_cuda_kernel.cu | 235 - .../csrc/layer_norm/ln_fwd_kernels.cuh | 114 - .../csrc/layer_norm/ln_kernel_traits.h | 159 - apex/contrib/csrc/layer_norm/ln_utils.cuh | 793 ---- .../additive_masked_softmax_dropout_cuda.cu | 113 - apex/contrib/csrc/multihead_attn/dropout.cuh | 272 -- .../encdec_multihead_attn_cuda.cu | 611 --- .../encdec_multihead_attn_norm_add_cuda.cu | 690 --- .../csrc/multihead_attn/layer_norm.cuh | 649 --- .../masked_softmax_dropout_cuda.cu | 124 - .../multihead_attn_frontend.cpp | 836 ---- apex/contrib/csrc/multihead_attn/philox.cuh | 96 - ..._multihead_attn_bias_additive_mask_cuda.cu | 504 -- .../self_multihead_attn_bias_cuda.cu | 504 -- .../self_multihead_attn_cuda.cu | 509 -- .../self_multihead_attn_norm_add_cuda.cu | 580 --- apex/contrib/csrc/multihead_attn/softmax.cuh | 3149 ------------- .../multihead_attn/strided_batched_gemm.cuh | 135 - apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp | 25 - apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu | 215 - apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh | 45 - .../csrc/optimizers/fused_adam_cuda.cpp | 86 - .../csrc/optimizers/fused_adam_cuda_kernel.cu | 1037 ----- .../csrc/optimizers/fused_lamb_cuda.cpp | 21 - .../csrc/optimizers/fused_lamb_cuda_kernel.cu | 294 -- .../optimizers/multi_tensor_distopt_adam.cpp | 20 - .../multi_tensor_distopt_adam_kernel.cu | 228 - .../optimizers/multi_tensor_distopt_lamb.cpp | 36 - .../multi_tensor_distopt_lamb_kernel.cu | 506 -- apex/contrib/csrc/peer_memory/peer_memory.cpp | 29 - .../csrc/peer_memory/peer_memory_cuda.cu | 750 --- .../csrc/peer_memory/peer_memory_cuda.cuh | 50 - .../csrc/transducer/transducer_joint.cpp | 98 - .../transducer/transducer_joint_kernel.cu | 985 ---- .../csrc/transducer/transducer_loss.cpp | 109 - .../csrc/transducer/transducer_loss_kernel.cu | 767 ---- apex/contrib/csrc/xentropy/interface.cpp | 52 - apex/contrib/csrc/xentropy/xentropy_kernel.cu | 726 --- .../func_test_multihead_attn.py | 108 - .../perf_test_multihead_attn.py | 115 - apex/contrib/fmha/__init__.py | 1 - apex/contrib/fmha/fmha.py | 76 - apex/contrib/focal_loss/__init__.py | 9 - apex/contrib/focal_loss/focal_loss.py | 60 - apex/contrib/groupbn/__init__.py | 9 - apex/contrib/groupbn/batch_norm.py | 260 -- apex/contrib/index_mul_2d/__init__.py | 1 - apex/contrib/index_mul_2d/index_mul_2d.py | 144 - apex/contrib/layer_norm/__init__.py | 1 - apex/contrib/layer_norm/layer_norm.py | 53 - apex/contrib/multihead_attn/MHA_bwd.png | Bin 86630 -> 0 bytes apex/contrib/multihead_attn/MHA_fwd.png | Bin 84392 -> 0 bytes apex/contrib/multihead_attn/README.md | 60 - apex/contrib/multihead_attn/__init__.py | 3 - .../multihead_attn/encdec_multihead_attn.py | 190 - .../encdec_multihead_attn_func.py | 357 -- .../fast_encdec_multihead_attn_func.py | 121 - ...ast_encdec_multihead_attn_norm_add_func.py | 159 - .../fast_self_multihead_attn_func.py | 243 - .../fast_self_multihead_attn_norm_add_func.py | 135 - .../mask_softmax_dropout_func.py | 64 - .../multihead_attn/self_multihead_attn.py | 255 -- .../self_multihead_attn_func.py | 308 -- apex/contrib/optimizers/__init__.py | 3 - .../optimizers/distributed_fused_adam.py | 1280 ------ .../optimizers/distributed_fused_lamb.py | 722 --- apex/contrib/optimizers/fp16_optimizer.py | 243 - apex/contrib/optimizers/fused_adam.py | 206 - apex/contrib/optimizers/fused_lamb.py | 208 - apex/contrib/optimizers/fused_sgd.py | 211 - apex/contrib/peer_memory/__init__.py | 3 - .../peer_halo_exchange_module_tests.py | 164 - .../peer_memory/peer_halo_exchanger_1d.py | 65 - apex/contrib/peer_memory/peer_memory.py | 87 - apex/contrib/sparsity/README.md | 134 - apex/contrib/sparsity/__init__.py | 2 - apex/contrib/sparsity/asp.py | 312 -- apex/contrib/sparsity/permutation_lib.py | 925 ---- .../permutation_search_kernels.cu | 371 -- .../permutation_search_kernels/__init__.py | 2 - .../call_permutation_search_kernels.py | 74 - .../exhaustive_search.py | 371 -- .../permutation_utilities.py | 113 - apex/contrib/sparsity/sparse_masklib.py | 184 - .../sparsity/test/checkpointing_test_part1.py | 94 - .../sparsity/test/checkpointing_test_part2.py | 79 - .../test/checkpointing_test_reference.py | 96 - apex/contrib/sparsity/test/toy_problem.py | 87 - apex/contrib/test/clip_grad/test_clip_grad.py | 162 - .../conv_bias_relu/test_conv_bias_relu.py | 105 - apex/contrib/test/fmha/test_fmha.py | 136 - .../test/focal_loss/test_focal_loss.py | 69 - .../test/fused_dense/test_fused_dense.py | 44 - apex/contrib/test/groupbn/test_groupbn.py | 185 - .../test/groupbn/test_groupbn_channel_last.py | 194 - .../test/index_mul_2d/test_index_mul_2d.py | 106 - .../test/layer_norm/test_fast_layer_norm.py | 277 -- .../test_encdec_multihead_attn.py | 136 - .../test_encdec_multihead_attn_norm_add.py | 78 - .../test_fast_self_multihead_attn_bias.py | 77 - .../multihead_attn/test_mha_fused_softmax.py | 42 - .../test_self_multihead_attn.py | 130 - .../test_self_multihead_attn_norm_add.py | 73 - .../contrib/test/optimizers/test_dist_adam.py | 391 -- apex/contrib/test/run_rocm_extensions.py | 26 - apex/contrib/test/test_label_smoothing.py | 128 - .../test/transducer/test_transducer_joint.py | 163 - .../test/transducer/test_transducer_loss.py | 133 - .../contrib/test/transducer/transducer_ref.py | 112 - apex/contrib/transducer/__init__.py | 2 - apex/contrib/transducer/transducer.py | 195 - apex/contrib/xentropy/__init__.py | 9 - apex/contrib/xentropy/softmax_xentropy.py | 28 - apex/fp16_utils/README.md | 16 - apex/fp16_utils/__init__.py | 16 - apex/fp16_utils/fp16_optimizer.py | 554 --- apex/fp16_utils/fp16util.py | 187 - apex/fp16_utils/loss_scaler.py | 186 - apex/fused_dense/__init__.py | 1 - apex/fused_dense/fused_dense.py | 85 - apex/mlp/__init__.py | 1 - apex/mlp/mlp.py | 79 - apex/multi_tensor_apply/__init__.py | 5 - apex/multi_tensor_apply/multi_tensor_apply.py | 30 - apex/normalization/__init__.py | 1 - apex/normalization/fused_layer_norm.py | 437 -- apex/optimizers/__init__.py | 7 - apex/optimizers/fused_adagrad.py | 122 - apex/optimizers/fused_adam.py | 193 - apex/optimizers/fused_lamb.py | 215 - apex/optimizers/fused_lars.py | 224 - apex/optimizers/fused_mixed_precision_lamb.py | 256 -- apex/optimizers/fused_novograd.py | 214 - apex/optimizers/fused_sgd.py | 264 -- apex/parallel/LARC.py | 107 - apex/parallel/README.md | 66 - apex/parallel/__init__.py | 95 - apex/parallel/distributed.py | 640 --- apex/parallel/multiproc.py | 35 - apex/parallel/optimized_sync_batchnorm.py | 85 - .../optimized_sync_batchnorm_kernel.py | 119 - apex/parallel/sync_batchnorm.py | 134 - apex/parallel/sync_batchnorm_kernel.py | 87 - apex/testing/__init__.py | 0 apex/testing/common_utils.py | 33 - apex/transformer/README.md | 81 - apex/transformer/__init__.py | 23 - apex/transformer/_data/__init__.py | 8 - apex/transformer/_data/_batchsampler.py | 180 - apex/transformer/amp/__init__.py | 6 - apex/transformer/amp/grad_scaler.py | 119 - apex/transformer/enums.py | 35 - apex/transformer/functional/__init__.py | 5 - apex/transformer/functional/fused_softmax.py | 211 - apex/transformer/layers/__init__.py | 11 - apex/transformer/layers/layer_norm.py | 99 - apex/transformer/log_util.py | 18 - apex/transformer/microbatches.py | 195 - apex/transformer/parallel_state.py | 682 --- .../transformer/pipeline_parallel/__init__.py | 8 - apex/transformer/pipeline_parallel/_timers.py | 83 - .../pipeline_parallel/p2p_communication.py | 578 --- .../pipeline_parallel/schedules/__init__.py | 35 - .../pipeline_parallel/schedules/common.py | 398 -- .../schedules/fwd_bwd_no_pipelining.py | 132 - .../fwd_bwd_pipelining_with_interleaving.py | 415 -- ...fwd_bwd_pipelining_without_interleaving.py | 489 -- apex/transformer/pipeline_parallel/utils.py | 357 -- apex/transformer/tensor_parallel/__init__.py | 75 - .../tensor_parallel/cross_entropy.py | 103 - apex/transformer/tensor_parallel/data.py | 122 - apex/transformer/tensor_parallel/layers.py | 780 ---- apex/transformer/tensor_parallel/mappings.py | 304 -- apex/transformer/tensor_parallel/memory.py | 151 - apex/transformer/tensor_parallel/random.py | 311 -- apex/transformer/tensor_parallel/utils.py | 64 - apex/transformer/testing/__init__.py | 0 apex/transformer/testing/arguments.py | 971 ---- apex/transformer/testing/commons.py | 297 -- .../testing/distributed_test_base.py | 133 - apex/transformer/testing/global_vars.py | 270 -- apex/transformer/testing/standalone_bert.py | 255 -- apex/transformer/testing/standalone_gpt.py | 111 - .../testing/standalone_transformer_lm.py | 1574 ------- apex/transformer/utils.py | 48 - csrc/amp_C_frontend.cpp | 194 - csrc/compat.h | 9 - csrc/flatten_unflatten.cpp | 18 - csrc/fused_dense.cpp | 192 - csrc/fused_dense_cuda.cu | 1525 ------ csrc/layer_norm_cuda.cpp | 442 -- csrc/layer_norm_cuda_kernel.cu | 1229 ----- csrc/megatron/fused_weight_gradient_dense.cpp | 21 - ...d_weight_gradient_dense_16bit_prec_cuda.cu | 155 - .../fused_weight_gradient_dense_cuda.cu | 195 - csrc/megatron/scaled_masked_softmax.cpp | 96 - csrc/megatron/scaled_masked_softmax.h | 505 -- csrc/megatron/scaled_masked_softmax_cuda.cu | 117 - .../scaled_upper_triang_masked_softmax.cpp | 71 - .../scaled_upper_triang_masked_softmax.h | 513 --- ...scaled_upper_triang_masked_softmax_cuda.cu | 98 - csrc/mlp.cpp | 166 - csrc/mlp_cuda.cu | 1783 -------- csrc/multi_tensor_adagrad.cu | 100 - csrc/multi_tensor_adam.cu | 171 - csrc/multi_tensor_apply.cuh | 147 - csrc/multi_tensor_apply_base.cuh | 147 - csrc/multi_tensor_axpby_kernel.cu | 157 - csrc/multi_tensor_l2norm_kernel.cu | 456 -- csrc/multi_tensor_l2norm_kernel_mp.cu | 220 - csrc/multi_tensor_l2norm_scale_kernel.cu | 326 -- csrc/multi_tensor_lamb.cu | 413 -- csrc/multi_tensor_lamb_mp.cu | 496 -- csrc/multi_tensor_lamb_stage_1.cu | 151 - csrc/multi_tensor_lamb_stage_2.cu | 125 - csrc/multi_tensor_lars.cu | 354 -- csrc/multi_tensor_novograd.cu | 188 - csrc/multi_tensor_scale_kernel.cu | 136 - csrc/multi_tensor_sgd_kernel.cu | 322 -- csrc/syncbn.cpp | 109 - csrc/type_shim.h | 491 -- csrc/utils.h | 27 - csrc/welford.cu | 1550 ------- docs/Makefile | 32 - docs/source/_static/css/pytorch_theme.css | 118 - docs/source/_static/img/nv-pytorch2.png | Bin 6502 -> 0 bytes docs/source/_templates/layout.html | 51 - docs/source/advanced.rst | 219 - docs/source/amp.rst | 288 -- docs/source/conf.py | 248 - docs/source/fp16_utils.rst | 59 - docs/source/index.rst | 53 - docs/source/layernorm.rst | 17 - docs/source/optimizers.rst | 23 - docs/source/parallel.rst | 25 - examples/README.md | 4 - examples/dcgan/README.md | 41 - examples/dcgan/main_amp.py | 274 -- examples/docker/Dockerfile | 16 - examples/docker/README.md | 40 - examples/imagenet/README.md | 183 - examples/imagenet/main_amp.py | 543 --- examples/simple/distributed/README.md | 13 - .../distributed/distributed_data_parallel.py | 65 - examples/simple/distributed/run.sh | 2 - get_version.py | 67 - pyproject.toml | 7 - requirements.txt | 6 - requirements_dev.txt | 3 - setup.py | 709 --- tests/L0/run_amp/__init__.py | 0 tests/L0/run_amp/test_add_param_group.py | 159 - tests/L0/run_amp/test_basic_casts.py | 258 -- tests/L0/run_amp/test_cache.py | 158 - tests/L0/run_amp/test_checkpointing.py | 273 -- tests/L0/run_amp/test_fused_sgd.py | 793 ---- tests/L0/run_amp/test_larc.py | 53 - tests/L0/run_amp/test_multi_tensor_axpby.py | 183 - tests/L0/run_amp/test_multi_tensor_l2norm.py | 87 - tests/L0/run_amp/test_multi_tensor_scale.py | 129 - .../test_multiple_models_optimizers_losses.py | 762 --- tests/L0/run_amp/test_promotion.py | 112 - tests/L0/run_amp/test_rnn.py | 121 - tests/L0/run_amp/utils.py | 27 - tests/L0/run_fp16util/__init__.py | 0 tests/L0/run_fp16util/test_fp16util.py | 75 - .../test_fused_layer_norm.py | 298 -- tests/L0/run_mlp/test_mlp.py | 222 - tests/L0/run_optimizers/__init__.py | 0 .../L0/run_optimizers/test_fused_novograd.py | 170 - .../L0/run_optimizers/test_fused_optimizer.py | 310 -- .../test_fused_optimizer_channels_last.py | 112 - tests/L0/run_optimizers/test_lamb.py | 337 -- tests/L0/run_rocm.sh | 2 - tests/L0/run_test.py | 72 - tests/L0/run_transformer/__init__.py | 0 tests/L0/run_transformer/gpt_scaling_test.py | 116 - .../run_transformer/run_bert_minimal_test.py | 260 -- .../run_dynamic_batchsize_test.py | 202 - .../run_transformer/run_gpt_minimal_test.py | 223 - .../L0/run_transformer/test_batch_sampler.py | 142 - .../L0/run_transformer/test_cross_entropy.py | 94 - tests/L0/run_transformer/test_data.py | 64 - .../L0/run_transformer/test_fused_softmax.py | 220 - tests/L0/run_transformer/test_layers.py | 558 --- tests/L0/run_transformer/test_mapping.py | 89 - tests/L0/run_transformer/test_microbatches.py | 85 - tests/L0/run_transformer/test_p2p_comm.py | 122 - .../L0/run_transformer/test_parallel_state.py | 185 - .../test_pipeline_parallel_fwd_bwd.py | 447 -- tests/L0/run_transformer/test_random.py | 120 - .../test_transformer_module.py | 107 - .../run_transformer/test_transformer_utils.py | 40 - tests/L1/common/compare.py | 64 - tests/L1/common/main_amp.py | 526 --- tests/L1/common/run_test.sh | 144 - tests/L1/cross_product/run.sh | 6 - tests/L1/cross_product_distributed/run.sh | 4 - .../pipeline_parallel_fwd_bwd_ucc_async.py | 219 - .../DDP/ddp_race_condition_test.py | 69 - tests/distributed/DDP/run_race_test.sh | 3 - .../amp_master_params/amp_master_params.py | 71 - .../distributed/amp_master_params/compare.py | 31 - tests/distributed/amp_master_params/run.sh | 4 - tests/distributed/run_rocm_distributed.sh | 46 - .../python_single_gpu_unit_test.py | 112 - .../synced_batchnorm/single_gpu_unit_test.py | 162 - .../synced_batchnorm/test_batchnorm1d.py | 18 - .../synced_batchnorm/test_groups.py | 189 - .../two_gpu_test_different_batch_size.py | 158 - .../synced_batchnorm/two_gpu_unit_test.py | 182 - .../distributed/synced_batchnorm/unit_test.sh | 8 - tests/docker_extension_builds/run.sh | 73 - 396 files changed, 94431 deletions(-) delete mode 100644 Dockerfile delete mode 100644 LICENSE delete mode 100644 README.md delete mode 100644 README_ORIGIN.md delete mode 100644 apex/RNN/README.md delete mode 100644 apex/RNN/RNNBackend.py delete mode 100644 apex/RNN/__init__.py delete mode 100644 apex/RNN/cells.py delete mode 100644 apex/RNN/models.py delete mode 100644 apex/__init__.py delete mode 100644 apex/_autocast_utils.py delete mode 100644 apex/amp/README.md delete mode 100644 apex/amp/__init__.py delete mode 100644 apex/amp/__version__.py delete mode 100644 apex/amp/_amp_state.py delete mode 100644 apex/amp/_initialize.py delete mode 100644 apex/amp/_process_optimizer.py delete mode 100644 apex/amp/amp.py delete mode 100644 apex/amp/compat.py delete mode 100644 apex/amp/frontend.py delete mode 100644 apex/amp/handle.py delete mode 100644 apex/amp/lists/__init__.py delete mode 100644 apex/amp/lists/functional_overrides.py delete mode 100644 apex/amp/lists/tensor_overrides.py delete mode 100644 apex/amp/lists/torch_overrides.py delete mode 100644 apex/amp/opt.py delete mode 100644 apex/amp/rnn_compat.py delete mode 100644 apex/amp/scaler.py delete mode 100644 apex/amp/utils.py delete mode 100644 apex/amp/wrap.py delete mode 100644 apex/contrib/__init__.py delete mode 100644 apex/contrib/bottleneck/__init__.py delete mode 100644 apex/contrib/bottleneck/bottleneck.py delete mode 100644 apex/contrib/bottleneck/bottleneck_module_test.py delete mode 100644 apex/contrib/bottleneck/halo_exchangers.py delete mode 100644 apex/contrib/bottleneck/test.py delete mode 100644 apex/contrib/clip_grad/__init__.py delete mode 100644 apex/contrib/clip_grad/clip_grad.py delete mode 100644 apex/contrib/conv_bias_relu/__init__.py delete mode 100644 apex/contrib/conv_bias_relu/conv_bias_relu.py delete mode 100644 apex/contrib/csrc/bottleneck/bottleneck.cpp delete mode 100644 apex/contrib/csrc/conv_bias_relu/conv_bias_relu.cpp delete mode 160000 apex/contrib/csrc/cudnn-frontend delete mode 100644 apex/contrib/csrc/fmha/fmha_api.cpp delete mode 100644 apex/contrib/csrc/fmha/src/fmha.h delete mode 100644 apex/contrib/csrc/fmha/src/fmha/gemm.h delete mode 100644 apex/contrib/csrc/fmha/src/fmha/gmem_tile.h delete mode 100644 apex/contrib/csrc/fmha/src/fmha/kernel_traits.h delete mode 100644 apex/contrib/csrc/fmha/src/fmha/mask.h delete mode 100644 apex/contrib/csrc/fmha/src/fmha/smem_tile.h delete mode 100644 apex/contrib/csrc/fmha/src/fmha/softmax.h delete mode 100644 apex/contrib/csrc/fmha/src/fmha/utils.h delete mode 100644 apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu delete mode 100644 apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_256_64_kernel.sm80.cu delete mode 100644 apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu delete mode 100644 apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu delete mode 100644 apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload.h delete mode 100644 apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload_nl.h delete mode 100644 apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu delete mode 100644 apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu delete mode 100644 apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu delete mode 100644 apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu delete mode 100644 apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN.h delete mode 100644 apex/contrib/csrc/fmha/src/fmha_kernel.h delete mode 100644 apex/contrib/csrc/fmha/src/fmha_noloop_reduce.cu delete mode 100644 apex/contrib/csrc/fmha/src/fmha_utils.h delete mode 100644 apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp delete mode 100644 apex/contrib/csrc/focal_loss/focal_loss_cuda_kernel.cu delete mode 100644 apex/contrib/csrc/groupbn/batch_norm.cu delete mode 100644 apex/contrib/csrc/groupbn/batch_norm.h delete mode 100644 apex/contrib/csrc/groupbn/batch_norm_add_relu.cu delete mode 100644 apex/contrib/csrc/groupbn/batch_norm_add_relu.h delete mode 100644 apex/contrib/csrc/groupbn/cuda_utils.h delete mode 100644 apex/contrib/csrc/groupbn/dnn.h delete mode 100644 apex/contrib/csrc/groupbn/interface.cpp delete mode 100644 apex/contrib/csrc/groupbn/ipc.cu delete mode 100644 apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h delete mode 100644 apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp delete mode 100644 apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda_kernel.cu delete mode 100644 apex/contrib/csrc/layer_norm/ln.h delete mode 100644 apex/contrib/csrc/layer_norm/ln_api.cpp delete mode 100644 apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh delete mode 100644 apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu delete mode 100644 apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu delete mode 100644 apex/contrib/csrc/layer_norm/ln_fwd_kernels.cuh delete mode 100644 apex/contrib/csrc/layer_norm/ln_kernel_traits.h delete mode 100644 apex/contrib/csrc/layer_norm/ln_utils.cuh delete mode 100644 apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu delete mode 100644 apex/contrib/csrc/multihead_attn/dropout.cuh delete mode 100644 apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu delete mode 100644 apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu delete mode 100644 apex/contrib/csrc/multihead_attn/layer_norm.cuh delete mode 100644 apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu delete mode 100644 apex/contrib/csrc/multihead_attn/multihead_attn_frontend.cpp delete mode 100644 apex/contrib/csrc/multihead_attn/philox.cuh delete mode 100644 apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu delete mode 100644 apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu delete mode 100644 apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu delete mode 100644 apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu delete mode 100644 apex/contrib/csrc/multihead_attn/softmax.cuh delete mode 100644 apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh delete mode 100644 apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp delete mode 100644 apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu delete mode 100644 apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh delete mode 100644 apex/contrib/csrc/optimizers/fused_adam_cuda.cpp delete mode 100644 apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu delete mode 100644 apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp delete mode 100644 apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu delete mode 100644 apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp delete mode 100644 apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu delete mode 100644 apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp delete mode 100644 apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu delete mode 100644 apex/contrib/csrc/peer_memory/peer_memory.cpp delete mode 100644 apex/contrib/csrc/peer_memory/peer_memory_cuda.cu delete mode 100644 apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh delete mode 100755 apex/contrib/csrc/transducer/transducer_joint.cpp delete mode 100755 apex/contrib/csrc/transducer/transducer_joint_kernel.cu delete mode 100644 apex/contrib/csrc/transducer/transducer_loss.cpp delete mode 100755 apex/contrib/csrc/transducer/transducer_loss_kernel.cu delete mode 100644 apex/contrib/csrc/xentropy/interface.cpp delete mode 100644 apex/contrib/csrc/xentropy/xentropy_kernel.cu delete mode 100644 apex/contrib/examples/multihead_attn/func_test_multihead_attn.py delete mode 100644 apex/contrib/examples/multihead_attn/perf_test_multihead_attn.py delete mode 100644 apex/contrib/fmha/__init__.py delete mode 100644 apex/contrib/fmha/fmha.py delete mode 100644 apex/contrib/focal_loss/__init__.py delete mode 100644 apex/contrib/focal_loss/focal_loss.py delete mode 100644 apex/contrib/groupbn/__init__.py delete mode 100644 apex/contrib/groupbn/batch_norm.py delete mode 100644 apex/contrib/index_mul_2d/__init__.py delete mode 100644 apex/contrib/index_mul_2d/index_mul_2d.py delete mode 100644 apex/contrib/layer_norm/__init__.py delete mode 100644 apex/contrib/layer_norm/layer_norm.py delete mode 100644 apex/contrib/multihead_attn/MHA_bwd.png delete mode 100644 apex/contrib/multihead_attn/MHA_fwd.png delete mode 100644 apex/contrib/multihead_attn/README.md delete mode 100644 apex/contrib/multihead_attn/__init__.py delete mode 100644 apex/contrib/multihead_attn/encdec_multihead_attn.py delete mode 100644 apex/contrib/multihead_attn/encdec_multihead_attn_func.py delete mode 100644 apex/contrib/multihead_attn/fast_encdec_multihead_attn_func.py delete mode 100644 apex/contrib/multihead_attn/fast_encdec_multihead_attn_norm_add_func.py delete mode 100644 apex/contrib/multihead_attn/fast_self_multihead_attn_func.py delete mode 100644 apex/contrib/multihead_attn/fast_self_multihead_attn_norm_add_func.py delete mode 100644 apex/contrib/multihead_attn/mask_softmax_dropout_func.py delete mode 100644 apex/contrib/multihead_attn/self_multihead_attn.py delete mode 100644 apex/contrib/multihead_attn/self_multihead_attn_func.py delete mode 100644 apex/contrib/optimizers/__init__.py delete mode 100644 apex/contrib/optimizers/distributed_fused_adam.py delete mode 100644 apex/contrib/optimizers/distributed_fused_lamb.py delete mode 100755 apex/contrib/optimizers/fp16_optimizer.py delete mode 100644 apex/contrib/optimizers/fused_adam.py delete mode 100644 apex/contrib/optimizers/fused_lamb.py delete mode 100644 apex/contrib/optimizers/fused_sgd.py delete mode 100644 apex/contrib/peer_memory/__init__.py delete mode 100644 apex/contrib/peer_memory/peer_halo_exchange_module_tests.py delete mode 100644 apex/contrib/peer_memory/peer_halo_exchanger_1d.py delete mode 100644 apex/contrib/peer_memory/peer_memory.py delete mode 100644 apex/contrib/sparsity/README.md delete mode 100644 apex/contrib/sparsity/__init__.py delete mode 100644 apex/contrib/sparsity/asp.py delete mode 100644 apex/contrib/sparsity/permutation_lib.py delete mode 100644 apex/contrib/sparsity/permutation_search_kernels/CUDA_kernels/permutation_search_kernels.cu delete mode 100644 apex/contrib/sparsity/permutation_search_kernels/__init__.py delete mode 100644 apex/contrib/sparsity/permutation_search_kernels/call_permutation_search_kernels.py delete mode 100644 apex/contrib/sparsity/permutation_search_kernels/exhaustive_search.py delete mode 100644 apex/contrib/sparsity/permutation_search_kernels/permutation_utilities.py delete mode 100644 apex/contrib/sparsity/sparse_masklib.py delete mode 100644 apex/contrib/sparsity/test/checkpointing_test_part1.py delete mode 100644 apex/contrib/sparsity/test/checkpointing_test_part2.py delete mode 100644 apex/contrib/sparsity/test/checkpointing_test_reference.py delete mode 100644 apex/contrib/sparsity/test/toy_problem.py delete mode 100644 apex/contrib/test/clip_grad/test_clip_grad.py delete mode 100644 apex/contrib/test/conv_bias_relu/test_conv_bias_relu.py delete mode 100644 apex/contrib/test/fmha/test_fmha.py delete mode 100644 apex/contrib/test/focal_loss/test_focal_loss.py delete mode 100644 apex/contrib/test/fused_dense/test_fused_dense.py delete mode 100644 apex/contrib/test/groupbn/test_groupbn.py delete mode 100644 apex/contrib/test/groupbn/test_groupbn_channel_last.py delete mode 100644 apex/contrib/test/index_mul_2d/test_index_mul_2d.py delete mode 100644 apex/contrib/test/layer_norm/test_fast_layer_norm.py delete mode 100644 apex/contrib/test/multihead_attn/test_encdec_multihead_attn.py delete mode 100644 apex/contrib/test/multihead_attn/test_encdec_multihead_attn_norm_add.py delete mode 100644 apex/contrib/test/multihead_attn/test_fast_self_multihead_attn_bias.py delete mode 100644 apex/contrib/test/multihead_attn/test_mha_fused_softmax.py delete mode 100644 apex/contrib/test/multihead_attn/test_self_multihead_attn.py delete mode 100644 apex/contrib/test/multihead_attn/test_self_multihead_attn_norm_add.py delete mode 100644 apex/contrib/test/optimizers/test_dist_adam.py delete mode 100644 apex/contrib/test/run_rocm_extensions.py delete mode 100644 apex/contrib/test/test_label_smoothing.py delete mode 100755 apex/contrib/test/transducer/test_transducer_joint.py delete mode 100755 apex/contrib/test/transducer/test_transducer_loss.py delete mode 100755 apex/contrib/test/transducer/transducer_ref.py delete mode 100755 apex/contrib/transducer/__init__.py delete mode 100755 apex/contrib/transducer/transducer.py delete mode 100644 apex/contrib/xentropy/__init__.py delete mode 100644 apex/contrib/xentropy/softmax_xentropy.py delete mode 100644 apex/fp16_utils/README.md delete mode 100644 apex/fp16_utils/__init__.py delete mode 100755 apex/fp16_utils/fp16_optimizer.py delete mode 100644 apex/fp16_utils/fp16util.py delete mode 100644 apex/fp16_utils/loss_scaler.py delete mode 100644 apex/fused_dense/__init__.py delete mode 100644 apex/fused_dense/fused_dense.py delete mode 100644 apex/mlp/__init__.py delete mode 100644 apex/mlp/mlp.py delete mode 100644 apex/multi_tensor_apply/__init__.py delete mode 100644 apex/multi_tensor_apply/multi_tensor_apply.py delete mode 100644 apex/normalization/__init__.py delete mode 100644 apex/normalization/fused_layer_norm.py delete mode 100644 apex/optimizers/__init__.py delete mode 100644 apex/optimizers/fused_adagrad.py delete mode 100644 apex/optimizers/fused_adam.py delete mode 100644 apex/optimizers/fused_lamb.py delete mode 100644 apex/optimizers/fused_lars.py delete mode 100644 apex/optimizers/fused_mixed_precision_lamb.py delete mode 100644 apex/optimizers/fused_novograd.py delete mode 100644 apex/optimizers/fused_sgd.py delete mode 100644 apex/parallel/LARC.py delete mode 100644 apex/parallel/README.md delete mode 100644 apex/parallel/__init__.py delete mode 100644 apex/parallel/distributed.py delete mode 100644 apex/parallel/multiproc.py delete mode 100644 apex/parallel/optimized_sync_batchnorm.py delete mode 100644 apex/parallel/optimized_sync_batchnorm_kernel.py delete mode 100644 apex/parallel/sync_batchnorm.py delete mode 100644 apex/parallel/sync_batchnorm_kernel.py delete mode 100644 apex/testing/__init__.py delete mode 100644 apex/testing/common_utils.py delete mode 100644 apex/transformer/README.md delete mode 100644 apex/transformer/__init__.py delete mode 100644 apex/transformer/_data/__init__.py delete mode 100644 apex/transformer/_data/_batchsampler.py delete mode 100644 apex/transformer/amp/__init__.py delete mode 100644 apex/transformer/amp/grad_scaler.py delete mode 100644 apex/transformer/enums.py delete mode 100644 apex/transformer/functional/__init__.py delete mode 100644 apex/transformer/functional/fused_softmax.py delete mode 100644 apex/transformer/layers/__init__.py delete mode 100644 apex/transformer/layers/layer_norm.py delete mode 100644 apex/transformer/log_util.py delete mode 100644 apex/transformer/microbatches.py delete mode 100644 apex/transformer/parallel_state.py delete mode 100644 apex/transformer/pipeline_parallel/__init__.py delete mode 100644 apex/transformer/pipeline_parallel/_timers.py delete mode 100644 apex/transformer/pipeline_parallel/p2p_communication.py delete mode 100644 apex/transformer/pipeline_parallel/schedules/__init__.py delete mode 100644 apex/transformer/pipeline_parallel/schedules/common.py delete mode 100644 apex/transformer/pipeline_parallel/schedules/fwd_bwd_no_pipelining.py delete mode 100644 apex/transformer/pipeline_parallel/schedules/fwd_bwd_pipelining_with_interleaving.py delete mode 100644 apex/transformer/pipeline_parallel/schedules/fwd_bwd_pipelining_without_interleaving.py delete mode 100644 apex/transformer/pipeline_parallel/utils.py delete mode 100644 apex/transformer/tensor_parallel/__init__.py delete mode 100644 apex/transformer/tensor_parallel/cross_entropy.py delete mode 100644 apex/transformer/tensor_parallel/data.py delete mode 100644 apex/transformer/tensor_parallel/layers.py delete mode 100644 apex/transformer/tensor_parallel/mappings.py delete mode 100644 apex/transformer/tensor_parallel/memory.py delete mode 100644 apex/transformer/tensor_parallel/random.py delete mode 100644 apex/transformer/tensor_parallel/utils.py delete mode 100644 apex/transformer/testing/__init__.py delete mode 100644 apex/transformer/testing/arguments.py delete mode 100644 apex/transformer/testing/commons.py delete mode 100644 apex/transformer/testing/distributed_test_base.py delete mode 100644 apex/transformer/testing/global_vars.py delete mode 100644 apex/transformer/testing/standalone_bert.py delete mode 100644 apex/transformer/testing/standalone_gpt.py delete mode 100644 apex/transformer/testing/standalone_transformer_lm.py delete mode 100644 apex/transformer/utils.py delete mode 100644 csrc/amp_C_frontend.cpp delete mode 100644 csrc/compat.h delete mode 100644 csrc/flatten_unflatten.cpp delete mode 100644 csrc/fused_dense.cpp delete mode 100644 csrc/fused_dense_cuda.cu delete mode 100644 csrc/layer_norm_cuda.cpp delete mode 100644 csrc/layer_norm_cuda_kernel.cu delete mode 100644 csrc/megatron/fused_weight_gradient_dense.cpp delete mode 100644 csrc/megatron/fused_weight_gradient_dense_16bit_prec_cuda.cu delete mode 100644 csrc/megatron/fused_weight_gradient_dense_cuda.cu delete mode 100644 csrc/megatron/scaled_masked_softmax.cpp delete mode 100644 csrc/megatron/scaled_masked_softmax.h delete mode 100644 csrc/megatron/scaled_masked_softmax_cuda.cu delete mode 100644 csrc/megatron/scaled_upper_triang_masked_softmax.cpp delete mode 100644 csrc/megatron/scaled_upper_triang_masked_softmax.h delete mode 100644 csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu delete mode 100644 csrc/mlp.cpp delete mode 100644 csrc/mlp_cuda.cu delete mode 100644 csrc/multi_tensor_adagrad.cu delete mode 100644 csrc/multi_tensor_adam.cu delete mode 100644 csrc/multi_tensor_apply.cuh delete mode 100644 csrc/multi_tensor_apply_base.cuh delete mode 100644 csrc/multi_tensor_axpby_kernel.cu delete mode 100644 csrc/multi_tensor_l2norm_kernel.cu delete mode 100644 csrc/multi_tensor_l2norm_kernel_mp.cu delete mode 100644 csrc/multi_tensor_l2norm_scale_kernel.cu delete mode 100644 csrc/multi_tensor_lamb.cu delete mode 100644 csrc/multi_tensor_lamb_mp.cu delete mode 100644 csrc/multi_tensor_lamb_stage_1.cu delete mode 100644 csrc/multi_tensor_lamb_stage_2.cu delete mode 100644 csrc/multi_tensor_lars.cu delete mode 100644 csrc/multi_tensor_novograd.cu delete mode 100644 csrc/multi_tensor_scale_kernel.cu delete mode 100644 csrc/multi_tensor_sgd_kernel.cu delete mode 100644 csrc/syncbn.cpp delete mode 100644 csrc/type_shim.h delete mode 100644 csrc/utils.h delete mode 100644 csrc/welford.cu delete mode 100644 docs/Makefile delete mode 100644 docs/source/_static/css/pytorch_theme.css delete mode 100644 docs/source/_static/img/nv-pytorch2.png delete mode 100644 docs/source/_templates/layout.html delete mode 100644 docs/source/advanced.rst delete mode 100644 docs/source/amp.rst delete mode 100644 docs/source/conf.py delete mode 100644 docs/source/fp16_utils.rst delete mode 100644 docs/source/index.rst delete mode 100644 docs/source/layernorm.rst delete mode 100644 docs/source/optimizers.rst delete mode 100644 docs/source/parallel.rst delete mode 100644 examples/README.md delete mode 100644 examples/dcgan/README.md delete mode 100644 examples/dcgan/main_amp.py delete mode 100644 examples/docker/Dockerfile delete mode 100644 examples/docker/README.md delete mode 100644 examples/imagenet/README.md delete mode 100644 examples/imagenet/main_amp.py delete mode 100644 examples/simple/distributed/README.md delete mode 100644 examples/simple/distributed/distributed_data_parallel.py delete mode 100644 examples/simple/distributed/run.sh delete mode 100644 get_version.py delete mode 100644 pyproject.toml delete mode 100644 requirements.txt delete mode 100644 requirements_dev.txt delete mode 100644 setup.py delete mode 100644 tests/L0/run_amp/__init__.py delete mode 100644 tests/L0/run_amp/test_add_param_group.py delete mode 100644 tests/L0/run_amp/test_basic_casts.py delete mode 100644 tests/L0/run_amp/test_cache.py delete mode 100644 tests/L0/run_amp/test_checkpointing.py delete mode 100644 tests/L0/run_amp/test_fused_sgd.py delete mode 100644 tests/L0/run_amp/test_larc.py delete mode 100644 tests/L0/run_amp/test_multi_tensor_axpby.py delete mode 100644 tests/L0/run_amp/test_multi_tensor_l2norm.py delete mode 100644 tests/L0/run_amp/test_multi_tensor_scale.py delete mode 100644 tests/L0/run_amp/test_multiple_models_optimizers_losses.py delete mode 100644 tests/L0/run_amp/test_promotion.py delete mode 100644 tests/L0/run_amp/test_rnn.py delete mode 100644 tests/L0/run_amp/utils.py delete mode 100644 tests/L0/run_fp16util/__init__.py delete mode 100644 tests/L0/run_fp16util/test_fp16util.py delete mode 100644 tests/L0/run_fused_layer_norm/test_fused_layer_norm.py delete mode 100644 tests/L0/run_mlp/test_mlp.py delete mode 100644 tests/L0/run_optimizers/__init__.py delete mode 100755 tests/L0/run_optimizers/test_fused_novograd.py delete mode 100644 tests/L0/run_optimizers/test_fused_optimizer.py delete mode 100644 tests/L0/run_optimizers/test_fused_optimizer_channels_last.py delete mode 100644 tests/L0/run_optimizers/test_lamb.py delete mode 100755 tests/L0/run_rocm.sh delete mode 100644 tests/L0/run_test.py delete mode 100644 tests/L0/run_transformer/__init__.py delete mode 100644 tests/L0/run_transformer/gpt_scaling_test.py delete mode 100644 tests/L0/run_transformer/run_bert_minimal_test.py delete mode 100644 tests/L0/run_transformer/run_dynamic_batchsize_test.py delete mode 100644 tests/L0/run_transformer/run_gpt_minimal_test.py delete mode 100644 tests/L0/run_transformer/test_batch_sampler.py delete mode 100644 tests/L0/run_transformer/test_cross_entropy.py delete mode 100644 tests/L0/run_transformer/test_data.py delete mode 100644 tests/L0/run_transformer/test_fused_softmax.py delete mode 100644 tests/L0/run_transformer/test_layers.py delete mode 100644 tests/L0/run_transformer/test_mapping.py delete mode 100644 tests/L0/run_transformer/test_microbatches.py delete mode 100644 tests/L0/run_transformer/test_p2p_comm.py delete mode 100644 tests/L0/run_transformer/test_parallel_state.py delete mode 100644 tests/L0/run_transformer/test_pipeline_parallel_fwd_bwd.py delete mode 100644 tests/L0/run_transformer/test_random.py delete mode 100644 tests/L0/run_transformer/test_transformer_module.py delete mode 100644 tests/L0/run_transformer/test_transformer_utils.py delete mode 100644 tests/L1/common/compare.py delete mode 100644 tests/L1/common/main_amp.py delete mode 100644 tests/L1/common/run_test.sh delete mode 100644 tests/L1/cross_product/run.sh delete mode 100644 tests/L1/cross_product_distributed/run.sh delete mode 100644 tests/L1/transformer/pipeline_parallel_fwd_bwd_ucc_async.py delete mode 100644 tests/distributed/DDP/ddp_race_condition_test.py delete mode 100644 tests/distributed/DDP/run_race_test.sh delete mode 100644 tests/distributed/amp_master_params/amp_master_params.py delete mode 100644 tests/distributed/amp_master_params/compare.py delete mode 100644 tests/distributed/amp_master_params/run.sh delete mode 100644 tests/distributed/run_rocm_distributed.sh delete mode 100644 tests/distributed/synced_batchnorm/python_single_gpu_unit_test.py delete mode 100644 tests/distributed/synced_batchnorm/single_gpu_unit_test.py delete mode 100644 tests/distributed/synced_batchnorm/test_batchnorm1d.py delete mode 100644 tests/distributed/synced_batchnorm/test_groups.py delete mode 100755 tests/distributed/synced_batchnorm/two_gpu_test_different_batch_size.py delete mode 100644 tests/distributed/synced_batchnorm/two_gpu_unit_test.py delete mode 100755 tests/distributed/synced_batchnorm/unit_test.sh delete mode 100644 tests/docker_extension_builds/run.sh diff --git a/Dockerfile b/Dockerfile deleted file mode 100644 index 8bf9a17..0000000 --- a/Dockerfile +++ /dev/null @@ -1,7 +0,0 @@ -ARG FROM_IMAGE=lcskrishna/rocm-pytorch:rocm3.3_ubuntu16.04_py3.6_pytorch_bfloat16_mgpu - -FROM ${FROM_IMAGE} -RUN \ - git clone --recursive https://github.com/ROCmSoftwarePlatform/apex.git && \ - cd apex && \ - python3.6 setup.py install --cpp_ext --cuda_ext diff --git a/LICENSE b/LICENSE deleted file mode 100644 index 3d1e945..0000000 --- a/LICENSE +++ /dev/null @@ -1,11 +0,0 @@ -All rights reserved. - -Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - -1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. - -2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. - -3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/README.md b/README.md deleted file mode 100644 index a53d6fd..0000000 --- a/README.md +++ /dev/null @@ -1,69 +0,0 @@ -# APEX - -## 介绍 - -[Introduction](README_ORIGIN.md) - -## 安装 - -### System Requirements - -- Linux. - -- Python 3.7, 3.8, 3.9 - -- (**推荐**) Upgrade pip - - ``` - python3 -m pip install --upgrade pip #--user - ``` - -### 使用pip安装(以dtk-23.04版本为例) -可以在光合[光合开发者社区](https://developer.hpccube.com/tool/#sdk) AI 生态包中获取最新的 apex Release 版本(需对应 DCU Toolkit 版本与 python 版本) -```bash -python3 -m pip install apex-0.1+git2d8b360.abi0.dtk2304-cp37-cp37m-linux_x86_64.whl -``` - -### 使用源码安装 - -#### 编译环境准备(以dtk-23.04版本为例) - -- 拉取 apex 代码 - - ``` - git clone -b dtk-23.04 http://developer.hpccube.com/codes/aicomponent/apex.git - ``` - -- 在[开发者社区](https://developer.hpccube.com/tool/#sdk) DCU Toolkit 中下载 DTK-23.04 解压至 /opt/ 路径下,并建立软链接 - - ``` - cd /opt && ln -s dtk-23.04 dtk - ``` - -- 在光合[光合开发者社区](https://developer.hpccube.com/tool/#sdk) AI 生态包中获取对应的 pytorch Release 版本(需对应 DCU Toolkit 版本与 python 版本) - ```bash - python3 -m pip install torch-1.13.1a0+git4c8a1fe.abi0.dtk2304-cp37-cp37m-linux_x86_64.whl - ``` - -- 导入环境变量以及安装必要依赖库 - - ```bash - source /opt/dtk/env.sh - - export PYTORCH_ROCM_ARCH="gfx906;gfx926" - - MAX_JOBS=16 - pip3 install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple --trusted-host pypi.tuna.tsinghua.edu.cn - pip3 install wheel -i https://pypi.tuna.tsinghua.edu.cn/simple --trusted-host pypi.tuna.tsinghua.edu.cn - ``` - - -#### 编译安装 - -- 执行编译命令 - ```shell - cd apex - CXX=hipcc CC=hipcc python3 setup.py --cpp_ext --cuda_ext bdist_wheel - pip install dist/apex* - ``` - diff --git a/README_ORIGIN.md b/README_ORIGIN.md deleted file mode 100644 index 5dd33f9..0000000 --- a/README_ORIGIN.md +++ /dev/null @@ -1,162 +0,0 @@ -# Introduction - -This repository holds NVIDIA-maintained utilities to streamline mixed precision and distributed training in Pytorch. -Some of the code here will be included in upstream Pytorch eventually. -The intent of Apex is to make up-to-date utilities available to users as quickly as possible. - -## Full API Documentation: [https://nvidia.github.io/apex](https://nvidia.github.io/apex) - -## [GTC 2019](https://github.com/mcarilli/mixed_precision_references/tree/master/GTC_2019) and [Pytorch DevCon 2019](https://github.com/mcarilli/mixed_precision_references/tree/master/Pytorch_Devcon_2019) Slides - -# Contents - -## 1. Amp: Automatic Mixed Precision - -`apex.amp` is a tool to enable mixed precision training by changing only 3 lines of your script. -Users can easily experiment with different pure and mixed precision training modes by supplying -different flags to `amp.initialize`. - -[Webinar introducing Amp](https://info.nvidia.com/webinar-mixed-precision-with-pytorch-reg-page.html) -(The flag `cast_batchnorm` has been renamed to `keep_batchnorm_fp32`). - -[API Documentation](https://nvidia.github.io/apex/amp.html) - -[Comprehensive Imagenet example](https://github.com/NVIDIA/apex/tree/master/examples/imagenet) - -[DCGAN example coming soon...](https://github.com/NVIDIA/apex/tree/master/examples/dcgan) - -[Moving to the new Amp API](https://nvidia.github.io/apex/amp.html#transition-guide-for-old-api-users) (for users of the deprecated "Amp" and "FP16_Optimizer" APIs) - -## 2. Distributed Training - -`apex.parallel.DistributedDataParallel` is a module wrapper, similar to -`torch.nn.parallel.DistributedDataParallel`. It enables convenient multiprocess distributed training, -optimized for NVIDIA's NCCL communication library. - -[API Documentation](https://nvidia.github.io/apex/parallel.html) - -[Python Source](https://github.com/NVIDIA/apex/tree/master/apex/parallel) - -[Example/Walkthrough](https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed) - -The [Imagenet example](https://github.com/NVIDIA/apex/tree/master/examples/imagenet) -shows use of `apex.parallel.DistributedDataParallel` along with `apex.amp`. - -### Synchronized Batch Normalization - -`apex.parallel.SyncBatchNorm` extends `torch.nn.modules.batchnorm._BatchNorm` to -support synchronized BN. -It allreduces stats across processes during multiprocess (DistributedDataParallel) training. -Synchronous BN has been used in cases where only a small -local minibatch can fit on each GPU. -Allreduced stats increase the effective batch size for the BN layer to the -global batch size across all processes (which, technically, is the correct -formulation). -Synchronous BN has been observed to improve converged accuracy in some of our research models. - -### Checkpointing - -To properly save and load your `amp` training, we introduce the `amp.state_dict()`, which contains all `loss_scalers` and their corresponding unskipped steps, -as well as `amp.load_state_dict()` to restore these attributes. - -In order to get bitwise accuracy, we recommend the following workflow: -```python -# Initialization -opt_level = 'O1' -model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level) - -# Train your model -... -with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() -... - -# Save checkpoint -checkpoint = { - 'model': model.state_dict(), - 'optimizer': optimizer.state_dict(), - 'amp': amp.state_dict() -} -torch.save(checkpoint, 'amp_checkpoint.pt') -... - -# Restore -model = ... -optimizer = ... -checkpoint = torch.load('amp_checkpoint.pt') - -model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level) -model.load_state_dict(checkpoint['model']) -optimizer.load_state_dict(checkpoint['optimizer']) -amp.load_state_dict(checkpoint['amp']) - -# Continue training -... -``` - -Note that we recommend restoring the model using the same `opt_level`. Also note that we recommend calling the `load_state_dict` methods after `amp.initialize`. - -# Installation - -## Containers -NVIDIA PyTorch Containers are available on NGC: https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch. -The containers come with all the custom extensions available at the moment. - -See [the NGC documentation](https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/index.html) for details such as: -- how to pull a container -- how to run a pulled container -- release notes - -## From Source - -To install Apex from source, we recommend using the nightly Pytorch obtainable from https://github.com/pytorch/pytorch. - -The latest stable release obtainable from https://pytorch.org should also work. - -### Rocm -Apex on ROCm supports both python only build and extension build. -Note: Pytorch version recommended is >=1.5 for extension build. - -### To install using python only build use the following command in apex folder: -``` -python setup.py install -``` - -### To install using extensions enabled use the following command in apex folder: -``` -# if pip >= 23.1 (ref: https://pip.pypa.io/en/stable/news/#v23-1) which supports multiple `--config-settings` with the same key... -pip install -v --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ -# otherwise -python setup.py install --cpp_ext --cuda_ext - -``` -Note that using --cuda_ext flag to install Apex will also enable all the extensions supported on ROCm including "--distributed_adam", "--distributed_lamb", "--bnp", "--xentropy", "--deprecated_fused_adam", "--deprecated_fused_lamb", and "--fast_multihead_attn". - -### Linux -For performance and full functionality, we recommend installing Apex with -CUDA and C++ extensions via -```bash -git clone https://github.com/NVIDIA/apex -cd apex -# if pip >= 23.1 (ref: https://pip.pypa.io/en/stable/news/#v23-1) which supports multiple `--config-settings` with the same key... -pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ -# otherwise -pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --global-option="--cpp_ext" --global-option="--cuda_ext" ./ -``` - -Apex also supports a Python-only build via -```bash -pip install -v --disable-pip-version-check --no-build-isolation --no-cache-dir ./ -``` -A Python-only build omits: -- Fused kernels required to use `apex.optimizers.FusedAdam`. -- Fused kernels required to use `apex.normalization.FusedLayerNorm` and `apex.normalization.FusedRMSNorm`. -- Fused kernels that improve the performance and numerical stability of `apex.parallel.SyncBatchNorm`. -- Fused kernels that improve the performance of `apex.parallel.DistributedDataParallel` and `apex.amp`. -`DistributedDataParallel`, `amp`, and `SyncBatchNorm` will still be usable, but they may be slower. - - -### [Experimental] Windows -`pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .` may work if you were able to build Pytorch from source -on your system. A Python-only build via `pip install -v --no-cache-dir .` is more likely to work. -If you installed Pytorch in a Conda environment, make sure to install Apex in that same environment. \ No newline at end of file diff --git a/apex/RNN/README.md b/apex/RNN/README.md deleted file mode 100644 index 9e86fd8..0000000 --- a/apex/RNN/README.md +++ /dev/null @@ -1 +0,0 @@ -Under construction... diff --git a/apex/RNN/RNNBackend.py b/apex/RNN/RNNBackend.py deleted file mode 100644 index a9382e6..0000000 --- a/apex/RNN/RNNBackend.py +++ /dev/null @@ -1,365 +0,0 @@ -import torch -import torch.nn as nn -from torch.autograd import Variable - -import torch.nn.functional as F - -import math - - -def is_iterable(maybe_iterable): - return isinstance(maybe_iterable, list) or isinstance(maybe_iterable, tuple) - - -def flatten_list(tens_list): - """ - flatten_list - """ - if not is_iterable(tens_list): - return tens_list - - return torch.cat(tens_list, dim=0).view(len(tens_list), *tens_list[0].size() ) - - -#These modules always assumes batch_first -class bidirectionalRNN(nn.Module): - """ - bidirectionalRNN - """ - def __init__(self, inputRNN, num_layers=1, dropout = 0): - super(bidirectionalRNN, self).__init__() - self.dropout = dropout - self.fwd = stackedRNN(inputRNN, num_layers=num_layers, dropout = dropout) - self.bckwrd = stackedRNN(inputRNN.new_like(), num_layers=num_layers, dropout = dropout) - self.rnns = nn.ModuleList([self.fwd, self.bckwrd]) - - #collect hidden option will return all hidden/cell states from entire RNN - def forward(self, input, collect_hidden=False): - """ - forward() - """ - seq_len = input.size(0) - bsz = input.size(1) - - fwd_out, fwd_hiddens = list(self.fwd(input, collect_hidden = collect_hidden)) - bckwrd_out, bckwrd_hiddens = list(self.bckwrd(input, reverse=True, collect_hidden = collect_hidden)) - - output = torch.cat( [fwd_out, bckwrd_out], -1 ) - hiddens = tuple( torch.cat(hidden, -1) for hidden in zip( fwd_hiddens, bckwrd_hiddens) ) - - return output, hiddens - - def reset_parameters(self): - """ - reset_parameters() - """ - for rnn in self.rnns: - rnn.reset_parameters() - - def init_hidden(self, bsz): - """ - init_hidden() - """ - for rnn in self.rnns: - rnn.init_hidden(bsz) - - def detach_hidden(self): - """ - detach_hidden() - """ - for rnn in self.rnns: - rnn.detachHidden() - - def reset_hidden(self, bsz): - """ - reset_hidden() - """ - for rnn in self.rnns: - rnn.reset_hidden(bsz) - - def init_inference(self, bsz): - """ - init_inference() - """ - for rnn in self.rnns: - rnn.init_inference(bsz) - - -#assumes hidden_state[0] of inputRNN is output hidden state -#constructor either takes an RNNCell or list of RNN layers -class stackedRNN(nn.Module): - """ - stackedRNN - """ - def __init__(self, inputRNN, num_layers=1, dropout=0): - super(stackedRNN, self).__init__() - - self.dropout = dropout - - if isinstance(inputRNN, RNNCell): - self.rnns = [inputRNN] - for i in range(num_layers-1): - self.rnns.append(inputRNN.new_like(inputRNN.output_size)) - elif isinstance(inputRNN, list): - assert len(inputRNN) == num_layers, "RNN list length must be equal to num_layers" - self.rnns=inputRNN - else: - raise RuntimeError() - - self.nLayers = len(self.rnns) - - self.rnns = nn.ModuleList(self.rnns) - - - ''' - Returns output as hidden_state[0] Tensor([sequence steps][batch size][features]) - If collect hidden will also return Tuple( - [n_hidden_states][sequence steps] Tensor([layer][batch size][features]) - ) - If not collect hidden will also return Tuple( - [n_hidden_states] Tensor([layer][batch size][features]) - ''' - def forward(self, input, collect_hidden=False, reverse=False): - """ - forward() - """ - seq_len = input.size(0) - bsz = input.size(1) - inp_iter = reversed(range(seq_len)) if reverse else range(seq_len) - - hidden_states = [[] for i in range(self.nLayers)] - outputs = [] - - for seq in inp_iter: - for layer in range(self.nLayers): - - if layer == 0: - prev_out = input[seq] - - outs = self.rnns[layer](prev_out) - - if collect_hidden: - hidden_states[layer].append(outs) - elif seq == seq_len-1: - hidden_states[layer].append(outs) - - prev_out = outs[0] - - outputs.append(prev_out) - - if reverse: - outputs = list(reversed(outputs)) - ''' - At this point outputs is in format: - list( [seq_length] x Tensor([bsz][features]) ) - need to convert it to: - list( Tensor([seq_length][bsz][features]) ) - ''' - output = flatten_list(outputs) - - ''' - hidden_states at this point is in format: - list( [layer][seq_length][hidden_states] x Tensor([bsz][features]) ) - need to convert it to: - For not collect hidden: - list( [hidden_states] x Tensor([layer][bsz][features]) ) - For collect hidden: - list( [hidden_states][seq_length] x Tensor([layer][bsz][features]) ) - ''' - if not collect_hidden: - seq_len = 1 - n_hid = self.rnns[0].n_hidden_states - new_hidden = [ [ [ None for k in range(self.nLayers)] for j in range(seq_len) ] for i in range(n_hid) ] - - - for i in range(n_hid): - for j in range(seq_len): - for k in range(self.nLayers): - new_hidden[i][j][k] = hidden_states[k][j][i] - - hidden_states = new_hidden - #Now in format list( [hidden_states][seq_length][layer] x Tensor([bsz][features]) ) - #Reverse seq_length if reverse - if reverse: - hidden_states = list( list(reversed(list(entry))) for entry in hidden_states) - - #flatten layer dimension into tensor - hiddens = list( list( - flatten_list(seq) for seq in hidden ) - for hidden in hidden_states ) - - #Now in format list( [hidden_states][seq_length] x Tensor([layer][bsz][features]) ) - #Remove seq_length dimension if not collect_hidden - if not collect_hidden: - hidden_states = list( entry[0] for entry in hidden_states) - return output, hidden_states - - def reset_parameters(self): - """ - reset_parameters() - """ - for rnn in self.rnns: - rnn.reset_parameters() - - def init_hidden(self, bsz): - """ - init_hidden() - """ - for rnn in self.rnns: - rnn.init_hidden(bsz) - - def detach_hidden(self): - """ - detach_hidden() - """ - for rnn in self.rnns: - rnn.detach_hidden() - - def reset_hidden(self, bsz): - """ - reset_hidden() - """ - for rnn in self.rnns: - rnn.reset_hidden(bsz) - - def init_inference(self, bsz): - """ - init_inference() - """ - for rnn in self.rnns: - rnn.init_inference(bsz) - -class RNNCell(nn.Module): - """ - RNNCell - gate_multiplier is related to the architecture you're working with - For LSTM-like it will be 4 and GRU-like will be 3. - Always assumes input is NOT batch_first. - Output size that's not hidden size will use output projection - Hidden_states is number of hidden states that are needed for cell - if one will go directly to cell as tensor, if more will go as list - """ - def __init__(self, gate_multiplier, input_size, hidden_size, cell, n_hidden_states = 2, bias = False, output_size = None): - super(RNNCell, self).__init__() - - self.gate_multiplier = gate_multiplier - self.input_size = input_size - self.hidden_size = hidden_size - self.cell = cell - self.bias = bias - self.output_size = output_size - if output_size is None: - self.output_size = hidden_size - - self.gate_size = gate_multiplier * self.hidden_size - self.n_hidden_states = n_hidden_states - - self.w_ih = nn.Parameter(torch.empty(self.gate_size, self.input_size)) - self.w_hh = nn.Parameter(torch.empty(self.gate_size, self.output_size)) - - #Check if there's recurrent projection - if(self.output_size != self.hidden_size): - self.w_ho = nn.Parameter(torch.empty(self.output_size, self.hidden_size)) - - self.b_ih = self.b_hh = None - if self.bias: - self.b_ih = nn.Parameter(torch.empty(self.gate_size)) - self.b_hh = nn.Parameter(torch.empty(self.gate_size)) - - #hidden states for forward - self.hidden = [ None for states in range(self.n_hidden_states)] - - self.reset_parameters() - - def new_like(self, new_input_size=None): - """ - new_like() - """ - if new_input_size is None: - new_input_size = self.input_size - - return type(self)(self.gate_multiplier, - new_input_size, - self.hidden_size, - self.cell, - self.n_hidden_states, - self.bias, - self.output_size) - - - #Use xavier where we can (weights), otherwise use uniform (bias) - def reset_parameters(self, gain=1): - """ - reset_parameters() - """ - stdev = 1.0 / math.sqrt(self.hidden_size) - for param in self.parameters(): - param.data.uniform_(-stdev, stdev) - ''' - Xavier reset: - def reset_parameters(self, gain=1): - stdv = 1.0 / math.sqrt(self.gate_size) - - for param in self.parameters(): - if (param.dim() > 1): - torch.nn.init.xavier_normal(param, gain) - else: - param.data.uniform_(-stdv, stdv) - ''' - def init_hidden(self, bsz): - """ - init_hidden() - """ - for param in self.parameters(): - if param is not None: - a_param = param - break - - for i, _ in enumerate(self.hidden): - if(self.hidden[i] is None or self.hidden[i].data.size()[0] != bsz): - - if i==0: - hidden_size = self.output_size - else: - hidden_size = self.hidden_size - - tens = a_param.data.new(bsz, hidden_size).zero_() - self.hidden[i] = Variable(tens, requires_grad=False) - - - def reset_hidden(self, bsz): - """ - reset_hidden() - """ - for i, _ in enumerate(self.hidden): - self.hidden[i] = None - self.init_hidden(bsz) - - def detach_hidden(self): - """ - detach_hidden() - """ - for i, _ in enumerate(self.hidden): - if self.hidden[i] is None: - raise RuntimeError("Must initialize hidden state before you can detach it") - for i, _ in enumerate(self.hidden): - self.hidden[i] = self.hidden[i].detach() - - def forward(self, input): - """ - forward() - if not inited or bsz has changed this will create hidden states - """ - self.init_hidden(input.size()[0]) - - hidden_state = self.hidden[0] if self.n_hidden_states == 1 else self.hidden - self.hidden = self.cell(input, hidden_state, self.w_ih, self.w_hh, b_ih=self.b_ih, b_hh=self.b_hh) - if(self.n_hidden_states > 1): - self.hidden = list(self.hidden) - else: - self.hidden=[self.hidden] - - if self.output_size != self.hidden_size: - self.hidden[0] = F.linear(self.hidden[0], self.w_ho) - - return tuple(self.hidden) diff --git a/apex/RNN/__init__.py b/apex/RNN/__init__.py deleted file mode 100644 index d706746..0000000 --- a/apex/RNN/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .models import LSTM, GRU, ReLU, Tanh, mLSTM - -__all__ = ['models'] diff --git a/apex/RNN/cells.py b/apex/RNN/cells.py deleted file mode 100644 index 09b0858..0000000 --- a/apex/RNN/cells.py +++ /dev/null @@ -1,84 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -from .RNNBackend import RNNCell - -from torch.nn._functions.thnn import rnnFusedPointwise as fusedBackend - -import math - - -class mLSTMRNNCell(RNNCell): - """ - mLSTMRNNCell - """ - - def __init__(self, input_size, hidden_size, bias = False, output_size = None): - gate_multiplier = 4 - super(mLSTMRNNCell, self).__init__(gate_multiplier, input_size, hidden_size, mLSTMCell, n_hidden_states = 2, bias = bias, output_size = output_size) - - self.w_mih = nn.Parameter(torch.empty(self.output_size, self.input_size)) - self.w_mhh = nn.Parameter(torch.empty(self.output_size, self.output_size)) - - self.reset_parameters() - - def forward(self, input): - """ - mLSTMRNNCell.forward() - """ - #if not inited or bsz has changed this will create hidden states - self.init_hidden(input.size()[0]) - - hidden_state = self.hidden[0] if self.n_hidden_states == 1 else self.hidden - - self.hidden = list( - self.cell(input, hidden_state, self.w_ih, self.w_hh, self.w_mih, self.w_mhh, - b_ih=self.b_ih, b_hh=self.b_hh) - ) - - if self.output_size != self.hidden_size: - self.hidden[0] = F.linear(self.hidden[0], self.w_ho) - return tuple(self.hidden) - - - def new_like(self, new_input_size=None): - if new_input_size is None: - new_input_size = self.input_size - - return type(self)( - new_input_size, - self.hidden_size, - self.bias, - self.output_size) - -def mLSTMCell(input, hidden, w_ih, w_hh, w_mih, w_mhh, b_ih=None, b_hh=None): - """ - mLSTMCell - """ - - if input.is_cuda: - igates = F.linear(input, w_ih) - m = F.linear(input, w_mih) * F.linear(hidden[0], w_mhh) - hgates = F.linear(m, w_hh) - - state = fusedBackend.LSTMFused.apply - return state(igates, hgates, hidden[1], b_ih, b_hh) - - hx, cx = hidden - - m = F.linear(input, w_mih) * F.linear(hidden[0], w_mhh) - gates = F.linear(input, w_ih, b_ih) + F.linear(m, w_hh, b_hh) - - ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) - - ingate = F.sigmoid(ingate) - forgetgate = F.sigmoid(forgetgate) - cellgate = F.tanh(cellgate) - outgate = F.sigmoid(outgate) - - cy = (forgetgate * cx) + (ingate * cellgate) - hy = outgate * F.tanh(cy) - - return hy, cy - diff --git a/apex/RNN/models.py b/apex/RNN/models.py deleted file mode 100644 index dd7adce..0000000 --- a/apex/RNN/models.py +++ /dev/null @@ -1,54 +0,0 @@ -import torch - -from torch.nn._functions.rnn import LSTMCell, RNNReLUCell, RNNTanhCell, GRUCell - -from .RNNBackend import bidirectionalRNN, stackedRNN, RNNCell -from .cells import mLSTMRNNCell, mLSTMCell - -def toRNNBackend(inputRNN, num_layers, bidirectional=False, dropout = 0): - """ - :class:`toRNNBackend` - """ - - if bidirectional: - return bidirectionalRNN(inputRNN, num_layers, dropout = dropout) - else: - return stackedRNN(inputRNN, num_layers, dropout = dropout) - - -def LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None): - """ - :class:`LSTM` - """ - inputRNN = RNNCell(4, input_size, hidden_size, LSTMCell, 2, bias, output_size) - return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout) - -def GRU(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None): - """ - :class:`GRU` - """ - inputRNN = RNNCell(3, input_size, hidden_size, GRUCell, 1, bias, output_size) - return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout) - -def ReLU(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None): - """ - :class:`ReLU` - """ - inputRNN = RNNCell(1, input_size, hidden_size, RNNReLUCell, 1, bias, output_size) - return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout) - -def Tanh(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None): - """ - :class:`Tanh` - """ - inputRNN = RNNCell(1, input_size, hidden_size, RNNTanhCell, 1, bias, output_size) - return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout) - -def mLSTM(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None): - """ - :class:`mLSTM` - """ - inputRNN = mLSTMRNNCell(input_size, hidden_size, bias=bias, output_size=output_size) - return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout) - - diff --git a/apex/__init__.py b/apex/__init__.py deleted file mode 100644 index 47231a1..0000000 --- a/apex/__init__.py +++ /dev/null @@ -1,57 +0,0 @@ -import logging -import warnings - -# May help avoid undefined symbol errors https://pytorch.org/cppdocs/notes/faq.html#undefined-symbol-errors-from-pytorch-aten -import torch - - -if torch.distributed.is_available(): - from . import parallel - -from . import amp -from . import fp16_utils - -# For optimizers and normalization there is no Python fallback. -# Absence of cuda backend is a hard error. -# I would like the errors from importing fused_adam_cuda or fused_layer_norm_cuda -# to be triggered lazily, because if someone has installed with --cpp_ext and --cuda_ext -# so they expect those backends to be available, but for some reason they actually aren't -# available (for example because they built improperly in a way that isn't revealed until -# load time) the error message is timely and visible. -from . import optimizers -from . import normalization -from . import transformer - - -# Logging utilities for apex.transformer module -class RankInfoFormatter(logging.Formatter): - - def format(self, record): - from apex.transformer.parallel_state import get_rank_info - record.rank_info = get_rank_info() - return super().format(record) - - -_library_root_logger = logging.getLogger(__name__) -handler = logging.StreamHandler() -handler.setFormatter(RankInfoFormatter("%(asctime)s - PID:%(process)d - rank:%(rank_info)s - %(filename)s:%(lineno)d - %(levelname)s - %(message)s", "%y-%m-%d %H:%M:%S")) -_library_root_logger.addHandler(handler) -_library_root_logger.propagate = False - - -def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int) -> bool: - cudnn_available = torch.backends.cudnn.is_available() - cudnn_version = torch.backends.cudnn.version() if cudnn_available else None - if not (cudnn_available and (cudnn_version >= required_cudnn_version)): - warnings.warn( - f"`{global_option}` depends on cuDNN {required_cudnn_version} or later, " - f"but {'cuDNN is not available' if not cudnn_available else cudnn_version}" - ) - return False - return True - -try: - from .version import version, git_hash, git_branch, dtk, abi, torch_version, dcu_version # noqa: F401 - __version__, __dcu_version__ = version, dcu_version -except ImportError: - pass diff --git a/apex/_autocast_utils.py b/apex/_autocast_utils.py deleted file mode 100644 index e86c6c6..0000000 --- a/apex/_autocast_utils.py +++ /dev/null @@ -1,23 +0,0 @@ -from typing import Optional, Sequence - -import torch - - -def _get_autocast_dtypes() -> Sequence[torch.dtype]: - if torch.cuda.is_bf16_supported(): - return [torch.half, torch.bfloat16] - return [torch.half] - - -def _get_current_dtype(dtype: Optional[torch.dtype] = None) -> torch.dtype: - if not torch.is_autocast_enabled(): - return torch.float or dtype - else: - return torch.get_autocast_gpu_dtype() - - -def _cast_if_autocast_enabled(*args): - if not torch.is_autocast_enabled(): - return args - else: - return torch.cuda.amp.autocast_mode._cast(args, torch.get_autocast_gpu_dtype()) diff --git a/apex/amp/README.md b/apex/amp/README.md deleted file mode 100644 index a87b501..0000000 --- a/apex/amp/README.md +++ /dev/null @@ -1,72 +0,0 @@ -# amp: Automatic Mixed Precision - -## Annotating User Functions - -Nearly all PyTorch user code needs nothing more than the two steps -above to use amp. After all, custom layers are built out of simpler -PyTorch components, and amp already can see those. - -However, any custom C++ or CUDA code is outside of amp's (default) -view of things. For example, suppose I implemented a new recurrent -cell called a "forgetful recurrent unit" that calls directly into a -CUDA backend: - -```python -from backend import FRUBackend - -def fru(input, hidden, weight, bias): - # call to CUDA code - FRUBackend(input, hidden, weight, bias) -``` - -In this case, it is possible to get a runtime type mismatch. For -example, you might have `input` in fp16, and `weight` in fp32, and amp -doesn't have the visibility to insert an appropriate cast. - -amp exposes two ways to handle "invisible" backend code: function -annotations and explicit registration. - -#### Function annotation - -The first way to handle backend code is a set of function annotations: - -- `@amp.half_function` -- `@amp.float_function` -- `@amp.promote_function` - -These correspond to: - -- Cast all arguments to fp16 -- Cast all argumnets fo fp32 -- If there are any type mismatches, cast everything to the widest type - -In our example, we believe that the FRU unit is fp16-safe and will get -performance gains from casting its arguments to fp16, so we write: - -```python -@amp.half_function -def fru(input, hidden, weight, bias): - #... -``` - -#### Explicit registration - -The other way to handle backend code is with explicit function -registration: - -- `amp.register_half_function(module, function_name)` -- `amp.register_float_function(module, function_name)` -- `amp.register_promote_function(module, function_name)` - -When using this API, `module` is the containing class or module for -the function, and `function_name` is the _string_ name of the -function. Note that the function must be registered before the call to -`amp.initalize()`. - -For our FRU unit, we can register the backend function directly: - -```python -import backend - -amp.register_half_function(backend, 'FRUBackend') -``` diff --git a/apex/amp/__init__.py b/apex/amp/__init__.py deleted file mode 100644 index b4f81cd..0000000 --- a/apex/amp/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .amp import init, half_function, bfloat16_function, float_function, promote_function,\ - register_half_function, register_bfloat16_function, register_float_function, register_promote_function -from .handle import scale_loss, disable_casts -from .frontend import initialize, state_dict, load_state_dict -from ._amp_state import master_params, _amp_state diff --git a/apex/amp/__version__.py b/apex/amp/__version__.py deleted file mode 100644 index 3a83701..0000000 --- a/apex/amp/__version__.py +++ /dev/null @@ -1,2 +0,0 @@ -VERSION = (0, 1, 0) -__version__ = '.'.join(map(str, VERSION)) diff --git a/apex/amp/_amp_state.py b/apex/amp/_amp_state.py deleted file mode 100644 index 7e8a329..0000000 --- a/apex/amp/_amp_state.py +++ /dev/null @@ -1,59 +0,0 @@ -# This is a "header object" that allows different amp modules to communicate. -# I'm a C++ guy, not a python guy. I decided this approach because it seemed most C++-like. -# But apparently it's ok: -# http://effbot.org/pyfaq/how-do-i-share-global-variables-across-modules.htm -import torch - - -class AmpState(object): - def __init__(self): - self.hard_override=False - self.allow_incoming_model_not_fp32 = False - self.verbosity=1 - - -# Attribute stash. Could also just stash things as global module attributes. -_amp_state = AmpState() - - -def warn_or_err(msg): - if _amp_state.hard_override: - print("Warning: " + msg) - else: - raise RuntimeError(msg) - # I'm not sure if allowing hard_override is a good idea. - # + " If you're sure you know what you're doing, supply " + - # "hard_override=True to amp.initialize.") - - -def maybe_print(msg, rank0=False): - distributed = torch.distributed.is_available() and \ - torch.distributed.is_initialized() and \ - torch.distributed.get_world_size() > 1 - if _amp_state.verbosity > 0: - if rank0: - if distributed: - if torch.distributed.get_rank() == 0: - print(msg) - else: - print(msg) - else: - print(msg) - - -# def iter_params(param_groups): -# for group in param_groups: -# for p in group['params']: -# yield p - - -def master_params(optimizer): - """ - Generator expression that iterates over the params owned by ``optimizer``. - - Args: - optimizer: An optimizer previously returned from ``amp.initialize``. - """ - for group in optimizer.param_groups: - for p in group['params']: - yield p diff --git a/apex/amp/_initialize.py b/apex/amp/_initialize.py deleted file mode 100644 index 641451f..0000000 --- a/apex/amp/_initialize.py +++ /dev/null @@ -1,267 +0,0 @@ -import collections.abc as container_abcs -from types import MethodType -import functools -import sys -import warnings - -import numpy as np -import torch - -from ._amp_state import _amp_state, warn_or_err -from .handle import disable_casts -from .scaler import LossScaler -from ._process_optimizer import _process_optimizer -from apex.fp16_utils import convert_network -from ..fp16_utils import FP16_Optimizer as FP16_Optimizer_general -from ..contrib.optimizers import FP16_Optimizer as FP16_Optimizer_for_fused - -if torch.distributed.is_available(): - from ..parallel import DistributedDataParallel as apex_DDP - from ..parallel.LARC import LARC - - -def to_type(dtype, t): - if isinstance(t, torch.Tensor): - if not t.is_cuda: - # This should not be a hard error, since it may be legitimate. - warnings.warn("An input tensor was not cuda.") - # GANs require this. - # if t.requires_grad: - # warn_or_err("input data requires grad. Since input data is not a model parameter,\n" - # "its gradients will not be properly allreduced by DDP.") - if t.is_floating_point(): - return t.to(dtype) - return t - else: - # Trust the user's custom batch type, that's all I can do here. - return t.to(dtype) - - -# Modified from torch.optim.optimizer.py. This is a bit more general than casted_args in utils.py. -def applier(value, fn): - if isinstance(value, torch.Tensor): - return fn(value) - elif isinstance(value, str): - return value - elif isinstance(value, np.ndarray): - return value - elif hasattr(value, "to"): # Allow handling of custom batch classes - return fn(value) - elif isinstance(value, container_abcs.Mapping): - return {applier(k, fn) : applier(v, fn) for k, v in value.items()} - elif isinstance(value, container_abcs.Iterable): - return type(value)(applier(v, fn) for v in value) - else: - # Do I want this to fire off even if someone chooses to pass something ordinary like - # an int or float? May be more annoying than it's worth. - # print("Warning: unrecognized type in applier. If your input data is a custom class, " - # "provide it with a .to(dtype) method which converts its floating-point Tensors to dtype. " - # "Amp will check for your custom to() and invoke it to cast the batch's " - # "floating-point Tensors to the appropriate type. " - # "Also, if your data is a custom class, it is your responsibility to ensure that " - # "any Tensors you want to be cuda are already cuda." - return value - - -def check_models(models): - for model in models: - parallel_type = None - if isinstance(model, torch.nn.parallel.DistributedDataParallel): - parallel_type = "torch.nn.parallel.DistributedDataParallel" - if ('apex_DDP' in sys.modules) and isinstance(model, apex_DDP): - parallel_type = "apex.parallel.DistributedDataParallel" - if isinstance(model, torch.nn.parallel.DataParallel): - parallel_type = "torch.nn.parallel.DataParallel" - if parallel_type is not None: - raise RuntimeError("Incoming model is an instance of {}. ".format(parallel_type) + - "Parallel wrappers should only be applied to the model(s) AFTER \n" - "the model(s) have been returned from amp.initialize.") - - -def check_params_fp32(models): - for model in models: - for name, param in model.named_parameters(): - if param.is_floating_point(): - if 'Half' in param.type() or 'BFloat16' in param.type(): - warn_or_err("Found param {} with type {}, expected torch.cuda.FloatTensor.\n" - "When using amp.initialize, you do not need to call .half() or .bfloat16()\n" - "on your model before passing it, no matter what optimization level you choose.".format( - name, param.type())) - elif not param.is_cuda: - warn_or_err("Found param {} with type {}, expected torch.cuda.FloatTensor.\n" - "When using amp.initialize, you need to provide a model with parameters\n" - "located on a CUDA device before passing it no matter what optimization level\n" - "you chose. Use model.to('cuda') to use the default device.".format( - name, param.type())) - - # Backward compatibility for PyTorch 0.4 - if hasattr(model, 'named_buffers'): - buf_iter = model.named_buffers() - else: - buf_iter = model._buffers - for obj in buf_iter: - if type(obj)==tuple: - name, buf = obj - else: - name, buf = obj, buf_iter[obj] - if buf.is_floating_point(): - if 'Half' in buf.type(): - warn_or_err("Found buffer {} with type {}, expected torch.cuda.FloatTensor.\n" - "When using amp.initialize, you do not need to call .half() on your model\n" - "before passing it, no matter what optimization level you choose.".format( - name, buf.type())) - elif not buf.is_cuda: - warn_or_err("Found buffer {} with type {}, expected torch.cuda.FloatTensor.\n" - "When using amp.initialize, you need to provide a model with buffers\n" - "located on a CUDA device before passing it no matter what optimization level\n" - "you chose. Use model.to('cuda') to use the default device.".format( - name, buf.type())) - - -def check_optimizers(optimizers): - for optim in optimizers: - bad_optim_type = None - if isinstance(optim, FP16_Optimizer_general): - bad_optim_type = "apex.fp16_utils.FP16_Optimizer" - if isinstance(optim, FP16_Optimizer_for_fused): - bad_optim_type = "apex.optimizers.FP16_Optimizer" - if bad_optim_type is not None: - raise RuntimeError("An incoming optimizer is an instance of {}. ".format(bad_optim_type) + - "The optimizer(s) passed to amp.initialize() must be bare \n" - "instances of either ordinary Pytorch optimizers, or Apex fused \n" - "optimizers.\n") - - -class O2StateDictHook(object): - def __init__(self, fn): - self.fn = fn - - def __call__(self, module, state_dict, prefix, local_metadata): - for key in state_dict: - param = state_dict[key] - if 'Half' in param.type() or 'BFloat16' in param.type(): - param = param.to(torch.float32) - state_dict[key] = param - - -def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs=None): - from .amp import init as amp_init - - optimizers_was_list = False - if isinstance(optimizers, torch.optim.Optimizer) or ('LARC' in globals() and isinstance(optimizers, LARC)): - optimizers = [optimizers] - elif optimizers is None: - optimizers = [] - elif isinstance(optimizers, list): - optimizers_was_list = True - check_optimizers(optimizers) - else: - check_optimizers([optimizers]) - raise TypeError("optimizers must be either a single optimizer or a list of optimizers.") - - if isinstance(models, torch.nn.Module): - models_was_list = False - models = [models] - elif isinstance(models, list): - models_was_list = True - else: - raise TypeError("models must be either a single model or a list of models.") - - check_models(models) - - if not _amp_state.allow_incoming_model_not_fp32: - check_params_fp32(models) - - # In the future, when FP16_Optimizer can be deprecated and master weights can - # become an attribute, remember to stash master weights before casting the model. - - if properties.cast_model_type: - if properties.keep_batchnorm_fp32: - for model in models: - convert_network(model, properties.cast_model_type) - else: - for model in models: - model.to(properties.cast_model_type) - - input_caster = functools.partial(to_type, properties.cast_model_type) - if cast_model_outputs is not None: - output_caster = functools.partial(to_type, cast_model_outputs) - else: - output_caster = functools.partial(to_type, torch.float32) - - for model in models: - # Patch the forward method to cast incoming data to the correct type, and - # outgoing data to float32, so "the user never needs to call .half()/.bfloat16()." - # I like writing things explicitly more than decorators. - def patch_forward(old_fwd): - def new_fwd(*args, **kwargs): - output = old_fwd(*applier(args, input_caster), - **applier(kwargs, input_caster)) - return applier(output, output_caster) - return new_fwd - - model.forward = patch_forward(model.forward) - - # State dict trick to recast any preexisting per-param state tensors - for optimizer in optimizers: - optimizer.load_state_dict(optimizer.state_dict()) - - # patch model.state_dict() to return float32 params - for model in models: - for module in model.modules(): - module._register_state_dict_hook(O2StateDictHook(functools.partial(to_type, torch.float32))) - - elif cast_model_outputs is not None: - output_caster = functools.partial(to_type, cast_model_outputs) - - for model in models: - def patch_forward(old_fwd): - def new_fwd(*args, **kwargs): - output = old_fwd(*args, **kwargs) - return applier(output, output_caster) - return new_fwd - - model.forward = patch_forward(model.forward) - - for i, optimizer in enumerate(optimizers): - optimizers[i] = _process_optimizer(optimizer, properties) - - _amp_state.loss_scalers = [] - for _ in range(num_losses): - _amp_state.loss_scalers.append(LossScaler(properties.loss_scale, - min_loss_scale=_amp_state.min_loss_scale, - max_loss_scale=_amp_state.max_loss_scale)) - - if properties.patch_torch_functions: - # handle is unused here. It's accessible later through a global value anyway. - handle = amp_init(loss_scale=properties.loss_scale, - patch_type=properties.patch_torch_functions_type, - verbose=(_amp_state.verbosity == 2)) - for optimizer in optimizers: - # Disable Amp casting for the optimizer step, because it should only be - # applied to FP32 master params anyway. - def patch_step(old_step): - def new_step(self, *args, **kwargs): - with disable_casts(): - output = old_step(*args, **kwargs) - return output - return new_step - - optimizer.step = MethodType(patch_step(optimizer.step), optimizer) - - if optimizers_was_list: - if models_was_list: - return models, optimizers - else: - return models[0], optimizers - else: - if models_was_list: - if len(optimizers) == 0: - return models - else: - return models, optimizers[0] - else: - if len(optimizers) == 0: - return models[0] - else: - return models[0], optimizers[0] diff --git a/apex/amp/_process_optimizer.py b/apex/amp/_process_optimizer.py deleted file mode 100644 index 390d918..0000000 --- a/apex/amp/_process_optimizer.py +++ /dev/null @@ -1,489 +0,0 @@ -import types -from ..fp16_utils import master_params_to_model_params -from ..multi_tensor_apply import multi_tensor_applier -from ._amp_state import maybe_print, _amp_state -import torch -from ..optimizers import FusedSGD - - -class AmpOptimizerState(object): - def __init__(self): - pass - - -def _master_params_to_model_params(self): - stash = self._amp_stash - if multi_tensor_applier.available: - if len(stash.all_fp16_params) > 0: - multi_tensor_applier( - stash.multi_tensor_scale, - stash.dummy_overflow_buf, - [stash.all_fp32_from_fp16_params, stash.all_fp16_params], - 1.0) - else: - for fp16_group, fp32_from_fp16_group in zip(stash.fp16_groups, stash.fp32_from_fp16_groups): - master_params_to_model_params(fp16_group, fp32_from_fp16_group) - - -def lazy_init_with_master_weights(self): - stash = self._amp_stash - stash.fp16_groups = [] - stash.fp32_from_fp16_groups = [] - stash.fp32_from_fp32_groups = [] - for i, param_group in enumerate(self.param_groups): - # maybe_print("FP16_Optimizer processing param group {}:".format(i)) - fp16_params_this_group = [] - fp32_params_this_group = [] - fp32_from_fp16_params_this_group = [] - for i, param in enumerate(param_group['params']): - if param.requires_grad: - if param.type() in {'torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor'}: - # maybe_print("FP16_Optimizer received torch.cuda.HalfTensor with {}" - # .format(param.size())) - fp16_params_this_group.append(param) - master_param = param.detach().clone().float() - master_param.requires_grad = True - param_group['params'][i] = master_param - fp32_from_fp16_params_this_group.append(master_param) - # Reset existing state dict key to the new master param. - # We still need to recast per-param state tensors, if any, to FP32. - if param in self.state: - self.state[master_param] = self.state.pop(param) - elif param.type() == 'torch.cuda.FloatTensor': - # maybe_print("FP16_Optimizer received torch.cuda.FloatTensor with {}" - # .format(param.size())) - fp32_params_this_group.append(param) - param_group['params'][i] = param - else: - raise TypeError("Optimizer's parameters must one of " - "torch.cuda.FloatTensor, torch.cuda.HalfTensor, torch.cuda.BFloat16Tensor. " - "Received {}".format(param.type())) - - stash.fp16_groups.append(fp16_params_this_group) - stash.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group) - stash.fp32_from_fp32_groups.append(fp32_params_this_group) - - stash.all_fp16_params = [] - for group in stash.fp16_groups: - stash.all_fp16_params += group - - stash.all_fp32_from_fp16_params = [] - for group in stash.fp32_from_fp16_groups: - stash.all_fp32_from_fp16_params += group - - stash.all_fp32_from_fp32_params = [] - for group in stash.fp32_from_fp32_groups: - stash.all_fp32_from_fp32_params += group - - # all_fp16_grad_stash is only needed for fused optimizers. - stash.all_fp16_grad_stash = [None for _ in stash.all_fp16_params] - # stash.all_fp32_from_fp16_grad_stash = [None for _ in stash.all_fp32_from_fp16_params] - stash.all_fp32_from_fp32_grad_stash = [None for _ in stash.all_fp32_from_fp32_params] - - for param in stash.all_fp32_from_fp16_params: - param.grad = None - - for param in stash.all_fp32_from_fp32_params: - param.grad = None - - # Leverage state_dict() and load_state_dict() to recast preexisting per-param state tensors - self.load_state_dict(self.state_dict()) - - -def post_backward_models_are_masters(scaler, params, stashed_grads, scale_override=None): - grads_have_scale, stashed_have_scale, out_scale = scaler.loss_scale(), 1.0, 1.0 - - # not much to do if scale == 1.0 and static scaling - if scaler.loss_scale() == 1.0 and not scaler.dynamic: - # Clear the stash. - for i in range(len(stashed_grads)): - stashed_grads[i] = None - return - - if scale_override is not None: - grads_have_scale, stashed_have_scale, out_scale = scale_override - - # This is a lot of python overhead... - grads_needing_unscale = [] - grads_needing_unscale_with_stash = [] - stashed = [] - for param, stashed_grad in zip(params, stashed_grads): - if param.grad is None and stashed_grad is not None: - param.grad = stashed_grad - elif param.grad is not None and stashed_grad is None: - grads_needing_unscale.append(param.grad) - elif param.grad is not None and stashed_grad is not None: - grads_needing_unscale_with_stash.append(param.grad) - stashed.append(stashed_grad) - else: # param.grad is None and stashed_grad is None - continue - - # unscale() implements grads*(1/scale), so "scale" should be grads_have_scale/out_scale. - if len(grads_needing_unscale) > 0: - scaler.unscale( - grads_needing_unscale, - grads_needing_unscale, - None, # unused_scale, currently present to avoid API breakage elsewhere - models_are_masters=True, - scale_override=grads_have_scale/out_scale) - - if len(grads_needing_unscale_with_stash) > 0: - scaler.unscale_with_stashed( - grads_needing_unscale_with_stash, - stashed, - grads_needing_unscale_with_stash, - scale_override=(grads_have_scale, stashed_have_scale, out_scale)) - - # Clear the stash. - for i in range(len(stashed_grads)): - stashed_grads[i] = None - - -def prepare_backward_with_master_weights(self): - stash = self._amp_stash - - self._amp_lazy_init() - - for i, param in enumerate(stash.all_fp16_params): - # Set up to leverage grad copy elision. - # This may behave differently from an unpatched optimizer if zero_grad is used and the param is unused. - param.grad = None - - # for i, param in enumerate(stash.all_fp32_from_fp16_params): - # stash.all_fp32_from_fp16_grad_stash[i] = param.grad - - for i, param in enumerate(stash.all_fp32_from_fp32_params): - stash.all_fp32_from_fp32_grad_stash[i] = param.grad - # Set up to leverage grad copy elision: - param.grad = None - - -def post_backward_with_master_weights(self, scaler): - stash = self._amp_stash - - self._amp_lazy_init() - - # This is a lot of python overhead... - fp16_grads_needing_unscale = [] - new_fp32_grads = [] - fp16_grads_needing_unscale_with_stash = [] - preexisting_fp32_grads = [] - for fp16_param, fp32_param in zip(stash.all_fp16_params, - stash.all_fp32_from_fp16_params): - if fp16_param.grad is None and fp32_param.grad is not None: - continue - elif fp16_param.grad is not None and fp32_param.grad is None: - fp32_param.grad = torch.empty_like(fp32_param) - fp16_grads_needing_unscale.append(fp16_param.grad) - new_fp32_grads.append(fp32_param.grad) - elif fp16_param.grad is not None and fp32_param.grad is not None: - fp16_grads_needing_unscale_with_stash.append(fp16_param.grad) - preexisting_fp32_grads.append(fp32_param.grad) - else: # fp16_param.grad is None and fp32_param.grad is None: - continue - - if len(fp16_grads_needing_unscale) > 0: - scaler.unscale( - fp16_grads_needing_unscale, - new_fp32_grads, - scaler.loss_scale(), - models_are_masters=False) - - if len(fp16_grads_needing_unscale_with_stash) > 0: - scaler.unscale_with_stashed( - fp16_grads_needing_unscale_with_stash, - preexisting_fp32_grads, - preexisting_fp32_grads) - - # fp32 params can be treated as they would be in the "no_master_weights" case. - post_backward_models_are_masters( - scaler, - stash.all_fp32_from_fp32_params, - stash.all_fp32_from_fp32_grad_stash) - - -def lazy_init_no_master_weights(self): - stash = self._amp_stash - stash.all_fp16_params = [] - stash.all_fp32_params = [] - for i, param_group in enumerate(self.param_groups): - for i, param in enumerate(param_group['params']): - if param.type() in {'torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor'}: - stash.all_fp16_params.append(param) - elif param.type() == 'torch.cuda.FloatTensor': - stash.all_fp32_params.append(param) - else: - raise TypeError("Optimizer's parameters must be one of " - "torch.cuda.FloatTensor, torch.cuda.HalfTensor, torch.BFloat16Tensor. " - "Received {}".format(param.type())) - - stash.all_fp16_grad_stash = [None for _ in stash.all_fp16_params] - stash.all_fp32_grad_stash = [None for _ in stash.all_fp32_params] - - -def prepare_backward_no_master_weights(self): - stash = self._amp_stash - - self._amp_lazy_init() - - for i, param in enumerate(stash.all_fp16_params): - stash.all_fp16_grad_stash[i] = param.grad - # Set up to leverage grad copy elision: - param.grad = None - - for i, param in enumerate(stash.all_fp32_params): - stash.all_fp32_grad_stash[i] = param.grad - # Set up to leverage grad copy elision: - param.grad = None - - -def post_backward_no_master_weights(self, scaler): - stash = self._amp_stash - - self._amp_lazy_init() - - split_types = ((stash.all_fp16_params, stash.all_fp16_grad_stash), - (stash.all_fp32_params, stash.all_fp32_grad_stash)) - - for params, stashed_grads in split_types: - post_backward_models_are_masters(scaler, params, stashed_grads) - - -##################################################################################### -# FusedSGD versions -##################################################################################### - -# FusedSGD never explicitly materializes the fp32 gradients for "fp32 from fp16" master params -# outside the kernel, so we must accumulate directly into the model grads. -def prepare_backward_with_master_weights_FusedSGD(self): - if self.materialize_master_grads: - prepare_backward_with_master_weights(self) - else: - stash = self._amp_stash - - self._amp_lazy_init() - - for i, param in enumerate(stash.all_fp16_params): - stash.all_fp16_grad_stash[i] = param.grad - # Set up to leverage grad copy elision: - param.grad = None - - for i, param in enumerate(stash.all_fp32_from_fp32_params): - stash.all_fp32_from_fp32_grad_stash[i] = param.grad - # Set up to leverage grad copy elision: - param.grad = None - - -def post_backward_with_master_weights_FusedSGD(self, scaler): - if self.materialize_master_grads: - post_backward_with_master_weights(self, scaler) - else: - stash = self._amp_stash - - self._amp_lazy_init() - - grads_have_scale = scaler.loss_scale() - stashed_have_scale = self.most_recent_scale - out_scale = grads_have_scale - if self.scale_set_by_backward: - out_scale = min(grads_have_scale, self.most_recent_scale) - - split_types = ((stash.all_fp16_params, stash.all_fp16_grad_stash), - (stash.all_fp32_from_fp32_params, stash.all_fp32_from_fp32_grad_stash)) - - - # unscale_with_stashed() implements grads*1/scale + stashed_grads*1. - # stashed_grads are scaled by self.most_recent_scale. - for params, stashed_grads in split_types: - post_backward_models_are_masters(scaler, params, stashed_grads, - (grads_have_scale, stashed_have_scale, out_scale)) - - self.most_recent_scale = out_scale - self.scale_set_by_backward = True - - -def prepare_backward_no_master_weights_FusedSGD(self): - prepare_backward_no_master_weights(self) - - -def post_backward_no_master_weights_FusedSGD(self, scaler): - post_backward_no_master_weights(self, scaler) - - -def _amp_lazy_init(self): - stash = self._amp_stash - - if not stash.lazy_init_called: - self._lazy_init_maybe_master_weights() - stash.lazy_init_called = True - - -def _process_optimizer(optimizer, properties): - if hasattr(optimizer, "_amp_stash"): - raise RuntimeError("A given optimizer should only be passed through amp.initialize once.") - else: - optimizer._amp_stash = AmpOptimizerState() - - optimizer._amp_stash.lazy_init_called = False - optimizer._amp_stash.already_patched = False - optimizer._amp_stash.params_have_scaled_gradients = False - - for name in ("_lazy_init_maybe_master_weights", - "_master_params_to_model_params", - "_prepare_amp_backward", - "_post_amp_backward", - "_amp_lazy_init"): - if hasattr(optimizer, name): - raise RuntimeError("Incoming optimizer already has {} defined.".format(name)) - - # TODO: Centralize exposure and import error checking for the C backend. - if multi_tensor_applier.available: - import amp_C - optimizer._amp_stash.multi_tensor_scale = amp_C.multi_tensor_scale - optimizer._amp_stash.multi_tensor_l2norm = amp_C.multi_tensor_l2norm - optimizer._amp_stash.dummy_overflow_buf = torch.cuda.IntTensor([0]); - - if properties.master_weights: - optimizer._lazy_init_maybe_master_weights = types.MethodType( - lazy_init_with_master_weights, optimizer) - - optimizer._master_params_to_model_params = types.MethodType( - _master_params_to_model_params, optimizer) - - old_step = optimizer.step - def new_step(self, closure=None): - if closure is not None: - raise RuntimeError("Currently, Amp does not support closure use with optimizers.") - retval = old_step() - if not isinstance(self, FusedSGD): - self._master_params_to_model_params() - # Clear the master grads that wouldn't be zeroed by model.zero_grad() - for param in self._amp_stash.all_fp32_from_fp16_params: - param.grad = None - return retval - optimizer.step = types.MethodType(new_step, optimizer) - - old_zero_grad = optimizer.zero_grad - def new_zero_grad(self): - stash = self._amp_stash - self._amp_lazy_init() - # Zero the model grads. - for param in stash.all_fp16_params: - if param.grad is not None: - param.grad.detach_() - param.grad.zero_() - for param in stash.all_fp32_from_fp32_params: - if param.grad is not None: - param.grad.detach_() - param.grad.zero_() - # Clear the master grads that are independent of model grads - for param in self._amp_stash.all_fp32_from_fp16_params: - param.grad = None - optimizer.zero_grad = types.MethodType(new_zero_grad, optimizer) - - if isinstance(optimizer, FusedSGD): - optimizer._prepare_amp_backward = types.MethodType( - prepare_backward_with_master_weights_FusedSGD, optimizer) - optimizer._post_amp_backward = types.MethodType( - post_backward_with_master_weights_FusedSGD, optimizer) - else: - optimizer._prepare_amp_backward = types.MethodType( - prepare_backward_with_master_weights, optimizer) - optimizer._post_amp_backward = types.MethodType( - post_backward_with_master_weights, optimizer) - else: - optimizer._lazy_init_maybe_master_weights = types.MethodType( - lazy_init_no_master_weights, optimizer) - - if isinstance(optimizer, FusedSGD): - optimizer._prepare_amp_backward = types.MethodType( - prepare_backward_no_master_weights_FusedSGD, optimizer) - optimizer._post_amp_backward = types.MethodType( - post_backward_no_master_weights_FusedSGD, optimizer) - else: - optimizer._prepare_amp_backward = types.MethodType( - prepare_backward_no_master_weights, optimizer) - optimizer._post_amp_backward = types.MethodType( - post_backward_no_master_weights, optimizer) - - optimizer._amp_lazy_init = types.MethodType(_amp_lazy_init, optimizer) - - old_add_param_group = optimizer.add_param_group - - def new_add_param_group(self, new_group): - stash = self._amp_stash - - if not stash.lazy_init_called: - self._lazy_init_maybe_master_weights() - stash.lazy_init_called = True - - assert isinstance(new_group, dict), "param group must be a dict" - - new_params = new_group['params'] - if isinstance(new_params, torch.Tensor): - new_group['params'] = [new_params] - elif isinstance(new_params, set): - raise TypeError('optimizer parameters need to be organized in ordered collections, but ' - 'the ordering of tensors in sets will change between runs. Please use a list instead.') - else: - new_group['params'] = list(new_params) - - if properties.master_weights: - # Mutate new_group in-place to use FP32 master params - fp16_params_this_group = [] - fp32_params_this_group = [] - fp32_from_fp16_params_this_group = [] - for i, param in enumerate(new_group['params']): - if param.requires_grad: - if param.type() in {'torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor'}: - fp16_params_this_group.append(param) - master_param = param.detach().clone().float() - master_param.requires_grad = True - new_group['params'][i] = master_param - fp32_from_fp16_params_this_group.append(master_param) - elif param.type() == 'torch.cuda.FloatTensor': - fp32_params_this_group.append(param) - new_group['params'][i] = param - else: - raise TypeError("Optimizer's parameters must be one of " - "torch.cuda.FloatTensor, torch.cuda.HalfTensor, torch.cuda.BFloat16Tensor. " - "Received {}".format(param.type())) - - stash.fp16_groups.append(fp16_params_this_group) - stash.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group) - stash.fp32_from_fp32_groups.append(fp32_params_this_group) - - stash.all_fp16_params += fp16_params_this_group - stash.all_fp32_from_fp16_params += fp32_from_fp16_params_this_group - stash.all_fp32_from_fp32_params += fp32_params_this_group - - # stash.all_fp32_from_fp16_grad_stash = [None for _ in stash.all_fp32_from_fp16_params] - stash.all_fp32_from_fp32_grad_stash += [None for _ in fp32_params_this_group] - - # It should be ok to let params be added with existing .grad attributes. - # for param in fp16_params_this_group: - # param.grad = None - - # for param in fp32_from_fp16_params_this_group: - # param.grad = None - - # for param in stash.fp32_params_this_group: - # param.grad = None - else: - for param in new_group['params']: - if param.type() in {'torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor'}: - stash.all_fp16_params.append(param) - stash.all_fp16_grad_stash.append(None) - elif param.type() == 'torch.cuda.FloatTensor': - stash.all_fp32_params.append(param) - stash.all_fp32_grad_stash.append(None) - else: - raise TypeError("Optimizer's parameters must one of " - "torch.cuda.FloatTensor, torch.cuda.HalfTensor, torch.cuda.BFloat16Tensor. " - "Received {}".format(param.type())) - - old_add_param_group(new_group) - - optimizer.add_param_group = types.MethodType(new_add_param_group, optimizer) - - return optimizer diff --git a/apex/amp/amp.py b/apex/amp/amp.py deleted file mode 100644 index b438b3f..0000000 --- a/apex/amp/amp.py +++ /dev/null @@ -1,198 +0,0 @@ -from . import compat, rnn_compat, utils, wrap -from .handle import AmpHandle, NoOpHandle -from .lists import functional_overrides, torch_overrides, tensor_overrides -from ._amp_state import _amp_state -from .frontend import * - -import functools -import itertools - -import torch - -_DECORATOR_HANDLE = None -_USER_CAST_REGISTRY = set() -_USER_PROMOTE_REGISTRY = set() - - -def _decorator_helper(orig_fn, cast_fn, wrap_fn): - def wrapper(*args, **kwargs): - handle = _DECORATOR_HANDLE - if handle is None or not handle.is_active(): - return orig_fn(*args, **kwargs) - inner_cast_fn = utils.verbosify(cast_fn, orig_fn.__name__, - handle.verbose) - return wrap_fn(orig_fn, inner_cast_fn, handle)(*args, **kwargs) - return wrapper - - -# Decorator form -def half_function(fn): - wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=True) - return _decorator_helper(fn, utils.maybe_half, wrap_fn) - -def bfloat16_function(fn): - wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=True) - return _decorator_helper(fn, utils.maybe_bfloat16, wrap_fn) - -def float_function(fn): - wrap_fn = functools.partial(wrap.make_cast_wrapper, try_caching=False) - return _decorator_helper(fn, utils.maybe_float, wrap_fn) - - -def promote_function(fn): - wrap_fn = functools.partial(wrap.make_promote_wrapper) - return _decorator_helper(fn, utils.maybe_float, wrap_fn) - - -# Registry form -def register_half_function(module, name): - if not hasattr(module, name): - raise ValueError('No function named {} in module {}.'.format( - name, module)) - _USER_CAST_REGISTRY.add((module, name, utils.maybe_half)) - -def register_bfloat16_function(module, name): - if not hasattr(module, name): - raise ValueError('No function named {} in module {}.'.format( - name, module)) - _USER_CAST_REGISTRY.add((module, name, utils.maybe_bfloat16)) - -def register_float_function(module, name): - if not hasattr(module, name): - raise ValueError('No function named {} in module {}.'.format( - name, module)) - _USER_CAST_REGISTRY.add((module, name, utils.maybe_float)) - - -def register_promote_function(module, name): - if not hasattr(module, name): - raise ValueError('No function named {} in module {}.'.format( - name, module)) - _USER_PROMOTE_REGISTRY.add((module, name)) - - -# Top-level function to insert _all_ the hooks. -def init(enabled=True, loss_scale="dynamic", patch_type=torch.float16, enable_caching=True, verbose=False, allow_banned=False): - global _DECORATOR_HANDLE - - if not enabled: - handle = NoOpHandle() - _DECORATOR_HANDLE = handle - return handle - - handle = AmpHandle(loss_scale, enable_caching, verbose) - - # 0) Force-{fp16, fp32} for user-annotated functions - for mod, fn, cast_fn in _USER_CAST_REGISTRY: - try_caching = (cast_fn == utils.maybe_half) - wrap.cached_cast(mod, fn, cast_fn, handle, - try_caching, verbose) - _USER_CAST_REGISTRY.clear() - - # 0.5) Force-promote for user-annotated functions - for mod, fn in _USER_PROMOTE_REGISTRY: - wrap.promote(mod, fn, handle, verbose) - _USER_PROMOTE_REGISTRY.clear() - - # conditionally choose between fp16 and bfloat16 functions list to cache - if patch_type == torch.float16: - low_prec_funcs = 'FP16_FUNCS' - maybe_low_prec = utils.maybe_half - low_prec_tensor = torch.cuda.HalfTensor - elif patch_type == torch.bfloat16: - low_prec_funcs = 'BFLOAT16_FUNCS' - maybe_low_prec = utils.maybe_bfloat16 - low_prec_tensor = torch.cuda.BFloat16Tensor - else: - raise RuntimeError("Unsupported patch_torch_functions_type passed to initialize." + - "Supported types are: torch.float16 and torch.bfloat16.") - - # 1) Force-{fp16, fp32} on white- / black-list functions - override_modules = [functional_overrides, - torch_overrides, - tensor_overrides] - cast_table = [(low_prec_funcs, maybe_low_prec), - ('FP32_FUNCS', utils.maybe_float)] - - for module, (list_name, cast_fn) in itertools.product(override_modules, - cast_table): - for fn in getattr(module, list_name): - try_caching = (cast_fn == maybe_low_prec) - wrap.cached_cast(module.MODULE, fn, cast_fn, handle, - try_caching, verbose) - - # 1.5) Pre-0.4, put the blacklist methods on HalfTensor and whitelist - # methods on FloatTensor, since they're distinct types. - if compat.tensor_is_float_tensor(): - for fn in tensor_overrides.FP16_FUNCS: - wrap.cached_cast(torch.cuda.FloatTensor, fn, utils.maybe_half, - handle, try_caching=True, verbose=verbose) - for fn in tensor_overrides.FP32_FUNCS: - wrap.cached_cast(torch.cuda.HalfTensor, fn, utils.maybe_float, - handle, try_caching=False, verbose=verbose) - - # 2) Enable type-promotion on multi-arg functions and methods. - # NB: special handling for sequence fns (e.g. `torch.cat`). - promote_modules = [torch_overrides, tensor_overrides] - promote_table = [('CASTS', wrap.promote), - ('SEQUENCE_CASTS', wrap.sequence_promote)] - for promote_mod, (list_name, promote_fn) in itertools.product(promote_modules, - promote_table): - for fn in getattr(promote_mod, list_name): - promote_fn(promote_mod.MODULE, fn, handle, verbose) - - # 2.5) Pre-0.4, add blacklist methods directly to HalfTensor and FloatTensor types - if compat.tensor_is_float_tensor(): - for cls, (list_name, promote_fn) in itertools.product([torch.cuda.FloatTensor, - torch.cuda.HalfTensor], - promote_table): - for fn in getattr(tensor_overrides, list_name): - promote_fn(cls, fn, handle, verbose) - - # 3) For any in-place version of a blacklist function, error if any input is fp16/bfloat16. - # NB: this is overly conservative. - for fn in utils.as_inplace(torch_overrides.FP32_FUNCS): - wrap.err_if_any_half(torch_overrides.MODULE, fn, handle) - - # 3.5) For any in-place blacklist method, error if called on fp16/bfloat16 tensor - for fn in utils.as_inplace(tensor_overrides.FP32_FUNCS): - wrap.err_if_arg0_half(tensor_overrides.MODULE, fn, handle, verbose) - if compat.tensor_is_float_tensor(): - wrap.err_if_arg0_half(torch.cuda.HalfTensor, fn, handle, verbose) - - # 4) For other in-place methods, match the type of self tensor - for fn in utils.as_inplace(itertools.chain( - getattr(tensor_overrides, low_prec_funcs), - tensor_overrides.CASTS)): - wrap.promote_match_arg0(tensor_overrides.MODULE, fn, handle, verbose) - if compat.tensor_is_float_tensor(): - wrap.promote_match_arg0(torch.cuda.HalfTensor, fn, handle, verbose) - wrap.promote_match_arg0(torch.cuda.FloatTensor, fn, handle, verbose) - - # 5) RNNs + RNN cells are whitelisted specially - if rnn_compat.has_old_rnns(): - wrap.rnn_cast(torch.nn.backends.thnn.backend, 'RNN', handle, verbose) - if not rnn_compat.has_old_rnns(): - # Patch in our own indirection of `_VF` in modules/rnn s.t. it is mutable. - torch.nn.modules.rnn._VF = rnn_compat.VariableFunctionsShim() - # Wrap all the rnns - for x in rnn_compat.RNN_NAMES: - wrap.new_rnn_cast(x.upper(), maybe_low_prec, handle, verbose) - - # Wrap all the RNN cells - rnn_compat.whitelist_rnn_cells(maybe_low_prec, handle, verbose) - - # 6) Place error+print message on banned functions. - # Or, if allow_banned, then cast to FP32. - for fn, err_msg in functional_overrides.BANNED_FUNCS: - if allow_banned: - wrap.cached_cast(functional_overrides.MODULE, fn, utils.maybe_float, - handle, try_caching=True, verbose=verbose) - else: - wrap.err_if_any_half(functional_overrides.MODULE, fn, handle, err_msg) - - _DECORATOR_HANDLE = handle - - _amp_state.handle = handle - - return handle diff --git a/apex/amp/compat.py b/apex/amp/compat.py deleted file mode 100644 index 2725fa8..0000000 --- a/apex/amp/compat.py +++ /dev/null @@ -1,47 +0,0 @@ -import torch - -# True for post-0.4, when Variables/Tensors merged. -def variable_is_tensor(): - v = torch.autograd.Variable() - return isinstance(v, torch.Tensor) - -def tensor_is_variable(): - x = torch.Tensor() - return type(x) == torch.autograd.Variable - -# False for post-0.4 -def tensor_is_float_tensor(): - x = torch.Tensor() - return type(x) == torch.FloatTensor - -# Akin to `torch.is_tensor`, but returns True for Variable -# objects in pre-0.4. -def is_tensor_like(x): - return torch.is_tensor(x) or isinstance(x, torch.autograd.Variable) - -# Wraps `torch.is_floating_point` if present, otherwise checks -# the suffix of `x.type()`. -def is_floating_point(x): - if hasattr(torch, 'is_floating_point'): - return torch.is_floating_point(x) - try: - torch_type = x.type() - return torch_type.endswith('FloatTensor') or \ - torch_type.endswith('HalfTensor') or \ - torch_type.endswith('DoubleTensor') or \ - torch_type.endswith('BFloat16Tensor') - except AttributeError: - return False - -def scalar_python_val(x): - if hasattr(x, 'item'): - return x.item() - else: - if isinstance(x, torch.autograd.Variable): - return x.data[0] - else: - return x[0] - -# Accounts for the possibility that some ops may be removed from a namespace. -def filter_attrs(module, attrs): - return list(attrname for attrname in attrs if hasattr(module, attrname)) diff --git a/apex/amp/frontend.py b/apex/amp/frontend.py deleted file mode 100644 index cbaf139..0000000 --- a/apex/amp/frontend.py +++ /dev/null @@ -1,509 +0,0 @@ -import torch -from ._initialize import _initialize -from ._amp_state import _amp_state, warn_or_err, maybe_print -from collections import OrderedDict - - -class Properties(object): - """ - This class has two purposes: to establish a set of default properties, - and to route setting of these attributes through __setattr__ so that (in theory) - they can be checked for consistency with other existing args. - """ - def __init__(self): - self.options = { - "enabled" : False, - "opt_level" : None, - "cast_model_type" : None, - "patch_torch_functions" : False, - # TODO: patch_torch_functions_type could probably be unified with - # patch_torch_functions. Currently introducing a new attribute - # to be on the safer side and not break stuff. - "patch_torch_functions_type" : None, - "keep_batchnorm_fp32" : None, - "master_weights" : None, - "loss_scale" : 1.0, - # Reserved for future functionality - # "fused_optimizer" : False, - # "enable_ddp_interop" : False, - } - - """ - This function allows updating several options at a time without routing through - __setattr__ checks, to avoid "you can't get there from here" scenarios. - Currently not intended to be exposed; users are expected to select an opt_level - and apply consistent modifications. - """ - def _update_options_dict(self, new_options): - for k, v in new_options: - if k in self.options: - self.options[k] = v - else: - raise ValueError("Tried to set unexpected option {}".format(k)) - """ - The members of "options" are not direct attributes of self, so access attempts - will roll down to __getattr__. This borrows from the logic in torch.nn.Module. - """ - def __getattr__(self, name): - if "options" in self.__dict__: - options = self.__dict__["options"] - if name in options: - return options[name] - raise AttributeError("'{}' object has no attribute '{}'".format( - type(self).__name__, name)) - - def __setattr__(self, name, value): - if "options" in self.__dict__: - if name in self.options: - # print("setting {} {}".format(name, value)) - if name == "cast_model_type": - if self.opt_level in {"O1", "O4"} and value is not None: - if value is not False: - if value is not torch.float32: - warn_or_err("O1 inserts casts around Torch functions rather than " - "model weights, so with O1, the model weights themselves " - "should remain FP32. If you wish to cast the model to a " - "different type, use opt_level='O2' or 'O3'. " + - "cast_model_type was {}".format(value)) - self.options[name] = value - elif name == "patch_torch_functions": - if self.opt_level not in {"O1", "O4"} and value: - warn_or_err("Currently, patch_torch_functions=True should only be set by " - "selecting opt_level='O1' or 'O4'.") - self.options[name] = value - elif name == "patch_torch_functions_type": - if self.opt_level not in {"O1", "O4"} and value is not None: - warn_or_err("Currently, patch_torch_functions_type should only be set by " - "selecting opt_level='O1' or 'O4'.") - elif self.opt_level == "O1" and value != torch.float16: - warn_or_err("patch_torch_functions_type should only be set to torch.float16 " - "for opt_level='O1.") - elif self.opt_level == "O4" and value != torch.bfloat16: - warn_or_err("patch_torch_functions_type should only be set to torch.bfloat16 " - "for opt_level='O4.") - else: - self.options[name] = value - elif name == "keep_batchnorm_fp32": - if self.opt_level in {"O1", "O4"} and value is not None: - warn_or_err("With opt_level O1 or O4, batchnorm functions are automatically patched " - "to run in FP32, so keep_batchnorm_fp32 should be None." + - " keep_batchnorm_fp32 was {}".format(value)) - if value == "False": - self.options[name] = False - elif value == "True": - self.options[name] = True - else: - assert (value is True or value is False or value is None),\ - "keep_batchnorm_fp32 must be a boolean, the string 'True' or 'False', "\ - "or None, found keep_batchnorm_fp32={}".format(value) - self.options[name] = value - elif name == "master_weights": - if self.opt_level in {"O1", "O4"} and value is not None: - warn_or_err("It doesn't make sense to use master_weights with O1 and O4 . " - "With O1 and O4, your model weights themselves should be FP32.") - self.options[name] = value - elif name == "loss_scale": - if value == "dynamic": - self.options[name] = value - else: - self.options[name] = float(value) - else: - self.options[name] = value - else: - super(Properties, self).__setattr__(name, value) - - -""" O0-O3 are convenience wrappers to establish defaults for typically used mixed precision options. """ - -class O3: - brief = "O3: Pure FP16 training." - more = "Calls .half() on your model, converting the entire model to FP16.\n"\ - "A casting operation is also inserted to cast incoming Tensors to FP16,\n"\ - "so you don't need to change your data pipeline.\n"\ - "This mode is useful for establishing a performance ceiling.\n"\ - "It's also possible training may 'just work' in this mode.\n"\ - "If not, try other optimization levels." - - def __call__(self, properties): - properties.enabled = True - properties.opt_level = "O3" - properties.cast_model_type = torch.float16 - properties.patch_torch_functions = False - properties.patch_torch_functions_type = None - properties.keep_batchnorm_fp32 = False - properties.master_weights = False - properties.loss_scale = 1.0 - # properties.fused_optimizer = False - # properties.enable_ddp_interop = False - return properties # modified in place so this isn't really necessary - - -class O2: - brief = "O2: FP16 training with FP32 batchnorm and FP32 master weights.\n" - more = "Calls .half() on your model, converting the entire model (except for batchnorms)\n"\ - "to FP16. Batchnorms are retained in FP32 for additional stability.\n"\ - "The forward pass is patched to cast incoming Tensors to FP16, so you don't need to change\n"\ - "your data pipeline.\n"\ - "O2 creates FP32 master weights outside the model and patches any optimizers to update\n"\ - "these master weights, then copy the master weights into the FP16 model weights.\n"\ - "Master weights can also improve convergence and stability." - - def __call__(self, properties): - properties.enabled = True - properties.opt_level = "O2" - properties.cast_model_type = torch.float16 - properties.patch_torch_functions = False - properties.patch_torch_functions_type = None - properties.keep_batchnorm_fp32 = True - properties.master_weights = True - properties.loss_scale = "dynamic" - # properties.fused_optimizer = False - # properties.enable_ddp_interop = False - return properties # modified in place so this isn't really necessary - - -class O1: - brief = "O1: Insert automatic casts around Pytorch functions and Tensor methods.\n" - more = "The type of your model's weights is not altered. However, internally,\n"\ - "Pytorch functions are patched to cast any Tensor Core-friendly ops to FP16 for speed,\n"\ - "while operations that might benefit from the additional stability of FP32 are patched\n"\ - "to cast their inputs to fp32.\n"\ - "O1 is the safest way to try mixed precision training, and is recommended when\n"\ - "trying mixed precision training for the first time." - - def __call__(self, properties): - properties.enabled = True - properties.opt_level = "O1" - properties.cast_model_type = None - properties.patch_torch_functions = True - properties.patch_torch_functions_type = torch.float16 - properties.keep_batchnorm_fp32 = None - properties.master_weights = None - properties.loss_scale = "dynamic" - # properties.fused_optimizer = False - # properties.enable_ddp_interop = False - return properties # modified in place so this isn't really necessary - - -class O0: - brief = "O0: Pure FP32 training.\n" - more = "Your models are checked to make sure parameters are FP32, but otherwise the\n"\ - "types of weights and internal Pytorch operations are not altered. This mode disables any\n"\ - "FP16 arithmetic, although other optimizations like DDP interop may still be requested.\n" - - def __call__(self, properties): - properties.enabled = True - properties.opt_level = "O0" - properties.cast_model_type = torch.float32 - properties.patch_torch_functions = False - properties.patch_torch_functions_type = None - properties.keep_batchnorm_fp32 = None - properties.master_weights = False - properties.loss_scale = 1.0 - # properties.fused_optimizer = False - # properties.enable_ddp_interop = False - return properties # modified in place so this isn't really necessary - -class O4: - brief = "O4: Insert automatic casts around Pytorch functions and Tensor methods.\n" - more = "The type of your model's weights is not altered. However, internally,\n"\ - "Pytorch functions are patched to cast any Tensor Core-friendly ops to BFLOAT16 for speed,\n"\ - "while operations that might benefit from the additional stability of FP32 are patched\n"\ - "to cast their inputs to fp32.\n"\ - "Loss scaling is not required in O4 mode since bflaot16 has the same dynamic range as fp32." - - def __call__(self, properties): - properties.enabled = True - properties.opt_level = "O4" - properties.cast_model_type = None - properties.patch_torch_functions = True - properties.patch_torch_functions_type = torch.bfloat16 - properties.keep_batchnorm_fp32 = None - properties.master_weights = None - properties.loss_scale = 1 - return properties # modified in place so this isn't really necessary - -class O5: - brief = "O5: BFLOAT16 training with FP32 batchnorm and FP32 master weights.\n" - more = "Calls .bfloat16() on your model, converting the entire model (except for batchnorms)\n"\ - "to BFLOAT16. Batchnorms are retained in FP32 for additional stability.\n"\ - "The forward pass is patched to cast incoming Tensors to BFLOAT16, so you don't need to change\n"\ - "your data pipeline.\n"\ - "O5 creates FP32 master weights outside the model and patches any optimizers to update\n"\ - "these master weights, then copy the master weights into the BFLOAT16 model weights.\n"\ - "Master weights can also improve convergence and stability." - - def __call__(self, properties): - properties.enabled = True - properties.opt_level = "O5" - properties.cast_model_type = torch.bfloat16 - properties.patch_torch_functions = False - properties.patch_torch_functions = None - properties.patch_torch_functions_type = None - properties.keep_batchnorm_fp32 = True - properties.master_weights = True - properties.loss_scale = 1 - return properties # modified in place so this isn't really necessary - - -opt_levels = {"O3": O3(), - "O2": O2(), - "O1": O1(), - "O0": O0(), - "O4": O4(), - "O5": O5()} - - -# allow user to directly pass Properties struct as well? -def initialize( - models, - optimizers=None, - enabled=True, - opt_level="O1", - cast_model_type=None, - patch_torch_functions=None, - patch_torch_functions_type=None, - keep_batchnorm_fp32=None, - master_weights=None, - loss_scale=None, - cast_model_outputs=None, - num_losses=1, - verbosity=1, - min_loss_scale=None, - max_loss_scale=2.**24 - ): - """ - Initialize your models, optimizers, and the Torch tensor and functional namespace according to the - chosen ``opt_level`` and overridden properties, if any. - - ``amp.initialize`` should be called **after** you have finished - constructing your model(s) and - optimizer(s), but **before** you send your model through any DistributedDataParallel wrapper. - See `Distributed training`_ in the Imagenet example. - - Currently, ``amp.initialize`` should only be called **once**, - although it can process an arbitrary number of - models and optimizers (see the corresponding `Advanced Amp Usage topic`_). - If you think your use case requires ``amp.initialize`` to be called more than once, - `let us know`_. - - Any property keyword argument that is not ``None`` will be interpreted as a manual override. - - To prevent having to rewrite anything else in your script, name the returned models/optimizers - to replace the passed models/optimizers, as in the code sample below. - - Args: - models (torch.nn.Module or list of torch.nn.Modules): Models to modify/cast. - optimizers (optional, torch.optim.Optimizer or list of torch.optim.Optimizers): Optimizers to modify/cast. - REQUIRED for training, optional for inference. - enabled (bool, optional, default=True): If False, renders all Amp calls no-ops, so your script - should run as if Amp were not present. - opt_level (str, optional, default="O1"): Pure or mixed precision optimization level. Accepted values are - "O0", "O1", "O2", "O3", "O4" and "O5", explained in detail above. - cast_model_type (``torch.dtype``, optional, default=None): Optional property override, see - above. - patch_torch_functions (bool, optional, default=None): Optional property override. - patch_torch_functions_type (``torch.dtype``, optional, default=None): Optional property override - keep_batchnorm_fp32 (bool or str, optional, default=None): Optional property override. If - passed as a string, must be the string "True" or "False". - master_weights (bool, optional, default=None): Optional property override. - loss_scale (float or str, optional, default=None): Optional property override. If passed as a string, - must be a string representing a number, e.g., "128.0", or the string "dynamic". - cast_model_outputs (torch.dtype, optional, default=None): Option to ensure that the outputs - of your model(s) are always cast to a particular type regardless of ``opt_level``. - num_losses (int, optional, default=1): Option to tell Amp in advance how many losses/backward - passes you plan to use. When used in conjunction with the ``loss_id`` argument to - ``amp.scale_loss``, enables Amp to use a different loss scale per loss/backward pass, - which can improve stability. See "Multiple models/optimizers/losses" - under `Advanced Amp Usage`_ for examples. If ``num_losses`` is left to 1, Amp will still - support multiple losses/backward passes, but use a single global loss scale - for all of them. - verbosity (int, default=1): Set to 0 to suppress Amp-related output. - min_loss_scale (float, default=None): Sets a floor for the loss scale values that can be chosen by dynamic - loss scaling. The default value of None means that no floor is imposed. - If dynamic loss scaling is not used, `min_loss_scale` is ignored. - max_loss_scale (float, default=2.**24): Sets a ceiling for the loss scale values that can be chosen by - dynamic loss scaling. If dynamic loss scaling is not used, `max_loss_scale` is ignored. - - Returns: - Model(s) and optimizer(s) modified according to the ``opt_level``. - If either the ``models`` or ``optimizers`` args were lists, the corresponding return value will - also be a list. - - Permissible invocations:: - - model, optim = amp.initialize(model, optim,...) - model, [optim1, optim2] = amp.initialize(model, [optim1, optim2],...) - [model1, model2], optim = amp.initialize([model1, model2], optim,...) - [model1, model2], [optim1, optim2] = amp.initialize([model1, model2], [optim1, optim2],...) - - # This is not an exhaustive list of the cross product of options that are possible, - # just a set of examples. - model, optim = amp.initialize(model, optim, opt_level="O0") - model, optim = amp.initialize(model, optim, opt_level="O0", loss_scale="dynamic"|128.0|"128.0") - - model, optim = amp.initialize(model, optim, opt_level="O1") # uses "loss_scale="dynamic" default - model, optim = amp.initialize(model, optim, opt_level="O1", loss_scale=128.0|"128.0") - - model, optim = amp.initialize(model, optim, opt_level="O2") # uses "loss_scale="dynamic" default - model, optim = amp.initialize(model, optim, opt_level="O2", loss_scale=128.0|"128.0") - model, optim = amp.initialize(model, optim, opt_level="O2", keep_batchnorm_fp32=True|False|"True"|"False") - - model, optim = amp.initialize(model, optim, opt_level="O3") # uses loss_scale=1.0 default - model, optim = amp.initialize(model, optim, opt_level="O3", loss_scale="dynamic"|128.0|"128.0") - model, optim = amp.initialize(model, optim, opt_level="O3", keep_batchnorm_fp32=True|False|"True"|"False") - - The `Imagenet example`_ demonstrates live use of various opt_levels and overrides. - - .. _`Distributed training`: - https://github.com/NVIDIA/apex/tree/master/examples/imagenet#distributed-training - - .. _`Imagenet example`: - https://github.com/NVIDIA/apex/tree/master/examples/imagenet - - .. _`Advanced Amp Usage`: - https://nvidia.github.io/apex/advanced.html - - .. _`Advanced Amp Usage topic`: - https://nvidia.github.io/apex/advanced.html#multiple-models-optimizers-losses - - .. _`let us know`: - https://github.com/NVIDIA/apex/issues - """ - _amp_state.opt_properties = Properties() - _amp_state.verbosity = verbosity - - if not enabled: - if optimizers is None: - return models - else: - return models, optimizers - - if not torch.backends.cudnn.enabled: - raise RuntimeError( - "Amp requires torch.backends.cudnn.enabled = True") - - if opt_level not in opt_levels: - raise RuntimeError( - "Unexpected optimization level {}. ".format(opt_level) + - "Options are 'O0', 'O1', 'O2', 'O3', 'O4', 'O5'. Note that in `O0`, `O1`, etc., the prefix O is the letter O, " + - "not the number zero.") - else: - _amp_state.opt_properties = opt_levels[opt_level](_amp_state.opt_properties) - maybe_print("Selected optimization level {}".format(opt_levels[opt_level].brief), True) - maybe_print("Defaults for this optimization level are:", True) - for k, v in _amp_state.opt_properties.options.items(): - maybe_print("{:26} : {}".format(k, v), True) - - _amp_state.min_loss_scale = min_loss_scale - _amp_state.max_loss_scale = max_loss_scale - - maybe_print("Processing user overrides (additional kwargs that are not None)...", True) - # I chose to have the keyword arguments listed directly in the argument list, - # instead of **kwargs, so I can't use kwargs.items() here. - if enabled is not None: - _amp_state.opt_properties.enabled = enabled - if opt_level is not None: - _amp_state.opt_properties.opt_level = opt_level - if cast_model_type is not None: - _amp_state.opt_properties.cast_model_type = cast_model_type - if patch_torch_functions is not None: - _amp_state.opt_properties.patch_torch_functions = patch_torch_functions - if patch_torch_functions_type is not None: - _amp_state.opt_properties.patch_torch_functions_type = patch_torch_functions_type - if keep_batchnorm_fp32 is not None: - _amp_state.opt_properties.keep_batchnorm_fp32 = keep_batchnorm_fp32 - if master_weights is not None: - _amp_state.opt_properties.master_weights = master_weights - if loss_scale is not None: - _amp_state.opt_properties.loss_scale = loss_scale - - maybe_print("After processing overrides, optimization options are:", True) - for k, v in _amp_state.opt_properties.options.items(): - maybe_print("{:26} : {}".format(k, v), True) - - return _initialize(models, optimizers, _amp_state.opt_properties, num_losses, cast_model_outputs) - - -def state_dict(destination=None): - if destination is None: - destination = OrderedDict() - - for idx, loss_scaler in enumerate(_amp_state.loss_scalers): - destination['loss_scaler%d' % idx] = { - 'loss_scale': loss_scaler.loss_scale(), - 'unskipped': loss_scaler._unskipped, - } - return destination - - -def load_state_dict(state_dict): - # Check if state_dict containes the same number of loss_scalers as current setup - if len(state_dict) != len(_amp_state.loss_scalers): - print('Warning: state_dict contains {} entries, while {} loss_scalers are used'.format( - len(state_dict), len(_amp_state.loss_scalers))) - - state_dict = state_dict.copy() - - nb_loss_scalers = len(_amp_state.loss_scalers) - unexpected_keys = [] - # Initialize idx outside, since unexpected_keys will increase it if enumerate is used - idx = 0 - for key in state_dict: - if 'loss_scaler' not in key: - unexpected_keys.append(key) - else: - if idx > (nb_loss_scalers - 1): - print('Skipping loss_scaler[{}], since num_losses was set to {}'.format( - idx, nb_loss_scalers)) - break - _amp_state.loss_scalers[idx]._loss_scale = state_dict[key]['loss_scale'] - _amp_state.loss_scalers[idx]._unskipped = state_dict[key]['unskipped'] - idx += 1 - - if len(unexpected_keys) > 0: - raise RuntimeError( - 'Error(s) in loading state_dict. Unexpected key(s) in state_dict: {}. '.format( - ', '.join('"{}"'.format(k) for k in unexpected_keys))) - - -# TODO: is this necessary/useful? -# def check_option_consistency(enabled=True, -# opt_level=None, -# cast_model_type=None, -# patch_torch_functions=None, -# keep_batchnorm_fp32=None, -# master_weights=None, -# loss_scale=None, -# enable_ddp_interop=None, -# hard_override=False): -# """ -# Utility function that enables users to quickly check if the option combination they intend -# to use is permitted. ``check_option_consistency`` does not require models or optimizers -# to be constructed, and can be called at any point in the script. ``check_option_consistency`` -# is totally self-contained; it does not set any amp global state or affect anything outside -# of itself. -# """ -# -# if not enabled: -# return -# -# if opt_level not in opt_levels: -# raise RuntimeError("Unexpected optimization level. Options are 'O0', 'O1', 'O2', 'O3'.") -# else: -# opt_properties = opt_levels[opt_level](Properties()) -# print("Selected optimization level {}", opt_levels[opt_level].brief) -# print("Defaults for this optimization level are:") -# for k, v in opt_properties.options: -# print("{:22} : {}".format(k, v)) -# -# print("Processing user overrides (additional kwargs that are not None)...") -# for k, v in kwargs: -# if k not in _amp_state.opt_properties.options: -# raise RuntimeError("Unexpected kwarg {}".format(k)) -# if v is not None: -# setattr(opt_properties, k, v) -# -# print("After processing overrides, optimization options are:") -# for k, v in opt_properties.options: -# print("{:22} : {}".format(k, v)) diff --git a/apex/amp/handle.py b/apex/amp/handle.py deleted file mode 100644 index 0be567c..0000000 --- a/apex/amp/handle.py +++ /dev/null @@ -1,281 +0,0 @@ -import contextlib -import warnings -import sys -import torch - -from . import utils -from .opt import OptimWrapper -from .scaler import LossScaler -from ._amp_state import _amp_state, master_params, maybe_print - -if torch.distributed.is_available(): - from ..parallel.LARC import LARC - - -# There's no reason to expose the notion of a "handle". Everything can happen through amp.* calls. -@contextlib.contextmanager -def scale_loss(loss, - optimizers, - loss_id=0, - model=None, - delay_unscale=False, - delay_overflow_check=False): - """ - On context manager entrance, creates ``scaled_loss = (loss.float())*current loss scale``. - ``scaled_loss`` is yielded so that the user can call ``scaled_loss.backward()``:: - - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - - On context manager exit (if ``delay_unscale=False``), the gradients are checked for infs/NaNs - and unscaled, so that ``optimizer.step()`` can be called. - - .. note:: - If Amp is using explicit FP32 master params (which is the default for ``opt_level=O2``, and - can also be manually enabled by supplying ``master_weights=True`` to ``amp.initialize``) - any FP16 gradients are copied to FP32 master gradients before being unscaled. - ``optimizer.step()`` will then apply the unscaled master gradients to the master params. - - .. warning:: - If Amp is using explicit FP32 master params, only the FP32 master gradients will be - unscaled. The direct ``.grad`` attributes of any FP16 - model params will remain scaled after context manager exit. - This subtlety affects gradient clipping. See "Gradient clipping" under - `Advanced Amp Usage`_ for best practices. - - Args: - loss(Tensor): Typically a scalar Tensor. The ``scaled_loss`` that the context - manager yields is simply ``loss.float()*loss_scale``, so in principle - ``loss`` could have more than one element, as long as you call - ``backward()`` on ``scaled_loss`` appropriately within the context manager body. - optimizers: All optimizer(s) for which the current backward pass is creating gradients. - Must be an optimizer or list of optimizers returned from an earlier call - to ``amp.initialize``. For example use with multiple optimizers, see - "Multiple models/optimizers/losses" under `Advanced Amp Usage`_. - loss_id(int, optional, default=0): When used in conjunction with the ``num_losses`` argument - to ``amp.initialize``, enables Amp to use a different loss scale per loss. ``loss_id`` - must be an integer between 0 and ``num_losses`` that tells Amp which loss is - being used for the current backward pass. See "Multiple models/optimizers/losses" - under `Advanced Amp Usage`_ for examples. If ``loss_id`` is left unspecified, Amp - will use the default global loss scaler for this backward pass. - model(torch.nn.Module, optional, default=None): Currently unused, reserved to enable future - optimizations. - delay_unscale(bool, optional, default=False): ``delay_unscale`` is never necessary, and - the default value of ``False`` is strongly recommended. - If ``True``, Amp will not unscale the gradients or perform model->master - gradient copies on context manager exit. - ``delay_unscale=True`` is a minor ninja performance optimization and can result - in weird gotchas (especially with multiple models/optimizers/losses), - so only use it if you know what you're doing. - "Gradient accumulation across iterations" under `Advanced Amp Usage`_ - illustrates a situation where this CAN (but does not need to) be used. - - .. warning:: - If ``delay_unscale`` is ``True`` for a given backward pass, ``optimizer.step()`` cannot be - called yet after context manager exit, and must wait for another, later backward context - manager invocation with ``delay_unscale`` left to False. - - .. _`Advanced Amp Usage`: - https://nvidia.github.io/apex/advanced.html - """ - if not hasattr(_amp_state, "opt_properties"): - raise RuntimeError("Invoked 'with amp.scale_loss`, but internal Amp state has not been initialized. " - "model, optimizer = amp.initialize(model, optimizer, opt_level=...) must be called " - "before `with amp.scale_loss`.") - - if not _amp_state.opt_properties.enabled: - yield loss - return - - if isinstance(optimizers, torch.optim.Optimizer) or ('LARC' in globals() and isinstance(optimizers, LARC)): - optimizers = [optimizers] - - loss_scaler = _amp_state.loss_scalers[loss_id] - loss_scale = loss_scaler.loss_scale() - - if ((not _amp_state.opt_properties.master_weights) - and (not loss_scaler.dynamic) - and loss_scale == 1.0): - yield loss.float() - # Needing to drop the cache here as well is an ugly gotcha. - # But for now I think it's necessary to short-circuit. - # Probably ok to skip this if not delay_unscale - if _amp_state.opt_properties.patch_torch_functions: - _amp_state.handle._clear_cache() - return - - if not delay_unscale: - if isinstance(optimizers, list): - for optimizer in optimizers: - if not optimizer._amp_stash.params_have_scaled_gradients: - optimizer._prepare_amp_backward() - - yield (loss.float())*loss_scale - - if delay_unscale: - for optimizer in optimizers: - optimizer._amp_stash.params_have_scaled_gradients = True - else: - # FusedSGD may take care of unscaling as part of their step() methods. - # if not isinstance(optimizers, FP16_Optimizer_for_fused): - loss_scaler.clear_overflow_state() - for optimizer in optimizers: - optimizer._post_amp_backward(loss_scaler) - optimizer._amp_stash.params_have_scaled_gradients = False - # For future fused optimizers that enable sync-free dynamic loss scaling, - # should_skip will always be False. - should_skip = False if delay_overflow_check else loss_scaler.update_scale() - if should_skip: - for optimizer in optimizers: - if not optimizer._amp_stash.already_patched: - # Close on loss_scaler and loss_id as well, to be safe. Probably not - # necessary because amp.scale_loss is already creating a temporary scope. - def patch_step(opt, loss_scaler, loss_id): - opt_step = opt.step - def skip_step(closure=None): - if closure is not None: - raise RuntimeError("Currently, Amp does not support closure use with optimizers.") - maybe_print(("Gradient overflow. Skipping step, loss scaler " + - "{} reducing loss scale to {}").format(loss_id, - loss_scaler.loss_scale())) - # TODO: I don't like the special casing for different optimizer implementations. - # Maybe skip should delegate to a method owned by the optimizers themselves. - if hasattr(opt._amp_stash, "all_fp32_from_fp16_params"): - # Clear the master grads that wouldn't be zeroed by model.zero_grad() - for param in opt._amp_stash.all_fp32_from_fp16_params: - param.grad = None - if hasattr(opt, "most_recent_scale"): - opt.most_recent_scale = 1.0 - opt.scale_set_by_backward = False - opt.step = opt_step - opt._amp_stash.already_patched = False - return skip_step - optimizer.step = patch_step(optimizer, loss_scaler, loss_id) - optimizer._amp_stash.already_patched = True - - # Probably ok to skip this if not delay_unscale - if _amp_state.opt_properties.patch_torch_functions: - _amp_state.handle._clear_cache() - - -# Free function version of AmpHandle.disable_casts, another step on the -# path to removing the concept of "AmpHandle" -@contextlib.contextmanager -def disable_casts(): - _amp_state.handle._is_active = False - yield - _amp_state.handle._is_active = True - - -class AmpHandle(object): - def __init__(self, loss_scale="dynamic", enable_caching=True, verbose=False): - self._enable_caching = enable_caching - self._verbose = verbose - self._cache = dict() - self._default_scaler = LossScaler(loss_scale) - self._is_active = True - self._all_wrappers = [] - - def is_active(self): - return self._is_active - - @contextlib.contextmanager - def _disable_casts(self): - self._is_active = False - yield - self._is_active = True - - def wrap_optimizer(self, optimizer, num_loss=1): - self._default_scaler = None - return OptimWrapper(optimizer, self, num_loss) - - @contextlib.contextmanager - def scale_loss(self, loss, optimizer): - raise RuntimeError("The old Amp API is no longer supported. Please move to the new API, " - "documented here: https://nvidia.github.io/apex/amp.html. Transition guide: " - "https://nvidia.github.io/apex/amp.html#transition-guide-for-old-api-users") - - if not self.is_active(): - yield loss - return - - if self._default_scaler is None: - raise RuntimeError( - 'After calling `handle.wrap_optimizer()`, you must explicitly ' + - 'use `optimizer.scale_loss(loss)`.') - - # TODO: this code block is duplicated here and `opt.py`. Unify. - loss_scale = self._default_scaler.loss_scale() - yield loss * loss_scale - - self._default_scaler.clear_overflow_state() - self._default_scaler.unscale( - master_params(optimizer), - master_params(optimizer), - loss_scale) - should_skip = self._default_scaler.update_scale() - if should_skip: - optimizer_step = optimizer.step - def skip_step(): - maybe_print('Gradient overflow, skipping update') - optimizer.step = optimizer_step - optimizer.step = skip_step - - self._clear_cache() - - def _clear_cache(self): - self._cache.clear() - - # Experimental support for saving / restoring uncasted versions of functions - def _save_func(self, mod, fn, func): - self._all_wrappers.append((mod, fn, func)) - - def _deactivate(self): - for mod, fn, func in self._all_wrappers: - utils.set_func(mod, fn, func) - self._all_wrappers = [] - - @property - def has_cache(self): - return self._enable_caching - - @property - def cache(self): - return self._cache - - def remove_cache(self, param): - if self.has_cache and param in self.cache: - del self.cache[param] - - @property - def verbose(self): - return self._verbose - -class NoOpHandle(object): - def is_active(self): - return False - - @contextlib.contextmanager - def _disable_casts(self): - yield - - def wrap_optimizer(self, optimizer, num_loss=1): - return OptimWrapper(optimizer, self, num_loss) - - @contextlib.contextmanager - def scale_loss(self, loss, optimizer): - yield loss - - @property - def has_cache(self): - return False - - @property - def verbose(self): - return False - - def _clear_cache(self): - pass - - def _deactivate(self): - pass diff --git a/apex/amp/lists/__init__.py b/apex/amp/lists/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/apex/amp/lists/functional_overrides.py b/apex/amp/lists/functional_overrides.py deleted file mode 100644 index 9ecdf09..0000000 --- a/apex/amp/lists/functional_overrides.py +++ /dev/null @@ -1,91 +0,0 @@ - -# TODO: think about the following two. They do weird things. -# - torch.nn.utils.clip_grad (but it should always be fp32 anyway) -# - torch.nn.utils.weight_norm - -# Notes: -# F.instance_norm uses batch_norm internally. Which correctly handles -# fp16 in/out with fp32 weights. So we shouldn't do anything for -# either of these. -# F.normalize calls `input.norm()` internally, so it's redundant, but -# kept here in case impl. changes. -# F.cosine_similarity is same: calls `x.norm()` internally. - -import torch.nn.functional - -MODULE = torch.nn.functional - -FP16_FUNCS = [ - 'conv1d', - 'conv2d', - 'conv3d', - 'conv_transpose1d', - 'conv_transpose2d', - 'conv_transpose3d', - 'conv_tbc', # Undocumented / maybe new? - 'linear', -] - -BFLOAT16_FUNCS = [ - 'conv1d', - 'conv2d', - 'conv3d', - 'conv_transpose1d', - 'conv_transpose2d', - 'conv_transpose3d', - 'conv_tbc', # Undocumented / maybe new? - 'linear', -] - -FP32_FUNCS = [ - - # Interpolation/Upsampling TODO: Remove for 1.2 - 'interpolate', - 'grid_sample', - - # Pointwise - 'softplus', - 'softmin', - 'log_softmax', - 'softmax', - 'gelu', - - # Normalization - 'layer_norm', - 'group_norm', - 'local_response_norm', - 'normalize', - 'cosine_similarity', - - # Loss functions - # TODO: which of these can be fp16? - 'poisson_nll_loss', - 'cosine_embedding_loss', - 'cross_entropy', - 'hinge_embedding_loss', - 'kl_div', - 'l1_loss', - 'mse_loss', - 'margin_ranking_loss', - 'multilabel_margin_loss', - 'multilabel_soft_margin_loss', - 'multi_margin_loss', - 'nll_loss', - 'binary_cross_entropy_with_logits', - 'smooth_l1_loss', - 'soft_margin_loss', - 'triplet_margin_loss', - 'ctc_loss' -] - -BANNED_FUNCS = [ - ('binary_cross_entropy', - ("\namp does not work out-of-the-box with `F.binary_cross_entropy` or `torch.nn.BCELoss.` " - "It requires that the output of the previous function be already a FloatTensor. \n\n" - "Most models have a Sigmoid right before BCELoss. In that case, you can use\n" - " torch.nn.BCEWithLogitsLoss\nto combine Sigmoid+BCELoss into a single layer " - "that is compatible with amp.\nAnother option is to add\n" - " amp.register_float_function(torch, 'sigmoid')\nbefore calling `amp.init()`.\n" - "If you _really_ know what you are doing, you can disable this warning by passing " - "allow_banned=True to `amp.init()`.")) -] diff --git a/apex/amp/lists/tensor_overrides.py b/apex/amp/lists/tensor_overrides.py deleted file mode 100644 index d2783ce..0000000 --- a/apex/amp/lists/tensor_overrides.py +++ /dev/null @@ -1,67 +0,0 @@ -from .. import compat -from . import torch_overrides - -import importlib - -import torch - -# if compat.variable_is_tensor() and not compat.tensor_is_variable(): -MODULE = torch.Tensor -# else: -# MODULE = torch.autograd.Variable - - -FP16_FUNCS = compat.filter_attrs(MODULE, [ - '__matmul__', -]) - -BFLOAT16_FUNCS = [ - '__matmul__', -] - -FP32_FUNCS = compat.filter_attrs(MODULE, [ - '__ipow__', - '__pow__', - '__rpow__', - - # Cast to fp32 before transfer to CPU - 'cpu', -]) - -CASTS = compat.filter_attrs(MODULE, [ - '__add__', - '__div__', - '__eq__', - '__ge__', - '__gt__', - '__iadd__', - '__idiv__', - '__imul__', - '__isub__', - '__itruediv__', - '__le__', - '__lt__', - '__mul__', - '__ne__', - '__radd__', - '__rdiv__', - '__rmul__', - '__rsub__', - '__rtruediv__', - '__sub__', - '__truediv__', -]) - -# None of these, but here to make code cleaner. -SEQUENCE_CASTS = [] - -# We need to grab all the methods from torch_overrides and add them to -# the Tensor lists as well, as almost all methods are duplicated -# between `torch` and `torch.Tensor` (and check with `hasattr`, -# because a few random ones aren't defined on Tensor) -_self_mod = importlib.import_module(__name__) -for attrname in ['FP16_FUNCS', 'BFLOAT16_FUNCS', 'FP32_FUNCS', 'CASTS', 'SEQUENCE_CASTS']: - lst = getattr(_self_mod, attrname) - for fn in getattr(torch_overrides, attrname): - if hasattr(MODULE, fn): - lst.append(fn) diff --git a/apex/amp/lists/torch_overrides.py b/apex/amp/lists/torch_overrides.py deleted file mode 100644 index 0998870..0000000 --- a/apex/amp/lists/torch_overrides.py +++ /dev/null @@ -1,136 +0,0 @@ -import torch - -from .. import utils - -MODULE = torch - -FP16_FUNCS = [ - # Low level functions wrapped by torch.nn layers. - # The wrapper layers contain the weights which are then passed in as a parameter - # to these functions. - 'conv1d', - 'conv2d', - 'conv3d', - 'conv_transpose1d', - 'conv_transpose2d', - 'conv_transpose3d', - 'conv_tbc', - 'prelu', - - # BLAS - 'addmm', - 'addmv', - 'addr', - 'matmul', - 'mm', - 'mv', -] - -BFLOAT16_FUNCS = [ - # Low level functions wrapped by torch.nn layers. - # The wrapper layers contain the weights which are then passed in as a parameter - # to these functions. - 'conv1d', - 'conv2d', - 'conv3d', - 'conv_transpose1d', - 'conv_transpose2d', - 'conv_transpose3d', - 'conv_tbc', - - # BLAS - 'addmm', - 'addmv', - 'addr', - 'matmul', - 'mm', - 'mv', -] - -FP32_FUNCS = [ - # Pointwise - 'acos', - 'asin', - 'cosh', - 'erfinv', - 'exp', - 'expm1', - 'log', - 'log10', - 'log2', - 'reciprocal', - 'rsqrt', - 'sinh', - 'tan', - - # Other math - 'pow', - - # Reduction - 'cumprod', - 'cumsum', - 'dist', - # 'mean', - 'norm', - 'prod', - 'std', - 'sum', - 'var', - - # Misc - 'renorm' -] - -version_strings = torch.__version__.split('.') -version_major = version_strings[0] -version_minor = version_strings[1] -version_num = float(version_major + "." + version_minor) -# Before torch 1.1, mean must be blacklisted. -if version_num < 1.1: - FP32_FUNCS.append('mean') - -# Before CUDA 9.1, batched matmul was missing fast FP16 kernels. We -# check the CUDA version -- if at least 9.1, then put the bmm -# functions on the fp16 list. Otherwise, put them on the fp32 list. -_bmms = ['addbmm', - 'baddbmm', - 'bmm'] - -if utils.is_cuda_enabled(): - # workaround https://github.com/facebookresearch/maskrcnn-benchmark/issues/802 - if utils.get_cuda_version() >= (9, 1, 0): - FP16_FUNCS.extend(_bmms) - else: - FP32_FUNCS.extend(_bmms) - -# Multi-tensor fns that may need type promotion -CASTS = [ - # Multi-tensor math - 'addcdiv', - 'addcmul', - 'atan2', - 'cross', - 'bilinear', - 'dot', - - # Element-wise _or_ tensor-wise math - 'add', - 'div', - 'mul', - - # Comparison - 'eq', - 'equal', - 'ge', - 'gt', - 'le', - 'lt', - 'ne' -] - -# Functions that take sequence arguments. We need to inspect the whole -# sequence and cast to the widest type. -SEQUENCE_CASTS = [ - 'cat', - 'stack' -] diff --git a/apex/amp/opt.py b/apex/amp/opt.py deleted file mode 100644 index baf3116..0000000 --- a/apex/amp/opt.py +++ /dev/null @@ -1,103 +0,0 @@ -import contextlib -import warnings - -from .scaler import LossScaler, master_params -from ._amp_state import maybe_print - -import numpy as np - -class OptimWrapper(object): - def __init__(self, optimizer, amp_handle, num_loss): - self._optimizer = optimizer - self._amp_handle = amp_handle - self._num_loss = num_loss - self._loss_idx = 0 - self._skip_next = [False] * num_loss - self._loss_scaler = [LossScaler('dynamic') for _ in range(num_loss)] - - @contextlib.contextmanager - def scale_loss(self, loss): - if not self._amp_handle.is_active(): - yield loss - return - - # When there are multiple losses per-optimizer, we need - # to save out current grad accumulation, since we won't be - # able to unscale this particulare loss once the grads are - # all mixed together. - cached_grads = [] - if self._loss_idx > 0: - for p in master_params(self._optimizer): - if p.grad is not None: - cached_grads.append(p.grad.data.detach().clone()) - else: - cached_grads.append(None) - self._optimizer.zero_grad() - - loss_scale = self._cur_loss_scaler().loss_scale() - yield loss * loss_scale - - self._cur_loss_scaler().clear_overflow_state() - self._cur_loss_scaler().unscale( - master_params(self._optimizer), - master_params(self._optimizer), - loss_scale) - self._skip_next[self._loss_idx] = self._cur_loss_scaler().update_scale() - self._loss_idx += 1 - - if len(cached_grads) > 0: - for p, cached_grad in zip(master_params(self._optimizer), - cached_grads): - if cached_grad is not None: - p.grad.data.add_(cached_grad) - cached_grads = [] - - def _cur_loss_scaler(self): - assert 0 <= self._loss_idx < self._num_loss - return self._loss_scaler[self._loss_idx] - - def step(self, closure=None): - if not self._amp_handle.is_active(): - return self._optimizer.step(closure=closure) - - self._loss_idx = 0 - - for group in self._optimizer.param_groups: - for p in group['params']: - self._amp_handle.remove_cache(p) - - if closure is not None: - raise NotImplementedError( - 'The `closure` argument is unsupported by the amp ' + - 'optimizer wrapper.') - if any(self._skip_next): - maybe_print('Gradient overflow, skipping update') - self._skip_next = [False] * self._num_loss - else: - return self._optimizer.step(closure=closure) - - # Forward any attribute lookups - def __getattr__(self, attr): - return getattr(self._optimizer, attr) - - # Forward all torch.optim.Optimizer methods - def __getstate__(self): - return self._optimizer.__getstate__() - - def __setstate__(self): - return self._optimizer.__setstate__() - - def __repr__(self): - return self._optimizer.__repr__() - - def state_dict(self): - return self._optimizer.state_dict() - - def load_state_dict(self, state_dict): - return self._optimizer.load_state_dict(state_dict) - - def zero_grad(self): - return self._optimizer.zero_grad() - - def add_param_group(self, param_group): - return self._optimizer.add_param_group(param_group) diff --git a/apex/amp/rnn_compat.py b/apex/amp/rnn_compat.py deleted file mode 100644 index 987dba7..0000000 --- a/apex/amp/rnn_compat.py +++ /dev/null @@ -1,53 +0,0 @@ -from . import utils, wrap - -import torch -_VF = torch._C._VariableFunctions -RNN_NAMES = ['rnn_relu', 'rnn_tanh', 'gru', 'lstm'] - -def _gen_VF_wrapper(name): - def wrapper(*args, **kwargs): - return getattr(_VF, name)(*args, **kwargs) - return wrapper - -# Some python magic to generate an object that has the rnn cell functions -# defined on it, all of which call into corresponding _VF version. -# Intended to patch torch.nn.modules.rnn._VF (aka, the ref named "_VF" -# imported at module scope within torch.nn.modules.rnn). This should -# not affect third-party importers of _VF.py. -class VariableFunctionsShim(object): - def __init__(self): - for name in RNN_NAMES: - for suffix in ['', '_cell']: - fn_name = name + suffix - setattr(self, fn_name, _gen_VF_wrapper(fn_name)) - -def has_old_rnns(): - try: - torch.nn.backends.thnn.backend.LSTMCell - return True - except: - return False - -def whitelist_rnn_cells(cast_fn, handle, verbose): - # Different module + function names in old/new RNN cases - if has_old_rnns(): - fn_names = ['RNNReLUCell', 'RNNTanhCell', 'LSTMCell', 'GRUCell'] - mod = torch.nn.backends.thnn.backend - else: - fn_names = [x + '_cell' for x in RNN_NAMES] - mod = torch.nn.modules.rnn._VF - assert isinstance(mod, VariableFunctionsShim) - - # Insert casts on cell functions - for fn in fn_names: - wrap.cached_cast(mod, fn, cast_fn, handle, - try_caching=True, verbose=verbose) - - if has_old_rnns(): - # Special handling of `backward` for fused gru / lstm: - # The `backward` method calls Tensor.sum() (blacklist) internally, - # and then the resulting grad_input has the wrong type. - # TODO: where else is this a problem? - for rnn_type in ['GRUFused', 'LSTMFused']: - mod = getattr(torch.nn._functions.thnn.rnnFusedPointwise, rnn_type) - wrap.disable_casts(mod, 'backward', handle) diff --git a/apex/amp/scaler.py b/apex/amp/scaler.py deleted file mode 100644 index 15c70d4..0000000 --- a/apex/amp/scaler.py +++ /dev/null @@ -1,226 +0,0 @@ -import torch -from ..multi_tensor_apply import multi_tensor_applier -from ._amp_state import _amp_state, master_params, maybe_print -from itertools import product - -def scale_check_overflow_python(model_grad, master_grad, scale, check_overflow=False): - # Exception handling for 18.04 compatibility - if check_overflow: - if model_grad.is_sparse: - cpu_sum = float(model_grad.float()._values().sum()) - else: - cpu_sum = float(model_grad.float().sum()) - if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: - return True - - if master_grad is not model_grad: # copy_ probably internally short-circuits this - if model_grad.is_sparse: - master_grad.copy_(model_grad.to_dense()) - else: - master_grad.copy_(model_grad) - if scale != 1.0: - master_grad.mul_(scale) - return False - -def axpby_check_overflow_python(model_grad, stashed_grad, master_grad, a, b, check_overflow=False): - # Exception handling for 18.04 compatibility - if check_overflow: - if model_grad.is_sparse: - cpu_sum = float(model_grad.float()._values().sum()) - else: - cpu_sum = float(model_grad.float().sum()) - if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: - return True - - # if master_grad is not model_grad: # copy_ probably internally short-circuits this - # master_grad.copy_(model_grad) - assert stashed_grad.dtype == master_grad.dtype - converted_model_grad = model_grad.data.to(master_grad.dtype) - master_grad.data = a*converted_model_grad.data + b*stashed_grad.data - return False - -class LossScaler(object): - warned_no_fused_kernel = False - warned_unscaling_non_fp32_grad = False - has_fused_kernel = False - - def __init__(self, - loss_scale, - init_scale=2.**16, - scale_factor=2., - scale_window=2000, - min_loss_scale=None, - max_loss_scale=2.**24): - if loss_scale == "dynamic": - self.dynamic = True - self._loss_scale = min(max_loss_scale, init_scale) - else: - self.dynamic = False - self._loss_scale = loss_scale - self._max_loss_scale = max_loss_scale - self._min_loss_scale = min_loss_scale - self._scale_seq_len = scale_window - self._unskipped = 0 - self._has_overflow = False - self._overflow_buf = torch.cuda.IntTensor([0]) - if multi_tensor_applier.available: - import amp_C - LossScaler.has_fused_kernel = multi_tensor_applier.available - LossScaler.multi_tensor_scale_cuda = amp_C.multi_tensor_scale - LossScaler.multi_tensor_axpby_cuda = amp_C.multi_tensor_axpby - else: - if not LossScaler.warned_no_fused_kernel: - maybe_print( - "Warning: multi_tensor_applier fused unscale kernel is unavailable, " - "possibly because apex was installed without --cuda_ext --cpp_ext. " - "Using Python fallback. Original ImportError was: " + - repr(multi_tensor_applier.import_err), - True) - LossScaler.has_fused_kernel = False - LossScaler.warned_no_fused_kernel = True - - def loss_scale(self): - return self._loss_scale - - def unscale_python(self, model_grads, master_grads, scale): - for model, master in zip(model_grads, master_grads): - if model is not None: - if not LossScaler.warned_unscaling_non_fp32_grad: - if master.dtype != torch.float32: - maybe_print( - "Attempting to unscale a grad with type {} ".format(master.type()) + - "Unscaling non-fp32 grads may indicate an error. " - "When using Amp, you don't need to call .half() on your model.") - LossScaler.warned_unscaling_non_fp32_grad = True - self._has_overflow = scale_check_overflow_python(model, - master, - 1./scale, - self.dynamic) - if self._has_overflow and self.dynamic: - break - - # unused_scale keeps some of the old API alive for hopefully a short time. - def unscale(self, model_grads, master_grads, unused_scale, models_are_masters=False, scale_override=None): - if self._has_overflow: - return - - scale = self._loss_scale - if scale_override is not None: - scale = scale_override - - if scale == 1.0 and models_are_masters and not self.dynamic: - return - - if LossScaler.has_fused_kernel: - # if (not LossScaler.warned_unscaling_non_fp32_grad - # and master_grads[0].dtype == torch.float16): - # print("Warning: unscaling grads that are not FP32. " - # "Unscaling non-fp32 grads may indicate an error. " - # "When using Amp, you don't need to call .half() on your model.") - # # Setting this to True unconditionally allows the possibility of an escape - # # if never-before-seen non-fp32 grads are created in some later iteration. - # LossScaler.warned_unscaling_non_fp32_grad = True - multi_tensor_applier(LossScaler.multi_tensor_scale_cuda, - self._overflow_buf, - [model_grads, master_grads], - 1./scale) - else: - self.unscale_python(model_grads, master_grads, scale) - - # Defer to update_scale - # If the fused kernel is available, we only need one D2H memcopy and sync. - # if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow: - # self._has_overflow = self._overflow_buf.item() - - def unscale_with_stashed_python(self, - model_grads, - stashed_master_grads, - master_grads, - a, - b): - for model, stashed, master in zip(model_grads, stashed_master_grads, master_grads): - if model is None and stashed is None: - continue - else: - if not LossScaler.warned_unscaling_non_fp32_grad: - if master.dtype != torch.float32: - maybe_print( - "Attempting to unscale a grad with type {} ".format(master.type()) + - "Unscaling non-fp32 grads may indicate an error. " - "When using Amp, you don't need to call .half() on your model.") - LossScaler.warned_unscaling_non_fp32_grad = True - self._has_overflow = axpby_check_overflow_python(model, - stashed, - master, - a, - b, - self.dynamic) - if self._has_overflow and self.dynamic: - break - - def unscale_with_stashed(self, - model_grads, - stashed_master_grads, - master_grads, - scale_override=None): - if self._has_overflow: - return - - grads_have_scale, stashed_have_scale, out_scale = self._loss_scale, 1.0, 1.0 - if scale_override is not None: - grads_have_scale, stashed_have_scale, out_scale = scale_override - - if LossScaler.has_fused_kernel: - if (not LossScaler.warned_unscaling_non_fp32_grad - and master_grads[0].dtype == torch.float16): - print("Warning: unscaling grads that are not FP32. " - "Unscaling non-fp32 grads may indicate an error. " - "When using Amp, you don't need to call .half() on your model.") - # Setting this to True unconditionally allows the possibility of an escape - # if never-before-seen non-fp32 grads are created in some later iteration. - LossScaler.warned_unscaling_non_fp32_grad = True - multi_tensor_applier(LossScaler.multi_tensor_axpby_cuda, - self._overflow_buf, - [model_grads, stashed_master_grads, master_grads], - out_scale/grads_have_scale, # 1./scale, - out_scale/stashed_have_scale, # 1.0, - 0) # check only arg 0, aka the incoming model grads, for infs - else: - self.unscale_with_stashed_python(model_grads, - stashed_master_grads, - master_grads, - out_scale/grads_have_scale, - out_scale/stashed_have_scale) - - # Defer to update_scale - # If the fused kernel is available, we only need one D2H memcopy and sync. - # if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow: - # self._has_overflow = self._overflow_buf.item() - - def clear_overflow_state(self): - self._has_overflow = False - if self.has_fused_kernel: - self._overflow_buf.zero_() - - # Separate so unscale() can be called more that once before updating. - def update_scale(self): - # If the fused kernel is available, we only need one D2H memcopy and sync. - if LossScaler.has_fused_kernel and self.dynamic and not self._has_overflow: - self._has_overflow = self._overflow_buf.item() - - if self._has_overflow and self.dynamic: - should_skip = True - if(self._min_loss_scale): - self._loss_scale = max(self._min_loss_scale, self._loss_scale/2.) - else: - self._loss_scale = self._loss_scale/2. - self._unskipped = 0 - else: - should_skip = False - self._unskipped += 1 - - if self._unskipped == self._scale_seq_len and self.dynamic: - self._loss_scale = min(self._max_loss_scale, self._loss_scale*2.) - self._unskipped = 0 - - return should_skip diff --git a/apex/amp/utils.py b/apex/amp/utils.py deleted file mode 100644 index c27fce5..0000000 --- a/apex/amp/utils.py +++ /dev/null @@ -1,232 +0,0 @@ -from . import compat - -import functools -import itertools - -import torch - -def is_cuda_enabled(): - return torch.version.cuda is not None - -def get_cuda_version(): - return tuple(int(x) for x in torch.version.cuda.split('.')) - -def is_fp_tensor(x): - if is_nested(x): - # Fast-fail version of all(is_fp_tensor) - for y in x: - if not is_fp_tensor(y): - return False - return True - return compat.is_tensor_like(x) and compat.is_floating_point(x) - -def is_nested(x): - return isinstance(x, tuple) or isinstance(x, list) - -def should_cache(x): - if is_nested(x): - # Fast-fail version of all(should_cache) - for y in x: - if not should_cache(y): - return False - return True - return isinstance(x, torch.nn.parameter.Parameter) and \ - type_string(x) == 'FloatTensor' - -def collect_fp_tensor_types(args, kwargs): - def collect_types(x, types): - if is_nested(x): - for y in x: - collect_types(y, types) - else: - types.add(type_string(x)) - - all_args = itertools.chain(args, kwargs.values()) - types = set() - for x in all_args: - if is_fp_tensor(x): - collect_types(x, types) - return types - -def type_string(x): - return x.type().split('.')[-1] - -def maybe_half(x, name='', verbose=False): - if is_nested(x): - return type(x)([maybe_half(y) for y in x]) - - if not x.is_cuda or type_string(x) == 'HalfTensor': - return x - else: - if verbose: - print('Float->Half ({})'.format(name)) - return x.half() - -def maybe_bfloat16(x, name='', verbose=False): - if is_nested(x): - return type(x)([maybe_bfloat16(y) for y in x]) - - if not x.is_cuda or type_string(x) == 'BFloat16Tensor': - return x - else: - if verbose: - print('Float->BFloat16 ({})'.format(name)) - return x.bfloat16() - -def maybe_float(x, name='', verbose=False): - if is_nested(x): - return type(x)([maybe_float(y) for y in x]) - - if not x.is_cuda or type_string(x) == 'FloatTensor': - return x - else: - if verbose: - print('Half->Float ({})'.format(name)) - return x.float() - -# NB: returneds casted `args`, mutates `kwargs` in-place -def casted_args(cast_fn, args, kwargs): - new_args = [] - for x in args: - if is_fp_tensor(x): - new_args.append(cast_fn(x)) - else: - new_args.append(x) - for k in kwargs: - val = kwargs[k] - if is_fp_tensor(val): - kwargs[k] = cast_fn(val) - return new_args - -def cached_cast(cast_fn, x, cache): - if is_nested(x): - return type(x)([cached_cast(y) for y in x]) - if x in cache: - cached_x = cache[x] - next_functions_available = False - if x.requires_grad and cached_x.requires_grad: - if len(cached_x.grad_fn.next_functions) > 1: - next_functions_available = True - # Make sure x is actually cached_x's autograd parent. - if next_functions_available and cached_x.grad_fn.next_functions[1][0].variable is not x: - raise RuntimeError("x and cache[x] both require grad, but x is not " - "cache[x]'s parent. This is likely an error.") - # During eval, it's possible to end up caching casted weights with - # requires_grad=False. On the next training iter, if cached_x is found - # and reused from the cache, it will not actually have x as its parent. - # Therefore, we choose to invalidate the cache (and force refreshing the cast) - # if x.requires_grad and cached_x.requires_grad do not match. - # - # During eval (i.e. running under with torch.no_grad()) the invalidation - # check would cause the cached value to be dropped every time, because - # cached_x would always be created with requires_grad=False, while x would - # still have requires_grad=True. This would render the cache effectively - # useless during eval. Therefore, if we are running under the no_grad() - # context manager (torch.is_grad_enabled=False) we elide the invalidation - # check, and use the cached value even though its requires_grad flag doesn't - # match. During eval, we don't care that there's no autograd-graph - # connection between x and cached_x. - if torch.is_grad_enabled() and x.requires_grad != cached_x.requires_grad: - del cache[x] - elif x.requires_grad and cached_x.requires_grad and not next_functions_available: - del cache[x] - else: - return cached_x - - casted_x = cast_fn(x) - cache[x] = casted_x - return casted_x - -def verbosify(cast_fn, fn_name, verbose): - if verbose: - return functools.partial(cast_fn, name=fn_name, verbose=verbose) - else: - return cast_fn - -def as_inplace(fns): - for x in fns: - yield x + '_' - -def has_func(mod, fn): - if isinstance(mod, dict): - return fn in mod - else: - return hasattr(mod, fn) - -def get_func(mod, fn): - if isinstance(mod, dict): - return mod[fn] - else: - return getattr(mod, fn) - -def set_func(mod, fn, new_fn): - if isinstance(mod, dict): - mod[fn] = new_fn - else: - setattr(mod, fn, new_fn) - -def set_func_save(handle, mod, fn, new_fn): - cur_fn = get_func(mod, fn) - handle._save_func(mod, fn, cur_fn) - set_func(mod, fn, new_fn) - -# A couple problems get solved here: -# - The flat_weight buffer is disconnected from autograd graph, -# so the fp16 weights need to be derived from the input weights -# to this forward call, not the flat buffer. -# - The ordering of weights in the flat buffer is...idiosyncratic. -# First problem is solved with combination of set_ (to set up -# correct storage) and copy_ (so the fp16 weight derives from the -# fp32 one in autograd. -# Second is solved by doing ptr arithmetic on the fp32 weights -# to derive the correct offset. -# -# TODO: maybe this should actually use -# `torch._cudnn_rnn_flatten_weight`? But then I need to call -# on first iter and cache the right offsets. Ugh. -def synthesize_flattened_rnn_weights(fp32_weights, - fp16_flat_tensor, - rnn_fn='', - verbose=False): - fp16_weights = [] - fp32_base_ptr = fp32_weights[0][0].data_ptr() - for layer_weights in fp32_weights: - fp16_layer_weights = [] - for w_fp32 in layer_weights: - w_fp16 = w_fp32.new().half() - offset = (w_fp32.data_ptr() - fp32_base_ptr) // w_fp32.element_size() - w_fp16.set_(fp16_flat_tensor.storage(), - offset, - w_fp32.shape) - w_fp16.copy_(w_fp32) - if verbose: - print('Float->Half ({})'.format(rnn_fn)) - fp16_layer_weights.append(w_fp16) - fp16_weights.append(fp16_layer_weights) - return fp16_weights - -def _str_from_dtype(dtype=torch.float16): - type_to_str = {torch.float16 : 'Half', - torch.bfloat16 : 'BFloat16'} - return type_to_str[dtype] - -# Roughly same as above, just the `fp32_weights` aren't nested. -# Code kept separate for readability. -def new_synthesize_flattened_rnn_weights(fp32_weights, - fp16_flat_tensor, - rnn_fn='', - dtype=torch.float16, - verbose=False): - fp16_weights = [] - fp32_base_ptr = fp32_weights[0].data_ptr() - for w_fp32 in fp32_weights: - w_fp16 = w_fp32.new().to(dtype=dtype) - offset = (w_fp32.data_ptr() - fp32_base_ptr) // w_fp32.element_size() - w_fp16.set_(fp16_flat_tensor.storage(), - offset, - w_fp32.shape) - w_fp16.copy_(w_fp32) - if verbose: - print('Float->{} ({})'.format(_str_from_dtype(dtype), rnn_fn)) - fp16_weights.append(w_fp16) - return fp16_weights diff --git a/apex/amp/wrap.py b/apex/amp/wrap.py deleted file mode 100644 index d0a23fd..0000000 --- a/apex/amp/wrap.py +++ /dev/null @@ -1,286 +0,0 @@ -from . import compat -from . import utils -from ._amp_state import _amp_state -from . import rnn_compat - -import functools - -import torch - -def make_cast_wrapper(orig_fn, cast_fn, handle, - try_caching=False): - @functools.wraps(orig_fn) - def wrapper(*args, **kwargs): - if not handle.is_active(): - return orig_fn(*args, **kwargs) - - if try_caching and handle.has_cache: - args = list(args) - for i in range(len(args)): - if utils.should_cache(args[i]): - args[i] = utils.cached_cast(cast_fn, args[i], handle.cache) - for k in kwargs: - if utils.should_cache(kwargs[k]): - kwargs[k] = utils.cached_cast(cast_fn, kwargs[k], handle.cache) - new_args = utils.casted_args(cast_fn, - args, - kwargs) - return orig_fn(*new_args, **kwargs) - return wrapper - -def cached_cast(mod, fn, cast_fn, handle, - try_caching=False, verbose=False): - if not utils.has_func(mod, fn): - return - - orig_fn = utils.get_func(mod, fn) - cast_fn = utils.verbosify(cast_fn, fn, verbose) - wrapper = make_cast_wrapper(orig_fn, cast_fn, handle, try_caching) - utils.set_func_save(handle, mod, fn, wrapper) - -# `handle` arg is unused, but simplifies API to make `make_cast_wrapper` -# Annoyingly, make_promote_wrapper still uses the global handle. Once everyone -# is on the new API and I am free to get rid of handle, I can clean this up. -def make_promote_wrapper(orig_fn, cast_fn, handle=None): - @functools.wraps(orig_fn) - def wrapper(*args, **kwargs): - if not _amp_state.handle.is_active(): - return orig_fn(*args, **kwargs) - - types = utils.collect_fp_tensor_types(args, kwargs) - - if len(types) <= 1: - return orig_fn(*args, **kwargs) - elif len(types) == 2 and (types == set(['HalfTensor', 'FloatTensor']) - or types == set(['BFloat16Tensor', 'FloatTensor'])): - new_args = utils.casted_args(cast_fn, - args, - kwargs) - return orig_fn(*new_args, **kwargs) - else: - raise NotImplementedError('Do not know how to handle ' + - 'these types to promote: {}' - .format(types)) - return wrapper - -def promote(mod, fn, handle, verbose=False): - orig_fn = utils.get_func(mod, fn) - maybe_float = utils.verbosify(utils.maybe_float, fn, verbose) - wrapper = make_promote_wrapper(orig_fn, maybe_float) - utils.set_func_save(handle, mod, fn, wrapper) - -def sequence_promote(mod, fn, handle, verbose=False): - orig_fn = utils.get_func(mod, fn) - maybe_float = utils.verbosify(utils.maybe_float, fn, verbose) - @functools.wraps(orig_fn) - def wrapper(seq, *args, **kwargs): - if not _amp_state.handle.is_active(): - return orig_fn(seq, *args, **kwargs) - - types = set([utils.type_string(x) for x in seq]) - if len(types) <= 1: - return orig_fn(seq, *args, **kwargs) - elif (types == set(['HalfTensor', 'FloatTensor']) or - types == set(['BFloat16Tensor', 'FloatTensor'])): - cast_seq = utils.casted_args(maybe_float, - seq, {}) - return orig_fn(cast_seq, *args, **kwargs) - else: - # TODO: other mixed-type cases aren't due to amp. - # Just pass through? - return orig_fn(seq, *args, **kwargs) - utils.set_func_save(handle, mod, fn, wrapper) - -def promote_match_arg0(mod, fn, handle, verbose=False): - if not utils.has_func(mod, fn): - return - - orig_fn = utils.get_func(mod, fn) - @functools.wraps(orig_fn) - def wrapper(arg0, *args, **kwargs): - assert compat.is_tensor_like(arg0) - if not _amp_state.handle.is_active(): - return orig_fn(arg0, *args, **kwargs) - - if utils.type_string(arg0) == 'HalfTensor': - cast_fn = utils.maybe_half - if utils.type_string(arg0) == 'BFloat16Tensor': - cast_fn = utils.maybe_bfloat16 - elif utils.type_string(arg0) == 'FloatTensor': - cast_fn = utils.maybe_float - else: - return orig_fn(arg0, *args, **kwargs) - cast_fn = utils.verbosify(cast_fn, fn, verbose) - new_args = utils.casted_args(cast_fn, args, kwargs) - return orig_fn(arg0, *new_args, **kwargs) - utils.set_func_save(handle, mod, fn, wrapper) - -def err_if_any_half(mod, fn, handle, custom_err_msg=None): - if not utils.has_func(mod, fn): - return - - orig_fn = utils.get_func(mod, fn) - @functools.wraps(orig_fn) - def wrapper(*args, **kwargs): - types = utils.collect_fp_tensor_types(args, kwargs) - if 'HalfTensor' in types or 'BFloat16Tensor' in types: - if custom_err_msg: - raise NotImplementedError(custom_err_msg) - else: - raise NotImplementedError('Cannot call in-place function ' + - '{} with fp16 or bfloat16 args.'.format(fn)) - else: - return orig_fn(*args, **kwargs) - utils.set_func_save(handle, mod, fn, wrapper) - -def err_if_arg0_half(mod, fn, handle, verbose=False): - if not utils.has_func(mod, fn): - return - - orig_fn = utils.get_func(mod, fn) - @functools.wraps(orig_fn) - def wrapper(arg0, *args, **kwargs): - assert compat.is_tensor_like(arg0) - if utils.type_string(arg0) in {'HalfTensor', 'BFloat16Tensor'}: - raise NotImplementedError('Cannot call in-place method ' + - '{} with fp16 or bfloat16 args.'.format(fn)) - else: - cast_fn = utils.verbosify(utils.maybe_float, fn, verbose) - new_args = utils.casted_args(cast_fn, args, kwargs) - return orig_fn(arg0, *new_args, **kwargs) - utils.set_func_save(handle, mod, fn, wrapper) - -# Current RNN approach: -# - Wrap top-level `RNN` function in thnn backend -# - Will call into either CudnnRNN or AutogradRNN -# - Each of these are factory functions that return a per-iter -# `forward` function -# - We interpose on the factory function to: -# 1) Interpose on the actual forward function and put in casts -# 2) Insert an fp16 `flat_weight` if necessary -def rnn_cast(backend, fn, handle, verbose=False): - orig_rnn = utils.get_func(backend, fn) - @functools.wraps(orig_rnn) - def rnn_wrapper(*args, **kwargs): - flat_weight = kwargs.get('flat_weight') - if flat_weight is not None: - # We replace `flat_weight` with an uninitialized fp16 - # Tensor. The "actual" weight tensors (provided in `forward`), - # will then be set up as ptrs into the buffer and have the - # corresponding fp32 values copied in. - # We need to call `copy` on the "actual" weights so that the - # autograd graph correctly backprops from the wgrads computed - # inside cuDNN (on fp16 weights) into the fp32 weights. - assert utils.type_string(flat_weight) == 'FloatTensor' - if compat.tensor_is_float_tensor() or compat.tensor_is_variable(): - # Pre-0.4. A little slower, since it zeros out memory. - flat_weight_fp16 = flat_weight.new().half().resize_(flat_weight.shape) - else: - flat_weight_fp16 = torch.empty_like(flat_weight, - dtype=torch.float16) - kwargs['flat_weight'] = flat_weight_fp16 - else: - flat_weight_fp16 = None - - forward = orig_rnn(*args, **kwargs) - @functools.wraps(forward) - def fwd_wrapper(*fargs, **fkwargs): - assert len(fargs) == 3 or len(fargs) == 4 - inputs, weights, hiddens = fargs[:3] - assert utils.is_fp_tensor(inputs) - assert isinstance(weights, list) - cast_fn = utils.verbosify(utils.maybe_half, - fn, - verbose) - new_args = [] - - # 0) Inputs - new_args.append(cast_fn(inputs)) - - # 1) Weights - if flat_weight_fp16 is not None: - fp16_weights = utils.synthesize_flattened_rnn_weights( - weights, flat_weight_fp16, fn, verbose) - else: - fp16_weights = [[cast_fn(w) for w in layer] - for layer in weights] - new_args.append(fp16_weights) - - # 2) Inputs: either a tuple (for LSTM) or single tensor - if isinstance(hiddens, tuple): - new_args.append(tuple(cast_fn(x) for x in hiddens)) - elif utils.is_fp_tensor(hiddens): - new_args.append(cast_fn(hiddens)) - else: - # Hiddens can, in principle, be `None` -- pass through - new_args.append(hiddens) - - # 3) Batch sizes (0.4 or later only) - if len(fargs) == 4: - new_args.append(fargs[3]) - - return forward(*new_args, **fkwargs) - return fwd_wrapper - utils.set_func_save(handle, backend, fn, rnn_wrapper) - -def new_rnn_cast(fn, cast_fn, handle, verbose=False): - # Forward+backward compatibility around https://github.com/pytorch/pytorch/pull/15744 - # For rnn backend calls that route through _rnn_impls, we must patch the ref - # that _rnn_impls stashed. For rnn backend calls that directly invoke - # _VF., e.g. _VF.lstm, we can patch onto VariableFunctionsShim, - # which in turn has patched the ref named "_VF" in torch.nn.modules.rnn. - if utils.has_func(torch.nn.modules.rnn._rnn_impls, fn): - mod = torch.nn.modules.rnn._rnn_impls - else: - mod = torch.nn.modules.rnn._VF - assert isinstance(mod, rnn_compat.VariableFunctionsShim) - fn = fn.lower() - orig_fn = utils.get_func(mod, fn) - cast_fn = utils.verbosify(cast_fn, fn, verbose) - @functools.wraps(orig_fn) - def wrapper(*args, **kwargs): - # Exact call signature from modules/rnn.py - assert len(args) == 9 - assert len(kwargs) == 0 - - if not _amp_state.handle.is_active(): - return orig_fn(*args, **kwargs) - - if isinstance(args[6], bool): - params_idx = 2 # Not PackedSequence case - else: - params_idx = 3 # PackedSequence case - - if cast_fn == utils.maybe_half: - dtype = torch.half - elif cast_fn == utils.maybe_bfloat16: - dtype = torch.bfloat16 - else: - raise RuntimeError("Unsupported cast_fn passed. Supports only maybe_half and maybe_bfloat16") - new_args = [] - for i, arg in enumerate(args): - if i == params_idx: - num_params = sum([x.numel() for x in arg]) - fp16_weight_buf = args[0].new_empty((num_params,), - dtype=dtype) - casted_weights = utils.new_synthesize_flattened_rnn_weights( - arg, fp16_weight_buf, fn, dtype, verbose) - new_args.append(casted_weights) - elif utils.is_fp_tensor(arg): - new_args.append(cast_fn(arg)) - else: - new_args.append(arg) - - return orig_fn(*new_args) - utils.set_func_save(handle, mod, fn, wrapper) - -def disable_casts(mod, fn, handle): - if not utils.has_func(mod, fn): - return - - orig_fn = utils.get_func(mod, fn) - @functools.wraps(orig_fn) - def wrapper(*args, **kwargs): - with handle._disable_casts(): - return orig_fn(*args, **kwargs) - utils.set_func_save(handle, mod, fn, wrapper) diff --git a/apex/contrib/__init__.py b/apex/contrib/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/apex/contrib/bottleneck/__init__.py b/apex/contrib/bottleneck/__init__.py deleted file mode 100644 index 300b7c3..0000000 --- a/apex/contrib/bottleneck/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .bottleneck import Bottleneck, SpatialBottleneck -from .halo_exchangers import HaloExchangerNoComm, HaloExchangerAllGather, HaloExchangerSendRecv, HaloExchangerPeer diff --git a/apex/contrib/bottleneck/bottleneck.py b/apex/contrib/bottleneck/bottleneck.py deleted file mode 100644 index 5ea5694..0000000 --- a/apex/contrib/bottleneck/bottleneck.py +++ /dev/null @@ -1,749 +0,0 @@ -import functools as func - -import torch -import torch.distributed as dist -from torch import nn - -from apex import check_cudnn_version_and_warn -import fast_bottleneck -import nccl_p2p_cuda as inc - - -assert check_cudnn_version_and_warn(__name__, 8400) - - -def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'): - weight_tensor_nchw = tensor - nn.init.kaiming_uniform_(weight_tensor_nchw, a=a, mode=mode, nonlinearity=nonlinearity) - -def compute_scale_bias_one(nhwc, weight, bias, running_mean, running_var, w_scale, w_bias): - scale = weight * running_var.rsqrt() - bias = bias - running_mean * scale - w_scale.copy_(scale) - w_bias.copy_(bias) - -def compute_scale_bias_method(nhwc, args): - for arg in args: - # arg is tuple of (weight, bias, running_mean, running_var, w_scale, w_bias) - compute_scale_bias_one(nhwc, *arg) - -class FrozenBatchNorm2d(torch.jit.ScriptModule): - """ - BatchNorm2d where the batch statistics and the affine parameters are fixed - """ - def __init__(self, n): - super(FrozenBatchNorm2d, self).__init__() - self.register_buffer("weight", torch.ones(n)) - self.register_buffer("bias", torch.zeros(n)) - self.register_buffer("running_mean", torch.zeros(n)) - self.register_buffer("running_var", torch.ones(n)) - - @torch.jit.script_method - def get_scale_bias(self, nhwc): - # type: (bool) -> List[torch.Tensor] - scale = self.weight * self.running_var.rsqrt() - bias = self.bias - self.running_mean * scale - if nhwc: - scale = scale.reshape(1, 1, 1, -1) - bias = bias.reshape(1, 1, 1, -1) - else: - scale = scale.reshape(1, -1, 1, 1) - bias = bias.reshape(1, -1, 1, 1) - return scale, bias - - @torch.jit.script_method - def forward(self, x): - scale, bias = self.get_scale_bias(False) - return x * scale + bias - -@torch.jit.script -def drelu_dscale1(grad_o, output, scale1): - relu_mask = (output>0) - dx_relu = relu_mask * grad_o - g1 = dx_relu * scale1 - return g1, dx_relu - -@torch.jit.script -def drelu_dscale2(grad_o, output, scale1, scale2): - relu_mask = (output>0) - dx_relu = relu_mask * grad_o - g1 = dx_relu * scale1 - g2 = dx_relu * scale2 - return g1, g2 - -class BottleneckFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, nhwc, stride_1x1, scale, bias, x, *conv): - # TODO: clean up order of tensors - args = [x, *conv[0:3], *scale[0:3], *bias[0:3]] - ctx.downsample = len(conv) > 3 - if ctx.downsample: - args.append(conv[3]) - args.append(scale[3]) - args.append(bias[3]) - - # weight buffers are always in nhwc while shape can be nhwc or channels_last - # here we pass in flag and let c++ handle it - # alternatively, we can put all sizes into a fixed format and pass it in - outputs = fast_bottleneck.forward(nhwc, stride_1x1, args) - ctx.save_for_backward(*(args+outputs)) - # save relu outputs for drelu - ctx.nhwc = nhwc - ctx.stride_1x1 = stride_1x1 - return outputs[2] - - # backward relu is not exposed, MUL with mask used now - # only support dgrad - @staticmethod - def backward(ctx, grad_o): - outputs = ctx.saved_tensors[-3:] - - if ctx.downsample: - grad_conv3, grad_conv4 = drelu_dscale2(grad_o, outputs[2], ctx.saved_tensors[6], ctx.saved_tensors[11]) - else: - grad_conv3, grad_conv4 = drelu_dscale1(grad_o, outputs[2], ctx.saved_tensors[6]) - - # create input vector for backward - t_list = [*ctx.saved_tensors[0:10]] - t_list.append(grad_conv3) - t_list.append(grad_conv4) - - # outputs used for wgrad and generating drelu mask - t_list.append(outputs[0]) - t_list.append(outputs[1]) - - # in case there is downsample - if ctx.downsample: - t_list.append(ctx.saved_tensors[10]) - - grads = fast_bottleneck.backward(ctx.nhwc, ctx.stride_1x1, t_list) - - return (None, None, None, None, *grads) - -bottleneck_function = BottleneckFunction.apply - -def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): - """3x3 convolution with padding""" - return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, - padding=dilation, groups=groups, bias=False, dilation=dilation) - -def conv1x1(in_planes, out_planes, stride=1): - """1x1 convolution""" - return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) - -class Bottleneck(torch.nn.Module): - # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) - # while original implementation places the stride at the first 1x1 convolution(self.conv1) - # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. - # This variant is also known as ResNet V1.5 and improves accuracy according to - # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. - # here we put it at 1x1 - - def __init__(self, in_channels, bottleneck_channels, out_channels, stride=1, groups=1, - dilation=1, norm_func=None, use_cudnn=False, explicit_nhwc=False): - super(Bottleneck, self).__init__() - if groups != 1: - raise RuntimeError('Only support groups == 1') - if dilation != 1: - raise RuntimeError('Only support dilation == 1') - if norm_func == None: - norm_func = FrozenBatchNorm2d - else: - raise RuntimeError('Only support frozen BN now.') - - if stride != 1 or in_channels != out_channels: - self.downsample = nn.Sequential( - conv1x1(in_channels, out_channels, stride), - norm_func(out_channels), - ) - else: - self.downsample = None - - # Both self.conv2 and self.downsample layers downsample the input when stride != 1 - self.conv1 = conv1x1(in_channels, bottleneck_channels, stride) - self.conv2 = conv3x3(bottleneck_channels, bottleneck_channels) - self.conv3 = conv1x1(bottleneck_channels, out_channels) - self.relu = nn.ReLU(inplace=True) - self.stride = stride - - self.bn1 = norm_func(bottleneck_channels) - self.bn2 = norm_func(bottleneck_channels) - self.bn3 = norm_func(out_channels) - self.w_scale = None - - self.use_cudnn = use_cudnn - - # setup conv weights - self.w_conv = [self.conv1.weight, self.conv2.weight, self.conv3.weight] - if self.downsample is not None: - self.w_conv.append(self.downsample[0].weight) - - # init weight in nchw format before possible transpose - for w in self.w_conv: - kaiming_uniform_(w, a=1) - - # TODO: prevent unsupported case usage - # support cases - # native cudnn - # normal yes no - # channel_last yes yes - # explicit_nhwc no yes - self.explicit_nhwc = explicit_nhwc - if self.explicit_nhwc: - for p in self.parameters(): - with torch.no_grad(): - p.data = p.data.permute(0,2,3,1).contiguous() - - return - - # Returns single callable that recomputes scale and bias for all frozen batch-norms. - # This method must be called before cuda graphing. - # The callable it returns can be called anytime. - # Calling this method will prevent these from being computed every forward call. - def get_scale_bias_callable(self): - self.w_scale, self.w_bias, args = [], [], [] - batch_norms = [self.bn1, self.bn2, self.bn3] - if self.downsample is not None: - batch_norms.append(self.downsample[1]) - for bn in batch_norms: - s = torch.empty_like(bn.weight) - b = torch.empty_like(s) - args.append( (bn.weight, bn.bias, bn.running_mean, bn.running_var, s, b) ) - if self.explicit_nhwc: - self.w_scale.append( s.reshape(1, 1, 1, -1) ) - self.w_bias.append( b.reshape(1, 1, 1, -1) ) - else: - self.w_scale.append( s.reshape(1, -1, 1, 1) ) - self.w_bias.append( b.reshape(1, -1, 1, 1) ) - return func.partial(compute_scale_bias_method, self.explicit_nhwc, args) - - def forward(self, x): - if self.use_cudnn: - if self.w_scale is None: - # calculate scale/bias from registered buffers - # TODO: make this better - s1, b1 = self.bn1.get_scale_bias(self.explicit_nhwc) - s2, b2 = self.bn2.get_scale_bias(self.explicit_nhwc) - s3, b3 = self.bn3.get_scale_bias(self.explicit_nhwc) - w_scale = [s1, s2, s3] - w_bias = [b1, b2, b3] - if self.downsample is not None: - s4, b4 = self.downsample[1].get_scale_bias(self.explicit_nhwc) - w_scale.append(s4) - w_bias.append(b4) - out = bottleneck_function(self.explicit_nhwc, self.stride, w_scale, w_bias, x, *self.w_conv) - else: - out = bottleneck_function(self.explicit_nhwc, self.stride, self.w_scale, self.w_bias, x, *self.w_conv) - return out - - if self.explicit_nhwc: - raise RuntimeError('explicit nhwc with native ops is not supported.') - - # fallback to native ops - identity = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - out = self.relu(out) - - out = self.conv3(out) - out = self.bn3(out) - - if self.downsample is not None: - identity = self.downsample(x) - - out += identity - out = self.relu(out) - - return out - - -class SpatialBottleneckFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, spatial_group_size, spatial_group_rank, spatial_communicator, spatial_halo_exchanger, spatial_method, use_delay_kernel, explicit_nhwc, stride_1x1, scale, bias, thresholdTop, thresholdBottom, x, *conv): - if spatial_group_size > 1: - stream1 = spatial_halo_exchanger.stream1 - stream2 = spatial_halo_exchanger.stream2 - stream3 = spatial_halo_exchanger.stream3 - - # TODO: clean up order of tensors - args = [x, *conv[0:3], *scale[0:3], *bias[0:3]] - ctx.downsample = len(conv) > 3 - if ctx.downsample: - args.append(conv[3]) - args.append(scale[3]) - args.append(bias[3]) - - # weight buffers are always in explicit_nhwc while shape can be explicit_nhwc or channels_last - # here we pass in flag and let c++ handle it - # alternatively, we can put all sizes into a fixed format and pass it in - outputs = fast_bottleneck.forward_init(explicit_nhwc, stride_1x1, args) - fast_bottleneck.forward_out1(explicit_nhwc, stride_1x1, args, outputs) - - if spatial_group_size > 1: - out1 = outputs[0] - if explicit_nhwc: - N,Hs,W,C = list(out1.shape) - memory_format = torch.contiguous_format - out1_pad = torch.empty([N,Hs+2,W,C], dtype=out1.dtype, device='cuda') - else: - N,C,Hs,W = list(out1.shape) - memory_format = torch.channels_last if out1.is_contiguous(memory_format=torch.channels_last) else torch.contiguous_format - out1_pad = torch.empty([N,C,Hs+2,W], dtype=out1.dtype, device='cuda', memory_format=memory_format) - stream1.wait_stream(torch.cuda.current_stream()) - if spatial_method != 2: stream3.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(stream1): - if explicit_nhwc: - top_out1_halo = out1_pad[:,:1,:,:] - btm_out1_halo = out1_pad[:,Hs+1:Hs+2,:,:] - spatial_halo_exchanger.left_right_halo_exchange(out1[:,:1,:,:], out1[:,Hs-1:,:,:], top_out1_halo, btm_out1_halo) - else: - top_out1_halo = out1_pad[:,:,:1,:] - btm_out1_halo = out1_pad[:,:,Hs+1:Hs+2,:] - spatial_halo_exchanger.left_right_halo_exchange(out1[:,:,:1,:], out1[:,:,Hs-1:,:], top_out1_halo, btm_out1_halo) - if spatial_method == 1: - # overlap mid convolution with halo transfer - if spatial_group_rank < spatial_group_size-1: - stream2.wait_stream(stream1) - with torch.cuda.stream(stream2): - if explicit_nhwc: - btm_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device) - btm_fat_halo[:,0:2,:,:].copy_(out1[:,Hs-2:,:,:]) - btm_fat_halo[:,2:,:,:].copy_(btm_out1_halo) - else: - btm_fat_halo = torch.empty((N,C,3,W),dtype=out1.dtype,device=out1.device) - btm_fat_halo[:,:,0:2,:].copy_(out1[:,:,Hs-2:,:]) - btm_fat_halo[:,:,2:,:].copy_(btm_out1_halo) - btm_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, btm_fat_halo, args) - if spatial_group_rank > 0: - with torch.cuda.stream(stream1): - if explicit_nhwc: - top_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device) - top_fat_halo[:,:1,:,:].copy_(top_out1_halo) - top_fat_halo[:,1:3,:,:].copy_(out1[:,:2,:,:]) - else: - top_fat_halo = torch.empty((N,C,3,W),dtype=out1.dtype,device=out1.device) - top_fat_halo[:,:,:1,:].copy_(top_out1_halo) - top_fat_halo[:,:,1:3,:].copy_(out1[:,:,:2,:]) - top_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, top_fat_halo, args) - if use_delay_kernel: inc.add_delay(10) - elif spatial_method != 2 and spatial_method != 3: - assert(False), "spatial_method must be 1, 2 or 3" - - if spatial_group_size <= 1: - fast_bottleneck.forward_out2(explicit_nhwc, stride_1x1, args, outputs) - elif spatial_method == 1: - fast_bottleneck.forward_out2(explicit_nhwc, stride_1x1, args, outputs) - with torch.cuda.stream(stream3): - if explicit_nhwc: - out1_pad[:,1:Hs+1,:,:].copy_(out1) - else: - out1_pad[:,:,1:Hs+1,:].copy_(out1) - elif spatial_method == 2: - # wait for halo transfer to finish before doing a full convolution of padded x - if explicit_nhwc: - out1_pad[:,1:Hs+1,:,:].copy_(out1) - else: - out1_pad[:,:,1:Hs+1,:].copy_(out1) - torch.cuda.current_stream().wait_stream(stream1) - fast_bottleneck.forward_out2_pad(explicit_nhwc, stride_1x1, args, outputs, out1_pad) - elif spatial_method == 3: - fast_bottleneck.forward_out2_mask(explicit_nhwc, stride_1x1, args, outputs, thresholdTop, thresholdBottom) - with torch.cuda.stream(stream3): - if explicit_nhwc: - out1_pad[:,1:Hs+1,:,:].copy_(out1) - else: - out1_pad[:,:,1:Hs+1,:].copy_(out1) - - # compute halo cells for outputs[1] (out2) - if spatial_group_size > 1: - out2 = outputs[1] - if explicit_nhwc: - top_out2_halo = out2[:,:1,:,:] - btm_out2_halo = out2[:,Hs-1:,:,:] - else: - top_out2_halo = out2[:,:,:1,:] - btm_out2_halo = out2[:,:,Hs-1:,:] - if spatial_method == 1: - if spatial_group_rank > 0: - torch.cuda.current_stream().wait_stream(stream1) - top_out2_halo.copy_(top_out2) - if spatial_group_rank < spatial_group_size-1: - torch.cuda.current_stream().wait_stream(stream2) - btm_out2_halo.copy_(btm_out2) - elif spatial_method == 3: - # Note - # out2 halo correction cannot overlap with anything since it has - # to wait for out2_mask to finish, but itself has to finish before - # the first kernel of _forward_rest can launch. - # At least we can overlap the two halo correction kernels. - if spatial_group_rank < spatial_group_size-1: - stream2.wait_stream(stream1) # wait for halo transfers to finish - stream2.wait_stream(torch.cuda.current_stream()) # wait for *_out2_mask to finish - with torch.cuda.stream(stream2): - w1by3 = args[2][:,2:3,:,:].clone() - btm_out1_halo = btm_out1_halo.clone() - btm_out2 = fast_bottleneck.forward_out2_halo_corr(explicit_nhwc, btm_out1_halo, args, w1by3, btm_out2_halo.clone()) - btm_out2_halo.copy_(btm_out2) - if spatial_group_rank > 0: - stream1.wait_stream(torch.cuda.current_stream()) # wait for *_out2_mask to finish - with torch.cuda.stream(stream1): - w1by3 = args[2][:,:1,:,:].clone() - top_out1_halo = top_out1_halo.clone() - top_out2 = fast_bottleneck.forward_out2_halo_corr(explicit_nhwc, top_out1_halo, args, w1by3, top_out2_halo.clone()) - top_out2_halo.copy_(top_out2) - if spatial_group_rank < spatial_group_size-1: - torch.cuda.current_stream().wait_stream(stream2) - if spatial_group_rank > 0: - torch.cuda.current_stream().wait_stream(stream1) - - fast_bottleneck.forward_rest(explicit_nhwc, stride_1x1, args, outputs) - # save halos for backward pass - if spatial_group_size > 1: - if spatial_method != 2: - # make sure copy of mid-section of out1 into out1_pad is done before exiting - torch.cuda.current_stream().wait_stream(stream3) - ctx.save_for_backward(*(args+outputs+[out1_pad,])) - else: - ctx.save_for_backward(*(args+outputs)) - # save relu outputs for drelu - ctx.explicit_nhwc = explicit_nhwc - ctx.stride_1x1 = stride_1x1 - ctx.spatial_group_size = spatial_group_size - if spatial_group_size > 1: - ctx.spatial_group_rank = spatial_group_rank - ctx.spatial_halo_exchanger = spatial_halo_exchanger - ctx.spatial_method = spatial_method - ctx.use_delay_kernel = use_delay_kernel - ctx.thresholdTop = thresholdTop - ctx.thresholdBottom = thresholdBottom - ctx.stream1 = stream1 - ctx.stream2 = stream2 - ctx.stream3 = stream3 - return outputs[2] - - # backward relu is not exposed, MUL with mask used now - # only support dgrad - @staticmethod - def backward(ctx, grad_o): - if ctx.spatial_group_size > 1: - out1_pad = ctx.saved_tensors[-1] - outputs = ctx.saved_tensors[-4:-1] - else: - outputs = ctx.saved_tensors[-3:] - - if ctx.downsample: - grad_conv3, grad_conv4 = drelu_dscale2(grad_o, outputs[2], ctx.saved_tensors[6], ctx.saved_tensors[11]) - else: - grad_conv3, grad_conv4 = drelu_dscale1(grad_o, outputs[2], ctx.saved_tensors[6]) - - # create input vector for backward - t_list = [*ctx.saved_tensors[0:10]] - t_list.append(grad_conv3) - t_list.append(grad_conv4) - - # outputs used for wgrad and generating drelu mask - t_list.append(outputs[0]) - t_list.append(outputs[1]) - - # in case there is downsample - if ctx.downsample: - t_list.append(ctx.saved_tensors[10]) - - grads = fast_bottleneck.backward_init(ctx.explicit_nhwc, ctx.stride_1x1, t_list) - wgrad3_stream = torch.cuda.Stream() - wgrad3_stream.wait_stream(torch.cuda.current_stream()) - grad_out2 = fast_bottleneck.backward_grad_out2(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads) - wgrad2_stream = torch.cuda.Stream() - wgrad2_stream.wait_stream(torch.cuda.current_stream()) - # do halo exchange of grad_out2 here - # compute halo cells for grad_out1 - if ctx.spatial_group_size > 1: - if ctx.explicit_nhwc: - N,Hs,W,C = list(grad_out2.shape) - else: - N,C,Hs,W = list(grad_out2.shape) - relu1 = t_list[12] - ctx.stream1.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(ctx.stream1): - top_halo, btm_halo = ctx.spatial_halo_exchanger.left_right_halo_exchange(grad_out2[:,:1,:,:], grad_out2[:,Hs-1:,:,:]) - # copy halos to send buffer - if ctx.spatial_method == 1 or ctx.spatial_method == 2: - # 1 -> halo recompute approach - # 2 -> wait for concatenated halos, then do single conv on full input (not implemented yet for bprop) - if ctx.spatial_group_rank < ctx.spatial_group_size-1: - ctx.stream2.wait_stream(ctx.stream1) - with torch.cuda.stream(ctx.stream2): - if ctx.explicit_nhwc: - btm_fat_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device) - btm_fat_halo[:,:2,:,:].copy_(grad_out2[:,Hs-2:,:,:]) - btm_fat_halo[:,2:,:,:].copy_(btm_halo) - btm_fat_relu_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device) - btm_fat_relu_halo[:,:2,:,:].copy_(relu1[:,Hs-2:,:,:]) - btm_fat_relu_halo[:,2:,:,:].zero_() - else: - btm_fat_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device) - btm_fat_halo[:,:,:2,:].copy_(grad_out2[:,:,Hs-2:,:]) - btm_fat_halo[:,:,2:,:].copy_(btm_halo) - btm_fat_relu_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device) - btm_fat_relu_halo[:,:,:2,:].copy_(relu1[:,:,Hs-2:,:]) - btm_fat_relu_halo[:,:,2:,:].zero_() - btm_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, btm_fat_halo, btm_fat_relu_halo) - if ctx.explicit_nhwc: - btm_grad_out1_halo = btm_grad_out1_halo[:,1:2,:,:] - else: - btm_grad_out1_halo = btm_grad_out1_halo[:,:,1:2,:] - if ctx.spatial_group_rank > 0: - with torch.cuda.stream(ctx.stream1): - if ctx.explicit_nhwc: - top_fat_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device) - top_fat_halo[:,:1,:,:].copy_(top_halo) - top_fat_halo[:,1:,:,:].copy_(grad_out2[:,:2,:,:]) - top_fat_relu_halo = torch.empty((N,3,W,C),dtype=grad_out2.dtype,device=grad_out2.device) - top_fat_relu_halo[:,:1,:,:].zero_() - top_fat_relu_halo[:,1:,:,:].copy_(relu1[:,:2,:,:]) - else: - top_fat_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device) - top_fat_halo[:,:,:1,:].copy_(top_halo) - top_fat_halo[:,:,1:,:].copy_(grad_out2[:,:,:2,:]) - top_fat_relu_halo = torch.empty((N,C,3,W),dtype=grad_out2.dtype,device=grad_out2.device) - top_fat_relu_halo[:,:,:1,:].zero_() - top_fat_relu_halo[:,:,1:,:].copy_(relu1[:,:,:2,:]) - top_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, top_fat_halo, top_fat_relu_halo) - if ctx.explicit_nhwc: - top_grad_out1_halo = top_grad_out1_halo[:,1:2,:,:] - else: - top_grad_out1_halo = top_grad_out1_halo[:,:,1:2,:] - if ctx.use_delay_kernel: inc.add_delay(10) - elif ctx.spatial_method != 3: - assert(False), "spatial_method must be 1, 2 or 3" - - # compute grad_out1 for internal cells - if ctx.spatial_group_size <= 1 or ctx.spatial_method == 1 or ctx.spatial_method == 2: - grad_out1 = fast_bottleneck.backward_grad_out1(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2) - elif ctx.spatial_group_size > 1 and ctx.spatial_method == 3: - grad_out1 = fast_bottleneck.backward_grad_out1_mask(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2, ctx.thresholdTop, ctx.thresholdBottom) - - # apply halo cells to grad_out1 - if ctx.spatial_group_size > 1: - w = t_list[2] - z = t_list[4] - relu1 = t_list[12] - #print("w.shape = %s, z.shape = %s, relu1.shape = %s" % (str(list(w.shape)), str(list(z.shape)), str(list(relu1.shape)))) - if ctx.spatial_method == 1 or ctx.spatial_method == 2: - if ctx.spatial_group_rank < ctx.spatial_group_size-1: - torch.cuda.current_stream().wait_stream(ctx.stream2) - if ctx.explicit_nhwc: - grad_out1[:,Hs-1:,:,:].copy_(btm_grad_out1_halo) - else: - grad_out1[:,:,Hs-1:,:].copy_(btm_grad_out1_halo) - #print("ctx.spatial_group_rank = %d, apply grad_out1 btm halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape)))) - if ctx.spatial_group_rank > 0: - torch.cuda.current_stream().wait_stream(ctx.stream1) - if ctx.explicit_nhwc: - grad_out1[:,:1,:,:].copy_(top_grad_out1_halo) - else: - grad_out1[:,:,:1,:].copy_(top_grad_out1_halo) - #print("ctx.spatial_group_rank = %d, apply grad_out1 top halo (grad_out1.shape = %s)" % (ctx.spatial_group_rank, str(list(grad_out1.shape)))) - elif ctx.spatial_method == 3: - if ctx.spatial_group_rank < ctx.spatial_group_size-1: - if ctx.explicit_nhwc: - btm_relu_halo = relu1[:,Hs-1:,:,:].clone() - btm_grad_out1 = grad_out1[:,Hs-1:,:,:] - else: - btm_relu_halo = relu1[:,:,Hs-1:,:].clone() - btm_grad_out1 = grad_out1[:,:,Hs-1:,:] - w1by3 = w[:,:1,:,:].clone() - ctx.stream2.wait_stream(ctx.stream1) # wait for halo transfers to finish - ctx.stream2.wait_stream(torch.cuda.current_stream()) # wait for backward_grad_out1_mask to finish before launching halo correction kernel - with torch.cuda.stream(ctx.stream2): - btm_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo_corr(ctx.explicit_nhwc, ctx.stride_1x1, t_list, w1by3, grads, btm_halo, btm_relu_halo, btm_grad_out1.clone()) - btm_grad_out1.copy_(btm_grad_out1_halo) - if ctx.spatial_group_rank > 0: - if ctx.explicit_nhwc: - top_relu_halo = relu1[:,:1,:,:].clone() - top_grad_out1 = grad_out1[:,:1,:,:] - else: - top_relu_halo = relu1[:,:,:1,:].clone() - top_grad_out1 = grad_out1[:,:,:1,:] - w1by3 = w[:,2:,:,:].clone() - ctx.stream1.wait_stream(torch.cuda.current_stream()) # wait for backward_grad_out1_mask to finish before launching halo correction kernel - with torch.cuda.stream(ctx.stream1): - top_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo_corr(ctx.explicit_nhwc, ctx.stride_1x1, t_list, w1by3, grads, top_halo, top_relu_halo, top_grad_out1.clone()) - top_grad_out1.copy_(top_grad_out1_halo) - if ctx.spatial_group_rank < ctx.spatial_group_size-1: - torch.cuda.current_stream().wait_stream(ctx.stream2) # wait for halo correction to finish - if ctx.spatial_group_rank > 0: - torch.cuda.current_stream().wait_stream(ctx.stream1) - - wgrad1_stream = torch.cuda.Stream() - wgrad1_stream.wait_stream(torch.cuda.current_stream()) - fast_bottleneck.backward_rest(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2, grad_out1) - with torch.cuda.stream(wgrad3_stream): - fast_bottleneck.backward_wgrad3(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads) - with torch.cuda.stream(wgrad2_stream): - if ctx.spatial_group_size > 1: - fast_bottleneck.backward_wgrad2_pad(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, out1_pad, grad_out2) - else: - fast_bottleneck.backward_wgrad2(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out2) - with torch.cuda.stream(wgrad1_stream): - fast_bottleneck.backward_wgrad1(ctx.explicit_nhwc, ctx.stride_1x1, t_list, grads, grad_out1) - torch.cuda.current_stream().wait_stream(wgrad3_stream) - torch.cuda.current_stream().wait_stream(wgrad2_stream) - torch.cuda.current_stream().wait_stream(wgrad1_stream) - - return (None, None, None, None, None, None, None, None, None, None, None, None, *grads) - -spatial_bottleneck_function = SpatialBottleneckFunction.apply - -class SpatialBottleneck(torch.nn.Module): - # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) - # while original implementation places the stride at the first 1x1 convolution(self.conv1) - # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. - # This variant is also known as ResNet V1.5 and improves accuracy according to - # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. - # here we put it at 1x1 - - def __init__(self, in_channels, bottleneck_channels, out_channels, stride=1, groups=1, - dilation=1, norm_func=None, use_cudnn=False, explicit_nhwc=False, - spatial_parallel_args=None): - super(SpatialBottleneck, self).__init__() - if groups != 1: - raise RuntimeError('Only support groups == 1') - if dilation != 1: - raise RuntimeError('Only support dilation == 1') - if norm_func == None: - norm_func = FrozenBatchNorm2d - else: - raise RuntimeError('Only support frozen BN now.') - - if stride != 1 or in_channels != out_channels: - self.downsample = nn.Sequential( - conv1x1(in_channels, out_channels, stride), - norm_func(out_channels), - ) - else: - self.downsample = None - - # Both self.conv2 and self.downsample layers downsample the input when stride != 1 - self.conv1 = conv1x1(in_channels, bottleneck_channels, stride) - self.conv2 = conv3x3(bottleneck_channels, bottleneck_channels) - self.conv3 = conv1x1(bottleneck_channels, out_channels) - self.relu = nn.ReLU(inplace=True) - self.stride = stride - - self.bn1 = norm_func(bottleneck_channels) - self.bn2 = norm_func(bottleneck_channels) - self.bn3 = norm_func(out_channels) - self.w_scale = None - - self.use_cudnn = use_cudnn - - # setup conv weights - self.w_conv = [self.conv1.weight, self.conv2.weight, self.conv3.weight] - if self.downsample is not None: - self.w_conv.append(self.downsample[0].weight) - - # init weight in nchw format before possible transpose - for w in self.w_conv: - kaiming_uniform_(w, a=1) - - self.thresholdTop, self.thresholdBottom = None, None - - # TODO: prevent unsupported case usage - # support cases - # native cudnn - # normal yes no - # channel_last yes yes - # explicit_nhwc no yes - self.explicit_nhwc = explicit_nhwc - if self.explicit_nhwc: - for p in self.parameters(): - with torch.no_grad(): - p.data = p.data.permute(0,2,3,1).contiguous() - - # spatial communicator - if spatial_parallel_args is None: - self.spatial_parallel_args = (1, 0, None, None, 0, False) - else: - self.spatial_parallel_args = spatial_parallel_args - return - - # Returns single callable that recomputes scale and bias for all frozen batch-norms. - # This method must be called before cuda graphing. - # The callable it returns can be called anytime. - # Calling this method will prevent these from being computed every forward call. - def get_scale_bias_callable(self): - self.w_scale, self.w_bias, args = [], [], [] - batch_norms = [self.bn1, self.bn2, self.bn3] - if self.downsample is not None: - batch_norms.append(self.downsample[1]) - for bn in batch_norms: - s = torch.empty_like(bn.weight) - b = torch.empty_like(s) - args.append( (bn.weight, bn.bias, bn.running_mean, bn.running_var, s, b) ) - if self.explicit_nhwc: - self.w_scale.append( s.reshape(1, 1, 1, -1) ) - self.w_bias.append( b.reshape(1, 1, 1, -1) ) - else: - self.w_scale.append( s.reshape(1, -1, 1, 1) ) - self.w_bias.append( b.reshape(1, -1, 1, 1) ) - return func.partial(compute_scale_bias_method, self.explicit_nhwc, args) - - def forward(self, x): - if self.use_cudnn: - if self.thresholdTop is None: - spatial_group_size, spatial_group_rank, _, _, _, _ = self.spatial_parallel_args - if self.explicit_nhwc: - N,H,W,C = list(x.shape) - else: - N,C,H,W = list(x.shape) - self.thresholdTop = torch.tensor([1 if spatial_group_rank > 0 else 0], dtype=torch.int32, device='cuda') - self.thresholdBottom = torch.tensor([H-2 if spatial_group_rank < spatial_group_size - 1 else H-1], dtype=torch.int32, device='cuda') - - if self.w_scale is None: - # calculate scale/bias from registered buffers - # TODO: make this better - s1, b1 = self.bn1.get_scale_bias(self.explicit_nhwc) - s2, b2 = self.bn2.get_scale_bias(self.explicit_nhwc) - s3, b3 = self.bn3.get_scale_bias(self.explicit_nhwc) - w_scale = [s1, s2, s3] - w_bias = [b1, b2, b3] - if self.downsample is not None: - s4, b4 = self.downsample[1].get_scale_bias(self.explicit_nhwc) - w_scale.append(s4) - w_bias.append(b4) - out = spatial_bottleneck_function(*self.spatial_parallel_args, self.explicit_nhwc, self.stride, w_scale, w_bias, self.thresholdTop, self.thresholdBottom, x, *self.w_conv) - else: - out = spatial_bottleneck_function(*self.spatial_parallel_args, self.explicit_nhwc, self.stride, self.w_scale, self.w_bias, self.thresholdTop, self.thresholdBottom, x, *self.w_conv) - return out - - if self.explicit_nhwc: - raise RuntimeError('explicit nhwc with native ops is not supported.') - - # fallback to native ops - identity = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - out = self.relu(out) - - out = self.conv3(out) - out = self.bn3(out) - - if self.downsample is not None: - identity = self.downsample(x) - - out += identity - out = self.relu(out) - - return out - diff --git a/apex/contrib/bottleneck/bottleneck_module_test.py b/apex/contrib/bottleneck/bottleneck_module_test.py deleted file mode 100644 index 16d1e55..0000000 --- a/apex/contrib/bottleneck/bottleneck_module_test.py +++ /dev/null @@ -1,254 +0,0 @@ -import torch -from apex.contrib.bottleneck import Bottleneck, SpatialBottleneck -from apex.contrib.bottleneck import HaloExchangerNoComm, HaloExchangerAllGather, HaloExchangerSendRecv, HaloExchangerPeer -from apex.contrib.peer_memory import PeerMemoryPool - - -def ground_truth_bottleneck(C, dtype, explicit_nhwc): - bottleneck = Bottleneck(C,C,C,use_cudnn=True,explicit_nhwc=explicit_nhwc) - bottleneck.to(dtype=dtype, device='cuda') - for p in bottleneck.parameters(): - torch.distributed.broadcast(p, 0) - for b in bottleneck.buffers(): - torch.distributed.broadcast(b, 0) - return bottleneck - - -def print_bottleneck_p_and_b(bottleneck): - with torch.no_grad(): - for n,p in bottleneck.named_parameters(): - print("%s :: %s" % (n, str(p.norm(p=2,dtype=torch.float32)))) - for n,p in bottleneck.named_buffers(): - print("%s :: %s" % (n, str(p.norm(p=2,dtype=torch.float32)))) - - -def has_nan(x): - if isinstance(x, list) or isinstance(x, tuple): - for xx in x: - if torch.any(torch.isnan(xx)): - return True - return False - elif isinstance(x, dict): - for k,v in x.items(): - if torch.any(torch.isnan(v)): - return True - else: - return torch.any(torch.isnan(x)) - - -def rel_diff_t(xx1, xx2): - return ((xx1 - xx2).norm(p=2,dtype=torch.float32) / (xx1 + xx2).norm(p=2,dtype=torch.float32)).item() - - -def rel_diff(x1, x2): - if isinstance(x1, list) or isinstance(x1, tuple): - return [rel_diff_t(xx1,xx2) for xx1,xx2 in zip(x1,x2)] - elif isinstance(x1, dict): - return [rel_diff_t(xx1, xx2) for (k1,xx1), (k2,xx2) in zip(x1.items(),x2.items())] - else: - return rel_diff_t(x1,x2) - - -def graph_it(bottleneck, x): - print("Graphing") - with torch.no_grad(): - x = x.clone() - x.grad = None - x.requires_grad = True - return torch.cuda.make_graphed_callables(bottleneck, (x,)) - - -def clone_inputs(bottleneck, x, dy=None): - with torch.no_grad(): - x = x.clone() - x.grad = None - x.requires_grad = True - if dy is None: - y = bottleneck(x) - dy = torch.randn_like(y) / 1e2 - torch.distributed.broadcast(dy, 0) - return x, dy - - -def fprop_and_bprop(bottleneck, x, dy): - y = bottleneck(x) - y.backward(dy) - dgrad = x.grad.detach() - wgrad = {} - for n,p in bottleneck.named_parameters(): - wgrad[n] = p.grad.detach() - return x, y, dy, dgrad, wgrad - - -def ground_truth(N, C, H, W, dtype, memory_format, bottleneck): - if memory_format == 1: - # 1 -> explicit nhwc - explicit_nhwc = True - with torch.no_grad(): - x = torch.randn([N,H,W,C], dtype=dtype, device='cuda') - torch.distributed.broadcast(x, 0) - x, dy = clone_inputs(bottleneck, x) - return fprop_and_bprop(bottleneck, x, dy) - else: - # 2 -> native nhwc - # 3 -> nchw - explicit_nhwc = False - assert(False), "Not implemented yet" - - -def print_ground_truth(gt): - x, y, dy, dgrad, wgrad = gt - if has_nan(y) or has_nan(dgrad) or has_nan(wgrad): - print("Error! Ground truth has NAN") - else: - print("Ok! No NAN found in ground truth") - - -def apply_to_different_bottleneck(gt, bottleneck): - with torch.no_grad(): - x, _, dy, _, _ = gt - x, dy = clone_inputs(bottleneck, x, dy) - return fprop_and_bprop(bottleneck, x, dy) - - -def compare_single_field(results, f1, f2, l0, l1, l2): - if has_nan(f1) and has_nan(f2): - results[l0] = "both NAN" - elif has_nan(f1): - results[l0] = "%s.%s NAN" % (l1, l0) - elif has_nan(f2): - results[l0] = "%s.%s NAN" % (l2, l0) - else: - results[l0] = "%s" % (str(rel_diff(f1,f2))) - - -def compare(gt, bt): - x1, y1, dy1, dgrad1, wgrad1 = gt - x2, y2, dy2, dgrad2, wgrad2 = bt - results = {} - compare_single_field(results, y1, y2, "y", "gt", "bt") - compare_single_field(results, dy1, dy2, "dy", "gt", "bt") - compare_single_field(results, dgrad1, dgrad2, "dgrad", "gt", "bt") - compare_single_field(results, wgrad1, wgrad2, "wgrad", "gt", "bt") - for i in range(torch.distributed.get_world_size()): - if i == torch.distributed.get_rank(): - print(i,results) - torch.distributed.barrier() - - -def spatial_parallel_bottleneck(C, dtype, explicit_nhwc, gt_bottleneck, spatial_parallel_args): - spatial_bottleneck = SpatialBottleneck(C,C,C,use_cudnn=True,explicit_nhwc=explicit_nhwc,spatial_parallel_args=spatial_parallel_args) - spatial_bottleneck.to(dtype=dtype, device='cuda') - with torch.no_grad(): - sp = {} - for n,p in spatial_bottleneck.named_parameters(): - sp[n] = p - for n,p in gt_bottleneck.named_parameters(): - sp[n].copy_(p) - sb = {} - for n,b in spatial_bottleneck.named_buffers(): - sb[n] = b - for n,b in gt_bottleneck.named_buffers(): - sb[n].copy_(b) - return spatial_bottleneck - -def n_way_spatial(halex, gt_bottleneck, gt, explicit_nhwc, world_size, rank, fp32_reduce=False): - assert(explicit_nhwc), "Only tested for explicit nhwc" - - x, _, dy, _, _ = gt - N, H, W, C = list(x.shape) # Tensor is already shaped properly for n-way parallel - dtype = x.dtype - - spatial_group_size = world_size - spatial_group_rank = rank - spatial_communicator = None - spatial_halo_exchanger = halex - spatial_method = 1 # 1 -> overlap halo and main conv, 2 -> wait for halo, conv on padded x - use_delay_kernel = False - spatial_parallel_args = (spatial_group_size, spatial_group_rank, spatial_communicator, spatial_halo_exchanger, spatial_method, use_delay_kernel) - spatial_bottleneck = spatial_parallel_bottleneck(C, dtype, explicit_nhwc, gt_bottleneck, spatial_parallel_args) - - with torch.no_grad(): - Hs = H // spatial_group_size - xs = x[:,spatial_group_rank*Hs:(spatial_group_rank+1)*Hs,:,:].clone() - dys = dy[:,spatial_group_rank*Hs:(spatial_group_rank+1)*Hs,:,:].clone() - xs.requires_grad = True - - spatial_bottleneck = graph_it(spatial_bottleneck, xs) - _, y, _, dgrad, wgrad = fprop_and_bprop(spatial_bottleneck, xs, dys) - - # gather output pieces - for n,p in wgrad.items(): - if fp32_reduce: - p32 = p.float() - torch.distributed.all_reduce(p32) - p.copy_(p32.half()) - else: - torch.distributed.all_reduce(p) - ys = [torch.empty_like(y) for _ in range(spatial_group_size)] - torch.distributed.all_gather(ys,y) - y = torch.cat(ys,dim=1) - dgrads = [torch.empty_like(dgrad) for _ in range(spatial_group_size)] - torch.distributed.all_gather(dgrads,dgrad) - dgrad = torch.cat(dgrads,dim=1) - return x, y, dy, dgrad, wgrad - - -def main(): - torch.use_deterministic_algorithms(True) - - torch.distributed.init_process_group("nccl") - rank = torch.distributed.get_rank() - world_size = torch.distributed.get_world_size() - torch.cuda.set_device(rank) - - explicit_nhwc = True - - dtype = torch.float16 - N, C, H, W = 1, 64, 200, 336 - Hs = ((H+8*world_size-1) // (8*world_size)) * 8 - H = Hs*world_size - gt_bottleneck = ground_truth_bottleneck(C, dtype, explicit_nhwc) - gt = ground_truth(N, C, H, W, dtype, 1, gt_bottleneck) - - # verify that spatial bottleneck with group_size 1 produces same results as ground truth bottleneck - spatial_bottleneck = spatial_parallel_bottleneck(C, dtype, explicit_nhwc, gt_bottleneck, None) - bt = apply_to_different_bottleneck(gt, spatial_bottleneck) - compare(gt, bt) - #print_bottleneck_p_and_b(gt_bottleneck) - #print_bottleneck_p_and_b(spatial_bottleneck) - - group_size = world_size - group = rank // group_size - ranks = [group*group_size+i for i in range(group_size)] - rank_in_group = rank % group_size - - spatial_group_size = world_size - spatial_communicator = None - - peer_pool = PeerMemoryPool(64*1024*1024, 2*1024*1024, ranks) - - #class HaloExchangerNoComm(HaloExchanger): - # def __init__(self, ranks, rank_in_group): - #class HaloExchangerAllGather(HaloExchanger): - # def __init__(self, ranks, rank_in_group, comm): - #class HaloExchangerSendRecv(HaloExchanger): - # def __init__(self, ranks, rank_in_group): - #class HaloExchangerPeer(HaloExchanger): - # def __init__(self, ranks, rank_in_group, peer_pool, explicit_nhwc, numSM=1): - - #halex = HaloExchangerAllGather(ranks, rank_in_group) - #halex = HaloExchangerSendRecv(ranks, rank_in_group) - - halex = HaloExchangerPeer(ranks, rank_in_group, peer_pool, explicit_nhwc, numSM=1) - #print("halex.signals = %s" % (str(halex.signals))) - # Make sure peer memory halo exchanger has finished initializing flags on all ranks before proceeding - #torch.cuda.synchronize() - #torch.distributed.barrier() - - bt2 = n_way_spatial(halex, gt_bottleneck, gt, explicit_nhwc, world_size, rank, fp32_reduce=True) - compare(gt, bt2) - - -if __name__ == "__main__": - main() diff --git a/apex/contrib/bottleneck/halo_exchangers.py b/apex/contrib/bottleneck/halo_exchangers.py deleted file mode 100644 index b627fb2..0000000 --- a/apex/contrib/bottleneck/halo_exchangers.py +++ /dev/null @@ -1,171 +0,0 @@ -import torch -import torch.distributed as dist -from torch import nn -import nccl_p2p_cuda as inc -import peer_memory_cuda as pm - -# Communication free halo exchanger. -# NB! This halo exchanger does not exchange halos with neighbors as it should, it merely swaps the inputs -# NB! This is only useful for performance testing. -# NB! Do not use for actual production runs -class HaloExchanger(object): - def __init__(self, ranks, rank_in_group): - self.stream1 = torch.cuda.Stream() - self.stream2 = torch.cuda.Stream() - self.stream3 = torch.cuda.Stream() - self.group_size = len(ranks) - self.ranks = ranks - self.rank_in_group = rank_in_group - self.wrap_around_left_rank_in_group = (rank_in_group + self.group_size - 1) % self.group_size - self.wrap_around_right_rank_in_group = (rank_in_group + 1) % self.group_size - self.left_rank = ranks[rank_in_group-1] if rank_in_group > 0 else -1 - self.left_zero = True if rank_in_group == 0 else False - self.right_rank = ranks[rank_in_group+1] if rank_in_group < self.group_size - 1 else -1 - self.right_zero = True if rank_in_group == self.group_size - 1 else False - -class HaloExchangerNoComm(HaloExchanger): - def __init__(self, ranks, rank_in_group): - super(HaloExchangerNoComm, self).__init__(ranks, rank_in_group) - - def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None): - if left_input_halo is None: - return right_output_halo, left_output_halo - else: - left_input_halo.copy_(right_output_halo) - right_input_halo.copy_(left_output_halo) - -class HaloExchangerAllGather(HaloExchanger): - def __init__(self, ranks, rank_in_group, comm): - super(HaloExchangerAllGather, self).__init__(ranks, rank_in_group) - # self.comm must be NCCL process_group created with torch.distributed.new_group(ranks=ranks) - self.comm = comm - - def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None): - N,Hh,W,C = list(left_output_halo.shape) - send_halos = torch.empty((N,2*Hh,W,C),dtype=left_output_halo.dtype,device=left_output_halo.device) - send_halos[:,:Hh,:,:].copy_(left_output_halo) - send_halos[:,Hh:,:,:].copy_(right_output_halo) - all_halos = torch.empty((N,2*Hh*self.group_size,W,C),dtype=left_output_halo.dtype,device=left_output_halo.device) - all_halos = [all_halos[:,i*2*Hh:(i+1)*2*Hh,:,:] for i in range(self.group_size)] - torch.distributed.all_gather(all_halos,send_halos,group=self.comm,no_copy=True) - ag_left_input_halo = all_halos[self.wrap_around_left_rank_in_group][:,Hh:,:,:] - ag_right_input_halo = all_halos[self.wrap_around_right_rank_in_group][:,:Hh,:,:] - if left_input_halo is None: - if self.left_zero: - ag_left_input_halo.zero_() - if self.right_zero: - ag_right_input_halo.zero_() - return ag_left_input_halo, ag_right_input_halo - else: - if self.left_zero: - left_input_halo.zero_() - else: - left_input_halo.copy_(ag_left_input_halo) - if self.right_zero: - right_input_halo.zero_() - else: - right_input_halo.copy_(ag_right_input_halo) - -class HaloExchangerSendRecv(HaloExchanger): - def __init__(self, ranks, rank_in_group): - super(HaloExchangerSendRecv, self).__init__(ranks, rank_in_group) - nccl_id = inc.get_unique_nccl_id(1).cuda() - torch.distributed.broadcast(nccl_id, 0) - nccl_id = nccl_id.cpu() - print("%d :: nccl_id = %s" % (torch.distributed.get_rank(), str(nccl_id))) - # Create another global nccl communicator in addition to the one created by torch.distributed.init_process_group("nccl") - # This is unavoidable because the underlying NCCL communicator torch.distributed creates is a protected variable, hence - # it cannot be accessed from another class. - # TODO: Figure out a way to avoid creating a second global communicator - assert(torch.distributed.get_rank() == self.ranks[self.rank_in_group]), "ranks[%d](%d) != torch.distributed.get_rank()(%d)" % (self.rank_in_group, self.ranks[self.rank_in_group], torch.distributed.get_rank()) - self.handle = inc.init_nccl_comm(nccl_id, torch.distributed.get_rank(), torch.distributed.get_world_size()) - - def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None): - if left_input_halo is None: - left_input_halo, right_input_halo = inc.left_right_halo_exchange(self.handle, self.left_rank, self.right_rank , left_output_halo, right_output_halo) - return left_input_halo, right_input_halo - else: - inc.left_right_halo_exchange_inplace(self.handle, self.left_rank, self.right_rank, left_output_halo, right_output_halo, left_input_halo, right_input_halo) - -class HaloExchangerPeer(HaloExchanger): - def __init__(self, ranks, rank_in_group, peer_pool, explicit_nhwc, numSM=1): - super(HaloExchangerPeer, self).__init__(ranks, rank_in_group) - self.diagnostics = False - self.explicit_nhwc = explicit_nhwc - self.numSM = numSM - self.peer_pool = peer_pool - self.signals = peer_pool.allocate_peer_tensors([2,4], torch.int32, False, False) - self.signals[self.rank_in_group].zero_() - - def left_right_halo_exchange(self, left_output_halo, right_output_halo, left_input_halo=None, right_input_halo=None): - inplace = False if left_input_halo is None and right_input_halo is None else True - if not inplace: - left_input_halo = torch.empty_like(right_output_halo) - right_input_halo = torch.empty_like(left_output_halo) - channels_last = left_output_halo.is_contiguous(memory_format=torch.channels_last) and not self.explicit_nhwc - left_tx = self.peer_pool.allocate_peer_tensors(list(left_output_halo.shape), left_output_halo.dtype, channels_last, True) - right_tx = self.peer_pool.allocate_peer_tensors(list(right_output_halo.shape), right_output_halo.dtype, channels_last, True) - pm.push_pull_halos_1d( - self.diagnostics, self.explicit_nhwc, self.numSM, - self.left_zero, left_output_halo, left_tx[self.rank_in_group], right_tx[self.wrap_around_left_rank_in_group], left_input_halo, - self.right_zero, right_output_halo, right_tx[self.rank_in_group], left_tx[self.wrap_around_right_rank_in_group], right_input_halo, - self.signals[self.wrap_around_left_rank_in_group], self.signals[self.wrap_around_right_rank_in_group], self.signals[self.rank_in_group] - ) - if not inplace: - return left_input_halo, right_input_halo - -# Class that combines input volume with halos from neighbors (1d). -class HaloPadder: - def __init__(self, halo_ex): - self.halo_ex = halo_ex - self.stream1 = torch.cuda.Stream() - self.stream2 = torch.cuda.Stream() - - def __call__(self, y, half_halo, explicit_nhwc, H_split): - channels_last = not explicit_nhwc and y.is_contiguous(memory_format=torch.channels_last) - if explicit_nhwc: - N,H,W,C = list(y.shape) - if H_split: - padded_shape = [N,H+2*half_halo,W,C] - ypad = torch.empty(shape=padded_shape, dtype=y.dtype, device=y.device, memory_format=torch.contiguous_format) - yleft = ypad[:,:half_halo,:,:] - ymid = ypad[:,half_halo:H+half_halo,:,:] - yright = ypad[:,H+half_halo:H+2*half_halo,:,:] - oleft = y[:,:half_halo,:,:] - oright = y[:,H-half_halo:,:,:] - else: - padded_shape = [N,H,W+2*half_halo,C] - ypad = torch.empty(shape=padded_shape, dtype=y.dtype, device=y.device, memory_format=torch.contiguous_format) - yleft = ypad[:,:,:half_halo,:] - ymid = ypad[:,:,half_halo:W+half_halo,:] - yright = ypad[:,:,W+half_halo:W+2*half_halo,:] - oleft = y[:,:,:half_halo,:] - oright = y[:,:,W-half_halo:,:] - else: - N,C,H,W = list(y.shape) - if H_split: - padded_shape = [N,C,H+2*half_halo,W] - ypad = torch.empty(shape=padded_shape, dtype=y.dtype, device=y.device, memory_format=torch.channels_last) - yleft = ypad[:,:,:half_halo,:] - ymid = ypad[:,:,half_halo:H+half_halo,:] - yright = ypad[:,:,H+half_halo:H+2*half_halo,:] - oleft = y[:,:,:half_halo,:] - oright = y[:,:,H-half_halo:,:] - else: - padded_shape = [N,C,H,W+2*half_halo] - ypad = torch.empty(shape=padded_shape, dtype=y.dtype, device=y.device, memory_format=torch.channels_last) - yleft = ypad[:,:,:,:half_halo] - ymid = ypad[:,:,:,half_halo:W+half_halo] - yright = ypad[:,:,:,W+half_halo:W+2*half_halo] - oleft = y[:,:,:,:half_halo] - oright = y[:,:,:,W-half_halo:] - with torch.cuda.stream(self.stream1): - self.halo_ex(oleft, oright, yleft, yright) - with torch.cuda.stream(self.stream2): - ymid.copy_(y) - return ypad - - def wait(self): - current_stream = torch.cuda.current_stream() - current_stream.wait_stream(self.stream1) - current_stream.wait_stream(self.stream2) diff --git a/apex/contrib/bottleneck/test.py b/apex/contrib/bottleneck/test.py deleted file mode 100644 index 2c3c621..0000000 --- a/apex/contrib/bottleneck/test.py +++ /dev/null @@ -1,71 +0,0 @@ -import torch -from bottleneck import Bottleneck -torch.manual_seed(23337) - -# use True to print layerwise sum for all outputs in reference code path -DEBUG = False#True - -for stride, o_channel in [(1,32), (1,128), (2,32)]: - print("testing stride ==", stride, ", in_channel == 32 , out_channel ==", o_channel) - a_ = torch.randn(17,32,28,28) - - a = a_.cuda().half().to(memory_format=torch.channels_last).requires_grad_() - model = Bottleneck(32,8,o_channel,stride=stride).cuda().half().to(memory_format=torch.channels_last) - - # test model - b = model(a) - b.mean().backward() - d_grad = a.grad.float() - a.grad = None - torch.cuda.synchronize() - - if DEBUG: - print("[DEBUG] ref dx :", d_grad.sum().item()) - # print wgrad. we don't need to reset since later cpp print before accumulation - for i, w in enumerate(model.w_conv): - print("[DEBUG] ref wgrad{} :".format(i+1), w.grad.sum().item()) - - wgrads = [] - for w in model.w_conv: - wgrads.append(w.grad.float()) - - model.use_cudnn = True - model.zero_grad() - c = model(a) - c.mean().backward() - - torch.cuda.synchronize() - print("comparing native and channels_last:") - print("max error fprop:", (b-c).abs().max().item(), "max elem:", b.abs().max().item()) - print("max error dgrad:", (d_grad-a.grad.float()).abs().max().item(), "max elem:", d_grad.abs().max().item()) - for i, (w, wgrad) in enumerate(zip(model.w_conv, wgrads)): - print("max error wgrad{}:".format(i+1), (wgrad - w.grad.float()).abs().max().item(), "max elem:", wgrad.abs().max().item()) - - nhwc_a = a_.permute(0,2,3,1).contiguous().cuda().half().requires_grad_() - nhwc_model = Bottleneck(32,8,o_channel,stride=stride,explicit_nhwc=True, use_cudnn=True).cuda().half() - for p,q in zip(model.parameters(), nhwc_model.parameters()): - # model's storage is already in nhwc, we clone and assign to explicit nhwc model - q.data.copy_(p.data.permute(0,2,3,1).contiguous()) - for p,q in zip(model.buffers(), nhwc_model.buffers()): - q.data.copy_(p.data) - - d = nhwc_model(nhwc_a) - d.mean().backward() - torch.cuda.synchronize() - - # reset reference to cudnn channels_last permute - #c_s = c.storage().tolist() - #d_s = d.storage().tolist() - #print(max([x-y for x,y in zip(c_s,d_s)])) - c = c.contiguous(memory_format=torch.contiguous_format).permute(0,2,3,1).contiguous() - d_grad = a.grad.float().permute(0,2,3,1).contiguous() - wgrads = [] - for w in model.w_conv: - wgrads.append(w.grad.float().permute(0,2,3,1).contiguous()) - - torch.cuda.synchronize() - print("comparing nhwc and channels_last:") - print("max error fprop:", (d-c).abs().max().item(), "max elem:", c.abs().max().item()) - print("max error dgrad:", (d_grad-nhwc_a.grad.float()).abs().max().item(), "max elem:", d_grad.abs().max().item()) - for i, (w, wgrad) in enumerate(zip(nhwc_model.w_conv, wgrads)): - print("max error wgrad{}:".format(i+1), (wgrad - w.grad.float()).abs().max().item(), "max elem:", wgrad.abs().max().item()) diff --git a/apex/contrib/clip_grad/__init__.py b/apex/contrib/clip_grad/__init__.py deleted file mode 100644 index cc9f501..0000000 --- a/apex/contrib/clip_grad/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .clip_grad import clip_grad_norm_ diff --git a/apex/contrib/clip_grad/clip_grad.py b/apex/contrib/clip_grad/clip_grad.py deleted file mode 100644 index b641135..0000000 --- a/apex/contrib/clip_grad/clip_grad.py +++ /dev/null @@ -1,128 +0,0 @@ -from typing import Union, Iterable - -import torch - -_kernel_import_succeeded = False -try: - import amp_C - from apex.multi_tensor_apply import multi_tensor_applier - _kernel_import_succeeded = True -except ImportError: - _kernel_import_succeeded = False - -_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]] - - -def clip_grad_norm_( - parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0, - error_if_nonfinite: bool = False) -> torch.Tensor: - r"""Clips gradient norm of an iterable of parameters. - - The norm is computed over all gradients together, as if they were - concatenated into a single vector. Gradients are modified in-place. - - This is identical to torch.nn.utils.clip_grad_norm_, except it - uses a fused CUDA kernel when computing the 2-norm of GPU tensors - in float32 and float16. - - Args: - parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a - single Tensor that will have gradients normalized - max_norm (float or int): max norm of the gradients - norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for - infinity norm. - error_if_nonfinite (bool): if True, an error is thrown if the total - norm of the gradients from :attr:`parameters` is ``nan``, - ``inf``, or ``-inf``. Default: False (will switch to True in the future) - - Returns: - Total norm of the parameters (viewed as a single vector). - - """ - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - parameters = [p for p in parameters if p.grad is not None] - max_norm = float(max_norm) - norm_type = float(norm_type) - - # Trivial case - if len(parameters) == 0: - return torch.tensor(0.) - - # Fallback implementation - if not (_kernel_import_succeeded - and norm_type == 2.0 - and any(p.is_cuda for p in parameters)): - return torch.nn.utils.clip_grad_norm_( - parameters, - max_norm, - norm_type=norm_type, - error_if_nonfinite = error_if_nonfinite, - ) - - # Find fp32 and fp16 gradients on GPU - device = next(p.device for p in parameters if p.is_cuda) - grads_fp32, grads_fp16, grads_misc = [], [], [] - for p in parameters: - grad = p.grad.detach() - if p.dtype == torch.float32 and p.device == device: - grads_fp32.append(grad) - elif p.dtype == torch.float16 and p.device == device: - grads_fp16.append(grad) - else: - grads_misc.append(grad) - - # Compute gradient L2 norms - norms = [] - dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device=device) - if grads_fp32: - norms.append( - multi_tensor_applier( - amp_C.multi_tensor_l2norm, - dummy_overflow_buf, - [grads_fp32], - False, - )[0] - ) - if grads_fp16: - norms.append( - multi_tensor_applier( - amp_C.multi_tensor_l2norm, - dummy_overflow_buf, - [grads_fp16], - False, - )[0], - ) - for g in grads_misc: - norms.append(torch.linalg.norm(g).unsqueeze(0).to(device)) - total_norm = torch.linalg.norm(torch.cat(norms)) - - # Check for non-finite values - if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()): - raise RuntimeError( - f'The total norm of order {norm_type} for gradients from ' - '`parameters` is non-finite, so it cannot be clipped. To disable ' - 'this error and scale the gradients by the non-finite norm anyway, ' - 'set `error_if_nonfinite=False`') - - # Scale gradients - clip_coef = max_norm / (total_norm + 1e-6) - clip_coef_clamped = torch.clamp(clip_coef, max=1.0) - if grads_fp32: - multi_tensor_applier( - amp_C.multi_tensor_scale, - dummy_overflow_buf, - [grads_fp32, grads_fp32], - clip_coef_clamped, - ) - if grads_fp16: - multi_tensor_applier( - amp_C.multi_tensor_scale, - dummy_overflow_buf, - [grads_fp16, grads_fp16], - clip_coef_clamped, - ) - for g in grads_misc: - g.mul_(clip_coef_clamped.to(g.device)) - - return total_norm diff --git a/apex/contrib/conv_bias_relu/__init__.py b/apex/contrib/conv_bias_relu/__init__.py deleted file mode 100644 index a257106..0000000 --- a/apex/contrib/conv_bias_relu/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .conv_bias_relu import ConvBiasReLU, ConvBias, ConvBiasMaskReLU - diff --git a/apex/contrib/conv_bias_relu/conv_bias_relu.py b/apex/contrib/conv_bias_relu/conv_bias_relu.py deleted file mode 100644 index b3e66c5..0000000 --- a/apex/contrib/conv_bias_relu/conv_bias_relu.py +++ /dev/null @@ -1,81 +0,0 @@ -import pdb - -import torch -from torch.autograd import gradcheck - -from apex import check_cudnn_version_and_warn -import fused_conv_bias_relu - -check_cudnn_version_and_warn(__name__, 8400) - - -class ConvBiasReLU_(torch.autograd.Function): - @staticmethod - @torch.cuda.amp.custom_fwd(cast_inputs=torch.half) - def forward(ctx, x, weight, bias, padding, stride): - outputs = fused_conv_bias_relu.forward([x, weight, bias], padding, stride) - ctx.save_for_backward(x, weight, outputs[0]) - ctx.padding = padding - ctx.stride = stride - - return outputs[0] - - @staticmethod - @torch.cuda.amp.custom_bwd - def backward(ctx, grad_output): - bwd_args = [*ctx.saved_tensors, grad_output] - padding = ctx.padding - stride = ctx.stride - grads = fused_conv_bias_relu.backward(bwd_args, padding, stride) - - return grads[0], grads[1], grads[2], None, None - - -class ConvBiasMaskReLU_(torch.autograd.Function): - @staticmethod - @torch.cuda.amp.custom_fwd(cast_inputs=torch.half) - def forward(ctx, x, weight, bias, mask, padding, stride): - outputs = fused_conv_bias_relu.forward_mask([x, weight, bias, mask], padding, stride) - ctx.save_for_backward(x, weight, outputs[0]) - ctx.padding = padding - ctx.stride = stride - - return outputs[0] - - @staticmethod - @torch.cuda.amp.custom_bwd - def backward(ctx, grad_output): - bwd_args = [*ctx.saved_tensors, grad_output] - padding = ctx.padding - stride = ctx.stride - grads = fused_conv_bias_relu.backward(bwd_args, padding, stride) - - return grads[0], grads[1], grads[2], None, None, None - - -class ConvBias_(torch.autograd.Function): - @staticmethod - @torch.cuda.amp.custom_fwd(cast_inputs=torch.half) - def forward(ctx, x, weight, bias, padding, stride): - outputs = fused_conv_bias_relu.forward_no_relu([x, weight, bias], padding, stride) - ctx.save_for_backward(x, weight) - ctx.padding = padding - ctx.stride = stride - - return outputs[0] - - @staticmethod - @torch.cuda.amp.custom_bwd - def backward(ctx, grad_output): - bwd_args = [*ctx.saved_tensors, grad_output] - padding = ctx.padding - stride = ctx.stride - grads = fused_conv_bias_relu.backward_no_relu(bwd_args, padding, stride) - - return grads[0], grads[1], grads[2], None, None - - -ConvBiasReLU = ConvBiasReLU_.apply -ConvBiasMaskReLU = ConvBiasMaskReLU_.apply -ConvBias = ConvBias_.apply - diff --git a/apex/contrib/csrc/bottleneck/bottleneck.cpp b/apex/contrib/csrc/bottleneck/bottleneck.cpp deleted file mode 100644 index 9a0c340..0000000 --- a/apex/contrib/csrc/bottleneck/bottleneck.cpp +++ /dev/null @@ -1,4073 +0,0 @@ -#include -#include // for getcudnnhandle -#include -#include -#include -#include - -#include - -#ifdef DEBUG -#define DEBUG_MSG(str) do { std::cout << str << std::endl; } while( false ) -#else -#define DEBUG_MSG(str) do { } while ( false ) -#endif - -#ifdef DEBUG_CUDNN -#define DEBUG_CUDNN_MSG(buf, str) do { buf << str << std::endl; } while( false ) -#else -#define DEBUG_CUDNN_MSG(buf, str) do { } while ( false ) -#endif - -#define checkCudnnErr(...) \ - do { \ - int err = checkCudnnError(__VA_ARGS__, #__VA_ARGS__, __FILE__, __LINE__); \ - if (err) { \ - return; \ - } \ - } while (0) - - -int checkCudnnError(cudnnStatus_t code, const char* expr, const char* file, int line) { - if (code) { - printf("CUDNN error at %s:%d, code=%d (%s) in '%s'\n", file, line, (int)code, cudnnGetErrorString(code), expr); - return 1; - } - return 0; -} - -void checkError(cudaError_t code, char const * func, const char *file, const int line, bool abort = true); -#define checkCUDAError(val) { checkError((val), #val, __FILE__, __LINE__); } // in-line regular function - -void checkError(cudaError_t code, char const * func, const char *file, const int line, bool abort) -{ - if (code != cudaSuccess) - { - const char * errorMessage = cudaGetErrorString(code); - fprintf(stderr, "CUDA error returned from \"%s\" at %s:%d, Error code: %d (%s)\n", func, file, line, code, errorMessage); - if (abort){ - cudaDeviceReset(); - exit(code); - } - } -} - -void generateStrides(const int64_t* dimA, int64_t* strideA, int nbDims, cudnnTensorFormat_t filterFormat) { - // For INT8x4 and INT8x32 we still compute standard strides here to input - // into the cuDNN functions. We will manually scale by resizeFactor in the cpu ref. - if (filterFormat == CUDNN_TENSOR_NCHW) { - strideA[nbDims - 1] = 1; - for (int64_t d = nbDims - 2; d >= 0; d--) { - strideA[d] = strideA[d + 1] * dimA[d + 1]; - } - } else { - // Here we assume that the format is CUDNN_TENSOR_NHWC - strideA[1] = 1; - strideA[nbDims - 1] = strideA[1] * dimA[1]; - for (int64_t d = nbDims - 2; d >= 2; d--) { - strideA[d] = strideA[d + 1] * dimA[d + 1]; - } - strideA[0] = strideA[2] * dimA[2]; - } -} - - -int getFwdConvDilatedFilterDim(int filterDim, int dilation) { - return ((filterDim - 1) * dilation) + 1; -} - -int getFwdConvPaddedImageDim(int tensorDim, int pad) { - return tensorDim + (2 * pad); -} - -int getFwdConvOutputDim( - int tensorDim, - int pad, - int filterDim, - int stride, - int dilation) -{ - int p = (getFwdConvPaddedImageDim(tensorDim, pad) - getFwdConvDilatedFilterDim(filterDim, dilation)) / stride + 1; - return (p); -} - -enum { - X_TENSOR, - Y_TENSOR, - W_TENSOR, - Z_TENSOR, - B_TENSOR, - AFTERADD_TENSOR, - AFTERBIAS_TENSOR, - AFTERCONV_TENSOR, - OPTIONAL, - AFTEROPT_TENSOR, -}; - -using common_conv_descriptors = - std::tuple; - - -common_conv_descriptors -create_common_descriptors(int64_t* x_dim_padded, - int64_t* padA, - int64_t* convstrideA, - int64_t* dilationA, - int64_t* w_dim_padded, - int64_t* y_dim_padded, - cudnnDataType_t dataType, - cudnnConvolutionMode_t mode) { - const int convDim = 2; - - int64_t strideA_padded[4]; - int64_t outstrideA_padded[4]; - int64_t filterstrideA_padded[4]; - - generateStrides(w_dim_padded, filterstrideA_padded, 4, CUDNN_TENSOR_NHWC); - generateStrides(x_dim_padded, strideA_padded, 4, CUDNN_TENSOR_NHWC); - generateStrides(y_dim_padded, outstrideA_padded, 4, CUDNN_TENSOR_NHWC); - - return common_conv_descriptors(cudnn_frontend::TensorBuilder() - .setDim(4, x_dim_padded) - .setStrides(4, strideA_padded) - .setId('x') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, outstrideA_padded) - .setId('y') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, w_dim_padded) - .setStrides(4, filterstrideA_padded) - .setId('w') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::ConvDescBuilder() - .setDataType(CUDNN_DATA_FLOAT) - .setMathMode(mode) - .setNDims(convDim) - .setStrides(convDim, convstrideA) - .setPrePadding(convDim, padA) - .setPostPadding(convDim, padA) - .setDilation(convDim, dilationA) - .build()); -} - -using common_convbias_descriptors = std::tuple; - -common_convbias_descriptors -create_conv_bias_add_act_descriptors(int64_t* x_dim_padded, - int64_t* padA, - int64_t* convstrideA, - int64_t* dilationA, - int64_t* w_dim_padded, - int64_t* y_dim_padded, - cudnnDataType_t dataType) { - const int convDim = 2; - - int64_t b_dim_padded[4]; - b_dim_padded[0] = 1; - b_dim_padded[1] = y_dim_padded[1]; - b_dim_padded[2] = 1; - b_dim_padded[3] = 1; - - int64_t x_stride_padded[4]; - int64_t y_stride_padded[4]; - int64_t w_stride_padded[4]; - int64_t b_stride_padded[4]; - - generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC); - generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC); - generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC); - generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC); - - return common_convbias_descriptors(cudnn_frontend::TensorBuilder() - .setDim(4, x_dim_padded) - .setStrides(4, x_stride_padded) - .setId('x') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('y') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, w_dim_padded) - .setStrides(4, w_stride_padded) - .setId('w') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, b_dim_padded) - .setStrides(4, b_stride_padded) - .setId('z') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, b_dim_padded) - .setStrides(4, b_stride_padded) - .setId('b') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setVirtual() - .setId('A') // after add - .setAlignment(16) - .setDataType(CUDNN_DATA_FLOAT) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setVirtual() - .setId('B') // after bias - .setAlignment(16) - .setDataType(CUDNN_DATA_FLOAT) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('C') // after conv - .setAlignment(16) - .setVirtual() - .setDataType(CUDNN_DATA_FLOAT) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('i') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('D') // after optional add - .setAlignment(16) - .setVirtual() - .setDataType(CUDNN_DATA_FLOAT) - .build()); -} - -// tensor descriptors used for dgrad -enum { - X_OR_DX_TENSOR, - DY_TENSOR, - W_OR_DW_TENSOR, - SCALE_TENSOR, - RELU_TENSOR, - AFTER_DCONV_TENSOR, - AFTER_DRELU_TENSOR, -}; - -using dconv_descriptors = std::tuple; - -dconv_descriptors -create_dconv_descriptors(int64_t* x_dim_padded, - int64_t* padA, - int64_t* convstrideA, - int64_t* dilationA, - int64_t* w_dim_padded, - int64_t* y_dim_padded, - cudnnDataType_t dataType) { - const int convDim = 2; - - int64_t b_dim_padded[4]; - b_dim_padded[0] = 1; - b_dim_padded[1] = x_dim_padded[1]; - b_dim_padded[2] = 1; - b_dim_padded[3] = 1; - - int64_t x_stride_padded[4]; - int64_t y_stride_padded[4]; - int64_t w_stride_padded[4]; - int64_t b_stride_padded[4]; - - generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC); - generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC); - generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC); - generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC); - - return dconv_descriptors(cudnn_frontend::TensorBuilder() - .setDim(4, x_dim_padded) - .setStrides(4, x_stride_padded) - .setId('x') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('y') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, w_dim_padded) - .setStrides(4, w_stride_padded) - .setId('w') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, b_dim_padded) - .setStrides(4, b_stride_padded) - .setId('s') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, x_dim_padded) - .setStrides(4, x_stride_padded) - .setId('r') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, x_dim_padded) - .setStrides(4, x_stride_padded) - .setVirtual() - .setId('A') // after dconv - .setAlignment(16) - .setDataType(CUDNN_DATA_FLOAT) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, x_dim_padded) - .setStrides(4, x_stride_padded) - .setVirtual() - .setId('B') // after drelu - .setAlignment(16) - .setDataType(CUDNN_DATA_FLOAT) - .build()); -} - -// create a cache for plan -std::unordered_map plan_cache; - -// TODO: better name -std::string getConvFusionString(int64_t* x_dim_padded, - int64_t* padA, - int64_t* convstrideA, - int64_t* dilationA, - int64_t* w_dim_padded, - cudnnDataType_t dataType, - std::string fusion_string) { - - for(int i=0;i<4;i++) { - fusion_string += 'X'; - fusion_string += std::to_string(x_dim_padded[i]); - } - for(int i=0;i<4;i++) { - fusion_string += 'W'; - fusion_string += std::to_string(w_dim_padded[i]); - } - for(int i=0;i<2;i++) { - fusion_string += 'P'; - fusion_string += std::to_string(padA[i]); - } - for(int i=0;i<2;i++) { - fusion_string += 'S'; - fusion_string += std::to_string(convstrideA[i]); - } - for(int i=0;i<2;i++) { - fusion_string += 'D'; - fusion_string += std::to_string(dilationA[i]); - } - fusion_string += 'T'; - fusion_string += std::to_string(dataType); - return fusion_string; -} - -cudnn_frontend::ExecutionPlan& getOrCreatePlan(cudnnHandle_t handle_, - std::stringstream& log_buf, - cudnn_frontend::OperationGraph& opGraph, - std::string cache_string, - bool use_heuristic = true){ - auto it = plan_cache.find(cache_string); - if (it != plan_cache.end()) { - DEBUG_CUDNN_MSG(log_buf, "Found plan in cache"); - return it->second; - } else { - if (use_heuristic){ - // TODO: confirm which mode to use - auto heuristics = cudnn_frontend::EngineHeuristicsBuilder() - .setOperationGraph(opGraph) - .setHeurMode(CUDNN_HEUR_MODE_INSTANT) - .build(); - // try 3 times for now as WAR for no heuristic training - int max_tries = 3, count = 0; - auto& engine_configs = heuristics.getEngineConfig(max_tries); - while(true) { - try { - plan_cache.emplace(cache_string, std::move(cudnn_frontend::ExecutionPlanBuilder() - .setHandle(handle_) - .setEngineConfig(engine_configs[count], opGraph.getTag()) - .build())); - break; - } catch (cudnn_frontend::cudnnException e) { - if (++count == max_tries) throw e; - } - } - }else{ - DEBUG_CUDNN_MSG(log_buf, "No plan in cache"); - // How many engines support this operation graph ? - auto total_engines = opGraph.getEngineCount(); - DEBUG_CUDNN_MSG(log_buf, opGraph.describe() << " has " << total_engines << " engines."); - // We have to randomly pick one engine from [0, total_engines) - // Selecting "0" by default - auto engine = cudnn_frontend::EngineBuilder().setGlobalEngineIdx(0).setOperationGraph(opGraph).build(); - DEBUG_CUDNN_MSG(log_buf, engine.describe()); - auto& knobs = engine.getSupportedKnobs(); - for (auto it = std::begin(knobs); it != std::end(knobs); ++it) { - DEBUG_CUDNN_MSG(log_buf, it->describe()); - } - if (knobs.begin() != knobs.end()) { - DEBUG_CUDNN_MSG(log_buf, "Updated knob choice"); - knobs.begin()->setChoice(knobs.begin()->getMinValue() + 1); - DEBUG_CUDNN_MSG(log_buf, knobs.begin()->describe()); - } - - // Createmplacee the requisite engine config - auto engine_config = cudnn_frontend::EngineConfigBuilder().setEngine(engine).build(); - DEBUG_CUDNN_MSG(log_buf, engine_config.describe()); - plan_cache.emplace(cache_string, std::move(cudnn_frontend::ExecutionPlanBuilder().setHandle(handle_).setEngineConfig(engine_config).build())); - } - - return plan_cache.find(cache_string)->second; - } -} - -void -run_conv_scale_bias_add_activation(int64_t* x_dim_padded, - int64_t* pad, - int64_t* convstride, - int64_t* dilation, - int64_t* w_dim_padded, - int64_t* y_dim_padded, - cudnnDataType_t dataType, - at::Half* devPtrX, - at::Half* devPtrW, - at::Half* devPtrY, - at::Half* devPtrZ, - at::Half* devPtrB, - at::Half* devPtrI) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); - std::stringstream log_buf; - try { - int convDim = 2; - - // Creates the necessary tensor descriptors - common_convbias_descriptors tensors = create_conv_bias_add_act_descriptors( - x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - - // Define the add operation - auto scaleDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_MUL) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe()); - - // Define the bias operation - auto biasDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_ADD) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, biasDesc.describe()); - - // optional add - auto addDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_ADD) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, addDesc.describe()); - - // Define the activation operation - auto actDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_RELU_FWD) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); - - // Define the convolution problem - auto convDesc = cudnn_frontend::ConvDescBuilder() - .setDataType(CUDNN_DATA_FLOAT) - .setMathMode(CUDNN_CROSS_CORRELATION) - .setNDims(convDim) - .setStrides(convDim, convstride) - .setPrePadding(convDim, pad) - .setPostPadding(convDim, pad) - .setDilation(convDim, dilation) - .build(); - DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); - - float alpha = 1.0f; - float beta = 0.0f; - - // Create a convolution Node - auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) - .setxDesc(std::get(tensors)) - .setwDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setcDesc(convDesc) - .setAlpha(alpha) - .setBeta(beta) - .build(); - DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); - - // Create a Add Node with scaling parameters. - auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(conv_op.getOutputTensor()) - .setbDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(scaleDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, scale_op.describe()); - - // Create a Bias Node. - auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(scale_op.getOutputTensor()) - .setbDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(biasDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, bias_op.describe()); - - // Create a optional add Node. - auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(bias_op.getOutputTensor()) - .setbDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(addDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, add_op.describe()); - - - // Create an Activation Node. - auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(devPtrI ? add_op.getOutputTensor() : bias_op.getOutputTensor()) - .setyDesc(std::get(tensors)) - .setpwDesc(actDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, act_op.describe()); - - // Create an Operation Graph. In this case it is convolution add bias activation - std::array ops = {&conv_op, &scale_op, &bias_op, devPtrI ? &add_op : &act_op, &act_op}; - - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle_) - .setOperationGraph(devPtrI ? ops.size() : 4, ops.data()) - .build(); - - // Create string encoding for plan caching - auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); - DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); - - auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); - DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); - - auto workspace_size = plan.getWorkspaceSize(); - DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); - - void* workspace_ptr = nullptr; - auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); - if (workspace_size > 0) { - workspace_ptr = workspace_tensor.data_ptr(); - } - void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB, devPtrI}; - int64_t uids[] = {'x', 'y', 'w', 'z', 'b', 'i'}; - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace_ptr) - .setDataPointers(devPtrI ? 6 : 5, data_ptrs) - .setUids(devPtrI ? 6 : 5, uids) - .build(); - DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); - cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); - checkCudnnErr(status); - cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); - } catch (cudnn_frontend::cudnnException e) { - std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; - } -} - -void -run_conv_scale_bias(int64_t* x_dim_padded, - int64_t* pad, - int64_t* convstride, - int64_t* dilation, - int64_t* w_dim_padded, - int64_t* y_dim_padded, - cudnnDataType_t dataType, - at::Half* devPtrX, - at::Half* devPtrW, - at::Half* devPtrY, - at::Half* devPtrZ, - at::Half* devPtrB) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); - std::stringstream log_buf; - try { - int convDim = 2; - - // Creates the necessary tensor descriptors - common_convbias_descriptors tensors = create_conv_bias_add_act_descriptors( - x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - - // Define the add operation - auto scaleDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_MUL) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe()); - - // Define the bias operation - auto addDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_ADD) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, addDesc.describe()); - - // Define the convolution problem - auto convDesc = cudnn_frontend::ConvDescBuilder() - .setDataType(CUDNN_DATA_FLOAT) - .setMathMode(CUDNN_CROSS_CORRELATION) - .setNDims(convDim) - .setStrides(convDim, convstride) - .setPrePadding(convDim, pad) - .setPostPadding(convDim, pad) - .setDilation(convDim, dilation) - .build(); - DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); - - float alpha = 1.0f; - float beta = 0.0f; - - // Create a convolution Node - auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) - .setxDesc(std::get(tensors)) - .setwDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setcDesc(convDesc) - .setAlpha(alpha) - .setBeta(beta) - .build(); - DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); - - // Create a Add Node with scaling parameters. - auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(conv_op.getOutputTensor()) - .setbDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) // TODO: change enum to aftermul - .setpwDesc(scaleDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, scale_op.describe()); - - // Create a Bias Node. - auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(scale_op.getOutputTensor()) - .setbDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(addDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, add_op.describe()); - - // Create an Operation Graph. In this case it is convolution add bias activation - std::array ops = {&conv_op, &scale_op, &add_op}; - - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle_) - .setOperationGraph(ops.size(), ops.data()) - .build(); - - // Create string encoding for plan caching - auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); - DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); - - auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); - DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); - - auto workspace_size = plan.getWorkspaceSize(); - DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); - - void* workspace_ptr = nullptr; - auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); - if (workspace_size > 0) { - workspace_ptr = workspace_tensor.data_ptr(); - } - void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB}; - int64_t uids[] = {'x', 'y', 'w', 'z', 'b'}; - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace_ptr) - .setDataPointers(5, data_ptrs) - .setUids(5, uids) - .build(); - DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); - cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); - checkCudnnErr(status); - cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); - } catch (cudnn_frontend::cudnnException e) { - std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; - } -} - - -void -run_dconv_drelu_dscale(int64_t* x_dim_padded, - int64_t* pad, - int64_t* convstride, - int64_t* dilation, - int64_t* w_dim_padded, - int64_t* y_dim_padded, - cudnnDataType_t dataType, - at::Half* devPtrX, - at::Half* devPtrW, - at::Half* devPtrY, - at::Half* devPtrZ, - at::Half* devPtrR) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); - std::stringstream log_buf; - try { - int convDim = 2; - - // Creates the necessary tensor descriptors - dconv_descriptors tensors = create_dconv_descriptors( - x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - - // Define the convolution problem - auto convDesc = cudnn_frontend::ConvDescBuilder() - .setDataType(CUDNN_DATA_FLOAT) - .setMathMode(CUDNN_CROSS_CORRELATION) - .setNDims(convDim) - .setStrides(convDim, convstride) - .setPrePadding(convDim, pad) - .setPostPadding(convDim, pad) - .setDilation(convDim, dilation) - .build(); - DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); - - // Define the activation backward operation - auto actDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_RELU_BWD) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); - - // Define the scale backward operation - auto scaleDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_MUL) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe()); - - float alpha = 1.0f; - float beta = 0.0f; - - // Create a convolution Node - auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) - .setdxDesc(std::get(tensors)) - .setwDesc(std::get(tensors)) - .setdyDesc(std::get(tensors)) - .setcDesc(convDesc) - .setAlpha(alpha) - .setBeta(beta) - .build(); - DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); - - // TODO: do we need getOutputTensor(), and what it returns in backward case? - // Create an relu backward Node. - auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setdyDesc(std::get(tensors)) - .setxDesc(std::get(tensors)) - .setdxDesc(std::get(tensors)) - .setpwDesc(actDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, act_op.describe()); - - // Create a Scale Node. - auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(std::get(tensors)) - .setbDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(scaleDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, scale_op.describe()); - - // Create an Operation Graph. In this case it is convolution add bias activation - std::array ops = {&conv_op, &act_op, &scale_op}; - - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle_) - .setOperationGraph(ops.size(), ops.data()) - .build(); - - // Create string encoding for plan caching - auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); - DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); - - auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); - DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); - - auto workspace_size = plan.getWorkspaceSize(); - DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); - - void* workspace_ptr = nullptr; - auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); - if (workspace_size > 0) { - workspace_ptr = workspace_tensor.data_ptr(); - } - void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrR}; - int64_t uids[] = {'x', 'y', 'w', 's', 'r'}; - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace_ptr) - .setDataPointers(5, data_ptrs) - .setUids(5, uids) - .build(); - DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); - cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); - checkCudnnErr(status); - cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); - } catch (cudnn_frontend::cudnnException e) { - std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; - } -} - -void -run_dconv(int64_t* x_dim_padded, - int64_t* pad, - int64_t* convstride, - int64_t* dilation, - int64_t* w_dim_padded, - int64_t* y_dim_padded, - cudnnDataType_t dataType, - at::Half* devPtrX, - at::Half* devPtrW, - at::Half* devPtrY, - cudnnBackendDescriptorType_t mode) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); - std::stringstream log_buf; - try { - int convDim = 2; - - // Creates the necessary tensor descriptors - dconv_descriptors tensors = create_dconv_descriptors( - x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - - // Define the convolution problem - auto convDesc = cudnn_frontend::ConvDescBuilder() - .setDataType(CUDNN_DATA_FLOAT) - .setMathMode(CUDNN_CROSS_CORRELATION) - .setNDims(convDim) - .setStrides(convDim, convstride) - .setPrePadding(convDim, pad) - .setPostPadding(convDim, pad) - .setDilation(convDim, dilation) - .build(); - DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); - - float alpha = 1.0f; - float beta = 0.0f; - - // Create a convolution Node - // mode should be one of following - // CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR - // CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR - auto conv_op_builder = cudnn_frontend::OperationBuilder(mode); - if (mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) { - conv_op_builder.setdxDesc(std::get(tensors)) - .setwDesc(std::get(tensors)) - .setdyDesc(std::get(tensors)) - .setcDesc(convDesc) - .setAlpha(alpha) - .setBeta(beta); - } - else { - conv_op_builder.setxDesc(std::get(tensors)) - .setdwDesc(std::get(tensors)) - .setdyDesc(std::get(tensors)) - .setcDesc(convDesc) - .setAlpha(alpha) - .setBeta(beta); - } - auto conv_op = conv_op_builder.build(); - DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); - - // Create an Operation Graph. In this case it is convolution add bias activation - std::array ops = {&conv_op}; - - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle_) - .setOperationGraph(ops.size(), ops.data()) - .build(); - - // Create string encoding for plan caching - auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); - DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); - - auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); - DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); - - auto workspace_size = plan.getWorkspaceSize(); - DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); - - void* workspace_ptr = nullptr; - auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); - if (workspace_size > 0) { - workspace_ptr = workspace_tensor.data_ptr(); - } - void* data_ptrs[] = {devPtrX, devPtrY, devPtrW}; - int64_t uids[] = {'x', 'y', 'w'}; - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace_ptr) - .setDataPointers(3, data_ptrs) - .setUids(3, uids) - .build(); - DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); - cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); - checkCudnnErr(status); - cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); - } catch (cudnn_frontend::cudnnException e) { - std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; - } -} - -void -run_dconv_add(int64_t* x_dim_padded, - int64_t* pad, - int64_t* convstride, - int64_t* dilation, - int64_t* w_dim_padded, - int64_t* y_dim_padded, - cudnnDataType_t dataType, - at::Half* devPtrX, - at::Half* devPtrW, - at::Half* devPtrY, - at::Half* devPtrR) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); - std::stringstream log_buf; - try { - int convDim = 2; - - // Creates the necessary tensor descriptors - dconv_descriptors tensors = create_dconv_descriptors( - x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - - // Define the convolution problem - auto convDesc = cudnn_frontend::ConvDescBuilder() - .setDataType(CUDNN_DATA_FLOAT) - .setMathMode(CUDNN_CROSS_CORRELATION) - .setNDims(convDim) - .setStrides(convDim, convstride) - .setPrePadding(convDim, pad) - .setPostPadding(convDim, pad) - .setDilation(convDim, dilation) - .build(); - DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); - - // Define the add backward operation - auto addDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_ADD) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, addDesc.describe()); - - float alpha = 1.0f; - float beta = 0.0f; - - // Create a convolution Node - auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) - .setdxDesc(std::get(tensors)) - .setwDesc(std::get(tensors)) - .setdyDesc(std::get(tensors)) - .setcDesc(convDesc) - .setAlpha(alpha) - .setBeta(beta) - .build(); - DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); - - // TODO: do we need getOutputTensor(), and what it returns in backward case? - // Create add Node. - auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(std::get(tensors)) - .setbDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(addDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, add_op.describe()); - - // Create an Operation Graph. In this case it is convolution add bias activation - std::array ops = {&conv_op, &add_op}; - - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle_) - .setOperationGraph(ops.size(), ops.data()) - .build(); - - // Create string encoding for plan caching - auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); - DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); - - auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); - DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); - - auto workspace_size = plan.getWorkspaceSize(); - DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); - - void* workspace_ptr = nullptr; - auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); - if (workspace_size > 0) { - workspace_ptr = workspace_tensor.data_ptr(); - } - void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrR}; - int64_t uids[] = {'x', 'y', 'w', 'r'}; - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace_ptr) - .setDataPointers(4, data_ptrs) - .setUids(4, uids) - .build(); - DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); - cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); - checkCudnnErr(status); - cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); - } catch (cudnn_frontend::cudnnException e) { - std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; - } -} - - -// inputs contains x,w,z,b,(i) -std::vector bottleneck_forward(bool explicit_nhwc, int stride_1X1, std::vector inputs) { - - std::cout << std::fixed; - // create output vector - std::vector outputs; - auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; - - // setup dimensions - int64_t dimA[] = {0, 0, 0, 0}; - int64_t filterdimA1[] = {0, 0, 0, 0}; - int64_t filterdimA2[] = {0, 0, 0, 0}; - int64_t filterdimA3[] = {0, 0, 0, 0}; - int64_t filterdimA4[] = {0, 0, 0, 0}; - - // All dim calculation after this order of n,c,h,w - int axis[] {0,1,2,3}; - if (explicit_nhwc) { - axis[0] = 0; - axis[1] = 3; - axis[2] = 1; - axis[3] = 2; - } - for (int dim=0;dim<4;dim++) { - dimA[dim] = inputs[0].size(axis[dim]); - filterdimA1[dim] = inputs[1].size(axis[dim]); - filterdimA2[dim] = inputs[2].size(axis[dim]); - filterdimA3[dim] = inputs[3].size(axis[dim]); - } - if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) { - for (int dim=0;dim<4;dim++) { - filterdimA4[dim] = inputs[10].size(axis[dim]); - } - } - - // output dim in n,c,h,w used by backend - int64_t outdimA1[] = {0, 0, 0, 0}; // Computed Below - int64_t outdimA2[] = {0, 0, 0, 0}; // Computed Below - int64_t outdimA3[] = {0, 0, 0, 0}; // Computed Below - - // use these fixed value for test run - int64_t padA[] = {0, 0}; - int64_t padA1[] = {1, 1}; - int64_t dilationA[] = {1, 1}; - int64_t convstrideA[] = {1, 1}; - int64_t convstride1X1[] = {stride_1X1, stride_1X1}; - - // compute output from pad/stride/dilation - outdimA1[0] = dimA[0]; - outdimA1[1] = filterdimA1[0]; - for (int dim = 0; dim < 2; dim++) { - outdimA1[dim + 2] = getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]); - } - - outdimA2[0] = outdimA1[0]; - outdimA2[1] = filterdimA2[0]; - for (int dim = 0; dim < 2; dim++) { - outdimA2[dim + 2] = getFwdConvOutputDim(outdimA1[dim + 2], padA1[dim], filterdimA2[dim + 2], convstrideA[dim], dilationA[dim]); - } - - outdimA3[0] = outdimA2[0]; - outdimA3[1] = filterdimA3[0]; - for (int dim = 0; dim < 2; dim++) { - outdimA3[dim + 2] = getFwdConvOutputDim(outdimA2[dim + 2], padA[dim], filterdimA3[dim + 2], convstrideA[dim], dilationA[dim]); - } - - // Create output tensor in the correct shape in pytorch's view - int64_t outdim1[] = {0, 0, 0, 0}; - int64_t outdim2[] = {0, 0, 0, 0}; - int64_t outdim3[] = {0, 0, 0, 0}; - if (explicit_nhwc) { - axis[0] = 0; - axis[1] = 2; - axis[2] = 3; - axis[3] = 1; - } - for (int dim=0;dim<4;dim++) { - outdim1[dim] = outdimA1[axis[dim]]; - outdim2[dim] = outdimA2[axis[dim]]; - outdim3[dim] = outdimA3[axis[dim]]; - } - - // run - at::Half* x = inputs[0].data_ptr(); - at::Half* w = inputs[1].data_ptr(); - at::Half* z = inputs[4].data_ptr(); - at::Half* b = inputs[7].data_ptr(); - auto out1 = at::empty(outdim1, inputs[0].type(), output_format); - at::Half* y1 = out1.data_ptr(); - - run_conv_scale_bias_add_activation(dimA, - padA, - convstride1X1, - dilationA, - filterdimA1, - outdimA1, - CUDNN_DATA_HALF, - x, - w, - y1, - z, - b, - nullptr); - - DEBUG_MSG("[DEBUG] new relu1 : " << out1.to(at::kFloat).sum().item()); - - w = inputs[2].data_ptr(); - z = inputs[5].data_ptr(); - b = inputs[8].data_ptr(); - auto out2 = at::empty(outdim2, inputs[0].type(), output_format); - at::Half* y2 = out2.data_ptr(); - - run_conv_scale_bias_add_activation(outdimA1, - padA1, - convstrideA, - dilationA, - filterdimA2, - outdimA2, - CUDNN_DATA_HALF, - y1, - w, - y2, - z, - b, - nullptr); - DEBUG_MSG("[DEBUG] new relu2 : " << out2.to(at::kFloat).sum().item()); - - // create output of conv3 - auto out3 = at::empty(outdim3, inputs[0].type(), output_format); - at::Half* y3 = out3.data_ptr(); - - // create output of conv4 that may exist - auto identity = at::empty_like(out3); - at::Half* yi = identity.data_ptr(); - - if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]){ - - w = inputs[10].data_ptr(); - z = inputs[11].data_ptr(); - b = inputs[12].data_ptr(); - run_conv_scale_bias(dimA, - padA, - convstride1X1, - dilationA, - filterdimA4, - outdimA3, - CUDNN_DATA_HALF, - x, - w, - yi, - z, - b); - DEBUG_MSG("[DEBUG] new downsample : " << identity.to(at::kFloat).sum().item()); - } - else { - yi = x; - } - - w = inputs[3].data_ptr(); - z = inputs[6].data_ptr(); - b = inputs[9].data_ptr(); - - run_conv_scale_bias_add_activation(outdimA2, - padA, - convstrideA, - dilationA, - filterdimA3, - outdimA3, - CUDNN_DATA_HALF, - y2, - w, - y3, - z, - b, - yi); - DEBUG_MSG("[DEBUG] new relu3 : " << out3.to(at::kFloat).sum().item()); - - outputs.push_back(out1); - outputs.push_back(out2); - outputs.push_back(out3); - - return outputs; -} - -std::vector bottleneck_backward(bool explicit_nhwc, int stride_1X1, std::vector inputs) { - - bool requires_grad = inputs[0].requires_grad(); - - std::cout << std::fixed; - // create output vector - std::vector outputs; - auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; - - // setup dimensions - int64_t dimA[] = {0, 0, 0, 0}; - int64_t filterdimA1[] = {0, 0, 0, 0}; - int64_t filterdimA2[] = {0, 0, 0, 0}; - int64_t filterdimA3[] = {0, 0, 0, 0}; - int64_t filterdimA4[] = {0, 0, 0, 0}; - - // All dim calculation after this order of n,c,h,w - int axis[] {0,1,2,3}; - if (explicit_nhwc) { - axis[0] = 0; - axis[1] = 3; - axis[2] = 1; - axis[3] = 2; - } - for (int dim=0;dim<4;dim++) { - dimA[dim] = inputs[0].size(axis[dim]); - filterdimA1[dim] = inputs[1].size(axis[dim]); - filterdimA2[dim] = inputs[2].size(axis[dim]); - filterdimA3[dim] = inputs[3].size(axis[dim]); - } - if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) { - for (int dim=0;dim<4;dim++) { - filterdimA4[dim] = inputs[14].size(axis[dim]); - } - } - - // output dim in n,c,h,w used by backend - int64_t outdimA1[] = {0, 0, 0, 0}; // Computed Below - int64_t outdimA2[] = {0, 0, 0, 0}; // Computed Below - int64_t outdimA3[] = {0, 0, 0, 0}; // Computed Below - - // use these fixed value for test run - int64_t padA[] = {0, 0}; - int64_t padA1[] = {1, 1}; - int64_t dilationA[] = {1, 1}; - int64_t convstrideA[] = {1, 1}; - int64_t convstride1X1[] = {stride_1X1, stride_1X1}; - - // compute output from pad/stride/dilation - outdimA1[0] = dimA[0]; - outdimA1[1] = filterdimA1[0]; - for (int dim = 0; dim < 2; dim++) { - outdimA1[dim + 2] = getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]); - } - - outdimA2[0] = outdimA1[0]; - outdimA2[1] = filterdimA2[0]; - for (int dim = 0; dim < 2; dim++) { - outdimA2[dim + 2] = getFwdConvOutputDim(outdimA1[dim + 2], padA1[dim], filterdimA2[dim + 2], convstrideA[dim], dilationA[dim]); - } - - outdimA3[0] = outdimA2[0]; - outdimA3[1] = filterdimA3[0]; - for (int dim = 0; dim < 2; dim++) { - outdimA3[dim + 2] = getFwdConvOutputDim(outdimA2[dim + 2], padA[dim], filterdimA3[dim + 2], convstrideA[dim], dilationA[dim]); - } - - // Create output tensor in the correct shape in pytorch's view - int64_t outdim1[] = {0, 0, 0, 0}; - int64_t outdim2[] = {0, 0, 0, 0}; - int64_t outdim3[] = {0, 0, 0, 0}; - if (explicit_nhwc) { - axis[0] = 0; - axis[1] = 2; - axis[2] = 3; - axis[3] = 1; - } - for (int dim=0;dim<4;dim++) { - outdim1[dim] = outdimA1[axis[dim]]; - outdim2[dim] = outdimA2[axis[dim]]; - outdim3[dim] = outdimA3[axis[dim]]; - } - - // dconv3+drelu2+dscale2 - at::Half* conv_in = inputs[13].data_ptr(); - at::Half* dy3 = inputs[10].data_ptr(); - - DEBUG_MSG("[DEBUG] new dconv3 : " << inputs[10].to(at::kFloat).sum().item()); - - // wgrad - auto wgrad3 = at::empty_like(inputs[3]); - at::Half* dw3 = wgrad3.data_ptr(); - run_dconv(outdimA2, - padA, - convstrideA, - dilationA, - filterdimA3, - outdimA3, - CUDNN_DATA_HALF, - conv_in, - dw3, - dy3, - CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); - - // dgrad - auto grad_out2 = at::empty(outdim2, inputs[0].type(), output_format); - at::Half* dy2 = grad_out2.data_ptr(); - at::Half* w = inputs[3].data_ptr(); - at::Half* z = inputs[5].data_ptr(); - - at::Half* relu2 = inputs[13].data_ptr(); - - run_dconv_drelu_dscale(outdimA2, - padA, - convstrideA, - dilationA, - filterdimA3, - outdimA3, - CUDNN_DATA_HALF, - dy2, - w, - dy3, - z, - relu2); - - DEBUG_MSG("[DEBUG] new dconv2 : " << grad_out2.to(at::kFloat).sum().item()); - - // dconv2+drelu1+dscale1 - conv_in = inputs[12].data_ptr(); - - // wgrad - auto wgrad2 = at::empty_like(inputs[2]); - at::Half* dw2 = wgrad2.data_ptr(); - run_dconv(outdimA1, - padA1, - convstrideA, - dilationA, - filterdimA2, - outdimA2, - CUDNN_DATA_HALF, - conv_in, - dw2, - dy2, - CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); - - // dgrad - auto grad_out1 = at::empty(outdim1, inputs[0].type(), output_format); - at::Half* dy1 = grad_out1.data_ptr(); - w = inputs[2].data_ptr(); - z = inputs[4].data_ptr(); - - at::Half* relu1 = inputs[12].data_ptr(); - // fused dgrad - run_dconv_drelu_dscale(outdimA1, - padA1, - convstrideA, - dilationA, - filterdimA2, - outdimA2, - CUDNN_DATA_HALF, - dy1, - w, - dy2, - z, - relu1); - -/* - // backward strided conv cannot be fused - // if stride == 1 but channel changes, we can fuse here - if (stride_1X1 != 1){ - // dgrad - run_dconv(outdimA1, - padA1, - convstride1X1, - dilationA, - filterdimA2, - outdimA2, - CUDNN_DATA_HALF, - dy1, - w, - dy2, - CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR); - - // mul fused mask - grad_out1.mul_(inputs[15]); - } - else { - at::Half* relu1 = inputs[12].data_ptr(); - // fused dgrad - run_dconv_drelu_dscale(outdimA1, - padA1, - convstride1X1, - dilationA, - filterdimA2, - outdimA2, - CUDNN_DATA_HALF, - dy1, - w, - dy2, - z, - relu1); - } -*/ - DEBUG_MSG("[DEBUG] new dconv1 : " << grad_out1.to(at::kFloat).sum().item()); - - // create grads of conv4 that may exist - auto grad_x_conv4 = at::empty_like(inputs[0]); - at::Half* dx_conv4 = grad_x_conv4.data_ptr(); - at::Tensor wgrad4; - - // x used for dconv1 and dconv4 wgrad - at::Half* x = inputs[0].data_ptr(); - - if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]){ - w = inputs[14].data_ptr(); - at::Half* dy_conv4 = inputs[11].data_ptr(); - if (requires_grad) { - run_dconv(dimA, - padA, - convstride1X1, - dilationA, - filterdimA4, - outdimA3, - CUDNN_DATA_HALF, - dx_conv4, - w, - dy_conv4, - CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR); - // we don't print here since we can't hook out this grad in pytorch alone to compare, due to addition with dx - // DEBUG_MSG("[DEBUG] new dx_identity : " << grad_x_conv4.to(at::kFloat).sum().item()); - } - // wgrad - wgrad4 = at::empty_like(inputs[14]); - at::Half* dw4 = wgrad4.data_ptr(); - run_dconv(dimA, - padA, - convstride1X1, - dilationA, - filterdimA4, - outdimA3, - CUDNN_DATA_HALF, - x, - dw4, - dy_conv4, - CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); - } - else { - // if there is no downsample, dx_conv4 is fork of drelu3 - dx_conv4 = inputs[11].data_ptr(); - } - - // dconv1+add - // wgrad - auto wgrad1 = at::empty_like(inputs[1]); - at::Half* dw1 = wgrad1.data_ptr(); - run_dconv(dimA, - padA, - convstride1X1, - dilationA, - filterdimA1, - outdimA1, - CUDNN_DATA_HALF, - x, - dw1, - dy1, - CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); - - // dgrad - w = inputs[1].data_ptr(); - auto grad_x = at::empty_like(inputs[0]); - at::Half* dx = grad_x.data_ptr(); - - // backward strided conv cannot be fused - // if stride == 1 but channel changes, we can fuse here - if (requires_grad){ - if (stride_1X1 != 1){ - run_dconv(dimA, - padA, - convstride1X1, - dilationA, - filterdimA1, - outdimA1, - CUDNN_DATA_HALF, - dx, - w, - dy1, - CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR); - // add 2 together - grad_x.add_(grad_x_conv4); - } - else { - run_dconv_add(dimA, - padA, - convstride1X1, - dilationA, - filterdimA1, - outdimA1, - CUDNN_DATA_HALF, - dx, - w, - dy1, - dx_conv4); - } - } - - DEBUG_MSG("[DEBUG] new dx : " << grad_x.to(at::kFloat).sum().item()); - DEBUG_MSG("[DEBUG] new wgrad1 : " << wgrad1.to(at::kFloat).sum().item()); - DEBUG_MSG("[DEBUG] new wgrad2 : " << wgrad2.to(at::kFloat).sum().item()); - DEBUG_MSG("[DEBUG] new wgrad3 : " << wgrad3.to(at::kFloat).sum().item()); - outputs.push_back(grad_x); - outputs.push_back(wgrad1); - outputs.push_back(wgrad2); - outputs.push_back(wgrad3); - - if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) { - DEBUG_MSG("[DEBUG] new wgrad4 : " << wgrad4.to(at::kFloat).sum().item()); - outputs.push_back(wgrad4); - } - - return outputs; -} - -namespace { - -enum { - X_TENSOR, - Y_TENSOR, - W_TENSOR, - Z_TENSOR, - B_TENSOR, - AFTERADD_TENSOR, - AFTERBIAS_TENSOR, - AFTERCONV_TENSOR, - OPTIONAL, - AFTEROPT_TENSOR, - AFTERACT_TENSOR, - GEN_INDEX_TENSOR, - MASK_TOP_TENSOR, - MASK_BOTTOM_TENSOR, - MASK_TENSOR, - THRESHOLD_TOP_TENSOR, - THRESHOLD_BOTTOM_TENSOR, -}; - -using masked_convbias_descriptors = std::tuple; - -masked_convbias_descriptors -create_conv_bias_add_act_mask_descriptors(int64_t* x_dim_padded, - int64_t* padA, - int64_t* convstrideA, - int64_t* dilationA, - int64_t* w_dim_padded, - int64_t* y_dim_padded, - int64_t* threshold_dim, - cudnnDataType_t dataType) { - const int convDim = 2; - - int64_t b_dim_padded[4]; - b_dim_padded[0] = 1; - b_dim_padded[1] = y_dim_padded[1]; - b_dim_padded[2] = 1; - b_dim_padded[3] = 1; - - int64_t x_stride_padded[4]; - int64_t y_stride_padded[4]; - int64_t w_stride_padded[4]; - int64_t b_stride_padded[4]; - int64_t threshold_stride[4]; - - generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC); - generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC); - generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC); - generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC); - generateStrides(threshold_dim, threshold_stride, 4, CUDNN_TENSOR_NHWC); - - return masked_convbias_descriptors(cudnn_frontend::TensorBuilder() - .setDim(4, x_dim_padded) - .setStrides(4, x_stride_padded) - .setId('x') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('y') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, w_dim_padded) - .setStrides(4, w_stride_padded) - .setId('w') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, b_dim_padded) - .setStrides(4, b_stride_padded) - .setId('z') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, b_dim_padded) - .setStrides(4, b_stride_padded) - .setId('b') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setVirtual() - .setId('A') // after add - .setAlignment(16) - .setDataType(CUDNN_DATA_FLOAT) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setVirtual() - .setId('B') // after bias - .setAlignment(16) - .setDataType(CUDNN_DATA_FLOAT) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('C') // after conv - .setAlignment(16) - .setVirtual() - .setDataType(CUDNN_DATA_FLOAT) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('i') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('D') // after optional add - .setAlignment(16) - .setVirtual() - .setDataType(CUDNN_DATA_FLOAT) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('E') // after act for masked - .setAlignment(16) - .setVirtual() - .setDataType(CUDNN_DATA_FLOAT) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('I') // output of the gen index operation - .setAlignment(16) - .setVirtual() - .setDataType(CUDNN_DATA_INT32) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('m') // top half of the mask created after the less than - .setAlignment(16) - .setVirtual() - .setDataType(CUDNN_DATA_BOOLEAN) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('n') // bottom half of the mask - .setAlignment(16) - .setVirtual() - .setDataType(CUDNN_DATA_BOOLEAN) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('M') // OR of the top and bottom masks - .setAlignment(16) - .setVirtual() - .setDataType(CUDNN_DATA_BOOLEAN) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, threshold_dim) - .setStrides(4, threshold_stride) - .setId('t') // threshold for creating the top mask - .setAlignment(16) - .setDataType(CUDNN_DATA_INT32) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, threshold_dim) - .setStrides(4, threshold_stride) - .setId('u') // threshold for creating the bottom mask - .setAlignment(16) - .setDataType(CUDNN_DATA_INT32) - .build()); -} - -// tensor descriptors used for dgrad -enum { - X_OR_DX_TENSOR, - DY_TENSOR, - W_OR_DW_TENSOR, - SCALE_TENSOR, - RELU_TENSOR, - AFTER_DCONV_TENSOR, - AFTER_DRELU_TENSOR, - DGRAD_INPUT_TENSOR, - DGRAD_OPTIONAL_TENSOR, - DGRAD_GEN_INDEX_TENSOR, - DGRAD_MASK_TOP_TENSOR, - DGRAD_MASK_BOTTOM_TENSOR, - DGRAD_MASK_TENSOR, - DGRAD_THRESHOLD_TOP_TENSOR, - DGRAD_THRESHOLD_BOTTOM_TENSOR, -}; - -using dconv_add_descriptors = std::tuple; - -dconv_add_descriptors -create_dconv_add_descriptors(int64_t* x_dim_padded, - int64_t* padA, - int64_t* convstrideA, - int64_t* dilationA, - int64_t* w_dim_padded, - int64_t* y_dim_padded, - cudnnDataType_t dataType) { - const int convDim = 2; - - int64_t b_dim_padded[4]; - b_dim_padded[0] = 1; - b_dim_padded[1] = x_dim_padded[1]; - b_dim_padded[2] = 1; - b_dim_padded[3] = 1; - - int64_t x_stride_padded[4]; - int64_t y_stride_padded[4]; - int64_t w_stride_padded[4]; - int64_t b_stride_padded[4]; - - generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC); - generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC); - generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC); - generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC); - - return dconv_add_descriptors(cudnn_frontend::TensorBuilder() - .setDim(4, x_dim_padded) - .setStrides(4, x_stride_padded) - .setId('x') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('y') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, w_dim_padded) - .setStrides(4, w_stride_padded) - .setId('w') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, b_dim_padded) - .setStrides(4, b_stride_padded) - .setId('s') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, x_dim_padded) - .setStrides(4, x_stride_padded) - .setId('r') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, x_dim_padded) - .setStrides(4, x_stride_padded) - .setVirtual() - .setId('A') // after dconv - .setAlignment(16) - .setDataType(CUDNN_DATA_FLOAT) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, x_dim_padded) - .setStrides(4, x_stride_padded) - .setVirtual() - .setId('B') // after drelu - .setAlignment(16) - .setDataType(CUDNN_DATA_FLOAT) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('i') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('D') // after optional add - .setAlignment(16) - .setVirtual() - .setDataType(CUDNN_DATA_FLOAT) - .build()); -} - -using dconv_mask_descriptors = std::tuple; - -dconv_mask_descriptors -create_dconv_mask_descriptors(int64_t* x_dim_padded, - int64_t* padA, - int64_t* convstrideA, - int64_t* dilationA, - int64_t* w_dim_padded, - int64_t* y_dim_padded, - int64_t* threshold_dim, - cudnnDataType_t dataType) { - const int convDim = 2; - - int64_t b_dim_padded[4]; - b_dim_padded[0] = 1; - b_dim_padded[1] = x_dim_padded[1]; - b_dim_padded[2] = 1; - b_dim_padded[3] = 1; - - int64_t x_stride_padded[4]; - int64_t y_stride_padded[4]; - int64_t w_stride_padded[4]; - int64_t b_stride_padded[4]; - int64_t threshold_stride[4]; - - generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC); - generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC); - generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC); - generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC); - generateStrides(threshold_dim, threshold_stride, 4, CUDNN_TENSOR_NHWC); - - return dconv_mask_descriptors(cudnn_frontend::TensorBuilder() - .setDim(4, x_dim_padded) - .setStrides(4, x_stride_padded) - .setId('x') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('y') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, w_dim_padded) - .setStrides(4, w_stride_padded) - .setId('w') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, b_dim_padded) - .setStrides(4, b_stride_padded) - .setId('s') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, x_dim_padded) - .setStrides(4, x_stride_padded) - .setId('r') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, x_dim_padded) - .setStrides(4, x_stride_padded) - .setVirtual() - .setId('A') // after dconv - .setAlignment(16) - .setDataType(CUDNN_DATA_FLOAT) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, x_dim_padded) - .setStrides(4, x_stride_padded) - .setVirtual() - .setId('B') // after drelu - .setAlignment(16) - .setDataType(CUDNN_DATA_FLOAT) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('i') - .setAlignment(16) - .setDataType(dataType) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('D') // after optional add - .setAlignment(16) - .setVirtual() - .setDataType(CUDNN_DATA_FLOAT) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('I') // output of the gen index operation - .setAlignment(16) - .setVirtual() - .setDataType(CUDNN_DATA_INT32) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('m') // top half of the mask created after the less than - .setAlignment(16) - .setVirtual() - .setDataType(CUDNN_DATA_BOOLEAN) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('n') // bottom half of the mask - .setAlignment(16) - .setVirtual() - .setDataType(CUDNN_DATA_BOOLEAN) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, y_dim_padded) - .setStrides(4, y_stride_padded) - .setId('M') // OR of the top and bottom masks - .setAlignment(16) - .setVirtual() - .setDataType(CUDNN_DATA_BOOLEAN) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, threshold_dim) - .setStrides(4, threshold_stride) - .setId('t') // threshold for creating the top mask - .setAlignment(16) - .setDataType(CUDNN_DATA_INT32) - .build(), - cudnn_frontend::TensorBuilder() - .setDim(4, threshold_dim) - .setStrides(4, threshold_stride) - .setId('u') // threshold for creating the bottom mask - .setAlignment(16) - .setDataType(CUDNN_DATA_INT32) - .build()); -} - -void -run_conv_add_scale_bias_activation(int64_t* x_dim_padded, - int64_t* pad, - int64_t* convstride, - int64_t* dilation, - int64_t* w_dim_padded, - int64_t* y_dim_padded, - cudnnDataType_t dataType, - at::Half* devPtrX, - at::Half* devPtrW, - at::Half* devPtrY, - at::Half* devPtrZ, - at::Half* devPtrB, - at::Half* devPtrI) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); - std::stringstream log_buf; - try { - int convDim = 2; - - // Creates the necessary tensor descriptors - common_convbias_descriptors tensors = create_conv_bias_add_act_descriptors( - x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - - // Define the add operation - auto scaleDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_MUL) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe()); - - // Define the bias operation - auto biasDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_ADD) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, biasDesc.describe()); - - // optional add - auto addDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_ADD) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, addDesc.describe()); - - // Define the activation operation - auto actDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_RELU_FWD) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); - - // Define the convolution problem - auto convDesc = cudnn_frontend::ConvDescBuilder() - .setDataType(CUDNN_DATA_FLOAT) - .setMathMode(CUDNN_CROSS_CORRELATION) - .setNDims(convDim) - .setStrides(convDim, convstride) - .setPrePadding(convDim, pad) - .setPostPadding(convDim, pad) - .setDilation(convDim, dilation) - .build(); - DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); - - float alpha = 1.0f; - float beta = 0.0f; - - // Create a convolution Node - auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) - .setxDesc(std::get(tensors)) - .setwDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setcDesc(convDesc) - .setAlpha(alpha) - .setBeta(beta) - .build(); - DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); - - // create an add node. - auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(conv_op.getOutputTensor()) - .setbDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(addDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, add_op.describe()); - - // Create a Add Node with scaling parameters. - auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(add_op.getOutputTensor()) - .setbDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(scaleDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, scale_op.describe()); - - // Create a Bias Node. - auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(scale_op.getOutputTensor()) - .setbDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(biasDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, bias_op.describe()); - - // Create an Activation Node. - auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(bias_op.getOutputTensor()) - .setyDesc(std::get(tensors)) - .setpwDesc(actDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, act_op.describe()); - - // Create an Operation Graph. In this case it is convolution add bias activation - std::array ops = {&conv_op, &add_op, &scale_op, &bias_op, &act_op}; - - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle_) - .setOperationGraph(ops.size(), ops.data()) - .build(); - - // Create string encoding for plan caching - auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); - DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); - - auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); - DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); - - auto workspace_size = plan.getWorkspaceSize(); - DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); - - void* workspace_ptr = nullptr; - auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); - if (workspace_size > 0) { - workspace_ptr = workspace_tensor.data_ptr(); - } - void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB, devPtrI}; - int64_t uids[] = {'x', 'y', 'w', 'z', 'b', 'i'}; - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace_ptr) - .setDataPointers(6, data_ptrs) - .setUids(6, uids) - .build(); - DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); - cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); - checkCudnnErr(status); - cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); - } catch (cudnn_frontend::cudnnException e) { - std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; - } -} - -void -run_conv_scale_bias_add_activation_mask(int64_t* x_dim_padded, - int64_t* pad, - int64_t* convstride, - int64_t* dilation, - int64_t* w_dim_padded, - int64_t* y_dim_padded, - int64_t* threshold_dim, - cudnnDataType_t dataType, - at::Half* devPtrX, - at::Half* devPtrW, - at::Half* devPtrY, - at::Half* devPtrZ, - at::Half* devPtrB, - at::Half* devPtrI, - int* devPtrT, - int* devPtrU, - int axis) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); - std::stringstream log_buf; - try { - int convDim = 2; - - // Creates the necessary tensor descriptors - masked_convbias_descriptors tensors = create_conv_bias_add_act_mask_descriptors( - x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, threshold_dim, dataType); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - - // Define the add operation - auto scaleDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_MUL) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe()); - - // Define the bias operation - auto biasDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_ADD) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, biasDesc.describe()); - - // optional add - auto addDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_ADD) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, addDesc.describe()); - - // Define the activation operation - auto actDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_RELU_FWD) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); - - // Define the convolution problem - auto convDesc = cudnn_frontend::ConvDescBuilder() - .setDataType(CUDNN_DATA_FLOAT) - .setMathMode(CUDNN_CROSS_CORRELATION) - .setNDims(convDim) - .setStrides(convDim, convstride) - .setPrePadding(convDim, pad) - .setPostPadding(convDim, pad) - .setDilation(convDim, dilation) - .build(); - DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); - - // Define the genIndex descriptor - auto genIndexDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_GEN_INDEX) - .setMathPrecision(CUDNN_DATA_FLOAT) - .setAxis(axis) - .build(); - DEBUG_CUDNN_MSG(log_buf, genIndexDesc.describe()); - - // Define the lessThan descriptor - auto lessThanDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_CMP_LT) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, lessThanDesc.describe()); - - // Define the greaterThan descriptor - auto greaterThanDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_CMP_GT) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, greaterThanDesc.describe()); - - // Define the logical_or descriptor - auto logicalOrDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_LOGICAL_OR) - .setMathPrecision(CUDNN_DATA_BOOLEAN) - .build(); - DEBUG_CUDNN_MSG(log_buf, logicalOrDesc.describe()); - - // Define the binary_selection descriptor - auto selectionDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_BINARY_SELECT) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, selectionDesc.describe()); - - float alpha = 1.0f; - float beta = 0.0f; - - // Create a convolution Node - auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) - .setxDesc(std::get(tensors)) - .setwDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setcDesc(convDesc) - .setAlpha(alpha) - .setBeta(beta) - .build(); - DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); - - // Create a Add Node with scaling parameters. - auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(conv_op.getOutputTensor()) - .setbDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(scaleDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, scale_op.describe()); - - // Create a Bias Node. - auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(scale_op.getOutputTensor()) - .setbDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(biasDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, bias_op.describe()); - - // Create a optional add Node. - auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(bias_op.getOutputTensor()) - .setbDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(addDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, add_op.describe()); - - - // Create an Activation Node. - auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(devPtrI ? add_op.getOutputTensor() : bias_op.getOutputTensor()) - .setyDesc(std::get(tensors)) - .setpwDesc(actDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, act_op.describe()); - - // Create a Gen_Index Node. - auto genIndex_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(genIndexDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, genIndex_op.describe()); - - // Create a LessThan Node. - auto lessThan_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(std::get(tensors)) - .setbDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(lessThanDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, lessThan_op.describe()); - - // Create a GreaterThan Node. - auto greaterThan_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(std::get(tensors)) - .setbDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(greaterThanDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, greaterThan_op.describe()); - - // Create a LogicalOr Node. - auto logicalOr_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(std::get(tensors)) - .setbDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(logicalOrDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, logicalOr_op.describe()); - - // Create a Binary_Selection Node. - auto selection_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(std::get(tensors)) - .setbDesc(std::get(tensors)) - .settDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(selectionDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, selection_op.describe()); - - // Create an Operation Graph. In this case it is convolution add bias activation - if (devPtrI) { - - std::array ops = {&conv_op, &scale_op, &bias_op, &add_op, &act_op, &genIndex_op, &lessThan_op, &greaterThan_op, &logicalOr_op, &selection_op}; - - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle_) - .setOperationGraph(ops.size(), ops.data()) - .build(); - - // Create string encoding for plan caching - auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); - DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); - - auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); - DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); - - auto workspace_size = plan.getWorkspaceSize(); - DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); - - void* workspace_ptr = nullptr; - auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); - if (workspace_size > 0) { - workspace_ptr = workspace_tensor.data_ptr(); - } - void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB, devPtrI, devPtrT, devPtrU}; - int64_t uids[] = {'x', 'y', 'w', 'z', 'b', 'i', 't', 'u'}; - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace_ptr) - .setDataPointers(8, data_ptrs) - .setUids(8, uids) - .build(); - DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); - cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); - checkCudnnErr(status); - cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); - } else { - - std::array ops = {&conv_op, &scale_op, &bias_op, &act_op, &genIndex_op, &lessThan_op, &greaterThan_op, &logicalOr_op, &selection_op}; - - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle_) - .setOperationGraph(ops.size(), ops.data()) - .build(); - - // Create string encoding for plan caching - auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); - DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); - - auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); - DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); - - auto workspace_size = plan.getWorkspaceSize(); - DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); - - void* workspace_ptr = nullptr; - auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); - if (workspace_size > 0) { - workspace_ptr = workspace_tensor.data_ptr(); - } - void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB, devPtrT, devPtrU}; - int64_t uids[] = {'x', 'y', 'w', 'z', 'b', 't', 'u'}; - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace_ptr) - .setDataPointers(7, data_ptrs) - .setUids(7, uids) - .build(); - DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); - cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); - checkCudnnErr(status); - cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); - } - } catch (cudnn_frontend::cudnnException e) { - std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; - } -} - -void -run_dconv_add_drelu_dscale(int64_t* x_dim_padded, - int64_t* pad, - int64_t* convstride, - int64_t* dilation, - int64_t* w_dim_padded, - int64_t* y_dim_padded, - cudnnDataType_t dataType, - at::Half* devPtrX, - at::Half* devPtrW, - at::Half* devPtrY, - at::Half* devPtrZ, - at::Half* devPtrR, - at::Half* devPtrI) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); - std::stringstream log_buf; - try { - int convDim = 2; - - // Creates the necessary tensor descriptors - dconv_add_descriptors tensors = create_dconv_add_descriptors( - x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - - // Define the convolution problem - auto convDesc = cudnn_frontend::ConvDescBuilder() - .setDataType(CUDNN_DATA_FLOAT) - .setMathMode(CUDNN_CROSS_CORRELATION) - .setNDims(convDim) - .setStrides(convDim, convstride) - .setPrePadding(convDim, pad) - .setPostPadding(convDim, pad) - .setDilation(convDim, dilation) - .build(); - DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); - - // optional add - auto addDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_ADD) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, addDesc.describe()); - - // Define the activation backward operation - auto actDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_RELU_BWD) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); - - // Define the scale backward operation - auto scaleDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_MUL) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe()); - - float alpha = 1.0f; - float beta = 0.0f; - - // Create a convolution Node - auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) - .setdxDesc(std::get(tensors)) - .setwDesc(std::get(tensors)) - .setdyDesc(std::get(tensors)) - .setcDesc(convDesc) - .setAlpha(alpha) - .setBeta(beta) - .build(); - DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); - - // Create add Node. - auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(std::get(tensors)) - .setbDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(addDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, add_op.describe()); - - // TODO: do we need getOutputTensor(), and what it returns in backward case? - // Create an relu backward Node. - auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setdyDesc(std::get(tensors)) - .setxDesc(std::get(tensors)) - .setdxDesc(std::get(tensors)) - .setpwDesc(actDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, act_op.describe()); - - // Create a Scale Node. - auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(std::get(tensors)) - .setbDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(scaleDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, scale_op.describe()); - - // Create an Operation Graph. In this case it is convolution add bias activation - std::array ops = {&conv_op, &add_op, &act_op, &scale_op}; - - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle_) - .setOperationGraph(ops.size(), ops.data()) - .build(); - - // Create string encoding for plan caching - auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); - DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); - - auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); - DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); - - auto workspace_size = plan.getWorkspaceSize(); - DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); - - void* workspace_ptr = nullptr; - auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); - if (workspace_size > 0) { - workspace_ptr = workspace_tensor.data_ptr(); - } - void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrR, devPtrI}; - int64_t uids[] = {'x', 'y', 'w', 's', 'r', 'i'}; - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace_ptr) - .setDataPointers(6, data_ptrs) - .setUids(6, uids) - .build(); - DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); - cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); - checkCudnnErr(status); - cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); - } catch (cudnn_frontend::cudnnException e) { - std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; - } -} - -void -run_dconv_drelu_dscale_mask(int64_t* x_dim_padded, - int64_t* pad, - int64_t* convstride, - int64_t* dilation, - int64_t* w_dim_padded, - int64_t* y_dim_padded, - int64_t* threshold_dim, - cudnnDataType_t dataType, - at::Half* devPtrX, - at::Half* devPtrW, - at::Half* devPtrY, - at::Half* devPtrZ, - at::Half* devPtrR, - int* devPtrT, - int* devPtrU, - int axis) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); - std::stringstream log_buf; - try { - int convDim = 2; - - // Creates the necessary tensor descriptors - dconv_mask_descriptors tensors = create_dconv_mask_descriptors( - x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, threshold_dim, dataType); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - DEBUG_CUDNN_MSG(log_buf, std::get(tensors).describe()); - - // Define the convolution problem - auto convDesc = cudnn_frontend::ConvDescBuilder() - .setDataType(CUDNN_DATA_FLOAT) - .setMathMode(CUDNN_CROSS_CORRELATION) - .setNDims(convDim) - .setStrides(convDim, convstride) - .setPrePadding(convDim, pad) - .setPostPadding(convDim, pad) - .setDilation(convDim, dilation) - .build(); - DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); - - // Define the activation backward operation - auto actDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_RELU_BWD) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); - - // Define the scale backward operation - auto scaleDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_MUL) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe()); - - // Define the genIndex descriptor - auto genIndexDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_GEN_INDEX) - .setMathPrecision(CUDNN_DATA_FLOAT) - .setAxis(axis) - .build(); - DEBUG_CUDNN_MSG(log_buf, genIndexDesc.describe()); - - // Define the lessThan descriptor - auto lessThanDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_CMP_LT) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, lessThanDesc.describe()); - - // Define the greaterThan descriptor - auto greaterThanDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_CMP_GT) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, greaterThanDesc.describe()); - - // Define the logical_or descriptor - auto logicalOrDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_LOGICAL_OR) - .setMathPrecision(CUDNN_DATA_BOOLEAN) - .build(); - DEBUG_CUDNN_MSG(log_buf, logicalOrDesc.describe()); - - // Define the binary_selection descriptor - auto selectionDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_BINARY_SELECT) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, selectionDesc.describe()); - - float alpha = 1.0f; - float beta = 0.0f; - - // Create a convolution Node - auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) - .setdxDesc(std::get(tensors)) - .setwDesc(std::get(tensors)) - .setdyDesc(std::get(tensors)) - .setcDesc(convDesc) - .setAlpha(alpha) - .setBeta(beta) - .build(); - DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); - - // TODO: do we need getOutputTensor(), and what it returns in backward case? - // Create an relu backward Node. - auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setdyDesc(std::get(tensors)) - .setxDesc(std::get(tensors)) - .setdxDesc(std::get(tensors)) - .setpwDesc(actDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, act_op.describe()); - - // Create a Scale Node. - auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(std::get(tensors)) - .setbDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(scaleDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, scale_op.describe()); - - // Create a Gen_Index Node. - auto genIndex_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(genIndexDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, genIndex_op.describe()); - - // Create a LessThan Node. - auto lessThan_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(std::get(tensors)) - .setbDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(lessThanDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, lessThan_op.describe()); - - // Create a GreaterThan Node. - auto greaterThan_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(std::get(tensors)) - .setbDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(greaterThanDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, greaterThan_op.describe()); - - // Create a LogicalOr Node. - auto logicalOr_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(std::get(tensors)) - .setbDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(logicalOrDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, logicalOr_op.describe()); - - // Create a Binary_Selection Node. - auto selection_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(std::get(tensors)) - .setbDesc(std::get(tensors)) - .settDesc(std::get(tensors)) - .setyDesc(std::get(tensors)) - .setpwDesc(selectionDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, selection_op.describe()); - - // Create an Operation Graph. In this case it is convolution add bias activation - std::array ops = {&conv_op, &act_op, &scale_op, &genIndex_op, &lessThan_op, &greaterThan_op, &logicalOr_op, &selection_op}; - - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle_) - .setOperationGraph(ops.size(), ops.data()) - .build(); - - // Create string encoding for plan caching - auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag()); - DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); - - auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); - DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); - - auto workspace_size = plan.getWorkspaceSize(); - DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); - - void* workspace_ptr = nullptr; - auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); - if (workspace_size > 0) { - workspace_ptr = workspace_tensor.data_ptr(); - } - void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrR, devPtrT, devPtrU}; - int64_t uids[] = {'x', 'y', 'w', 's', 'r', 't', 'u'}; - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace_ptr) - .setDataPointers(7, data_ptrs) - .setUids(7, uids) - .build(); - DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); - cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); - checkCudnnErr(status); - cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); - } catch (cudnn_frontend::cudnnException e) { - std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; - } -} - -struct bottleneck_forward_status { - - int64_t dimA[4]; - int64_t filterdimA1[4]; - int64_t filterdimA2[4]; - int64_t filterdimA2hh[4]; - int64_t filterdimA3[4]; - int64_t filterdimA4[4]; - - int64_t threshdim[4]; - - int axis[4]; - - int64_t outdimA0[4]; - int64_t outdimA1[4]; - int64_t outdimA1b[4]; // out1_pad - int64_t outdimA2[4]; - int64_t outdimA3[4]; - int64_t outdimA4[4]; - - int64_t padA[2]; - int64_t padA1[2]; - int64_t padA2[2]; // halo padding - int64_t dilationA[2]; - int64_t convstrideA[2]; - int64_t convstride1X1[2]; - - int64_t outdim0[4]; // halo input shape - int64_t outdim1[4]; - int64_t outdim1b[4]; - int64_t outdim2[4]; - int64_t outdim3[4]; - int64_t outdim4[4]; // halo output shape - - void init(bool explicit_nhwc, int stride_1X1, std::vector inputs) { - dimA[0] = dimA[1] = dimA[2] = dimA[3] = 0; - filterdimA1[0] = filterdimA1[1] = filterdimA1[2] = filterdimA1[3] = 0; - filterdimA2[0] = filterdimA2[1] = filterdimA2[2] = filterdimA2[3] = 0; - filterdimA2hh[0] = filterdimA2hh[1] = filterdimA2hh[2] = filterdimA2hh[3] = 0; - filterdimA3[0] = filterdimA3[1] = filterdimA3[2] = filterdimA3[3] = 0; - filterdimA4[0] = filterdimA4[1] = filterdimA4[2] = filterdimA4[3] = 0; - threshdim[0] = threshdim[1] = threshdim[2] = threshdim[3] = 1; - - // All dim calculation after this order of n,c,h,w - if (explicit_nhwc) { - axis[0] = 0; - axis[1] = 3; - axis[2] = 1; - axis[3] = 2; - } else { - axis[0] = 0; - axis[1] = 1; - axis[2] = 2; - axis[3] = 3; - } - - for (int dim=0;dim<4;dim++) { - dimA[dim] = inputs[0].size(axis[dim]); - filterdimA1[dim] = inputs[1].size(axis[dim]); - filterdimA2[dim] = inputs[2].size(axis[dim]); - filterdimA3[dim] = inputs[3].size(axis[dim]); - } - if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) { - for (int dim=0;dim<4;dim++) { - filterdimA4[dim] = inputs[10].size(axis[dim]); - } - } - for (int dim=0;dim<4;dim++) { - if (dim == 2) { - filterdimA2hh[dim] = 1; - } else { - filterdimA2hh[dim] = filterdimA2[dim]; - } - } - - // output dim in n,c,h,w used by backend - outdimA0[0] = outdimA0[1] = outdimA0[2] = outdimA0[3] = 0; - outdimA1[0] = outdimA1[1] = outdimA1[2] = outdimA1[3] = 0; - outdimA1b[0] = outdimA1b[1] = outdimA1b[2] = outdimA1b[3] = 0; - outdimA2[0] = outdimA2[1] = outdimA2[2] = outdimA2[3] = 0; - outdimA3[0] = outdimA3[1] = outdimA3[2] = outdimA3[3] = 0; - outdimA4[0] = outdimA4[1] = outdimA4[2] = outdimA4[3] = 0; - - // use these fixed value for test run - padA[0] = 0; padA[1] = 0; - padA1[0] = 1; padA1[1] = 1; - padA2[0] = 0; padA2[1] = 1; - dilationA[0] = 1; dilationA[1] = 1; - convstrideA[0] = 1; convstrideA[1] = 1; - convstride1X1[0] = stride_1X1; convstride1X1[1] = stride_1X1; - - // compute output from pad/stride/dilation - outdimA1[0] = dimA[0]; - outdimA1[1] = filterdimA1[0]; - for (int dim = 0; dim < 2; dim++) { - outdimA1[dim + 2] = getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]); - } - for (int dim = 0; dim < 4; dim++) { - if (dim == 2) { - outdimA1b[dim] = outdimA1[dim] + 2; - } else { - outdimA1b[dim] = outdimA1[dim]; - } - } - - outdimA2[0] = outdimA1[0]; - outdimA2[1] = filterdimA2[0]; - for (int dim = 0; dim < 2; dim++) { - outdimA2[dim + 2] = getFwdConvOutputDim(outdimA1[dim + 2], padA1[dim], filterdimA2[dim + 2], convstrideA[dim], dilationA[dim]); - } - - for (int dim = 0; dim < 4; dim++) { - if (dim == 2) { - outdimA0[dim] = 3; - outdimA4[dim] = 1; - } else { - outdimA0[dim] = outdimA1[dim]; - outdimA4[dim] = outdimA2[dim]; - } - } - - outdimA3[0] = outdimA2[0]; - outdimA3[1] = filterdimA3[0]; - for (int dim = 0; dim < 2; dim++) { - outdimA3[dim + 2] = getFwdConvOutputDim(outdimA2[dim + 2], padA[dim], filterdimA3[dim + 2], convstrideA[dim], dilationA[dim]); - } - - // Create output tensor in the correct shape in pytorch's view - outdim1[0] = outdim1[1] = outdim1[2] = outdim1[3] = 0; - outdim1b[0] = outdim1b[1] = outdim1b[2] = outdim1b[3] = 0; - outdim2[0] = outdim2[1] = outdim2[2] = outdim2[3] = 0; - outdim3[0] = outdim3[1] = outdim3[2] = outdim3[3] = 0; - if (explicit_nhwc) { - axis[0] = 0; - axis[1] = 2; - axis[2] = 3; - axis[3] = 1; - } - for (int dim=0;dim<4;dim++) { - outdim0[dim] = outdimA0[axis[dim]]; - outdim1[dim] = outdimA1[axis[dim]]; - outdim1b[dim] = outdimA1b[axis[dim]]; - outdim2[dim] = outdimA2[axis[dim]]; - outdim3[dim] = outdimA3[axis[dim]]; - outdim4[dim] = outdimA4[axis[dim]]; - } - } -}; - -bottleneck_forward_status forward_state; - -} // end of anonymous namespace - -std::vector bottleneck_forward_init(bool explicit_nhwc, int stride_1X1, std::vector inputs) { - // NB! Bottleneck_forward and bottleneck_backward are NOT thread safe method. - // NB! We use a global object to store state. - forward_state.init(explicit_nhwc, stride_1X1, inputs); - - // create output vector - std::vector outputs; - auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; - - //printf("outdim1 = (%d,%d,%d,%d)\n",forward_state.outdim1[0],forward_state.outdim1[1],forward_state.outdim1[2],forward_state.outdim1[3]); - auto out1 = at::empty(forward_state.outdim1, inputs[0].type(), output_format); - auto out2 = at::empty(forward_state.outdim2, inputs[0].type(), output_format); - auto out3 = at::empty(forward_state.outdim3, inputs[0].type(), output_format); - - outputs.push_back(out1); - outputs.push_back(out2); - outputs.push_back(out3); - - return outputs; -} - -// inputs contains x,w,z,b,(i) -void bottleneck_forward_out1(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs) { - - std::cout << std::fixed; - - // run - at::Half* x = inputs[0].data_ptr(); - at::Half* w = inputs[1].data_ptr(); - at::Half* z = inputs[4].data_ptr(); - at::Half* b = inputs[7].data_ptr(); - auto out1 = outputs[0]; - at::Half* y1 = out1.data_ptr(); - - run_conv_scale_bias_add_activation(forward_state.dimA, - forward_state.padA, - forward_state.convstride1X1, - forward_state.dilationA, - forward_state.filterdimA1, - forward_state.outdimA1, - CUDNN_DATA_HALF, - x, - w, - y1, - z, - b, - nullptr); - - DEBUG_MSG("[DEBUG] new relu1 : " << out1.to(at::kFloat).sum().item()); -} - -// computes halo (top or bottom) from fat halo input. -// fat halo input is 3 pixels wide in H. -at::Tensor bottleneck_forward_out2_halo(bool explicit_nhwc, at::Tensor fat_halo_y1, std::vector inputs) { - - auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; - - // run - at::Half* w = inputs[2].data_ptr(); - at::Half* z = inputs[5].data_ptr(); - at::Half* b = inputs[8].data_ptr(); - - at::Half* y1 = fat_halo_y1.data_ptr(); - - auto halo_y2 = at::empty(forward_state.outdim4, inputs[0].type(), output_format); - at::Half* y2 = halo_y2.data_ptr(); - - run_conv_scale_bias_add_activation(forward_state.outdimA0, - forward_state.padA2, - forward_state.convstrideA, - forward_state.dilationA, - forward_state.filterdimA2, - forward_state.outdimA4, - CUDNN_DATA_HALF, - y1, - w, - y2, - z, - b, - nullptr); - - return halo_y2; -} - -// compute halo correction term (top or bottom) from slim halo input (N,C,1,W). -// slim halo input is 1 pixel wide in H. -at::Tensor bottleneck_forward_out2_halo_corr(bool explicit_nhwc, at::Tensor slim_halo_y1, std::vector inputs, at::Tensor w1by3, at::Tensor out2_part_halo) { - - auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; - - // run - at::Half* w = w1by3.data_ptr(); // C,C,1,3 - at::Half* z = inputs[5].data_ptr(); - at::Half* b = inputs[8].data_ptr(); - - at::Half* y1 = slim_halo_y1.data_ptr(); - - at::Half* prev_out2 = out2_part_halo.data_ptr(); - - auto halo_y2 = at::empty(forward_state.outdim4, inputs[0].type(), output_format); - at::Half* y2 = halo_y2.data_ptr(); - - run_conv_add_scale_bias_activation(forward_state.outdimA4, - forward_state.padA2, - forward_state.convstrideA, - forward_state.dilationA, - forward_state.filterdimA2hh, - forward_state.outdimA4, - CUDNN_DATA_HALF, - y1, - w, - y2, - z, - b, - prev_out2); - - return halo_y2; -} - -void bottleneck_forward_out2(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs) { - - std::cout << std::fixed; - - // from _out1 method - at::Half* x = inputs[0].data_ptr(); - auto out1 = outputs[0]; - at::Half* y1 = out1.data_ptr(); - - // run - at::Half* w = inputs[2].data_ptr(); - at::Half* z = inputs[5].data_ptr(); - at::Half* b = inputs[8].data_ptr(); - auto out2 = outputs[1]; - at::Half* y2 = out2.data_ptr(); - - //printf("forward_state.outdimA1 = {%d,%d,%d,%d}\n",forward_state.outdimA1[0],forward_state.outdimA1[1],forward_state.outdimA1[2],forward_state.outdimA1[3]); - //printf("forward_state.padA1 = {%d,%d}\n",forward_state.padA1[0],forward_state.padA1[1]); - //printf("forward_state.convstrideA = {%d,%d}\n",forward_state.convstrideA[0],forward_state.convstrideA[1]); - //printf("forward_state.dilationA = {%d,%d}\n",forward_state.dilationA[0],forward_state.dilationA[1]); - //printf("forward_state.filterdimA2 = {%d,%d,%d,%d}\n",forward_state.filterdimA2[0],forward_state.filterdimA2[1],forward_state.filterdimA2[2],forward_state.filterdimA2[3]); - //printf("forward_state.outdimA2 = {%d,%d,%d,%d}\n",forward_state.outdimA2[0],forward_state.outdimA2[1],forward_state.outdimA2[2],forward_state.outdimA2[3]); - run_conv_scale_bias_add_activation(forward_state.outdimA1, - forward_state.padA1, - forward_state.convstrideA, - forward_state.dilationA, - forward_state.filterdimA2, - forward_state.outdimA2, - CUDNN_DATA_HALF, - y1, - w, - y2, - z, - b, - nullptr); - DEBUG_MSG("[DEBUG] new relu2 : " << out2.to(at::kFloat).sum().item()); -} - -void bottleneck_forward_out2_mask(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs, at::Tensor thresholdTop, at::Tensor thresholdBottom) { - - std::cout << std::fixed; - - // from _out1 method - at::Half* x = inputs[0].data_ptr(); - auto out1 = outputs[0]; - at::Half* y1 = out1.data_ptr(); - - // run - at::Half* w = inputs[2].data_ptr(); - at::Half* z = inputs[5].data_ptr(); - at::Half* b = inputs[8].data_ptr(); - auto out2 = outputs[1]; - at::Half* y2 = out2.data_ptr(); - - //printf("forward_state.outdimA1 = {%d,%d,%d,%d}\n",forward_state.outdimA1[0],forward_state.outdimA1[1],forward_state.outdimA1[2],forward_state.outdimA1[3]); - //printf("forward_state.padA1 = {%d,%d}\n",forward_state.padA1[0],forward_state.padA1[1]); - //printf("forward_state.convstrideA = {%d,%d}\n",forward_state.convstrideA[0],forward_state.convstrideA[1]); - //printf("forward_state.dilationA = {%d,%d}\n",forward_state.dilationA[0],forward_state.dilationA[1]); - //printf("forward_state.filterdimA2 = {%d,%d,%d,%d}\n",forward_state.filterdimA2[0],forward_state.filterdimA2[1],forward_state.filterdimA2[2],forward_state.filterdimA2[3]); - //printf("forward_state.outdimA2 = {%d,%d,%d,%d}\n",forward_state.outdimA2[0],forward_state.outdimA2[1],forward_state.outdimA2[2],forward_state.outdimA2[3]); - run_conv_scale_bias_add_activation_mask(forward_state.outdimA1, - forward_state.padA1, - forward_state.convstrideA, - forward_state.dilationA, - forward_state.filterdimA2, - forward_state.outdimA2, - forward_state.threshdim, - CUDNN_DATA_HALF, - y1, - w, - y2, - z, - b, - nullptr, - thresholdTop.data_ptr(), - thresholdBottom.data_ptr(), - 2); // axis == 1 -> Does this assume explicit NHWC? - DEBUG_MSG("[DEBUG] new relu2 : " << out2.to(at::kFloat).sum().item()); -} - -void bottleneck_forward_out2_pad(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs, at::Tensor out1_pad) { - - std::cout << std::fixed; - - // from _out1 method - at::Half* x = inputs[0].data_ptr(); - auto out1 = outputs[0]; - at::Half* y1 = out1_pad.data_ptr(); - - // run - at::Half* w = inputs[2].data_ptr(); - at::Half* z = inputs[5].data_ptr(); - at::Half* b = inputs[8].data_ptr(); - auto out2 = outputs[1]; - at::Half* y2 = out2.data_ptr(); - - //printf("forward_state.outdimA1 = {%d,%d,%d,%d}\n",forward_state.outdimA1[0],forward_state.outdimA1[1],forward_state.outdimA1[2],forward_state.outdimA1[3]); - //printf("forward_state.padA1 = {%d,%d}\n",forward_state.padA1[0],forward_state.padA1[1]); - //printf("forward_state.convstrideA = {%d,%d}\n",forward_state.convstrideA[0],forward_state.convstrideA[1]); - //printf("forward_state.dilationA = {%d,%d}\n",forward_state.dilationA[0],forward_state.dilationA[1]); - //printf("forward_state.filterdimA2 = {%d,%d,%d,%d}\n",forward_state.filterdimA2[0],forward_state.filterdimA2[1],forward_state.filterdimA2[2],forward_state.filterdimA2[3]); - //printf("forward_state.outdimA2 = {%d,%d,%d,%d}\n",forward_state.outdimA2[0],forward_state.outdimA2[1],forward_state.outdimA2[2],forward_state.outdimA2[3]); - run_conv_scale_bias_add_activation(forward_state.outdimA1b, - forward_state.padA2, - forward_state.convstrideA, - forward_state.dilationA, - forward_state.filterdimA2, - forward_state.outdimA2, - CUDNN_DATA_HALF, - y1, - w, - y2, - z, - b, - nullptr); - DEBUG_MSG("[DEBUG] new relu2 : " << out2.to(at::kFloat).sum().item()); -} - -void bottleneck_forward_rest(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs) { - - std::cout << std::fixed; - - // from _out1 method - at::Half* x = inputs[0].data_ptr(); - - // create output of conv3 - auto out3 = outputs[2]; - at::Half* y3 = out3.data_ptr(); - - // create output of conv4 that may exist - auto identity = at::empty_like(out3); - at::Half* yi = identity.data_ptr(); - - at::Half *w, *z, *b; - - if (stride_1X1 != 1 || forward_state.filterdimA3[0] != forward_state.dimA[1]){ - - w = inputs[10].data_ptr(); - z = inputs[11].data_ptr(); - b = inputs[12].data_ptr(); - run_conv_scale_bias(forward_state.dimA, - forward_state.padA, - forward_state.convstride1X1, - forward_state.dilationA, - forward_state.filterdimA4, - forward_state.outdimA3, - CUDNN_DATA_HALF, - x, - w, - yi, - z, - b); - DEBUG_MSG("[DEBUG] new downsample : " << identity.to(at::kFloat).sum().item()); - } - else { - yi = x; - } - - auto out2 = outputs[1]; - at::Half* y2 = out2.data_ptr(); - - w = inputs[3].data_ptr(); - z = inputs[6].data_ptr(); - b = inputs[9].data_ptr(); - - run_conv_scale_bias_add_activation(forward_state.outdimA2, - forward_state.padA, - forward_state.convstrideA, - forward_state.dilationA, - forward_state.filterdimA3, - forward_state.outdimA3, - CUDNN_DATA_HALF, - y2, - w, - y3, - z, - b, - yi); - DEBUG_MSG("[DEBUG] new relu3 : " << out3.to(at::kFloat).sum().item()); -} - -namespace { - -struct bottleneck_backward_state { - - int64_t dimA[4]; - int64_t filterdimA1[4]; - int64_t filterdimA2[4]; - int64_t filterdimA3[4]; - int64_t filterdimA4[4]; - int64_t filterdimA2hh[4]; // Cin,Cout,1,3 - int64_t threshdim[4]; - - int axis[4]; - - int64_t outdimA1[4]; // grad_out1 - int64_t outdimA1b[4]; // out1_pad - int64_t outdimA2[4]; // grad_out2 - int64_t outdimA3[4]; - int64_t outdimA1h[4]; // output: grad_out1 halo (H=3) - int64_t outdimA2h[4]; // input : grad_out2 halo cells (H=3) - int64_t outdimA1hh[4]; // input: grad_out2 halo (H=1) - int64_t outdimA2hh[4]; // input: out1 halo (H=1) - - int64_t padA[2]; - int64_t padA1[2]; - int64_t padA2[2]; - int64_t dilationA[2]; - int64_t convstrideA[2]; - int64_t convstride1X1[2]; - - int64_t filterdim2hh[4]; // Cin,1,3,Cout - - int64_t outdim1[4]; - int64_t outdim1b[4]; - int64_t outdim2[4]; - int64_t outdim3[4]; - int64_t outdim1h[4]; - int64_t outdim1hh[4]; - - void init(bool explicit_nhwc, int stride_1X1, std::vector inputs) { - // setup dimensions - dimA[0] = dimA[1] = dimA[2] = dimA[3] = 0; - filterdimA1[0] = filterdimA1[1] = filterdimA1[2] = filterdimA1[3] = 0; - filterdimA2[0] = filterdimA2[1] = filterdimA2[2] = filterdimA2[3] = 0; - filterdimA3[0] = filterdimA3[1] = filterdimA3[2] = filterdimA3[3] = 0; - filterdimA4[0] = filterdimA4[1] = filterdimA4[2] = filterdimA4[3] = 0; - filterdimA2hh[0] = filterdimA2hh[1] = filterdimA2hh[2] = filterdimA2hh[3] = 0; - threshdim[0] = threshdim[1] = threshdim[2] = threshdim[3] = 1; - - // All dim calculation after this order of n,c,h,w - if (explicit_nhwc) { - axis[0] = 0; - axis[1] = 3; - axis[2] = 1; - axis[3] = 2; - } else { - axis[0] = 0; - axis[1] = 1; - axis[2] = 2; - axis[3] = 3; - } - - for (int dim=0;dim<4;dim++) { - dimA[dim] = inputs[0].size(axis[dim]); - filterdimA1[dim] = inputs[1].size(axis[dim]); - filterdimA2[dim] = inputs[2].size(axis[dim]); - filterdimA3[dim] = inputs[3].size(axis[dim]); - } - if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) { - for (int dim=0;dim<4;dim++) { - filterdimA4[dim] = inputs[14].size(axis[dim]); - } - } - - for (int dim=0;dim<4;dim++) { - if (dim == 2) { - filterdimA2hh[dim] = 1; - } else { - filterdimA2hh[dim] = filterdimA2[dim]; - } - } - - // output dim in n,c,h,w used by backend - outdimA1[0] = outdimA1[1] = outdimA1[2] = outdimA1[3] = 0; - outdimA1b[0] = outdimA1b[1] = outdimA1b[2] = outdimA1b[3] = 0; - outdimA2[0] = outdimA2[1] = outdimA2[2] = outdimA2[3] = 0; - outdimA3[0] = outdimA3[1] = outdimA3[2] = outdimA3[3] = 0; - outdimA1h[0] = outdimA1h[1] = outdimA1h[2] = outdimA1h[3] = 0; - outdimA2h[0] = outdimA2h[1] = outdimA2h[2] = outdimA2h[3] = 0; - outdimA1hh[0] = outdimA1hh[1] = outdimA1hh[2] = outdimA1hh[3] = 0; - outdimA2hh[0] = outdimA2hh[1] = outdimA2hh[2] = outdimA2hh[3] = 0; - - // use these fixed value for test run - padA[0] = 0; padA[1] = 0; - padA1[0] = 1; padA1[1] = 1; - padA2[0] = 0; padA2[1] = 1; - dilationA[0] = 1; dilationA[1] = 1; - convstrideA[0] = 1; convstrideA[1] = 1; - convstride1X1[0] = stride_1X1; convstride1X1[1] = stride_1X1; - - // compute output from pad/stride/dilation - outdimA1[0] = dimA[0]; - outdimA1[1] = filterdimA1[0]; - for (int dim = 0; dim < 2; dim++) { - outdimA1[dim + 2] = getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]); - } - for (int dim = 0; dim < 4; dim++) { - if (dim == 2) { - outdimA1b[dim] = outdimA1[dim] + 2; - } else { - outdimA1b[dim] = outdimA1[dim]; - } - } - - outdimA2[0] = outdimA1[0]; - outdimA2[1] = filterdimA2[0]; - for (int dim = 0; dim < 2; dim++) { - outdimA2[dim + 2] = getFwdConvOutputDim(outdimA1[dim + 2], padA1[dim], filterdimA2[dim + 2], convstrideA[dim], dilationA[dim]); - } - - outdimA3[0] = outdimA2[0]; - outdimA3[1] = filterdimA3[0]; - for (int dim = 0; dim < 2; dim++) { - outdimA3[dim + 2] = getFwdConvOutputDim(outdimA2[dim + 2], padA[dim], filterdimA3[dim + 2], convstrideA[dim], dilationA[dim]); - } - - for (int dim = 0; dim < 4; dim++) { - if (dim == 2) { - outdimA1h[dim] = 3; - outdimA2h[dim] = 3; - outdimA1hh[dim] = 1; - outdimA2hh[dim] = 1; - } else { - outdimA1h[dim] = outdimA1[dim]; - outdimA2h[dim] = outdimA2[dim]; - outdimA1hh[dim] = outdimA1[dim]; - outdimA2hh[dim] = outdimA2[dim]; - } - } - - // Create output tensor in the correct shape in pytorch's view - outdim1[0] = outdim1[1] = outdim1[2] = outdim1[3] = 0; - outdim1b[0] = outdim1b[1] = outdim1b[2] = outdim1b[3] = 0; - outdim2[0] = outdim2[1] = outdim2[2] = outdim2[3] = 0; - outdim3[0] = outdim3[1] = outdim3[2] = outdim3[3] = 0; - outdim1h[0] = outdim1h[1] = outdim1h[2] = outdim1h[3] = 0; - outdim1hh[0] = outdim1hh[1] = outdim1hh[2] = outdim1hh[3] = 0; - filterdim2hh[0] = filterdim2hh[1] = filterdim2hh[2] = filterdim2hh[3] = 0; - if (explicit_nhwc) { - axis[0] = 0; - axis[1] = 2; - axis[2] = 3; - axis[3] = 1; - } - for (int dim=0;dim<4;dim++) { - outdim1[dim] = outdimA1[axis[dim]]; - outdim1b[dim] = outdimA1b[axis[dim]]; - outdim2[dim] = outdimA2[axis[dim]]; - outdim3[dim] = outdimA3[axis[dim]]; - outdim1h[dim] = outdimA1h[axis[dim]]; - outdim1hh[dim] = outdimA1hh[axis[dim]]; - filterdim2hh[dim] = filterdimA2hh[axis[dim]]; - } - } -}; - -bottleneck_backward_state backward_state; - -} - -std::vector bottleneck_backward_init(bool explicit_nhwc, int stride_1X1, std::vector inputs) { - - std::cout << std::fixed; - - backward_state.init(explicit_nhwc, stride_1X1, inputs); - - // create output vector - std::vector outputs; - auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; - - auto grad_x = at::empty_like(inputs[0]); - auto wgrad1 = at::empty_like(inputs[1]); - auto wgrad2 = at::empty_like(inputs[2]); - auto wgrad3 = at::empty_like(inputs[3]); - - outputs.push_back(grad_x); - outputs.push_back(wgrad1); - outputs.push_back(wgrad2); - outputs.push_back(wgrad3); - if (stride_1X1 != 1 || backward_state.filterdimA3[0] != backward_state.dimA[1]) { - auto wgrad4 = at::empty_like(inputs[14]); - outputs.push_back(wgrad4); - } - - return outputs; -} - -void bottleneck_backward_wgrad3(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs) { - - // dconv3+drelu2+dscale2 - at::Half* conv_in = inputs[13].data_ptr(); - at::Half* dy3 = inputs[10].data_ptr(); - - // wgrad - auto wgrad3 = outputs[3]; - at::Half* dw3 = wgrad3.data_ptr(); - run_dconv(backward_state.outdimA2, - backward_state.padA, - backward_state.convstrideA, - backward_state.dilationA, - backward_state.filterdimA3, - backward_state.outdimA3, - CUDNN_DATA_HALF, - conv_in, - dw3, - dy3, - CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); - DEBUG_MSG("[DEBUG] new wgrad3 : " << wgrad3.to(at::kFloat).sum().item()); - -} - -at::Tensor bottleneck_backward_grad_out2(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs) { - - bool requires_grad = inputs[0].requires_grad(); - - std::cout << std::fixed; - auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; - - // dconv3+drelu2+dscale2 - at::Half* conv_in = inputs[13].data_ptr(); - at::Half* dy3 = inputs[10].data_ptr(); - - DEBUG_MSG("[DEBUG] new dconv3 : " << inputs[10].to(at::kFloat).sum().item()); - - // dgrad - auto grad_out2 = at::empty(backward_state.outdim2, inputs[0].type(), output_format); - at::Half* dy2 = grad_out2.data_ptr(); - at::Half* w = inputs[3].data_ptr(); - at::Half* z = inputs[5].data_ptr(); - - at::Half* relu2 = inputs[13].data_ptr(); - - run_dconv_drelu_dscale(backward_state.outdimA2, - backward_state.padA, - backward_state.convstrideA, - backward_state.dilationA, - backward_state.filterdimA3, - backward_state.outdimA3, - CUDNN_DATA_HALF, - dy2, - w, - dy3, - z, - relu2); - - // do halo exchange of dy2 here - - DEBUG_MSG("[DEBUG] new dconv2 : " << grad_out2.to(at::kFloat).sum().item()); - - return grad_out2; -} - -at::Tensor bottleneck_backward_grad_out1(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs, at::Tensor grad_out2) { - - bool requires_grad = inputs[0].requires_grad(); - - std::cout << std::fixed; - auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; - - // dgrad - at::Half* dy2 = grad_out2.data_ptr(); - - // dgrad - auto grad_out1 = at::empty(backward_state.outdim1, inputs[0].type(), output_format); - at::Half* dy1 = grad_out1.data_ptr(); - at::Half* w = inputs[2].data_ptr(); - at::Half* z = inputs[4].data_ptr(); - - at::Half* relu1 = inputs[12].data_ptr(); - //printf("relu.shape = [%d,%d,%d,%d]\n",inputs[12].size(0),inputs[12].size(1),inputs[12].size(2),inputs[12].size(3)); - - // fused dgrad - //printf("backward_state.outdim1 = {%d,%d,%d,%d}\n",backward_state.outdim1[0],backward_state.outdim1[1],backward_state.outdim1[2],backward_state.outdim1[3]); - run_dconv_drelu_dscale(backward_state.outdimA1, - backward_state.padA1, - backward_state.convstrideA, - backward_state.dilationA, - backward_state.filterdimA2, - backward_state.outdimA2, - CUDNN_DATA_HALF, - dy1, - w, - dy2, - z, - relu1); - - return grad_out1; -} - -at::Tensor bottleneck_backward_grad_out1_mask(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs, at::Tensor grad_out2, at::Tensor thresholdTop, at::Tensor thresholdBottom) { - - bool requires_grad = inputs[0].requires_grad(); - - std::cout << std::fixed; - auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; - - // dgrad - at::Half* dy2 = grad_out2.data_ptr(); - - // dgrad - auto grad_out1 = at::empty(backward_state.outdim1, inputs[0].type(), output_format); - at::Half* dy1 = grad_out1.data_ptr(); - at::Half* w = inputs[2].data_ptr(); - at::Half* z = inputs[4].data_ptr(); - - at::Half* relu1 = inputs[12].data_ptr(); - //printf("relu.shape = [%d,%d,%d,%d]\n",inputs[12].size(0),inputs[12].size(1),inputs[12].size(2),inputs[12].size(3)); - - // fused dgrad - run_dconv_drelu_dscale_mask(backward_state.outdimA1, - backward_state.padA1, - backward_state.convstrideA, - backward_state.dilationA, - backward_state.filterdimA2, - backward_state.outdimA2, - backward_state.threshdim, - CUDNN_DATA_HALF, - dy1, - w, - dy2, - z, - relu1, - thresholdTop.data_ptr(), - thresholdBottom.data_ptr(), - 2); - - return grad_out1; -} - -// perform backward data 1x3 convolution (grad_out * w_rot180) on grad_out2 input of shape [N,1,W,C] with padding=(0,1) to produce output of shape [N,1,W,C] -at::Tensor bottleneck_backward_grad_out1_halo_corr(bool explicit_nhwc, int stride_1X1, std::vector inputs, at::Tensor w1by3, std::vector outputs, at::Tensor grad_out2_halo, at::Tensor relu1_halo, at::Tensor part_grad_out1) { - - bool requires_grad = inputs[0].requires_grad(); - - std::cout << std::fixed; - auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; - - // dgrad - at::Half* dy2h = grad_out2_halo.data_ptr(); - - // dgrad - auto grad_out1_halo = at::empty(backward_state.outdim1hh, inputs[0].type(), output_format); - at::Half* dy1h = grad_out1_halo.data_ptr(); - //at::Half* w = inputs[2].data_ptr(); // use w1by3 instead, which is a sliced version of inputs[2] - at::Half* w = w1by3.data_ptr(); - at::Half* z = inputs[4].data_ptr(); - at::Half* relu1h = relu1_halo.data_ptr(); - at::Half* pdy1h = part_grad_out1.data_ptr(); - - //printf("relu.shape = [%d,%d,%d,%d]\n",relu1_halo.size(0),relu1_halo.size(1),relu1_halo.size(2),relu1_halo.size(3)); - // fused dgrad - //printf("backward_state.outdimA1h = {%d,%d,%d,%d}\n",backward_state.outdimA1h[0],backward_state.outdimA1h[1],backward_state.outdimA1h[2],backward_state.outdimA1h[3]); - //printf("backward_state.outdimA2h = {%d,%d,%d,%d}\n",backward_state.outdimA2h[0],backward_state.outdimA2h[1],backward_state.outdimA2h[2],backward_state.outdimA2h[3]); - //printf("backward_state.filterdimA2 = {%d,%d,%d,%d}\n",backward_state.filterdimA2[0],backward_state.filterdimA2[1],backward_state.filterdimA2[2],backward_state.filterdimA2[3]); - run_dconv_add_drelu_dscale(backward_state.outdimA1hh, - backward_state.padA2, // 0,1 - backward_state.convstrideA, - backward_state.dilationA, - backward_state.filterdimA2hh, // C,1,3,C - backward_state.outdimA2hh, - CUDNN_DATA_HALF, - dy1h, - w, - dy2h, - z, - relu1h, - pdy1h); - - return grad_out1_halo; -} - -// perform backward data 3x3 convolution (grad_out * w_rot180) on grad_out2 input of shape [N,3,W,C] with padding=(1,1) to produce output of shape [N,3,W,C] -at::Tensor bottleneck_backward_grad_out1_halo(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs, at::Tensor grad_out2_halo, at::Tensor relu1_halo) { - - bool requires_grad = inputs[0].requires_grad(); - - std::cout << std::fixed; - auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; - - // dgrad - at::Half* dy2h = grad_out2_halo.data_ptr(); - - // dgrad - auto grad_out1_halo = at::empty(backward_state.outdim1h, inputs[0].type(), output_format); - at::Half* dy1h = grad_out1_halo.data_ptr(); - at::Half* w = inputs[2].data_ptr(); - at::Half* z = inputs[4].data_ptr(); - - at::Half* relu1h = relu1_halo.data_ptr(); - //printf("relu.shape = [%d,%d,%d,%d]\n",relu1_halo.size(0),relu1_halo.size(1),relu1_halo.size(2),relu1_halo.size(3)); - // fused dgrad - //printf("backward_state.outdimA1h = {%d,%d,%d,%d}\n",backward_state.outdimA1h[0],backward_state.outdimA1h[1],backward_state.outdimA1h[2],backward_state.outdimA1h[3]); - //printf("backward_state.outdimA2h = {%d,%d,%d,%d}\n",backward_state.outdimA2h[0],backward_state.outdimA2h[1],backward_state.outdimA2h[2],backward_state.outdimA2h[3]); - //printf("backward_state.filterdimA2 = {%d,%d,%d,%d}\n",backward_state.filterdimA2[0],backward_state.filterdimA2[1],backward_state.filterdimA2[2],backward_state.filterdimA2[3]); - run_dconv_drelu_dscale(backward_state.outdimA1h, - backward_state.padA1, - backward_state.convstrideA, - backward_state.dilationA, - backward_state.filterdimA2, - backward_state.outdimA2h, - CUDNN_DATA_HALF, - dy1h, - w, - dy2h, - z, - relu1h); - - return grad_out1_halo; -} - -void bottleneck_backward_wgrad2_pad(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs, at::Tensor input, at::Tensor grad_out2) { - - std::cout << std::fixed; - auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; - - // dgrad - at::Half* dy2 = grad_out2.data_ptr(); - - // dconv2+drelu1+dscale1 - at::Half* conv_in = input.data_ptr(); - - // wgrad - auto wgrad2 = outputs[2]; - at::Half* dw2 = wgrad2.data_ptr(); - - //printf("outdimA1b = (%d,%d,%d,%d)\n",backward_state.outdimA1b[0],backward_state.outdimA1b[1],backward_state.outdimA1b[2],backward_state.outdimA1b[3]); - //printf("backward_state.padA2 = {%d,%d}\n",backward_state.padA2[0],backward_state.padA2[1]); - run_dconv(backward_state.outdimA1b, // conv_in.shape (including H halos) - backward_state.padA2, // 0, 1 - backward_state.convstrideA, - backward_state.dilationA, - backward_state.filterdimA2, // dw2.shape - backward_state.outdimA2, // dy2.shape - CUDNN_DATA_HALF, - conv_in, - dw2, - dy2, - CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); - DEBUG_MSG("[DEBUG] new wgrad2 : " << wgrad2.to(at::kFloat).sum().item()); -} - -void bottleneck_backward_wgrad2(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs, at::Tensor grad_out2) { - - bool requires_grad = inputs[0].requires_grad(); - - std::cout << std::fixed; - auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; - - // dgrad - at::Half* dy2 = grad_out2.data_ptr(); - - // dconv2+drelu1+dscale1 - at::Half* conv_in = inputs[12].data_ptr(); - - // wgrad - auto wgrad2 = outputs[2]; - at::Half* dw2 = wgrad2.data_ptr(); - - //printf("outdimA1 = (%d,%d,%d,%d)\n",backward_state.outdimA1[0],backward_state.outdimA1[1],backward_state.outdimA1[2],backward_state.outdimA1[3]); - run_dconv(backward_state.outdimA1, - backward_state.padA1, - backward_state.convstrideA, - backward_state.dilationA, - backward_state.filterdimA2, - backward_state.outdimA2, - CUDNN_DATA_HALF, - conv_in, - dw2, - dy2, - CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); - DEBUG_MSG("[DEBUG] new wgrad2 : " << wgrad2.to(at::kFloat).sum().item()); -} - -// compute halo cells for input volume of dimension [N,1,W,C] with padding=(0,1) to produce output volume of dimension [N,1,W,C] -// input and grad_out2_halo tensors are all of same shape -// output tensor is of shape [Cin,1,3,Cout] (regular filter dims are [Cin,3,3,Cout] -at::Tensor bottleneck_backward_wgrad2_halo(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs, at::Tensor input, at::Tensor grad_out2_halo) { - - bool requires_grad = inputs[0].requires_grad(); - - std::cout << std::fixed; - auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; - - // dgrad - at::Half* dy2 = grad_out2_halo.data_ptr(); - - // dconv2+drelu1+dscale1 - at::Half* conv_in = input.data_ptr(); - - // wgrad - auto wgrad2_halo = at::empty(backward_state.filterdim2hh, input.type(), output_format); - at::Half* dw2 = wgrad2_halo.data_ptr(); - - //printf("backward_state.outdimA1hh = {%d,%d,%d,%d}\n",backward_state.outdimA1hh[0],backward_state.outdimA1hh[1],backward_state.outdimA1hh[2],backward_state.outdimA1hh[3]); - //printf("backward_state.outdimA2hh = {%d,%d,%d,%d}\n",backward_state.outdimA2hh[0],backward_state.outdimA2hh[1],backward_state.outdimA2hh[2],backward_state.outdimA2hh[3]); - //printf("backward_state.filterdim2hh = {%d,%d,%d,%d}\n",backward_state.filterdim2hh[0],backward_state.filterdim2hh[1],backward_state.filterdim2hh[2],backward_state.filterdim2hh[3]); - //printf("backward_state.filterdimA2hh = {%d,%d,%d,%d}\n",backward_state.filterdimA2hh[0],backward_state.filterdimA2hh[1],backward_state.filterdimA2hh[2],backward_state.filterdimA2hh[3]); - //printf("backward_state.padA2 = {%d,%d}\n",backward_state.padA2[0],backward_state.padA2[1]); - run_dconv(backward_state.outdimA1hh, // N,C,1,W - backward_state.padA2, // 0, 1 - backward_state.convstrideA, - backward_state.dilationA, - backward_state.filterdimA2hh, // Cin,Cout,1,3 - backward_state.outdimA2hh, // N,C,1,W - CUDNN_DATA_HALF, - conv_in, - dw2, - dy2, - CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); - - return wgrad2_halo; -} - -void bottleneck_backward_wgrad1(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs, at::Tensor grad_out1) { - - at::Half* x = inputs[0].data_ptr(); - at::Half* dy1 = grad_out1.data_ptr(); - - // dconv1+add - // wgrad - auto wgrad1 = outputs[1]; - at::Half* dw1 = wgrad1.data_ptr(); - run_dconv(backward_state.dimA, - backward_state.padA, - backward_state.convstride1X1, - backward_state.dilationA, - backward_state.filterdimA1, - backward_state.outdimA1, - CUDNN_DATA_HALF, - x, - dw1, - dy1, - CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); - -} - -void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector inputs, std::vector outputs, at::Tensor grad_out2, at::Tensor grad_out1) { - - bool requires_grad = inputs[0].requires_grad(); - - std::cout << std::fixed; - auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast; - - // dgrad - at::Half* dy2 = grad_out2.data_ptr(); - at::Half* dy1 = grad_out1.data_ptr(); - -/* - // backward strided conv cannot be fused - // if stride == 1 but channel changes, we can fuse here - if (stride_1X1 != 1){ - // dgrad - run_dconv(outdimA1, - padA1, - convstride1X1, - dilationA, - filterdimA2, - outdimA2, - CUDNN_DATA_HALF, - dy1, - w, - dy2, - CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR); - - // mul fused mask - grad_out1.mul_(inputs[15]); - } - else { - at::Half* relu1 = inputs[12].data_ptr(); - // fused dgrad - run_dconv_drelu_dscale(outdimA1, - padA1, - convstride1X1, - dilationA, - filterdimA2, - outdimA2, - CUDNN_DATA_HALF, - dy1, - w, - dy2, - z, - relu1); - } -*/ - DEBUG_MSG("[DEBUG] new dconv1 : " << grad_out1.to(at::kFloat).sum().item()); - - // create grads of conv4 that may exist - auto grad_x_conv4 = at::empty_like(inputs[0]); - at::Half* dx_conv4 = grad_x_conv4.data_ptr(); - at::Tensor wgrad4; - - // x used for dconv1 and dconv4 wgrad - at::Half* x = inputs[0].data_ptr(); - - at::Half* w = NULL; - - if (stride_1X1 != 1 || backward_state.filterdimA3[0] != backward_state.dimA[1]){ - w = inputs[14].data_ptr(); - at::Half* dy_conv4 = inputs[11].data_ptr(); - if (requires_grad) { - run_dconv(backward_state.dimA, - backward_state.padA, - backward_state.convstride1X1, - backward_state.dilationA, - backward_state.filterdimA4, - backward_state.outdimA3, - CUDNN_DATA_HALF, - dx_conv4, - w, - dy_conv4, - CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR); - // we don't print here since we can't hook out this grad in pytorch alone to compare, due to addition with dx - // DEBUG_MSG("[DEBUG] new dx_identity : " << grad_x_conv4.to(at::kFloat).sum().item()); - } - // wgrad - wgrad4 = outputs[4]; - at::Half* dw4 = wgrad4.data_ptr(); - run_dconv(backward_state.dimA, - backward_state.padA, - backward_state.convstride1X1, - backward_state.dilationA, - backward_state.filterdimA4, - backward_state.outdimA3, - CUDNN_DATA_HALF, - x, - dw4, - dy_conv4, - CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); - } - else { - // if there is no downsample, dx_conv4 is fork of drelu3 - dx_conv4 = inputs[11].data_ptr(); - } - - // dgrad - w = inputs[1].data_ptr(); - auto grad_x = outputs[0]; - at::Half* dx = grad_x.data_ptr(); - - // backward strided conv cannot be fused - // if stride == 1 but channel changes, we can fuse here - if (requires_grad){ - if (stride_1X1 != 1){ - run_dconv(backward_state.dimA, - backward_state.padA, - backward_state.convstride1X1, - backward_state.dilationA, - backward_state.filterdimA1, - backward_state.outdimA1, - CUDNN_DATA_HALF, - dx, - w, - dy1, - CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR); - // add 2 together - grad_x.add_(grad_x_conv4); - } - else { - run_dconv_add(backward_state.dimA, - backward_state.padA, - backward_state.convstride1X1, - backward_state.dilationA, - backward_state.filterdimA1, - backward_state.outdimA1, - CUDNN_DATA_HALF, - dx, - w, - dy1, - dx_conv4); - } - } - - DEBUG_MSG("[DEBUG] new dx : " << grad_x.to(at::kFloat).sum().item()); - DEBUG_MSG("[DEBUG] new wgrad1 : " << wgrad1.to(at::kFloat).sum().item()); - - if (stride_1X1 != 1 || backward_state.filterdimA3[0] != backward_state.dimA[1]) { - DEBUG_MSG("[DEBUG] new wgrad4 : " << wgrad4.to(at::kFloat).sum().item()); - } -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &bottleneck_forward, "Bottleneck block forward"); - m.def("backward", &bottleneck_backward, "Bottleneck block backward"); - m.def("forward_init", &bottleneck_forward_init, "Bottleneck block init"); - m.def("forward_out1", &bottleneck_forward_out1, "Bottleneck block forward"); - m.def("forward_out2", &bottleneck_forward_out2, "Bottleneck block forward"); - m.def("forward_out2_mask", &bottleneck_forward_out2_mask, "Bottleneck block forward"); - m.def("forward_out2_halo", &bottleneck_forward_out2_halo, "Bottleneck block forward"); - m.def("forward_out2_halo_corr", &bottleneck_forward_out2_halo_corr, "Bottleneck block forward"); - m.def("forward_out2_pad", &bottleneck_forward_out2_pad, "Bottleneck block forward"); - m.def("forward_rest", &bottleneck_forward_rest, "Bottleneck block forward"); - m.def("backward_init", &bottleneck_backward_init, "Bottleneck block backward init"); - m.def("backward_grad_out2", &bottleneck_backward_grad_out2, "Bottleneck block backward"); - m.def("backward_grad_out1", &bottleneck_backward_grad_out1, "Bottleneck block backward"); - m.def("backward_grad_out1_mask", &bottleneck_backward_grad_out1_mask, "Bottleneck block backward"); - m.def("backward_grad_out1_halo", &bottleneck_backward_grad_out1_halo, "Bottleneck block backward"); - m.def("backward_grad_out1_halo_corr", &bottleneck_backward_grad_out1_halo_corr, "Bottleneck block backward"); - m.def("backward_wgrad2_pad", &bottleneck_backward_wgrad2_pad, "Bottleneck block backward"); - m.def("backward_wgrad2", &bottleneck_backward_wgrad2, "Bottleneck block backward"); - m.def("backward_wgrad2_halo", &bottleneck_backward_wgrad2_halo, "Bottleneck block backward"); - m.def("backward_wgrad3", &bottleneck_backward_wgrad3, "Bottleneck block backward"); - m.def("backward_wgrad1", &bottleneck_backward_wgrad1, "Bottleneck block backward"); - m.def("backward_rest", &bottleneck_backward_rest, "Bottleneck block backward"); -} diff --git a/apex/contrib/csrc/conv_bias_relu/conv_bias_relu.cpp b/apex/contrib/csrc/conv_bias_relu/conv_bias_relu.cpp deleted file mode 100644 index 66f89ef..0000000 --- a/apex/contrib/csrc/conv_bias_relu/conv_bias_relu.cpp +++ /dev/null @@ -1,1639 +0,0 @@ -#include -#include // for getcudnnhandle -#include -#include -#include -#include - -#include - -#ifdef DEBUG -#define DEBUG_MSG(str) do { std::cout << str << std::endl; } while( false ) -#else -#define DEBUG_MSG(str) do { } while ( false ) -#endif - -#ifdef DEBUG_CUDNN -#define DEBUG_CUDNN_MSG(buf, str) do { buf << str << std::endl; } while( false ) -#else -#define DEBUG_CUDNN_MSG(buf, str) do { } while ( false ) -#endif - -#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(at::MemoryFormat::ChannelsLast), #x " must be contiguous") -#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) - -#define checkCudnnErr(...) \ - do { \ - int err = checkCudnnError(__VA_ARGS__, #__VA_ARGS__, __FILE__, __LINE__); \ - if (err) { \ - return; \ - } \ - } while (0) - - -int checkCudnnError(cudnnStatus_t code, const char* expr, const char* file, int line) { - if (code) { - printf("CUDNN error at %s:%d, code=%d (%s) in '%s'\n", file, line, (int)code, cudnnGetErrorString(code), expr); - return 1; - } - return 0; -} - -void checkError(cudaError_t code, char const * func, const char *file, const int line, bool abort = true); -#define checkCUDAError(val) { checkError((val), #val, __FILE__, __LINE__); } // in-line regular function - -void checkError(cudaError_t code, char const * func, const char *file, const int line, bool abort) { - if (code != cudaSuccess) - { - const char * errorMessage = cudaGetErrorString(code); - fprintf(stderr, "CUDA error returned from \"%s\" at %s:%d, Error code: %d (%s)\n", func, file, line, code, errorMessage); - if (abort){ - cudaDeviceReset(); - exit(code); - } - } -} - -void generateStrides(const int64_t* dimA, int64_t* strideA, int nbDims, cudnnTensorFormat_t filterFormat) { - // For INT8x4 and INT8x32 we still compute standard strides here to input - // into the cuDNN functions. We will manually scale by resizeFactor in the cpu ref. - if (filterFormat == CUDNN_TENSOR_NCHW) { - strideA[nbDims - 1] = 1; - for (int64_t d = nbDims - 2; d >= 0; d--) { - strideA[d] = strideA[d + 1] * dimA[d + 1]; - } - } else { - // Here we assume that the format is CUDNN_TENSOR_NHWC - strideA[1] = 1; - strideA[nbDims - 1] = strideA[1] * dimA[1]; - for (int64_t d = nbDims - 2; d >= 2; d--) { - strideA[d] = strideA[d + 1] * dimA[d + 1]; - } - strideA[0] = strideA[2] * dimA[2]; - } -} - - -int getFwdConvDilatedFilterDim(int filterDim, int dilation) { - return ((filterDim - 1) * dilation) + 1; -} - - -int getFwdConvPaddedImageDim(int tensorDim, int pad) { - return tensorDim + (2 * pad); -} - - -int getFwdConvOutputDim(int tensorDim, - int pad, - int filterDim, - int stride, - int dilation) { - int p = (getFwdConvPaddedImageDim(tensorDim, pad) - getFwdConvDilatedFilterDim(filterDim, dilation)) / stride + 1; - return (p); -} - - -// create a cache for plan -std::unordered_map plan_cache; - - -std::string getConvFusionString(int64_t* x_dim_padded, - int64_t* padA, - int64_t* convstrideA, - int64_t* dilationA, - int64_t* w_dim_padded, - cudnnDataType_t dataType, - std::string fusion_string) { - - for(int i=0;i<4;i++) { - fusion_string += 'X'; - fusion_string += std::to_string(x_dim_padded[i]); - } - for(int i=0;i<4;i++) { - fusion_string += 'W'; - fusion_string += std::to_string(w_dim_padded[i]); - } - for(int i=0;i<2;i++) { - fusion_string += 'P'; - fusion_string += std::to_string(padA[i]); - } - for(int i=0;i<2;i++) { - fusion_string += 'S'; - fusion_string += std::to_string(convstrideA[i]); - } - for(int i=0;i<2;i++) { - fusion_string += 'D'; - fusion_string += std::to_string(dilationA[i]); - } - fusion_string += 'T'; - fusion_string += std::to_string(dataType); - return fusion_string; -} - - -cudnn_frontend::ExecutionPlan& getOrCreatePlan(cudnnHandle_t handle_, - std::stringstream& log_buf, - cudnn_frontend::OperationGraph& opGraph, - std::string cache_string, - bool use_heuristic = true){ - auto it = plan_cache.find(cache_string); - if (it != plan_cache.end()) { - DEBUG_CUDNN_MSG(log_buf, "Found plan in cache"); - return it->second; - } else { - if (use_heuristic){ - // TODO: confirm which mode to use - auto heuristics = cudnn_frontend::EngineHeuristicsBuilder() - .setOperationGraph(opGraph) - .setHeurMode(CUDNN_HEUR_MODE_INSTANT) - .build(); - // try 3 times for now as WAR for no heuristic training - int max_tries = 3, count = 0; - auto& engine_configs = heuristics.getEngineConfig(max_tries); - while(true) { - try { - plan_cache.emplace(cache_string, std::move(cudnn_frontend::ExecutionPlanBuilder() - .setHandle(handle_) - .setEngineConfig(engine_configs[count], opGraph.getTag()) - .build())); - break; - } catch (cudnn_frontend::cudnnException e) { - if (++count == max_tries) throw e; - } - } - }else{ - DEBUG_CUDNN_MSG(log_buf, "No plan in cache"); - // How many engines support this operation graph ? - auto total_engines = opGraph.getEngineCount(); - DEBUG_CUDNN_MSG(log_buf, opGraph.describe() << " has " << total_engines << " engines."); - // We have to randomly pick one engine from [0, total_engines) - // Selecting "0" by default - auto engine = cudnn_frontend::EngineBuilder().setGlobalEngineIdx(0).setOperationGraph(opGraph).build(); - DEBUG_CUDNN_MSG(log_buf, engine.describe()); - auto& knobs = engine.getSupportedKnobs(); - for (auto it = std::begin(knobs); it != std::end(knobs); ++it) { - DEBUG_CUDNN_MSG(log_buf, it->describe()); - } - if (knobs.begin() != knobs.end()) { - DEBUG_CUDNN_MSG(log_buf, "Updated knob choice"); - knobs.begin()->setChoice(knobs.begin()->getMinValue() + 1); - DEBUG_CUDNN_MSG(log_buf, knobs.begin()->describe()); - } - - // Createmplacee the requisite engine config - auto engine_config = cudnn_frontend::EngineConfigBuilder().setEngine(engine).build(); - DEBUG_CUDNN_MSG(log_buf, engine_config.describe()); - plan_cache.emplace(cache_string, std::move(cudnn_frontend::ExecutionPlanBuilder().setHandle(handle_).setEngineConfig(engine_config).build())); - } - - return plan_cache.find(cache_string)->second; - } -} - - -void -run_conv_bias(int64_t* x_dim, - int64_t* w_dim, - int64_t* y_dim, - int64_t* conv_pad, - int64_t* convstride, - int64_t* dilation, - cudnnDataType_t dataType, - at::Half* devPtrX, - at::Half* devPtrW, - at::Half* devPtrB, - at::Half* devPtrY) { - - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); - std::stringstream log_buf; - - try { - int convDim = 2; - float alpha = 1.0f; - float beta = 0.0f; - int64_t b_dim[] = {1, y_dim[1], 1, 1}; - - // Creates the necessary tensor descriptors - int64_t stride[4]; - generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto xTensor = cudnn_frontend::TensorBuilder() - .setDim(4, x_dim) - .setStrides(4, stride) - .setId('x') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, xTensor.describe()); - - generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto wTensor = cudnn_frontend::TensorBuilder() - .setDim(4, w_dim) - .setStrides(4, stride) - .setId('w') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, wTensor.describe()); - - generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto afterConvTensor = cudnn_frontend::TensorBuilder() - .setDim(4, y_dim) - .setStrides(4, stride) - .setId('c') - .setAlignment(16) - .setDataType(CUDNN_DATA_FLOAT) - .setVirtual() - .build(); - DEBUG_CUDNN_MSG(log_buf, afterConvTensor.describe()); - - generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto bTensor = cudnn_frontend::TensorBuilder() - .setDim(4, b_dim) - .setStrides(4, stride) - .setId('b') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, bTensor.describe()); - - generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto afterBiasTensor = cudnn_frontend::TensorBuilder() - .setDim(4, y_dim) - .setStrides(4, stride) - .setId('y') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, afterBiasTensor.describe()); - - // Define the bias operation - auto biasDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_ADD) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, biasDesc.describe()); - - // Define the convolution problem - auto convDesc = cudnn_frontend::ConvDescBuilder() - .setDataType(CUDNN_DATA_FLOAT) - .setMathMode(CUDNN_CROSS_CORRELATION) - .setNDims(convDim) - .setStrides(convDim, convstride) - .setPrePadding(convDim, conv_pad) - .setPostPadding(convDim, conv_pad) - .setDilation(convDim, dilation) - .build(); - DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); - - - // Create a convolution Node - auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) - .setxDesc(xTensor) - .setwDesc(wTensor) - .setyDesc(afterConvTensor) - .setcDesc(convDesc) - .setAlpha(alpha) - .setBeta(beta) - .build(); - DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); - - // Create a Bias Node. - auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(conv_op.getOutputTensor()) - .setbDesc(bTensor) - .setyDesc(afterBiasTensor) - .setpwDesc(biasDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, bias_op.describe()); - - // Create an Operation Graph. In this case it is convolution bias activation - std::array ops = {&conv_op, &bias_op}; - - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle_) - .setOperationGraph(2, ops.data()) - .build(); - - // Create string encoding for plan caching - auto cache_string = getConvFusionString(x_dim, conv_pad, convstride, dilation, w_dim, dataType, opGraph.getTag()); - DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); - - auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); - DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); - - auto workspace_size = plan.getWorkspaceSize(); - DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); - - void* workspace_ptr = nullptr; - auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); - if (workspace_size > 0) { - workspace_ptr = workspace_tensor.data_ptr(); - } - void* data_ptrs[] = {devPtrX, devPtrW, devPtrB, devPtrY}; - int64_t uids[] = {'x', 'w', 'b', 'y'}; - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace_ptr) - .setDataPointers(4, data_ptrs) - .setUids(4, uids) - .build(); - DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); - cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); - checkCudnnErr(status); - cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); - } catch (cudnn_frontend::cudnnException e) { - std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; - } -} - - -void -run_conv_bias_mask_relu(int64_t* x_dim, - int64_t* w_dim, - int64_t* y_dim, - int64_t* conv_pad, - int64_t* conv_stride, - int64_t* conv_dilation, - cudnnDataType_t dataType, - at::Half* devPtrX, - at::Half* devPtrW, - at::Half* devPtrB, - int8_t* devPtrM, - at::Half* devPtrY) { - - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); - std::stringstream log_buf; - - try { - int conv_dim = 2; - float alpha = 1.0f; - float beta = 0.0f; - int64_t b_dim[] = {1, y_dim[1], 1, 1}; - - // Creates the necessary tensor descriptors - int64_t stride[4]; - generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto xTensor = cudnn_frontend::TensorBuilder() - .setDim(4, x_dim) - .setStrides(4, stride) - .setId('x') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, xTensor.describe()); - - generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto wTensor = cudnn_frontend::TensorBuilder() - .setDim(4, w_dim) - .setStrides(4, stride) - .setId('w') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, wTensor.describe()); - - generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto mTensor = cudnn_frontend::TensorBuilder() - .setDim(4, y_dim) - .setStrides(4, stride) - .setId('m') - .setAlignment(16) - .setDataType(CUDNN_DATA_INT8) - .build(); - DEBUG_CUDNN_MSG(log_buf, wTensor.describe()); - - generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto afterConvTensor = cudnn_frontend::TensorBuilder() - .setDim(4, y_dim) - .setStrides(4, stride) - .setId('c') - .setAlignment(16) - .setDataType(CUDNN_DATA_FLOAT) - .setVirtual() - .build(); - DEBUG_CUDNN_MSG(log_buf, afterConvTensor.describe()); - - generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto bTensor = cudnn_frontend::TensorBuilder() - .setDim(4, b_dim) - .setStrides(4, stride) - .setId('b') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, bTensor.describe()); - - generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto afterBiasTensor = cudnn_frontend::TensorBuilder() - .setDim(4, y_dim) - .setStrides(4, stride) - .setId('B') - .setAlignment(16) - .setDataType(CUDNN_DATA_FLOAT) - .setVirtual() - .build(); - DEBUG_CUDNN_MSG(log_buf, afterBiasTensor.describe()); - - generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto afterMaskTensor = cudnn_frontend::TensorBuilder() - .setDim(4, y_dim) - .setStrides(4, stride) - .setId('M') - .setAlignment(16) - .setDataType(CUDNN_DATA_FLOAT) - .setVirtual() - .build(); - DEBUG_CUDNN_MSG(log_buf, afterBiasTensor.describe()); - - generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto afterReLUTensor = cudnn_frontend::TensorBuilder() - .setDim(4, y_dim) - .setStrides(4, stride) - .setId('y') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, afterReLUTensor.describe()); - - // Define the convolution problem - auto convDesc = cudnn_frontend::ConvDescBuilder() - .setDataType(CUDNN_DATA_FLOAT) - .setMathMode(CUDNN_CROSS_CORRELATION) - .setNDims(conv_dim) - .setStrides(conv_dim, conv_stride) - .setPrePadding(conv_dim, conv_pad) - .setPostPadding(conv_dim, conv_pad) - .setDilation(conv_dim, conv_dilation) - .build(); - DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); - - // Define the bias operation - auto biasDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_ADD) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, biasDesc.describe()); - - // Define the mask operation - auto maskDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_MUL) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - - // Define the activation operation - auto actDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_RELU_FWD) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); - - // Create a convolution Node - auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) - .setxDesc(xTensor) - .setwDesc(wTensor) - .setyDesc(afterConvTensor) - .setcDesc(convDesc) - .setAlpha(alpha) - .setBeta(beta) - .build(); - DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); - - // Create a Bias Node - auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(conv_op.getOutputTensor()) - .setbDesc(bTensor) - .setyDesc(afterBiasTensor) - .setpwDesc(biasDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, bias_op.describe()); - - // create a Mask Node - auto mask_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(bias_op.getOutputTensor()) - .setbDesc(mTensor) - .setyDesc(afterMaskTensor) - .setpwDesc(maskDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, mask_op.describe()); - - // Create an Activation Node - auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(mask_op.getOutputTensor()) - .setyDesc(afterReLUTensor) - .setpwDesc(actDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, act_op.describe()); - - // Create an Operation Graph. In this case it is convolution bias activation - std::array ops = {&conv_op, &bias_op, &mask_op, &act_op}; - - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle_) - .setOperationGraph(4, ops.data()) - .build(); - - // Create string encoding for plan caching - auto cache_string = getConvFusionString(x_dim, conv_pad, conv_stride, conv_dilation, w_dim, dataType, opGraph.getTag()); - DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); - - auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); - DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); - - auto workspace_size = plan.getWorkspaceSize(); - DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); - - void* workspace_ptr = nullptr; - auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); - if (workspace_size > 0) { - workspace_ptr = workspace_tensor.data_ptr(); - } - void* data_ptrs[] = {devPtrX, devPtrW, devPtrB, devPtrM, devPtrY}; - int64_t uids[] = {'x', 'w', 'b', 'm', 'y'}; - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace_ptr) - .setDataPointers(5, data_ptrs) - .setUids(5, uids) - .build(); - DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); - cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); - checkCudnnErr(status); - cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); - } catch (cudnn_frontend::cudnnException e) { - std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; - } -} - - -void -run_conv_bias_relu(int64_t* x_dim, - int64_t* w_dim, - int64_t* y_dim, - int64_t* conv_pad, - int64_t* conv_stride, - int64_t* conv_dilation, - cudnnDataType_t dataType, - at::Half* devPtrX, - at::Half* devPtrW, - at::Half* devPtrB, - at::Half* devPtrY) { - - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); - std::stringstream log_buf; - - try { - int conv_dim = 2; - float alpha = 1.0f; - float beta = 0.0f; - int64_t b_dim[] = {1, y_dim[1], 1, 1}; - - // Creates the necessary tensor descriptors - int64_t stride[4]; - generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto xTensor = cudnn_frontend::TensorBuilder() - .setDim(4, x_dim) - .setStrides(4, stride) - .setId('x') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, xTensor.describe()); - - generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto wTensor = cudnn_frontend::TensorBuilder() - .setDim(4, w_dim) - .setStrides(4, stride) - .setId('w') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, wTensor.describe()); - - generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto afterConvTensor = cudnn_frontend::TensorBuilder() - .setDim(4, y_dim) - .setStrides(4, stride) - .setId('c') - .setAlignment(16) - .setDataType(CUDNN_DATA_FLOAT) - .setVirtual() - .build(); - DEBUG_CUDNN_MSG(log_buf, afterConvTensor.describe()); - - generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto bTensor = cudnn_frontend::TensorBuilder() - .setDim(4, b_dim) - .setStrides(4, stride) - .setId('b') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, bTensor.describe()); - - generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto afterBiasTensor = cudnn_frontend::TensorBuilder() - .setDim(4, y_dim) - .setStrides(4, stride) - .setId('B') - .setAlignment(16) - .setDataType(CUDNN_DATA_FLOAT) - .setVirtual() - .build(); - DEBUG_CUDNN_MSG(log_buf, afterBiasTensor.describe()); - - generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto afterReLUTensor = cudnn_frontend::TensorBuilder() - .setDim(4, y_dim) - .setStrides(4, stride) - .setId('y') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, afterReLUTensor.describe()); - - // Define the convolution problem - auto convDesc = cudnn_frontend::ConvDescBuilder() - .setDataType(CUDNN_DATA_FLOAT) - .setMathMode(CUDNN_CROSS_CORRELATION) - .setNDims(conv_dim) - .setStrides(conv_dim, conv_stride) - .setPrePadding(conv_dim, conv_pad) - .setPostPadding(conv_dim, conv_pad) - .setDilation(conv_dim, conv_dilation) - .build(); - DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); - - // Define the bias operation - auto biasDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_ADD) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, biasDesc.describe()); - - // Define the activation operation - auto actDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_RELU_FWD) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); - - // Create a convolution Node - auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) - .setxDesc(xTensor) - .setwDesc(wTensor) - .setyDesc(afterConvTensor) - .setcDesc(convDesc) - .setAlpha(alpha) - .setBeta(beta) - .build(); - DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); - - // Create a Bias Node. - auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(conv_op.getOutputTensor()) - .setbDesc(bTensor) - .setyDesc(afterBiasTensor) - .setpwDesc(biasDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, bias_op.describe()); - - // Create an Activation Node. - auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(bias_op.getOutputTensor()) - .setyDesc(afterReLUTensor) - .setpwDesc(actDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, act_op.describe()); - - // Create an Operation Graph. In this case it is convolution bias activation - std::array ops = {&conv_op, &bias_op, &act_op}; - - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle_) - .setOperationGraph(3, ops.data()) - .build(); - - // Create string encoding for plan caching - auto cache_string = getConvFusionString(x_dim, conv_pad, conv_stride, conv_dilation, w_dim, dataType, opGraph.getTag()); - DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); - - auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); - DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); - - auto workspace_size = plan.getWorkspaceSize(); - DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); - - void* workspace_ptr = nullptr; - auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); - if (workspace_size > 0) { - workspace_ptr = workspace_tensor.data_ptr(); - } - void* data_ptrs[] = {devPtrX, devPtrW, devPtrB, devPtrY}; - int64_t uids[] = {'x', 'w', 'b', 'y'}; - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace_ptr) - .setDataPointers(4, data_ptrs) - .setUids(4, uids) - .build(); - DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); - cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); - checkCudnnErr(status); - cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); - } catch (cudnn_frontend::cudnnException e) { - std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; - } -} - - -void -run_drelu_dbias(int64_t* dy_dim, - cudnnDataType_t dataType, - at::Half* devPtrDY, - at::Half* devPtrR, - at::Half* devPtrDR, - float* devPtrDB) { - - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); - std::stringstream log_buf; - - try { - int convDim = 2; - float alpha = 1.0f; - float beta = 0.0f; - int64_t b_dim[] = {1, dy_dim[1], 1, 1}; - - // Creates the necessary tensor descriptors - int64_t stride[4]; - generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto dyTensor = cudnn_frontend::TensorBuilder() - .setDim(4, dy_dim) - .setStrides(4, stride) - .setId('x') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, dyTensor.describe()); - - generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto rTensor = cudnn_frontend::TensorBuilder() - .setDim(4, dy_dim) - .setStrides(4, stride) - .setId('r') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, rTensor.describe()); - - generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto inActGradTensor = cudnn_frontend::TensorBuilder() - .setDim(4, dy_dim) - .setStrides(4, stride) - .setId('R') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, inActGradTensor.describe()); - - generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto biasGradTensor = cudnn_frontend::TensorBuilder() - .setDim(4, b_dim) - .setStrides(4, stride) - .setId('y') - .setAlignment(16) - .setDataType(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, biasGradTensor.describe()); - - // Define the activation backward operation - auto actDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_RELU_BWD) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); - - // Define the bias backward operation - auto biasDesc = cudnn_frontend::ReductionDescBuilder() - .setMathPrecision(CUDNN_DATA_FLOAT) - .setReductionOp(CUDNN_REDUCE_TENSOR_ADD) - .build(); - DEBUG_CUDNN_MSG(log_buf, biasDesc.describe()); - - // Create an relu backward Node - auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setdyDesc(dyTensor) - .setxDesc(rTensor) - .setdxDesc(inActGradTensor) - .setpwDesc(actDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, act_op.describe()); - - // Create bias node - auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) - .setxDesc(inActGradTensor) - .setyDesc(biasGradTensor) - .setreductionDesc(biasDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, bias_op.describe()); - - // Create an Operation Graph. In this case it is bias only - std::array ops = {&act_op, &bias_op}; - - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle_) - .setOperationGraph(ops.size(), ops.data()) - .build(); - - // Create string encoding for plan caching - // creating unique dummy values - int64_t pad_dummy[] = {20, 20}; - int64_t stride_dummy[] = {20, 20}; - int64_t dilation_dummy[] = {20, 20}; - auto cache_string = getConvFusionString(dy_dim, pad_dummy, stride_dummy, dilation_dummy, b_dim, dataType, opGraph.getTag()); - DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); - - auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); - DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); - - auto workspace_size = plan.getWorkspaceSize(); - DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); - - void* workspace_ptr = nullptr; - auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); - if (workspace_size > 0) { - workspace_ptr = workspace_tensor.data_ptr(); - } - void* data_ptrs[] = {devPtrDY, devPtrR, devPtrDR, devPtrDB}; - int64_t uids[] = {'x', 'r', 'R', 'y'}; - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace_ptr) - .setDataPointers(4, data_ptrs) - .setUids(4, uids) - .build(); - DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); - cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); - checkCudnnErr(status); - cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); - } catch (cudnn_frontend::cudnnException e) { - std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; - } -} - - -void -run_dconv_drelu_dbias(int64_t* x_dim, - int64_t* w_dim, - int64_t* y_dim, - int64_t* pad, - int64_t* convstride, - int64_t* dilation, - cudnnDataType_t dataType, - at::Half* devPtrX, - at::Half* devPtrW, - at::Half* devPtrR, - at::Half* devPtrRg, - float* devPtrY) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); - std::stringstream log_buf; - try { - int convDim = 2; - float alpha = 1.0f; - float beta = 0.0f; - int64_t b_dim[] = {1, x_dim[1], 1, 1}; - - int64_t stride[4]; - generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto outConvGradTensor = cudnn_frontend::TensorBuilder() - .setDim(4, y_dim) - .setStrides(4, stride) - .setId('x') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, outConvGradTensor.describe()); - - generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto wTensor = cudnn_frontend::TensorBuilder() - .setDim(4, w_dim) - .setStrides(4, stride) - .setId('w') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, wTensor.describe()); - - generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto inConvGradTensor = cudnn_frontend::TensorBuilder() - .setDim(4, x_dim) - .setStrides(4, stride) - .setId('A') - .setAlignment(16) - .setDataType(CUDNN_DATA_FLOAT) - .setVirtual() - .build(); - DEBUG_CUDNN_MSG(log_buf, inConvGradTensor.describe()); - - generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto rTensor = cudnn_frontend::TensorBuilder() - .setDim(4, x_dim) - .setStrides(4, stride) - .setId('r') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, rTensor.describe()); - - generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto inReLUGradTensor = cudnn_frontend::TensorBuilder() - .setDim(4, x_dim) - .setStrides(4, stride) - .setId('R') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, inReLUGradTensor.describe()); - - generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto inBiasGradTensor = cudnn_frontend::TensorBuilder() - .setDim(4, b_dim) - .setStrides(4, stride) - .setId('y') - .setAlignment(16) - .setDataType(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, inBiasGradTensor.describe()); - - // Define the convolution problem - auto convDesc = cudnn_frontend::ConvDescBuilder() - .setDataType(CUDNN_DATA_FLOAT) - .setMathMode(CUDNN_CROSS_CORRELATION) - .setNDims(convDim) - .setStrides(convDim, convstride) - .setPrePadding(convDim, pad) - .setPostPadding(convDim, pad) - .setDilation(convDim, dilation) - .build(); - DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); - - // Define the activation backward operation - auto actDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_RELU_BWD) - .setMathPrecision(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, actDesc.describe()); - - // Define the bias backward operation - auto biasDesc = cudnn_frontend::ReductionDescBuilder() - .setMathPrecision(CUDNN_DATA_FLOAT) - .setReductionOp(CUDNN_REDUCE_TENSOR_ADD) - .build(); - DEBUG_CUDNN_MSG(log_buf, biasDesc.describe()); - - // Create a convolution Node - auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) - .setdyDesc(outConvGradTensor) - .setwDesc(wTensor) - .setdxDesc(inConvGradTensor) - .setcDesc(convDesc) - .setAlpha(alpha) - .setBeta(beta) - .build(); - DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); - - // Create an relu backward Node - auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setdyDesc(inConvGradTensor) - .setxDesc(rTensor) - .setdxDesc(inReLUGradTensor) - .setpwDesc(actDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, act_op.describe()); - - // Create bias node - auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) - .setxDesc(inReLUGradTensor) - .setyDesc(inBiasGradTensor) - .setreductionDesc(biasDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, bias_op.describe()); - - // Create an Operation Graph. In this case it is bias only - std::array ops = {&conv_op, &act_op, &bias_op}; - - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle_) - .setOperationGraph(ops.size(), ops.data()) - .build(); - - // Create string encoding for plan caching - auto cache_string = getConvFusionString(x_dim, pad, convstride, dilation, w_dim, dataType, opGraph.getTag()); - DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); - - auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); - DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); - - auto workspace_size = plan.getWorkspaceSize(); - DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); - - void* workspace_ptr = nullptr; - auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); - if (workspace_size > 0) { - workspace_ptr = workspace_tensor.data_ptr(); - } - void* data_ptrs[] = {devPtrX, devPtrW, devPtrR, devPtrRg, devPtrY}; - int64_t uids[] = {'x', 'w', 'r', 'R', 'y'}; - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace_ptr) - .setDataPointers(5, data_ptrs) - .setUids(5, uids) - .build(); - DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); - cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); - checkCudnnErr(status); - cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); - } catch (cudnn_frontend::cudnnException e) { - std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; - } - -} - - -void -run_dconv(int64_t* x_dim, - int64_t* w_dim, - int64_t* y_dim, - int64_t* conv_pad, - int64_t* conv_stride, - int64_t* conv_dilation, - cudnnDataType_t dataType, - at::Half* devPtrX, - at::Half* devPtrW, - at::Half* devPtrY, - cudnnBackendDescriptorType_t mode) { - - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); - std::stringstream log_buf; - - try { - int conv_dim = 2; - float alpha = 1.0f; - float beta = 0.0f; - - // Define the convolution problem - int64_t stride[4]; - generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto xTensor = cudnn_frontend::TensorBuilder() - .setDim(4, x_dim) - .setStrides(4, stride) - .setId('x') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, xTensor.describe()); - - generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto wTensor = cudnn_frontend::TensorBuilder() - .setDim(4, w_dim) - .setStrides(4, stride) - .setId('w') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, wTensor.describe()); - - generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto yTensor = cudnn_frontend::TensorBuilder() - .setDim(4, y_dim) - .setStrides(4, stride) - .setId('y') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, yTensor.describe()); - - - // Define the convolution problem - auto convDesc = cudnn_frontend::ConvDescBuilder() - .setDataType(CUDNN_DATA_FLOAT) - .setMathMode(CUDNN_CROSS_CORRELATION) - .setNDims(conv_dim) - .setStrides(conv_dim, conv_stride) - .setPrePadding(conv_dim, conv_pad) - .setPostPadding(conv_dim, conv_pad) - .setDilation(conv_dim, conv_dilation) - .build(); - DEBUG_CUDNN_MSG(log_buf, convDesc.describe()); - - // Create a convolution node - // mode should be one of following - // CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR - // CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR - auto conv_op_builder = cudnn_frontend::OperationBuilder(mode); - if (mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) { - conv_op_builder.setdxDesc(xTensor) - .setwDesc(wTensor) - .setdyDesc(yTensor) - .setcDesc(convDesc); - } - else { - conv_op_builder.setxDesc(xTensor) - .setdwDesc(wTensor) - .setdyDesc(yTensor) - .setcDesc(convDesc); - } - auto conv_op = conv_op_builder - .setAlpha(alpha) - .setBeta(beta) - .build(); - DEBUG_CUDNN_MSG(log_buf, conv_op.describe()); - - // Create an Operation Graph. In this case it is convolution add bias activation - std::array ops = {&conv_op}; - - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle_) - .setOperationGraph(ops.size(), ops.data()) - .build(); - - // Create string encoding for plan caching - auto cache_string = getConvFusionString(x_dim, conv_pad, conv_stride, conv_dilation, w_dim, dataType, opGraph.getTag()); - DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); - - auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); - DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); - - auto workspace_size = plan.getWorkspaceSize(); - DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); - - void* workspace_ptr = nullptr; - auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); - if (workspace_size > 0) { - workspace_ptr = workspace_tensor.data_ptr(); - } - void* data_ptrs[] = {devPtrX, devPtrW, devPtrY}; - int64_t uids[] = {'x', 'w', 'y'}; - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace_ptr) - .setDataPointers(3, data_ptrs) - .setUids(3, uids) - .build(); - DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); - cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); - checkCudnnErr(status); - cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); - } catch (cudnn_frontend::cudnnException e) { - std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; - } -} - - -void -run_dbias(int64_t* x_dim, - cudnnDataType_t dataType, - at::Half* devPtrX, - float* devPtrY) { - cudnnHandle_t handle_ = torch::native::getCudnnHandle(); - std::stringstream log_buf; - try { - int convDim = 2; - int64_t b_dim[] = {1, x_dim[1], 1, 1}; - - int64_t stride[4]; - generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto xTensor = cudnn_frontend::TensorBuilder() - .setDim(4, x_dim) - .setStrides(4, stride) - .setId('x') - .setAlignment(16) - .setDataType(dataType) - .build(); - DEBUG_CUDNN_MSG(log_buf, xTensor.describe()); - - generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC); - auto yTensor = cudnn_frontend::TensorBuilder() - .setDim(4, b_dim) - .setStrides(4, stride) - .setId('y') - .setAlignment(16) - .setDataType(CUDNN_DATA_FLOAT) - .build(); - DEBUG_CUDNN_MSG(log_buf, yTensor.describe()); - - // Define the bias backward operation - auto biasDesc = cudnn_frontend::ReductionDescBuilder() - .setMathPrecision(CUDNN_DATA_FLOAT) - .setReductionOp(CUDNN_REDUCE_TENSOR_ADD) - .build(); - DEBUG_CUDNN_MSG(log_buf, biasDesc.describe()); - - // Create bias node - auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) - .setxDesc(xTensor) - .setyDesc(yTensor) - .setreductionDesc(biasDesc) - .build(); - DEBUG_CUDNN_MSG(log_buf, bias_op.describe()); - - // Create an Operation Graph. In this case it is bias only - std::array ops = {&bias_op}; - - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle_) - .setOperationGraph(ops.size(), ops.data()) - .build(); - - // Create string encoding for plan caching - int64_t pad_dummy[] = {10, 10}; - int64_t stride_dummy[] = {10, 10}; - int64_t dilation_dummy[] = {10, 10}; - auto cache_string = getConvFusionString(x_dim, pad_dummy, stride_dummy, dilation_dummy, b_dim, dataType, opGraph.getTag()); - DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string); - - auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string); - DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag()); - - auto workspace_size = plan.getWorkspaceSize(); - DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size); - - void* workspace_ptr = nullptr; - auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat)); - if (workspace_size > 0) { - workspace_ptr = workspace_tensor.data_ptr(); - } - void* data_ptrs[] = {devPtrX, devPtrY}; - int64_t uids[] = {'x', 'y'}; - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace_ptr) - .setDataPointers(2, data_ptrs) - .setUids(2, uids) - .build(); - DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe()); - cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc()); - checkCudnnErr(status); - cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status); - } catch (cudnn_frontend::cudnnException e) { - std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl; - } - -} - - -std::vector conv_bias_mask_relu_forward(std::vector inputs, int64_t padding, int64_t stride) { - std::cout << std::fixed; - - // create output vector - std::vector outputs; - auto output_format = at::MemoryFormat::ChannelsLast; - - // setup dimensions - int64_t x_dim[] = {0, 0, 0, 0}; - int64_t w_dim[] = {0, 0, 0, 0}; - - // All dim calculation after this order of n,c,h,w - int axis[] = {0, 1, 2, 3}; - for (int dim = 0; dim < 4; dim++) { - x_dim[dim] = inputs[0].size(axis[dim]); - w_dim[dim] = inputs[1].size(axis[dim]); - } - - // output dim in n,c,h,w used by backend - int64_t y_dim[] = {0, 0, 0, 0}; - - // use these fixed values - int64_t conv_pad[] = {padding, padding}; - int64_t conv_stride[] = {stride, stride}; - int64_t conv_dilation[] = {1, 1}; - - // compute output from pad/stride/dilation - y_dim[0] = x_dim[0]; - y_dim[1] = w_dim[0]; - for (int dim = 0; dim < 2; dim++) { - y_dim[dim + 2] = getFwdConvOutputDim(x_dim[dim + 2], conv_pad[dim], w_dim[dim + 2], conv_stride[dim], conv_dilation[dim]); - } - - // run - at::Half* x = inputs[0].data_ptr(); - at::Half* w = inputs[1].data_ptr(); - at::Half* b = inputs[2].data_ptr(); - int8_t* m = inputs[3].data_ptr(); - auto out = at::empty(y_dim, inputs[0].type(), output_format); - at::Half* y = out.data_ptr(); - - run_conv_bias_mask_relu(x_dim, - w_dim, - y_dim, - conv_pad, - conv_stride, - conv_dilation, - CUDNN_DATA_HALF, - x, - w, - b, - m, - y); - - DEBUG_MSG("[DEBUG] conv-bias-mask-relu : " << y.to(at::kFloat).sum().item()); - - outputs.push_back(out); - - return outputs; -} - - -std::vector conv_bias_relu_forward(std::vector inputs, int64_t padding, int64_t stride) { - std::cout << std::fixed; - - // create output vector - std::vector outputs; - auto output_format = at::MemoryFormat::ChannelsLast; - - // setup dimensions - int64_t x_dim[] = {0, 0, 0, 0}; - int64_t w_dim[] = {0, 0, 0, 0}; - - // All dim calculation after this order of n,c,h,w - int axis[] = {0, 1, 2, 3}; - for (int dim = 0; dim < 4; dim++) { - x_dim[dim] = inputs[0].size(axis[dim]); - w_dim[dim] = inputs[1].size(axis[dim]); - } - - // output dim in n,c,h,w used by backend - int64_t y_dim[] = {0, 0, 0, 0}; - - // use these fixed values - int64_t conv_pad[] = {padding, padding}; - int64_t conv_stride[] = {stride, stride}; - int64_t conv_dilation[] = {1, 1}; - - // compute output from pad/stride/dilation - y_dim[0] = x_dim[0]; - y_dim[1] = w_dim[0]; - for (int dim = 0; dim < 2; dim++) { - y_dim[dim + 2] = getFwdConvOutputDim(x_dim[dim + 2], conv_pad[dim], w_dim[dim + 2], conv_stride[dim], conv_dilation[dim]); - } - - // run - at::Half* x = inputs[0].data_ptr(); - at::Half* w = inputs[1].data_ptr(); - at::Half* b = inputs[2].data_ptr(); - auto out = at::empty(y_dim, inputs[0].type(), output_format); - at::Half* y = out.data_ptr(); - - run_conv_bias_relu(x_dim, - w_dim, - y_dim, - conv_pad, - conv_stride, - conv_dilation, - CUDNN_DATA_HALF, - x, - w, - b, - y); - - DEBUG_MSG("[DEBUG] conv-bias-relu : " << y.to(at::kFloat).sum().item()); - - outputs.push_back(out); - - return outputs; -} - - -std::vector conv_bias_relu_backward(std::vector inputs, int64_t padding, int64_t stride) { - bool requires_grad = inputs[0].requires_grad(); - - for (int i = 0; i <= 3; i++) { - CHECK_INPUT(inputs[i]); - } - - std::cout << std::fixed; - - // create output vector - std::vector outputs; - auto output_format = at::MemoryFormat::ChannelsLast; - - // setup dimensions - int64_t x_dim[] = {0, 0, 0, 0}; - int64_t w_dim[] = {0, 0, 0, 0}; - int64_t y_dim[] = {0, 0, 0, 0}; - - // All dim calculation after this order of n,c,h,w - int axis[] = {0, 1, 2, 3}; - for (int dim = 0; dim < 4; dim++) { - x_dim[dim] = inputs[0].size(axis[dim]); - w_dim[dim] = inputs[1].size(axis[dim]); - y_dim[dim] = inputs[3].size(axis[dim]); - } - - int64_t b_dim[] = {1, y_dim[1], 1, 1}; - - int64_t conv_pad[] = {padding, padding}; - int64_t conv_stride[] = {stride, stride}; - int64_t conv_dilation[] = {1, 1}; - - // run - // drelu-dbias - at::Half* dy = inputs[3].data_ptr(); - at::Half* r = inputs[2].data_ptr(); - auto drelu = at::empty_like(inputs[2]); - at::Half* dr = drelu.data_ptr(); - auto options = at::TensorOptions().dtype(at::kFloat).layout(inputs[0].layout()).device(inputs[0].device()).requires_grad(false); - auto bgrad = at::empty(b_dim, options, output_format); - float* db = bgrad.data_ptr(); - run_drelu_dbias(y_dim, - CUDNN_DATA_HALF, - dy, - r, - dr, - db); - - // conv wgrad - at::Half* x = inputs[0].data_ptr(); - auto wgrad = at::empty_like(inputs[1]); - at::Half* dw = wgrad.data_ptr(); - run_dconv(x_dim, - w_dim, - y_dim, - conv_pad, - conv_stride, - conv_dilation, - CUDNN_DATA_HALF, - x, - dw, - dr, - CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); - - // conv dgrad - at::Half* w = inputs[1].data_ptr(); - auto dgrad = at::empty_like(inputs[0]); - at::Half* dx = dgrad.data_ptr(); - run_dconv(x_dim, - w_dim, - y_dim, - conv_pad, - conv_stride, - conv_dilation, - CUDNN_DATA_HALF, - dx, - w, - dr, - CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR); - - outputs.push_back(dgrad); - outputs.push_back(wgrad); - outputs.push_back(bgrad); - - return outputs; - -} - -std::vector conv_bias_forward(std::vector inputs, int64_t padding, int64_t stride) { - std::cout << std::fixed; - - // create output vector - std::vector outputs; - auto output_format = at::MemoryFormat::ChannelsLast; - - // setup dimensions - int64_t x_dim[] = {0, 0, 0, 0}; - int64_t w_dim[] = {0, 0, 0, 0}; - - // All dim calculation after this order of n,c,h,w - int axis[] = {0, 1, 2, 3}; - for (int dim = 0; dim < 4; dim++) { - x_dim[dim] = inputs[0].size(axis[dim]); - w_dim[dim] = inputs[1].size(axis[dim]); - } - - // output dim in n,c,h,w used by backend - int64_t y_dim[] = {0, 0, 0, 0}; - - // use these fixed values - int64_t conv_pad[] = {padding, padding}; - int64_t conv_stride[] = {stride, stride}; - int64_t conv_dilation[] = {1, 1}; - - // compute output from pad/stride/dilation - y_dim[0] = x_dim[0]; - y_dim[1] = w_dim[0]; - for (int dim = 0; dim < 2; dim++) { - y_dim[dim + 2] = getFwdConvOutputDim(x_dim[dim + 2], conv_pad[dim], w_dim[dim + 2], conv_stride[dim], conv_dilation[dim]); - } - - // run - at::Half* x = inputs[0].data_ptr(); - at::Half* w = inputs[1].data_ptr(); - at::Half* b = inputs[2].data_ptr(); - auto out = at::empty(y_dim, inputs[0].type(), output_format); - at::Half* y = out.data_ptr(); - - run_conv_bias(x_dim, - w_dim, - y_dim, - conv_pad, - conv_stride, - conv_dilation, - CUDNN_DATA_HALF, - x, - w, - b, - y); - - DEBUG_MSG("[DEBUG] conv-bias : " << y.to(at::kFloat).sum().item()); - - outputs.push_back(out); - - return outputs; -} - - -std::vector conv_bias_backward(std::vector inputs, int64_t padding, int64_t stride) { - bool requires_grad = inputs[0].requires_grad(); - - for (int i = 0; i <= 2; i++) { - CHECK_INPUT(inputs[i]); - } - - std::cout << std::fixed; - - // create output vector - std::vector outputs; - auto output_format = at::MemoryFormat::ChannelsLast; - - // setup dimensions - int64_t x_dim[] = {0, 0, 0, 0}; - int64_t w_dim[] = {0, 0, 0, 0}; - int64_t y_dim[] = {0, 0, 0, 0}; - - // All dim calculation after this order of n,c,h,w - int axis[] = {0, 1, 2, 3}; - for (int dim = 0; dim < 4; dim++) { - x_dim[dim] = inputs[0].size(axis[dim]); - w_dim[dim] = inputs[1].size(axis[dim]); - y_dim[dim] = inputs[2].size(axis[dim]); - } - - int64_t b_dim[] = {1, y_dim[1], 1, 1}; - - int64_t conv_pad[] = {padding, padding}; - int64_t conv_stride[] = {stride, stride}; - int64_t conv_dilation[] = {1, 1}; - - // run - // dbias - at::Half* dy = inputs[2].data_ptr(); - auto options = at::TensorOptions().dtype(at::kFloat).layout(inputs[0].layout()).device(inputs[0].device()).requires_grad(false); - auto bgrad = at::empty(b_dim, options, output_format); - float* db = bgrad.data_ptr(); - run_dbias(y_dim, - CUDNN_DATA_HALF, - dy, - db); - - // conv wgrad - at::Half* x = inputs[0].data_ptr(); - auto wgrad = at::empty_like(inputs[1]); - at::Half* dw = wgrad.data_ptr(); - run_dconv(x_dim, - w_dim, - y_dim, - conv_pad, - conv_stride, - conv_dilation, - CUDNN_DATA_HALF, - x, - dw, - dy, - CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR); - - // conv dgrad - at::Half* w = inputs[1].data_ptr(); - auto dgrad = at::empty_like(inputs[0]); - at::Half* dx = dgrad.data_ptr(); - run_dconv(x_dim, - w_dim, - y_dim, - conv_pad, - conv_stride, - conv_dilation, - CUDNN_DATA_HALF, - dx, - w, - dy, - CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR); - - outputs.push_back(dgrad); - outputs.push_back(wgrad); - outputs.push_back(bgrad); - - return outputs; -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &conv_bias_relu_forward, "Fused Conv-Bias-ReLU forward"); - m.def("backward", &conv_bias_relu_backward, "Fused Conv-Bias-ReLU backward"); - m.def("forward_no_relu", &conv_bias_forward, "Fused Conv-Bias forward"); - m.def("backward_no_relu", &conv_bias_backward, "Fused Conv-Bias backward"); - m.def("forward_mask", &conv_bias_mask_relu_forward, "Fused Conv-Bias-Mask-ReLU forward"); -} - diff --git a/apex/contrib/csrc/cudnn-frontend b/apex/contrib/csrc/cudnn-frontend deleted file mode 160000 index fa61199..0000000 --- a/apex/contrib/csrc/cudnn-frontend +++ /dev/null @@ -1 +0,0 @@ -Subproject commit fa611998a360cbabaa2dcc7c9859748144114fc0 diff --git a/apex/contrib/csrc/fmha/fmha_api.cpp b/apex/contrib/csrc/fmha/fmha_api.cpp deleted file mode 100644 index 07865b6..0000000 --- a/apex/contrib/csrc/fmha/fmha_api.cpp +++ /dev/null @@ -1,361 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#include -#include - -#include "fmha.h" - -void set_params(Fused_multihead_attention_fprop_params ¶ms, - // sizes - const size_t b, - const size_t s, - const size_t h, - const size_t d, - // device pointers - void *qkv_packed_d, - void *cu_seqlens_d, - void *o_packed_d, - void *s_d, - float p_dropout) { - - Data_type acc_type = DATA_TYPE_FP32; - Data_type data_type = DATA_TYPE_FP16; - - // Reset the parameters - memset(¶ms, 0, sizeof(params)); - - // Set the pointers and strides. - params.qkv_ptr = qkv_packed_d; - params.qkv_stride_in_bytes = get_size_in_bytes(h * 3 * d, data_type); - params.o_ptr = o_packed_d; - params.o_stride_in_bytes = get_size_in_bytes(h * d, data_type); - - params.cu_seqlens = static_cast(cu_seqlens_d); - - // S = softmax(P) - params.s_ptr = s_d; - params.s_stride_in_bytes = get_size_in_bytes(b * h * s, data_type); - - // Set the dimensions. - params.b = b; - params.h = h; - params.s = s; - params.d = d; - - // Set the different scale values. - const float scale_bmm1 = 1.f / sqrtf(d); - constexpr float scale_softmax = 1.f; - constexpr float scale_bmm2 = 1.f; - - set_alpha(params.scale_bmm1, scale_bmm1, data_type); - set_alpha(params.scale_softmax, scale_softmax, acc_type); - set_alpha(params.scale_bmm2, scale_bmm2, data_type); - - // Set this to probability of keeping an element to simplify things. - params.p_dropout = 1.f - p_dropout; - params.rp_dropout = 1.f / params.p_dropout; - TORCH_CHECK(p_dropout < 1.f); - set_alpha(params.scale_dropout, params.rp_dropout, data_type); -} - -std::vector -mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i - const at::Tensor &cu_seqlens, // b+1 - const float p_dropout, - const int max_seq_len, - const bool is_training, - const bool is_nl, - const bool zero_tensors, - c10::optional gen_) { - - auto dprops = at::cuda::getCurrentDeviceProperties(); - TORCH_CHECK(dprops->major == 8 && dprops->minor == 0); - auto stream = at::cuda::getCurrentCUDAStream().stream(); - Launch_params launch_params(dprops, stream, is_training, is_nl); - - int seq_len = 512; - auto launch = &run_fmha_fp16_512_64_sm80; - if( max_seq_len <= 128 ) { - seq_len = 128; - launch = &run_fmha_fp16_128_64_sm80; - } else if( max_seq_len <= 256 ) { - seq_len = 256; - launch = &run_fmha_fp16_256_64_sm80; - } else if( max_seq_len <= 384 ) { - seq_len = 384; - launch = &run_fmha_fp16_384_64_sm80; - } else if( max_seq_len <= 512 ) { - seq_len = 512; - launch = &run_fmha_fp16_512_64_sm80; - } else { - TORCH_CHECK(false); - } - - TORCH_CHECK(qkv.is_cuda()) - TORCH_CHECK(cu_seqlens.is_cuda()) - - TORCH_CHECK(qkv.is_contiguous()) - TORCH_CHECK(cu_seqlens.is_contiguous()) - - TORCH_CHECK(cu_seqlens.dim() == 1); - TORCH_CHECK(qkv.dim() == 4); - - const auto sizes = qkv.sizes(); - - TORCH_CHECK(sizes[THREE_DIM] == 3); - - const int batch_size = cu_seqlens.numel() - 1; - const int total = sizes[TOTAL_DIM]; - const int num_heads = sizes[H_DIM]; - const int head_size = sizes[D_DIM]; - TORCH_CHECK(batch_size > 0); - TORCH_CHECK(head_size == 64); - auto opts = qkv.options(); - - auto ctx = torch::empty({ total, num_heads, head_size }, opts); - - auto s = torch::empty({ batch_size, num_heads, seq_len, seq_len }, opts); - - if( zero_tensors ) { - ctx.zero_(); - s.zero_(); - } - - auto gen = at::get_generator_or_default( - gen_, at::cuda::detail::getDefaultCUDAGenerator()); - - - set_params(launch_params.params, - batch_size, - seq_len, - num_heads, - head_size, - qkv.data_ptr(), - cu_seqlens.data_ptr(), - ctx.data_ptr(), - s.data_ptr(), - p_dropout); - - launch(launch_params, /*configure=*/ true); - // number of times random will be generated per thread, to offset philox counter in thc random - // state - int64_t counter_offset = launch_params.elts_per_thread; - at::PhiloxCudaState rng_engine_inputs; - - if( is_training ) { - // See Note [Acquire lock when using random generators] - std::lock_guard lock(gen->mutex_); - launch_params.params.philox_args = gen->philox_cuda_state(counter_offset); - } - - launch(launch_params, /*configure=*/ false); - - return { ctx, s }; -} - - -std::vector -mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size - const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i - at::Tensor &softmax, // b x h x s x s softmax and dmask - will be overwritten with dP - const at::Tensor &cu_seqlens, // b+1 - const float p_dropout, // probability to drop - const int max_seq_len, // max sequence length to choose the kernel - const bool zero_tensors -) { - auto dprops = at::cuda::getCurrentDeviceProperties(); - TORCH_CHECK(dprops->major == 8 && dprops->minor == 0); - int seq_len = 512; - auto launch = &run_fmha_dgrad_fp16_512_64_sm80; - if( max_seq_len <= 128 ) { - seq_len = 128; - launch = &run_fmha_dgrad_fp16_128_64_sm80; - } else if( max_seq_len <= 256 ) { - seq_len = 256; - launch = &run_fmha_dgrad_fp16_256_64_sm80; - } else if( max_seq_len <= 384 ) { - seq_len = 384; - launch = &run_fmha_dgrad_fp16_384_64_sm80; - } else if( max_seq_len <= 512 ) { - seq_len = 512; - launch = &run_fmha_dgrad_fp16_512_64_sm80; - } else { - TORCH_CHECK(false); - } - - auto stream = at::cuda::getCurrentCUDAStream().stream(); - - TORCH_CHECK(qkv.dtype() == torch::kFloat16); - TORCH_CHECK(dout.dtype() == torch::kFloat16); - TORCH_CHECK(softmax.dtype() == torch::kFloat16); - TORCH_CHECK(cu_seqlens.dtype() == torch::kInt32); - - TORCH_CHECK(qkv.is_cuda()); - TORCH_CHECK(cu_seqlens.is_cuda()); - - TORCH_CHECK(qkv.is_contiguous()); - TORCH_CHECK(cu_seqlens.is_contiguous()); - - TORCH_CHECK(cu_seqlens.dim() == 1); - TORCH_CHECK(qkv.dim() == 4); - - const auto sizes = qkv.sizes(); - - TORCH_CHECK(sizes[THREE_DIM] == 3); - - const int batch_size = cu_seqlens.numel() - 1; - const int num_heads = sizes[H_DIM]; - const int head_size = sizes[D_DIM]; - TORCH_CHECK(batch_size > 0); - TORCH_CHECK(head_size == 64); - - auto dqkv = torch::empty_like(qkv); - - if( zero_tensors ) { - dqkv.zero_(); - } - - Fused_multihead_attention_fprop_params params; - - set_params(params, - batch_size, - seq_len, - num_heads, - head_size, - qkv.data_ptr(), - cu_seqlens.data_ptr(), - dout.data_ptr(), // we set o_ptr to dout - softmax.data_ptr(), // softmax gets overwritten by dP! - p_dropout); - - // we're re-using these scales - Data_type acc_type = DATA_TYPE_FP32; - set_alpha(params.scale_bmm1, 1.f, acc_type); - set_alpha(params.scale_softmax, 1.f / sqrtf(head_size), acc_type); - set_alpha(params.scale_bmm2, 1.f, DATA_TYPE_FP16); - params.dqkv_ptr = dqkv.data_ptr(); - - launch(params, stream); - return { dqkv, softmax }; -} - -std::vector mha_bwd_nl(const at::Tensor &dout, // total x num_heads, x head_size - const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i - at::Tensor &softmax, // b x h x s x s softmax and dmask - will be overwritten with dP - const at::Tensor &cu_seqlens, // b+1 - const float p_dropout, // probability to drop - const int max_seq_len, // max sequence length to choose the kernel - const bool zero_tensors -) { - - auto stream = at::cuda::getCurrentCUDAStream().stream(); - - TORCH_CHECK(qkv.is_cuda()) - TORCH_CHECK(cu_seqlens.is_cuda()) - - TORCH_CHECK(qkv.is_contiguous()) - TORCH_CHECK(cu_seqlens.is_contiguous()) - - TORCH_CHECK(cu_seqlens.dim() == 1); - - TORCH_CHECK(qkv.dim() == 4); - - const auto sizes = qkv.sizes(); - - TORCH_CHECK(sizes[THREE_DIM] == 3); - - const int batch_size = cu_seqlens.numel() - 1; - - const int total = sizes[TOTAL_DIM]; - const int num_heads = sizes[H_DIM]; - const int head_size = sizes[D_DIM]; - TORCH_CHECK(batch_size > 0); - TORCH_CHECK(head_size == 64); - - int seq_len = 512; - auto launch = &run_fmha_dgrad_fp16_512_64_sm80_nl; - - auto opts = qkv.options(); - - auto dqkv = torch::empty_like(qkv); - - if( zero_tensors ) { - dqkv.zero_(); - } - - int num_chunks = 2; - if( batch_size == 1 ) { - num_chunks = 4; - }else if( batch_size == 2 ) { - num_chunks = 3; - } - auto dkv = torch::empty({total, num_chunks, 2, num_heads, head_size}, opts); - - Fused_multihead_attention_fprop_params params; - - set_params(params, - batch_size, - seq_len, - num_heads, - head_size, - qkv.data_ptr(), - cu_seqlens.data_ptr(), - dout.data_ptr(), // o_ptr = dout - softmax.data_ptr(), // softmax gets overwritten by dP! - p_dropout); - - params.dkv_ptr = dkv.data_ptr(); - - Data_type acc_type = DATA_TYPE_FP32; - set_alpha(params.scale_bmm1, 1.f, acc_type); - set_alpha(params.scale_softmax, 1.f / sqrtf(head_size), acc_type); - set_alpha(params.scale_bmm2, 1.f, DATA_TYPE_FP16); - params.dqkv_ptr = dqkv.data_ptr(); - - launch(params, num_chunks, stream); - - //SPLIT-K reduction of num_chunks dK, dV parts - - // The equivalent of the following Pytorch code: - // using namespace torch::indexing; - // at::Tensor view_out = dqkv.index({Slice(), Slice(1, None, None)}); - // torch::sum_out(view_out, dkv, 1); - - const int hidden_size = num_heads * head_size; - fmha_run_noloop_reduce( - dqkv.data_ptr(), dkv.data_ptr(), cu_seqlens.data_ptr(), hidden_size, batch_size, total, num_chunks, stream); - - return { dqkv, softmax, dkv }; -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.doc() = "Fused Multi-head Self-attention for BERT"; - m.def("fwd", &mha_fwd, "Forward pass"); - m.def("bwd", &mha_bwd, "Backward pass"); - m.def("bwd_nl", &mha_bwd_nl, "Backward pass (small-batch)"); -} diff --git a/apex/contrib/csrc/fmha/src/fmha.h b/apex/contrib/csrc/fmha/src/fmha.h deleted file mode 100644 index d01a915..0000000 --- a/apex/contrib/csrc/fmha/src/fmha.h +++ /dev/null @@ -1,163 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#pragma once - -#include -#include - -#ifdef OLD_GENERATOR_PATH -#include -#else -#include -#endif - -#include - -#include - - -constexpr int TOTAL_DIM = 0; -constexpr int THREE_DIM = 1; -constexpr int H_DIM = 2; -constexpr int D_DIM = 3; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct Qkv_params { - // The QKV matrices. - void * __restrict__ qkv_ptr; - - // The stride between rows of the Q, K and V matrices. - size_t qkv_stride_in_bytes; - - // The number of heads. - int h; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct Fused_multihead_attention_fprop_params : public Qkv_params { - - // The dQKV matrices. - void * __restrict__ dqkv_ptr; - - // Temporary for dKV. - void * __restrict__ dkv_ptr; - - // The O matrix (output). - void * __restrict__ o_ptr; - - // The stride between rows of O. - int64_t o_stride_in_bytes; - - // The pointer to the S matrix, overwritten by the dP matrix (bwd). - void * __restrict__ s_ptr; - // The stride between rows of the S matrix. - int64_t s_stride_in_bytes; - - // The dimensions. - int b, s, d; - - // The scaling factors for the kernel. - uint32_t scale_bmm1, scale_softmax, scale_bmm2; - - // array of length b+1 holding starting offset of each sequence. - int * __restrict__ cu_seqlens; - - // The dropout probability (probability of keeping an activation). - float p_dropout; - - // Scale factor of 1 / (1 - p_dropout). - float rp_dropout; - - // Scale factor of 1 / (1 - p_dropout), in half2. - uint32_t scale_dropout; - - // Random state. - at::PhiloxCudaState philox_args; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Launch_params{ - Launch_params(cudaDeviceProp * props_, - cudaStream_t stream_, - bool is_training_, - bool is_nl_) - : elts_per_thread(0) - , props(props_) - , stream(stream_) - , is_training(is_training_) - , is_nl(is_nl_) { - } - - size_t elts_per_thread; - - cudaDeviceProp * props; - - cudaStream_t stream; - - bool is_training; - - Kernel_params params; - int num_full_heads; - int num_main_groups; - int heads_last_wave; - int main_steps; - int rest_steps; - bool is_nl; - -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void run_fmha_fp16_128_64_sm80(Launch_params &launch_params, const bool configure); -void run_fmha_fp16_256_64_sm80(Launch_params &launch_params, const bool configure); -void run_fmha_fp16_384_64_sm80(Launch_params &launch_params, const bool configure); -void run_fmha_fp16_512_64_sm80(Launch_params &launch_params, const bool configure); - -void run_fmha_dgrad_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params ¶ms, cudaStream_t stream); -void run_fmha_dgrad_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params ¶ms, cudaStream_t stream); -void run_fmha_dgrad_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params ¶ms, cudaStream_t stream); -void run_fmha_dgrad_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params ¶ms, cudaStream_t stream); - -void run_fmha_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_params ¶ms, const bool is_training, const int num_chunks, cudaStream_t stream); - -void run_fmha_dgrad_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_params ¶ms, const int num_chunks, cudaStream_t stream); - -void fmha_run_noloop_reduce(void *out, - const void *in, - const int *cu_seqlens, - const int hidden_size, - const int batch_size, - const int total, - const int num_chunks, - cudaStream_t stream); - - diff --git a/apex/contrib/csrc/fmha/src/fmha/gemm.h b/apex/contrib/csrc/fmha/src/fmha/gemm.h deleted file mode 100644 index 62529a2..0000000 --- a/apex/contrib/csrc/fmha/src/fmha/gemm.h +++ /dev/null @@ -1,314 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#pragma once - -#include - -#define FMHA_DIV_UP(m, n) (((m) + (n)-1) / (n)) - -namespace fmha { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< typename Data_type_, int NUM_ELTS_, int BITS_PER_ELT_, int ALIGNMENT_ > -struct Fragment_base_ { - - // The data type. - using Data_type = Data_type_; - // default input type - using Input_type_ = Data_type_; - // Does it store the array of elements. - enum { HAS_ELTS = BITS_PER_ELT_ >= 8 }; - // The number of elements. - enum { NUM_ELTS = NUM_ELTS_ }; - // The size of element in bits. - enum { BITS_PER_ELT = BITS_PER_ELT_ }; - // The size of byte of a single register. - enum { BYTES_PER_REG = 4 }; - // The size in bits. - enum { BITS_PER_REG = BYTES_PER_REG * 8 }; - // The number of registers needed to store the fragment. - enum { NUM_REGS = Div_up::VALUE }; - // The size in bytes (as returned by sizeof(Fragment_base<>). - enum { SIZE_IN_BYTES = NUM_REGS * BYTES_PER_REG }; - // The alignment. - enum { ALIGNMENT = ALIGNMENT_ > 0 ? ALIGNMENT_ : Min::VALUE }; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< - // The type of the elements. - typename Data_type_, - // The number of elements. - int NUM_ELTS_, - // The alignment if you want to force a value -- use 0 otherwise. - int ALIGNMENT_ = 0, - // The base class. - typename Base_ = Fragment_base_ -> -struct alignas(static_cast(Base_::ALIGNMENT)) Fragment : public Base_ { - - // The size of a load/store. - enum { BYTES_PER_LOAD_STORE = Base_::NUM_REGS * sizeof(uint32_t) }; - - // Clear the fragment. Using PTX in that code seems to produce better SASS... - inline __device__ void clear() { - #pragma unroll - for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) { - asm volatile("mov.u32 %0, 0; \n" : "=r"(this->reg(ii)) : ); - } - } - - // Immutable access to a register. - inline __device__ const uint32_t& reg(int ii) const { - return this->regs_[ii]; - } - - // Mutable access to a register. - inline __device__ uint32_t& reg(int ii) { - return this->regs_[ii]; - } - - uint32_t regs_[Base_::NUM_REGS]; - - // Immutable access to the elements. - inline __device__ const Data_type_& elt(int ii) const { - return reinterpret_cast(&this->regs_[0])[ii]; - } - - // Mutable access to the elements. - inline __device__ Data_type_& elt(int ii) { - return reinterpret_cast(&this->regs_[0])[ii]; - } - - // Immutable access to the elements with a cast. - template< typename Cast_type > - inline __device__ const Cast_type& elt_as(int ii) const { - return reinterpret_cast(&this->regs_[0])[ii]; - } - - // Mutable access to the elements. - template< typename Cast_type > - inline __device__ Cast_type& elt_as(int ii) { - return reinterpret_cast(&this->regs_[0])[ii]; - } - - // Add another fragment. - inline __device__ void add(const Fragment &other) { - #pragma unroll - for( int ii = 0; ii < NUM_ELTS_; ++ii ) { - this->elt(ii) += other.elt(ii); - } - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< typename Layout > -struct Fragment_a : public Fragment { -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< typename Layout > -struct Fragment_b : public Fragment { -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct Fragment_accumulator : public Fragment { - - // The base class. - using Base = Fragment; - - // Add two fragments. - template< typename Other_fragment_ > - inline __device__ void add(const Other_fragment_ &other) { - for( int ii = 0; ii < Base::NUM_ELTS; ++ii ) { - this->elt(ii) = this->elt(ii) + other.elt(ii); - } - } - - // Do the HMMA. - template< typename Layout_a, typename Layout_b > - inline __device__ void mma(const Fragment_a &a, - const Fragment_b &b) { - asm volatile( \ - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" \ - " {%0, %1, %2, %3}, \n" \ - " {%4, %5, %6, %7}, \n" \ - " {%8, %9}, \n" \ - " {%0, %1, %2, %3}; \n" \ - : "+f"( elt(0)), "+f"( elt(1)), "+f"( elt(2)), "+f"( elt(3)) - : "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3)) - , "r"(b.reg(0)), "r"(b.reg(1))); - asm volatile( \ - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" \ - " {%0, %1, %2, %3}, \n" \ - " {%4, %5, %6, %7}, \n" \ - " {%8, %9}, \n" \ - " {%0, %1, %2, %3}; \n" \ - : "+f"( elt(4)), "+f"( elt(5)), "+f"( elt(6)), "+f"( elt(7)) - : "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3)) - , "r"(b.reg(2)), "r"(b.reg(3))); - } - -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< typename Fragment, int M, int N > -inline __device__ void clear(Fragment (&frag)[M][N]) { - #pragma unroll - for( int mi = 0; mi < M; ++mi ) { - #pragma unroll - for( int ni = 0; ni < N; ++ni ) { - frag[mi][ni].clear(); - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< typename Accumulator_type, int WARPS_K > -struct Clear_accumulator { -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int WARPS_K > -struct Clear_accumulator { - template< typename Acc, int M, int N > - static inline __device__ void apply(Acc (&acc)[M][N], bool = false) { - fmha::clear(acc); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void gemm(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N]) { - - #pragma unroll - for( int mi = 0; mi < M; ++mi ) { - #pragma unroll - for( int ni = 0; ni < N; ++ni ) { - acc[mi][ni].mma(a[mi], b[ni]); - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< - // The number of rows in the CTA tile. - int M_, - // The number of cols in the CTA tile. - int N_, - // The number of elements in the the K dimension of the GEMM loop. - int K_, - // The number of rows of warps. - int WARPS_M_, - // The number of cols of warps. - int WARPS_N_, - // The number of warps in the K dimension of the GEMM loop. - int WARPS_K_> -struct Cta_tile_ { - - enum { M = M_, N = N_, K = K_ }; - // The number of warps. - enum { WARPS_M = WARPS_M_, WARPS_N = WARPS_N_, WARPS_K = WARPS_K_ }; - // The number of warps per CTA. - enum { WARPS_PER_CTA = WARPS_M * WARPS_N * WARPS_K }; - // The number of threads per warp. - enum { THREADS_PER_WARP = 32 }; - // The number of threads per CTA. - enum { THREADS_PER_CTA = WARPS_PER_CTA * THREADS_PER_WARP }; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Hmma_tile { - // The number of elements computed with a single warp-MMA. - enum { M_PER_MMA = 16, N_PER_MMA = 16, K_PER_MMA = 16 }; - - // The number of elements computed with a single CTA-MMA. - enum { - M_PER_MMA_PER_CTA = M_PER_MMA * Cta_tile::WARPS_M, - N_PER_MMA_PER_CTA = N_PER_MMA * Cta_tile::WARPS_N, - K_PER_MMA_PER_CTA = K_PER_MMA * Cta_tile::WARPS_K - }; - - // The number of MMAs needed to compute the GEMM. - enum { - MMAS_M = Div_up::VALUE, - MMAS_N = Div_up::VALUE, - MMAS_K = Div_up::VALUE, - }; - - // The number of elements computed per warp. - enum { - M_PER_WARP = MMAS_M * M_PER_MMA, - N_PER_WARP = MMAS_N * N_PER_MMA, - K_PER_WARP = MMAS_K * K_PER_MMA, - }; - -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -using A_type = uint16_t; -using B_type = uint16_t; -using C_type = uint16_t; -using Accumulator_type = float; -using Epilogue_type = float; - -constexpr int BITS_PER_ELEMENT_A = sizeof(A_type) * 8; -constexpr int BITS_PER_ELEMENT_B = sizeof(B_type) * 8; -constexpr int BITS_PER_ELEMENT_C = sizeof(C_type) * 8; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -using Cta_tile_extd = Cta_tile_; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -using Cta_tile_with_k_with_padding = Cta_tile_extd::VALUE, - Cta_tile_::WARPS_M, - Cta_tile_::WARPS_N, - Cta_tile_::WARPS_K>; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace fmha diff --git a/apex/contrib/csrc/fmha/src/fmha/gmem_tile.h b/apex/contrib/csrc/fmha/src/fmha/gmem_tile.h deleted file mode 100644 index 5c86dd8..0000000 --- a/apex/contrib/csrc/fmha/src/fmha/gmem_tile.h +++ /dev/null @@ -1,456 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#pragma once - -namespace fmha { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< - // The dimensions of the tile computed by the CTA. - typename Cta_tile, - // The number of bits per element. - int BITS_PER_ELEMENT, - // The number of rows of Q, K or V loaded by this tile. - int ROWS, - // The number of columns. - int COLS, - // The number of matrics. - int NUM_MATS = 3 -> -struct Gmem_tile_qkv { - - // The size of each LDG. - enum { BYTES_PER_LDG = 16 }; - // The size of a row in bytes. - enum { BYTES_PER_ROW = COLS * BITS_PER_ELEMENT / 8 }; - - // The number of threads to load a "row" of the matrix. - enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDG }; - - // The number of "rows" loaded per LDG. - enum { ROWS_PER_LDG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; - // The number of LDGs needed to load a chunk of the Q matrix. - enum { LDGS = fmha::Div_up::VALUE }; - - // Ctor. - template< typename Params, typename BInfo > - inline __device__ Gmem_tile_qkv(const Params ¶ms, const int qkv_offset, const BInfo &binfo, const int tidx) - : params_qkv_stride_in_bytes_(params.qkv_stride_in_bytes) - , actual_seqlen(binfo.actual_seqlen) - , qkv_ptr_(reinterpret_cast(params.qkv_ptr)) { - - // Compute the position in the sequence (within the CTA for the moment). - int row = tidx / THREADS_PER_ROW; - // Compute the position of the thread in the row. - int col = tidx % THREADS_PER_ROW; - - // Store the row as we need it to disable the loads. - row_ = row; - - // The row offset in the batched GEMM. For each seq element, we store QKV in that order. - int64_t row_offset = (int64_t)row * params.qkv_stride_in_bytes; - // Add the block index. - row_offset += (int64_t)((binfo.sum_s * NUM_MATS + qkv_offset) * binfo.h + binfo.bidh) * BYTES_PER_ROW; - - // Assemble the final pointer. - qkv_ptr_ += row_offset + col * BYTES_PER_LDG; - } - - // Store data to shared memory. - template< typename Smem_tile > - inline __device__ void commit(Smem_tile &smem_tile) { - smem_tile.store(fetch_); - } - - // Load data from memory. - template< typename Smem_tile > - inline __device__ void load(Smem_tile &smem_tile) { - const void *ptrs[LDGS]; - uint32_t preds[LDGS]; - #pragma unroll - for( int ii = 0; ii < LDGS; ++ii ) { - ptrs[ii] = qkv_ptr_ + (int64_t)ii * ROWS_PER_LDG * params_qkv_stride_in_bytes_; - preds[ii] = ((row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen)); - fetch_[ii] = make_uint4(0, 0, 0, 0); - } - - // not packing predicates removes restrictions (e.g. FP16 384, 4 warps) - Ldg_functor fct(fetch_, ptrs); - #pragma unroll - for( int ii = 0; ii < LDGS; ++ii ) { - fct.load(ii, preds[ii]); - } - } - - // Store data to memory. - inline __device__ void store(const uint4 (&data)[LDGS]) { - #pragma unroll - for( int ii = 0; ii < LDGS; ++ii ) { - char *ptr = qkv_ptr_ + (int64_t)ii * ROWS_PER_LDG * params_qkv_stride_in_bytes_; - if( (row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen) ) { - fmha::stg(ptr, data[ii]); - } - } - } - - // Move the pointer to the next location. - inline __device__ void move() { - qkv_ptr_ += (int64_t)ROWS * params_qkv_stride_in_bytes_; - actual_seqlen -= ROWS; - } - - inline __device__ void move(int steps) { - qkv_ptr_ += (int64_t)ROWS * params_qkv_stride_in_bytes_ * steps; - actual_seqlen -= ROWS * steps; - } - - // The stride between rows for the QKV matrice. - int64_t params_qkv_stride_in_bytes_; - // The pointer. - char *qkv_ptr_; - // The fetch registers. - uint4 fetch_[LDGS]; - // Keep track of the row the thread is processing as we move the tile. - int row_; - // The length of the sequence loaded by that memory tile. - int actual_seqlen; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< typename Cta_tile > -struct Gmem_tile_o { - - // The mma tile. - using Mma_tile = fmha::Hmma_tile; - - // The size of each element. - enum { BYTES_PER_ELEMENT = 2 }; - // The size of a row in bytes. - enum { BYTES_PER_ROW = Cta_tile::N * BYTES_PER_ELEMENT }; - - // The number of threads to store a "row" of the matrix. - enum { THREADS_PER_ROW = 16 }; - // The size of each STG. - enum { BYTES_PER_STG = BYTES_PER_ROW / THREADS_PER_ROW }; - - // The number of "rows" stored per iteration of the loop. The output of 1 MMA. - enum { ROWS = Cta_tile::M }; - // The number of "rows" stored per iteration of the loop. The output of 1 MMA. - enum { ROWS_PER_LOOP = ROWS <= 64 ? ROWS : (int)Mma_tile::M_PER_MMA_PER_CTA }; - // The number of outter loop for the stores. - enum { LOOPS = ROWS / ROWS_PER_LOOP }; - - // The number of "rows" stored per STG. - enum { ROWS_PER_STG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; - // Do we have to guard against partial writes/reads. - enum { HAS_INCOMPLETE_STG = Cta_tile::M % ROWS_PER_STG != 0 }; - // The number of STGs needed to store a chunk of the Q matrix. - enum { STGS_PER_LOOP = fmha::Div_up::VALUE }; - // The number of STGs needed to store a chunk of the Q matrix in total. - enum { STGS = STGS_PER_LOOP * LOOPS }; - - // Ctor. - template - inline __device__ Gmem_tile_o(const Params ¶ms, const BInfo &binfo, int tidx) - : params_o_stride_in_bytes_(params.o_stride_in_bytes) - , actual_seqlen_(binfo.actual_seqlen) - , o_ptr_(reinterpret_cast(params.o_ptr)) { - - // Compute the position in the sequence (within the CTA for the moment). - int row = tidx / THREADS_PER_ROW; - // Compute the position of the thread in the row. - int col = tidx % THREADS_PER_ROW; - - // Store the row as we need it to disable loads. - row_ = row; - - // The row offset in the batched GEMM. - int64_t row_offset = (int64_t)row * params.o_stride_in_bytes + binfo.bidx * BYTES_PER_ROW; - // Assemble the final pointer. - o_ptr_ += row_offset + col * BYTES_PER_STG; - - // Is that thread active on the last STG? - if( HAS_INCOMPLETE_STG ) { - is_active_for_last_stg_ = row + (STGS - 1) * ROWS_PER_STG < Cta_tile::M; - } - } - - // Store data to global memory. - inline __device__ void store(const uint4 (&src)[STGS_PER_LOOP], int mi) { - - #pragma unroll - for( int ii = 0; ii < STGS_PER_LOOP; ++ii ) { - int jj = mi * STGS_PER_LOOP + ii; - if( this->row_ + jj * ROWS_PER_STG >= this->actual_seqlen_ ) { - break; - } - - float x = reinterpret_cast(src[ii].x); - float y = reinterpret_cast(src[ii].y); - float z = reinterpret_cast(src[ii].z); - float w = reinterpret_cast(src[ii].w); - uint2 out = float4_to_half4(x, y, z, w); - if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) { - fmha::stg(this->o_ptr_ + jj * ROWS_PER_STG * this->params_o_stride_in_bytes_, out); - } - } - } - - // Move the pointer to the next location. - inline __device__ void move() { - row_ += ROWS; - o_ptr_ += (int64_t)ROWS * params_o_stride_in_bytes_; - } - - inline __device__ void move(const int steps) { - row_ += ROWS * steps; - o_ptr_ += (int64_t)ROWS * params_o_stride_in_bytes_ * steps; - } - - // The stride between rows for the QKV matrice. - int64_t params_o_stride_in_bytes_; - // The pointer. - char *o_ptr_; - // Is the thread active for the last STG? - int is_active_for_last_stg_; - // Keep track of the row to disable loads. - int row_; - // The length of the sequence loaded by that memory tile. - int actual_seqlen_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< typename Cta_tile, int BYTES_PER_ELEMENT > -struct Gmem_tile_mma_sd { - - // The mma tile. - using Mma_tile = fmha::Hmma_tile; - - // Each STG stores 8 elements. - enum { BYTES_PER_STG = BYTES_PER_ELEMENT * 8 }; - // The number of MMAs in the M dimension. - enum { MMAS_M = Mma_tile::MMAS_M }; - // The number of MMAs in the N dimension. - enum { MMAS_N = Mma_tile::MMAS_N }; - // The number of rows computed per MMA per thread block. - enum { M_PER_MMA_PER_CTA = Mma_tile::M_PER_MMA_PER_CTA }; - // The number of cols computed per MMA per thread block. - enum { N_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA }; - // The number of threads per block. - enum { THREADS_PER_CTA = Cta_tile::THREADS_PER_CTA }; - // The size of each row in bytes. I.e. how many bytes are stored per STG. - enum { BYTES_PER_ROW = THREADS_PER_CTA * BYTES_PER_STG }; - // The fixed sequence length. - enum { SEQLEN = Cta_tile::N }; - // The distance between two blocks (in bytes). - enum { BLOCK_STRIDE_BYTES = SEQLEN * SEQLEN * BYTES_PER_ELEMENT }; - // The distance between elements stored per loop (in bytes). - enum { LOOP_STRIDE_BYTES = MMAS_M * MMAS_N * BYTES_PER_ROW }; - - // The type of elements stored per STG. - using Type = typename fmha::Uint_from_size_in_bytes::Type; - - // Ctor. - template - inline __device__ Gmem_tile_mma_sd(void *ptr, const Params ¶ms, const int bidb, const int bidh, const int tidx) - : ptr_(static_cast(ptr)) { - - // The block index. - size_t bidx = bidb * params.h + bidh; - - // Set store location for each thread at the beginning of the loop - ptr_ += bidx * BLOCK_STRIDE_BYTES + tidx * BYTES_PER_STG; - } - - // Store to global memory. - inline __device__ void store(const Type &data, const int mi, const int ni) { - size_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW; - fmha::stg(ptr_ + offset, data); - } - - // Load from global memory. - inline __device__ void load(Type &data, const int mi, const int ni) { - size_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW; - fmha::ldg(data, ptr_ + offset); - } - - // Move to the next tile. - inline __device__ void move() { - ptr_ += LOOP_STRIDE_BYTES; - } - inline __device__ void move(const int steps) { - ptr_ += LOOP_STRIDE_BYTES * steps; - } - - // The pointer in global memory. - char *ptr_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< typename Cta_tile, typename Base = Gmem_tile_mma_sd > -struct Gmem_tile_mma_s : public Base { - - // The number of mmas in the vertical dimension. - enum { M = Base::MMAS_M }; - // The number of mmas in the horizontal dimension. - enum { N = Base::MMAS_N }; - // The type of the vectors stored by each STG. - using Type = typename Base::Type; - - // Ctor. - template< typename Params, typename Block_info > - inline __device__ Gmem_tile_mma_s(const Params ¶ms, const Block_info& binfo, const int tidx) - : Base(params.s_ptr, params, binfo.bidb, binfo.bidh, tidx) { - } - - // Store to global memory. - template - inline __device__ void store(const float (&softmax)[2 * M][4 * N], const Mask &mask) { - #pragma unroll - for( int mi = 0; mi < M; mi++ ) { - #pragma unroll - for( int ni = 0; ni < N; ni++ ) { - - float tmp00 = softmax[2 * mi + 0][4 * ni + 0]; - float tmp01 = softmax[2 * mi + 0][4 * ni + 1]; - float tmp02 = softmax[2 * mi + 0][4 * ni + 2]; - float tmp03 = softmax[2 * mi + 0][4 * ni + 3]; - - float tmp10 = softmax[2 * mi + 1][4 * ni + 0]; - float tmp11 = softmax[2 * mi + 1][4 * ni + 1]; - float tmp12 = softmax[2 * mi + 1][4 * ni + 2]; - float tmp13 = softmax[2 * mi + 1][4 * ni + 3]; - - uint4 dst; - dst.x = fmha::float2_to_half2(tmp00, tmp01); - dst.y = fmha::float2_to_half2(tmp02, tmp03); - dst.z = fmha::float2_to_half2(tmp10, tmp11); - dst.w = fmha::float2_to_half2(tmp12, tmp13); - if( mask.is_valid(mi, ni, 0, 0) ) { - Base::store(dst, mi, ni); - } - } - } - } - - // Store to global memory. - template - inline __device__ void store(const Fragment (&frag)[N][M], const Mask& mask){ - #pragma unroll - for( int mi = 0; mi < M; mi++ ) { - #pragma unroll - for( int ni = 0; ni < N; ni++ ) { - uint4 dst; - dst.x = frag[ni][mi].reg(0); - dst.y = frag[ni][mi].reg(2); - dst.z = frag[ni][mi].reg(1); - dst.w = frag[ni][mi].reg(3); - if( mask.any_valid(mi, ni) ) { - Base::store(dst, mi, ni); - } - } - } - } - - // Load from global memory. - template - inline __device__ void load(uint4 (®s)[M][N], const Mask &mask) { - #pragma unroll - for( int mi = 0; mi < M; mi++ ) { - #pragma unroll - for( int ni = 0; ni < N; ni++ ) { - regs[mi][ni] = make_uint4(0, 0, 0, 0); - if( mask.any_valid(mi, ni) ) { - Base::load(regs[mi][ni], mi, ni); - } - } - } - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< - // The dimensions of the tile computed by the CTA. - typename Cta_tile, - // The base class. - typename Base = fmha::Gmem_tile_qkv -> -struct Gmem_tile_dout : public Base { - - // Ctor. - template - inline __device__ Gmem_tile_dout(const Params ¶ms, const BInfo &binfo, int tidx) - : Base(params, 0, binfo, tidx) { - - this->qkv_ptr_ = reinterpret_cast(params.o_ptr); - this->params_qkv_stride_in_bytes_ = params.o_stride_in_bytes; // needed for move - - // Compute the position of the thread in the row. - int col = tidx % Base::THREADS_PER_ROW; - - // The row offset in the batched GEMM. For each seq element, we store O in that order. - int64_t row_offset = (int64_t)this->row_ * params.o_stride_in_bytes + binfo.bidx * Base::BYTES_PER_ROW; - - // Assemble the final pointer. - this->qkv_ptr_ += row_offset + col * Base::BYTES_PER_LDG; - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< typename Cta_tile, typename Base = fmha::Gmem_tile_o > -struct Gmem_tile_dq : public Base { - - // Ctor. - template - inline __device__ Gmem_tile_dq(const Params ¶ms, const BInfo &binfo, int tidx) - : Base(params, binfo, tidx) { - this->o_ptr_ = reinterpret_cast(params.dqkv_ptr); - this->params_o_stride_in_bytes_ = params.qkv_stride_in_bytes; // needed for move - - // Compute the position of the thread in the row. - int col = tidx % Base::THREADS_PER_ROW; - - // The row offset in the batched GEMM. For each seq element, we store O in that order. - int64_t row_offset = (int64_t)this->row_ * params.qkv_stride_in_bytes + - (binfo.sum_s * 3 * binfo.h + binfo.bidh) * Base::BYTES_PER_ROW; - - // Assemble the final pointer. - this->o_ptr_ += row_offset + col * Base::BYTES_PER_STG; - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace fmha - diff --git a/apex/contrib/csrc/fmha/src/fmha/kernel_traits.h b/apex/contrib/csrc/fmha/src/fmha/kernel_traits.h deleted file mode 100644 index d51b47c..0000000 --- a/apex/contrib/csrc/fmha/src/fmha/kernel_traits.h +++ /dev/null @@ -1,97 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#pragma once - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct FMHA_kernel_traits { - - // The CTA description for the 1st GEMM. - using Cta_tile_p = fmha::Cta_tile_extd; - // The CTA description for the 2nd GEMM. - using Cta_tile_o = fmha::Cta_tile_extd; - - // Do we use one buffer for K and V. - enum { SHARE_SMEM_FOR_K_AND_V = (FLAGS & 0x08u) != 0u }; - // Do we keep K in registers. - enum { K_IN_REGS = (FLAGS & 0x10u) == 0u }; - - // The global memory tile to load Q. - using Gmem_tile_q = fmha::Gmem_tile_qkv; - - // The shared memory tile to swizzle Q. - using Smem_tile_q = fmha::Smem_tile_a; - - // The global memory tile to load K. - using Gmem_tile_k = fmha::Gmem_tile_qkv; - // The shared memory tile to swizzle K. - using Smem_tile_k = fmha::Smem_tile_b; - - // The global memory tile to load V. - using Gmem_tile_v = fmha::Gmem_tile_qkv; - // The shared memory tile to swizzle V. - using Smem_tile_v = fmha::Smem_tile_v; - - // The global memory tile to store O. - using Gmem_tile_o = fmha::Gmem_tile_o; - // The shared memory tile for O. - using Smem_tile_o = fmha::Smem_tile_o; - - // The global memory tile to load/store S. - using Gmem_tile_s = fmha::Gmem_tile_mma_s; - - // The shared memory tile to transpose S. - using Smem_tile_st = fmha::Smem_tile_mma_transposed; - - using Gmem_tile_do = fmha::Gmem_tile_dout; - - // Make sure the number of threads match. - static_assert((int)Gmem_tile_o::THREADS_PER_ROW == (int)Smem_tile_o::THREADS_PER_ROW, ""); - - // The number of threads. - enum { THREADS = Cta_tile_p::THREADS_PER_CTA }; - // Make sure the number of threads matches both CTAs. - static_assert((int)THREADS == (int)Cta_tile_o::THREADS_PER_CTA, ""); - - // The amount of shared memory needed to load Q and K. - enum { BYTES_PER_SMEM_QK = Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE }; - // The extra amount of shared memory needed to load V. - enum { BYTES_PER_SMEM_V = SHARE_SMEM_FOR_K_AND_V ? 0u : Smem_tile_v::BYTES_PER_TILE }; - // The amount of shared memory needed for Q, K and V.. - enum { BYTES_PER_SMEM_QKV = BYTES_PER_SMEM_QK + BYTES_PER_SMEM_V }; - // The amount of shared memory needed to load Q and store O. - enum { BYTES_PER_SMEM_QO = Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE }; - - // The amount of shared memory needed for Q, K, V and O. - enum { BYTES_PER_SMEM = fmha::Max::VALUE }; - // Make sure we have enough shared memory. - static_assert(Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE <= BYTES_PER_SMEM, ""); -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/apex/contrib/csrc/fmha/src/fmha/mask.h b/apex/contrib/csrc/fmha/src/fmha/mask.h deleted file mode 100644 index 020258a..0000000 --- a/apex/contrib/csrc/fmha/src/fmha/mask.h +++ /dev/null @@ -1,81 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#pragma once - -namespace fmha { - - -template -struct Mask { - using Mma_tile = fmha::Hmma_tile; - - template - __device__ Mask(const Params ¶ms, const BInfo &blockInfo, int tidx) { - - actual_seqlen = blockInfo.actual_seqlen; - - const int warp = tidx / Cta_tile::THREADS_PER_WARP; - const int lane = tidx % Cta_tile::THREADS_PER_WARP; - - static_assert(Cta_tile::WARPS_K == 1, ""); - - // find the warp in the Cta tile - const int warp_n = (warp / Cta_tile::WARPS_M); - const int warp_m = (warp % Cta_tile::WARPS_M); - // decompose warp into 8x4 tile - const int quad = lane / 4; - const int tid = (lane % 4) * 2; - row = warp_m * 16 + quad; - col = warp_n * 16 + tid; - } - - inline __device__ bool is_valid(const int mi, const int ni, const int ii, const int jj) const { - - // ii and jj iterate over the 2x4 fragment - const bool col_valid = (ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1)) < actual_seqlen; - //&& (row + mi * Mma_tile::M_PER_MMA_PER_CTA + ii * 8) < actual_seqlen; - return col_valid; - // return row_valid && col_valid; - } - - //BERT Mask: if upper left is invalid, none are valid - inline __device__ bool any_valid(int mi, int ni) const { - return is_valid(mi, ni, 0, 0); - } - - inline __device__ void load(int it) { - row_offset = it * Cta_tile::M + row; - } - int row_offset; - - int row; - int col; - int actual_seqlen; -}; - -} // namespace fmha diff --git a/apex/contrib/csrc/fmha/src/fmha/smem_tile.h b/apex/contrib/csrc/fmha/src/fmha/smem_tile.h deleted file mode 100644 index 8087914..0000000 --- a/apex/contrib/csrc/fmha/src/fmha/smem_tile.h +++ /dev/null @@ -1,1286 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#pragma once - -#include -#include - -namespace fmha { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< - // The description of the tile computed by this CTA. - typename Cta_tile, - // The number of rows in the 2D shared memory buffer. - int M_, - // The number of cols. - int N_, - // The size in bits of each element. - int BITS_PER_ELEMENT_, - // The number of bytes per STS. - int BYTES_PER_STS_ = 16, - // The number of buffers. (Used in multistage and double buffer cases.) - int BUFFERS_PER_TILE_ = 1, - // Do we enable the fast path for LDS.128 and friends. - int ENABLE_LDS_FAST_PATH_ = 0, - // The number of rows that are used for the XOR swizzling to allow fast STS/LDS. - int ROWS_PER_XOR_PATTERN_ = 8, - // The number of cols that are used for the XOR swizzling to allow fast STS/LDS. - int COLS_PER_XOR_PATTERN_ = 1, - // Use or not predicates - bool USE_PREDICATES_ = true -> -struct Smem_tile_without_skews { - - // The size in bits of each element. - enum { BITS_PER_ELEMENT = BITS_PER_ELEMENT_ }; - // The size in bytes of a single STS. - enum { BYTES_PER_STS = BYTES_PER_STS_ }; - // The number of elements per STS. - enum { ELEMENTS_PER_STS = BYTES_PER_STS * 8 / BITS_PER_ELEMENT }; - // To support arbitrary N, we pad some values to a power-of-2. - enum { N_WITH_PADDING = Next_power_of_two::VALUE }; - // The number of bytes per row without packing of rows. - enum { BYTES_PER_ROW_BEFORE_PACKING = N_WITH_PADDING * BITS_PER_ELEMENT / 8 }; - // The number of bytes per row -- we want at least 128B per row. - enum { BYTES_PER_ROW = Max::VALUE }; - // The number of rows in shared memory (two rows may be packed into a single one). - enum { ROWS = M_ * BYTES_PER_ROW_BEFORE_PACKING / BYTES_PER_ROW }; - - // The number of threads per row. - enum { THREADS_PER_ROW_UNBOUNDED = BYTES_PER_ROW / BYTES_PER_STS }; - // The number of threads per row. - enum { THREADS_PER_ROW = Min::VALUE }; - - // The number of STS per row. - enum { STS_PER_ROW = BYTES_PER_ROW / THREADS_PER_ROW / BYTES_PER_STS }; - // It must be at least one. - static_assert(STS_PER_ROW >= 1, ""); - // The number of rows written with a single STS. - enum { ROWS_PER_STS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; - // Make sure we write to at least one row per STS. Thanks Dr. Obvious ;) - static_assert(ROWS_PER_STS >= 1, ""); - // The number of STS needed to store all rows. - enum { STS_PER_COL = Div_up::VALUE }; - // The number of STS in total. - enum { STS = STS_PER_COL * STS_PER_ROW }; - - // The size of one buffer in bytes in shared memory. - enum { BYTES_PER_BUFFER = STS * BYTES_PER_STS * Cta_tile::THREADS_PER_CTA }; - // The number of buffers. - enum { BUFFERS_PER_TILE = BUFFERS_PER_TILE_ }; - // The size in bytes of total buffers. - enum { BYTES_PER_TILE = BYTES_PER_BUFFER * BUFFERS_PER_TILE }; - // The boundary for smem_read_offset and smem_write_offset increment. - enum { BYTES_PER_TILE_INC_BOUNDARY = BYTES_PER_TILE - BYTES_PER_BUFFER }; - - // Do we enable the LDS.128 fast path? - enum { ENABLE_LDS_FAST_PATH = ENABLE_LDS_FAST_PATH_ }; - static_assert(ENABLE_LDS_FAST_PATH == 0); - // The number of rows that are used for the XOR swizzling to allow fast STS/LDS. - enum { ROWS_PER_XOR_PATTERN = ROWS_PER_XOR_PATTERN_ }; - // The number of cols that are used for the XOR swizzling to allow fast STS/LDS. - enum { COLS_PER_XOR_PATTERN = COLS_PER_XOR_PATTERN_ * 16 / BYTES_PER_STS }; - // Use or not predicates - enum { USE_PREDICATES = USE_PREDICATES_ }; - - // The type of elements that are stored in shared memory by each thread. - using Store_type = typename Uint_from_size_in_bytes::Type; - - // Ctor. - inline __device__ Smem_tile_without_skews(void *smem, int tidx) - : smem_(__nvvm_get_smem_pointer(smem)) { - - // The row written by a thread. See doc/mma_smem_layout.xlsx. - int smem_write_row = tidx / THREADS_PER_ROW; - - // The XOR pattern. - int smem_write_xor = smem_write_row % ROWS_PER_XOR_PATTERN * COLS_PER_XOR_PATTERN; - // Compute the column and apply the XOR pattern. - int smem_write_col = (tidx % THREADS_PER_ROW) ^ smem_write_xor; - - // The offset. - this->smem_write_offset_ = smem_write_row*BYTES_PER_ROW + smem_write_col*BYTES_PER_STS; - - // TODO: Why not merge it with the read offset? - this->smem_read_buffer_ = __shfl_sync(0xffffffff, 0, 0); - this->smem_write_buffer_ = __shfl_sync(0xffffffff, 0, 0); - } - - // Compute the store pointers. - template< int N > - inline __device__ void compute_store_pointers(uint32_t (&ptrs)[N]) { - #pragma unroll - for( int ii = 0; ii < N; ++ii ) { - // Decompose the STS into row/col. - int row = ii / STS_PER_ROW; - int col = ii % STS_PER_ROW; - - // Assemble the offset. - int offset = smem_write_offset_ + row*ROWS_PER_STS*BYTES_PER_ROW; - - // Take the column into account. - if( STS_PER_ROW > 1 ) { - offset += col*THREADS_PER_ROW*BYTES_PER_STS; - } - - // Apply the XOR pattern if needed. - if( ROWS_PER_STS < ROWS_PER_XOR_PATTERN ) { - const int m = row * ROWS_PER_STS % ROWS_PER_XOR_PATTERN; - offset ^= m * COLS_PER_XOR_PATTERN * BYTES_PER_STS; - } - - // Assemble the final pointer :) - ptrs[ii] = smem_ + offset + smem_write_buffer_; - } - } - - inline __device__ void debug_reset() { - for( int buffer = 0; buffer < BYTES_PER_TILE; buffer += BYTES_PER_BUFFER) { - for( int row = 0; row < ROWS; ++row ) { - for( int col = 0; col < BYTES_PER_ROW; col += 4 ) { - if( threadIdx.x == 0 ) { - uint32_t val = 0x0; - sts(val, smem_ + row*BYTES_PER_ROW + col + buffer); - } - } - } - } - } - - // Print the content of the tile (only for debug ;)). - inline __device__ void debug_print() const { - for( int buffer = 0; buffer < BYTES_PER_TILE; buffer += BYTES_PER_BUFFER) { - for( int row = 0; row < ROWS; ++row ) { - for( int col = 0; col < BYTES_PER_ROW; col += 4 ) { - if( threadIdx.x == 0 ) { - uint32_t val; - lds(val, smem_ + row*BYTES_PER_ROW + col + buffer); - printf("block=(x=%2d, y=%2d, z=%2d) (smem_=%2d, buffer=%2d, row=%2d, byte=%4d)=0x%08x\n", - blockIdx.x, - blockIdx.y, - blockIdx.z, - smem_, - buffer, - row, - col, - val); - } - } - } - } - } - - // Move the read offset to next buffer. - inline __device__ void move_to_next_read_buffer() { - if( BUFFERS_PER_TILE > 1 && smem_read_buffer_ >= BYTES_PER_TILE_INC_BOUNDARY ) { - this->smem_read_buffer_ -= BYTES_PER_TILE_INC_BOUNDARY; - } else if( BUFFERS_PER_TILE > 1 ) { - this->smem_read_buffer_ += BYTES_PER_BUFFER; - } - } - - // Move the read offset to next buffer. TODO: Remove this member function!!! - inline __device__ void move_next_read_buffer() { - this->move_to_next_read_buffer(); - } - - // Move the read offset to next N buffer (circular-buffer). - inline __device__ void move_to_next_read_buffer(int N) { - if( BUFFERS_PER_TILE > 1 ) { - this->smem_read_buffer_ += N * BYTES_PER_BUFFER; - this->smem_read_buffer_ -= smem_read_buffer_ >= BYTES_PER_TILE ? BYTES_PER_TILE : 0; - } - } - - // Move the read offset to next N buffer (circular-buffer). TODO: Remove this member function!!! - inline __device__ void move_next_read_buffer(int N) { - this->move_to_next_read_buffer(N); - } - - // Move the write offset to next buffer. - inline __device__ void move_to_next_write_buffer() { - if( BUFFERS_PER_TILE > 1 && smem_write_buffer_ >= BYTES_PER_TILE_INC_BOUNDARY ) { - this->smem_write_buffer_ -= BYTES_PER_TILE_INC_BOUNDARY; - } else if( BUFFERS_PER_TILE > 1 ) { - this->smem_write_buffer_ += BYTES_PER_BUFFER; - } - } - - // Move the write offset to next buffer. TODO: Remove that member function! - inline __device__ void move_next_write_buffer() { - this->move_to_next_write_buffer(); - } - - // Move the read offset. - inline __device__ void move_read_offset(int delta) { - this->smem_read_offset_ += delta; - } - - // Move the write offset. - inline __device__ void move_write_offset(int delta) { - this->smem_write_offset_ += delta; - } - - // Store to the tile in shared memory. - template< int N > - inline __device__ void store(const Store_type (&data)[N], uint64_t = 0) { - uint32_t smem_ptrs[N]; - this->compute_store_pointers(smem_ptrs); - sts(smem_ptrs, data); - } - - // Store to the tile in shared memory. - template< int N, int M > - inline __device__ void store(const Store_type (&data)[N], uint32_t (&preds)[M], uint64_t = 0) { - uint32_t smem_ptrs[N]; - this->compute_store_pointers(smem_ptrs); - sts(smem_ptrs, data, preds); - } - - // Store to the tile in shared memory. - template< int N > - inline __device__ void store(const Store_type (&data)[N], uint32_t preds, uint64_t = 0) { - this->store(data, preds); - } - - // Store to the tile in shared memory. - template< int N > - inline __device__ void store(const void* (&gmem_ptrs)[N], uint32_t preds, uint64_t = 0) { - uint32_t tmp[1] = { preds }; - this->store(gmem_ptrs, tmp); - } - - // The shared memory pointer. - uint32_t smem_; - // The read offset. Reserve 4 offsets if needed. - int smem_read_offset_; - // The write offset. - int smem_write_offset_; - // The buffer base offset for read. - int smem_read_buffer_; - // The buffer base offset for write. - int smem_write_buffer_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< - // The dimensions of the tile computed by the CTA. - typename Cta_tile, - // The layout of the tile. - typename Layout, - // The size of the STS. - int BYTES_PER_STS = 16, - // The number of buffers per tile. - int BUFFERS_PER_TILE = 1, - // Use or not predicates - bool USE_PREDICATES = true -> -struct Smem_tile_a { -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int MMAS_K, int MMAS_K_WITH_PADDING > -struct Compute_reset_mask { - // The potential mask. - enum { HALF = MMAS_K_WITH_PADDING / 2 }; - // The remainder. - enum { MOD = MMAS_K % HALF }; - // The final value. - enum { VALUE = (MMAS_K == MOD ? 0 : HALF) | Compute_reset_mask::VALUE }; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int MMAS_K_WITH_PADDING > -struct Compute_reset_mask<0, MMAS_K_WITH_PADDING> { - enum { VALUE = 0 }; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int MMAS_K > -struct Compute_reset_mask { - enum { VALUE = MMAS_K - 1 }; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int N > -struct Rows_per_xor_pattern_a { - // The size in bits. - enum { N_IN_BITS = N * fmha::BITS_PER_ELEMENT_A }; - // The number of rows. - enum { VALUE = N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8) }; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int N > -struct Rows_per_xor_pattern_row_a : public Rows_per_xor_pattern_a { -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< - // The dimensions of the tile computed by the CTA. - typename Cta_tile, - // The size of the STS. - int BYTES_PER_STS, - // The number of buffers per tile. - int BUFFERS_PER_TILE, - // How many rows to use for the XOR pattern to avoid bank conflicts? - int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_row_a::VALUE -> -struct Smem_tile_row_a : public Smem_tile_without_skews { - // The MMA tile. - using Mma_tile = fmha::Hmma_tile; - // The base class. - using Base = Smem_tile_without_skews; - // The fragment. - using Fragment = Fragment_a; - - // When we use padding to reach a power of two, special care has to be taken. - using Cta_tile_with_padding = Cta_tile_with_k_with_padding; - // The number of MMAs. - using Mma_tile_with_padding = fmha::Hmma_tile; - - // The size of a single LDS in bytes. - enum { BYTES_PER_LDS = 16 }; - - // Ctor. - inline __device__ Smem_tile_row_a(void *smem, int tidx) : Base(smem, tidx) { - - // For documentation on the layout, see doc/mma_smem_layout.xlsx. - - // The number of warps. - const int WARPS_M = Cta_tile::WARPS_M; - const int WARPS_N = Cta_tile::WARPS_N; - const int WARPS_K = Cta_tile::WARPS_K; - - static_assert(WARPS_M == 1); - static_assert(WARPS_N == 4 || WARPS_N == 8); - static_assert(WARPS_K == 1); - static_assert(Base::ROWS_PER_XOR_PATTERN == 8); - - // The row and column read by the thread. - int smem_read_row = (tidx & 0x0f); - int smem_read_col = (tidx & 0x07); - smem_read_col ^= (tidx & 0x10) / 16; - - // The shared memory offset. - this->smem_read_offset_ = smem_read_row*Base::BYTES_PER_ROW + smem_read_col*BYTES_PER_LDS; - } - - // Rewind smem_read_offset for last LDS phase in main loop. - inline __device__ void reverse_smem_read_offset(int ki = 0) { - // Undo the pointer increment for the next ni. - // Should match the load function below for ki = 0. - if( Mma_tile_with_padding::MMAS_K >= 2 ) { - this->smem_read_offset_ ^= BYTES_PER_LDS * 2; - } - } - - // Load from shared memory. - inline __device__ void load(Fragment (&a)[Mma_tile::MMAS_M], int ki) { - #pragma unroll - for( int mi = 0; mi < Mma_tile::MMAS_M; ++mi ) { - // Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows). - int offset = mi * Mma_tile::M_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING; - - // Load using LDSM.M88.4. - uint4 tmp; - ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset); - - // Store the value into the fragment. - a[mi].reg(0) = tmp.x; - a[mi].reg(1) = tmp.y; - a[mi].reg(2) = tmp.z; - a[mi].reg(3) = tmp.w; - } - - // Move the offset to the next possition. See doc/mma_smem_layout.xlsx. - static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented"); - if( Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15 ) { - this->smem_read_offset_ ^= 31 * BYTES_PER_LDS * 2; - } else if( Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7 ) { - this->smem_read_offset_ ^= 15 * BYTES_PER_LDS * 2; - } else if( Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3 ) { - this->smem_read_offset_ ^= 7 * BYTES_PER_LDS * 2; - } else if( Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1 ) { - this->smem_read_offset_ ^= 3 * BYTES_PER_LDS * 2; - } else if( Mma_tile_with_padding::MMAS_K >= 2 ) { - this->smem_read_offset_ ^= 1 * BYTES_PER_LDS * 2; - } - } - - // Reset the read offset. - inline __device__ void reset_read_offset() { - // The number of MMAs in the K dimension. - enum { MMAS_K = Mma_tile::MMAS_K }; - // The number of MMAs in the K dimension when we include padding. - enum { MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K }; - // Assemble the mask. - enum { MASK = Compute_reset_mask::VALUE }; - - // Reset the read offset. - this->smem_read_offset_ ^= MASK * BYTES_PER_LDS * 2; - } - -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< - // The dimensions of the tile computed by the CTA. - typename Cta_tile, - // The size of the STS. - int BYTES_PER_STS, - // The number of buffers per tile. - int BUFFERS_PER_TILE -> -struct Smem_tile_a - : public Smem_tile_row_a { - // The base class. - using Base = Smem_tile_row_a; - - // Ctor. - inline __device__ Smem_tile_a(void *smem, int tidx) : Base(smem, tidx) { - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< - // The dimensions of the tile computed by the CTA. - typename Cta_tile, - // The layout of the tile. - typename Layout, - // The size of the STS. - int BYTES_PER_STS = 16, - // The number of buffers per tile. - int BUFFERS_PER_TILE = 1, - // Use or not predicates - bool USE_PREDICATES = true -> -struct Smem_tile_b { -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int N > -struct Rows_per_xor_pattern_b { - // The size in bits. - enum { N_IN_BITS = N * fmha::BITS_PER_ELEMENT_B }; - // The number of rows. - enum { VALUE = N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8) }; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int N > -struct Rows_per_xor_pattern_col_b : public Rows_per_xor_pattern_b { -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< - // The dimensions of the tile computed by the CTA. - typename Cta_tile, - // The size of the STS. - int BYTES_PER_STS, - // The number of buffers per tile. - int BUFFERS_PER_TILE, - // How many rows to use for the XOR pattern to avoid bank conflicts? - int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_col_b::VALUE -> -struct Smem_tile_col_b : public Smem_tile_without_skews { - // The MMA tile. - using Mma_tile = fmha::Hmma_tile; - // The base class. - using Base = Smem_tile_without_skews; - // The fragment. - using Fragment = Fragment_b< Col>; - - // When we use padding to reach a power of two, special care has to be taken. - using Cta_tile_with_padding = Cta_tile_with_k_with_padding< Cta_tile>; - // The number of MMAs. - using Mma_tile_with_padding = fmha::Hmma_tile; - - // The size of a single LDS in bytes. - enum { BYTES_PER_LDS = 16 }; - - // The number of STS per thread - enum { STS_PER_THREAD_ = Base::ROWS * Base::THREADS_PER_ROW / Cta_tile::THREADS_PER_CTA }; - // The number of STS per thread must be at least 1. - enum { STS_PER_THREAD = Max<1, STS_PER_THREAD_>::VALUE }; - - // Ctor. - inline __device__ Smem_tile_col_b(void *smem, int tidx) : Base(smem, tidx) { - - // For documentation on the layout, see doc/mma_smem_layout.xlsx. - - // The number of warps. - const int WARPS_M = Cta_tile::WARPS_M; - const int WARPS_N = Cta_tile::WARPS_N; - const int WARPS_K = Cta_tile::WARPS_K; - static_assert(Base::ROWS_PER_XOR_PATTERN == 8); - static_assert(WARPS_M == 1); - static_assert(WARPS_N == 4 || WARPS_N == 8); - static_assert(WARPS_K == 1); - - // The masks to select the warps. - const int WARP_MASK_N = Warp_masks::N; - - // The divisor for the warps. - const int WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP; - - // The row and column read by the thread. - int smem_read_row = (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA + - (tidx & 0x07) + - (tidx & 0x10) / 2; - int smem_read_col = (tidx & 0x07); - smem_read_col ^= (tidx & 0x08) / 8; - // The shared memory offset. - this->smem_read_offset_ = smem_read_row*Base::BYTES_PER_ROW + smem_read_col*BYTES_PER_LDS; - } - - // Rewind smem_read_offset for last LDS phase in main loop. - inline __device__ void reverse_smem_read_offset(int ki = 0) { - // Undo the pointer increment for the next ni. - // Should match the load function below for ki = 0. - if( Mma_tile_with_padding::MMAS_K >= 2 ) { - this->smem_read_offset_ ^= BYTES_PER_LDS * 2; - } - } - - // Load from shared memory. - inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) { - #pragma unroll - for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) { - // Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows). - int offset = ni * Mma_tile::N_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING; - - // Load using LDSM.M88.4. - uint4 tmp; - ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset); - - // Store the value into the fragment. - b[ni].reg(0) = tmp.x; - b[ni].reg(1) = tmp.y; - b[ni].reg(2) = tmp.z; - b[ni].reg(3) = tmp.w; - } - - // Move the offset to the next possition. See doc/mma_smem_layout.xlsx. - static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented"); - if( Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15 ) { - this->smem_read_offset_ ^= 31 * BYTES_PER_LDS * 2; - } else if( Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7 ) { - this->smem_read_offset_ ^= 15 * BYTES_PER_LDS * 2; - } else if( Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3 ) { - this->smem_read_offset_ ^= 7 * BYTES_PER_LDS * 2; - } else if( Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1 ) { - this->smem_read_offset_ ^= 3 * BYTES_PER_LDS * 2; - } else if( Mma_tile_with_padding::MMAS_K >= 2 ) { - this->smem_read_offset_ ^= 1 * BYTES_PER_LDS * 2; - } - } - - // Reset the read offset. - inline __device__ void reset_read_offset() { - // The number of MMAs in the K dimension. - enum { MMAS_K = Mma_tile::MMAS_K }; - // The number of MMAs in the K dimension when we include padding. - enum { MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K }; - // Assemble the mask. - enum { MASK = Compute_reset_mask::VALUE }; - - // Reset the read offset. - this->smem_read_offset_ ^= MASK * BYTES_PER_LDS * 2; - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< - // The dimensions of the tile computed by the CTA. - typename Cta_tile, - // The size of the STS. - int BYTES_PER_STS, - // The number of buffers per tile. - int BUFFERS_PER_TILE -> -struct Smem_tile_b< Cta_tile, Col, BYTES_PER_STS, BUFFERS_PER_TILE > - : public Smem_tile_col_b { - - // The base class. - using Base = Smem_tile_col_b< Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>; - - // Ctor. - inline __device__ Smem_tile_b(void *smem, int tidx) : Base(smem, tidx) { - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int N > -struct Rows_per_xor_pattern_row_b : public Rows_per_xor_pattern_b< N> { -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template< - // The dimensions of the tile computed by the CTA. - typename Cta_tile, - // The size of the STS. - int BYTES_PER_STS, - // The number of buffers per tile. - int BUFFERS_PER_TILE, - // How many rows to use for the XOR pattern to avoid bank conflicts? - int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_row_b::VALUE, - // How many cols to use for the XOR pattern to avoid bank conflicts? - int COLS_PER_XOR_PATTERN_ = 1 -> -struct Smem_tile_row_b : public Smem_tile_without_skews { - - // The MMA tile. - using Mma_tile = fmha::Hmma_tile; - // The base class. - using Base = Smem_tile_without_skews; - // The fragment. - using Fragment = Fragment_b; - - // Can we use LDSM? No if the data type is 32-bit large. - enum { USE_LDSMT = fmha::BITS_PER_ELEMENT_B == 16 }; - // The size of a single LDS in bytes. - enum { BYTES_PER_LDS = USE_LDSMT ? 16 : 4 }; - // The number of elements per LDS. - enum { ELEMENTS_PER_LDS = BYTES_PER_LDS * 8 / fmha::BITS_PER_ELEMENT_B }; - - // The number of STS per thread - enum { STS_PER_THREAD_ = Base::ROWS * Base::THREADS_PER_ROW / Cta_tile::THREADS_PER_CTA }; - // The number of STS per thread must be at least 1. - enum { STS_PER_THREAD = Max<1, STS_PER_THREAD_>::VALUE }; - - // Ctor. - inline __device__ Smem_tile_row_b(void *smem, int tidx) : Base(smem, tidx) { - - // The number of warps. - const int WARPS_M = Cta_tile::WARPS_M; - const int WARPS_N = Cta_tile::WARPS_N; - const int WARPS_K = Cta_tile::WARPS_K; - static_assert(WARPS_K == 1); - static_assert(WARPS_M == 4 || WARPS_M == 8); - static_assert(WARPS_N == 1); - - // The masks to select the warps. - const int WARP_MASK_N = Warp_masks::N; - const int WARP_MASK_K = Warp_masks::K; - - // The divisor for the warps. - const int WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP; - const int WARP_DIV_K = WARPS_M * WARPS_N * Cta_tile::THREADS_PER_WARP; - - // The row/col read by the thread. - int smem_read_row, smem_read_col; - - static_assert(USE_LDSMT); - static_assert(Base::ROWS_PER_XOR_PATTERN == 8); - - smem_read_row = (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile::MMAS_K * 16 + - (tidx & 0x07) + (tidx & 0x08); - smem_read_col = (tidx & 0x07); - smem_read_col ^= (tidx & WARP_MASK_N) / WARP_DIV_N * 2 + (tidx & 0x10) / 16; - - // The shared memory offset. - this->smem_read_offset_ = smem_read_row*Base::BYTES_PER_ROW + smem_read_col*BYTES_PER_LDS; - - // Fill zeroes for group conv - } - - // Rewind smem_read_offset for last LDS phase in main loop. - inline __device__ void reverse_smem_read_offset(int ki = 0) { - // The size of each element in bits. - const int BITS_PER_ELT = fmha::BITS_PER_ELEMENT_B; - // The size in bytes of the data needed to compute an MMA per CTA. - const int BYTES_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA * BITS_PER_ELT / 8; - - #pragma unroll - for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) { - // Undo the pointer increment for the next ni. - // Should match the load function below for ki = 0. - if( BYTES_PER_MMA_PER_CTA >= 128 ) { - // Nothing to do! - } else if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 ) { - this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA; - } else if( BYTES_PER_MMA_PER_CTA == 64 ) { - // Nothing to do! - } else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 4 ) { - this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6); - } else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 2 ) { - this->smem_read_offset_ ^= BYTES_PER_LDS * 2; - } - } - - // Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels) - if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 && - Mma_tile::MMAS_N % 2 == 1 ) { - this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA; - } - } - - // Load from shared memory. - inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) { - // The size of each element in bits. - const int BITS_PER_ELT = fmha::BITS_PER_ELEMENT_B; - // The size in bytes of the data needed to compute an MMA per CTA. - const int BYTES_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA * BITS_PER_ELT / 8; - - #pragma unroll - for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) { - // Prepare the offset. - int offset = ki * Base::ROWS_PER_XOR_PATTERN * 2 * Base::BYTES_PER_ROW; - if ( BYTES_PER_MMA_PER_CTA == 32 ) { - offset += this->smem_read_offset_; - } else if ( BYTES_PER_MMA_PER_CTA == 64 ) { - offset += this->smem_read_offset_ + (ni/2) * BYTES_PER_MMA_PER_CTA * 2; - } else { - offset += this->smem_read_offset_ + (ni ) * BYTES_PER_MMA_PER_CTA; - } - - // Load the data using LDSM.MT88.2. - uint32_t ptr = this->smem_ + this->smem_read_buffer_ + offset; - uint4 tmp; - if( USE_LDSMT ) { - ldsmt(tmp, ptr); - } else { - lds(tmp.x, (ptr ) + 0*Base::BYTES_PER_ROW); - lds(tmp.y, (ptr ) + 4*Base::BYTES_PER_ROW); - lds(tmp.z, (ptr ^ 32) + 0*Base::BYTES_PER_ROW); - lds(tmp.w, (ptr ^ 32) + 4*Base::BYTES_PER_ROW); - } - - // Store those values in the fragment. - b[ni].reg(0) = tmp.x; - b[ni].reg(1) = tmp.y; - b[ni].reg(2) = tmp.z; - b[ni].reg(3) = tmp.w; - - // Move the pointer for the next ni. I expect the compiler to not recompute those. - if( BYTES_PER_MMA_PER_CTA >= 128 ) { - // Nothing to do! - } else if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 ) { - this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA; - } else if( BYTES_PER_MMA_PER_CTA == 64 ) { - // Nothing to do! - } else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 4 ) { - this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6); - } else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 2 ) { - this->smem_read_offset_ ^= BYTES_PER_LDS * 2; - } - } - - // Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels) - if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 && - Mma_tile::MMAS_N % 2 == 1 ) { - this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA; - } - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< - // The dimensions of the tile computed by the CTA. - typename Cta_tile, - // The size of the STS. - int BYTES_PER_STS, - // The number of buffers per tile. - int BUFFERS_PER_TILE -> -struct Smem_tile_b - : public Smem_tile_row_b { - - // The base class. - using Base = Smem_tile_row_b; - - // Ctor. - inline __device__ Smem_tile_b(void *smem, int tidx) : Base(smem, tidx) { - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Smem_tile_v : public fmha::Smem_tile_without_skews { - - // The base class. - using Base = Smem_tile_without_skews; - // The MMA tile. - using Mma_tile = fmha::Hmma_tile; - // The fragment. - using Fragment = Fragment_b< fmha::Col>; - - // The size of a single LDS in bytes. - enum { BYTES_PER_LDS = 16 }; - - // Ctor. - inline __device__ Smem_tile_v(void *smem, int tidx) : Base(smem, tidx) { - - // The row/col read by the thread. - int read_row, read_col; - - static_assert(Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 1 && (Cta_tile::WARPS_K == 4 || Cta_tile::WARPS_K == 8)); - - read_row = (tidx & 0xe0) / 2 + (tidx & 0x0f); - read_col = (tidx & 0x07); - read_col ^= (tidx & 0x10) / 16; - - // The shared memory offset. - this->smem_read_offset_ = read_row * Base::BYTES_PER_ROW + read_col * BYTES_PER_LDS; - } - - // Load from shared memory. - inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) { -#pragma unroll - for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) { - // Jump by 16 * #warps row. - int row = ki * 16 * Cta_tile::WARPS_K; - - // Load the data using LDSM.MT88.2. - uint4 tmp; - fmha::ldsmt(tmp, this->smem_ + this->smem_read_offset_ + row * Base::BYTES_PER_ROW); - b[ni].reg(0) = tmp.x; - b[ni].reg(1) = tmp.y; - b[ni].reg(2) = tmp.z; - b[ni].reg(3) = tmp.w; - - // Move the pointer for the next ni. I expect the compiler to not recompute those. - if( Mma_tile::MMAS_N == 4 ) { - this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6); - } else { - assert(false); // Not implemented! - } - } - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Smem_tile_o { - - // The MMA tile. - using Mma_tile = fmha::Hmma_tile; - // The accumulators. - using Accumulator = fmha::Fragment_accumulator; - // The accumulators. - using Data_type = typename Accumulator::Data_type; - - // The size of each element. - enum { BYTES_PER_ELEMENT = sizeof(Data_type) }; - // The size of each STS. - enum { BYTES_PER_STS = 8 }; - // The size of each row in shared memory. - enum { BYTES_PER_ROW = Cta_tile::N * Cta_tile::WARPS_K * BYTES_PER_ELEMENT }; - - // The size of each LDS. - enum { BYTES_PER_LDS = 16 }; - enum { THREADS_PER_ROW = 16 }; - - // The number of rows. - enum { ROWS = Cta_tile::M }; - // The number of "rows" to process per loop iteration (in the "epilogue"). - enum { ROWS_PER_LOOP = ROWS <= 64 ? ROWS : (int)Mma_tile::M_PER_MMA_PER_CTA }; - // The number of outer loops. - enum { LOOPS = ROWS / ROWS_PER_LOOP }; - // Make sure it matches our expectations. - static_assert(LOOPS == 1 || LOOPS == (int)Mma_tile::MMAS_M, ""); - - // The number of rows loaded per LDS. - enum { ROWS_PER_LDS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; - // Do we have to guard against partial writes/reads. - enum { HAS_INCOMPLETE_LDS = ROWS_PER_LOOP % ROWS_PER_LDS != 0 }; - // The total number of LDS per loop. - enum { LDS_PER_LOOP = fmha::Div_up::VALUE }; - - // The amount of shared memory. - enum { BYTES_PER_TILE = ROWS_PER_LOOP * BYTES_PER_ROW }; - - // The write pointer. - uint32_t smem_write_, smem_read_; - // Is the thread active for the last LDS of the series? - int is_active_for_last_lds_; - - static_assert(BYTES_PER_ROW == 64 * 4 * Cta_tile::WARPS_K); - static_assert(LOOPS == 1 || LOOPS == (int)Mma_tile::MMAS_M, ""); - - // Ctor. - inline __device__ Smem_tile_o(void *smem, int tidx) { - - // Get a 32-bit value for the shared memory address. - uint32_t smem_ = __nvvm_get_smem_pointer(smem); - - static_assert(Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 1 && (Cta_tile::WARPS_K == 4 || Cta_tile::WARPS_K == 8)); - - int write_row = (tidx & 0x1c) / 4; - int write_col = (tidx); - - // Assemble the write pointer. - smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS; - - // The element read by each thread. - int read_row = tidx / THREADS_PER_ROW; - int read_col = tidx % THREADS_PER_ROW; - - // Take the XOR pattern into account for the column. - read_col ^= 2 * (read_row & 0x7); - - // Assemble the read pointer. - this->smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; - - // Is that thread active on the last LDS? - if( HAS_INCOMPLETE_LDS ) { - this->is_active_for_last_lds_ = read_row + (LDS_PER_LOOP - 1) * ROWS_PER_LDS < Cta_tile::M; - } - } - - // Load the output fragments. - inline __device__ void load(uint4 (&out)[LDS_PER_LOOP]) const { - #pragma unroll - for( int ii = 0; ii < LDS_PER_LOOP; ++ii ) { - - // Load the elements before the reduction (split-K). - uint4 tmp[Cta_tile::WARPS_K]; - #pragma unroll - for( int jj = 0; jj < Cta_tile::WARPS_K; ++jj ) { - int imm = ii * ROWS_PER_LDS * BYTES_PER_ROW + jj * Cta_tile::N * BYTES_PER_ELEMENT; - if( !HAS_INCOMPLETE_LDS || (ii < LDS_PER_LOOP - 1 || this->is_active_for_last_lds_) ) { - fmha::lds(tmp[jj], this->smem_read_ + imm); - } - } - - // Perform the reduction. - out[ii] = tmp[0]; - #pragma unroll - for( int jj = 1; jj < Cta_tile::WARPS_K; ++jj ) { - out[ii] = fmha::fadd4(out[ii], tmp[jj]); - } - } - } - // Store the accumulators. - template - inline __device__ void store(const Accumulator (&acc)[M][N], int mi) { - enum { M_PER_MMA = Mma_tile::M_PER_MMA_PER_CTA }; - #pragma unroll - for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) { - - // The number of MMAs that are stored per loop iteration. - enum { MMAS_M_PER_LOOP = Mma_tile::MMAS_M / LOOPS }; - - // Store 1st column of the different MMAs. - #pragma unroll - for( int mj = 0; mj < MMAS_M_PER_LOOP; ++mj ) { - // Precompute the immediates to jump between rows. - int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW; - int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW; - uint2 tmp0, tmp1; - tmp0.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(0); - tmp0.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(1); - - tmp1.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(2); - tmp1.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(3); - - // Store. - fmha::sts(this->smem_write_ + row_0, tmp0); - fmha::sts(this->smem_write_ + row_1, tmp1); - } - - // Swizzle the write pointer using a XOR of 16B. - this->smem_write_ ^= 32; - - // Store 2nd column of the different MMAs. - #pragma unroll - for( int mj = 0; mj < MMAS_M_PER_LOOP; ++mj ) { - // Precompute the immediates to jump between rows. - int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW; - int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW; - - uint2 tmp0, tmp1; - tmp0.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(4); - tmp0.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(5); - - tmp1.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(6); - tmp1.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(7); - // Store. - fmha::sts(this->smem_write_ + row_0, tmp0); - fmha::sts(this->smem_write_ + row_1, tmp1); - } - - // Cancel the previous XOR of 1 + swizzle the write pointer using a XOR of 32B or 64B. - this->smem_write_ ^= (ni & 1) ? 7 * 32 : 3 * 32; - } - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Smem_tile_mma { - - using Mma_tile = fmha::Hmma_tile; - using Fragment = fmha::Fragment_a; - - enum { COLS = Cta_tile::N }; - enum { BYTES_PER_ELT = 2 }; - enum { BYTES_PER_STS = 4 }; - enum { BYTES_PER_ROW = COLS * BYTES_PER_ELT }; // TODO - enum { BYTES_PER_TILE = Cta_tile::M * BYTES_PER_ROW }; - - enum { WARPS_M = Cta_tile::WARPS_M }; - enum { WARPS_N = Cta_tile::WARPS_N }; - enum { WARPS_K = Cta_tile::WARPS_K }; - - static_assert(WARPS_K == 1); - inline __device__ Smem_tile_mma(char *smem, int tidx) { - smem_ = __nvvm_get_smem_pointer(smem); - - int write_col, write_row; - static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8) || (WARPS_M == 4 || WARPS_N == 8) || WARPS_N == 1); - if( WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8) ) { - write_row = (tidx & 0x1c) / 4; - write_col = (tidx & 0xe0) / 4 + (tidx & 0x03); - } else { - write_row = (tidx & 0xe0) / 2 + (tidx & 0x1c) / 4; - write_col = (tidx & 0x03); - } - write_col ^= (write_row & 0x07) * 4; - - write_offset_ = write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS; - } - - template - inline __device__ void store(const uint4 (®s)[M][N]) { - static_assert(COLS == Cta_tile::N); - for( int mi = 0; mi < M; mi++ ) { - for( int ni = 0; ni < N; ni++ ) { - size_t offset = write_offset_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT; - fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].x); - fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].z); - offset ^= 4 * BYTES_PER_STS; - fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].y); - fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].w); - } - } - } - - uint32_t smem_; - uint32_t write_offset_; - uint32_t warp_m; - uint32_t warp_n; - uint32_t lane; -}; - -template< typename Cta_tile, typename Base = Smem_tile_mma< Cta_tile>> -struct Smem_tile_mma_transposed : public Base { - enum { BYTES_PER_LDS = 16 }; - enum { BYTES_PER_ROW = Base::BYTES_PER_ROW }; - enum { BYTES_PER_ELT = Base::BYTES_PER_ELT }; - enum { WARPS_M = Base::WARPS_M }; - enum { WARPS_N = Base::WARPS_N }; - static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8)); - using Fragment = typename Base::Fragment; - inline __device__ Smem_tile_mma_transposed(char *smem, int tidx) : Base(smem, tidx) { - - static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8)); - int read_row, read_col; - read_row = (tidx & 0x0f); - read_col = (tidx & 0xe0) / 16 + (tidx & 0x1c) / 16; - - read_col ^= (read_row & 0x07); - read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; - } - - template - inline __device__ void load(Fragment (&frag)[M][N]) { - static_assert(Base::COLS == Cta_tile::N); - for( int mi = 0; mi < M; mi++ ) { - for( int ni = 0; ni < N; ni++ ) { - size_t offset = read_offset_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT; - uint4 dst; - fmha::ldsmt(dst, this->smem_ + offset); - frag[mi][ni].reg(0) = dst.x; - frag[mi][ni].reg(1) = dst.z; // Fragment A regs col major! - frag[mi][ni].reg(2) = dst.y; - frag[mi][ni].reg(3) = dst.w; - } - } - } - - uint32_t read_offset_; -}; - -template< typename Cta_tile, typename Base = Smem_tile_mma< Cta_tile>> -struct Smem_tile_mma_epilogue : public Base { - enum { BYTES_PER_LDS = 16 }; - enum { BYTES_PER_ROW = Base::BYTES_PER_ROW }; - enum { BYTES_PER_ELT = Base::BYTES_PER_ELT }; - enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDS }; - static_assert(THREADS_PER_ROW * BYTES_PER_LDS == BYTES_PER_ROW); - enum { ROWS_PER_LDS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW }; - enum { NUM_LDS = Cta_tile::M / ROWS_PER_LDS }; - static_assert(NUM_LDS * ROWS_PER_LDS == Cta_tile::M); - enum { WARPS_M = Base::WARPS_M }; - enum { WARPS_N = Base::WARPS_N }; - static_assert((WARPS_M == 4 || WARPS_N == 8) || WARPS_N == 1); - - using Acc = fmha::Fragment_accumulator; - - inline __device__ Smem_tile_mma_epilogue(char *smem, int tidx) : Base(smem, tidx) { - const int read_row = tidx / THREADS_PER_ROW; - int read_col = tidx % THREADS_PER_ROW; - read_col ^= (read_row & 0x07); - read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; - } - - inline __device__ void load(uint4 (&data)[NUM_LDS]) { - for( int ii = 0; ii < NUM_LDS; ii++ ) { - size_t offset = read_offset_ + ii * ROWS_PER_LDS * BYTES_PER_ROW; - fmha::lds(data[ii], this->smem_ + offset); - } - } - - template - inline __device__ void store(const Acc (&acc)[M][N]){ - #pragma unroll - for( int mi = 0; mi < M; mi++ ) { - #pragma unroll - for( int ni = 0; ni < N; ni++ ) { - // 1st row - 4 elements per row. - float tmp00 = acc[mi][ni].elt(0); - float tmp01 = acc[mi][ni].elt(1); - float tmp02 = acc[mi][ni].elt(4); - float tmp03 = acc[mi][ni].elt(5); - // 2nd row - 4 elements per row. - float tmp10 = acc[mi][ni].elt(2); - float tmp11 = acc[mi][ni].elt(3); - float tmp12 = acc[mi][ni].elt(6); - float tmp13 = acc[mi][ni].elt(7); - - uint32_t x = fmha::float2_to_half2(tmp00, tmp01); - uint32_t y = fmha::float2_to_half2(tmp02, tmp03); - uint32_t z = fmha::float2_to_half2(tmp10, tmp11); - uint32_t w = fmha::float2_to_half2(tmp12, tmp13); - - size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; - fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, x); - fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, z); - offset ^= 4 * Base::BYTES_PER_STS; - fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, y); - fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, w); - } - } - } - - template - inline __device__ void store(const uint4 (®s)[M][N]) { - for( int mi = 0; mi < M; mi++ ) { - for( int ni = 0; ni < N; ni++ ) { - size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; - fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].x); - fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].z); - offset ^= 4 * Base::BYTES_PER_STS; - fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].y); - fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].w); - } - } - } - - uint32_t read_offset_; -}; - -} // namespace fmha diff --git a/apex/contrib/csrc/fmha/src/fmha/softmax.h b/apex/contrib/csrc/fmha/src/fmha/softmax.h deleted file mode 100644 index 153f42d..0000000 --- a/apex/contrib/csrc/fmha/src/fmha/softmax.h +++ /dev/null @@ -1,395 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#pragma once - -namespace fmha { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct Sum_ { - enum { IS_SUM = 1 }; - static inline __device__ float apply(float x, float y) { - return x + y; - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct Max_ { - enum { IS_SUM = 0 }; - static inline __device__ float apply(float x, float y) { - return x > y ? x : y; - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float apply_exp_(float x, float max) { - return __expf(x - max); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template struct ReadType {}; -template<> struct ReadType<4> { using T = float;}; -template<> struct ReadType<8> { using T = float2;}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Smem_tile_reduce { - // Helper class to distribute MMA tiles reduced over rows per warp over quads. - - // The Mma tile. - using Mma_tile = fmha::Hmma_tile; - - // The number of MMAs in M/N dimensions. - enum { MMAS_M = Mma_tile::MMAS_M }; - enum { MMAS_N = Mma_tile::MMAS_N }; - - enum { WARPS_M = Cta_tile::WARPS_M }; - enum { WARPS_N = Cta_tile::WARPS_N }; - - - static constexpr int ROWS = WARPS_M * MMAS_M * 16; - static constexpr int COLS = WARPS_N; - static_assert(COLS == 4 || COLS == 8); - static constexpr int ROWS_PER_XOR_PATTERN = (COLS == 8) ? 4 : 8; - static constexpr int BYTES_PER_TILE = ROWS * COLS * sizeof(float); - static constexpr int ELTS_PER_TILE = ROWS * COLS; - - static constexpr int THREADS_PER_GROUP = Kernel_traits::Gmem_tile_o::THREADS_PER_ROW; - static_assert(THREADS_PER_GROUP == 16); // DEBUG - static constexpr int ROWS_PER_WARP = 32 / THREADS_PER_GROUP; - static constexpr int LOOPS = Kernel_traits::Gmem_tile_o::LOOPS; - static_assert(LOOPS == 1); - - using read_t = typename ReadType::T; - - __device__ inline Smem_tile_reduce(float *smem_, const int tidx) { - - int lane = tidx % 32; - int warp = tidx / 32; - - int warp_m = warp % WARPS_M; - int warp_n = warp / WARPS_M; - - qid_ = lane % 4; - int qp = lane / 4; - - // Swizzle the column to avoid 2-fold bank conflicts when we have 8 warps. - // This won't affect reading as we assume commutative reduction ops. - const int col = warp_n ^ (qp / ROWS_PER_XOR_PATTERN); - smem_write_ = &smem_[warp_m * 16 * MMAS_M * WARPS_N + qp * WARPS_N + col]; - smem_read_ = &reinterpret_cast(smem_)[warp_m * 16 * MMAS_M * 4 + qp * 4 + qid_]; - - } - - __device__ inline void store(float (&frag)[2 * MMAS_M]) { - if( qid_ == 0 ) { - #pragma unroll - for( int mi = 0; mi < MMAS_M; mi++ ) { - int offset = mi * 16 * WARPS_N; - smem_write_[offset + 0 * 8 * WARPS_N] = frag[mi * 2 + 0]; - smem_write_[offset + 1 * 8 * WARPS_N] = frag[mi * 2 + 1]; - } - } - } - - __device__ inline void load(read_t (&frag)[2 * MMAS_M]) { - #pragma unroll - for( int mi = 0; mi < MMAS_M; mi++ ) { - int offset = mi * 16 * 4; - frag[mi * 2 + 0] = smem_read_[offset + 0 * 8 * 4]; - frag[mi * 2 + 1] = smem_read_[offset + 1 * 8 * 4]; - } - } - - int qid_; - float *smem_write_; - read_t *smem_read_; - -}; - - -template -struct Softmax_base { - - // The Mma tile. - using Mma_tile = fmha::Hmma_tile; - - // The number of MMAs in M/N dimensions. - enum { MMAS_M = Mma_tile::MMAS_M }; - enum { MMAS_N = Mma_tile::MMAS_N }; - - // The number of groups of warp such that we have at most 4 warps writing consecutive elements. - enum { GROUPS = fmha::Div_up::VALUE }; - // The number of elements that we are going to store per row. - enum { ELEMENTS_PER_ROW = Cta_tile::WARPS_N / GROUPS }; - // The number of rows. - enum { ROWS = Cta_tile::M * GROUPS }; - // The total number of elements. - enum { ELEMENTS = ROWS * ELEMENTS_PER_ROW }; - - // Ctor. - template - inline __device__ Softmax_base(const Params ¶ms, void *smem, int bidb, int tidx) - : // packed_mask_ptr_(reinterpret_cast(params.packed_mask_ptr)), - smem_(reinterpret_cast(smem)), tidx_(tidx) { - - // Move to the 1st mask loaded by the thread+ tidx; - // packed_mask_ptr_ += bidb * params.packed_mask_stride_in_bytes + tidx * sizeof(uint32_t); - - // Extract the position in the warp. - int warp = tidx / Cta_tile::THREADS_PER_WARP; - int lane = tidx % Cta_tile::THREADS_PER_WARP; - - // Decompose the warp index into M and N. - int warp_m = warp % Cta_tile::WARPS_M; - int warp_n = warp / Cta_tile::WARPS_M; - - // Decompose the warp-n index into group/position-inside-the-group. - int warp_g = warp_n / ELEMENTS_PER_ROW; - int warp_i = warp_n % ELEMENTS_PER_ROW; - - // The location written by the threads. - int write_row = warp_g * (ROWS / GROUPS) + warp_m * Mma_tile::M_PER_MMA + lane / 4; - int write_col = warp_i; - - // Assemble the write pointer. - smem_write_ = &smem_[write_row * ELEMENTS_PER_ROW + write_col]; - - // Assemble the read pointer. - smem_read_ = &smem_[warp_m * Mma_tile::M_PER_MMA + lane / 4]; - } - - template - inline __device__ void apply_mask(const Mask &mask) { - #pragma unroll - for( int mi = 0; mi < MMAS_M; ++mi ) { - #pragma unroll - for( int ii = 0; ii < 2; ++ii ) { - #pragma unroll - for( int ni = 0; ni < MMAS_N; ++ni ) { - #pragma unroll - for( int jj = 0; jj < 4; ++jj ) { - if( !mask.is_valid(mi, ni, ii, jj) ) { - elt_[2 * mi + ii][4 * ni + jj] = -INFINITY; - } - } - } - } - } - } - - // Apply the exp to all the elements. - inline __device__ void apply_exp(const float (&max)[MMAS_M * 2]) { - #pragma unroll - for( int mi = 0; mi < MMAS_M * 2; ++mi ) { - #pragma unroll - for( int ni = 0; ni < MMAS_N * 4; ++ni ) { - elt_[mi][ni] = apply_exp_(elt_[mi][ni], max[mi]); - } - } - } - - // Scale all the elements. - inline __device__ void scale(const float (&sum)[MMAS_M * 2]) { - // Precompute the inverse sum to normalize. Without -use_fast_math, it makes a huge deal. - float inv_sum[MMAS_M * 2]; - #pragma unroll - for( int mi = 0; mi < MMAS_M * 2; ++mi ) { - inv_sum[mi] = (sum[mi] == 0.f || sum[mi] != sum[mi]) ? 1.f : 1.f / sum[mi]; - } - - // Update the values. - #pragma unroll - for( int mi = 0; mi < MMAS_M * 2; ++mi ) { - #pragma unroll - for( int ni = 0; ni < MMAS_N * 4; ++ni ) { - elt_[mi][ni] *= inv_sum[mi]; - } - } - } - - // The pointer to the mask. - const char *packed_mask_ptr_; - // Shared memory for the CTA-wide reduction. - float *smem_, *smem_write_, *smem_read_; - // The current thread index. - int tidx_; - // The elements. - float elt_[MMAS_M * 2][MMAS_N * 4]; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Softmax : public Softmax_base { - - // The base class. - using Base = Softmax_base; - // The fragment. - using Fragment_a = fmha::Fragment_a; - - static_assert(Fragment_a::NUM_REGS == 4); - - enum { WARPS_M = Cta_tile::WARPS_M }; - enum { WARPS_N = Cta_tile::WARPS_N }; - // The MMAs. - enum { MMAS_M = Base::MMAS_M }; - enum { MMAS_N = Base::MMAS_N }; - - // The accumulators. - using Accumulator = fmha::Fragment_accumulator; - using Accumulator_out = Fragment; - static_assert(Accumulator_out::NUM_REGS == 4); - - static_assert(std::is_same::value); - - using Smem_tile_red = Smem_tile_reduce; - static_assert(Smem_tile_red::ELTS_PER_TILE == Cta_tile::M * WARPS_N); - // Ctor. - template - inline __device__ Softmax(const Params ¶ms, void *smem, int bidb, int tidx) - : Base(params, smem, bidb, tidx) - , params_scale_bmm1_(params.scale_bmm1) - , smem_sum_(static_cast(smem), tidx) - , smem_max_(static_cast(smem) + Smem_tile_red::ELTS_PER_TILE, tidx) { - } - - // Pack the data to a fragment for the next GEMM. - template - inline __device__ void pack(Fragment_a (&dst)[K][M]) const { - #pragma unroll - for( int mi = 0; mi < M; ++mi ) { - #pragma unroll - for( int ki = 0; ki < K; ++ki ) { - - // 1st row - 4 elements per row. - float tmp_00 = this->elt_[2 * mi + 0][4 * ki + 0]; - float tmp_01 = this->elt_[2 * mi + 0][4 * ki + 1]; - float tmp_02 = this->elt_[2 * mi + 0][4 * ki + 2]; - float tmp_03 = this->elt_[2 * mi + 0][4 * ki + 3]; - - // 2nd row - 4 elements per row. - float tmp_10 = this->elt_[2 * mi + 1][4 * ki + 0]; - float tmp_11 = this->elt_[2 * mi + 1][4 * ki + 1]; - float tmp_12 = this->elt_[2 * mi + 1][4 * ki + 2]; - float tmp_13 = this->elt_[2 * mi + 1][4 * ki + 3]; - - // Pack to 4 registers. - dst[ki][mi].reg(0) = fmha::float2_to_half2(tmp_00, tmp_01); - dst[ki][mi].reg(1) = fmha::float2_to_half2(tmp_10, tmp_11); - dst[ki][mi].reg(2) = fmha::float2_to_half2(tmp_02, tmp_03); - dst[ki][mi].reg(3) = fmha::float2_to_half2(tmp_12, tmp_13); - } - } - } - - // Scale FP32 fragments - inline __device__ void unpack(const Accumulator (&acc)[MMAS_M][MMAS_N]) { - const float scalef = reinterpret_cast(this->params_scale_bmm1_); - - #pragma unroll - for( int mi = 0; mi < MMAS_M; ++mi ) { - #pragma unroll - for( int ni = 0; ni < MMAS_N; ++ni ) { - // 1st row - 4 elements per row. - this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi][ni].elt(0) * scalef; - this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi][ni].elt(1) * scalef; - this->elt_[2 * mi + 0][4 * ni + 2] = acc[mi][ni].elt(4) * scalef; - this->elt_[2 * mi + 0][4 * ni + 3] = acc[mi][ni].elt(5) * scalef; - // 2nd row - 4 elements per row. - this->elt_[2 * mi + 1][4 * ni + 0] = acc[mi][ni].elt(2) * scalef; - this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi][ni].elt(3) * scalef; - this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi][ni].elt(6) * scalef; - this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi][ni].elt(7) * scalef; - } - } - } - // Scale FP32 fragments - inline __device__ void unpack_noscale(const Accumulator (&acc)[MMAS_M][MMAS_N]) { - - #pragma unroll - for( int mi = 0; mi < MMAS_M; ++mi ) { - #pragma unroll - for( int ni = 0; ni < MMAS_N; ++ni ) { - // 1st row - 4 elements per row. - this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi][ni].elt(0); - this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi][ni].elt(1); - this->elt_[2 * mi + 0][4 * ni + 2] = acc[mi][ni].elt(4); - this->elt_[2 * mi + 0][4 * ni + 3] = acc[mi][ni].elt(5); - // 2nd row - 4 elements per row. - this->elt_[2 * mi + 1][4 * ni + 0] = acc[mi][ni].elt(2); - this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi][ni].elt(3); - this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi][ni].elt(6); - this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi][ni].elt(7); - } - } - } - - - - template - __device__ inline void reduce_(float (&frag)[2 * MMAS_M], Operator &op, Smem_tile_red & smem_red) { - for( int mi = 0; mi < 2 * MMAS_M; mi++ ) { - frag[mi] = this->elt_[mi][0]; - for( int ni = 1; ni < 4 * MMAS_N; ni++ ) { - frag[mi] = op(frag[mi], this->elt_[mi][ni]); - } - } - quad_reduce(frag, frag, op); - - smem_red.store(frag); - __syncthreads(); - typename Smem_tile_red::read_t tmp[2 * MMAS_M]; - smem_red.load(tmp); - - quad_allreduce(frag, tmp, op); - } - - __device__ inline void reduce_max(float (&frag)[2 * MMAS_M]){ - MaxOp max; - reduce_(frag, max, smem_max_); - } - - __device__ inline void reduce_sum(float (&frag)[2 * MMAS_M]){ - SumOp sum; - reduce_(frag, sum, smem_sum_); - } - - - const uint32_t params_scale_bmm1_; - Smem_tile_red smem_max_; - Smem_tile_red smem_sum_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace fmha diff --git a/apex/contrib/csrc/fmha/src/fmha/utils.h b/apex/contrib/csrc/fmha/src/fmha/utils.h deleted file mode 100644 index bedba0e..0000000 --- a/apex/contrib/csrc/fmha/src/fmha/utils.h +++ /dev/null @@ -1,1038 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#pragma once - -#include -#include -#include - -extern "C" __device__ uint32_t __nvvm_get_smem_pointer(void *ptr); - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace fmha { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct Row {}; -struct Col {}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int M, bool = (M & (M-1)) == 0 > -struct Next_power_of_two { -}; - -template< int M > -struct Next_power_of_two< M, true > { enum { VALUE = M }; }; -template<> -struct Next_power_of_two< 3, false> { enum { VALUE = 4 }; }; -template<> -struct Next_power_of_two< 5, false> { enum { VALUE = 8 }; }; -template<> -struct Next_power_of_two< 6, false> { enum { VALUE = 8 }; }; -template<> -struct Next_power_of_two< 7, false> { enum { VALUE = 8 }; }; -template<> -struct Next_power_of_two< 9, false> { enum { VALUE = 16 }; }; -template<> -struct Next_power_of_two< 10, false> { enum { VALUE = 16 }; }; -template<> -struct Next_power_of_two< 11, false> { enum { VALUE = 16 }; }; -template<> -struct Next_power_of_two< 12, false> { enum { VALUE = 16 }; }; -template<> -struct Next_power_of_two< 13, false> { enum { VALUE = 16 }; }; -template<> -struct Next_power_of_two< 14, false> { enum { VALUE = 16 }; }; -template<> -struct Next_power_of_two< 15, false> { enum { VALUE = 16 }; }; -template<> -struct Next_power_of_two< 24, false> { enum { VALUE = 32 }; }; -template<> -struct Next_power_of_two< 48, false> { enum { VALUE = 64 }; }; -template<> -struct Next_power_of_two< 80, false> { enum { VALUE = 128 }; }; -template<> -struct Next_power_of_two< 96, false> { enum { VALUE = 128 }; }; -template<> -struct Next_power_of_two<112, false> { enum { VALUE = 128 }; }; -template<> -struct Next_power_of_two<144, false> { enum { VALUE = 256 }; }; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int N, bool = (N & (N-1)) == 0 > -struct Prev_power_of_two { -}; - -template< int N > -struct Prev_power_of_two< N, true > { enum { VALUE = N }; }; -template<> -struct Prev_power_of_two< 3, false> { enum { VALUE = 2 }; }; -template<> -struct Prev_power_of_two< 5, false> { enum { VALUE = 4 }; }; -template<> -struct Prev_power_of_two< 6, false> { enum { VALUE = 4 }; }; -template<> -struct Prev_power_of_two< 7, false> { enum { VALUE = 4 }; }; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int M, int N > -struct Div_up { - enum { VALUE = (M + N-1) / N }; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int A, int B > -struct Max { - enum { VALUE = A >= B ? A : B }; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int A, int B, int C > -struct Max_3 { - enum { VALUE = Max::VALUE, C>::VALUE }; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int A, int B > -struct Min { - enum { VALUE = A <= B ? A : B }; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int SIZE_IN_BYTES > -struct Uint_from_size_in_bytes { -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -struct Uint_from_size_in_bytes<1> { - using Type = uint8_t; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -struct Uint_from_size_in_bytes<2> { - using Type = uint16_t; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -struct Uint_from_size_in_bytes<4> { - using Type = uint32_t; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -struct Uint_from_size_in_bytes<8> { - using Type = uint2; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -struct Uint_from_size_in_bytes<16> { - using Type = uint4; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int WARPS_M, int WARPS_N, int WARPS_K > -struct Warp_masks { -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -struct Warp_masks<8, 1, 1> { enum { M = 0xe0, N = 0x00, K = 0x00 }; }; -template<> -struct Warp_masks<4, 2, 1> { enum { M = 0x60, N = 0x80, K = 0x00 }; }; -template<> -struct Warp_masks<4, 1, 2> { enum { M = 0x60, N = 0x00, K = 0x80 }; }; -template<> -struct Warp_masks<4, 1, 1> { enum { M = 0x60, N = 0x00, K = 0x00 }; }; -template<> -struct Warp_masks<2, 4, 1> { enum { M = 0x20, N = 0xc0, K = 0x00 }; }; -template<> -struct Warp_masks<2, 2, 2> { enum { M = 0x20, N = 0x40, K = 0x80 }; }; -template<> -struct Warp_masks<2, 2, 1> { enum { M = 0x20, N = 0x40, K = 0x00 }; }; -template<> -struct Warp_masks<2, 1, 2> { enum { M = 0x20, N = 0x00, K = 0x40 }; }; -template<> -struct Warp_masks<2, 1, 1> { enum { M = 0x20, N = 0x00, K = 0x00 }; }; -template<> -struct Warp_masks<1, 8, 1> { enum { M = 0x00, N = 0xe0, K = 0x00 }; }; -template<> -struct Warp_masks<1, 4, 2> { enum { M = 0x00, N = 0x60, K = 0x80 }; }; -template<> -struct Warp_masks<1, 4, 1> { enum { M = 0x00, N = 0x60, K = 0x00 }; }; -template<> -struct Warp_masks<1, 2, 2> { enum { M = 0x00, N = 0x20, K = 0x40 }; }; -template<> -struct Warp_masks<1, 2, 1> { enum { M = 0x00, N = 0x20, K = 0x00 }; }; -template<> -struct Warp_masks<1, 1, 4> { enum { M = 0x00, N = 0x00, K = 0x60 }; }; -template<> -struct Warp_masks<1, 1, 2> { enum { M = 0x00, N = 0x00, K = 0x20 }; }; -template<> -struct Warp_masks<1, 1, 1> { enum { M = 0x00, N = 0x00, K = 0x00 }; }; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< typename T > -inline __device__ __host__ T div_up(T m, T n) { - return (m + n-1) / n; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline int clz(int x) { - for( int i = 31; i >= 0; --i ) { - if( (1 << i) & x ) { - return 31 - i; - } - } - return 32; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline int find_log_2(int x, bool round_up = false) { - int a = 31 - clz(x); - if( round_up ) { - a += (x & (x-1)) ? 1 : 0; - } - return a; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint32_t hadd2(uint32_t a, uint32_t b) { - uint32_t c; - asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint32_t hmin2(uint32_t a, uint32_t b) { - uint32_t c; - asm volatile("min.f16x2 %0, %1, %2;" : "=r"(c) : "r"(a), "r"(b)); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint32_t hmul2(uint32_t a, uint32_t b) { - uint32_t c; - asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint2 hmul4(uint2 a, uint2 b) { - uint2 c; - c.x = hmul2(a.x, b.x); - c.y = hmul2(a.y, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint4 hmul8(uint4 a, uint4 b) { - uint4 c; - c.x = hmul2(a.x, b.x); - c.y = hmul2(a.y, b.y); - c.z = hmul2(a.z, b.z); - c.w = hmul2(a.w, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint4 hmul8(uint32_t a, uint4 b) { - uint4 c; - c.x = hmul2(a, b.x); - c.y = hmul2(a, b.y); - c.z = hmul2(a, b.z); - c.w = hmul2(a, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint32_t hrelu2(uint32_t x, uint32_t lb = 0) { - uint32_t res; -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - asm volatile( "max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(lb)); -#else - const uint32_t zero = 0u; - asm volatile( \ - "{\n" \ - "\t .reg .f16x2 sela;\n" \ - "\t set.gtu.u32.f16x2 sela, %1, %2;\n" \ - "\t and.b32 %0, sela, %1;\n" - "}\n" : "=r"(res) : "r"(x), "r"(zero)); -#endif - return res; -} -static inline __device__ uint32_t habs2(uint32_t x) { - uint32_t res; - asm volatile( "abs.f16x2 %0, %1;\n" : "=r"(res) : "r"(x)); - return res; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// -template< typename T > -static inline __device__ T clamp(T x, T lb, T ub) { - return x < lb ? lb : (x > ub ? ub : x); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint16_t clamp_to_zero(uint16_t x) { - uint16_t mask; - asm volatile("set.gtu %0, %1, 0;" : "=h"(mask) : "h"(x)); - return mask & x; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint16_t float_to_half(float f) { - uint16_t h; - asm volatile("cvt.rn.f16.f32 %0, %1;" : "=h"(h) : "f"(f)); - return h; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint32_t float2_to_half2(float a, float b) { - uint32_t c; -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(c) : "f"(b), "f"(a)); -#else - uint16_t lo = float_to_half(a); - uint16_t hi = float_to_half(b); - asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(c) : "h"(lo), "h"(hi)); -#endif - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint32_t float_to_half2(float a) { - return float2_to_half2(a,a); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint32_t float2_to_half2(const float2 &f) { - return float2_to_half2(f.x, f.y); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint2 float4_to_half4(float x, float y, float z, float w) { - uint2 d; - d.x = float2_to_half2(x, y); - d.y = float2_to_half2(z, w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint32_t hfma2(uint32_t a, uint32_t b, uint32_t c) { - uint32_t d; - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint32_t hfma2_relu(uint32_t a, uint32_t b, uint32_t c) { - uint32_t d; -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - asm volatile("fma.rn.f16x2.relu %0, %1, %2, %3;" : "=r"(d) : "r"(a), "r"(b), "r"(c)); -#else - d = hrelu2(hfma2(a, b, c)); -#endif - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint32_t h0_h0(uint32_t x) { - uint32_t y; - asm volatile("{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {lo, lo};}\n" - : "=r"(y) : "r"(x)); - return y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ float h0_to_float(uint32_t h2) { - float f; - asm volatile("{\n" \ - ".reg .f16 lo, hi;\n" \ - "mov.b32 {lo, hi}, %1;\n" \ - "cvt.f32.f16 %0, lo;\n" \ - "}\n" : "=f"(f) : "r"(h2)); - return f; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint32_t h1_h1(uint32_t x) { - uint32_t y; - asm volatile("{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {hi, hi};}\n" - : "=r"(y) : "r"(x)); - return y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint16_t hadd(uint16_t a, uint16_t b) { - uint16_t d; - asm volatile("add.f16 %0, %1, %2;" : "=h"(d) : "h"(a), "h"(b)); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint32_t hadd(uint32_t a, uint32_t b) { - return hadd2(a, b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint2 hadd4(uint2 a, uint2 b) { - uint2 c; - c.x = hadd2(a.x, b.x); - c.y = hadd2(a.y, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint2 hadd(uint2 a, uint2 b) { - return hadd4(a, b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint4 hadd8(uint4 a, uint4 b) { - uint4 c; - c.x = hadd2(a.x, b.x); - c.y = hadd2(a.y, b.y); - c.z = hadd2(a.z, b.z); - c.w = hadd2(a.w, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint4 fadd4(uint4 a, uint4 b) { - float4 c; - c.x = reinterpret_cast(a.x) + reinterpret_cast(b.x); - c.y = reinterpret_cast(a.y) + reinterpret_cast(b.y); - c.z = reinterpret_cast(a.z) + reinterpret_cast(b.z); - c.w = reinterpret_cast(a.w) + reinterpret_cast(b.w); - return reinterpret_cast(c); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint4 hadd(uint4 a, uint4 b) { - return hadd8(a, b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ float half_to_float(uint16_t h) { - float f; - asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); - return f; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ float2 half2_to_float2(uint32_t x) { - uint16_t lo, hi; - asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(x)); - return make_float2(half_to_float(lo), half_to_float(hi)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ void half2_to_float2(float &x, float &y, uint32_t h) { - float2 tmp = half2_to_float2(h); - x = tmp.x; - y = tmp.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint16_t hfma(uint16_t a, uint16_t b, uint16_t c) { - uint16_t d; - asm volatile("fma.rn.f16 %0, %1, %2, %3;" : "=h"(d) : "h"(a), "h"(b), "h"(c)); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ uint16_t hmul(uint16_t a, uint16_t b) { - uint16_t d; - asm volatile("mul.f16 %0, %1, %2;" : "=h"(d) : "h"(a), "h"(b)); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline __device__ float sigmoid(float x) { - return 1.f / (1.f + expf(-x)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void clear(uint16_t &dst) { - dst = uint16_t(0); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void clear(uint32_t &dst) { - dst = 0u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void clear(uint2 &dst) { - dst = make_uint2(0u, 0u); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void clear(uint4 &dst) { - dst = make_uint4(0u, 0u, 0u, 0u); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// -// P R E D I C A T E P A C K I N G -// -//////////////////////////////////////////////////////////////////////////////////////////////////// -enum { BYTES_PER_REG = 4, PREDS_PER_BYTE = 4, PREDS_PER_REG = BYTES_PER_REG * PREDS_PER_BYTE }; - - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// -// G E N E R I C P R E D I C A T E D L D G S T S -// -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int N, int M, typename Functor > -inline __device__ void load_(Functor &fct, const uint32_t (&preds)[M]) { - - // The number of complete bytes (where we use all the predicates in a byte). - enum { COMPLETE = N / PREDS_PER_BYTE }; - // Make sure we did allocate enough predicates. - static_assert(Div_up::VALUE <= M, ""); - // The remainder. - enum { REMAINDER = N - COMPLETE * PREDS_PER_BYTE }; - // Make sure we got the math right and the remainder is between 0 and 3. - static_assert(REMAINDER >= 0 && REMAINDER <= 3, ""); - // The mask to extract the predicates. - enum { COMPLETE_MASK = (1 << PREDS_PER_BYTE) - 1 }; - - // Clear the fetch registers. - #pragma unroll - for( int ii = 0; ii < N; ++ii ) { - fct.clear(ii); - } - - // Run complete steps. - bool p[PREDS_PER_BYTE]; - #pragma unroll - for( int ii = 0; ii < COMPLETE; ++ii ) { - - // The predicate. - uint32_t reg = preds[ii / BYTES_PER_REG]; - - // Extract the predicates. - #pragma unroll - for( int jj = 0; jj < PREDS_PER_BYTE; ++jj ) { - uint32_t mask = 1u << (ii % BYTES_PER_REG * 8 + jj); - p[jj] = (reg & mask) != 0u; - } - - // Issue the loads. - #pragma unroll - for( int jj = 0; jj < PREDS_PER_BYTE; ++jj ) { - fct.load(ii * PREDS_PER_BYTE + jj, p[jj]); - } - } - - // Skip the rest of the code if we do not have a remainder. - if( REMAINDER > 0 ) { - - // The mask to extract the predicates. - enum { REMAINDER_MASK = (1 << REMAINDER) - 1 }; - - // The predicate register. - uint32_t reg = preds[COMPLETE / BYTES_PER_REG]; - - // Extract the predicates. - #pragma unroll - for( int jj = 0; jj < PREDS_PER_BYTE; ++jj ) { - uint32_t mask = 1u << (COMPLETE % BYTES_PER_REG * 8 + jj); - p[jj] = (reg & mask) != 0u; - } - - // Issue the loads. - #pragma unroll - for( int ii = 0; ii < REMAINDER; ++ii ) { - fct.load(COMPLETE * PREDS_PER_BYTE + ii, p[ii]); - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int M, typename Functor > -inline __device__ void load_(Functor &fct, uint32_t preds) { - uint32_t tmp[1] = { preds }; - load_(fct, tmp); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// -// L D G -// -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void ldg(uint8_t &dst, const void *ptr) { - dst = *reinterpret_cast(ptr); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void ldg(uint16_t &dst, const void *ptr) { - dst = *reinterpret_cast(ptr); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void ldg(uint32_t &dst, const void *ptr) { - dst = *reinterpret_cast(ptr); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void ldg(uint2 &dst, const void *ptr) { - dst = *reinterpret_cast(ptr); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void ldg(uint4 &dst, const void *ptr) { - dst = *reinterpret_cast(ptr); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< typename Data_type, int N > -struct Ldg_functor { - // Ctor. - inline __device__ Ldg_functor(Data_type (&fetch)[N], const void* (&ptrs)[N]) - : fetch_(fetch), ptrs_(ptrs) { - } - - // Clear the element. - inline __device__ void clear(int ii) { - fmha::clear(fetch_[ii]); - } - - // Trigger the loads. - inline __device__ void load(int ii, bool p) { - if( p ) { - ldg(fetch_[ii], ptrs_[ii]); - } - } - - // The fetch registers. - Data_type (&fetch_)[N]; - // The pointers. - const void* (&ptrs_)[N]; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< typename Data_type, int N, int M > -inline __device__ void ldg_(Data_type (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) { - Ldg_functor fct(fetch, ptrs); - load_(fct, preds); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int N, int M > -inline __device__ void ldg(uint8_t (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) { - ldg_(fetch, ptrs, preds); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int N, int M > -inline __device__ void ldg(uint16_t (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) { - ldg_(fetch, ptrs, preds); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int N, int M > -inline __device__ void ldg(uint32_t (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) { - ldg_(fetch, ptrs, preds); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int N, int M > -inline __device__ void ldg(uint2 (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) { - ldg_(fetch, ptrs, preds); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int N, int M > -inline __device__ void ldg(uint4 (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) { - ldg_(fetch, ptrs, preds); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// -// L D S -// -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void lds(uint16_t &dst, uint32_t ptr) { - asm volatile("ld.shared.b16 %0, [%1];\n" : "=h"(dst) : "r"(ptr)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void lds(uint32_t &dst, uint32_t ptr) { - asm volatile("ld.shared.b32 %0, [%1];\n" : "=r"(dst) : "r"(ptr)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void lds(uint2 &dst, uint32_t ptr) { - asm volatile("ld.shared.v2.b32 {%0, %1}, [%2];\n" : "=r"(dst.x), "=r"(dst.y) : "r"(ptr)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void lds(uint4 &dst, uint32_t ptr) { - asm volatile("ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];\n" - : "=r"(dst.x) - , "=r"(dst.y) - , "=r"(dst.z) - , "=r"(dst.w) - : "r"(ptr)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// -// L D S M -// -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void ldsm(uint32_t &dst, uint32_t ptr) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 - asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n" - : "=r"(dst) : "r"(ptr)); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void ldsmt(uint32_t &dst, uint32_t ptr) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 - asm volatile("ldmatrix.sync.aligned.m8n8.x1.trans.shared.b16 {%0}, [%1];\n" - : "=r"(dst) : "r"(ptr)); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void ldsm(uint2 &dst, uint32_t ptr) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 - asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0, %1}, [%2];\n" - : "=r"(dst.x), "=r"(dst.y) : "r"(ptr)); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void ldsmt(uint2 &dst, uint32_t ptr) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 - asm volatile("ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16 {%0, %1}, [%2];\n" - : "=r"(dst.x), "=r"(dst.y) : "r"(ptr)); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void ldsm(uint4 &dst, uint32_t ptr) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 - asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" - : "=r"(dst.x), "=r"(dst.y), "=r"(dst.z), "=r"(dst.w) : "r"(ptr)); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void ldsmt(uint4 &dst, uint32_t ptr) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730 - asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];\n" - : "=r"(dst.x), "=r"(dst.y), "=r"(dst.z), "=r"(dst.w) : "r"(ptr)); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// -// S T G -// -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void stg(void *ptr, uint8_t val) { - *reinterpret_cast(ptr) = val; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void stg(void *ptr, uint16_t val) { - *reinterpret_cast(ptr) = val; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void stg(void *ptr, uint32_t val) { - *reinterpret_cast(ptr) = val; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void stg(void *ptr, uint2 val) { - *reinterpret_cast(ptr) = val; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void stg(void *ptr, uint4 val) { - *reinterpret_cast(ptr) = val; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// -// S T S -// -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void sts(uint32_t ptr, uint16_t val) { - asm volatile("st.shared.b16 [%0], %1;\n" : : "r"(ptr), "h"(val)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void sts(uint32_t ptr, uint32_t val) { - asm volatile("st.shared.b32 [%0], %1;\n" : : "r"(ptr), "r"(val)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void sts(uint32_t ptr, uint2 val) { - asm volatile("st.shared.v2.b32 [%0], {%1, %2};\n" - : - : "r"(ptr) - , "r"(val.x) - , "r"(val.y)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void sts(uint32_t ptr, uint4 val) { - asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};\n" - : - : "r"(ptr) - , "r"(val.x) - , "r"(val.y) - , "r"(val.z) - , "r"(val.w)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< typename Data_type, int N > -inline __device__ void sts_(uint32_t (&ptrs)[N], const Data_type (&data)[N]) { - #pragma unroll - for( int ii = 0; ii < N; ++ii ) { - sts(ptrs[ii], data[ii]); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int N > -inline __device__ void sts(uint32_t (&ptrs)[N], const uint16_t (&data)[N]) { - sts_(ptrs, data); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int N > -inline __device__ void sts(uint32_t (&ptrs)[N], const uint32_t (&data)[N]) { - sts_(ptrs, data); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int N > -inline __device__ void sts(uint32_t (&ptrs)[N], const uint2 (&data)[N]) { - sts_(ptrs, data); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int N > -inline __device__ void sts(uint32_t (&ptrs)[N], const uint4 (&data)[N]) { - sts_(ptrs, data); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MaxOp { -__device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct SumOp { -__device__ inline T operator()(T const & x, T const & y) { return x + y; } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Allreduce { - static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); - template - static __device__ inline T run(T x, Operator &op) { - constexpr int OFFSET = THREADS / 2; - x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); - return Allreduce::run(x, op); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -struct Allreduce<2> { -template -static __device__ inline T run(T x, Operator &op) { - x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); - return x; -} -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__device__ inline void quad_reduce(float (&dst)[M], float (&src)[M], Operator &op) { - #pragma unroll - for(int mi=0; mi < M; mi++){ - dst[mi] = src[mi]; - dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 2)); - dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 1)); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__device__ inline void quad_reduce(float (&dst)[M], float2 (&src)[M], Operator &op) { - float tmp[M]; - #pragma unroll - for(int mi=0; mi < M; mi++){ - tmp[mi] = op(src[mi].x, src[mi].y); - } - quad_reduce(dst, tmp, op); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__device__ inline void quad_allreduce(float (&dst)[M], float (&src)[M], Operator &op) { - #pragma unroll - for(int mi=0; mi < M; mi++){ - dst[mi] = src[mi]; - dst[mi] = Allreduce<4>::run(dst[mi], op); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__device__ inline void quad_allreduce(float (&dst)[M], float2 (&src)[M], Operator &op) { - float tmp[M]; - #pragma unroll - for(int mi=0; mi < M; mi++){ - tmp[mi] = op(src[mi].x, src[mi].y); - } - quad_allreduce(dst, tmp, op); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace fmha diff --git a/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu b/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu deleted file mode 100644 index 517a5b7..0000000 --- a/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu +++ /dev/null @@ -1,60 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#include "fmha.h" -#include "fmha_dgrad_kernel_1xN_reload.h" - -using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>; - -extern "C" __global__ void fmha_dgrad_fp16_128_64_sm80_kernel(Fused_multihead_attention_fprop_params params) { - fmha::compute_dv_1xN(params); - fmha::compute_dq_dk_1xN(params); -} - -void run_fmha_dgrad_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params ¶ms, cudaStream_t stream) { - - constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float); - constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE; - constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE; - constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE; - - using Smem_tile_s = fmha::Smem_tile_mma_transposed< Kernel_traits::Cta_tile_p>; - constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE; - static_assert(smem_size_s == 16 * 128 * 2); - static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N); - - constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax; - constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v; - constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk); - - if( smem_size >= 48 * 1024 ) { - FMHA_CHECK_CUDA(cudaFuncSetAttribute( - fmha_dgrad_fp16_128_64_sm80_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - dim3 grid(params.h, params.b); - fmha_dgrad_fp16_128_64_sm80_kernel<<>>(params); -} diff --git a/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_256_64_kernel.sm80.cu b/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_256_64_kernel.sm80.cu deleted file mode 100644 index ac22a16..0000000 --- a/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_256_64_kernel.sm80.cu +++ /dev/null @@ -1,60 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#include "fmha.h" -#include "fmha_dgrad_kernel_1xN_reload.h" - -using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>; - -extern "C" __global__ void fmha_dgrad_fp16_256_64_sm80_kernel(Fused_multihead_attention_fprop_params params) { - fmha::compute_dv_1xN(params); - fmha::compute_dq_dk_1xN(params); -} - -void run_fmha_dgrad_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params ¶ms, cudaStream_t stream) { - - constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float); - constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE; - constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE; - constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE; - - using Smem_tile_s = fmha::Smem_tile_mma_transposed< Kernel_traits::Cta_tile_p>; - constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE; - static_assert(smem_size_s == 16 * 256 * 2); - static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N); - - constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax; - constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v; - constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk); - - if( smem_size >= 48 * 1024 ) { - FMHA_CHECK_CUDA(cudaFuncSetAttribute( - fmha_dgrad_fp16_256_64_sm80_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - dim3 grid(params.h, params.b); - fmha_dgrad_fp16_256_64_sm80_kernel<<>>(params); -} diff --git a/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu b/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu deleted file mode 100644 index 7081438..0000000 --- a/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu +++ /dev/null @@ -1,60 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#include "fmha.h" -#include "fmha_dgrad_kernel_1xN_reload.h" - -using Kernel_traits = FMHA_kernel_traits<384, 64, 16, 1, 8, 0x08u>; - -extern "C" __global__ void fmha_dgrad_fp16_384_64_sm80_kernel(Fused_multihead_attention_fprop_params params) { - fmha::compute_dv_1xN(params); - fmha::compute_dq_dk_1xN(params); -} - -void run_fmha_dgrad_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params ¶ms, cudaStream_t stream) { - - constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float); - constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE; - constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE; - constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE; - - using Smem_tile_s = fmha::Smem_tile_mma_transposed< Kernel_traits::Cta_tile_p>; - constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE; - static_assert(smem_size_s == 16 * 384 * 2); - static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N); - - constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax; - constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v; - constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk); - - if( smem_size >= 48 * 1024 ) { - FMHA_CHECK_CUDA(cudaFuncSetAttribute( - fmha_dgrad_fp16_384_64_sm80_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - dim3 grid(params.h, params.b); - fmha_dgrad_fp16_384_64_sm80_kernel<<>>(params); -} diff --git a/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu b/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu deleted file mode 100644 index 735006c..0000000 --- a/apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu +++ /dev/null @@ -1,105 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#include "fmha.h" -#include "fmha_dgrad_kernel_1xN_reload.h" -#include "fmha_dgrad_kernel_1xN_reload_nl.h" - -using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x08u>; - -extern "C" __global__ void fmha_dgrad_fp16_512_64_sm80_kernel(Fused_multihead_attention_fprop_params params) { - fmha::compute_dv_1xN(params); - fmha::compute_dq_dk_1xN(params); -} - -template -__global__ -void fmha_dgrad_fp16_512_64_sm80_nl_kernel(Fused_multihead_attention_fprop_params params){ - fmha::compute_dv_1xN_nl(params); - fmha::compute_dq_dk_1xN_nl(params); -} - -void run_fmha_dgrad_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params ¶ms, cudaStream_t stream) { - - constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float); - constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE; - constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE; - constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE; - - using Smem_tile_s = fmha::Smem_tile_mma_transposed< Kernel_traits::Cta_tile_p>; - constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE; - static_assert(smem_size_s == 16 * 512 * 2); - static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N); - - constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax; - constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v; - constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk); - - if( smem_size >= 48 * 1024 ) { - FMHA_CHECK_CUDA(cudaFuncSetAttribute( - fmha_dgrad_fp16_512_64_sm80_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - dim3 grid(params.h, params.b); - fmha_dgrad_fp16_512_64_sm80_kernel<<>>(params); -} - -void run_fmha_dgrad_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_params ¶ms, const int num_chunks, cudaStream_t stream) { - - constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float); - constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE; - constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE; - constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE; - - using Smem_tile_s = fmha::Smem_tile_mma_transposed; - constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE; - static_assert(smem_size_s == 16 * 512 * 2); - static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N); - - constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax; - constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v; - constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk); - - auto kernel = fmha_dgrad_fp16_512_64_sm80_nl_kernel<2>; - - if( num_chunks == 2 ) { - kernel = fmha_dgrad_fp16_512_64_sm80_nl_kernel<2>; - }else if( num_chunks == 3 ) { - kernel = fmha_dgrad_fp16_512_64_sm80_nl_kernel<3>; - } else { - assert(false && "Unsupperted number of chunks"); - } - - if( smem_size >= 48 * 1024 ) { - FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - - dim3 grid(params.h, params.b, num_chunks); - - kernel<<>>(params); - - FMHA_CHECK_CUDA(cudaPeekAtLastError()); -} diff --git a/apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload.h b/apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload.h deleted file mode 100644 index 3c4b817..0000000 --- a/apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload.h +++ /dev/null @@ -1,558 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#pragma once - -#include "fmha_kernel.h" -#include -#include - -namespace fmha { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void compute_dv_1xN(const Params ¶ms) { - - // The description of the CTA tile for the 1st batched GEMM. - using Cta_tile_p = typename Kernel_traits::Cta_tile_p; - // The description of the CTA tile for the 2nd batched GEMM. - using Cta_tile_dv = - fmha::Cta_tile_extd; - - static_assert(Cta_tile_dv::M == 512 || Cta_tile_dv::M == 384 || Cta_tile_dv::M == 256 || Cta_tile_dv::M == 128); - static_assert(Cta_tile_dv::N == 64); - static_assert(Cta_tile_dv::K == 16); - - // The MMA tile for the 1st GEMM. - using Mma_tile_p = fmha::Hmma_tile; - // The MMA tile for the 2nd GEMM. - using Mma_tile_dv = fmha::Hmma_tile; - - // The global memory tile to load Q. - using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q; - // The shared memory tile to swizzle Q. - // using Smem_tile_q = typename Kernel_traits::Smem_tile_q; - using Smem_tile_q = fmha::Smem_tile_a; - // The shared memory tile to reload Q as fragment b. - using Smem_tile_qt = fmha::Smem_tile_b; - - // The global memory tile to load K. - using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k; - // The shared memory tile to swizzle K. - using Smem_tile_k = typename Kernel_traits::Smem_tile_k; - - // The global memory tile to load V. - using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v; - // The shared memory tile to swizzle V. - using Smem_tile_v = typename Kernel_traits::Smem_tile_v; - - // The global memory tile to store O. - using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o; - // The shared memory tile to swizzle O. - using Smem_tile_o = typename Kernel_traits::Smem_tile_o; - - // The global memory tile to store dV. - using Gmem_tile_dv = typename Kernel_traits::Gmem_tile_v; - // The shared memory tile to swizzle dV. - using Smem_tile_dv = fmha::Smem_tile_mma_epilogue; - static_assert(Smem_tile_dv::NUM_LDS == Gmem_tile_dv::LDGS); - static_assert(Smem_tile_dv::THREADS_PER_ROW == Gmem_tile_dv::THREADS_PER_ROW); - - using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s; - using Smem_tile_st = typename Kernel_traits::Smem_tile_st; - using Gmem_tile_do = typename Kernel_traits::Gmem_tile_do; - - // Shared memory. - extern __shared__ char smem_[]; - - // The block index for the batch. - const int bidb = blockIdx.y; - // The block index for the head. - const int bidh = blockIdx.x; - // The thread index. - const int tidx = threadIdx.x; - - const BlockInfoPadded binfo(params, bidb, bidh, tidx); - if( binfo.stop_early() ) - return; - Mask mask(params, binfo, tidx); - - // Allocate the global memory tile loader for Q. - Gmem_tile_do gmem_q(params, binfo, tidx); // treating dout as Q - // Allocate the shared memory tile loader for Q. - Smem_tile_q smem_q(&smem_[0], tidx); - Smem_tile_qt smem_qt(&smem_[0], tidx); - Smem_tile_st smem_s(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], tidx); - - // Allocate the global memory tile loader for K. - Gmem_tile_k gmem_k(params, 2, binfo, tidx); // treating V as K - // Allocate the shared memory tile loader for K. - Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx); - - // Trigger the loads for Q. - gmem_q.load(smem_q); - // Trigger the loads for K. - gmem_k.load(smem_k); - - // Commit the data for Q and K to shared memory. - gmem_q.commit(smem_q); - gmem_k.commit(smem_k); - - // Make sure the data is in shared memory. - __syncthreads(); - - // Load the fragments for Q. - typename Smem_tile_q::Fragment frag_q[2][Mma_tile_p::MMAS_M]; - smem_q.load(frag_q[0], 0); - - typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dv::MMAS_N]; - static_assert(Smem_tile_qt::Fragment::NUM_REGS == 4); - static_assert(Mma_tile_dv::MMAS_K == 1); - smem_qt.load(frag_qt[0], 0); - - // Load the fragments for K. We keep the data in registers during the entire kernel. - typename Smem_tile_k::Fragment frag_k[2][Mma_tile_p::MMAS_N]; - smem_k.load(frag_k[0], 0); - - enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 }; - - Gmem_tile_s gmem_s(params, binfo, tidx); - - // Create the object to do the softmax. - using Softmax = fmha::Softmax; - Softmax softmax( - params, &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], bidb, tidx); - - enum { THREADS_PER_ROW = 32 }; - enum { M = Mma_tile_p::MMAS_M }; - enum { N = Mma_tile_p::MMAS_N }; - - // Declare the accumulators for the 2nd gemm. - fmha::Fragment_accumulator acc_dv[Mma_tile_dv::MMAS_M][Mma_tile_dv::MMAS_N]; - fmha::Clear_accumulator::apply(acc_dv); - - enum { STEPS = Cta_tile_p::N / Cta_tile_p::M }; - // Load over the entire sequence length. - for( int l = 0; l < STEPS; l++ ) { - const int loop = l * Cta_tile_p::M; - if( loop >= binfo.actual_seqlen ) - break; - - // Load S - uint4 s_regs[M][N]; - gmem_s.load(s_regs, mask); - fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; - fmha::Clear_accumulator::apply(acc_p); - // Do this part of P^T = (Q * K^T)^T. - #pragma unroll - for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) { - // Trigger the load from shared memory for the next series of Q values. - smem_q.load(frag_q[ki & 1], ki); - smem_k.load(frag_k[ki & 1], ki); - // Do the math for the values already in registers. - fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]); - } - - // Store s * dmask to smem for transpose - smem_s.store(s_regs); - - // Declare the accumulators for the 1st gemm. - // Do the final stage of math. - { - int ki = Mma_tile_p::MMAS_K; - fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]); - } - // Trigger the load for the next Q values. We're using double buffering, so reading qt is safe - if( l < STEPS - 1) { - smem_q.move_to_next_write_buffer(); - gmem_q.move(); - gmem_q.load(smem_q); - } - - - // Convert from the accumulator type to FP32 for Softmax. - softmax.unpack(acc_p); - - float s_mat[2 * M][4 * N]; - - #pragma unroll - for( int mi = 0; mi < M; mi++ ) { - #pragma unroll - for( int ni = 0; ni < N; ni++ ) { - uint4 &dst = s_regs[mi][ni]; - fmha::half2_to_float2(s_mat[2 * mi + 0][4 * ni + 0], s_mat[2 * mi + 0][4 * ni + 1], dst.x); - fmha::half2_to_float2(s_mat[2 * mi + 0][4 * ni + 2], s_mat[2 * mi + 0][4 * ni + 3], dst.y); - fmha::half2_to_float2(s_mat[2 * mi + 1][4 * ni + 0], s_mat[2 * mi + 1][4 * ni + 1], dst.z); - fmha::half2_to_float2(s_mat[2 * mi + 1][4 * ni + 2], s_mat[2 * mi + 1][4 * ni + 3], dst.w); - } - } - - #pragma unroll - for( int mi = 0; mi < M; mi++ ) { - #pragma unroll - for( int ii = 0; ii < 2; ii++ ) { - #pragma unroll - for( int ni = 0; ni < N; ni++ ) { - #pragma unroll - for( int jj = 0; jj < 4; jj++ ) { - float & s_dmask = s_mat[2 * mi + ii][4 * ni + jj]; - const bool drop = reinterpret_cast(s_dmask) & 0x80000000; - const float d_s = drop ? 0.f : softmax.elt_[2 * mi + ii][4 * ni + jj] * params.rp_dropout; - s_dmask = fabsf(s_dmask); - softmax.elt_[2 * mi + ii][4 * ni + jj] = d_s * fabsf(s_dmask); - } - } - } - } - - float p_sum[2 * M]; - softmax.reduce_sum(p_sum); - - const float scalef = reinterpret_cast(params.scale_softmax); - #pragma unroll - for( int mi = 0; mi < M; mi++ ) { - #pragma unroll - for( int ii = 0; ii < 2; ii++ ) { - #pragma unroll - for( int ni = 0; ni < N; ni++ ) { - #pragma unroll - for( int jj = 0; jj < 4; jj++ ) { - softmax.elt_[2 * mi + ii][4 * ni + jj] -= p_sum[2 * mi + ii] * (s_mat[2 * mi + ii][4 * ni + jj]) ; - softmax.elt_[2 * mi + ii][4 * ni + jj] *= scalef; - } - } - } - } - typename Smem_tile_st::Fragment frag_s[Mma_tile_dv::MMAS_K][Mma_tile_dv::MMAS_M]; - smem_s.load(frag_s); - for( int ki = 0; ki < Mma_tile_dv::MMAS_K; ki++ ) { - for( int mi = 0; mi < Mma_tile_dv::MMAS_M; mi++ ) { - for( int ii = 0; ii < Smem_tile_st::Fragment::NUM_REGS; ii++ ) { - frag_s[ki][mi].reg(ii) = fmha::hmul2(frag_s[ki][mi].reg(ii), params.scale_dropout); - frag_s[ki][mi].reg(ii) = fmha::hrelu2(frag_s[ki][mi].reg(ii)); - } - } - } - - gmem_s.store(softmax.elt_, mask); - gmem_s.move(); - - #pragma unroll - for( int ki = 1; ki < Mma_tile_dv::MMAS_K; ++ki ) { - // Trigger the load from shared memory for the next series of Q values. - smem_qt.load(frag_qt[ki & 1], ki); - // Do the math for the values already in registers. - fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]); - } - - // Do the final stage of math. - { - int ki = Mma_tile_dv::MMAS_K; - fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]); - } - // Commit the values for Q into shared memory. - if(l < STEPS - 1) { - gmem_q.commit(smem_q); - } - - // Make sure we are reading from the correct buffer. - smem_q.move_to_next_read_buffer(); - smem_qt.move_to_next_read_buffer(); - - // Make sure the data is in shared memory. - __syncthreads(); - - // Trigger the loads for the values of Q for the next iteration. - smem_q.load(frag_q[0], 0); - smem_k.load(frag_k[0], 0); - smem_qt.load(frag_qt[0], 0); - - } // Outer loop over the sequence length. - - // Epilogue swizzle for dV - Smem_tile_dv smem_dv(&smem_[Kernel_traits::Smem_tile_q::BYTES_PER_TILE], tidx); - smem_dv.store(acc_dv); - - __syncthreads(); - uint4 dv_out[Smem_tile_dv::NUM_LDS]; - smem_dv.load(dv_out); - Qkv_params dv_params; - dv_params.qkv_ptr = params.dqkv_ptr; - dv_params.qkv_stride_in_bytes = params.qkv_stride_in_bytes; - dv_params.h = params.h; - Gmem_tile_dv gmem_dv(dv_params, 2, binfo, tidx); - gmem_dv.store(dv_out); -} - -template -inline __device__ void compute_dq_dk_1xN(const Params ¶ms) { - - // The description of the CTA tile for the 1st batched GEMM. - using Cta_tile_p = typename Kernel_traits::Cta_tile_p; - using Cta_tile_o = typename Kernel_traits::Cta_tile_o; - // The description of the CTA tile for the 2nd batched GEMM. - using Cta_tile_dk = - fmha::Cta_tile_extd; - static_assert(Cta_tile_dk::M == 512 || Cta_tile_dk::M == 384 || Cta_tile_dk::M == 256 || Cta_tile_dk::M == 128); - static_assert(Cta_tile_dk::N == 64); - static_assert(Cta_tile_dk::K == 16); - - // The MMA tile for the 1st GEMM. - using Mma_tile_p = fmha::Hmma_tile; - using Mma_tile_o = fmha::Hmma_tile; - // The MMA tile for the 2nd GEMM. - using Mma_tile_dk = fmha::Hmma_tile; - - // The global memory tile to load Q. - using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q; - // The shared memory tile to swizzle Q. - using Smem_tile_q = typename Kernel_traits::Smem_tile_q; - - // The global memory tile to load K. - using Gmem_tile_k = typename Kernel_traits::Gmem_tile_v; - // The shared memory tile to swizzle K. - using Smem_tile_k = typename Kernel_traits::Smem_tile_v; // K is used like V in fprop - - // The global memory tile to load V. - using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v; - // The shared memory tile to swizzle V. - using Smem_tile_v = typename Kernel_traits::Smem_tile_v; - - // The global memory tile to store O. - // using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o; - using Gmem_tile_o = fmha::Gmem_tile_dq; - // The shared memory tile to swizzle O. - using Smem_tile_o = typename Kernel_traits::Smem_tile_o; - - // The global memory tile to store dK. - using Gmem_tile_dk = typename Kernel_traits::Gmem_tile_v; - // The shared memory tile to swizzle dK. - using Smem_tile_dk = fmha::Smem_tile_mma_epilogue; - static_assert(Smem_tile_dk::NUM_LDS == Gmem_tile_dk::LDGS); - static_assert(Smem_tile_dk::THREADS_PER_ROW == Gmem_tile_dk::THREADS_PER_ROW); - - // The shared memory tile to reload Q transposed. - using Smem_tile_qt = fmha::Smem_tile_b; - - using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s; - - using Smem_tile_st = typename Kernel_traits::Smem_tile_st; - - - enum { M = Mma_tile_p::MMAS_M }; - enum { N = Mma_tile_p::MMAS_N }; - static_assert(M == Mma_tile_o::MMAS_M); - static_assert(N == Mma_tile_o::MMAS_K); - // Shared memory. - extern __shared__ char smem_[]; - - // The block index for the batch. - const int bidb = blockIdx.y; - // The block index for the head. - const int bidh = blockIdx.x; - // The thread index. - const int tidx = threadIdx.x; - - const BlockInfoPadded binfo(params, bidb, bidh, tidx); - if( binfo.stop_early() ) - return; - - Mask mask(params, binfo, tidx); - - // Allocate the global memory tile loader for Q. - Gmem_tile_q gmem_q(params, 0, binfo, tidx); - // Allocate the shared memory tile loader for Q. - Smem_tile_q smem_q(&smem_[0], tidx); - Smem_tile_qt smem_qt(&smem_[0], tidx); - Smem_tile_st smem_s(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE], tidx); - - // Allocate the global memory tile loader for K. - Gmem_tile_k gmem_k(params, 1, binfo, tidx); - // Allocate the shared memory tile loader for K. - Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx); - - // Allocate the global memory tile loader for O. - Gmem_tile_o gmem_o(params, binfo, tidx); - // Allocate the shared memory tile loader for O. We use the same as K so be careful!!! - Smem_tile_o smem_o(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], tidx); - - // Trigger the loads for Q. - gmem_q.load(smem_q); - // Trigger the loads for K. - gmem_k.load(smem_k); - - Gmem_tile_s gmem_s(params, binfo, tidx); - // Load dP - uint4 s_regs[M][N]; - gmem_s.load(s_regs, mask); - gmem_s.move(); - - // Commit the data for Q and K to shared memory. - gmem_q.commit(smem_q); - gmem_k.commit(smem_k); - - // Make sure the data is in shared memory. - __syncthreads(); - - typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dk::MMAS_N]; - smem_qt.load(frag_qt[0], 0); - typename Smem_tile_k::Fragment frag_k[2][Mma_tile_o::MMAS_N]; - smem_k.load(frag_k[0], 0); - - enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 }; - - enum { THREADS_PER_ROW = 32 }; - enum { STEPS = Cta_tile_p::N / Cta_tile_p::M }; - - // Declare the accumulators for the 2nd gemm. - fmha::Fragment_accumulator acc_dk[Mma_tile_dk::MMAS_M][Mma_tile_dk::MMAS_N]; - fmha::Clear_accumulator::apply(acc_dk); - - // Load over the entire sequence length. - for( int l=0;l= binfo.actual_seqlen ) - break; - - // Pack dP as Fragment_a - fmha::Fragment_a frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M]; - #pragma unroll - for( int mi = 0; mi < M; mi++ ) { - #pragma unroll - for( int ni = 0; ni < N; ni++ ) { - uint4 &dst = s_regs[mi][ni]; - frag_p[ni][mi].reg(0) = dst.x; // row 0, cols 0,1 - frag_p[ni][mi].reg(1) = dst.z; // row 8, cols 0,1 - frag_p[ni][mi].reg(2) = dst.y; // row 0, cols 8,9 - frag_p[ni][mi].reg(3) = dst.w; // row 8, cols 8,9 - } - } - - // Declare the accumulators for the 1st gemm. - fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N]; - fmha::Clear_accumulator::apply(acc_o); - - // Do this part of O = P^T * V^T. dQ = dP x dK - #pragma unroll - for( int ki = 1; ki < Mma_tile_o::MMAS_K; ++ki ) { - // Trigger the load from shared memory for the next series of Q values. - smem_k.load(frag_k[ki & 1], ki); - // Do the math for the values already in registers. - fmha::gemm(acc_o, frag_p[ki - 1], frag_k[(ki - 1) & 1]); - } - - // Do the final stage of math. - { - int ki = Mma_tile_o::MMAS_K; - fmha::gemm(acc_o, frag_p[ki - 1], frag_k[(ki - 1) & 1]); - } - - // Store dP to smem for transpose - smem_s.store(s_regs); - if(l < STEPS - 1) { - // Load next part of S - gmem_s.load(s_regs, mask); - gmem_s.move(); - smem_q.move_to_next_write_buffer(); - gmem_q.move(); - gmem_q.load(smem_q); - } - // Loop over MMAS_M. - #pragma unroll - for( int ii = 0; ii < Gmem_tile_o::LOOPS; ++ii ) { - - // Swizzle the elements and do the final reduction. - smem_o.store(acc_o, ii); - - // Make sure the data is in shared memory. - __syncthreads(); - - // Load from shared memory. - uint4 out[Gmem_tile_o::STGS_PER_LOOP]; - smem_o.load(out); - - // Make sure the data was read from shared memory. - if( ii < Gmem_tile_o::LOOPS - 1 ) { - __syncthreads(); - } - - // Output the values. - gmem_o.store(out, ii); - } - - // Move to the next part of the output. - gmem_o.move(); - - typename Smem_tile_st::Fragment frag_s[Mma_tile_dk::MMAS_K][Mma_tile_dk::MMAS_M]; - smem_s.load(frag_s); - - #pragma unroll - for( int ki = 1; ki < Mma_tile_dk::MMAS_K; ++ki ) { - // Trigger the load from shared memory for the next series of Q values. - smem_qt.load(frag_qt[ki & 1], ki); - // Do the math for the values already in registers. - fmha::gemm(acc_dk, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]); - } - - // Do the final stage of math. - { - int ki = Mma_tile_dk::MMAS_K; - fmha::gemm(acc_dk, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]); - } - - // Commit the values for Q into shared memory. - if( l < STEPS - 1) { - gmem_q.commit(smem_q); - } - - // Make sure the data is in shared memory. - __syncthreads(); - - // Trigger the loads for the values of Q for the next iteration. - smem_qt.load(frag_qt[0], 0); - smem_k.load(frag_k[0], 0); - - } // Outer loop over the sequence length. - - // Epilogue swizzle for dK - Smem_tile_dk smem_dk(&smem_[0], tidx); - smem_dk.store(acc_dk); - __syncthreads(); - uint4 dk_out[Smem_tile_dk::NUM_LDS]; - smem_dk.load(dk_out); - Qkv_params dk_params; - dk_params.qkv_ptr = params.dqkv_ptr; - dk_params.qkv_stride_in_bytes = params.qkv_stride_in_bytes; - dk_params.h = params.h; - Gmem_tile_dk gmem_dk(dk_params, 1, binfo, tidx); - gmem_dk.store(dk_out); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace fmha diff --git a/apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload_nl.h b/apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload_nl.h deleted file mode 100644 index 26776d4..0000000 --- a/apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload_nl.h +++ /dev/null @@ -1,569 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#pragma once - -#include "fmha_kernel.h" -#include -#include - -namespace fmha { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void compute_dv_1xN_nl(const Params ¶ms) { - - // The description of the CTA tile for the 1st batched GEMM. - using Cta_tile_p = typename Kernel_traits::Cta_tile_p; - // The description of the CTA tile for the 2nd batched GEMM. - using Cta_tile_dv = fmha::Cta_tile_extd; - - static_assert(Cta_tile_dv::M == 512 || Cta_tile_dv::M == 384 || Cta_tile_dv::M == 256 || Cta_tile_dv::M == 128); - static_assert(Cta_tile_dv::N == 64); - static_assert(Cta_tile_dv::K == 16); - - // The MMA tile for the 1st GEMM. - using Mma_tile_p = fmha::Hmma_tile; - // The MMA tile for the 2nd GEMM. - using Mma_tile_dv = fmha::Hmma_tile; - - // The global memory tile to load Q. - using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q; - // The shared memory tile to swizzle Q. - using Smem_tile_q = fmha::Smem_tile_a; - // The shared memory tile to reload Q as fragment b. - using Smem_tile_qt = fmha::Smem_tile_b; - - // The global memory tile to load K. - using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k; - // The shared memory tile to swizzle K. - using Smem_tile_k = typename Kernel_traits::Smem_tile_k; - - // The global memory tile to load V. - using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v; - // The shared memory tile to swizzle V. - using Smem_tile_v = typename Kernel_traits::Smem_tile_v; - - // The global memory tile to store dV. - using Gmem_tile_dv = fmha::Gmem_tile_qkv; - - // The shared memory tile to swizzle dV. - using Smem_tile_dv = fmha::Smem_tile_mma_epilogue; - static_assert(Smem_tile_dv::NUM_LDS == Gmem_tile_dv::LDGS); - static_assert(Smem_tile_dv::THREADS_PER_ROW == Gmem_tile_dv::THREADS_PER_ROW); - - using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s; - using Smem_tile_st = typename Kernel_traits::Smem_tile_st; - using Gmem_tile_do = typename Kernel_traits::Gmem_tile_do; - - // Shared memory. - extern __shared__ char smem_[]; - - // The block index for the chunk. - const int bidc = blockIdx.z; - // The block index for the batch. - const int bidb = blockIdx.y; - // The block index for the head. - const int bidh = blockIdx.x; - // The thread index. - const int tidx = threadIdx.x; - - const BlockInfoPadded binfo(params, bidb, bidh, tidx); - if( binfo.stop_early() ) - return; - fmha::Mask mask(params, binfo, tidx); - - // Allocate the global memory tile loader for Q. - Gmem_tile_do gmem_q(params, binfo, tidx); // treating dout as Q - // Allocate the shared memory tile loader for Q. - Smem_tile_q smem_q(&smem_[0], tidx); - Smem_tile_qt smem_qt(&smem_[0], tidx); - Smem_tile_st smem_s(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], tidx); - - // Allocate the global memory tile loader for K. - Gmem_tile_k gmem_k(params, 2, binfo, tidx); // treating V as K - // Allocate the shared memory tile loader for K. - Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx); - - Gmem_tile_s gmem_s(params, binfo, tidx); - - using Noloop = Noloop_traits; - - Noloop nl_traits(bidc, binfo); - nl_traits.move_all(gmem_q, gmem_s); - - // Trigger the loads for Q. - gmem_q.load(smem_q); - // Trigger the loads for K. - gmem_k.load(smem_k); - - // Commit the data for Q and K to shared memory. - gmem_q.commit(smem_q); - gmem_k.commit(smem_k); - - // Make sure the data is in shared memory. - __syncthreads(); - - // Load the fragments for Q. - typename Smem_tile_q::Fragment frag_q[2][Mma_tile_p::MMAS_M]; - smem_q.load(frag_q[0], 0); - - typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dv::MMAS_N]; - static_assert(Smem_tile_qt::Fragment::NUM_REGS == 4); - static_assert(Mma_tile_dv::MMAS_K == 1); - smem_qt.load(frag_qt[0], 0); - - // Load the fragments for K. We keep the data in registers during the entire kernel. - typename Smem_tile_k::Fragment frag_k[2][Mma_tile_p::MMAS_N]; - smem_k.load(frag_k[0], 0); - - enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 }; - - // Create the object to do the softmax. - using Softmax = fmha::Softmax; - Softmax softmax( - params, &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], bidb, tidx); - - enum { THREADS_PER_ROW = 32 }; - enum { M = Mma_tile_p::MMAS_M }; - enum { N = Mma_tile_p::MMAS_N }; - - // Declare the accumulators for the 2nd gemm. - fmha::Fragment_accumulator acc_dv[Mma_tile_dv::MMAS_M][Mma_tile_dv::MMAS_N]; - fmha::Clear_accumulator::apply(acc_dv); - - // Load over the entire sequence length. - for(int l = 0; l < nl_traits.num_steps_;l++) { - - uint4 s_regs[M][N]; - gmem_s.load(s_regs, mask); - fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; - fmha::Clear_accumulator::apply(acc_p); - // Do this part of P^T = (Q * K^T)^T. - #pragma unroll - for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) { - // Trigger the load from shared memory for the next series of Q values. - smem_q.load(frag_q[ki & 1], ki); - smem_k.load(frag_k[ki & 1], ki); - // Do the math for the values already in registers. - fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]); - } - - smem_s.store(s_regs); - - // Declare the accumulators for the 1st gemm. - // Do the final stage of math. - { - int ki = Mma_tile_p::MMAS_K; - fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]); - } - // Trigger the load for the next Q values. We're using double buffering, so reading qt is safe - if(l < nl_traits.num_steps_ - 1) { - smem_q.move_to_next_write_buffer(); - gmem_q.move(); - gmem_q.load(smem_q); - } - // Convert from the accumulator type to FP32 for Softmax. - softmax.unpack(acc_p); - - float s_mat[2 * M][4 * N]; - - #pragma unroll - for( int mi = 0; mi < M; mi++ ) { - #pragma unroll - for( int ni = 0; ni < N; ni++ ) { - uint4 &dst = s_regs[mi][ni]; - fmha::half2_to_float2(s_mat[2 * mi + 0][4 * ni + 0], s_mat[2 * mi + 0][4 * ni + 1], dst.x); - fmha::half2_to_float2(s_mat[2 * mi + 0][4 * ni + 2], s_mat[2 * mi + 0][4 * ni + 3], dst.y); - fmha::half2_to_float2(s_mat[2 * mi + 1][4 * ni + 0], s_mat[2 * mi + 1][4 * ni + 1], dst.z); - fmha::half2_to_float2(s_mat[2 * mi + 1][4 * ni + 2], s_mat[2 * mi + 1][4 * ni + 3], dst.w); - } - } - - #pragma unroll - for( int mi = 0; mi < M; mi++ ) { - #pragma unroll - for( int ii = 0; ii < 2; ii++ ) { - #pragma unroll - for( int ni = 0; ni < N; ni++ ) { - #pragma unroll - for( int jj = 0; jj < 4; jj++ ) { - float & s_dmask = s_mat[2 * mi + ii][4 * ni + jj]; - const bool drop = reinterpret_cast(s_dmask) & 0x80000000; - const float d_s= drop ? 0.f : softmax.elt_[2 * mi + ii][4 * ni + jj] * params.rp_dropout; - s_dmask = fabsf(s_dmask); - softmax.elt_[2 * mi + ii][4 * ni + jj] = d_s * (s_dmask); - } - } - } - } - - float p_sum[2 * M]; - softmax.reduce_sum(p_sum); - - const float scalef = reinterpret_cast(params.scale_softmax); - #pragma unroll - for( int mi = 0; mi < M; mi++ ) { - #pragma unroll - for( int ii = 0; ii < 2; ii++ ) { - #pragma unroll - for( int ni = 0; ni < N; ni++ ) { - #pragma unroll - for( int jj = 0; jj < 4; jj++ ) { - softmax.elt_[2 * mi + ii][4 * ni + jj] -= p_sum[2 * mi + ii] * (s_mat[2 * mi + ii][4 * ni + jj]) ; - softmax.elt_[2 * mi + ii][4 * ni + jj] *= scalef; - } - } - } - } - - typename Smem_tile_st::Fragment frag_s[Mma_tile_dv::MMAS_K][Mma_tile_dv::MMAS_M]; - smem_s.load(frag_s); - for( int ki = 0; ki < Mma_tile_dv::MMAS_K; ki++ ) { - for( int mi = 0; mi < Mma_tile_dv::MMAS_M; mi++ ) { - for( int ii = 0; ii < Smem_tile_st::Fragment::NUM_REGS; ii++ ) { - frag_s[ki][mi].reg(ii) = fmha::hmul2(frag_s[ki][mi].reg(ii), params.scale_dropout); - frag_s[ki][mi].reg(ii) = fmha::hrelu2(frag_s[ki][mi].reg(ii)); - } - } - } - - gmem_s.store(softmax.elt_, mask); - gmem_s.move(); - - static_assert(Mma_tile_dv::MMAS_K == 1); // DEBUG - #pragma unroll - for( int ki = 1; ki < Mma_tile_dv::MMAS_K; ++ki ) { - // Trigger the load from shared memory for the next series of Q values. - smem_qt.load(frag_qt[ki & 1], ki); - // Do the math for the values already in registers. - fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]); - } - - // Do the final stage of math. - { - int ki = Mma_tile_dv::MMAS_K; - fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]); - } - // Commit the values for Q into shared memory. - if(l < nl_traits.num_steps_ - 1) { - gmem_q.commit(smem_q); - } - - // Make sure we are reading from the correct buffer. - smem_q.move_to_next_read_buffer(); - smem_qt.move_to_next_read_buffer(); - - // Make sure the data is in shared memory. - __syncthreads(); - - // Trigger the loads for the values of Q for the next iteration. - smem_q.load(frag_q[0], 0); - smem_k.load(frag_k[0], 0); - smem_qt.load(frag_qt[0], 0); - - } // Outer loop over the sequence length. - - // Epilogue for dV = (S * D)' * dout'. We're fully exposed to this! - - // Epilogue swizzle for dV - Smem_tile_dv smem_dv(&smem_[Kernel_traits::Smem_tile_q::BYTES_PER_TILE], tidx); - smem_dv.store(acc_dv); - - __syncthreads(); - - uint4 dv_out[Smem_tile_dv::NUM_LDS]; - smem_dv.load(dv_out); - Qkv_params dv_params; - dv_params.qkv_ptr = params.dkv_ptr; - dv_params.qkv_stride_in_bytes = params.h * 2 * CHUNKS * params.d * sizeof(half); - dv_params.h = params.h; - Gmem_tile_dv gmem_dv(dv_params, nl_traits.get_idx_dv(), binfo, tidx); - gmem_dv.store(dv_out); -} - -template -inline __device__ void compute_dq_dk_1xN_nl(const Params ¶ms) { - - // The description of the CTA tile for the 1st batched GEMM. - using Cta_tile_p = typename Kernel_traits::Cta_tile_p; - using Cta_tile_o = typename Kernel_traits::Cta_tile_o; - // The description of the CTA tile for the 2nd batched GEMM. - using Cta_tile_dk = fmha::Cta_tile_extd; - - static_assert(Cta_tile_dk::M == 512 || Cta_tile_dk::M == 384 || Cta_tile_dk::M == 256 || Cta_tile_dk::M == 128); - static_assert(Cta_tile_dk::N == 64); - static_assert(Cta_tile_dk::K == 16); - - // The MMA tile for the 1st GEMM. - using Mma_tile_p = fmha::Hmma_tile; - using Mma_tile_o = fmha::Hmma_tile; - // The MMA tile for the 2nd GEMM. - using Mma_tile_dk = fmha::Hmma_tile; - - // The global memory tile to load Q. - using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q; - // The shared memory tile to swizzle Q. - using Smem_tile_q = typename Kernel_traits::Smem_tile_q; - - // The global memory tile to load K. - using Gmem_tile_k = typename Kernel_traits::Gmem_tile_v; - // The shared memory tile to swizzle K. - using Smem_tile_k = typename Kernel_traits::Smem_tile_v; // K is used like V in fprop - - // The global memory tile to load V. - using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v; - // The shared memory tile to swizzle V. - using Smem_tile_v = typename Kernel_traits::Smem_tile_v; - - // The global memory tile to store O. - using Gmem_tile_o = Gmem_tile_dq; - // The shared memory tile to swizzle O. - using Smem_tile_o = typename Kernel_traits::Smem_tile_o; - - // The global memory tile to store dK. - using Gmem_tile_dk = fmha::Gmem_tile_qkv; - - // The shared memory tile to swizzle dK. - using Smem_tile_dk = fmha::Smem_tile_mma_epilogue; - static_assert(Smem_tile_dk::NUM_LDS == Gmem_tile_dk::LDGS); - static_assert(Smem_tile_dk::THREADS_PER_ROW == Gmem_tile_dk::THREADS_PER_ROW); - - // The shared memory tile to reload Q transposed. - using Smem_tile_qt = fmha::Smem_tile_b; - - // The global memory tile to load dP, stored in S - using Gmem_tile_s = Gmem_tile_mma_s; - // The shared memory tile to transpose dP. - using Smem_tile_st = Smem_tile_mma_transposed; - - using Noloop = Noloop_traits; - - enum { M = Mma_tile_p::MMAS_M }; - enum { N = Mma_tile_p::MMAS_N }; - static_assert(M == Mma_tile_o::MMAS_M); - static_assert(N == Mma_tile_o::MMAS_K); - // Shared memory. - extern __shared__ char smem_[]; - - const int bidc = blockIdx.z; - // The block index for the batch. - const int bidb = blockIdx.y; - // The block index for the head. - const int bidh = blockIdx.x; - // The thread index. - const int tidx = threadIdx.x; - - const BlockInfoPadded binfo(params, bidb, bidh, tidx); - if( binfo.stop_early() ) - return; - - fmha::Mask mask(params, binfo, tidx); - - // Allocate the global memory tile loader for Q. - Gmem_tile_q gmem_q(params, 0, binfo, tidx); - // Allocate the shared memory tile loader for Q (as B). - Smem_tile_qt smem_qt(&smem_[0], tidx); - // Allocate the global memory tile loader for dP. - Gmem_tile_s gmem_s(params, binfo, tidx); - // Allocate the shared memory tile loader for dP. - Smem_tile_st smem_s(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE], tidx); - - // Allocate the global memory tile loader for K. - Gmem_tile_k gmem_k(params, 1, binfo, tidx); - // Allocate the shared memory tile loader for K. - Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx); - - // Allocate the global memory tile loader for O. - Gmem_tile_o gmem_o(params, binfo, tidx); - // Allocate the shared memory tile loader for O. We use the same as K so be careful!!! - Smem_tile_o smem_o(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], tidx); - - Noloop nl_traits(bidc, binfo); - - nl_traits.move_all(gmem_q, gmem_o, gmem_s); - - // Trigger the loads for Q. - gmem_q.load(smem_qt); - // Trigger the loads for K. - gmem_k.load(smem_k); - - uint4 s_regs[M][N]; - gmem_s.load(s_regs, mask); - - // Commit the data for Q and K to shared memory. - gmem_q.commit(smem_qt); - gmem_k.commit(smem_k); - - // Make sure the data is in shared memory. - __syncthreads(); - - typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dk::MMAS_N]; - smem_qt.load(frag_qt[0], 0); - typename Smem_tile_k::Fragment frag_k[2][Mma_tile_o::MMAS_N]; - smem_k.load(frag_k[0], 0); - - enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 }; - - enum { THREADS_PER_ROW = 32 }; - - // Declare the accumulators for the 2nd gemm. - fmha::Fragment_accumulator acc_dk[Mma_tile_dk::MMAS_M][Mma_tile_dk::MMAS_N]; - fmha::Clear_accumulator::apply(acc_dk); - - // Load over the entire sequence length. - for(int l=0;l < nl_traits.num_steps_; l++) { - - // Pack dP as Fragment_a - fmha::Fragment_a frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M]; - #pragma unroll - for( int mi = 0; mi < M; mi++ ) { - #pragma unroll - for( int ni = 0; ni < N; ni++ ) { - uint4 &dst = s_regs[mi][ni]; - frag_p[ni][mi].reg(0) = dst.x; - frag_p[ni][mi].reg(1) = dst.z; - frag_p[ni][mi].reg(2) = dst.y; - frag_p[ni][mi].reg(3) = dst.w; - } - } - smem_s.store(s_regs); - if(l < nl_traits.num_steps_- 1) { - // Load next part of S - gmem_s.move(); - gmem_s.load(s_regs, mask); - // Trigger the load for the next Q values. - smem_qt.move_to_next_write_buffer(); - gmem_q.move(); - gmem_q.load(smem_qt); - } - // Declare the accumulators for the 1st gemm. - fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N]; - fmha::Clear_accumulator::apply(acc_o); - - // Do this part of O = P^T * V^T. dQ = dP x dK - #pragma unroll - for( int ki = 1; ki < Mma_tile_o::MMAS_K; ++ki ) { - // Trigger the load from shared memory for the next series of Q values. - smem_k.load(frag_k[ki & 1], ki); - // Do the math for the values already in registers. - fmha::gemm(acc_o, frag_p[ki - 1], frag_k[(ki - 1) & 1]); - } - - // Do the final stage of math. - { - int ki = Mma_tile_o::MMAS_K; - fmha::gemm(acc_o, frag_p[ki - 1], frag_k[(ki - 1) & 1]); - } - - static_assert(Gmem_tile_o::LOOPS == 1); //DEBUG - // Loop over MMAS_M. - #pragma unroll - for( int ii = 0; ii < Gmem_tile_o::LOOPS; ++ii ) { - - // Swizzle the elements and do the final reduction. - smem_o.store(acc_o, ii); - - // Make sure the data is in shared memory. - __syncthreads(); - - // Load from shared memory. - uint4 out[Gmem_tile_o::STGS_PER_LOOP]; - smem_o.load(out); - - // Make sure the data was read from shared memory. - if( ii < Gmem_tile_o::LOOPS - 1 ) { - __syncthreads(); - } - - // Output the values. - gmem_o.store(out, ii); - } - - // Move to the next part of the output. - gmem_o.move(); - - typename Smem_tile_st::Fragment frag_s[Mma_tile_dk::MMAS_K][Mma_tile_dk::MMAS_M]; - smem_s.load(frag_s); - - static_assert(Mma_tile_dk::MMAS_K == 1); // DEBUG - - #pragma unroll - for( int ki = 1; ki < Mma_tile_dk::MMAS_K; ++ki ) { - // Trigger the load from shared memory for the next series of Q values. - smem_qt.load(frag_qt[ki & 1], ki); - // Do the math for the values already in registers. - fmha::gemm(acc_dk, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]); - } - - // Do the final stage of math. - { - int ki = Mma_tile_dk::MMAS_K; - fmha::gemm(acc_dk, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]); - } - - // Commit the values for Q into shared memory. - if(l < nl_traits.num_steps_- 1) { - gmem_q.commit(smem_qt); - __syncthreads(); - // Trigger the loads for the values of Q for the next iteration. - smem_qt.load(frag_qt[0], 0); - smem_k.load(frag_k[0], 0); - } - - } // Outer loop over the sequence length. - - // Epilogue for dK = dP' * dq. We're fully exposed to this! - - // Epilogue swizzle for dK - Smem_tile_dk smem_dk(&smem_[0], tidx); - smem_dk.store(acc_dk); - - __syncthreads(); - - uint4 dk_out[Smem_tile_dk::NUM_LDS]; - smem_dk.load(dk_out); - Qkv_params dk_params; - dk_params.qkv_ptr = params.dkv_ptr; - dk_params.qkv_stride_in_bytes = params.h * 2 * CHUNKS * params.d * sizeof(half); - dk_params.h = params.h; - Gmem_tile_dk gmem_dk(dk_params, nl_traits.get_idx_dk(), binfo, tidx); - gmem_dk.store(dk_out); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace fmha diff --git a/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu b/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu deleted file mode 100644 index 9ebcbc5..0000000 --- a/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu +++ /dev/null @@ -1,84 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#include "fmha.h" -#include "fmha_fprop_kernel_1xN.h" - -using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>; - -template -__global__ -void fmha_fprop_fp16_128_64_sm80_kernel(Fused_multihead_attention_fprop_params params, - const int num_full_heads, - const int num_main_groups, - const int main_group_size, - const int main_steps, - const int rest_steps) { - - fmha::device_1xN( - params, num_full_heads, num_main_groups, main_group_size, main_steps, rest_steps); -} - -void run_fmha_fp16_128_64_sm80(Launch_params &launch_params, const bool configure) { - - auto kernel = launch_params.is_training ? &fmha_fprop_fp16_128_64_sm80_kernel : &fmha_fprop_fp16_128_64_sm80_kernel; - - constexpr int smem_size = fmha::get_dynamic_smem_size(); - - if( smem_size >= 48 * 1024 ) { - FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - - const int sm_count = launch_params.props->multiProcessorCount; - int ctas_per_sm; - FMHA_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size)); - int total_ctas = sm_count * ctas_per_sm; - - if(configure) { - const int heads_total = launch_params.params.b * launch_params.params.h; - std::tie(launch_params.num_full_heads, - launch_params.num_main_groups, - launch_params.heads_last_wave, - launch_params.main_steps, - launch_params.rest_steps, - launch_params.elts_per_thread) = fmha::work_dist(total_ctas, heads_total); - return; - } - - dim3 grid(total_ctas); - kernel<<>>( - launch_params.params, - launch_params.num_full_heads, - launch_params.num_main_groups, - launch_params.heads_last_wave, - launch_params.main_steps, - launch_params.rest_steps); - - FMHA_CHECK_CUDA(cudaPeekAtLastError()); - -} - diff --git a/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu b/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu deleted file mode 100644 index 448b9ad..0000000 --- a/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu +++ /dev/null @@ -1,84 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#include "fmha.h" -#include "fmha_fprop_kernel_1xN.h" - -using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>; - -template -__global__ -void fmha_fprop_fp16_256_64_sm80_kernel(Fused_multihead_attention_fprop_params params, - const int num_full_heads, - const int num_main_groups, - const int main_group_size, - const int main_steps, - const int rest_steps) { - - fmha::device_1xN( - params, num_full_heads, num_main_groups, main_group_size, main_steps, rest_steps); -} - -void run_fmha_fp16_256_64_sm80(Launch_params &launch_params, const bool configure) { - - auto kernel = launch_params.is_training ? &fmha_fprop_fp16_256_64_sm80_kernel : &fmha_fprop_fp16_256_64_sm80_kernel; - - constexpr int smem_size = fmha::get_dynamic_smem_size(); - - if( smem_size >= 48 * 1024 ) { - FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - - const int sm_count = launch_params.props->multiProcessorCount; - int ctas_per_sm; - FMHA_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size)); - int total_ctas = sm_count * ctas_per_sm; - - if(configure) { - const int heads_total = launch_params.params.b * launch_params.params.h; - std::tie(launch_params.num_full_heads, - launch_params.num_main_groups, - launch_params.heads_last_wave, - launch_params.main_steps, - launch_params.rest_steps, - launch_params.elts_per_thread) = fmha::work_dist(total_ctas, heads_total); - return; - } - - dim3 grid(total_ctas); - kernel<<>>( - launch_params.params, - launch_params.num_full_heads, - launch_params.num_main_groups, - launch_params.heads_last_wave, - launch_params.main_steps, - launch_params.rest_steps); - - FMHA_CHECK_CUDA(cudaPeekAtLastError()); - -} - diff --git a/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu b/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu deleted file mode 100644 index f1f21dc..0000000 --- a/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu +++ /dev/null @@ -1,84 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#include "fmha.h" -#include "fmha_fprop_kernel_1xN.h" - -using Kernel_traits = FMHA_kernel_traits<384, 64, 16, 1, 4, 0x18u>; - -template -__global__ -void fmha_fprop_fp16_384_64_sm80_kernel(Fused_multihead_attention_fprop_params params, - const int num_full_heads, - const int num_main_groups, - const int main_group_size, - const int main_steps, - const int rest_steps) { - - fmha::device_1xN( - params, num_full_heads, num_main_groups, main_group_size, main_steps, rest_steps); -} - -void run_fmha_fp16_384_64_sm80(Launch_params &launch_params, const bool configure) { - - auto kernel = launch_params.is_training ? &fmha_fprop_fp16_384_64_sm80_kernel : &fmha_fprop_fp16_384_64_sm80_kernel; - - constexpr int smem_size = fmha::get_dynamic_smem_size(); - - if( smem_size >= 48 * 1024 ) { - FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - - const int sm_count = launch_params.props->multiProcessorCount; - int ctas_per_sm; - FMHA_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size)); - int total_ctas = sm_count * ctas_per_sm; - - if(configure) { - const int heads_total = launch_params.params.b * launch_params.params.h; - std::tie(launch_params.num_full_heads, - launch_params.num_main_groups, - launch_params.heads_last_wave, - launch_params.main_steps, - launch_params.rest_steps, - launch_params.elts_per_thread) = fmha::work_dist(total_ctas, heads_total); - return; - } - - dim3 grid(total_ctas); - kernel<<>>( - launch_params.params, - launch_params.num_full_heads, - launch_params.num_main_groups, - launch_params.heads_last_wave, - launch_params.main_steps, - launch_params.rest_steps); - - FMHA_CHECK_CUDA(cudaPeekAtLastError()); - -} - diff --git a/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu b/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu deleted file mode 100644 index e37689e..0000000 --- a/apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu +++ /dev/null @@ -1,137 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#include "fmha.h" -#include "fmha_fprop_kernel_1xN.h" - -using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x00u>; - -template -__global__ -void fmha_fprop_fp16_512_64_sm80_kernel(Fused_multihead_attention_fprop_params params, - const int total_heads) { - - fmha::device_1xN(params, total_heads); -} - -template -__global__ -void fmha_fprop_fp16_512_64_sm80_kernel_nl(Fused_multihead_attention_fprop_params params, - const int num_full_heads, - const int num_main_groups, - const int main_group_size, - const int main_steps, - const int rest_steps) { - - fmha::device_1xN( - params, num_full_heads, num_main_groups, main_group_size, main_steps, rest_steps); -} - -void run_fmha_fp16_512_64_sm80_(Launch_params &launch_params, const bool configure) { - - auto kernel = launch_params.is_training ? &fmha_fprop_fp16_512_64_sm80_kernel : &fmha_fprop_fp16_512_64_sm80_kernel; - - constexpr int smem_size = fmha::get_dynamic_smem_size(); - - if( smem_size >= 48 * 1024 ) { - FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - - const int sm_count = launch_params.props->multiProcessorCount; - int ctas_per_sm; - FMHA_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size)); - int total_ctas = sm_count * ctas_per_sm; - - const int heads_total = launch_params.params.b * launch_params.params.h; - if(configure) { - - using Mma_tile_p = fmha::Hmma_tile; - constexpr size_t STEPS = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M; - constexpr size_t MMAS_M = Mma_tile_p::MMAS_M; - constexpr size_t MMAS_N = Mma_tile_p::MMAS_N; - - size_t heads_per_cta = ((heads_total + total_ctas - 1) / total_ctas); - size_t elts_per_head = STEPS * MMAS_M * MMAS_N * 8; - launch_params.elts_per_thread = heads_per_cta * elts_per_head; - return; - } - - dim3 grid(total_ctas); - kernel<<>>( - launch_params.params, - heads_total); - - FMHA_CHECK_CUDA(cudaPeekAtLastError()); - -} - -void run_fmha_fp16_512_64_sm80_nl_(Launch_params &launch_params, const bool configure) { - - auto kernel = launch_params.is_training ? &fmha_fprop_fp16_512_64_sm80_kernel_nl : &fmha_fprop_fp16_512_64_sm80_kernel_nl; - - constexpr int smem_size = fmha::get_dynamic_smem_size(); - - if( smem_size >= 48 * 1024 ) { - FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - - const int sm_count = launch_params.props->multiProcessorCount; - int ctas_per_sm; - FMHA_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size)); - int total_ctas = sm_count * ctas_per_sm; - - if(configure) { - const int heads_total = launch_params.params.b * launch_params.params.h; - std::tie(launch_params.num_full_heads, - launch_params.num_main_groups, - launch_params.heads_last_wave, - launch_params.main_steps, - launch_params.rest_steps, - launch_params.elts_per_thread) = fmha::work_dist(total_ctas, heads_total); - return; - } - - dim3 grid(total_ctas); - kernel<<>>( - launch_params.params, - launch_params.num_full_heads, - launch_params.num_main_groups, - launch_params.heads_last_wave, - launch_params.main_steps, - launch_params.rest_steps); - - FMHA_CHECK_CUDA(cudaPeekAtLastError()); - -} - -void run_fmha_fp16_512_64_sm80(Launch_params &launch_params, const bool configure) { - if( launch_params.is_nl ) { - run_fmha_fp16_512_64_sm80_nl_(launch_params, configure); - } else { - run_fmha_fp16_512_64_sm80_(launch_params, configure); - } -} diff --git a/apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN.h b/apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN.h deleted file mode 100644 index 5a040cf..0000000 --- a/apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN.h +++ /dev/null @@ -1,531 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#pragma once - -#include "fmha_kernel.h" -#include -#include - -namespace fmha { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Gemm_Q_K_base { - using Smem_tile_o = typename Kernel_traits::Smem_tile_o; - using Smem_tile_q = typename Kernel_traits::Smem_tile_q; - using Smem_tile_k = typename Kernel_traits::Smem_tile_k; - using Fragment_q = typename Smem_tile_q::Fragment; - using Fragment_k = typename Smem_tile_k::Fragment; - - // The description of the CTA tile for the 1st batched GEMM. - using Cta_tile_p = typename Kernel_traits::Cta_tile_p; - - // The MMA tile for the 1st GEMM. - using Mma_tile_p = fmha::Hmma_tile; - - static constexpr int SMEM_BYTES_SOFTMAX = Cta_tile_p::M * Cta_tile_p::WARPS_N * sizeof(float) * 2; - - __device__ inline Gemm_Q_K_base(char * smem_ptr_q, char * smem_ptr_k, const int tidx) - : smem_q(smem_ptr_q, tidx) - , smem_k(smem_ptr_k, tidx) { - - } - - __device__ inline void load_q() { - smem_q.load(frag_q[0], 0); - } - - __device__ inline void reload_q() { - smem_q.load(frag_q[0], 0); - } - - Fragment_q frag_q[2][Mma_tile_p::MMAS_M]; - Smem_tile_q smem_q; - Smem_tile_k smem_k; -}; - -template -struct Gemm_Q_K : public Gemm_Q_K_base { - - using Base = Gemm_Q_K_base; - using Smem_tile_o = typename Base::Smem_tile_o; - using Smem_tile_q = typename Base::Smem_tile_q; - using Smem_tile_k = typename Base::Smem_tile_k; - using Fragment_k = typename Base::Fragment_k; - using Mma_tile_p = typename Base::Mma_tile_p; - - enum { SHARE_SMEM_FOR_K_AND_V = Kernel_traits::SHARE_SMEM_FOR_K_AND_V }; - - enum { SMEM_OFFSET_O = Smem_tile_q::BYTES_PER_TILE }; - enum { SMEM_OFFSET_V = Smem_tile_q::BYTES_PER_TILE + (SHARE_SMEM_FOR_K_AND_V ? 0 : Smem_tile_k::BYTES_PER_TILE) }; - - // Q | K / V - // | O | SOFTMAX - static constexpr int SMEM_BYTES = Smem_tile_q::BYTES_PER_TILE - + std::max((SHARE_SMEM_FOR_K_AND_V ? 1 : 2) * Smem_tile_k::BYTES_PER_TILE, - Smem_tile_o::BYTES_PER_TILE + Base::SMEM_BYTES_SOFTMAX); - - __device__ inline Gemm_Q_K(char * smem_, const int tidx) - : Base(smem_, smem_ + Smem_tile_q::BYTES_PER_TILE, tidx) { - } - - __device__ inline void load_k(){ - #pragma unroll - for( int ki = 0; ki < Mma_tile_p::MMAS_K; ++ki ) { - Base::smem_k.load(frag_k[ki], ki); - } - } - - template - __device__ inline void operator()(Acc (&acc_p)[M][N]){ - // Do this part of P^T = (Q * K^T)^T. - #pragma unroll - for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) { - // Trigger the load from shared memory for the next series of Q values. - Base::smem_q.load(Base::frag_q[ki & 1], ki); - // Do the math for the values already in registers. - fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]); - } - // Do the final stage of math. - { - int ki = Mma_tile_p::MMAS_K; - fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]); - } - } - - __device__ inline void reload_k(){ - // Noop. - } - - Fragment_k frag_k[Mma_tile_p::MMAS_K][Mma_tile_p::MMAS_N]; -}; - - -template -struct Gemm_Q_K : public Gemm_Q_K_base { - using Base = Gemm_Q_K_base; - using Smem_tile_o = typename Base::Smem_tile_o; - using Smem_tile_q = typename Base::Smem_tile_q; - using Smem_tile_k = typename Base::Smem_tile_k; - using Smem_tile_v = typename Kernel_traits::Smem_tile_v; - using Fragment_k = typename Base::Fragment_k; - using Mma_tile_p = typename Base::Mma_tile_p; - Fragment_k frag_k[2][Mma_tile_p::MMAS_N]; - - enum { SHARE_SMEM_FOR_K_AND_V = Kernel_traits::SHARE_SMEM_FOR_K_AND_V }; - - enum { SMEM_OFFSET_V = Smem_tile_q::BYTES_PER_TILE + (SHARE_SMEM_FOR_K_AND_V ? 0 : Smem_tile_k::BYTES_PER_TILE) }; - static_assert(Smem_tile_v::BYTES_PER_TILE == (int) Smem_tile_k::BYTES_PER_TILE); - enum { SMEM_OFFSET_O = SMEM_OFFSET_V + Smem_tile_v::BYTES_PER_TILE }; - - // Q | K/V + O + SOFTMAX - static constexpr int SMEM_BYTES = Smem_tile_q::BYTES_PER_TILE - + (SHARE_SMEM_FOR_K_AND_V ? 1 : 2) * Smem_tile_k::BYTES_PER_TILE - + Smem_tile_o::BYTES_PER_TILE + Base::SMEM_BYTES_SOFTMAX; - - __device__ inline Gemm_Q_K(char * smem_, const int tidx) - : Base(smem_, smem_ + Smem_tile_q::BYTES_PER_TILE, tidx) { - } - - __device__ inline void load_k(){ - Base::smem_k.load(frag_k[0], 0); - } - - template - __device__ inline void operator()(Acc (&acc_p)[M][N]){ - // Do this part of P^T = (Q * K^T)^T. - #pragma unroll - for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) { - // Trigger the load from shared memory for the next series of Q values. - Base::smem_q.load(Base::frag_q[ki & 1], ki); - Base::smem_k.load(frag_k[ki & 1], ki); - // Do the math for the values already in registers. - fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]); - } - // Do the final stage of math. - { - int ki = Mma_tile_p::MMAS_K; - fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]); - } - } - - __device__ inline void reload_k(){ - Base::smem_k.load(frag_k[0], 0); - } -}; - -template -constexpr size_t get_dynamic_smem_size(){ - return Gemm_Q_K::SMEM_BYTES; -} - -template -inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const int bidh, const int begin, const int steps, Prng & ph) { - - - // The description of the CTA tile for the 1st batched GEMM. - using Cta_tile_p = typename Kernel_traits::Cta_tile_p; - // The description of the CTA tile for the 2nd batched GEMM. - using Cta_tile_o = typename Kernel_traits::Cta_tile_o; - - // The MMA tile for the 1st GEMM. - using Mma_tile_p = fmha::Hmma_tile; - // The MMA tile for the 2nd GEMM. - using Mma_tile_o = fmha::Hmma_tile; - - // The global memory tile to load Q. - using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q; - - // The global memory tile to load K. - using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k; - - // The global memory tile to load V. - using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v; - // The shared memory tile to swizzle V. - using Smem_tile_v = typename Kernel_traits::Smem_tile_v; - - // The global memory tile to store O. - using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o; - // The shared memory tile to swizzle O. - using Smem_tile_o = typename Kernel_traits::Smem_tile_o; - - using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s; - - using Gemm1 = Gemm_Q_K; - - using Softmax = fmha::Softmax; - - - // The number of threads per row. - enum { THREADS_PER_ROW = 32 }; - - enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 }; - - // Shared memory. - extern __shared__ char smem_[]; - - // The thread index. - const int tidx = threadIdx.x; - - const BlockInfoPadded binfo(params, bidb, bidh, tidx); - if( binfo.stop_early() ) return; - - Gemm1 gemm_q_k(smem_, tidx); - // Allocate the global memory tile loader for Q. - Gmem_tile_q gmem_q(params, 0, binfo, tidx); - // Allocate the global memory tile loader for O. - Gmem_tile_o gmem_o(params, binfo, tidx); - // Allocate the global memory tile loader for S. - Gmem_tile_s gmem_s(params, binfo, tidx); - // Wind gmem tiles to the correct position. - for( int it = 0; it < begin; it++ ) { - gmem_q.move(); - gmem_s.move(); - gmem_o.move(); - } - - fmha::Mask mask(params, binfo, tidx); - - // Allocate the global memory tile loader for K. - Gmem_tile_k gmem_k(params, 1, binfo, tidx); - // Allocate the global memory tile loader for V. - Gmem_tile_v gmem_v(params, 2, binfo, tidx); - // The base pointer of smem_v; - char *smem_v_ = &smem_[Gemm1::SMEM_OFFSET_V]; - - // Allocate the shared memory tile loader for V. We use the same as K so be careful!!! - Smem_tile_v smem_v(smem_v_, tidx); - - // Allocate the shared memory tile loader for O. We use the same as K so be careful!!! - Smem_tile_o smem_o(&smem_[Gemm1::SMEM_OFFSET_O], tidx); - - // Trigger the loads for K. - gmem_k.load(gemm_q_k.smem_k); - // Trigger the loads for Q. - gmem_q.load(gemm_q_k.smem_q); - // Trigger the loads for V. - gmem_v.load(smem_v); - - const uint32_t scale_bmm1 = reinterpret_cast(params.scale_bmm1); - #pragma unroll - for(int it=0;it < Gmem_tile_k::LDGS;it++){ - gmem_k.fetch_[it] = fmha::hmul8(scale_bmm1, gmem_k.fetch_[it]); - } - - - - // Commit the data for Q and V to shared memory. - gmem_q.commit(gemm_q_k.smem_q); - gmem_v.commit(smem_v); - - // Commit the data for K to shared memory. - if( !Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) { - gmem_k.commit(gemm_q_k.smem_k); - } - - __syncthreads(); - - // Load the fragments for Q. - gemm_q_k.load_q(); - - // Load the fragments for V. We keep the data in registers during the entire kernel. - typename Smem_tile_v::Fragment frag_v[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_N]; - #pragma unroll - for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) { - smem_v.load(frag_v[ki], ki); - } - - // Commit the data for V to shared memory if it has not been done already. - if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) { - // Make sure we are done loading the fragments for K. - __syncthreads(); - - // Commit the data to shared memory for V. - gmem_k.commit(gemm_q_k.smem_k); - - // Make sure the data is in shared memory. - __syncthreads(); - } - - // Load the fragments for K. - gemm_q_k.load_k(); - - // Create the object to do the softmax. - Softmax softmax(params, &smem_[Gemm1::SMEM_OFFSET_O + Smem_tile_o::BYTES_PER_TILE], bidb, tidx); - - // Load over the entire sequence length. - for( int l = 0; l < steps; l++ ) { - if(begin + l * Cta_tile_p::M >= binfo.actual_seqlen) break; - - // Declare the accumulators for the 1st gemm. - fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; - fmha::Clear_accumulator::apply(acc_p); - - // Do this part of P^T = (Q * K^T)^T. - gemm_q_k(acc_p); - - // Trigger the load for the next Q values. - if( l < steps - 1) { - gemm_q_k.smem_q.move_to_next_write_buffer(); - gmem_q.move(); - gmem_q.load(gemm_q_k.smem_q); - } - - // Load the mask for that iteration. - mask.load(begin + l); - - // Convert from the accumulator type to FP32 for Softmax. - softmax.unpack_noscale(acc_p); - - // Apply the mask. - softmax.apply_mask(mask); - - if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l == 0 ) { - // if we share K and V, it could be that V was not fully read yet but we write into smem for reduction - __syncthreads(); - } - // Compute the max. - float p_max[Mma_tile_p::MMAS_M * 2]; - //softmax.template reduce(p_max); - softmax.reduce_max(p_max); - - // Compute the exponential value. - softmax.apply_exp(p_max); - - // Compute the sum. - float p_sum[Mma_tile_p::MMAS_M * 2]; - softmax.reduce_sum(p_sum); - - // Finalize softmax on the accumulators of P^T. - softmax.scale(p_sum); - - using Frag_p = fmha::Fragment_a; - Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M]; - if( Is_training ) { - auto encode_dropout = [](bool keep, float val) { return keep ? val : -val; }; - #pragma unroll - for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) { - #pragma unroll - for( int ii = 0; ii < 2; ii++ ) { - #pragma unroll - for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) { - float4 tmp = uniform4(ph()); - // We encode the dropout pattern in the sign bit of the non-negative softmax to distinguish from pre-existing zeros - softmax.elt_[2 * mi + ii][4 * ni + 0] = - encode_dropout(tmp.x <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 0]); - softmax.elt_[2 * mi + ii][4 * ni + 1] = - encode_dropout(tmp.y <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 1]); - softmax.elt_[2 * mi + ii][4 * ni + 2] = - encode_dropout(tmp.z <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 2]); - softmax.elt_[2 * mi + ii][4 * ni + 3] = - encode_dropout(tmp.w <= params.p_dropout, softmax.elt_[2 * mi + ii][4 * ni + 3]); - } - } - } - softmax.pack(frag_p); - gmem_s.store(frag_p, mask); - gmem_s.move(); - } else { - softmax.pack(frag_p); - } - - // Commit the values for Q into shared memory. - if(l < steps - 1) { - gmem_q.commit(gemm_q_k.smem_q); - } - - #pragma unroll - for( int ki = 0; ki < Mma_tile_o::MMAS_K; ki++ ) { - #pragma unroll - for( int mi = 0; mi < Mma_tile_o::MMAS_M; mi++ ) { - #pragma unroll - for( int ii = 0; ii < Frag_p::NUM_REGS; ii++ ) { - //"Apply" the dropout. - frag_p[ki][mi].reg(ii) = fmha::hmul2(frag_p[ki][mi].reg(ii), params.scale_dropout); - frag_p[ki][mi].reg(ii) = fmha::hrelu2(frag_p[ki][mi].reg(ii)); - } - } - } - - // Declare the accumulators for the 1st gemm. - fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N]; - fmha::Clear_accumulator::apply(acc_o); - - // Do this part of O = P^T * V^T. - #pragma unroll - for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) { - fmha::gemm(acc_o, frag_p[ki], frag_v[ki]); - } - - // Loop over MMAS_M. - #pragma unroll - for( int ii = 0; ii < Gmem_tile_o::LOOPS; ++ii ) { - - // Swizzle the elements and do the final reduction. - smem_o.store(acc_o, ii); - - // Make sure the data is in shared memory. - __syncthreads(); - - // Load from shared memory. - uint4 out[Gmem_tile_o::STGS_PER_LOOP]; - smem_o.load(out); - - // Make sure the data was read from shared memory. - if( ii < Gmem_tile_o::LOOPS - 1 ) { - __syncthreads(); - } - - // Output the values. - gmem_o.store(out, ii); - } - - // Move to the next part of the output. - gmem_o.move(); - gemm_q_k.reload_k(); - - // Commit the values for Q into shared memory. - if(l < steps - 1) { - gemm_q_k.reload_q(); - } - - } // Outer loop over the sequence length. -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void device_1xN(const Params ¶ms, - const int num_full_heads, - const int num_main_groups, - const int main_group_size, - const int main_steps, - const int rest_steps) { - - constexpr int STEPS = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M; - const int tidx_global = blockIdx.x * gridDim.x + threadIdx.x; - auto seeds = at::cuda::philox::unpack(params.philox_args); - Philox ph(std::get<0>(seeds), tidx_global, std::get<1>(seeds)); - for( int it = 0; it < num_full_heads; it++ ) { - const int bidx = it * gridDim.x + blockIdx.x; - const int bidh = bidx % params.h; - const int bidb = bidx / params.h; - fmha::device_1xN_(params, bidb, bidh, 0, STEPS, ph); - __syncthreads(); - } - if( main_group_size == 0 ) - return; - const int head_offset = num_full_heads * gridDim.x; - - if( blockIdx.x < main_group_size * num_main_groups ) { - // process within heads - const int group = blockIdx.x % num_main_groups; - const int bidx = blockIdx.x / num_main_groups; - const int bidh = (head_offset + bidx) % params.h; - const int bidb = (head_offset + bidx) / params.h; - const int offset = group * main_steps; - fmha::device_1xN_(params, bidb, bidh, offset, main_steps, ph); - } else { - if(rest_steps == 0 ) return; - // process across heads - const int bidx = blockIdx.x - main_group_size * num_main_groups; - const int offset = num_main_groups * main_steps; - const int total_heads = params.b * params.h; - const int rest_ctas = gridDim.x - main_group_size * num_main_groups; - for( int it = head_offset + bidx; it < total_heads; it += rest_ctas ) { - const int bidh = it % params.h; - const int bidb = it / params.h; - fmha::device_1xN_(params, bidb, bidh, offset, rest_steps, ph); - __syncthreads(); - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void device_1xN(const Params ¶ms, const int total_heads) { - - const int tidx_global = blockIdx.x * gridDim.x + threadIdx.x; - auto seeds = at::cuda::philox::unpack(params.philox_args); - Philox ph(std::get<0>(seeds), tidx_global, std::get<1>(seeds)); - constexpr int STEPS = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M; - - for(int bidx = blockIdx.x; bidx < total_heads; bidx += gridDim.x){ - const int bidh = bidx % params.h; - const int bidb = bidx / params.h; - fmha::device_1xN_(params, bidb, bidh, 0, STEPS, ph); - __syncthreads(); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace fmha - diff --git a/apex/contrib/csrc/fmha/src/fmha_kernel.h b/apex/contrib/csrc/fmha/src/fmha_kernel.h deleted file mode 100644 index 63180b0..0000000 --- a/apex/contrib/csrc/fmha/src/fmha_kernel.h +++ /dev/null @@ -1,179 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#pragma once - -#include - -#include -#include -#include -#include -#include -#include - -namespace fmha { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct BlockInfoPadded { - - template - __device__ BlockInfoPadded(const Params ¶ms, - const int bidb, - const int bidh, - const int tidx) - : bidb(bidb), bidh(bidh), h(params.h) { - - // The block index. - sum_s = params.cu_seqlens[bidb]; - actual_seqlen = params.cu_seqlens[bidb + 1] - sum_s; - bidx = sum_s * params.h + bidh; - - tidx_global = (bidb * params.h + bidh) * THREADS_PER_CTA + tidx; - } - - __device__ bool stop_early() const { - return actual_seqlen == 0; - } - - int actual_seqlen; - int bidx; - int sum_s; - int bidh; - int bidb; - int tidx_global; - int h; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Noloop_traits{ - // Interpretation of Cta_tile dims, i.e. Cta_tile_p: - enum{ STEP = Cta_tile::M }; - enum{ SEQLEN = Cta_tile::N }; - - template - inline __device__ Noloop_traits(const int bidc, const Block_info& binfo) - : bidc_(bidc) { - const int seqlen = binfo.actual_seqlen; - const int steps = (seqlen + STEP - 1) / STEP; - const int steps_per_chunk = (steps + CHUNKS - 1) / CHUNKS; - - const int step_begin = bidc_ * steps_per_chunk; - const int step_end = min(steps, (bidc_ + 1) * steps_per_chunk); - const int actual_steps = max(0, step_end - step_begin); - loop_offset_ = step_begin; - num_steps_ = actual_steps; - - } - - template - inline __device__ void move_all(Tiles & ... tiles) const { - using expand_type = int[]; - for( int s = 0; s < loop_offset_; s++ ) { - expand_type{ (tiles.move(), 0)... }; - } - } - - inline __device__ int get_idx_dk() const { - //return bidc_; - return bidc_ * 2 + 0; - } - - inline __device__ int get_idx_dv() const { - //return CHUNKS + bidc_; - return bidc_ * 2 + 1; - } - - inline __device__ int offset_loop_count(const int l) { - // convert loop counter to position in the outer sequence - return (loop_offset_ + l) * STEP; - } - - const uint32_t bidc_; - int loop_offset_; - int num_steps_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -std::tuple work_dist(const int total_ctas, const int heads_total) { - - constexpr int STEPS_PER_HEAD = Kernel_traits::Cta_tile_p::N / Kernel_traits::Cta_tile_p::M; - - const int num_full_heads = heads_total / total_ctas; - const int heads_last_wave = heads_total % total_ctas; - - int num_main_groups = 0; - int main_steps = 0; - int rest_steps = 0; - if( heads_last_wave > 0 ) { - // Number of CTA groups that process within heads. - num_main_groups = total_ctas / heads_last_wave; - // Remaining CTAs that process between heads. - const int rest_ctas = total_ctas - (heads_last_wave * num_main_groups); - if(rest_ctas == 0) { - // We have exactly "num_main_groups" CTAs to process each of the remaining heads. - main_steps = (STEPS_PER_HEAD + num_main_groups - 1) / num_main_groups; - num_main_groups = STEPS_PER_HEAD / main_steps; // Here: main_step > 0 - rest_steps = STEPS_PER_HEAD % main_steps; - - } else { - // Ideal number of steps if we could load-balance as evenly as possible. - const int steps_ideal = (heads_last_wave * STEPS_PER_HEAD + total_ctas - 1) / total_ctas; - // Iterations that a "rest" CTA has to do at most. - const int max_rest_iters = (heads_last_wave + rest_ctas - 1) / rest_ctas; - // Find the first step distribution, s.t. the maximum work of the "rest" CTAs is less than the work of the main CTAs. - main_steps = steps_ideal; - rest_steps = STEPS_PER_HEAD - main_steps * num_main_groups; - for( ; main_steps * num_main_groups < STEPS_PER_HEAD; main_steps++ ) { - rest_steps = STEPS_PER_HEAD - main_steps * num_main_groups; - const int max_rest_total_steps = rest_steps * max_rest_iters; - if( max_rest_total_steps < main_steps ) - break; - } - rest_steps = STEPS_PER_HEAD - main_steps * num_main_groups; - } - } - - using Cta_tile_p = typename Kernel_traits::Cta_tile_p; - using Mma_tile_p = fmha::Hmma_tile; - - const int max_steps = STEPS_PER_HEAD * num_full_heads + std::max(main_steps, rest_steps); - const int elts_per_thread_per_step = Mma_tile_p::MMAS_M * Mma_tile_p::MMAS_N * 8; - const int elts_per_thread = max_steps * elts_per_thread_per_step; - - return {num_full_heads, num_main_groups, heads_last_wave, main_steps, rest_steps, elts_per_thread}; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace fmha diff --git a/apex/contrib/csrc/fmha/src/fmha_noloop_reduce.cu b/apex/contrib/csrc/fmha/src/fmha_noloop_reduce.cu deleted file mode 100644 index 8e4b9ef..0000000 --- a/apex/contrib/csrc/fmha/src/fmha_noloop_reduce.cu +++ /dev/null @@ -1,177 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#include "fmha.h" - -inline __device__ float4 ldg128(const void *ptr) { - return *static_cast(ptr); -} - -inline __device__ void stg128(void *ptr, const float4 &data) { - *static_cast(ptr) = data; -} - -template -__global__ __launch_bounds__(THREADS) void fmha_noloop_reduce_kernel(void *__restrict__ out, - const void *__restrict__ in, - const int *__restrict__ cu_seqlens, - const int batch_size) { - - enum { BYTES_PER_LDG = 16 }; - enum { NUM_ELTS = BYTES_PER_LDG / sizeof(T) }; - - // One CTA hidden vector for K and V - enum { BYTES_PER_ROW = HIDDEN_SIZE * sizeof(T) * 2 }; - // The stride in bytes in dQKV - enum { OUT_STRIDE_BYTES = 3 * HIDDEN_SIZE * sizeof(T) }; - // The offset in bytes in dQKV to the dKV part for non-interleaved heads - enum { OUT_OFFSET_KV_BYTES = HIDDEN_SIZE * sizeof(T) }; - - static_assert(BYTES_PER_ROW == HIDDEN_SIZE * 2 * sizeof(T)); - - // Size in bytes of the input tile - enum { BYTES_PER_TILE = CHUNKS * BYTES_PER_ROW }; - - enum { BYTES_PER_CTA = THREADS * BYTES_PER_LDG }; - - enum { LDGS = BYTES_PER_ROW / BYTES_PER_CTA }; - static_assert(BYTES_PER_CTA * LDGS == BYTES_PER_ROW); - - union Vec_t { - float4 raw; - T elt[NUM_ELTS]; - }; - - // ZERO-OUT invalid positions in dQKV - const int total = cu_seqlens[batch_size]; - if(blockIdx.x >= total){ - enum { BYTES_PER_QKV_ROW = 3 * HIDDEN_SIZE * sizeof(T) }; - enum { STGS = BYTES_PER_QKV_ROW / BYTES_PER_LDG }; - - const float4 zeros = make_float4(0.f, 0.f, 0.f, 0.f); - - char *base_ptr = static_cast(out) + blockIdx.x * OUT_STRIDE_BYTES; - - for(int tidx = threadIdx.x; tidx < STGS; tidx += THREADS){ - stg128(base_ptr + tidx * BYTES_PER_LDG, zeros); - } - - return; - } - - // SETUP - const int offset_in = blockIdx.x * BYTES_PER_TILE + threadIdx.x * BYTES_PER_LDG; - const char *ptr_in = static_cast(in) + offset_in; - - const int offset_out = blockIdx.x * OUT_STRIDE_BYTES + threadIdx.x * BYTES_PER_LDG; - char *ptr_out = static_cast(out) + OUT_OFFSET_KV_BYTES + offset_out; - - // LOAD - - Vec_t local_in[CHUNKS][LDGS]; - - #pragma unroll - for( int c = 0; c < CHUNKS; c++ ) { - #pragma unroll - for( int l = 0; l < LDGS; l++ ) { - int offset = c * BYTES_PER_ROW + l * BYTES_PER_CTA; - local_in[c][l].raw = ldg128(ptr_in + offset); - } - } - - // UNPACK - float acc[LDGS][NUM_ELTS]; - - #pragma unroll - for( int l = 0; l < LDGS; l++ ) { - #pragma unroll - for( int e = 0; e < NUM_ELTS; e++ ) { - acc[l][e] = float(local_in[0][l].elt[e]); - } - } - - // COMPUTE - #pragma unroll - for( int c = 1; c < CHUNKS; c++ ) { - #pragma unroll - for( int l = 0; l < LDGS; l++ ) { - #pragma unroll - for( int e = 0; e < NUM_ELTS; e++ ) { - acc[l][e] += float(local_in[c][l].elt[e]); - } - } - } - - // PACK - Vec_t local_out[LDGS]; - - #pragma unroll - for( int l = 0; l < LDGS; l++ ) { - #pragma unroll - for( int e = 0; e < NUM_ELTS; e++ ) { - local_out[l].elt[e] = T(acc[l][e]); - } - } - - // STORE - #pragma unroll - for( int l = 0; l < LDGS; l++ ) { - const int offset = l * BYTES_PER_CTA; - stg128(ptr_out + offset, local_out[l].raw); - } -} - -void fmha_run_noloop_reduce(void *out, - const void *in, - const int *cu_seqlens, - const int hidden_size, - const int batch_size, - const int total, - const int num_chunks, - cudaStream_t stream) { - - const int blocks = total; - - if(hidden_size == 1024){ - - constexpr int HIDDEN_SIZE = 1024; - constexpr int THREADS = 256; - - if( num_chunks == 2 ) { - fmha_noloop_reduce_kernel<<>>(out, in, cu_seqlens, batch_size); - } else if( num_chunks == 3 ) { - fmha_noloop_reduce_kernel<<>>(out, in, cu_seqlens, batch_size); - } else { - assert(false && "Unsupported num_chunks"); - } - - }else{ - assert(false && "Unsupported hidden_size"); - } - - FMHA_CHECK_CUDA(cudaPeekAtLastError()); -} diff --git a/apex/contrib/csrc/fmha/src/fmha_utils.h b/apex/contrib/csrc/fmha/src/fmha_utils.h deleted file mode 100644 index de07cc7..0000000 --- a/apex/contrib/csrc/fmha/src/fmha_utils.h +++ /dev/null @@ -1,92 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#pragma once - -#include -#include -#include -#include -#include - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define FMHA_CHECK_CUDA( call ) \ - do { \ - cudaError_t status_ = call; \ - if( status_ != cudaSuccess ) { \ - fprintf( stderr, \ - "CUDA error (%s:%d): %s\n", \ - __FILE__, \ - __LINE__, \ - cudaGetErrorString( status_ ) ); \ - exit( 1 ); \ - } \ - } while( 0 ) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -enum Data_type { DATA_TYPE_FP16, DATA_TYPE_FP32, DATA_TYPE_INT32, DATA_TYPE_INT8 }; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline void set_alpha( uint32_t &alpha, float norm, Data_type dtype ) { - if( dtype == DATA_TYPE_FP16 ) { - half x = __float2half_rn( norm ); - uint16_t h = reinterpret_cast( x ); - ushort2 h2 = { h, h }; - alpha = reinterpret_cast( h2 ); - } else if( dtype == DATA_TYPE_FP32 ) { - alpha = reinterpret_cast( norm ); - } else if( dtype == DATA_TYPE_INT32 ) { - int32_t inorm = static_cast( norm ); - alpha = reinterpret_cast( inorm ); - } else { - assert( false ); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline size_t get_size_in_bytes( size_t n, Data_type dtype ) { - switch( dtype ) { - case DATA_TYPE_FP32: - return n * 4; - case DATA_TYPE_FP16: - return n * 2; - case DATA_TYPE_INT32: - return n * 4; - case DATA_TYPE_INT8: - return n; - default: - assert( false ); - return 0; - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp b/apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp deleted file mode 100644 index 15393fb..0000000 --- a/apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp +++ /dev/null @@ -1,70 +0,0 @@ -#include - -#include -#include - -// CUDA forward declarations - -std::vector focal_loss_forward_cuda( - const at::Tensor &cls_output, - const at::Tensor &cls_targets_at_level, - const at::Tensor &num_positives_sum, - const int64_t num_real_classes, - const float alpha, - const float gamma, - const float smoothing_factor); - -at::Tensor focal_loss_backward_cuda( - const at::Tensor &grad_output, - const at::Tensor &partial_grad, - const at::Tensor &num_positives_sum); - -// C++ interface - -#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) \ - AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) \ - CHECK_CUDA(x); \ - CHECK_CONTIGUOUS(x) - -std::vector focal_loss_forward( - const at::Tensor &cls_output, - const at::Tensor &cls_targets_at_level, - const at::Tensor &num_positives_sum, - const int64_t num_real_classes, - const float alpha, - const float gamma, - const float smoothing_factor -) { - CHECK_INPUT(cls_output); - CHECK_INPUT(cls_targets_at_level); - CHECK_INPUT(num_positives_sum); - - return focal_loss_forward_cuda( - cls_output, - cls_targets_at_level, - num_positives_sum, - num_real_classes, - alpha, - gamma, - smoothing_factor); -} - -at::Tensor focal_loss_backward( - const at::Tensor &grad_output, - const at::Tensor &partial_grad, - const at::Tensor &num_positives_sum -) { - CHECK_INPUT(grad_output); - CHECK_INPUT(partial_grad); - - return focal_loss_backward_cuda(grad_output, partial_grad, num_positives_sum); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &focal_loss_forward, - "Focal loss calculation forward (CUDA)"); - m.def("backward", &focal_loss_backward, - "Focal loss calculation backward (CUDA)"); -} diff --git a/apex/contrib/csrc/focal_loss/focal_loss_cuda_kernel.cu b/apex/contrib/csrc/focal_loss/focal_loss_cuda_kernel.cu deleted file mode 100644 index bda4f88..0000000 --- a/apex/contrib/csrc/focal_loss/focal_loss_cuda_kernel.cu +++ /dev/null @@ -1,267 +0,0 @@ -#include -#include -#include - - -#define ASSERT_UINT4_ALIGNED(PTR) \ - TORCH_INTERNAL_ASSERT(is_aligned(PTR), "Tensor " #PTR " is not uint4 aligned") - -template bool is_aligned(const void *ptr) noexcept { - auto iptr = reinterpret_cast(ptr); - return !(iptr % alignof(T)); -} - -template -__global__ void focal_loss_forward_cuda_kernel( - outscalar_t *loss, scalar_t *partial_grad, - const scalar_t *__restrict__ cls_output, - const labelscalar_t *__restrict__ cls_targets_at_level, - const float *__restrict__ num_positives_sum, const int64_t num_examples, - const int64_t num_classes, const int64_t num_real_classes, - const float alpha, const float gamma, const float smoothing_factor) { - extern __shared__ unsigned char shm[]; - accscalar_t *loss_shm = reinterpret_cast(shm); - loss_shm[threadIdx.x] = 0; - accscalar_t loss_acc = 0; - - accscalar_t one = accscalar_t(1.0); - accscalar_t K = accscalar_t(2.0); - accscalar_t normalizer = one / static_cast(num_positives_sum[0]); - accscalar_t nn_norm, np_norm, pn_norm, pp_norm; - - // *_norm is used for label smoothing only - if (SMOOTHING) { - nn_norm = one - smoothing_factor / K; - np_norm = smoothing_factor / K; - pn_norm = smoothing_factor - smoothing_factor / K; - pp_norm = one - smoothing_factor + smoothing_factor / K; - } - - uint4 p_vec, grad_vec; - - // Accumulate loss on each thread - for (int64_t i = (blockIdx.x * blockDim.x + threadIdx.x) * ILP; - i < num_examples * num_classes; i += gridDim.x * blockDim.x * ILP) { - int64_t idy = i / num_classes; - labelscalar_t y = cls_targets_at_level[idy]; - int64_t base_yid = i % num_classes; - - int64_t pos_idx = idy * num_classes + y; - p_vec = *(uint4 *)&cls_output[i]; - - // Skip ignored matches - if (y == -2) { -#pragma unroll - for (int j = 0; j < ILP; j++) { - *((scalar_t *)(&grad_vec) + j) = 0; - } - *(uint4 *)&partial_grad[i] = grad_vec; - continue; - } - -#pragma unroll - for (int j = 0; j < ILP; j++) { - // Skip the pad classes - if (base_yid + j >= num_real_classes) { - *((scalar_t *)(&grad_vec) + j) = 0; - continue; - } - - accscalar_t p = static_cast(*((scalar_t *)(&p_vec) + j)); - accscalar_t exp_np = ::exp(-p); - accscalar_t exp_pp = ::exp(p); - accscalar_t sigma = one / (one + exp_np); - accscalar_t logee = (p >= 0) ? exp_np : exp_pp; - accscalar_t addee = (p >= 0) ? 0 : -p; - accscalar_t off_a = addee + ::log(one + logee); - - // Negative matches - accscalar_t base = SMOOTHING ? nn_norm * p : p; - accscalar_t off_b = (SMOOTHING ? np_norm : 0) - sigma; - accscalar_t coeff_f1 = one - alpha; - accscalar_t coeff_f2 = sigma; - accscalar_t coeff_b1 = gamma; - accscalar_t coeff_b2 = one - sigma; - - // Positive matches - if (y >= 0 && (i + j == pos_idx)) { - base = SMOOTHING ? pn_norm * p : 0; - off_b = (SMOOTHING ? pp_norm : one) - sigma; - coeff_f1 = alpha; - coeff_f2 = one - sigma; - coeff_b1 = -gamma; - coeff_b2 = sigma; - } - - accscalar_t coeff_f = coeff_f1 * ::pow(coeff_f2, gamma); - accscalar_t coeff_b = coeff_b1 * coeff_b2; - - accscalar_t loss_t = coeff_f * (base + off_a); - accscalar_t grad = coeff_f * (coeff_b * (base + off_a) - off_b); - - // Delay the normalize of partial gradient by num_positives_sum to back - // propagation because scalar_t reduces precision. Focal loss is very - // sensitive to the small gradient. No worry on overflow here since - // gradient has relative smaller range than input. - loss_acc += loss_t; - *((scalar_t *)(&grad_vec) + j) = static_cast(grad); - } - - // This can't ensure to generate stg.128 and may be two stg.64. - *(uint4 *)&partial_grad[i] = grad_vec; - } - loss_shm[threadIdx.x] = loss_acc; - - // Intra-CTA reduction - __syncthreads(); - for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) { - if (threadIdx.x < s) { - loss_shm[threadIdx.x] += loss_shm[threadIdx.x + s]; - } - __syncthreads(); - } - - // Inter-CTA reduction - if (threadIdx.x == 0) { - loss_acc = loss_shm[0] * normalizer; - atomicAdd(loss, loss_acc); - } -} - -template -__global__ void focal_loss_backward_cuda_kernel( - scalar_t *partial_grad, const outscalar_t *__restrict__ grad_output, - const float *__restrict__ num_positives_sum, const uint64_t numel) { - int64_t idx = (blockIdx.x * blockDim.x + threadIdx.x) * ILP; - - accscalar_t normalizer = static_cast(grad_output[0]) / - static_cast(num_positives_sum[0]); - - // The input is enforced to pad to use vector load, thus there's no need to - // check whether the last element of ILP can out of bound. - if (idx >= numel) - return; - - uint4 grad_vec; - grad_vec = *(uint4 *)&partial_grad[idx]; -#pragma unroll(ILP) - for (int i = 0; i < ILP; i++) { - auto grad = static_cast(*((scalar_t *)(&grad_vec) + i)); - grad *= normalizer; - *((scalar_t *)(&grad_vec) + i) = static_cast(grad); - } - *(uint4 *)&partial_grad[idx] = grad_vec; -} - -std::vector focal_loss_forward_cuda( - const at::Tensor &cls_output, const at::Tensor &cls_targets_at_level, - const at::Tensor &num_positives_sum, const int64_t num_real_classes, - const float alpha, const float gamma, const float smoothing_factor) { - // Checks required for correctness - TORCH_INTERNAL_ASSERT(cls_output.size(-1) >= num_real_classes, - "Incorrect number of real classes."); - TORCH_INTERNAL_ASSERT(cls_targets_at_level.scalar_type() == at::kLong, - "Invalid label type."); - TORCH_INTERNAL_ASSERT( - (num_positives_sum.numel() == 1) && - (num_positives_sum.scalar_type() == at::kFloat), - "Expect num_positives_sum to be a float32 tensor with only one element."); - TORCH_INTERNAL_ASSERT(cls_output.dim() == cls_targets_at_level.dim() + 1, - "Mis-matched dimensions between class output and label."); - for (int64_t i = 0; i < cls_targets_at_level.dim(); i++) - TORCH_INTERNAL_ASSERT(cls_output.size(i) == cls_targets_at_level.size(i), - "Mis-matched shape between class output and label."); - - // Checks required for better performance - const int ILP = sizeof(uint4) / cls_output.element_size(); - ASSERT_UINT4_ALIGNED(cls_output.data_ptr()); - TORCH_INTERNAL_ASSERT(cls_output.size(-1) % ILP == 0, - "Pad number of classes first to take advantage of 128 bit load."); - TORCH_INTERNAL_ASSERT(num_real_classes >= ILP, "Too few classes."); - - int64_t num_classes = cls_output.size(-1); - int64_t num_examples = cls_output.numel() / num_classes; - at::Tensor loss = at::zeros({}, cls_output.options().dtype(at::kFloat)); - - // Compute the incompelete gradient during fprop since most of the heavy - // functions of bprop are the same as fprop, thus trade memory for compute - // helps with focal loss. - at::Tensor partial_grad = at::empty_like(cls_output); - - // The grid contains 2 CTA per SM, each CTA loop on input with stride till the - // last item. - cudaDeviceProp props; - cudaGetDeviceProperties(&props, at::cuda::current_device()); - dim3 block(512); - dim3 grid(2 * props.multiProcessorCount); - - // Specialize on label smoothing or not to reduce redundant operations - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - if (smoothing_factor == 0.0f) { - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - cls_output.scalar_type(), "focal_loss_fprop", [&] { - using accscalar_t = at::acc_type; - using labelscalar_t = int64_t; - using outscalar_t = float; - const int ILP = sizeof(uint4) / sizeof(scalar_t); - focal_loss_forward_cuda_kernel - <<>>( - loss.data_ptr(), - partial_grad.data_ptr(), - cls_output.data_ptr(), - cls_targets_at_level.data_ptr(), - num_positives_sum.data_ptr(), num_examples, - num_classes, num_real_classes, alpha, gamma, - smoothing_factor); - }); - } else { - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - cls_output.scalar_type(), "focal_loss_fprop", [&] { - using accscalar_t = at::acc_type; - using labelscalar_t = int64_t; - using outscalar_t = float; - const int ILP = sizeof(uint4) / sizeof(scalar_t); - focal_loss_forward_cuda_kernel - <<>>( - loss.data_ptr(), - partial_grad.data_ptr(), - cls_output.data_ptr(), - cls_targets_at_level.data_ptr(), - num_positives_sum.data_ptr(), num_examples, - num_classes, num_real_classes, alpha, gamma, - smoothing_factor); - }); - } - - AT_CUDA_CHECK(cudaGetLastError()); - return {loss, partial_grad}; -} - -at::Tensor focal_loss_backward_cuda(const at::Tensor &grad_output, - const at::Tensor &partial_grad, - const at::Tensor &num_positives_sum) { - // Each thread process ILP elements - const int ILP = sizeof(uint4) / partial_grad.element_size(); - dim3 block(512); - dim3 grid((partial_grad.numel() + block.x * ILP - 1) / (block.x * ILP)); - - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - partial_grad.scalar_type(), "focal_loss_bprop", [&] { - using accscalar_t = at::acc_type; - using outscalar_t = float; - const int ILP = sizeof(uint4) / sizeof(scalar_t); - focal_loss_backward_cuda_kernel - <<>>(partial_grad.data_ptr(), - grad_output.data_ptr(), - num_positives_sum.data_ptr(), - partial_grad.numel()); - }); - - AT_CUDA_CHECK(cudaGetLastError()); - return partial_grad; -} diff --git a/apex/contrib/csrc/groupbn/batch_norm.cu b/apex/contrib/csrc/groupbn/batch_norm.cu deleted file mode 100644 index 92eb11f..0000000 --- a/apex/contrib/csrc/groupbn/batch_norm.cu +++ /dev/null @@ -1,342 +0,0 @@ -#include -#include -#include - -#include "batch_norm.h" - -#include - -#include "compat.h" - -#define cudaCheckErrors(msg) \ - do { \ - cudaError_t __err = cudaGetLastError(); \ - if (__err != cudaSuccess) { \ - fprintf(stderr, "Fatal error: %s (%s at %s:%d)\n", \ - msg, cudaGetErrorString(__err), \ - __FILE__, __LINE__); \ - fprintf(stderr, "*** FAILED - ABORTING\n"); \ - exit(1); \ - } \ - } while (0) - -static size_t round_up_to_multiple(size_t x, int multiple) { - return ((x + multiple - 1) / multiple) * multiple; -} - -struct Workspace { - Workspace(size_t size) : size(size), data(NULL) { - auto& allocator = *::c10::cuda::CUDACachingAllocator::get(); - dataPtr = allocator.allocate(size); - data = dataPtr.get(); - } - Workspace(const Workspace&) = delete; - Workspace(Workspace&&) = default; - Workspace& operator=(Workspace&&) = default; - ~Workspace() = default; - - size_t size; - void* data; - c10::DataPtr dataPtr; -}; - -// Return {y} -at::Tensor nhwc_bn_fwd_train( - const at::Tensor& x, - const at::Tensor& scale, - const at::Tensor& bias, - const at::Tensor& running_mean, - const at::Tensor& running_inv_var, - const at::Tensor& minibatch_mean, - const at::Tensor& minibatch_inv_var, - const at::Tensor& ret_cta, - const float momentum, - const float epsilon, - const bool fuse_relu, - void * my_data, - void * pair_data, - void * pair_data2, - void * pair_data3, - const int bn_group, - const at::Tensor& magic_tensor, - const int occupancy, - const int grid_dim_x, - const bool coop) { - - auto memory_format = x.suggest_memory_format(); - const bool check_channels_last = x.is_contiguous(at::MemoryFormat::ChannelsLast); - const int N = x.size(0); - const int H = check_channels_last ? x.size(2) : x.size(1); - const int W = check_channels_last ? x.size(3) : x.size(2); - const int C = check_channels_last ? x.size(1) : x.size(3); - - // generating new magic number and use that for sync - int* magic = magic_tensor.DATA_PTR(); - *magic = (*magic + 1) & 0xff; - - // Allocate output tensor - at::Tensor y = check_channels_last ? at::empty({N, C, H, W}, x.options().memory_format(memory_format)) : at::empty({N, H, W, C}, x.options()); - - // Create wrapper - NhwcBatchNorm *bn = new NhwcBatchNorm(); - - bn->setInputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W, bn_group); - bn->setOutputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W); - - bn->setConstants(momentum, epsilon); - - // set pointers within the wrapper - bn->setInputOutputPointers(x.contiguous(memory_format).DATA_PTR(), - nullptr, - y.contiguous(memory_format).DATA_PTR(), - nullptr); - - bn->setWeightPointers({scale.contiguous().DATA_PTR(), - bias.contiguous().DATA_PTR()}, {nullptr, nullptr}); - bn->setParameterPointers({running_mean.contiguous().DATA_PTR(), - running_inv_var.DATA_PTR()}); - - // deal with workspace(s) - auto workspace_bytes = bn->numWorkspaceBytes(); - // We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset - // an allocated workspace for the others - size_t total_workspace_bytes = 0; - std::vector workspace_offsets; - - for (auto index = 3; index < workspace_bytes.size(); ++index) { - total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512); - workspace_offsets.push_back(total_workspace_bytes); - - auto alloc_bytes = workspace_bytes[index]; - total_workspace_bytes += alloc_bytes; - } - - // Allocate the workspace - Workspace ws(total_workspace_bytes); - - std::vector workspace; - workspace.push_back(minibatch_mean.contiguous().DATA_PTR()); - workspace.push_back(minibatch_inv_var.contiguous().DATA_PTR()); - - auto stream = at::cuda::getCurrentCUDAStream().stream(); - const int retired_cta_bytes = workspace_bytes[2]; - void* retired_ctas = ret_cta.contiguous().DATA_PTR(); - assert(ret_cta.size(0)>=retired_cta_bytes); - workspace.push_back(retired_ctas); - - for (auto index = 3; index < workspace_bytes.size(); ++index) { - void *ptr = reinterpret_cast(ws.data) + workspace_offsets[index-3]; - workspace.push_back(ptr); - } - - bn->setWorkspacePointers(workspace, workspace_bytes); - - // Don't fuse in ReLU for now at least - bn->fwd(stream, fuse_relu, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x, coop); - - return y.contiguous(memory_format); -} - -at::Tensor nhwc_bn_fwd_eval( - const at::Tensor& x, - const at::Tensor& scale, - const at::Tensor& bias, - const at::Tensor& running_mean, - const at::Tensor& running_inv_var, - const at::Tensor& ret_cta, - const int bn_group, - const float momentum, - const float epsilon, - const bool fuse_relu) { - - const bool check_channels_last = x.is_contiguous(at::MemoryFormat::ChannelsLast); - auto memory_format = x.suggest_memory_format(); - const int N = x.size(0); - const int H = check_channels_last ? x.size(2) : x.size(1); - const int W = check_channels_last ? x.size(3) : x.size(2); - const int C = check_channels_last ? x.size(1) : x.size(3); - - // Allocate output tensor - at::Tensor y = check_channels_last ? at::empty({N, C, H, W}, x.options().memory_format(memory_format)) : at::empty({N, H, W, C}, x.options()); - - // Create wrapper - NhwcBatchNorm *bn = new NhwcBatchNorm(); - - bn->setInputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W, bn_group); - bn->setOutputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W); - - bn->setConstants(momentum, epsilon); - - // set pointers within the wrapper - bn->setInputOutputPointers(x.contiguous(memory_format).DATA_PTR(), - nullptr, - y.contiguous(memory_format).DATA_PTR(), - nullptr); - - bn->setWeightPointers({scale.contiguous().DATA_PTR(), - bias.contiguous().DATA_PTR()}, {nullptr, nullptr}); - bn->setParameterPointers({running_mean.contiguous().DATA_PTR(), - running_inv_var.contiguous().DATA_PTR()}); - - // deal with workspace(s) - auto workspace_bytes = bn->numWorkspaceBytes(); - // We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset - // an allocated workspace for the others - size_t total_workspace_bytes = 0; - std::vector workspace_offsets; - - for (auto index = 3; index < workspace_bytes.size(); ++index) { - total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512); - workspace_offsets.push_back(total_workspace_bytes); - - auto alloc_bytes = workspace_bytes[index]; - total_workspace_bytes += alloc_bytes; - } - - // Allocate the workspace - Workspace ws(total_workspace_bytes); - - std::vector workspace; - workspace.push_back(nullptr); - workspace.push_back(nullptr); - - auto stream = at::cuda::getCurrentCUDAStream().stream(); - const int retired_cta_bytes = workspace_bytes[2]; - void* retired_ctas = ret_cta.contiguous().DATA_PTR(); - assert(ret_cta.size(0)>=retired_cta_bytes); - workspace.push_back(retired_ctas); - - for (auto index = 3; index < workspace_bytes.size(); ++index) { - void *ptr = reinterpret_cast(ws.data) + workspace_offsets[index-3]; - workspace.push_back(ptr); - } - - bn->setWorkspacePointers(workspace, workspace_bytes); - - // Don't fuse in ReLU for now at least - bn->fwdInference(stream, fuse_relu); - - return y.contiguous(memory_format); - -} - -std::vector nhwc_bn_bwd( - const at::Tensor& x, - const at::Tensor& dy, - const at::Tensor& scale, - const at::Tensor& bias, - const at::Tensor& running_mean, - const at::Tensor& running_inv_var, - const at::Tensor& minibatch_mean, - const at::Tensor& minibatch_inv_var, - const at::Tensor& ret_cta, - const float momentum, - const float epsilon, - const bool fuse_relu, - void * my_data, - void * pair_data, - void * pair_data2, - void * pair_data3, - const int bn_group, - const at::Tensor& magic_tensor, - const int occupancy, - const int grid_dim_x, - const bool coop) { - // shape - const bool check_channels_last = x.is_contiguous(at::MemoryFormat::ChannelsLast); - auto memory_format = x.suggest_memory_format(); - const int N = x.size(0); - const int H = check_channels_last ? x.size(2) : x.size(1); - const int W = check_channels_last ? x.size(3) : x.size(2); - const int C = check_channels_last ? x.size(1) : x.size(3); - - // generating new magic number and use that for sync - int* magic = magic_tensor.DATA_PTR(); - *magic = (*magic + 1) & 0xff; - - // outputs - at::Tensor x_grad, scale_grad, bias_grad; - - // Allocate outputs - x_grad = check_channels_last ? at::empty({N, C, H, W}, dy.options().memory_format(memory_format)) : at::empty_like(x); - scale_grad = at::empty_like(scale); - bias_grad = at::empty_like(bias); - - // Create wrapper - NhwcBatchNorm *bn = new NhwcBatchNorm(); - - bn->setInputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W, bn_group); - bn->setOutputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W); - - bn->setConstants(momentum, epsilon); - - // set pointers within the wrapper - bn->setInputOutputPointers(x.contiguous(memory_format).DATA_PTR(), - x_grad.contiguous(memory_format).DATA_PTR(), - nullptr, - dy.contiguous(memory_format).DATA_PTR()); - - bn->setWeightPointers({scale.contiguous().DATA_PTR(), - bias.contiguous().DATA_PTR()}, - {scale_grad.DATA_PTR(), - bias_grad.DATA_PTR()}); - bn->setParameterPointers({running_mean.contiguous().DATA_PTR(), - running_inv_var.contiguous().DATA_PTR()}); - - // deal with workspace(s) - auto workspace_bytes = bn->numWorkspaceBytes(); - // We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset - // an allocated workspace for the others - size_t total_workspace_bytes = 0; - std::vector workspace_offsets; - - for (auto index = 3; index < workspace_bytes.size(); ++index) { - total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512); - workspace_offsets.push_back(total_workspace_bytes); - - auto alloc_bytes = workspace_bytes[index]; - total_workspace_bytes += alloc_bytes; - } - - // Allocate the workspace - Workspace ws(total_workspace_bytes); - - std::vector workspace; - workspace.push_back(minibatch_mean.contiguous().DATA_PTR()); - workspace.push_back(minibatch_inv_var.contiguous().DATA_PTR()); - - auto stream = at::cuda::getCurrentCUDAStream().stream(); - const int retired_cta_bytes = workspace_bytes[2]; - void* retired_ctas = ret_cta.contiguous().DATA_PTR(); - assert(ret_cta.size(0)>=retired_cta_bytes); - workspace.push_back(retired_ctas); - - for (auto index = 3; index < workspace_bytes.size(); ++index) { - void *ptr = reinterpret_cast(ws.data) + workspace_offsets[index-3]; - workspace.push_back(ptr); - } - - bn->setWorkspacePointers(workspace, workspace_bytes); - - bn->dgrad(stream, fuse_relu, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x, coop); - - return std::vector{x_grad.contiguous(memory_format), scale_grad, bias_grad}; -} - -int nhwc_bn_fwd_occupancy() { - int device_id=-1; - cudaGetDevice(&device_id); - - //max occupancy supported by the code is 2 - return NhwcBatchNorm::smem_driven_fwd_occupancy(device_id, 2); -} - -int nhwc_bn_bwd_occupancy() { - int device_id=-1; - cudaGetDevice(&device_id); - - //max occupancy supported by the code is 2 - return NhwcBatchNorm::smem_driven_bwd_occupancy(device_id, 2); -} - - diff --git a/apex/contrib/csrc/groupbn/batch_norm.h b/apex/contrib/csrc/groupbn/batch_norm.h deleted file mode 100644 index 5f56dd9..0000000 --- a/apex/contrib/csrc/groupbn/batch_norm.h +++ /dev/null @@ -1,901 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2018 by Contributors - * \file nhwc_batch_norm.h - * \brief CUDA NHWC Batch Normalization code - * \author Shankara Rao Thejaswi Nanditale, Dick Carter, Evgeni Krimer -*/ -#ifndef MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_H_ -#define MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_H_ - -#include "dnn.h" - -#include -#include -#include -#include - -#include "nhwc_batch_norm_kernel.h" -#include "cuda_utils.h" -#include "c10/macros/Macros.h" - - -#define VERBOSE_DEFAULT false - -class NhwcBatchNorm { - public: - NhwcBatchNorm() { - name_ = "nhwc_batchnorm"; - createTensorDescriptor(&X_tensor_desc_); - createTensorDescriptor(&Y_tensor_desc_); - } - - ~NhwcBatchNorm() { - destroyTensorDescriptor(X_tensor_desc_); - destroyTensorDescriptor(Y_tensor_desc_); - } - - void die() { - std::cerr << "batchnorm not initialized" << std::endl; - exit(-1); - } - - void fwd(cudaStream_t stream, bool use_relu, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop); - void dgrad(cudaStream_t stream, bool use_relu, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop); - void fwdInference(cudaStream_t stream, bool use_relu); - dim3 calc_fwd_grid(int *loop, const int grid_dim_x); - dim3 calc_bwd_grid(int *loop, const int grid_dim_x); - - void setInputDescriptor(const dnnTensorFormat_t format, - const dnnDataType_t data_type, - int n, int c, int h, int w, int bn_group) { - m_ = n * h * w; - int m_bn_adjusted = m_ * bn_group; - c_ = c; - // factor to scale sum of squared errors to get saved variance. Must be 1/nhw. - svar_inv_count_ = 1.f / m_bn_adjusted; - // factor to scale sum of squared errors to get running variance. Should be 1/(nhw-1). - int divisor = m_bn_adjusted - 1; - // nhw == 1 is unlikely, but by setting the rvar_inv_count_ == 1.f, we avoid running var infs. - rvar_inv_count_ = divisor == 0 ? 1.f : 1.f / divisor; - setTensorDescriptor(X_tensor_desc_, format, data_type, n, c, h, w); - } - - void setOutputDescriptor(const dnnTensorFormat_t format, - const dnnDataType_t data_type, - int n, int c, int h, int w) { - setTensorDescriptor(Y_tensor_desc_, format, data_type, n, c, h, w); - } - - const std::vector numWorkspaceBytes() const; - - void setWorkspacePointers( - const std::vector& workspace, - const std::vector& num_workspace_bytes); - - void setInputOutputPointers(void* X, void* dX, void* Y, void *dY) { - X_ = X; - dX_ = dX; - Y_ = Y; - dY_ = dY; - } - - // Sets the pointers for the scale and weight (in that order) data and derivative buffers. - void setWeightPointers(const std::vector& weight_pointers, - const std::vector& deriv_pointers) { - assert(weight_pointers.size() == 2); - assert(deriv_pointers.size() == 2); - scale_ = static_cast(weight_pointers[0]); - bias_ = static_cast(weight_pointers[1]); - dscale_ = static_cast(deriv_pointers[0]); - dbias_ = static_cast(deriv_pointers[1]); - } - - // Sets the pointers for the population mean and variance buffers, in that order. - void setParameterPointers(const std::vector& param_pointers) { - assert(param_pointers.size() == 2); - population_mean_ = static_cast(param_pointers[0]); - population_variance_ = static_cast(param_pointers[1]); - } - - void setConstants(const double exp_avg_factor, const double eps) { - exp_avg_factor_ = exp_avg_factor; - eps_ = eps; - } - - void processCudnnStatus(const dnnStatus_t& status, - const std::string& string = std::string(), - bool verbose = VERBOSE_DEFAULT) { -#ifdef __HIP_PLATFORM_HCC__ - if (status != DNN_STATUS_SUCCESS) - LOG(FATAL) << string << " " << miopenGetErrorString(status); - else if (verbose) - LOG(INFO) << string << " " << miopenGetErrorString(status); -#else - if (status != DNN_STATUS_SUCCESS) - LOG(FATAL) << string << " " << cudnnGetErrorString(status); - else if (verbose) - LOG(INFO) << string << " " << cudnnGetErrorString(status); -#endif - } - - void checkCudaStatus(const std::string& string = std::string(), - bool verbose = VERBOSE_DEFAULT) { - cudaError_t status = cudaGetLastError(); - if (status != cudaSuccess) - LOG(FATAL) << string << " " << cudaGetErrorString(status); - else if (verbose) - LOG(INFO) << string << " " << cudaGetErrorString(status); - } - - size_t size_retired_ctas(int grid_y) const { - // Note that the value of max_grid_y to handle known GPUs is about 160. - const int max_grid_y = 1024; - if (grid_y > max_grid_y) - LOG(INFO) << "GPU capabilities exceeds assumptions."; - const int retired_cta_bytes = max_grid_y * 2 * sizeof(int); - // Since the region will be initialized once and used for many kernels, - // the idea is to return an ample size that will cover all uses. - return retired_cta_bytes; - } - - dnnTensorDescriptor_t X_tensor_desc_ = nullptr; - dnnTensorDescriptor_t Y_tensor_desc_ = nullptr; - - void* X_ = nullptr; - void* dX_ = nullptr; - void* Y_ = nullptr; - void* dY_ = nullptr; - - // Learned scale and bias weights. - float* scale_ = nullptr; - float* dscale_ = nullptr; - float* bias_ = nullptr; - float* dbias_ = nullptr; - - // Computed population mean and variance parameters. - float* population_mean_ = nullptr; - float* population_variance_ = nullptr; - - // Workspace buffers for minibatch mean and variance (computed in fwd, needed by bwd). - float* minibatch_mean_ = nullptr; - float* minibatch_variance_ = nullptr; - - int m_ = 0; // Number of values per channel that BN is normalizing. - int c_ = 0; // Number of channels over which BN is normalizing. - - float svar_inv_count_ = 0.f; // factor to scale sum of squared errors to get saved variance - float rvar_inv_count_ = 0.f; // factor to scale sum of squared errors to get running variance - - double exp_avg_factor_ = 0.; - double eps_ = 0.; - std::string name_; - - private: - void setTensorDescriptor(dnnTensorDescriptor_t descriptor, - dnnTensorFormat_t format, - dnnDataType_t data_type, - int n, int c, int h, int w) { - dnnStatus_t status = DNN_STATUS_SUCCESS; -#ifdef __HIP_PLATFORM_HCC__ - status = miopenSet4dTensorDescriptor(descriptor, data_type, n, c, h, w); -#else - status = cudnnSetTensor4dDescriptor(descriptor, format, data_type, n, c, h, w); -#endif - processCudnnStatus(status, "set tensor descriptor"); - } - - void createTensorDescriptor(dnnTensorDescriptor_t *descriptor) { - dnnStatus_t status = DNN_STATUS_SUCCESS; -#ifdef __HIP_PLATFORM_HCC__ - status = miopenCreateTensorDescriptor(descriptor); -#else - status = cudnnCreateTensorDescriptor(descriptor); -#endif - processCudnnStatus(status, "create tensor_descriptor"); - } - - void destroyTensorDescriptor(dnnTensorDescriptor_t descriptor) { - dnnStatus_t status = DNN_STATUS_SUCCESS; -#ifdef __HIP_PLATFORM_HCC__ - status = miopenDestroyTensorDescriptor(descriptor); -#else - status = cudnnDestroyTensorDescriptor(descriptor); -#endif - processCudnnStatus(status, "destroy tensor_descriptor"); - } - - protected: - float *partial_sums_ = nullptr; - int *partial_counts_ = nullptr; - int *retired_ctas_ = nullptr; - - void _setFwdParams(NhwcBatchNormFwdParams *params) const; - void _setFwdInferenceParams(NhwcBatchNormFwdInferenceParams *params) const; - void _setBwdParams(NhwcBatchNormBwdParams *params) const; - - // @todo: ability to configure these? - // Kernel params - static const int USE_ONLINE_APPROACH = 1; - static const int THREADS_PER_CTA = 512; - static const int THREADS_PER_PIXEL = 32; - static const int C_ELEMENTS_PER_CTA = 128; - static const int ELEMENTS_PER_LDG = C_ELEMENTS_PER_CTA / THREADS_PER_PIXEL; - static const int MAX_SMEM_WITHOUT_OPT_IN = 48 * 1024; - - typedef uint16_t StorageType; - //typedef float StorageType; - // increasing this to 6 causes spills in fwd kernel! - static const int PIXELS_PER_THREAD_IN_REGISTERS_FWD = 1; - static const int PIXELS_PER_THREAD_IN_REGISTERS_BWD = 1; - static const int PIXELS_PER_THREAD_IN_SMEM_FWD = 0; - static const int PIXELS_PER_THREAD_IN_SMEM_BWD = 0; - - static const int PIXELS_PER_THREAD_FWD = PIXELS_PER_THREAD_IN_REGISTERS_FWD + \ - PIXELS_PER_THREAD_IN_SMEM_FWD; - static const int PIXELS_PER_THREAD_BWD = PIXELS_PER_THREAD_IN_REGISTERS_BWD + \ - PIXELS_PER_THREAD_IN_SMEM_BWD; - static const int PIXELS_PER_THREAD_FWD_INFERENCE = 4; - - // Derived params - static const size_t SMEM_SIZE_FWD = PIXELS_PER_THREAD_IN_SMEM_FWD*THREADS_PER_CTA*\ - ELEMENTS_PER_LDG*sizeof(StorageType); - static const size_t SMEM_SIZE_BWD = PIXELS_PER_THREAD_IN_SMEM_BWD*THREADS_PER_CTA*\ - ELEMENTS_PER_LDG*2*sizeof(StorageType); - static const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL; - static const int PIXELS_PER_CTA_FWD = THREADS_PER_CTA/THREADS_PER_PIXEL * \ - PIXELS_PER_THREAD_FWD; - static const int PIXELS_PER_CTA_BWD = THREADS_PER_CTA/THREADS_PER_PIXEL * \ - PIXELS_PER_THREAD_BWD; - static const int PIXELS_PER_CTA_FWD_INFERENCE = THREADS_PER_CTA/THREADS_PER_PIXEL * \ - PIXELS_PER_THREAD_FWD_INFERENCE; - - // max grid.y in case of group bn is limited by exchange buffer size - static const int MAX_GBN_BLOCK_Y = 256; - - // Helper function to launch the forward kernel. - - // We calculate (based on smem usage) the achievable occupancy and make sure we run a kernel - // version that was compiled with that occupancy in its launch bounds. This way, we avoid - // needless register spills. - void _fwdKernelLauncher(cudaStream_t stream, NhwcBatchNormFwdParams params, - dim3 grid_dim, int outer_loops, bool use_relu, const int occupancy, const bool coop) { - -#ifdef __HIP_PLATFORM_HCC__ -#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY, COOP) \ - do { \ - CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \ - auto fwd_func = nhwc_batch_norm_fwd< \ - StorageType, \ - THREADS_PER_CTA, \ - THREADS_PER_PIXEL, \ - PIXELS_PER_THREAD_IN_REGISTERS_FWD, \ - PIXELS_PER_THREAD_IN_SMEM_FWD, \ - ELEMENTS_PER_LDG, \ - USE_ONLINE_APPROACH, \ - OUTER_LOOPS, \ - USE_RELU, \ - USE_ADD_RELU, \ - COMPILED_FOR_OCCUPANCY>; \ - if (COMPILED_FOR_OCCUPANCY > 1) { \ - hipFuncSetAttribute((void *) fwd_func, hipFuncAttributePreferredSharedMemoryCarveout, 100); \ - checkCudaStatus(name_ + " fwd ser coop kernel (cudaFuncSetAttribute carveout)"); \ - } \ - void *params_ptr = static_cast(¶ms); \ - using FWD_FUNC = decltype(nhwc_batch_norm_fwd< \ - StorageType, \ - THREADS_PER_CTA, \ - THREADS_PER_PIXEL, \ - PIXELS_PER_THREAD_IN_REGISTERS_FWD, \ - PIXELS_PER_THREAD_IN_SMEM_FWD, \ - ELEMENTS_PER_LDG, \ - USE_ONLINE_APPROACH, \ - OUTER_LOOPS, \ - USE_RELU, \ - USE_ADD_RELU, \ - COMPILED_FOR_OCCUPANCY>); \ - if (COOP) { \ - hipLaunchCooperativeKernel(fwd_func, \ - grid_dim, \ - THREADS_PER_CTA, \ - ¶ms_ptr, \ - SMEM_SIZE_FWD, \ - stream); \ - } else { \ - hipLaunchKernel((void *) fwd_func, \ - grid_dim, \ - THREADS_PER_CTA, \ - ¶ms_ptr, \ - SMEM_SIZE_FWD, \ - stream); \ - } \ - checkCudaStatus(name_ + " fwd ser coop kernel"); \ - } while (0) -#else -#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY, COOP) \ - do { \ - CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \ - auto fwd_func = nhwc_batch_norm_fwd< \ - StorageType, \ - THREADS_PER_CTA, \ - THREADS_PER_PIXEL, \ - PIXELS_PER_THREAD_IN_REGISTERS_FWD, \ - PIXELS_PER_THREAD_IN_SMEM_FWD, \ - ELEMENTS_PER_LDG, \ - USE_ONLINE_APPROACH, \ - OUTER_LOOPS, \ - USE_RELU, \ - USE_ADD_RELU, \ - COMPILED_FOR_OCCUPANCY>; \ - if (COMPILED_FOR_OCCUPANCY > 1) { \ - cudaFuncSetAttribute(fwd_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \ - checkCudaStatus(name_ + " fwd ser coop kernel (cudaFuncSetAttribute carveout)"); \ - } \ - void *params_ptr = static_cast(¶ms); \ - using FWD_FUNC = decltype(nhwc_batch_norm_fwd< \ - StorageType, \ - THREADS_PER_CTA, \ - THREADS_PER_PIXEL, \ - PIXELS_PER_THREAD_IN_REGISTERS_FWD, \ - PIXELS_PER_THREAD_IN_SMEM_FWD, \ - ELEMENTS_PER_LDG, \ - USE_ONLINE_APPROACH, \ - OUTER_LOOPS, \ - USE_RELU, \ - USE_ADD_RELU, \ - COMPILED_FOR_OCCUPANCY>); \ - if (COOP) { \ - cudaLaunchCooperativeKernel(fwd_func, \ - grid_dim, \ - THREADS_PER_CTA, \ - ¶ms_ptr, \ - SMEM_SIZE_FWD, \ - stream); \ - } else { \ - cudaLaunchKernel(fwd_func, \ - grid_dim, \ - THREADS_PER_CTA, \ - ¶ms_ptr, \ - SMEM_SIZE_FWD, \ - stream); \ - } \ - checkCudaStatus(name_ + " fwd ser coop kernel"); \ - } while (0) -#endif - - // Don't try for an occupancy > 2 as this will squeeze register use and create spills. - if (outer_loops == 1 && use_relu) { - if (occupancy >= 2) - LAUNCH_FWD_KERNEL(1, true, false, 2, coop); - else - LAUNCH_FWD_KERNEL(1, true, false, 1, coop); - } else if (outer_loops == 1 && !use_relu) { - if (occupancy >= 2) - LAUNCH_FWD_KERNEL(1, false, false, 2, coop); - else - LAUNCH_FWD_KERNEL(1, false, false, 1, coop); - } else if (use_relu) { - if (occupancy >= 2) - LAUNCH_FWD_KERNEL(0, true, false, 2, coop); - else - LAUNCH_FWD_KERNEL(0, true, false, 1, coop); - } else { - if (occupancy >= 2) - LAUNCH_FWD_KERNEL(0, false, false, 2, coop); - else - LAUNCH_FWD_KERNEL(0, false, false, 1, coop); - } -#undef LAUNCH_FWD_KERNEL - } - - // Helper function to launch the backward kernel. - - void _bwdKernelLauncher(cudaStream_t stream, NhwcBatchNormBwdParams params, - dim3 grid_dim, int outer_loops, bool use_relu, const int occupancy, const bool coop) { -#ifdef __HIP_PLATFORM_HCC__ -#define LAUNCH_BWD_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \ - do { \ - CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \ - auto bwd_func = nhwc_batch_norm_bwd< \ - StorageType, \ - THREADS_PER_CTA, \ - THREADS_PER_PIXEL, \ - PIXELS_PER_THREAD_IN_REGISTERS_BWD, \ - PIXELS_PER_THREAD_IN_SMEM_BWD, \ - ELEMENTS_PER_LDG, \ - USE_ONLINE_APPROACH, \ - OUTER_LOOPS, \ - COMPILED_FOR_OCCUPANCY>; \ - if (COMPILED_FOR_OCCUPANCY > 1) { \ - hipFuncSetAttribute((void *) bwd_func, hipFuncAttributePreferredSharedMemoryCarveout, 100); \ - checkCudaStatus(name_ + " bwd coop serial kernel (cudaFuncSetAttribute carveout)"); \ - } \ - void *params_ptr = static_cast(¶ms); \ - using BWD_FUNC = decltype(nhwc_batch_norm_bwd< \ - StorageType, \ - THREADS_PER_CTA, \ - THREADS_PER_PIXEL, \ - PIXELS_PER_THREAD_IN_REGISTERS_BWD, \ - PIXELS_PER_THREAD_IN_SMEM_BWD, \ - ELEMENTS_PER_LDG, \ - USE_ONLINE_APPROACH, \ - OUTER_LOOPS, \ - COMPILED_FOR_OCCUPANCY>); \ - if (COOP) { \ - hipLaunchCooperativeKernel(bwd_func, \ - grid_dim, \ - THREADS_PER_CTA, \ - ¶ms_ptr, \ - SMEM_SIZE_BWD, \ - stream); \ - } else { \ - hipLaunchKernel((void *) bwd_func, \ - grid_dim, \ - THREADS_PER_CTA, \ - ¶ms_ptr, \ - SMEM_SIZE_BWD, \ - stream); \ - } \ - checkCudaStatus(name_ + " bwd coop serial kernel"); \ - } while (0) - -#define LAUNCH_BWD_RELU_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \ - do { \ - CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \ - auto bwd_relu_func = nhwc_batch_norm_bwd_relu< \ - StorageType, \ - THREADS_PER_CTA, \ - THREADS_PER_PIXEL, \ - PIXELS_PER_THREAD_IN_REGISTERS_BWD, \ - PIXELS_PER_THREAD_IN_SMEM_BWD, \ - ELEMENTS_PER_LDG, \ - USE_ONLINE_APPROACH, \ - OUTER_LOOPS, \ - COMPILED_FOR_OCCUPANCY>; \ - if (COMPILED_FOR_OCCUPANCY > 1) { \ - hipFuncSetAttribute((void *) bwd_relu_func, hipFuncAttributePreferredSharedMemoryCarveout, 100); \ - checkCudaStatus(name_ + " bwd-relu coop serial kernel (cudaFuncSetAttribute carveout)"); \ - } \ - void *params_ptr = static_cast(¶ms); \ - using BWD_RELU_FUNC = decltype(nhwc_batch_norm_bwd_relu< \ - StorageType, \ - THREADS_PER_CTA, \ - THREADS_PER_PIXEL, \ - PIXELS_PER_THREAD_IN_REGISTERS_BWD, \ - PIXELS_PER_THREAD_IN_SMEM_BWD, \ - ELEMENTS_PER_LDG, \ - USE_ONLINE_APPROACH, \ - OUTER_LOOPS, \ - COMPILED_FOR_OCCUPANCY>); \ - if (COOP) { \ - hipLaunchCooperativeKernel(bwd_relu_func, \ - grid_dim, \ - THREADS_PER_CTA, \ - ¶ms_ptr, \ - SMEM_SIZE_BWD, \ - stream); \ - } else { \ - hipLaunchKernel((void *) bwd_relu_func, \ - grid_dim, \ - THREADS_PER_CTA, \ - ¶ms_ptr, \ - SMEM_SIZE_BWD, \ - stream); \ - } \ - checkCudaStatus(name_ + " bwd-relu coop serial kernel"); \ - } while (0) -#else -#define LAUNCH_BWD_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \ - do { \ - CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \ - auto bwd_func = nhwc_batch_norm_bwd< \ - StorageType, \ - THREADS_PER_CTA, \ - THREADS_PER_PIXEL, \ - PIXELS_PER_THREAD_IN_REGISTERS_BWD, \ - PIXELS_PER_THREAD_IN_SMEM_BWD, \ - ELEMENTS_PER_LDG, \ - USE_ONLINE_APPROACH, \ - OUTER_LOOPS, \ - COMPILED_FOR_OCCUPANCY>; \ - if (COMPILED_FOR_OCCUPANCY > 1) { \ - cudaFuncSetAttribute(bwd_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \ - checkCudaStatus(name_ + " bwd coop serial kernel (cudaFuncSetAttribute carveout)"); \ - } \ - void *params_ptr = static_cast(¶ms); \ - using BWD_FUNC = decltype(nhwc_batch_norm_bwd< \ - StorageType, \ - THREADS_PER_CTA, \ - THREADS_PER_PIXEL, \ - PIXELS_PER_THREAD_IN_REGISTERS_BWD, \ - PIXELS_PER_THREAD_IN_SMEM_BWD, \ - ELEMENTS_PER_LDG, \ - USE_ONLINE_APPROACH, \ - OUTER_LOOPS, \ - COMPILED_FOR_OCCUPANCY>); \ - if (COOP) { \ - cudaLaunchCooperativeKernel(bwd_func, \ - grid_dim, \ - THREADS_PER_CTA, \ - ¶ms_ptr, \ - SMEM_SIZE_BWD, \ - stream); \ - } else { \ - cudaLaunchKernel(bwd_func, \ - grid_dim, \ - THREADS_PER_CTA, \ - ¶ms_ptr, \ - SMEM_SIZE_BWD, \ - stream); \ - } \ - checkCudaStatus(name_ + " bwd coop serial kernel"); \ - } while (0) - -#define LAUNCH_BWD_RELU_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \ - do { \ - CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << "Nhwc batchnorm kernel smem too big."; \ - auto bwd_relu_func = nhwc_batch_norm_bwd_relu< \ - StorageType, \ - THREADS_PER_CTA, \ - THREADS_PER_PIXEL, \ - PIXELS_PER_THREAD_IN_REGISTERS_BWD, \ - PIXELS_PER_THREAD_IN_SMEM_BWD, \ - ELEMENTS_PER_LDG, \ - USE_ONLINE_APPROACH, \ - OUTER_LOOPS, \ - COMPILED_FOR_OCCUPANCY>; \ - if (COMPILED_FOR_OCCUPANCY > 1) { \ - cudaFuncSetAttribute(bwd_relu_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \ - checkCudaStatus(name_ + " bwd-relu coop serial kernel (cudaFuncSetAttribute carveout)"); \ - } \ - void *params_ptr = static_cast(¶ms); \ - using BWD_RELU_FUNC = decltype(nhwc_batch_norm_bwd_relu< \ - StorageType, \ - THREADS_PER_CTA, \ - THREADS_PER_PIXEL, \ - PIXELS_PER_THREAD_IN_REGISTERS_BWD, \ - PIXELS_PER_THREAD_IN_SMEM_BWD, \ - ELEMENTS_PER_LDG, \ - USE_ONLINE_APPROACH, \ - OUTER_LOOPS, \ - COMPILED_FOR_OCCUPANCY>); \ - if (COOP) { \ - cudaLaunchCooperativeKernel(bwd_relu_func, \ - grid_dim, \ - THREADS_PER_CTA, \ - ¶ms_ptr, \ - SMEM_SIZE_BWD, \ - stream); \ - } else { \ - cudaLaunchKernel(bwd_relu_func, \ - grid_dim, \ - THREADS_PER_CTA, \ - ¶ms_ptr, \ - SMEM_SIZE_BWD, \ - stream); \ - } \ - checkCudaStatus(name_ + " bwd-relu coop serial kernel"); \ - } while (0) -#endif - - // Don't try for an occupancy > 2 as this will squeeze register use and create spills. - if (outer_loops == 1 && use_relu) { - if (occupancy >= 2) - LAUNCH_BWD_RELU_KERNEL(1, 2, coop); - else - LAUNCH_BWD_RELU_KERNEL(1, 1, coop); - } else if (outer_loops == 1 && !use_relu) { - if (occupancy >= 2) - LAUNCH_BWD_KERNEL(1, 2, coop); - else - LAUNCH_BWD_KERNEL(1, 1, coop); - } else if (use_relu) { - if (occupancy >= 2) - LAUNCH_BWD_RELU_KERNEL(0, 2, coop); - else - LAUNCH_BWD_RELU_KERNEL(0, 1, coop); - } else { - if (occupancy >= 2) - LAUNCH_BWD_KERNEL(0, 2, coop); - else - LAUNCH_BWD_KERNEL(0, 1, coop); - } -#undef LAUNCH_BWD_KERNEL - } - - public: - - // Calculate the expected fwd kernel occupancy, as dictated by shared memory usage. - static int smem_driven_fwd_occupancy(int device_id, const int max_cta_per_sm) { - using namespace at::cuda::utils; - int fwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/C10_WARP_SIZE)*ELEMENTS_PER_LDG*sizeof(float); - int fwd_smem_bytes = SMEM_SIZE_FWD + fwd_reduction_bytes; - int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / fwd_smem_bytes; - return std::min(max_cta_per_sm, occupancy); - } - - // Calculate the expected bwd kernel occupancy, as dictated by shared memory usage. - static int smem_driven_bwd_occupancy(int device_id, const int max_cta_per_sm) { - using namespace at::cuda::utils; - int bwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/C10_WARP_SIZE)*ELEMENTS_PER_LDG*sizeof(float); - int bwd_smem_bytes = SMEM_SIZE_BWD + bwd_reduction_bytes; - int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / bwd_smem_bytes; - return std::min(max_cta_per_sm, occupancy); - } -}; - -const std::vector NhwcBatchNorm::numWorkspaceBytes() const { - assert(c_ > 0); - - // choose the max memory required between fwd/bwd passes - int grid_x_fwd = div_up(m_, PIXELS_PER_CTA_FWD); - int grid_x_bwd = div_up(m_, PIXELS_PER_CTA_BWD); - int grid_x = max(grid_x_fwd, grid_x_bwd); - int grid_y = div_up(c_, C_ELEMENTS_PER_CTA); - - const size_t num_mean_bytes = c_ * sizeof(float); - const size_t num_variance_bytes = num_mean_bytes; - const size_t size_sums = grid_y*grid_x*THREADS_PER_PIXEL*\ - ELEMENTS_PER_LDG*2*sizeof(float); - const size_t size_counts = grid_y*grid_x*sizeof(int); - - return {num_mean_bytes, num_variance_bytes, - size_retired_ctas(grid_y), size_sums, size_counts}; -} - -void NhwcBatchNorm::setWorkspacePointers( - const std::vector& workspace, - const std::vector& num_workspace_bytes) { - assert(workspace.size() == 5); - assert(num_workspace_bytes.size() == 5); - - minibatch_mean_ = static_cast(workspace[0]); - minibatch_variance_ = static_cast(workspace[1]); - retired_ctas_ = static_cast(workspace[2]); - partial_sums_ = static_cast(workspace[3]); - partial_counts_ = static_cast(workspace[4]); -} - -void NhwcBatchNorm::_setFwdParams(NhwcBatchNormFwdParams *params) const { - params->gmem_src = static_cast(X_); - params->gmem_dst = static_cast(Y_); - params->gmem_src1 = nullptr; - params->gmem_bias = bias_; - params->gmem_scale = scale_; - params->gmem_running_mean = population_mean_; - params->gmem_running_var = population_variance_; - params->gmem_saved_mean = minibatch_mean_; - params->gmem_saved_var = minibatch_variance_; - params->gmem_relu_bitmask = nullptr; - params->nhw = m_; - params->c = c_; - params->svar_inv_count = svar_inv_count_; - params->rvar_inv_count = rvar_inv_count_; - params->gmem_sums = partial_sums_; - params->gmem_counts = partial_counts_; - params->gmem_retired_ctas = retired_ctas_; - params->var_eps = eps_; - params->outer_loops = 0; - params->exp_avg_factor = static_cast(exp_avg_factor_); - params->c_blks = div_up(c_, C_ELEMENTS_PER_CTA); -} - -void NhwcBatchNorm::_setFwdInferenceParams(NhwcBatchNormFwdInferenceParams - *params) const { - params->gmem_src = static_cast(X_); - params->gmem_dst = static_cast(Y_); - params->gmem_src1 = nullptr; - params->gmem_bias = bias_; - params->gmem_scale = scale_; - params->gmem_mean = population_mean_; - params->gmem_var = population_variance_; - params->nhw = m_; - params->c = c_; - params->var_eps = eps_; -} - -void NhwcBatchNorm::_setBwdParams(NhwcBatchNormBwdParams *params) const { - params->gmem_src = static_cast(X_); - params->gmem_dy = static_cast(dY_); - params->gmem_dst = static_cast(dX_); - params->gmem_dst1 = nullptr; - params->gmem_relu_bitmask = nullptr; - params->gmem_dscale = dscale_; - params->gmem_dbias = dbias_; - params->gmem_scale = scale_; - params->gmem_bias = bias_; - params->gmem_saved_mean = minibatch_mean_; - params->gmem_saved_var = minibatch_variance_; - params->nhw = m_; - params->c = c_; - params->svar_inv_count = svar_inv_count_; - params->gmem_sums = partial_sums_; - params->gmem_retired_ctas = retired_ctas_; - params->outer_loops = 0; - params->c_blks = div_up(c_, C_ELEMENTS_PER_CTA); -} - -void NhwcBatchNorm::fwdInference(cudaStream_t stream, bool use_relu) { - bool ptrs_are_set = - X_tensor_desc_ != nullptr - && Y_tensor_desc_ != nullptr - && scale_ != nullptr - && bias_ != nullptr - // && minibatch_mean_ != nullptr - // && minibatch_variance_ != nullptr - && population_mean_ != nullptr - && population_variance_ != nullptr - && X_ != nullptr - // && dX_ != nullptr - && Y_ != nullptr - // && dY_ != nullptr - // && dscale_ != nullptr - // && dbias_ != nullptr - && partial_sums_ != nullptr - && partial_counts_ != nullptr; - - if (!ptrs_are_set) - die(); - - dim3 grid_dim; - grid_dim.x = div_up(m_, PIXELS_PER_CTA_FWD_INFERENCE); - grid_dim.y = div_up(c_, C_ELEMENTS_PER_CTA); - - // @todo: maybe just move this inside initialize routine? - NhwcBatchNormFwdInferenceParams params; - _setFwdInferenceParams(¶ms); - - if (use_relu) { - nhwc_batch_norm_fwd_inference - - <<>>(params); - checkCudaStatus(name_ + " fwd_inference-relu kernel"); - } else { - nhwc_batch_norm_fwd_inference - - <<>>(params); - checkCudaStatus(name_ + " fwd_inference kernel"); - } -} - -dim3 NhwcBatchNorm::calc_fwd_grid(int *loop, const int grid_dim_x) { - dim3 grid_dim; - grid_dim.x = div_up(m_, PIXELS_PER_CTA_FWD); - int c_blks = div_up(c_, C_ELEMENTS_PER_CTA); - unsigned int max_grid_x = grid_dim_x; - if (grid_dim.x <= max_grid_x) { - *loop = 1; - if (max_grid_x / grid_dim.x > 1) { - grid_dim.y = std::min(c_blks, static_cast(max_grid_x / grid_dim.x)); - assert(grid_dim.y 1) { - grid_dim.y = std::min(c_blks, static_cast(max_grid_x / grid_dim.x)); - assert(grid_dim.y> 1); - - dim3 grid_dim = calc_fwd_grid(¶ms.outer_loops, grid_dim_x); - _fwdKernelLauncher(stream, params, grid_dim, params.outer_loops, use_relu, occupancy, coop); -} - -void NhwcBatchNorm::dgrad(cudaStream_t stream, bool use_relu, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, - const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop) { - bool ptrs_are_set = - X_tensor_desc_ != nullptr - && Y_tensor_desc_ != nullptr - && scale_ != nullptr - && (bias_ != nullptr || !use_relu) - && minibatch_mean_ != nullptr - && minibatch_variance_ != nullptr - // && population_mean_ != nullptr - // && population_variance_ != nullptr - && X_ != nullptr - && dX_ != nullptr - // && Y_ != nullptr - && dY_ != nullptr - && dscale_ != nullptr - && dbias_ != nullptr; - - if (!ptrs_are_set) - die(); - - // reset of retired_cta_count no longer needed - - NhwcBatchNormBwdParams params; - _setBwdParams(¶ms); - params.my_data = my_data; - params.pair_datas[0] = pair_data; - params.pair_datas[1] = pair_data2; - params.pair_datas[2] = pair_data3; - params.magic = magic; - params.sync_iters = (bn_group==8)?3:(bn_group >> 1); - params.wgrad_coeff = 1.0 / bn_group; - - dim3 grid_dim = calc_bwd_grid(¶ms.outer_loops, grid_dim_x); - _bwdKernelLauncher(stream, params, grid_dim, params.outer_loops, use_relu, occupancy, coop); -} - -#endif // MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_H_ diff --git a/apex/contrib/csrc/groupbn/batch_norm_add_relu.cu b/apex/contrib/csrc/groupbn/batch_norm_add_relu.cu deleted file mode 100644 index d3cc615..0000000 --- a/apex/contrib/csrc/groupbn/batch_norm_add_relu.cu +++ /dev/null @@ -1,353 +0,0 @@ -#include -#include -#include - -#include "batch_norm_add_relu.h" - -#include - -#include "compat.h" - -//FIXME move the common stuff to common h file -#define cudaCheckErrors(msg) \ - do { \ - cudaError_t __err = cudaGetLastError(); \ - if (__err != cudaSuccess) { \ - fprintf(stderr, "Fatal error: %s (%s at %s:%d)\n", \ - msg, cudaGetErrorString(__err), \ - __FILE__, __LINE__); \ - fprintf(stderr, "*** FAILED - ABORTING\n"); \ - exit(1); \ - } \ - } while (0) - -static size_t round_up_to_multiple(size_t x, int multiple) { - return ((x + multiple - 1) / multiple) * multiple; -} - -struct Workspace { - Workspace(size_t size) : size(size), data(NULL) { - auto& allocator = *::c10::cuda::CUDACachingAllocator::get(); - dataPtr = allocator.allocate(size); - data = dataPtr.get(); - } - Workspace(const Workspace&) = delete; - Workspace(Workspace&&) = default; - Workspace& operator=(Workspace&&) = default; - ~Workspace() = default; - - size_t size; - void* data; - c10::DataPtr dataPtr; -}; - -// Return {y} -at::Tensor nhwc_bn_addrelu_fwd_train( - const at::Tensor& x, - const at::Tensor& z, - const at::Tensor& scale, - const at::Tensor& bias, - const at::Tensor& running_mean, - const at::Tensor& running_inv_var, - const at::Tensor& minibatch_mean, - const at::Tensor& minibatch_inv_var, - const at::Tensor& bitmask, - const at::Tensor& ret_cta, - const float momentum, - const float epsilon, - void * my_data, - void * pair_data, - void * pair_data2, - void * pair_data3, - const int bn_group, - const at::Tensor& magic_tensor, - const int occupancy, - const int grid_dim_x, - const bool coop) { - - auto memory_format = x.suggest_memory_format(); - const bool check_channels_last = x.is_contiguous(at::MemoryFormat::ChannelsLast); - const int N = x.size(0); - const int H = check_channels_last ? x.size(2) : x.size(1); - const int W = check_channels_last ? x.size(3) : x.size(2); - const int C = check_channels_last ? x.size(1) : x.size(3); - - // generating new magic number and use that for sync - int* magic = magic_tensor.DATA_PTR(); - *magic = (*magic + 1) & 0xff; - - // Allocate output tensor - at::Tensor y = check_channels_last? at::empty({N, C, H, W}, x.options().memory_format(memory_format)) : at::empty({N, H, W, C}, x.options()); - - // Create wrapper - NhwcBatchNormAddRelu *bn = new NhwcBatchNormAddRelu(); - - bn->setInputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W, bn_group); - bn->setOutputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W); - - bn->setConstants(momentum, epsilon); - - // set pointers within the wrapper - bn->setInputOutputPointers(x.contiguous(memory_format).DATA_PTR(), - nullptr, - y.contiguous(memory_format).DATA_PTR(), - nullptr, - z.contiguous(memory_format).DATA_PTR(), - nullptr); - - bn->setWeightPointers({scale.contiguous().DATA_PTR(), - bias.contiguous().DATA_PTR()}, {nullptr, nullptr}); - bn->setParameterPointers({running_mean.contiguous().DATA_PTR(), - running_inv_var.contiguous().DATA_PTR()}); - - // deal with workspace(s) - auto workspace_bytes = bn->numWorkspaceBytes(); - // We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset - // an allocated workspace for the others - size_t total_workspace_bytes = 0; - std::vector workspace_offsets; - - for (auto index = 4; index < workspace_bytes.size(); ++index) { - total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512); - workspace_offsets.push_back(total_workspace_bytes); - - auto alloc_bytes = workspace_bytes[index]; - total_workspace_bytes += alloc_bytes; - } - - // Allocate the workspace - Workspace ws(total_workspace_bytes); - - std::vector workspace; - workspace.push_back(minibatch_mean.contiguous().DATA_PTR()); - workspace.push_back(minibatch_inv_var.contiguous().DATA_PTR()); - workspace.push_back(bitmask.contiguous().DATA_PTR()); - - auto stream = at::cuda::getCurrentCUDAStream().stream(); - const int retired_cta_bytes = workspace_bytes[3]; - void* retired_ctas = ret_cta.contiguous().DATA_PTR(); - assert(ret_cta.size(0)>=retired_cta_bytes); - - workspace.push_back(retired_ctas); - - for (auto index = 4; index < workspace_bytes.size(); ++index) { - void *ptr = reinterpret_cast(ws.data) + workspace_offsets[index-4]; - workspace.push_back(ptr); - } - - bn->setWorkspacePointers(workspace, workspace_bytes); - - // Don't fuse in ReLU for now at least - bn->fwd(stream, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x, coop); - - return y.contiguous(memory_format); -} - -at::Tensor nhwc_bn_addrelu_fwd_eval( - const at::Tensor& x, - const at::Tensor& z, - const at::Tensor& scale, - const at::Tensor& bias, - const at::Tensor& running_mean, - const at::Tensor& running_inv_var, - const at::Tensor& ret_cta, - const int bn_group, - const float momentum, - const float epsilon) { - - auto memory_format = x.suggest_memory_format(); - const bool check_channels_last = x.is_contiguous(at::MemoryFormat::ChannelsLast); - const int N = x.size(0); - const int H = check_channels_last ? x.size(2) : x.size(1); - const int W = check_channels_last ? x.size(3) : x.size(2); - const int C = check_channels_last ? x.size(1) : x.size(3); - - // Allocate output tensor - at::Tensor y = check_channels_last? at::empty({N, C, H, W}, x.options().memory_format(memory_format)): at::empty({N, H, W, C}, x.options()); - - // Create wrapper - NhwcBatchNormAddRelu *bn = new NhwcBatchNormAddRelu(); - - bn->setInputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W, bn_group); - bn->setOutputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W); - - bn->setConstants(momentum, epsilon); - - // set pointers within the wrapper - bn->setInputOutputPointers(x.contiguous(memory_format).DATA_PTR(), - nullptr, - y.contiguous(memory_format).DATA_PTR(), - nullptr, - z.contiguous(memory_format).DATA_PTR(), - nullptr); - - bn->setWeightPointers({scale.contiguous().DATA_PTR(), - bias.contiguous().DATA_PTR()}, {nullptr, nullptr}); - bn->setParameterPointers({running_mean.contiguous().DATA_PTR(), - running_inv_var.contiguous().DATA_PTR()}); - - // deal with workspace(s) - auto workspace_bytes = bn->numWorkspaceBytes(); - // We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset - // an allocated workspace for the others - size_t total_workspace_bytes = 0; - std::vector workspace_offsets; - - for (auto index = 4; index < workspace_bytes.size(); ++index) { - total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512); - workspace_offsets.push_back(total_workspace_bytes); - - auto alloc_bytes = workspace_bytes[index]; - total_workspace_bytes += alloc_bytes; - } - - // Allocate the workspace - Workspace ws(total_workspace_bytes); - - std::vector workspace; - workspace.push_back(nullptr); - workspace.push_back(nullptr); - workspace.push_back(nullptr); - - auto stream = at::cuda::getCurrentCUDAStream().stream(); - const int retired_cta_bytes = workspace_bytes[3]; - void* retired_ctas = ret_cta.contiguous().DATA_PTR(); - assert(ret_cta.size(0)>=retired_cta_bytes); - workspace.push_back(retired_ctas); - - for (auto index = 4; index < workspace_bytes.size(); ++index) { - void *ptr = reinterpret_cast(ws.data) + workspace_offsets[index-4]; - workspace.push_back(ptr); - } - - bn->setWorkspacePointers(workspace, workspace_bytes); - - // Don't fuse in ReLU for now at least - bn->fwdInference(stream); - - return y.contiguous(memory_format); - -} - -std::vector nhwc_bn_addrelu_bwd( - const at::Tensor& x, - const at::Tensor& dy, - const at::Tensor& scale, - const at::Tensor& bias, - const at::Tensor& running_mean, - const at::Tensor& running_inv_var, - const at::Tensor& minibatch_mean, - const at::Tensor& minibatch_inv_var, - const at::Tensor& bitmask, - const at::Tensor& ret_cta, - const float momentum, - const float epsilon, - void * my_data, - void * pair_data, - void * pair_data2, - void * pair_data3, - const int bn_group, - const at::Tensor& magic_tensor, - const int occupancy, - const int grid_dim_x, - const bool coop) { - // shape - auto memory_format = x.suggest_memory_format(); - const bool check_channels_last = x.is_contiguous(at::MemoryFormat::ChannelsLast); - const int N = x.size(0); - const int H = check_channels_last ? x.size(2) : x.size(1); - const int W = check_channels_last ? x.size(3) : x.size(2); - const int C = check_channels_last ? x.size(1) : x.size(3); - - // generating new magic number and use that for sync - int* magic = magic_tensor.DATA_PTR(); - *magic = (*magic + 1) & 0xff; - - // outputs - at::Tensor x_grad, z_grad, scale_grad, bias_grad; - - // Allocate outputs - x_grad = check_channels_last ? at::empty({N, C, H, W}, dy.options().memory_format(memory_format)) : at::empty_like(x); - z_grad = check_channels_last ? at::empty({N, C, H, W}, dy.options().memory_format(memory_format)) : at::empty_like(x); - scale_grad = at::empty_like(scale); - bias_grad = at::empty_like(bias); - - // Create wrapper - NhwcBatchNormAddRelu *bn = new NhwcBatchNormAddRelu(); - - bn->setInputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W, bn_group); - bn->setOutputDescriptor(DNN_TENSOR_FORMAT, DNN_DATA_HALF, N, C, H, W); - - bn->setConstants(momentum, epsilon); - - // set pointers within the wrapper - bn->setInputOutputPointers(x.contiguous(memory_format).DATA_PTR(), - x_grad.contiguous(memory_format).DATA_PTR(), - nullptr, - dy.contiguous(memory_format).DATA_PTR(), - nullptr, - z_grad.contiguous(memory_format).DATA_PTR()); - - bn->setWeightPointers({scale.contiguous().DATA_PTR(), - bias.contiguous().DATA_PTR()}, - {scale_grad.DATA_PTR(), bias_grad.DATA_PTR()}); - bn->setParameterPointers({running_mean.contiguous().DATA_PTR(), - running_inv_var.contiguous().DATA_PTR()}); - - // deal with workspace(s) - auto workspace_bytes = bn->numWorkspaceBytes(); - // We'll create explicit tensors for the first 2 workspace ptrs, then allocate & offset - // an allocated workspace for the others - size_t total_workspace_bytes = 0; - std::vector workspace_offsets; - - for (auto index = 4; index < workspace_bytes.size(); ++index) { - total_workspace_bytes = round_up_to_multiple(total_workspace_bytes, 512); - workspace_offsets.push_back(total_workspace_bytes); - - auto alloc_bytes = workspace_bytes[index]; - total_workspace_bytes += alloc_bytes; - } - - // Allocate the workspace - Workspace ws(total_workspace_bytes); - - std::vector workspace; - workspace.push_back(minibatch_mean.contiguous().DATA_PTR()); - workspace.push_back(minibatch_inv_var.contiguous().DATA_PTR()); - workspace.push_back(bitmask.contiguous().DATA_PTR()); - - auto stream = at::cuda::getCurrentCUDAStream().stream(); - const int retired_cta_bytes = workspace_bytes[3]; - void* retired_ctas = ret_cta.contiguous().DATA_PTR(); - assert(ret_cta.size(0)>=retired_cta_bytes); - workspace.push_back(retired_ctas); - - for (auto index = 4; index < workspace_bytes.size(); ++index) { - void *ptr = reinterpret_cast(ws.data) + workspace_offsets[index-4]; - workspace.push_back(ptr); - } - - bn->setWorkspacePointers(workspace, workspace_bytes); - - bn->dgrad(stream, my_data, pair_data, pair_data2, pair_data3, bn_group, *magic, occupancy, grid_dim_x, coop); - - return std::vector{x_grad.contiguous(memory_format), z_grad.contiguous(memory_format), scale_grad, bias_grad}; -} - -int nhwc_bn_addrelu_fwd_occupancy() { - int device_id=-1; - cudaGetDevice(&device_id); - - //max occupancy supported by the code is 2 - return NhwcBatchNormAddRelu::smem_driven_fwd_occupancy(device_id, 2); -} - -int nhwc_bn_addrelu_bwd_occupancy() { - int device_id=-1; - cudaGetDevice(&device_id); - - //max occupancy supported by the code is 2 - return NhwcBatchNormAddRelu::smem_driven_bwd_occupancy(device_id, 2); -} - diff --git a/apex/contrib/csrc/groupbn/batch_norm_add_relu.h b/apex/contrib/csrc/groupbn/batch_norm_add_relu.h deleted file mode 100644 index 4dcb600..0000000 --- a/apex/contrib/csrc/groupbn/batch_norm_add_relu.h +++ /dev/null @@ -1,816 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2018 by Contributors - * \file nhwc_batch_norm_add_relu.h - * \brief CUDA NHWC Batch Normalization code with fused addition - * \author Shankara Rao Thejaswi Nanditale, Dick Carter, Maxim Milakov, Evgeni Krimer -*/ -#ifndef MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_ADD_RELU_H_ -#define MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_ADD_RELU_H_ - -#include "dnn.h" - -#include -#include -#include -#include - -#include "nhwc_batch_norm_kernel.h" -#include "cuda_utils.h" -#include "c10/macros/Macros.h" - -#ifdef __HIP_PLATFORM_HCC__ -using bitmask_t = uint64_t; -using bitmask_pyt_t = int64_t; -#else -using bitmask_t = unsigned int; -using bitmask_pyt_t = int32_t; -#endif - -#define VERBOSE_DEFAULT false - -class NhwcBatchNormAddRelu { - public: - NhwcBatchNormAddRelu() { - name_ = "nhwc_batchnormaddrelu"; - createTensorDescriptor(&X_tensor_desc_); - createTensorDescriptor(&Y_tensor_desc_); - } - - ~NhwcBatchNormAddRelu() { - destroyTensorDescriptor(X_tensor_desc_); - destroyTensorDescriptor(Y_tensor_desc_); - } - - void die() { - std::cerr << "batchnormaddrelu not initialized" << std::endl; - exit(-1); - } - - void fwd(cudaStream_t stream, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop); - void dgrad(cudaStream_t stream, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop); - void fwdInference(cudaStream_t stream); - dim3 calc_fwd_grid(int *loop, const int grid_dim_x); - dim3 calc_bwd_grid(int *loop, const int grid_dim_x); - - void setInputDescriptor(const dnnTensorFormat_t format, - const dnnDataType_t data_type, - int n, int c, int h, int w, int bn_group) { - m_ = n * h * w; - int m_bn_adjusted = m_ * bn_group; - c_ = c; - // factor to scale sum of squared errors to get saved variance. Must be 1/nhw. - svar_inv_count_ = 1.f / m_bn_adjusted; - // factor to scale sum of squared errors to get running variance. Should be 1/(nhw-1). - int divisor = m_bn_adjusted - 1; - // nhw == 1 is unlikely, but by setting the rvar_inv_count_ == 1.f, we avoid running var infs. - rvar_inv_count_ = divisor == 0 ? 1.f : 1.f / divisor; - setTensorDescriptor(X_tensor_desc_, format, data_type, n, c, h, w); - } - - void setOutputDescriptor(const dnnTensorFormat_t format, - const dnnDataType_t data_type, - int n, int c, int h, int w) { - setTensorDescriptor(Y_tensor_desc_, format, data_type, n, c, h, w); - } - - const std::vector numWorkspaceBytes() const; - - void setWorkspacePointers( - const std::vector& workspace, - const std::vector& num_workspace_bytes); - - void setInputOutputPointers(void* X, void* dX, void* Y, void *dY, void* addend, void* dAddend) { - X_ = X; - dX_ = dX; - Y_ = Y; - dY_ = dY; - addend_ = addend; - dAddend_ = dAddend; - } - - // Sets the pointers for the scale and weight (in that order) data and derivative buffers. - void setWeightPointers(const std::vector& weight_pointers, - const std::vector& deriv_pointers) { - assert(weight_pointers.size() == 2); - assert(deriv_pointers.size() == 2); - scale_ = static_cast(weight_pointers[0]); - bias_ = static_cast(weight_pointers[1]); - dscale_ = static_cast(deriv_pointers[0]); - dbias_ = static_cast(deriv_pointers[1]); - } - - // Sets the pointers for the population mean and variance buffers, in that order. - void setParameterPointers(const std::vector& param_pointers) { - assert(param_pointers.size() == 2); - population_mean_ = static_cast(param_pointers[0]); - population_variance_ = static_cast(param_pointers[1]); - } - - void setConstants(const double exp_avg_factor, const double eps) { - exp_avg_factor_ = exp_avg_factor; - eps_ = eps; - } - - void processCudnnStatus(const dnnStatus_t& status, - const std::string& string = std::string(), - bool verbose = VERBOSE_DEFAULT) { -#ifdef __HIP_PLATFORM_HCC__ - if (status != DNN_STATUS_SUCCESS) - LOG(FATAL) << string << " " << miopenGetErrorString(status); - else if (verbose) - LOG(INFO) << string << " " << miopenGetErrorString(status); -#else - if (status != DNN_STATUS_SUCCESS) - LOG(FATAL) << string << " " << cudnnGetErrorString(status); - else if (verbose) - LOG(INFO) << string << " " << cudnnGetErrorString(status); -#endif - } - - void checkCudaStatus(const std::string& string = std::string(), - bool verbose = VERBOSE_DEFAULT) { - cudaError_t status = cudaGetLastError(); - if (status != cudaSuccess) - LOG(FATAL) << string << " " << cudaGetErrorString(status); - else if (verbose) - LOG(INFO) << string << " " << cudaGetErrorString(status); - } - - size_t size_retired_ctas(int grid_y) const { - // Note that the value of max_grid_y to handle known GPUs is about 160. - const int max_grid_y = 1024; - if (grid_y > max_grid_y) - LOG(INFO) << "GPU capabilities exceeds assumptions."; - const int retired_cta_bytes = max_grid_y * 2 * sizeof(int); - // Since the region will be initialized once and used for many kernels, - // the idea is to return an ample size that will cover all uses. - return retired_cta_bytes; - } - - dnnTensorDescriptor_t X_tensor_desc_ = nullptr; - dnnTensorDescriptor_t Y_tensor_desc_ = nullptr; - - void* X_ = nullptr; - void* dX_ = nullptr; - void* Y_ = nullptr; - void* dY_ = nullptr; - void* addend_ = nullptr; - void* dAddend_ = nullptr; - - // Learned scale and bias weights. - float* scale_ = nullptr; - float* dscale_ = nullptr; - float* bias_ = nullptr; - float* dbias_ = nullptr; - - // Computed population mean and variance parameters. - float* population_mean_ = nullptr; - float* population_variance_ = nullptr; - - // Workspace buffers for minibatch mean and variance (computed in fwd, needed by bwd). - float* minibatch_mean_ = nullptr; - float* minibatch_variance_ = nullptr; - - int m_ = 0; // Number of values per channel that BN is normalizing. - int c_ = 0; // Number of channels over which BN is normalizing. - - float svar_inv_count_ = 0.f; // factor to scale sum of squared errors to get saved variance - float rvar_inv_count_ = 0.f; // factor to scale sum of squared errors to get running variance - - double exp_avg_factor_ = 0.; - double eps_ = 0.; - std::string name_; - - private: - void setTensorDescriptor(dnnTensorDescriptor_t descriptor, - dnnTensorFormat_t format, - dnnDataType_t data_type, - int n, int c, int h, int w) { - dnnStatus_t status = DNN_STATUS_SUCCESS; -#ifdef __HIP_PLATFORM_HCC__ - status = miopenSet4dTensorDescriptor(descriptor, data_type, n, c, h, w); -#else - status = cudnnSetTensor4dDescriptor(descriptor, format, data_type, n, c, h, w); -#endif - processCudnnStatus(status, "set tensor descriptor"); - } - - void createTensorDescriptor(dnnTensorDescriptor_t *descriptor) { - dnnStatus_t status = DNN_STATUS_SUCCESS; -#ifdef __HIP_PLATFORM_HCC__ - status = miopenCreateTensorDescriptor(descriptor); -#else - status = cudnnCreateTensorDescriptor(descriptor); -#endif - processCudnnStatus(status, "create tensor_descriptor"); - } - - void destroyTensorDescriptor(dnnTensorDescriptor_t descriptor) { - dnnStatus_t status = DNN_STATUS_SUCCESS; -#ifdef __HIP_PLATFORM_HCC__ - status = miopenDestroyTensorDescriptor(descriptor); -#else - status = cudnnDestroyTensorDescriptor(descriptor); -#endif - processCudnnStatus(status, "destroy tensor_descriptor"); - } - - protected: - float *partial_sums_ = nullptr; - int *partial_counts_ = nullptr; - int *retired_ctas_ = nullptr; - bitmask_t *relu_bitmask_ = nullptr; - - void _setFwdParams(NhwcBatchNormFwdParams *params) const; - void _setFwdInferenceParams(NhwcBatchNormFwdInferenceParams *params) const; - void _setBwdParams(NhwcBatchNormBwdParams *params) const; - - // @todo: ability to configure these? - // Kernel params - static const int USE_ONLINE_APPROACH = 1; - static const int THREADS_PER_CTA = 512; - static const int THREADS_PER_PIXEL = 32; - static const int C_ELEMENTS_PER_CTA = 128; - static const int ELEMENTS_PER_LDG = C_ELEMENTS_PER_CTA / THREADS_PER_PIXEL; - static const int MAX_SMEM_WITHOUT_OPT_IN = 48 * 1024; - - typedef uint16_t StorageType; - // increasing this to 6 causes spills in fwd kernel! - static const int PIXELS_PER_THREAD_IN_REGISTERS_FWD = 1; - static const int PIXELS_PER_THREAD_IN_REGISTERS_BWD = 1; - static const int PIXELS_PER_THREAD_IN_SMEM_FWD = 0; - static const int PIXELS_PER_THREAD_IN_SMEM_BWD = 0; - - static const int PIXELS_PER_THREAD_FWD = PIXELS_PER_THREAD_IN_REGISTERS_FWD + \ - PIXELS_PER_THREAD_IN_SMEM_FWD; - static const int PIXELS_PER_THREAD_BWD = PIXELS_PER_THREAD_IN_REGISTERS_BWD + \ - PIXELS_PER_THREAD_IN_SMEM_BWD; - static const int PIXELS_PER_THREAD_FWD_INFERENCE = 4; - - // Derived params - static const size_t SMEM_SIZE_FWD = PIXELS_PER_THREAD_IN_SMEM_FWD*THREADS_PER_CTA*\ - ELEMENTS_PER_LDG*sizeof(StorageType); - static const size_t SMEM_SIZE_BWD = PIXELS_PER_THREAD_IN_SMEM_BWD*THREADS_PER_CTA*\ - ELEMENTS_PER_LDG*2*sizeof(StorageType); - static const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL; - static const int PIXELS_PER_CTA_FWD = THREADS_PER_CTA/THREADS_PER_PIXEL * \ - PIXELS_PER_THREAD_FWD; - static const int PIXELS_PER_CTA_BWD = THREADS_PER_CTA/THREADS_PER_PIXEL * \ - PIXELS_PER_THREAD_BWD; - static const int PIXELS_PER_CTA_FWD_INFERENCE = THREADS_PER_CTA/THREADS_PER_PIXEL * \ - PIXELS_PER_THREAD_FWD_INFERENCE; - - // max grid.y in case of group bn is limited by exchange buffer size - static const int MAX_GBN_BLOCK_Y = 256; - - // Helper function to launch the forward kernel. - - // We calculate (based on smem usage) the achievable occupancy and make sure we run a kernel - // version that was compiled with that occupancy in its launch bounds. This way, we avoid - // needless register spills. - void _fwdKernelLauncher(cudaStream_t stream, NhwcBatchNormFwdParams params, - dim3 grid_dim, int outer_loops, const int occupancy, const bool coop) { -#ifdef __HIP_PLATFORM_HCC__ -#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY, COOP) \ - do { \ - CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << \ - "Nhwc batchnormaddrelu kernel smem too big."; \ - auto fwd_func = nhwc_batch_norm_fwd< \ - StorageType, \ - THREADS_PER_CTA, \ - THREADS_PER_PIXEL, \ - PIXELS_PER_THREAD_IN_REGISTERS_FWD, \ - PIXELS_PER_THREAD_IN_SMEM_FWD, \ - ELEMENTS_PER_LDG, \ - USE_ONLINE_APPROACH, \ - OUTER_LOOPS, \ - USE_RELU, \ - USE_ADD_RELU, \ - COMPILED_FOR_OCCUPANCY>; \ - if (COMPILED_FOR_OCCUPANCY > 1) { \ - hipFuncSetAttribute((void *) fwd_func, hipFuncAttributePreferredSharedMemoryCarveout, 100); \ - checkCudaStatus(name_ + " fwd ser coop kernel (cudaFuncSetAttribute carveout)"); \ - } \ - void *params_ptr = static_cast(¶ms); \ - using FWD_FUNC = decltype(nhwc_batch_norm_fwd< \ - StorageType, \ - THREADS_PER_CTA, \ - THREADS_PER_PIXEL, \ - PIXELS_PER_THREAD_IN_REGISTERS_FWD, \ - PIXELS_PER_THREAD_IN_SMEM_FWD, \ - ELEMENTS_PER_LDG, \ - USE_ONLINE_APPROACH, \ - OUTER_LOOPS, \ - USE_RELU, \ - USE_ADD_RELU, \ - COMPILED_FOR_OCCUPANCY>); \ - if (COOP) { \ - hipLaunchCooperativeKernel(fwd_func, \ - grid_dim, \ - THREADS_PER_CTA, \ - ¶ms_ptr, \ - SMEM_SIZE_FWD, \ - stream); \ - } else { \ - hipLaunchKernel((void *) fwd_func, \ - grid_dim, \ - THREADS_PER_CTA, \ - ¶ms_ptr, \ - SMEM_SIZE_FWD, \ - stream); \ - } \ - checkCudaStatus(name_ + " fwd ser coop kernel"); \ - } while (0) -#else -#define LAUNCH_FWD_KERNEL(OUTER_LOOPS, USE_RELU, USE_ADD_RELU, COMPILED_FOR_OCCUPANCY, COOP) \ - do { \ - CHECK(SMEM_SIZE_FWD <= MAX_SMEM_WITHOUT_OPT_IN) << \ - "Nhwc batchnormaddrelu kernel smem too big."; \ - auto fwd_func = nhwc_batch_norm_fwd< \ - StorageType, \ - THREADS_PER_CTA, \ - THREADS_PER_PIXEL, \ - PIXELS_PER_THREAD_IN_REGISTERS_FWD, \ - PIXELS_PER_THREAD_IN_SMEM_FWD, \ - ELEMENTS_PER_LDG, \ - USE_ONLINE_APPROACH, \ - OUTER_LOOPS, \ - USE_RELU, \ - USE_ADD_RELU, \ - COMPILED_FOR_OCCUPANCY>; \ - if (COMPILED_FOR_OCCUPANCY > 1) { \ - cudaFuncSetAttribute(fwd_func, cudaFuncAttributePreferredSharedMemoryCarveout, 100); \ - checkCudaStatus(name_ + " fwd ser coop kernel (cudaFuncSetAttribute carveout)"); \ - } \ - void *params_ptr = static_cast(¶ms); \ - using FWD_FUNC = decltype(nhwc_batch_norm_fwd< \ - StorageType, \ - THREADS_PER_CTA, \ - THREADS_PER_PIXEL, \ - PIXELS_PER_THREAD_IN_REGISTERS_FWD, \ - PIXELS_PER_THREAD_IN_SMEM_FWD, \ - ELEMENTS_PER_LDG, \ - USE_ONLINE_APPROACH, \ - OUTER_LOOPS, \ - USE_RELU, \ - USE_ADD_RELU, \ - COMPILED_FOR_OCCUPANCY>); \ - if (COOP) { \ - cudaLaunchCooperativeKernel(fwd_func, \ - grid_dim, \ - THREADS_PER_CTA, \ - ¶ms_ptr, \ - SMEM_SIZE_FWD, \ - stream); \ - } else { \ - cudaLaunchKernel(fwd_func, \ - grid_dim, \ - THREADS_PER_CTA, \ - ¶ms_ptr, \ - SMEM_SIZE_FWD, \ - stream); \ - } \ - checkCudaStatus(name_ + " fwd ser coop kernel"); \ - } while (0) -#endif - - // Don't try for an occupancy > 2 as this will squeeze register use and create spills. - if (outer_loops == 1) { - if (occupancy >= 2) - LAUNCH_FWD_KERNEL(1, false, true, 2, coop); - else - LAUNCH_FWD_KERNEL(1, false, true, 1, coop); - } else { - if (occupancy >= 2) - LAUNCH_FWD_KERNEL(0, false, true, 2, coop); - else - LAUNCH_FWD_KERNEL(0, false, true, 1, coop); - } -#undef LAUNCH_FWD_KERNEL - } - - // Helper function to launch the backward kernel. - - void _bwdKernelLauncher(cudaStream_t stream, NhwcBatchNormBwdParams params, - dim3 grid_dim, int outer_loops, const int occupancy, const bool coop) { -#ifdef __HIP_PLATFORM_HCC__ -#define LAUNCH_BWD_ADD_RELU_KERNEL(OUTER_LOOPS, COMPILED_FOR_OCCUPANCY, COOP) \ - do { \ - CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << \ - "Nhwc batchnormaddrelu kernel smem too big."; \ - auto bwd_add_relu_func = nhwc_batch_norm_bwd_add_relu< \ - StorageType, \ - THREADS_PER_CTA, \ - THREADS_PER_PIXEL, \ - PIXELS_PER_THREAD_IN_REGISTERS_BWD, \ - PIXELS_PER_THREAD_IN_SMEM_BWD, \ - ELEMENTS_PER_LDG, \ - USE_ONLINE_APPROACH, \ - OUTER_LOOPS, \ - COMPILED_FOR_OCCUPANCY>; \ - if (COMPILED_FOR_OCCUPANCY > 1) { \ - hipFuncSetAttribute((void *) bwd_add_relu_func, \ - hipFuncAttributePreferredSharedMemoryCarveout, 100); \ - checkCudaStatus(name_ + \ - " bwd-add-relu coop serial kernel (cudaFuncSetAttribute carveout)"); \ - } \ - void *params_ptr = static_cast(¶ms); \ - using BWD_ADD_RELU_FUNC = decltype(nhwc_batch_norm_bwd_add_relu< \ - StorageType, \ - THREADS_PER_CTA, \ - THREADS_PER_PIXEL, \ - PIXELS_PER_THREAD_IN_REGISTERS_BWD, \ - PIXELS_PER_THREAD_IN_SMEM_BWD, \ - ELEMENTS_PER_LDG, \ - USE_ONLINE_APPROACH, \ - OUTER_LOOPS, \ - COMPILED_FOR_OCCUPANCY>); \ - if (COOP) { \ - hipLaunchCooperativeKernel(bwd_add_relu_func, \ - grid_dim, \ - THREADS_PER_CTA, \ - ¶ms_ptr, \ - SMEM_SIZE_BWD, \ - stream); \ - } else { \ - hipLaunchKernel((void *) bwd_add_relu_func, \ - grid_dim, \ - THREADS_PER_CTA, \ - ¶ms_ptr, \ - SMEM_SIZE_BWD, \ - stream); \ - } \ - checkCudaStatus(name_ + " bwd-add-relu coop serial kernel"); \ - } while (0) -#else - do { \ - CHECK(SMEM_SIZE_BWD <= MAX_SMEM_WITHOUT_OPT_IN) << \ - "Nhwc batchnormaddrelu kernel smem too big."; \ - auto bwd_add_relu_func = nhwc_batch_norm_bwd_add_relu< \ - StorageType, \ - THREADS_PER_CTA, \ - THREADS_PER_PIXEL, \ - PIXELS_PER_THREAD_IN_REGISTERS_BWD, \ - PIXELS_PER_THREAD_IN_SMEM_BWD, \ - ELEMENTS_PER_LDG, \ - USE_ONLINE_APPROACH, \ - OUTER_LOOPS, \ - COMPILED_FOR_OCCUPANCY>; \ - if (COMPILED_FOR_OCCUPANCY > 1) { \ - cudaFuncSetAttribute(bwd_add_relu_func, \ - cudaFuncAttributePreferredSharedMemoryCarveout, 100); \ - checkCudaStatus(name_ + \ - " bwd-add-relu coop serial kernel (cudaFuncSetAttribute carveout)"); \ - } \ - void *params_ptr = static_cast(¶ms); \ - using BWD_ADD_RELU_FUNC = decltype(nhwc_batch_norm_bwd_add_relu< \ - StorageType, \ - THREADS_PER_CTA, \ - THREADS_PER_PIXEL, \ - PIXELS_PER_THREAD_IN_REGISTERS_BWD, \ - PIXELS_PER_THREAD_IN_SMEM_BWD, \ - ELEMENTS_PER_LDG, \ - USE_ONLINE_APPROACH, \ - OUTER_LOOPS, \ - COMPILED_FOR_OCCUPANCY>); \ - if (COOP) { \ - cudaLaunchCooperativeKernel(bwd_add_relu_func, \ - grid_dim, \ - THREADS_PER_CTA, \ - ¶ms_ptr, \ - SMEM_SIZE_BWD, \ - stream); \ - } else { \ - cudaLaunchKernel(bwd_add_relu_func, \ - grid_dim, \ - THREADS_PER_CTA, \ - ¶ms_ptr, \ - SMEM_SIZE_BWD, \ - stream); \ - } \ - checkCudaStatus(name_ + " bwd-add-relu coop serial kernel"); \ - } while (0) -#endif - - // Don't try for an occupancy > 2 as this will squeeze register use and create spills. - if (outer_loops == 1) { - if (occupancy >= 2) - LAUNCH_BWD_ADD_RELU_KERNEL(1, 2, coop); - else - LAUNCH_BWD_ADD_RELU_KERNEL(1, 1, coop); - } else { - if (occupancy >= 2) - LAUNCH_BWD_ADD_RELU_KERNEL(0, 2, coop); - else - LAUNCH_BWD_ADD_RELU_KERNEL(0, 1, coop); - } -#undef LAUNCH_BWD_KERNEL - } - - public: - // Calculate the expected fwd kernel occupancy, as dictated by shared memory usage. - static int smem_driven_fwd_occupancy(int device_id, const int max_cta_per_sm) { - using namespace at::cuda::utils; - int fwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/C10_WARP_SIZE)*ELEMENTS_PER_LDG*sizeof(float); - int fwd_smem_bytes = SMEM_SIZE_FWD + fwd_reduction_bytes; - int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / fwd_smem_bytes; - return std::min(max_cta_per_sm, occupancy); - } - - // Calculate the expected bwd kernel occupancy, as dictated by shared memory usage. - static int smem_driven_bwd_occupancy(int device_id, const int max_cta_per_sm) { - using namespace at::cuda::utils; - int bwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/C10_WARP_SIZE)*ELEMENTS_PER_LDG*sizeof(float); - int bwd_smem_bytes = SMEM_SIZE_BWD + bwd_reduction_bytes; - int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / bwd_smem_bytes; - return std::min(max_cta_per_sm, occupancy); - } -}; - -const std::vector NhwcBatchNormAddRelu::numWorkspaceBytes() const { - assert(c_ > 0); - - // choose the max memory required between fwd/bwd passes - int grid_x_fwd = div_up(m_, PIXELS_PER_CTA_FWD); - int grid_x_bwd = div_up(m_, PIXELS_PER_CTA_BWD); - int grid_x = max(grid_x_fwd, grid_x_bwd); - int grid_y = div_up(c_, C_ELEMENTS_PER_CTA); - - const size_t num_mean_bytes = c_ * sizeof(float); - const size_t num_variance_bytes = num_mean_bytes; - -#ifdef __HIP_PLATFORM_HCC__ - int elems_per_group = ((m_ + 3) & ~3) * 2; -#else - int elems_per_group = ((m_ + 31) & ~31) * 2; -#endif - int group_count = div_up(c_, C_ELEMENTS_PER_CTA); - const size_t bitmask_bytes = elems_per_group * group_count * sizeof(bitmask_t); - - const size_t size_sums = grid_y*grid_x*THREADS_PER_PIXEL*\ - ELEMENTS_PER_LDG*2*sizeof(float); - const size_t size_counts = grid_y*grid_x*sizeof(int); - - return {num_mean_bytes, num_variance_bytes, bitmask_bytes, - size_retired_ctas(grid_y), size_sums, size_counts}; -} - -void NhwcBatchNormAddRelu::setWorkspacePointers( - const std::vector& workspace, - const std::vector& num_workspace_bytes) { - assert(workspace.size() == 6); - assert(num_workspace_bytes.size() == 6); - - minibatch_mean_ = static_cast(workspace[0]); - minibatch_variance_ = static_cast(workspace[1]); - relu_bitmask_ = static_cast(workspace[2]); - retired_ctas_ = static_cast(workspace[3]); - partial_sums_ = static_cast(workspace[4]); - partial_counts_ = static_cast(workspace[5]); -} - -void NhwcBatchNormAddRelu::_setFwdParams(NhwcBatchNormFwdParams *params) const { - params->gmem_src = static_cast(X_); - params->gmem_dst = static_cast(Y_); - params->gmem_src1 = static_cast(addend_); - params->gmem_bias = bias_; - params->gmem_scale = scale_; - params->gmem_running_mean = population_mean_; - params->gmem_running_var = population_variance_; - params->gmem_saved_mean = minibatch_mean_; - params->gmem_saved_var = minibatch_variance_; - params->gmem_relu_bitmask = relu_bitmask_; - params->nhw = m_; - params->c = c_; - params->svar_inv_count = svar_inv_count_; - params->rvar_inv_count = rvar_inv_count_; - params->gmem_sums = partial_sums_; - params->gmem_counts = partial_counts_; - params->gmem_retired_ctas = retired_ctas_; - params->var_eps = eps_; - params->outer_loops = 0; - params->exp_avg_factor = static_cast(exp_avg_factor_); - params->c_blks = div_up(c_, C_ELEMENTS_PER_CTA); -} - -void NhwcBatchNormAddRelu::_setFwdInferenceParams(NhwcBatchNormFwdInferenceParams - *params) const { - params->gmem_src = static_cast(X_); - params->gmem_dst = static_cast(Y_); - params->gmem_src1 = static_cast(addend_); - params->gmem_bias = bias_; - params->gmem_scale = scale_; - params->gmem_mean = population_mean_; - params->gmem_var = population_variance_; - params->nhw = m_; - params->c = c_; - params->var_eps = eps_; -} - -void NhwcBatchNormAddRelu::_setBwdParams(NhwcBatchNormBwdParams *params) const { - params->gmem_src = static_cast(X_); - params->gmem_dy = static_cast(dY_); - params->gmem_dst = static_cast(dX_); - params->gmem_dst1 = static_cast(dAddend_); - params->gmem_relu_bitmask = relu_bitmask_; - params->gmem_dscale = dscale_; - params->gmem_dbias = dbias_; - params->gmem_scale = scale_; - params->gmem_bias = bias_; - params->gmem_saved_mean = minibatch_mean_; - params->gmem_saved_var = minibatch_variance_; - params->nhw = m_; - params->c = c_; - params->svar_inv_count = svar_inv_count_; - params->gmem_sums = partial_sums_; - params->gmem_retired_ctas = retired_ctas_; - params->outer_loops = 0; - params->c_blks = div_up(c_, C_ELEMENTS_PER_CTA); -} - -void NhwcBatchNormAddRelu::fwdInference(cudaStream_t stream) { - bool ptrs_are_set = - X_tensor_desc_ != nullptr - && Y_tensor_desc_ != nullptr - && scale_ != nullptr - && bias_ != nullptr - // && minibatch_mean_ != nullptr - // && minibatch_variance_ != nullptr - && population_mean_ != nullptr - && population_variance_ != nullptr - && X_ != nullptr - // && dX_ != nullptr - && Y_ != nullptr - && addend_ != nullptr - // && dY_ != nullptr - // && dscale_ != nullptr - // && dbias_ != nullptr - && partial_sums_ != nullptr - && partial_counts_ != nullptr; - - if (!ptrs_are_set) - die(); - - dim3 grid_dim; - grid_dim.x = div_up(m_, PIXELS_PER_CTA_FWD_INFERENCE); - grid_dim.y = div_up(c_, C_ELEMENTS_PER_CTA); - - // @todo: maybe just move this inside initialize routine? - NhwcBatchNormFwdInferenceParams params; - _setFwdInferenceParams(¶ms); - - nhwc_batch_norm_fwd_inference - - <<>>(params); - checkCudaStatus(name_ + " fwd_inference-relu kernel"); -} - -dim3 NhwcBatchNormAddRelu::calc_fwd_grid(int *loop, const int grid_dim_x) { - dim3 grid_dim; - grid_dim.x = div_up(m_, PIXELS_PER_CTA_FWD); - int c_blks = div_up(c_, C_ELEMENTS_PER_CTA); - unsigned int max_grid_x = grid_dim_x; - if (grid_dim.x <= max_grid_x) { - *loop = 1; - if (max_grid_x / grid_dim.x > 1) { - grid_dim.y = std::min(c_blks, static_cast(max_grid_x / grid_dim.x)); - assert(grid_dim.y 1) { - grid_dim.y = std::min(c_blks, static_cast(max_grid_x / grid_dim.x)); - assert(grid_dim.y> 1); - - dim3 grid_dim = calc_fwd_grid(¶ms.outer_loops, grid_dim_x); - _fwdKernelLauncher(stream, params, grid_dim, params.outer_loops, occupancy, coop); -} - -void NhwcBatchNormAddRelu::dgrad(cudaStream_t stream, void* my_data, void* pair_data, void* pair_data2, void* pair_data3, - const int bn_group, const int magic, const int occupancy, const int grid_dim_x, const bool coop) { - bool ptrs_are_set = - X_tensor_desc_ != nullptr - && Y_tensor_desc_ != nullptr - && scale_ != nullptr - && bias_ != nullptr - && minibatch_mean_ != nullptr - && minibatch_variance_ != nullptr - && relu_bitmask_ != nullptr - // && population_mean_ != nullptr - // && population_variance_ != nullptr - && X_ != nullptr - && dX_ != nullptr - // && Y_ != nullptr - && dY_ != nullptr - && dAddend_ != nullptr - && dscale_ != nullptr - && dbias_ != nullptr - && retired_ctas_ != nullptr; - - if (!ptrs_are_set) - die(); - - // reset of retired_cta_count no longer needed - - NhwcBatchNormBwdParams params; - _setBwdParams(¶ms); - - params.my_data = my_data; - params.pair_datas[0] = pair_data; - params.pair_datas[1] = pair_data2; - params.pair_datas[2] = pair_data3; - params.magic = magic; - params.sync_iters = (bn_group==8)?3:(bn_group >> 1); - params.wgrad_coeff = 1.0 / bn_group; - - dim3 grid_dim = calc_bwd_grid(¶ms.outer_loops, grid_dim_x); - _bwdKernelLauncher(stream, params, grid_dim, params.outer_loops, occupancy, coop); -} - -#endif // MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_ADD_RELU_H_ diff --git a/apex/contrib/csrc/groupbn/cuda_utils.h b/apex/contrib/csrc/groupbn/cuda_utils.h deleted file mode 100644 index fa172f9..0000000 --- a/apex/contrib/csrc/groupbn/cuda_utils.h +++ /dev/null @@ -1,28 +0,0 @@ -#ifdef __HIP_PLATFORM_HCC__ -#include -#else -#include -#endif -#ifndef CUDA_UTILS_H -#define CUDA_UTILS_H - -namespace at { -namespace cuda { - -namespace utils { - -static inline int MaxSharedMemoryPerMultiprocessor(int device_id) { -#ifdef __HIP_PLATFORM_HCC__ - return getDeviceProperties(device_id)->maxSharedMemoryPerMultiProcessor; -#else - return getDeviceProperties(device_id)->sharedMemPerMultiprocessor; -#endif -} - - -} -} -} - - -#endif diff --git a/apex/contrib/csrc/groupbn/dnn.h b/apex/contrib/csrc/groupbn/dnn.h deleted file mode 100644 index 642a473..0000000 --- a/apex/contrib/csrc/groupbn/dnn.h +++ /dev/null @@ -1,26 +0,0 @@ -#ifndef DNN_H -#define DNN_H - -#ifdef __HIP_PLATFORM_HCC__ -#include -#define DNN_STATUS_SUCCESS miopenStatusSuccess -#define DNN_DATA_HALF miopenHalf -#define DNN_TENSOR_FORMAT 0 - -using dnnTensorFormat_t = int; -using dnnDataType_t = miopenDataType_t; -using dnnStatus_t = miopenStatus_t; -using dnnTensorDescriptor_t = miopenTensorDescriptor_t; -#else -#include -#define DNN_STATUS_SUCCESS CUDNN_STATUS_SUCCESS -#define DNN_DATA_HALF CUDNN_DATA_HALF -#define DNN_TENSOR_FORMAT CUDNN_TENSOR_NHWC - -using dnnTensorFormat_t = cudnnTensorFormat_t; -using dnnDataType_t = cudnnDataType_t; -using dnnStatus_t = cudnnStatus_t; -using dnnTensorDescriptor_t = cudnnTensorDescriptor_t; -#endif - -#endif // DNN_H diff --git a/apex/contrib/csrc/groupbn/interface.cpp b/apex/contrib/csrc/groupbn/interface.cpp deleted file mode 100644 index 8cea5f9..0000000 --- a/apex/contrib/csrc/groupbn/interface.cpp +++ /dev/null @@ -1,175 +0,0 @@ -#include -#include -#include - -#include -#include -#include -#include -#include "ATen/Scalar.h" -#ifndef VERSION_GE_1_1 -#include "ATen/Type.h" -#endif -#include "ATen/Tensor.h" -#include "ATen/Storage.h" -#include "ATen/Generator.h" - - -namespace py = pybind11; - -int64_t get_buffer_size( - const int bn_sync_steps); - -void* get_data_ptr( - const at::Tensor& data); - -void* get_remote_data_ptr( - const at::Tensor& handle, - const int64_t offset); - -void close_remote_data( - const at::Tensor& handle); - -at::Tensor nhwc_bn_fwd_train( - const at::Tensor& x, - const at::Tensor& scale, - const at::Tensor& bias, - const at::Tensor& running_mean, - const at::Tensor& running_inv_var, - const at::Tensor& minibatch_mean, - const at::Tensor& minibatch_inv_var, - const at::Tensor& ret_cta, - const float momentum, - const float epsilon, - const bool fuse_relu, - void* my_data, - void* pair_data, - void* pair_data2, - void* pair_data3, - const int bn_group, - const at::Tensor& magic_tensor, - const int occupancy, - const int grid_dim_x, - const bool coop); - -at::Tensor nhwc_bn_fwd_eval( - const at::Tensor& x, - const at::Tensor& scale, - const at::Tensor& bias, - const at::Tensor& running_mean, - const at::Tensor& running_inv_var, - const at::Tensor& ret_cta, - const int bn_group, - const float momentum, - const float epsilon, - const bool fuse_relu); - -std::vector nhwc_bn_bwd( - const at::Tensor& x, - const at::Tensor& dy, - const at::Tensor& scale, - const at::Tensor& bias, - const at::Tensor& running_mean, - const at::Tensor& running_inv_var, - const at::Tensor& minibatch_mean, - const at::Tensor& minibatch_inv_var, - const at::Tensor& ret_cta, - const float momentum, - const float epsilon, - const bool fuse_relu, - void* my_data, - void* pair_data, - void* pair_data2, - void* pair_data3, - const int bn_group, - const at::Tensor& magic_tensor, - const int occupancy, - const int grid_dim_x, - const bool coop); - -at::Tensor nhwc_bn_addrelu_fwd_train( - const at::Tensor& x, - const at::Tensor& z, - const at::Tensor& scale, - const at::Tensor& bias, - const at::Tensor& running_mean, - const at::Tensor& running_inv_var, - const at::Tensor& minibatch_mean, - const at::Tensor& minibatch_inv_var, - const at::Tensor& bitmask, - const at::Tensor& ret_cta, - const float momentum, - const float epsilon, - void* my_data, - void* pair_data, - void* pair_data2, - void* pair_data3, - const int bn_group, - const at::Tensor& magic_tensor, - const int occupancy, - const int grid_dim_x, - const bool coop); - -at::Tensor nhwc_bn_addrelu_fwd_eval( - const at::Tensor& x, - const at::Tensor& z, - const at::Tensor& scale, - const at::Tensor& bias, - const at::Tensor& running_mean, - const at::Tensor& running_inv_var, - const at::Tensor& ret_cta, - const int bn_group, - const float momentum, - const float epsilon); - -std::vector nhwc_bn_addrelu_bwd( - const at::Tensor& x, - const at::Tensor& dy, - const at::Tensor& scale, - const at::Tensor& bias, - const at::Tensor& running_mean, - const at::Tensor& running_inv_var, - const at::Tensor& minibatch_mean, - const at::Tensor& minibatch_inv_var, - const at::Tensor& bitmask, - const at::Tensor& ret_cta, - const float momentum, - const float epsilon, - void* my_data, - void* pair_data, - void* pair_data2, - void* pair_data3, - const int bn_group, - const at::Tensor& magic_tensor, - const int occupancy, - const int grid_dim_x, - const bool coop); - -int nhwc_bn_fwd_occupancy(); -int nhwc_bn_bwd_occupancy(); - -int nhwc_bn_addrelu_fwd_occupancy(); -int nhwc_bn_addrelu_bwd_occupancy(); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - - m.def("get_buffer_size", &get_buffer_size, "get_buffer_size"); - m.def("get_data_ptr", &get_data_ptr, "get_data_ptr"); - m.def("get_remote_data_ptr", &get_remote_data_ptr, "get_remote_data_ptr"); - m.def("close_remote_data", &close_remote_data, "close_remote_data"); - - m.def("bn_fwd_nhwc", &nhwc_bn_fwd_train, "bn_fwd_nhwc"); - m.def("bn_fwd_eval_nhwc", &nhwc_bn_fwd_eval, "bn_fwd_eval_nhwc"); - m.def("bn_bwd_nhwc", &nhwc_bn_bwd, "bn_bwd_nhwc"); - - m.def("bn_fwd_nhwc_occupancy", &nhwc_bn_fwd_occupancy, "bn_fwd_nhwc_occupancy"); - m.def("bn_bwd_nhwc_occupancy", &nhwc_bn_bwd_occupancy, "bn_bwd_nhwc_occupancy"); - - m.def("bn_addrelu_fwd_nhwc", &nhwc_bn_addrelu_fwd_train, "bn_addrelu_fwd_nhwc"); - m.def("bn_addrelu_fwd_eval_nhwc", &nhwc_bn_addrelu_fwd_eval, "bn_addrelu_fwd_eval_nhwc"); - m.def("bn_addrelu_bwd_nhwc", &nhwc_bn_addrelu_bwd, "bn_addrelu_bwd_nhwc"); - - m.def("bn_addrelu_fwd_nhwc_occupancy", &nhwc_bn_addrelu_fwd_occupancy, "bn_addrelu_fwd_nhwc_occupancy"); - m.def("bn_addrelu_bwd_nhwc_occupancy", &nhwc_bn_addrelu_bwd_occupancy, "bn_addrelu_bwd_nhwc_occupancy"); -} - diff --git a/apex/contrib/csrc/groupbn/ipc.cu b/apex/contrib/csrc/groupbn/ipc.cu deleted file mode 100644 index 6b152a0..0000000 --- a/apex/contrib/csrc/groupbn/ipc.cu +++ /dev/null @@ -1,129 +0,0 @@ -#include -#include - -#include - -#include "compat.h" - - -#define cudaCheckErrors(msg) \ - do { \ - cudaError_t __err = cudaGetLastError(); \ - if (__err != cudaSuccess) { \ - fprintf(stderr, "Fatal error: %s (%s at %s:%d)\n", \ - msg, cudaGetErrorString(__err), \ - __FILE__, __LINE__); \ - fprintf(stderr, "*** FAILED - ABORTING\n"); \ - exit(1); \ - } \ - } while (0) - -template<> -struct std::hash { - size_t operator() (const cudaIpcMemHandle_t& handle) const { - size_t hash = 0; - uint8_t* ptr = (uint8_t*)&handle; - assert(sizeof(uint8_t) == 1); - for (int i=0; i -struct std::equal_to { - bool operator() (const cudaIpcMemHandle_t &lhs, - const cudaIpcMemHandle_t &rhs) const { - return (std::memcmp((void*) &lhs, - (void*) &rhs, - sizeof(cudaIpcMemHandle_t)) == 0); - } -}; - -namespace { - -namespace gpuipc { -//from: src/operator/nn/cudnn/nhwc_batch_norm_kernel.h -// The number of threads per pixel. -const int THREADS_PER_PIXEL = 16; -// The number of elements per ldg. -const int ELEMENTS_PER_LDG = 4; -// The number of reducing ops, each uses its own space : mean, var, dscale, dbias -const int REDUCE_OPS = 4; -// Maximum block.y supported - limited due to buffer allocation -const int MAX_BLOCK_Y = 256; -const int MAX_OFFSET = REDUCE_OPS*MAX_BLOCK_Y; -const int BYTES_PER_ELEM = 4; -// Buffer size per sync step -const int SINGLE_SYNC_BUFFER_BYTES = MAX_OFFSET*THREADS_PER_PIXEL*2*ELEMENTS_PER_LDG*BYTES_PER_ELEM; -}; - -class IpcMemHandleRegistry { -public: - void* getPtr(const cudaIpcMemHandle_t& handle, int64_t offset) { - if (registry_.count(handle) == 0) { - registry_.insert(std::make_pair(handle, RegistryEntry())); - registry_[handle].dev_ptr = ipcOpenMem(handle); - } - registry_[handle].ref_count++; - return (((uint8_t*)registry_[handle].dev_ptr) + offset); - } - - void releasePtr(const cudaIpcMemHandle_t& handle) { - if (registry_.count(handle) == 0) { - } - if (--registry_[handle].ref_count == 0) { - ipcCloseMem(registry_[handle].dev_ptr); - registry_.erase(handle); - } - } - - struct RegistryEntry { - void* dev_ptr; - int ref_count; - RegistryEntry() : dev_ptr(NULL) , ref_count(0) {} - }; - -protected: - std::unordered_map registry_; - - void* ipcOpenMem(const cudaIpcMemHandle_t& handle) { - void *data; - cudaIpcOpenMemHandle(&data, handle, cudaIpcMemLazyEnablePeerAccess); - cudaCheckErrors("ipc init"); - return data; - } - - void ipcCloseMem(void* dev_ptr) { - cudaIpcCloseMemHandle(dev_ptr); - cudaCheckErrors("ipc close"); - } - -}; - -} - -static IpcMemHandleRegistry ipc_mem_registry; - -int64_t get_buffer_size(const int bn_sync_steps) { - return bn_sync_steps * gpuipc::SINGLE_SYNC_BUFFER_BYTES; -} - -void* get_remote_data_ptr(const at::Tensor& handle, const int64_t offset) { - cudaIpcMemHandle_t my_handle; - memcpy((unsigned char *)(&my_handle), handle.DATA_PTR(), sizeof(my_handle)); - return ipc_mem_registry.getPtr(my_handle, offset); -} - -void close_remote_data(const at::Tensor& handle) { - cudaIpcMemHandle_t my_handle; - memcpy((unsigned char *)(&my_handle), handle.DATA_PTR(), sizeof(my_handle)); - ipc_mem_registry.releasePtr(my_handle); -} - -void* get_data_ptr( - const at::Tensor& data) { - return data.DATA_PTR(); -} diff --git a/apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h b/apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h deleted file mode 100644 index 683a4c1..0000000 --- a/apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h +++ /dev/null @@ -1,3021 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2018 by Contributors - * \file nhwc_batch_norm_kernel.h - * \brief CUDA NHWC Batch Normalization code - * \author Shankara Rao Thejaswi Nanditale, Dick Carter, Maxim Milakov, Evgeni Krimer -*/ -#ifndef MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_KERNEL_H_ -#define MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_KERNEL_H_ - -#ifdef __HIP_PLATFORM_HCC__ -#include -#include -#include -#endif -#include -#include - -#ifdef __HIP_PLATFORM_HCC__ -using bitmask_t = uint64_t; -#define BITMASK_OFFSET 2 -#define ONE_BITMASK 1UL -#else -using bitmask_t = unsigned int; -#define BITMASK_OFFSET 2 -#define ONE_BITMASK 1U -#endif - -#define DEVICE_FUNCTION static inline __device__ - -// CTA margin used by cooperative launch. Can be overridden by env var NHWC_BATCHNORM_LAUNCH_MARGIN. -#define NHWC_BATCHNORM_LAUNCH_MARGIN_MIN 3 -#define NHWC_BATCHNORM_LAUNCH_MARGIN_DEFAULT NHWC_BATCHNORM_LAUNCH_MARGIN_MIN - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -DEVICE_FUNCTION void syncwarp() { -#ifdef __HIP_PLATFORM_HCC__ - __builtin_amdgcn_wave_barrier(); -#else - __syncwarp(); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -DEVICE_FUNCTION T shfl_sync(T var, int src_lane) { -#ifdef __HIP_PLATFORM_HCC__ - return __shfl(var, src_lane); -#else - return __shfl_sync(0xFFFFFFFFU, var, src_lane); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -DEVICE_FUNCTION bitmask_t ballot(int predicate) { -#ifdef __HIP_PLATFORM_HCC__ - return __ballot(predicate); -#else - return __ballot_sync(0xFFFFFFFFU, predicate); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< typename T, int ELEMENTS_PER_LDG > -struct PackedStorage { - enum { PACKED_ELEMENTS_PER_LDG = ELEMENTS_PER_LDG }; - typedef T Type; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int ELEMENTS_PER_LDG > -struct PackedStorage { - enum { PACKED_ELEMENTS_PER_LDG = ELEMENTS_PER_LDG/2 }; - typedef int Type; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int N > -DEVICE_FUNCTION void from_float(int (&dst)[N], const float (&src)[2*N]) { - // Convert from two f32s to two f16s (mantissa LSB rounds to nearest even) - // (From 64-bit to 32-bit) - half *dst_ = (half *) dst; - #pragma unroll - for (int i = 0; i < N; ++i) { -#ifdef __HIP_PLATFORM_HCC__ - dst_[2*i] = __float2half(src[2*i]); - dst_[2*i+1] = __float2half(src[2*i+1]); -#else - uint16_t lo, hi; - asm volatile("cvt.rn.f16.f32 %0, %1;" : "=h"(lo) : "f"(src[2*i+0])); - asm volatile("cvt.rn.f16.f32 %0, %1;" : "=h"(hi) : "f"(src[2*i+1])); - asm volatile("mov.b32 %0, {%1, %2};" : "=r"(dst[i]) : "h"(lo), "h"(hi)); -#endif - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int N > -DEVICE_FUNCTION void from_float(float (&dst)[N], const float (&src)[N]) { - #pragma unroll - for (int i = 0; i < N; ++i) { - dst[i] = src[i]; - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int N > -DEVICE_FUNCTION void to_float(float (&dst)[2*N], int (&src)[N]) { - // Convert from two f16s to two f32s (From 32-bit to 64-bit) - #pragma unroll - for (int i = 0; i < N; ++i) { -#ifdef __HIP_PLATFORM_HCC__ - half *src_ = (half *) src; - dst[2*i] = __half2float(src_[2*i]); - dst[2*i+1] = __half2float(src_[2*i+1]); -#else - uint16_t lo, hi; - asm volatile("mov.b32 {%0, %1}, %2;" : "=h"(lo), "=h"(hi) : "r"(src[i])); - asm volatile("cvt.f32.f16 %0, %1;" : "=f"(dst[2*i+0]) : "h"(lo)); - asm volatile("cvt.f32.f16 %0, %1;" : "=f"(dst[2*i+1]) : "h"(hi)); -#endif - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int N > -DEVICE_FUNCTION void to_float(float (&dst)[N], float (&src)[N]) { - #pragma unroll - for (int i = 0; i < N; ++i) { - dst[i] = src[i]; - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -DEVICE_FUNCTION void ldg(int (&dst)[1], const uint16_t *gmem) { - dst[0] = __ldg((const int*) gmem); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -DEVICE_FUNCTION void ldg_stream(int (&dst)[1], const uint16_t *gmem) { -#ifdef __HIP_PLATFORM_HCC__ - dst[0] = __ldg((const int*) gmem); -#else - unsigned int tmp; - asm volatile ("ld.global.cs.nc.s32 %0, [%1];" : "=r"(tmp) : "l" ((const uint *)gmem)); - dst[0] = tmp; -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -DEVICE_FUNCTION void ldg(int (&dst)[2], const uint16_t *gmem) { - int2 tmp = __ldg((const int2*) gmem); - dst[0] = tmp.x; - dst[1] = tmp.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -DEVICE_FUNCTION void ldg_stream(int (&dst)[2], const uint16_t *gmem) { -#ifdef __HIP_PLATFORM_HCC__ - int2 tmp = __ldg((const int2*) gmem); - dst[0] = tmp.x; - dst[1] = tmp.y; -#else - int2 tmp; - asm volatile ("ld.global.cs.nc.v2.s32 {%0,%1}, [%2];" - : "=r"(tmp.x), "=r"(tmp.y) : "l"((const int2 *)gmem)); - dst[0] = tmp.x; - dst[1] = tmp.y; -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int N > -DEVICE_FUNCTION void ldg(float (&dst)[N], const uint16_t *gmem) { - int tmp[N/2]; - ldg(tmp, gmem); - to_float(dst, tmp); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int N > -DEVICE_FUNCTION void ldg_stream(float (&dst)[N], const uint16_t *gmem) { - int tmp[N/2]; - ldg_stream(tmp, gmem); - to_float(dst, tmp); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -DEVICE_FUNCTION void stg(uint16_t *gmem, int (&src)[1]) { - reinterpret_cast(gmem)[0] = src[0]; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -DEVICE_FUNCTION void stg_stream(uint16_t *gmem, int (&src)[1]) { -#ifdef __HIP_PLATFORM_HCC__ - reinterpret_cast(gmem)[0] = src[0]; -#else - unsigned int tmp = src[0]; - asm volatile ("st.global.cs.s32 [%0], %1;" - :: "l"((uint *)gmem) , "r"(tmp)); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -DEVICE_FUNCTION void stg(uint16_t *gmem, int (&src)[2]) { -#ifdef __HIP_PLATFORM_HCC__ - half *gmem_ = (half *) gmem; - half *src_ = (half *) src; - for (int i = 0; i < 4; i++) { - gmem_[i] = src_[i]; - } -#else - reinterpret_cast(gmem)[0] = make_int2(src[0], src[1]); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -DEVICE_FUNCTION void stg_stream(uint16_t *gmem, int (&src)[2]) { -#ifdef __HIP_PLATFORM_HCC__ - half *gmem_ = (half *) gmem; - half *src_ = (half *) src; - for (int i = 0; i < 4; i++) { - gmem_[i] = src_[i]; - } -#else - asm volatile ("st.global.cs.v2.s32 [%0], {%1,%2};" - :: "l"((uint *)gmem) , "r"(src[0]), "r"( src[1])); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int N > -DEVICE_FUNCTION void stg(uint16_t *gmem, float (&src)[N]) { - int tmp[N/2]; - from_float(tmp, src); - stg(gmem, tmp); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int N > -DEVICE_FUNCTION void stg_stream(uint16_t *gmem, float (&src)[N]) { - int tmp[N/2]; - from_float(tmp, src); - stg_stream(gmem, tmp); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef __HIP_PLATFORM_HCC__ -DEVICE_FUNCTION void stg(uint16_t *gmem, float (&src)[4]) { - half *gmem_ = (half *) gmem; - gmem_[0] = __float2half(src[0]); - gmem_[1] = __float2half(src[1]); - gmem_[2] = __float2half(src[2]); - gmem_[3] = __float2half(src[3]); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -DEVICE_FUNCTION void stg_stream(uint16_t *gmem, float (&src)[4]) { - half *gmem_ = (half *) gmem; - gmem_[0] = __float2half(src[0]); - gmem_[1] = __float2half(src[1]); - gmem_[2] = __float2half(src[2]); - gmem_[3] = __float2half(src[3]); -} -#endif - -DEVICE_FUNCTION void read_from_gmem(float (&dst)[2], const float *gmem, int idx) { -#ifdef __HIP_PLATFORM_HCC__ - dst[0] = gmem[2*idx]; - dst[1] = gmem[2*idx+1]; -#else - float2 tmp = __ldg(reinterpret_cast(&gmem[2*idx])); - dst[0] = tmp.x; - dst[1] = tmp.y; -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -DEVICE_FUNCTION void read_from_gmem(float (&dst)[4], const float *gmem, int idx) { -#ifdef __HIP_PLATFORM_HCC__ - dst[0] = gmem[4*idx]; - dst[1] = gmem[4*idx+1]; - dst[2] = gmem[4*idx+2]; - dst[3] = gmem[4*idx+3]; -#else - float4 tmp = __ldg(reinterpret_cast(&gmem[4*idx])); - dst[0] = tmp.x; - dst[1] = tmp.y; - dst[2] = tmp.z; - dst[3] = tmp.w; -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -DEVICE_FUNCTION void read_from_smem(float (&x)[2], const float *smem, int idx) { -#ifdef __HIP_PLATFORM_HCC__ - x[0] = smem[2*idx]; - x[1] = smem[2*idx+1]; -#else - float2 tmp = *(const float2*) &smem[2*idx]; - x[0] = tmp.x; - x[1] = tmp.y; -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -DEVICE_FUNCTION void read_from_smem(int (&x)[1], const int *smem, int idx) { - x[0] = smem[idx]; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -DEVICE_FUNCTION void read_from_smem(float (&x)[4], const float *smem, int idx) { -#ifdef __HIP_PLATFORM_HCC__ - x[0] = smem[4*idx]; - x[1] = smem[4*idx+1]; - x[2] = smem[4*idx+2]; - x[3] = smem[4*idx+3]; -#else - float4 tmp = *(const float4*) &smem[4*idx]; - x[0] = tmp.x; - x[1] = tmp.y; - x[2] = tmp.z; - x[3] = tmp.w; -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -DEVICE_FUNCTION void read_from_smem(int (&x)[2], const int *smem, int idx) { -#ifdef __HIP_PLATFORM_HCC__ - x[0] = smem[2*idx]; - x[1] = smem[2*idx+1]; -#else - int2 tmp = *(const int2*) &smem[2*idx]; - x[0] = tmp.x; - x[1] = tmp.y; -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -DEVICE_FUNCTION void write_to_gmem(float *gmem, int idx, const float (&src)[2]) { -#ifdef __HIP_PLATFORM_HCC__ - gmem[2*idx] = src[0]; - gmem[2*idx+1] = src[1]; -#else - reinterpret_cast(&gmem[2*idx])[0] = make_float2(src[0], src[1]); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -DEVICE_FUNCTION void write_to_gmem(float *gmem, int idx, const float (&src)[4]) { -#ifdef __HIP_PLATFORM_HCC__ - gmem[4*idx] = src[0]; - gmem[4*idx+1] = src[1]; - gmem[4*idx+2] = src[2]; - gmem[4*idx+3] = src[3]; -#else - reinterpret_cast(&gmem[4*idx])[0] = make_float4(src[0], src[1], src[2], src[3]); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -DEVICE_FUNCTION void scaled_write_to_gmem(float *gmem, int idx, const float (&src)[4], const float coeff) { -#ifdef __HIP_PLATFORM_HCC__ - gmem[4*idx] = src[0]*coeff; - gmem[4*idx+1] = src[1]*coeff; - gmem[4*idx+2] = src[2]*coeff; - gmem[4*idx+3] = src[3]*coeff; -#else - reinterpret_cast(&gmem[4*idx])[0] = make_float4(src[0]*coeff, src[1]*coeff, src[2]*coeff, src[3]*coeff); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -DEVICE_FUNCTION void write_to_smem(float *smem, int idx, const float (&x)[2]) { -#ifdef __HIP_PLATFORM_HCC__ - smem[2*idx] = x[0]; - smem[2*idx+1] = x[1]; -#else - reinterpret_cast(&smem[2*idx])[0] = make_float2(x[0], x[1]); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -DEVICE_FUNCTION void write_to_smem(int *smem, int idx, const int (&x)[1]) { - smem[idx] = x[0]; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -DEVICE_FUNCTION void write_to_smem(float *smem, int idx, const float (&x)[4]) { -#ifdef __HIP_PLATFORM_HCC__ - smem[4*idx] = x[0]; - smem[4*idx+1] = x[1]; - smem[4*idx+2] = x[2]; - smem[4*idx+3] = x[3]; -#else - reinterpret_cast(&smem[4*idx])[0] = make_float4(x[0], x[1], x[2], x[3]); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -DEVICE_FUNCTION void write_to_smem(int *smem, int idx, const int (&x)[2]) { -#ifdef __HIP_PLATFORM_HCC__ - smem[2*idx] = x[0]; - smem[2*idx+1] = x[1]; -#else - reinterpret_cast(&smem[2*idx])[0] = make_int2(x[0], x[1]); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int N > -DEVICE_FUNCTION void zero_array(int (&dst)[N]) { - #pragma unroll - for (int i = 0; i < N; ++i) { - dst[i] = 0; - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int N > -DEVICE_FUNCTION void zero_array(float (&dst)[N]) { - #pragma unroll - for (int i = 0; i < N; ++i) { - dst[i] = 0.f; - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -DEVICE_FUNCTION void add(float (&x)[N], const float (&y)[N]) { - #pragma unroll - for (int i = 0; i < N; ++i) { - x[i] += y[i]; - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -DEVICE_FUNCTION void multiply(float (&x)[N], const float (&y)[N]) { - #pragma unroll - for (int i = 0; i < N; ++i) { - x[i] *= y[i]; - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -DEVICE_FUNCTION void scale_(float (&x)[N], float scalar) { - #pragma unroll - for (int i = 0; i < N; ++i) { - x[i] *= scalar; - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -DEVICE_FUNCTION void normalize(float (&x)[N], const float (&bias)[N], - const float (&scale)[N], const float (&m1)[N]) { - #pragma unroll - for (int i = 0; i < N; ++i) { - x[i] = bias[i] + scale[i] * (x[i] - m1[i]); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -DEVICE_FUNCTION Storage relu(Storage in) { - Storage zero = (Storage)0.f; - return (in < zero)? zero : in; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -DEVICE_FUNCTION void relu_activation(float (&x)[N]) { - #pragma unroll - for (int i = 0; i < N; ++i) { - x[i] = relu(x[i]); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -template< int THREADS_PER_CTA > -DEVICE_FUNCTION void parallel_sums_16x2(float *smem, float (&x)[4], int nhw, - void* params_my_data, void** params_pair_datas, int off, - const int magic, - const int sync_iters) { - // The size of a warp. -#ifdef __HIP_PLATFORM_HCC__ - const int THREADS_PER_WARP = 64; -#else - const int THREADS_PER_WARP = 32; -#endif - // The number of warps in a CTA. - const int WARPS_PER_CTA = THREADS_PER_CTA / THREADS_PER_WARP; - // The number of threads per pixel. - const int THREADS_PER_PIXEL = 16; - // The number of elements per ldg. - const int ELEMENTS_PER_LDG = 4; - // The number of reducing ops, each uses its own space : mean, var, dscale, dbias - const int REDUCE_OPS = 4; - // Maximum block.y supported - limited due to buffer allocation - const int MAX_BLOCK_Y = 256; - const int MAX_OFFSET = REDUCE_OPS*MAX_BLOCK_Y; - // The warp decomposition. - const int warp_id = threadIdx.x / THREADS_PER_WARP; - const int lane_id = threadIdx.x % THREADS_PER_WARP; - // total size of data per sync iter - const int data_total = MAX_OFFSET*THREADS_PER_PIXEL*ELEMENTS_PER_LDG*2; - -#ifdef __HIP_PLATFORM_HCC__ - for (int offset = THREADS_PER_PIXEL; offset <= THREADS_PER_WARP >> 1; offset <<= 1) { - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - x[i] += shfl_sync(x[i], offset + lane_id); - } - } -#else - #pragma unroll - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - x[i] += shfl_sync(x[i], THREADS_PER_PIXEL+lane_id); - } -#endif - - - // The warp leaders, write to SMEM. - if (lane_id < THREADS_PER_PIXEL) { - write_to_smem(smem, warp_id*THREADS_PER_PIXEL + lane_id, x); - } - - // The data is in SMEM. Do the final reduction. - __syncthreads(); - - // The 1st warp does all the work. - // We do the final reduction each half-warp sequentially reduces the final values. - if (warp_id == 0) { - read_from_smem(x, smem, threadIdx.x); - - #pragma unroll - for (int offset = 1; - offset < WARPS_PER_CTA/(THREADS_PER_WARP / THREADS_PER_PIXEL); ++offset) { - float y[ELEMENTS_PER_LDG]; - // Read the mean and variance from the other pixel. - read_from_smem(y, smem, threadIdx.x + offset*THREADS_PER_WARP); - // Compute the updated sum. - add(x, y); - } - -#ifdef __HIP_PLATFORM_HCC__ - for (int offset = THREADS_PER_WARP >> 1; offset >= THREADS_PER_PIXEL; offset >>= 1) { for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - x[i] += shfl_sync(x[i], offset + lane_id); - } - } -#else - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - x[i] += shfl_sync(x[i], THREADS_PER_PIXEL+lane_id); - } -#endif - - // Make sure the data was read from SMEM. - syncwarp(); - - // Store the final values. - if (threadIdx.x < THREADS_PER_PIXEL) { - // probably could do it earlier, before sync - -#ifndef __HIP_PLATFORM_HCC__ // bn_group > 1 is not enabled on HIP - for (int sync_iter=0; sync_iter < sync_iters; ++sync_iter) { - //float* params_pair_data = (reinterpret_cast(params_pair_datas))[sync_iter]; - void* params_pair_data = params_pair_datas[sync_iter]; - - // skip the space consumed by previous sync iterations - const int xbuf_offset = sync_iter*data_total; - // data starts after flags, but have to skip previous - const int data_offset = xbuf_offset - + off*ELEMENTS_PER_LDG*THREADS_PER_PIXEL*2 - + ELEMENTS_PER_LDG*threadIdx.x*2; - - // after sums for this GPU were computed, let CTA0 broadcast the sum to over GPU - if (blockIdx.x == 0) { - volatile float * write_data = - &((reinterpret_cast(params_pair_data))[data_offset]); - - // write the data to memory region to be reflected to other GPU - asm volatile ("st.global.wt.v4.b32 [%0], {%1,%2,%3,%4};" - :: "l"(write_data) , "f"(x[0]), "r"(magic), "f"(x[2]), "r"(magic)); - - asm volatile ("st.global.wt.v4.b32 [%0], {%1,%2,%3,%4};" - :: "l"(write_data+4) , "f"(x[1]), "r"(magic), "f"(x[3]), "r"(magic)); - } - - // now each CTA (on each GPU) reads the data written by CTA 0 of the other GPU - volatile float * read_data = - &((reinterpret_cast(params_my_data))[data_offset]); - - float other[4]; - uint32_t other_flag_a, other_flag_b; - do { - asm volatile ("ld.volatile.global.v4.b32 {%0, %1, %2, %3}, [%4];" - : "=f"(other[0]), "=r"(other_flag_a), "=f"(other[2]), "=r"(other_flag_b) : "l"(read_data)); - } while ((other_flag_a != magic) || (other_flag_b != magic)); - - do { - asm volatile ("ld.volatile.global.v4.b32 {%0, %1, %2, %3}, [%4];" - : "=f"(other[1]), "=r"(other_flag_a), "=f"(other[3]), "=r"(other_flag_b) : "l"(read_data+4)); - } while ((other_flag_a != magic) || (other_flag_b != magic)); - - add(x, other); - } -#endif - // finally, after syncing up and accounting for partial sums from - // other GPUs as required, write the result - - - write_to_smem(smem, threadIdx.x, x); - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int THREADS_PER_CTA > -DEVICE_FUNCTION void parallel_sums_8x4(float *smem, float (&x)[4], int nhw) { - // The size of a warp. -#ifdef __HIP_PLATFORM_HCC__ - const int THREADS_PER_WARP = 64; -#else - const int THREADS_PER_WARP = 32; -#endif - // The number of warps in a CTA. - const int WARPS_PER_CTA = THREADS_PER_CTA / THREADS_PER_WARP; - // The number of threads per pixel. - const int THREADS_PER_PIXEL = 8; - // The number of elements per ldg. - const int ELEMENTS_PER_LDG = 4; - // The warp decomposition. - const int warp_id = threadIdx.x / THREADS_PER_WARP; - const int lane_id = threadIdx.x % THREADS_PER_WARP; - - #pragma unroll - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - x[i] += shfl_sync(x[i], THREADS_PER_PIXEL+lane_id); - x[i] += shfl_sync(x[i], THREADS_PER_PIXEL*2+lane_id); - } - - // The warp leaders, write to SMEM. - if (lane_id < THREADS_PER_PIXEL) { - write_to_smem(smem, warp_id*THREADS_PER_PIXEL + lane_id, x); - } - - // The data is in SMEM. Do the final reduction. - __syncthreads(); - - // The 1st warp does all the work. - // We do the final reduction each half-warp sequentially reduces the final values. - if (warp_id == 0) { - read_from_smem(x, smem, threadIdx.x); - - #pragma unroll - for (int offset = 1; - offset < WARPS_PER_CTA/(THREADS_PER_WARP / THREADS_PER_PIXEL); ++offset) { - float y[ELEMENTS_PER_LDG]; - // Read the mean and variance from the other pixel. - read_from_smem(y, smem, threadIdx.x + offset*THREADS_PER_WARP); - // Compute the updated sum. - add(x, y); - } - - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - x[i] += shfl_sync(x[i], THREADS_PER_PIXEL+lane_id); - x[i] += shfl_sync(x[i], THREADS_PER_PIXEL*2+lane_id); - } - - // Make sure the data was read from SMEM. - syncwarp(); - - // Store the final values. - if (threadIdx.x < THREADS_PER_PIXEL) { - write_to_smem(smem, threadIdx.x, x); - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int THREADS_PER_CTA, int THREADS_PER_PIXEL, int ELEMENTS_PER_LDG > -DEVICE_FUNCTION void parallel_sums(float *smem, float (&x)[ELEMENTS_PER_LDG], int nhw) { - // The size of a warp. -#ifdef __HIP_PLATFORM_HCC__ - const int THREADS_PER_WARP = 64; -#else - const int THREADS_PER_WARP = 32; -#endif - const int WARPS_PER_CTA = THREADS_PER_CTA / THREADS_PER_WARP; - // The warp decomposition. - const int warp_id = threadIdx.x / THREADS_PER_WARP; - const int lane_id = threadIdx.x % THREADS_PER_WARP; - // total size of data per sync iter - -#ifdef __HIP_PLATFORM_HCC__ - for (int offset = THREADS_PER_PIXEL; offset <= THREADS_PER_WARP >> 1; offset <<= 1) { - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - x[i] += shfl_sync(x[i], offset + lane_id); - } - } -#else - #pragma unroll - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - x[i] += shfl_sync(x[i], THREADS_PER_PIXEL+lane_id); - } -#endif - - - // The warp leaders, write to SMEM. - if (lane_id < THREADS_PER_PIXEL) { - write_to_smem(smem, warp_id*THREADS_PER_PIXEL + lane_id, x); - } - - // The data is in SMEM. Do the final reduction. - __syncthreads(); - - // The 1st warp does all the work. - // We do the final reduction each half-warp sequentially reduces the final values. - if (warp_id == 0) { - read_from_smem(x, smem, threadIdx.x); - - #pragma unroll - for (int offset = 1; - offset < WARPS_PER_CTA/(THREADS_PER_WARP / THREADS_PER_PIXEL); ++offset) { - float y[ELEMENTS_PER_LDG]; - // Read the mean and variance from the other pixel. - read_from_smem(y, smem, threadIdx.x + offset*THREADS_PER_WARP); - // Compute the updated sum. - add(x, y); - } - -#ifdef __HIP_PLATFORM_HCC__ - for (int offset = THREADS_PER_WARP >> 1; offset >= THREADS_PER_PIXEL; offset >>= 1) { - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - x[i] += shfl_sync(x[i], offset + lane_id); - } - } -#else - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - x[i] += shfl_sync(x[i], THREADS_PER_PIXEL+lane_id); - } -#endif - - // Make sure the data was read from SMEM. - syncwarp(); - - // Store the final values. - if (threadIdx.x < THREADS_PER_PIXEL) { - // probably could do it earlier, before sync - write_to_smem(smem, threadIdx.x, x); - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< int THREADS_PER_PIXEL, int ELEMENTS_PER_LDG > -struct ParallelSums { - template< int THREADS_PER_CTA > - DEVICE_FUNCTION void dispatch(float *smem, float (&x)[ELEMENTS_PER_LDG], int nhw) { - parallel_sums(smem, x, nhw); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// -/* -template<> -struct ParallelSums<16, 4> { - template< int THREADS_PER_CTA > - DEVICE_FUNCTION void dispatch(float *smem, float (&x)[4], int nhw) { - parallel_sums_16x2(smem, x, nhw, 0, 0, 0, 0, 0); - } - - template< int THREADS_PER_CTA > - DEVICE_FUNCTION void dispatchX(float *smem, float (&x)[4], int nhw, void* params_my_data, void** params_pair_datas, int off, const int magic, const unsigned int& sync_iters) { - parallel_sums_16x2(smem, x, nhw, params_my_data, params_pair_datas, off, magic, sync_iters); - } -}; - -template<> -struct ParallelSums<8, 4> { - template< int THREADS_PER_CTA > - DEVICE_FUNCTION void dispatch(float *smem, float (&x)[4], int nhw) { - parallel_sums_8x4(smem, x, nhw); - } -}; -*/ - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static inline int div_up(int m, int n) { - return (m + n - 1) / n; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// It is expected that all threads in the CTA enter this function! -DEVICE_FUNCTION void inter_block_sync(int* gmem_retired_ctas, int expected_count, bool master) { - - // Register the CTA. - if (threadIdx.x == 0) { - // Issue the membar. - __threadfence(); - // Notify that the CTA is done. - int val_to_add = 1; - if (master) { - val_to_add = -(expected_count - 1); - } - atomicAdd(gmem_retired_ctas, val_to_add); - } - - // Are all CTAs done? - if (threadIdx.x == 0) { - int retired_ctas = -1; - do { - __threadfence(); -#ifdef __HIP_PLATFORM_HCC__ - retired_ctas = __ldg((const int*) gmem_retired_ctas); -#else - asm volatile ("ld.global.cg.b32 %0, [%1];" - : "=r"(retired_ctas) : "l"(gmem_retired_ctas)); -#endif - } while (retired_ctas != 0); - } - __syncthreads(); - -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct NhwcBatchNormFwdInferenceParams { - // The input/output tensors. - uint16_t *gmem_src, *gmem_dst, *gmem_src1; - // the final mean and variance as calculated during the training process - float *gmem_mean, *gmem_var; - // The bias/scale. - float *gmem_bias, *gmem_scale; - // The dimensions. - int nhw, c; - // epsilon - float var_eps; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// No DESIRED_OCCUPANCY launch bounds needed, as this is not launched cooperatively -template< - typename Storage, - int THREADS_PER_CTA, - int THREADS_PER_PIXEL, - int ELEMENTS_PER_LDG, - bool USE_RELU, - bool USE_ADD_RELU -> -__global__ __launch_bounds__(THREADS_PER_CTA) - void nhwc_batch_norm_fwd_inference(NhwcBatchNormFwdInferenceParams params) { - // The number of pixels loaded in a single LDG. - const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL; - // The number of C elements per CTA. - const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG; - - // The start position in the NHW dimension where the CTA starts. - const int cta_nhw_stride = gridDim.x * PIXELS_PER_LDG; - // Compute the NHW coordinate of the thread in the CTA. - const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL; - // thread's starting point in NHW - const int thread_nhw = thread_in_cta_nhw + blockIdx.x * PIXELS_PER_LDG; - - // The position in the C dimension where the CTA starts. - const int cta_c = blockIdx.y * C_ELEMENTS_PER_CTA; - // Compute the C coordinate of the thread in the CTA. - const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL; - // Compute the C coordinate of the thread. - const int thread_c = cta_c + thread_in_cta_c*ELEMENTS_PER_LDG; - - // Is the thread working on a valid C dimension? - const int is_valid_c = thread_c < params.c; - - float mean[ELEMENTS_PER_LDG], var[ELEMENTS_PER_LDG]; - float scale[ELEMENTS_PER_LDG], bias[ELEMENTS_PER_LDG]; - zero_array(mean); - zero_array(var); - zero_array(scale); - zero_array(bias); - if (is_valid_c) { - read_from_gmem(var, ¶ms.gmem_var[cta_c], thread_in_cta_c); - read_from_gmem(scale, ¶ms.gmem_scale[cta_c], thread_in_cta_c); - read_from_gmem(mean, ¶ms.gmem_mean[cta_c], thread_in_cta_c); - read_from_gmem(bias, ¶ms.gmem_bias[cta_c], thread_in_cta_c); - } - - // Update the scale with the stddev and eps. - #pragma unroll - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - scale[i] *= rsqrtf(var[i] + params.var_eps); - } - - // The base pointers for reading/writing - uint16_t *const gmem_src = ¶ms.gmem_src[thread_c]; - uint16_t *const gmem_dst = ¶ms.gmem_dst[thread_c]; - const uint16_t *gmem_src1 = nullptr; - if (USE_ADD_RELU) { - gmem_src1 = ¶ms.gmem_src1[thread_c]; - } - - // apply BN - for (int nhw = thread_nhw; nhw < params.nhw; nhw += cta_nhw_stride) { - float x_math[ELEMENTS_PER_LDG]; - zero_array(x_math); - if (is_valid_c) { - ldg(x_math, &gmem_src[nhw*params.c]); - } - - // Normalize and apply activation function - normalize(x_math, bias, scale, mean); - if (USE_ADD_RELU) { - float x1_math[ELEMENTS_PER_LDG]; - ldg(x1_math, &gmem_src1[nhw*params.c]); - add(x_math, x1_math); - relu_activation(x_math); - } else if (USE_RELU) { - relu_activation(x_math); - } - - if (is_valid_c) { - stg(&gmem_dst[nhw*params.c], x_math); - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct NhwcBatchNormFwdParams { - // The input/output tensors. - uint16_t *gmem_src, *gmem_dst, *gmem_src1; - // The bias/scale. - float *gmem_bias, *gmem_scale; - // running mean/var (refer BN API from cudnn doc) - float *gmem_running_mean, *gmem_running_var; - // saved mean/var (refer BN API from cudnn doc) - float *gmem_saved_mean, *gmem_saved_var; - // ReLU bitmask - bitmask_t *gmem_relu_bitmask; - // The dimensions. - int nhw, c; - // factor to scale sum of squared errors to get saved variance. Must be 1/nhw. - float svar_inv_count; - // factor to scale sum of squared errors to get running variance. Should be 1/nhw or 1/(nhw-1). - float rvar_inv_count; - // The buffer to do the reduction for mean, stddev and count. - float *gmem_sums; - // The buffer to count items in the different CTAs. - int *gmem_counts; - // The counters of retired CTAs. - int *gmem_retired_ctas; - // The epsilon to apply to the computation of the variance. - float var_eps; - // outer loop count - int outer_loops; - // exponential average factor - float exp_avg_factor; - // number of CTAs along .x dimension - int c_blks; - - void* my_data; - void* pair_datas[4]; - int magic; - int sync_iters; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< - typename Storage, - int THREADS_PER_CTA, - int THREADS_PER_PIXEL, - int PIXELS_PER_THREAD_IN_REGISTERS, - int PIXELS_PER_THREAD_IN_SMEM, - int ELEMENTS_PER_LDG, - int USE_ONLINE_APPROACH, - int OUTER_LOOPS_, - bool USE_RELU, - bool USE_ADD_RELU, - int DESIRED_OCCUPANCY -> -__global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) - void nhwc_batch_norm_fwd(NhwcBatchNormFwdParams params) { - // The number of pixels loaded in a single LDG. - const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL; - // The number of pixels computed per CTA stored in registers. - const int PIXELS_PER_CTA_IN_REGISTERS = PIXELS_PER_THREAD_IN_REGISTERS * PIXELS_PER_LDG; - // The number of pixels computed per CTA stored in SMEM. - const int PIXELS_PER_CTA_IN_SMEM = PIXELS_PER_THREAD_IN_SMEM*PIXELS_PER_LDG; - // The number of C elements per CTA. - const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG; - - // Shared memory to do CTA-wide parallel sums. - __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/warpSize)*ELEMENTS_PER_LDG]; - - // Compute the NHW coordinate of the thread in the CTA. - const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL; - - // The adapter for the storage. - typedef PackedStorage PackedStorage_; - // The data type for packed storage in SMEM. - typedef typename PackedStorage_::Type PackedStorageType; - // The number of elements in the packed storage. - const int PACKED_ELEMENTS_PER_LDG = PackedStorage_::PACKED_ELEMENTS_PER_LDG; - // Registers to keep the data live for the persistent approach. - PackedStorageType x_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG]; - - // Shared memory buffer to store the extra pixels. - extern __shared__ PackedStorageType smem_storage_packed[]; - -#ifdef __HIP_PLATFORM_HCC__ - const half zero_h = __float2half(0.0F); -#endif - - for (int c_blk_index = blockIdx.y; c_blk_index < params.c_blks; c_blk_index += gridDim.y) { - // The position in the NHW dimension where the CTA starts. - int cta_nhw_regs = blockIdx.x * PIXELS_PER_CTA_IN_REGISTERS; - // The position in the NHW dimension where the CTA starts for the portion in SMEM. - int cta_nhw_smem = blockIdx.x * PIXELS_PER_CTA_IN_SMEM; - - // The position in the C dimension where the CTA starts. - const int cta_c = c_blk_index * C_ELEMENTS_PER_CTA; - // Compute the C coordinate of the thread in the CTA. - const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL; - // Compute the C coordinate of the thread. - int thread_c = cta_c + thread_in_cta_c*ELEMENTS_PER_LDG; - - // Is the thread working on a valid C dimension? - const int is_valid_c = thread_c < params.c; - - // Clamp thread_c so that we load from valid locations even if we don't use the value - if (!is_valid_c) - thread_c = params.c - 4; - - // Single pass numerically stable algorithm, see: - // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online_algorithm - // - // n = 0, mean = 0.0, M2 = 0.0 - // - // for x in data: - // n += 1 - // delta = x - mean - // mean += delta/n - // delta2 = x - mean - // M2 += delta*delta2 - // - // if n < 2: - // return float('nan') - // else: - // return M2 / (n - 1) - - // Register to store the number of elements read so far. - float count = 0.f, mean[ELEMENTS_PER_LDG], m2[ELEMENTS_PER_LDG]; - #pragma unroll - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - mean[i] = 0.f; - m2[i] = 0.f; - } - - // The number of elements loaded by this CTA. - int cta_count = 0; - // The base pointer to load from. - const uint16_t *gmem_src = ¶ms.gmem_src[thread_c]; - - // outer loops - int OUTER_LOOPS = OUTER_LOOPS_ == 1? 1 : params.outer_loops; - // Load the batch of elements. Compute the mean/var across those elements. - const int pixels_per_iteration = PIXELS_PER_CTA_IN_REGISTERS*gridDim.x; - - if (OUTER_LOOPS_ != 1) { - // We cannot load everything to store persistently, so let's makes sure registers and - // smem are fully utilized, offset is evenly divisible by 32 - int offset = (pixels_per_iteration * OUTER_LOOPS + - PIXELS_PER_CTA_IN_SMEM * gridDim.x - params.nhw) & ~31; - cta_nhw_regs -= offset; - cta_nhw_smem -= offset; - } - - #pragma unroll 1 - for (int loop_i = 0; loop_i < OUTER_LOOPS; ++loop_i) { - // The nhw position. - int nhw_regs = cta_nhw_regs + loop_i*pixels_per_iteration; - // Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!! - cta_count += max(min(nhw_regs + PIXELS_PER_CTA_IN_REGISTERS, params.nhw) - - max(nhw_regs, 0), 0); - - // Load the data and compute the local mean/sum and the variance. - if (USE_ONLINE_APPROACH) { - // Read the elements from memory. - float is_valid[PIXELS_PER_THREAD_IN_REGISTERS]; - #pragma unroll - for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { - const int idx = nhw_regs + thread_in_cta_nhw + i*PIXELS_PER_LDG; - zero_array(x_storage[i]); - is_valid[i] = 0.f; - if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) { -#ifndef __HIP_PLATFORM_HCC__ - if (loop_i == OUTER_LOOPS - 1) { - ldg_stream(x_storage[i], &gmem_src[idx*params.c]); - } else { -#endif - ldg(x_storage[i], &gmem_src[idx*params.c]); -#ifndef __HIP_PLATFORM_HCC__ - } -#endif - is_valid[i] = 1.f; - } - } - - // Do the math. - #pragma unroll - for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { - // Convert to float. - float x_math[ELEMENTS_PER_LDG]; - to_float(x_math, x_storage[i]); - - // Update the count. - count += is_valid[i]; - // Invert the count. - float inv_count = is_valid[i] ? 1.f / count : 0.f; - - // Update the mean and m2 using deltas. - #pragma unroll - for (int j = 0; j < ELEMENTS_PER_LDG; ++j) { - float delta0 = x_math[j] - mean[j]; - mean[j] += delta0 * inv_count; - float delta1 = x_math[j] - mean[j]; - m2[j] += delta0 * delta1 * is_valid[i]; - } - } - } else { - // Read the elements from memory. - #pragma unroll - for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { - const int idx = nhw_regs + thread_in_cta_nhw + i*PIXELS_PER_LDG; - zero_array(x_storage[i]); - if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) { - if (loop_i == OUTER_LOOPS - 1) { - ldg_stream(x_storage[i], &gmem_src[idx*params.c]); - } else { - ldg(x_storage[i], &gmem_src[idx*params.c]); - } - count += 1.f; - } - } - - // Sum the elements in registers. - #pragma unroll - for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { - // Convert to float. - float x_math[ELEMENTS_PER_LDG]; - to_float(x_math, x_storage[i]); - - // Update the mean and m2 using deltas. - #pragma unroll - for (int j = 0; j < ELEMENTS_PER_LDG; ++j) { - mean[j] += x_math[j]; - } - } - - // Compute the mean. - float inv_count = 1.f / count; - #pragma unroll - for (int j = 0; j < ELEMENTS_PER_LDG; ++j) { - mean[j] *= inv_count; - } - - // Compute the variance. - #pragma unroll - for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { - // Convert to float. - float x_math[ELEMENTS_PER_LDG]; - to_float(x_math, x_storage[i]); - - // Is it a valid pixel? - float is_valid = i < static_cast(count) ? 1.f : 0.f; - // Update the mean and m2 using deltas. - #pragma unroll - for (int j = 0; j < ELEMENTS_PER_LDG; ++j) { - m2[j] += (x_math[j] - mean[j]) * (x_math[j] - mean[j]) * is_valid; - } - } - } - } - - // The elements to load and store in SMEM. - int smem_nhw = OUTER_LOOPS*pixels_per_iteration + cta_nhw_smem; - // Load elements from SMEM, update the CTA count. - int pixels_in_smem = min(smem_nhw + PIXELS_PER_CTA_IN_SMEM, params.nhw) - max(smem_nhw, 0); - if (pixels_in_smem > 0) { - cta_count += pixels_in_smem; - for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) { - const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG; - float is_pixel_valid = (((unsigned int)idx < - (unsigned int)params.nhw) && is_valid_c) ? 1.f : 0.f; - - PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG]; - ldg_stream(x_storage_local, &gmem_src[(is_pixel_valid ? idx : 0)*params.c]); - - // The offset to store in SMEM. - const int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG; - // Store in SMEM. - write_to_smem(&smem_storage_packed[offset], threadIdx.x, x_storage_local); - // Update the count. - count += is_pixel_valid; - // Invert the count. - float inv_count = is_pixel_valid ? 1.f / count : 0.f; - - float x_math[ELEMENTS_PER_LDG]; - to_float(x_math, x_storage_local); - // Update the mean and m2 using deltas. - #pragma unroll - for (int j = 0; j < ELEMENTS_PER_LDG; ++j) { - float delta0 = x_math[j] - mean[j]; - mean[j] += delta0 * inv_count; - float delta1 = x_math[j] - mean[j]; - m2[j] += delta0 * delta1 * is_pixel_valid; - } - } - } - - // We scale the mean by the number of elements. It brings more stability. - float m1[ELEMENTS_PER_LDG]; - #pragma unroll - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - m1[i] = mean[i] * count; - } - - // Run the parallel sum accross the CTA to get the local sum. -#ifdef __HIP_PLATFORM_HCC__ - ParallelSums::template dispatch( -#else - ParallelSums::dispatch( -#endif - smem, m1, thread_in_cta_nhw); - __syncthreads(); - - // The values in shared memory correspond to the CTA-wide sums. - read_from_smem(m1, smem, thread_in_cta_c); - __syncthreads(); - - // Adjust the variance. - float inv_cta_count = 1.f / static_cast(cta_count); - #pragma unroll - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - float mean_diff = m1[i]*inv_cta_count - mean[i]; - m2[i] = m2[i] + mean_diff * mean_diff * count; - } - - // Run the parallel sum accross the CTA to get the local adjusted variance. -#ifdef __HIP_PLATFORM_HCC__ - ParallelSums::template dispatch( -#else - ParallelSums::dispatch( -#endif - smem, m2, thread_in_cta_nhw); - - // The workspace in global memory is distributed across the different CTA. - int gmem_sums_offset = c_blk_index*gridDim.x*C_ELEMENTS_PER_CTA*2; - - // Write the data for the CTA to global memory. - float *gmem_sums = ¶ms.gmem_sums[gmem_sums_offset]; - if (threadIdx.x < THREADS_PER_PIXEL) { - const int idx = blockIdx.x*THREADS_PER_PIXEL + threadIdx.x; - write_to_gmem(&gmem_sums[ 0], idx, m1); - write_to_gmem(&gmem_sums[C_ELEMENTS_PER_CTA*gridDim.x], idx, m2); - } - - // The memory location to store the number of pixels per CTA. - int *gmem_counts = ¶ms.gmem_counts[c_blk_index*gridDim.x]; - if (threadIdx.x == 0) { - gmem_counts[blockIdx.x] = cta_count; - } - - // Read the bias and scale. - float bias[ELEMENTS_PER_LDG], scale[ELEMENTS_PER_LDG]; - if (is_valid_c) { - read_from_gmem(bias, ¶ms.gmem_bias[cta_c], thread_in_cta_c); - read_from_gmem(scale, ¶ms.gmem_scale[cta_c], thread_in_cta_c); - } - - // The counters to count how many CTAs have retired at this point. - // A given cta uses the same counter every other time through the outer loop. - int *gmem_retired_ctas = ¶ms.gmem_retired_ctas[c_blk_index % (2 * gridDim.y)]; - inter_block_sync(gmem_retired_ctas, gridDim.x, blockIdx.x == 0); - - // Reset the mean to compute the global mean. - #pragma unroll - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - m1[i] = 0.f; - } - - // Build the global mean. - #pragma unroll 1 - for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL*gridDim.x; idx += THREADS_PER_CTA) { - float tmp[ELEMENTS_PER_LDG]; - read_from_gmem(tmp, gmem_sums, idx); - add(m1, tmp); - } - -#ifndef __HIP_PLATFORM_HCC__ - if (params.sync_iters>0) - { - ParallelSums::dispatchX( - smem, m1, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+3, params.magic, params.sync_iters); - } else { -#endif -#ifdef __HIP_PLATFORM_HCC__ - ParallelSums::template dispatch( -#else - ParallelSums::dispatch( -#endif - smem, m1, thread_in_cta_nhw); -#ifndef __HIP_PLATFORM_HCC__ - } -#endif - __syncthreads(); - - // The values in shared memory correspond to the CTA-wide sums. - read_from_smem(m1, smem, thread_in_cta_c); - __syncthreads(); - - // Normalize the mean. - #pragma unroll - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - m1[i] = m1[i] * params.svar_inv_count; - } - - // Reset the variance. - #pragma unroll - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - m2[i] = 0.f; - } - - // for add+relu fusion - const uint16_t *gmem_src1 = nullptr; - if (USE_ADD_RELU) { - gmem_src1 = ¶ms.gmem_src1[thread_c]; - } - - // Build the global variance. - #pragma unroll 1 - for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL*gridDim.x; idx += THREADS_PER_CTA) { - // Read the means computed by different CTAs (again). Reuse tmp if we have 1 iteration. - float tmp_mean[ELEMENTS_PER_LDG], tmp_var[ELEMENTS_PER_LDG]; - read_from_gmem(tmp_mean, &gmem_sums[ 0], idx); - read_from_gmem(tmp_var, &gmem_sums[C_ELEMENTS_PER_CTA*gridDim.x], idx); - - // Read the number of pixels visited by a given CTA. - cta_count = __ldg(&gmem_counts[idx / THREADS_PER_PIXEL]); - - // Compute the diff to update the variance. - float mean_diff[ELEMENTS_PER_LDG], inv_cta_count = 1.f / static_cast(cta_count); - #pragma unroll - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - mean_diff[i] = m1[i] - tmp_mean[i]*inv_cta_count; - } - - // Update the variance. - #pragma unroll - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - m2[i] += tmp_var[i] + mean_diff[i]*mean_diff[i]*static_cast(cta_count); - } - } - -#ifndef __HIP_PLATFORM_HCC__ - if (params.sync_iters>0) - { - ParallelSums::dispatchX( - smem, m2, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+2, params.magic, params.sync_iters); - } else { -#endif -#ifdef __HIP_PLATFORM_HCC__ - ParallelSums::template dispatch( -#else - ParallelSums::dispatch( -#endif - smem, m2, thread_in_cta_nhw); -#ifndef __HIP_PLATFORM_HCC__ - } -#endif - __syncthreads(); - - read_from_smem(m2, smem, thread_in_cta_c); - - // Finalize the stddev. - // becasue saved var and running var may have different denominator, we don't do it here - // scale_(m2, inv_count); - - // store the saved mean/var - float svarinv[ELEMENTS_PER_LDG]; - bool is_valid_for_saving = is_valid_c && blockIdx.x == 0 && thread_in_cta_nhw == 0; - #pragma unroll - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - svarinv[i] = rsqrtf(m2[i] * params.svar_inv_count + params.var_eps); - } - if (is_valid_for_saving) { - write_to_gmem(params.gmem_saved_mean, thread_c/ELEMENTS_PER_LDG, m1); - write_to_gmem(params.gmem_saved_var, thread_c/ELEMENTS_PER_LDG, svarinv); - } - - // store the running mean/var - float rmean[ELEMENTS_PER_LDG], rvar[ELEMENTS_PER_LDG]; - zero_array(rmean); - zero_array(rvar); - if (params.exp_avg_factor != 1.f && is_valid_for_saving) { - read_from_gmem(rmean, params.gmem_running_mean, thread_c/ELEMENTS_PER_LDG); - read_from_gmem(rvar, params.gmem_running_var, thread_c/ELEMENTS_PER_LDG); - } - #pragma unroll - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - rmean[i] = (1.f - params.exp_avg_factor) * rmean[i] + \ - params.exp_avg_factor * m1[i]; - rvar[i] = (1.f - params.exp_avg_factor) * rvar[i] + \ - params.exp_avg_factor * (m2[i] * params.rvar_inv_count); - } - if (is_valid_for_saving) { - write_to_gmem(params.gmem_running_mean, thread_c/ELEMENTS_PER_LDG, rmean); - write_to_gmem(params.gmem_running_var, thread_c/ELEMENTS_PER_LDG, rvar); - } - - // Update the scale with the stddev and eps. - multiply(scale, svarinv); - - // The base pointer to write to. - uint16_t *const gmem_dst = ¶ms.gmem_dst[thread_c]; - - bitmask_t *const gmem_relu_bitmask = params.gmem_relu_bitmask + -#ifdef __HIP_PLATFORM_HCC__ - ((params.nhw + 3) & ~3) * 2 * c_blk_index; -#else - ((params.nhw + 31) & ~31) * 2 * c_blk_index; -#endif - - // Store the elements in registers. - #pragma unroll 1 - for (int loop_i = OUTER_LOOPS-1; loop_i >= 0; --loop_i) { - // The value for nhw. - int out_nhw = cta_nhw_regs + loop_i*pixels_per_iteration; - - // Normalize the elements and write to memory. - #pragma unroll - for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { - const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG; - const bool is_valid_nhw = - static_cast(idx) < static_cast(params.nhw); - const bool is_valid = is_valid_nhw && is_valid_c; - // Convert to float. - float x_math[ELEMENTS_PER_LDG]; - to_float(x_math, x_storage[i]); - - // Normalize and apply activation function - normalize(x_math, bias, scale, m1); - if (USE_ADD_RELU) { - float x1_math[ELEMENTS_PER_LDG]; - ldg_stream(x1_math, &gmem_src1[(is_valid ? idx : 0)*params.c]); - add(x_math, x1_math); - bitmask_t relu_mask; -#ifdef __HIP_PLATFORM_HCC__ - int lane_id = threadIdx.x & 63; -#else - int lane_id = threadIdx.x & 31; -#endif - #pragma unroll - for (int j = 0; j < ELEMENTS_PER_LDG; ++j) { -#ifdef __HIP_PLATFORM_HCC__ - bool rectified = __hle(__float2half(x_math[j]), zero_h); -#else - bool rectified = x_math[j] < 0; -#endif - bitmask_t local_relu_mask = ballot(rectified); - if (lane_id == j) { - // Thread 0 remembers the relu_mask from the first time through this - // loop, Thread 1 the next, Thread 2 the next, and Thread 3 the last. - relu_mask = local_relu_mask; - } - if (rectified) { - x_math[j] = 0.0F; - } - } - if (is_valid_nhw && (lane_id < ELEMENTS_PER_LDG)) { - gmem_relu_bitmask[idx * BITMASK_OFFSET + lane_id] = relu_mask; - } - } else if (USE_RELU) { - relu_activation(x_math); - } - - // Write back. - if (is_valid) { - stg_stream(&gmem_dst[idx*params.c], x_math); - } - } - - // The next value of nhw. - out_nhw -= pixels_per_iteration; - - // Read the next elements from memory. - #pragma unroll - for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { - const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG; - if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) { - ldg_stream(x_storage[i], &gmem_src[idx*params.c]); - } - } - } - - // Normalize the elements from SMEM and write them out. - if (pixels_in_smem > 0) { - #pragma unroll 2 - for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) { - const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG; - const bool is_valid_nhw = - static_cast(idx) < static_cast(params.nhw); - const bool is_valid = is_valid_nhw && is_valid_c; - - // Read from SMEM. - const int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG; - PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG]; - read_from_smem(x_storage_local, &smem_storage_packed[offset], threadIdx.x); - float x_math[ELEMENTS_PER_LDG]; - to_float(x_math, x_storage_local); - - // Normalize and apply activation function - normalize(x_math, bias, scale, m1); - if (USE_ADD_RELU) { - float x1_math[ELEMENTS_PER_LDG]; - ldg_stream(x1_math, &gmem_src1[(is_valid ? idx : 0)*params.c]); - add(x_math, x1_math); - bitmask_t relu_mask; -#ifdef __HIP_PLATFORM_HCC__ - int lane_id = threadIdx.x & 63; -#else - int lane_id = threadIdx.x & 31; -#endif - #pragma unroll - for (int j = 0; j < ELEMENTS_PER_LDG; ++j) { -#ifdef __HIP_PLATFORM_HCC__ - bool rectified = __hle(__float2half(x_math[j]), zero_h); -#else - bool rectified = x_math[j] < 0; -#endif - bitmask_t local_relu_mask = ballot(rectified); - if (lane_id == j) { - relu_mask = local_relu_mask; - } - if (rectified) { - x_math[j] = 0.0F; - } - } - if (is_valid_nhw && (lane_id < ELEMENTS_PER_LDG)) { - gmem_relu_bitmask[idx * BITMASK_OFFSET + lane_id] = relu_mask; - } - } else if (USE_RELU) { - relu_activation(x_math); - } - - // Write back. - if (is_valid) { - stg_stream(&gmem_dst[idx*params.c], x_math); - } - } - } - // We're about to start on the next c-blk. Needed? - __syncthreads(); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct NhwcBatchNormBwdParams { - // The input/output tensors. - uint16_t *gmem_src, *gmem_dy, *gmem_dst, *gmem_dst1; - // dscale/dbias - float *gmem_dscale, *gmem_dbias; - // The scale and bias. - float *gmem_scale, *gmem_bias; - // The mean/inv-var saved from fwd pass - float *gmem_saved_mean, *gmem_saved_var; - // ReLU bitmask - bitmask_t *gmem_relu_bitmask; - // The dimensions. - int nhw, c; - // factor to scale sum of squared errors to get saved variance. Must be 1/nhw. - float svar_inv_count; - // The buffer to do the reduction for dscale and dbias - float *gmem_sums; - // The counters of retired CTAs. - int *gmem_retired_ctas; - // outer loop count - int outer_loops; - // number of CTAs along .x dimension - int c_blks; - - void* my_data; - void* pair_datas[4]; - int magic; - int sync_iters; - float wgrad_coeff; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -DEVICE_FUNCTION void relu_bwd(float (&dy)[N], const float (&x)[N], - const float (&mean_var_scale_bias)[N], - const float (&var_scale)[N], bool valid_data) { - #pragma unroll - for (int j = 0; j < N; ++j) { - float y = (x[j] * var_scale[j]) + mean_var_scale_bias[j]; - if ((y <= 0.f) && valid_data) { - dy[j] = 0.f; - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -DEVICE_FUNCTION void relu_bwd(float (&dy)[N], const float (&y)[N], bool valid_data) { - #pragma unroll - for (int j = 0; j < N; ++j) { - if ((y[j] <= 0.f) && valid_data) { - dy[j] = 0.f; - } - } -} - -template -DEVICE_FUNCTION void relu_bwd(float (&dy)[N], const bool (&rectified)[N], bool valid_data) { - #pragma unroll - for (int j = 0; j < N; ++j) { - if (rectified[j] && valid_data) { - dy[j] = 0.f; - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -DEVICE_FUNCTION void relu_bwd_for_dx(float (&dy)[N], - const float (&x)[N], - const float (&mean_var_scale_bias)[N], - const float (&var_scale)[N]) { - #pragma unroll - for (int j = 0; j < N; ++j) { - float y = (x[j] * var_scale[j]) + mean_var_scale_bias[j]; - if (y <= 0.f) { - dy[j] = 0.f; - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -DEVICE_FUNCTION void relu_bwd_for_dx(float (&dy)[N], const float (&y)[N]) { - #pragma unroll - for (int j = 0; j < N; ++j) { - if (y[j] <= 0.f) { - dy[j] = 0.f; - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -DEVICE_FUNCTION void bwd_update(float (&dscale)[N], float (&dbias)[N], - const float (&dy)[N], const float (&x)[N], - const float (&mean)[N], float inv_count) { - #pragma unroll - for (int j = 0; j < N; ++j) { - float delta0 = dy[j] - dbias[j]; - dbias[j] += delta0 * inv_count; - delta0 = (dy[j] * (x[j] - mean[j])) - dscale[j]; - dscale[j] += delta0 * inv_count; - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -DEVICE_FUNCTION void bwd_dx(float (&dx)[N], const float (&dy)[N], - const float (&var)[N], const float (&x)[N], const float (&mean)[N], - const float (&dscale)[N], const float (&dbias)[N], float inv_count) { - #pragma unroll - for (int j = 0; j < N; ++j) { - float tmp1 = dy[j] - (dbias[j]* inv_count); - float tmp2 = dscale[j] * inv_count; - float tmp3 = x[j] - mean[j]; - dx[j] = var[j] * (tmp1 - (tmp2 * tmp3)); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< - typename Storage, - int THREADS_PER_CTA, - int THREADS_PER_PIXEL, - int PIXELS_PER_THREAD_IN_REGISTERS, - int PIXELS_PER_THREAD_IN_SMEM, - int ELEMENTS_PER_LDG, - int USE_ONLINE_APPROACH, - int OUTER_LOOPS_, - int DESIRED_OCCUPANCY -> -__global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) - void nhwc_batch_norm_bwd(NhwcBatchNormBwdParams params) { - // The number of pixels loaded in a single LDG. - const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL; - // The number of pixels computed per CTA stored in registers. - const int PIXELS_PER_CTA_IN_REGISTERS = PIXELS_PER_THREAD_IN_REGISTERS * PIXELS_PER_LDG; - // The number of pixels computed per CTA stored in SMEM. - const int PIXELS_PER_CTA_IN_SMEM = PIXELS_PER_THREAD_IN_SMEM*PIXELS_PER_LDG; - // The number of C elements per CTA. - const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG; - - // Shared memory to do CTA-wide parallel sums. - __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/warpSize)*ELEMENTS_PER_LDG]; - - // The adapter for the storage. - typedef PackedStorage PackedStorage_; - // The data type for packed storage in SMEM. - typedef typename PackedStorage_::Type PackedStorageType; - // The number of elements in the packed storage. - const int PACKED_ELEMENTS_PER_LDG = PackedStorage_::PACKED_ELEMENTS_PER_LDG; - // Registers to keep the data live for the persistent approach. - PackedStorageType x_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG]; - PackedStorageType dy_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG]; - - // Shared memory buffer to store the extra pixels. - extern __shared__ PackedStorageType smem_storage_packed[]; - - for (int c_blk_index = blockIdx.y; c_blk_index < params.c_blks; c_blk_index += gridDim.y) { - // The position in the NHW dimension where the CTA starts. - int cta_nhw_regs = blockIdx.x * PIXELS_PER_CTA_IN_REGISTERS; - // The position in the NHW dimension where the CTA starts for the portion in SMEM. - int cta_nhw_smem = blockIdx.x * PIXELS_PER_CTA_IN_SMEM; - // Compute the NHW coordinate of the thread in the CTA. - const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL; - - // The position in the C dimension where the CTA starts. - const int cta_c = c_blk_index * C_ELEMENTS_PER_CTA; - // Compute the C coordinate of the thread in the CTA. - const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL; - // Compute the C coordinate of the thread. - const int thread_c = cta_c + thread_in_cta_c*ELEMENTS_PER_LDG; - - // Is the thread working on a valid C dimension? - const int is_valid_c = thread_c < params.c; - - // Registers to store the mean used for entire duration - float mean[ELEMENTS_PER_LDG]; - zero_array(mean); - if (is_valid_c) { - read_from_gmem(mean, params.gmem_saved_mean, thread_c/ELEMENTS_PER_LDG); - } - - // accumulation related registers - float count = 0.f, dscale[ELEMENTS_PER_LDG], dbias[ELEMENTS_PER_LDG]; - zero_array(dscale); - zero_array(dbias); - - // The number of elements loaded by this CTA. - int cta_count = 0; - // The base pointers to load from. - const uint16_t *gmem_src = ¶ms.gmem_src[thread_c]; - const uint16_t *gmem_dy = ¶ms.gmem_dy[thread_c]; - - // outer loops - int OUTER_LOOPS = OUTER_LOOPS_ == 1? 1 : params.outer_loops; - // Load the batch of elements. Compute sum across them - const int pixels_per_iteration = PIXELS_PER_CTA_IN_REGISTERS*gridDim.x; - - if (OUTER_LOOPS_ != 1) { - // We cannot load everything to store persistently, so let's makes sure registers and - // smem are fully utilized - int offset = params.nhw - pixels_per_iteration * OUTER_LOOPS - - PIXELS_PER_CTA_IN_SMEM * gridDim.x; - cta_nhw_regs += offset; - cta_nhw_smem += offset; - } - - #pragma unroll 1 - for (int loop_i = 0; loop_i < OUTER_LOOPS; ++loop_i) { - // The nhw position. - int nhw_regs = cta_nhw_regs + loop_i*pixels_per_iteration; - // Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!! - cta_count += max(0, min(PIXELS_PER_CTA_IN_REGISTERS, params.nhw-nhw_regs)); - - // Read the elements from memory. - float is_valid[PIXELS_PER_THREAD_IN_REGISTERS]; - #pragma unroll - for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { - const int idx = nhw_regs + thread_in_cta_nhw + i*PIXELS_PER_LDG; - zero_array(x_storage[i]); - zero_array(dy_storage[i]); - is_valid[i] = 0.f; - if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) { - if (loop_i == OUTER_LOOPS - 1) { - ldg_stream(x_storage[i], &gmem_src[idx*params.c]); - ldg_stream(dy_storage[i], &gmem_dy[idx*params.c]); - } else { - ldg(x_storage[i], &gmem_src[idx*params.c]); - ldg(dy_storage[i], &gmem_dy[idx*params.c]); - } - is_valid[i] = 1.f; - } - } - - // Do the math. - #pragma unroll - for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { - // Convert to float and update - float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; - to_float(x_math, x_storage[i]); - to_float(dy_math, dy_storage[i]); - - // Update the count. - count += is_valid[i]; - // Invert the count. - float inv_count = is_valid[i] ? 1.f / count : 0.f; - - bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count); - } - } - - // The elements to load and store in SMEM. - int smem_nhw = OUTER_LOOPS*pixels_per_iteration + cta_nhw_smem; - // Load elements from SMEM, update the CTA count. - int pixels_in_smem = min(PIXELS_PER_CTA_IN_SMEM, params.nhw-smem_nhw); - if (pixels_in_smem > 0) { - cta_count += pixels_in_smem; - for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) { - const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG; - bool is_pixel_valid = (((unsigned int)idx < - (unsigned int)params.nhw) && is_valid_c); - PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG], - dy_storage_local[PACKED_ELEMENTS_PER_LDG]; - zero_array(x_storage_local); - zero_array(dy_storage_local); - if (is_pixel_valid) { - ldg_stream(x_storage_local, &gmem_src[idx*params.c]); - ldg_stream(dy_storage_local, &gmem_dy[idx*params.c]); - } - - // The offset to store in SMEM. - int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG; - // Store in SMEM. - write_to_smem(&smem_storage_packed[offset], threadIdx.x, x_storage_local); - offset += PIXELS_PER_THREAD_IN_SMEM*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG; - write_to_smem(&smem_storage_packed[offset], threadIdx.x, dy_storage_local); - // Update the count. - count += is_pixel_valid; - // Invert the count. - float inv_count = is_pixel_valid ? 1.f / count : 0.f; - - float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; - to_float(x_math, x_storage_local); - to_float(dy_math, dy_storage_local); - - bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count); - } - } - - // We scale the mean by the number of elements. It brings more stability. - #pragma unroll - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - dbias[i] *= count; - dscale[i] *= count; - } - - // dscale parallel sum -#ifdef __HIP_PLATFORM_HCC__ - ParallelSums::template dispatch( -#else - ParallelSums::dispatch( -#endif - smem, dscale, thread_in_cta_nhw); - __syncthreads(); - // The values in shared memory correspond to the CTA-wide sums. - read_from_smem(dscale, smem, thread_in_cta_c); - __syncthreads(); - - // dbias parallel sum -#ifdef __HIP_PLATFORM_HCC__ - ParallelSums::template dispatch( -#else - ParallelSums::dispatch( -#endif - smem, dbias, thread_in_cta_nhw); - __syncthreads(); - // The values in shared memory correspond to the CTA-wide sums. - read_from_smem(dbias, smem, thread_in_cta_c); - __syncthreads(); - - // The workspace in global memory is distributed across the different CTA. - int gmem_sums_offset = c_blk_index*gridDim.x*C_ELEMENTS_PER_CTA*2; - // Write the data for the CTA to global memory. - float *gmem_sums = ¶ms.gmem_sums[gmem_sums_offset]; - if (threadIdx.x < THREADS_PER_PIXEL) { - const int idx = blockIdx.x*THREADS_PER_PIXEL + threadIdx.x; - write_to_gmem(&gmem_sums[ 0], idx, dscale); - write_to_gmem(&gmem_sums[C_ELEMENTS_PER_CTA*gridDim.x], idx, dbias); - } - - // The counters to count how many CTAs have retired at this point. - // A given cta uses the same counter every other time through the outer loop. - int *gmem_retired_ctas = ¶ms.gmem_retired_ctas[c_blk_index % (2 * gridDim.y)]; - inter_block_sync(gmem_retired_ctas, gridDim.x, blockIdx.x == 0); - - // Reset the accumulators for global summation - zero_array(dscale); - zero_array(dbias); - - // Build the global accumulation - #pragma unroll 1 - for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL*gridDim.x; idx += THREADS_PER_CTA) { - float tmp1[ELEMENTS_PER_LDG], tmp2[ELEMENTS_PER_LDG]; - read_from_gmem(tmp1, gmem_sums, idx); - read_from_gmem(tmp2, gmem_sums+C_ELEMENTS_PER_CTA*gridDim.x, idx); - - #pragma unroll - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - dscale[i] += tmp1[i]; - dbias[i] += tmp2[i]; - } - } - - // dscale parallel sum -#ifndef __HIP_PLATFORM_HCC__ - if (params.sync_iters>0) { - ParallelSums::dispatchX( - smem, dscale, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+1, params.magic, params.sync_iters); - } else { -#endif -#ifdef __HIP_PLATFORM_HCC__ - ParallelSums::template dispatch( -#else - ParallelSums::dispatch( -#endif - smem, dscale, thread_in_cta_nhw); -#ifndef __HIP_PLATFORM_HCC__ - } -#endif - - __syncthreads(); - // The values in shared memory correspond to the CTA-wide sums. - read_from_smem(dscale, smem, thread_in_cta_c); - __syncthreads(); - - // dbias parallel sum -#ifndef __HIP_PLATFORM_HCC__ - if (params.sync_iters>0) { - ParallelSums::dispatchX( - smem, dbias, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+0, params.magic, params.sync_iters); - } else { -#endif -#ifdef __HIP_PLATFORM_HCC__ - ParallelSums::template dispatch( -#else - ParallelSums::dispatch( -#endif - smem, dbias, thread_in_cta_nhw); -#ifndef __HIP_PLATFORM_HCC__ - } -#endif - - __syncthreads(); - // The values in shared memory correspond to the CTA-wide sums. - read_from_smem(dbias, smem, thread_in_cta_c); - - // inv-var - float var[ELEMENTS_PER_LDG]; - zero_array(var); - if (is_valid_c) { - read_from_gmem(var, params.gmem_saved_var, thread_c/ELEMENTS_PER_LDG); - } - - // Normalize the dscale. - multiply(dscale, var); - - // store dscale/dbias - bool is_valid_for_saving = is_valid_c && blockIdx.x == 0 && thread_in_cta_nhw == 0; - if (is_valid_for_saving) { - if (params.sync_iters>0) - { - scaled_write_to_gmem(params.gmem_dscale, thread_c/ELEMENTS_PER_LDG, dscale, params.wgrad_coeff); - scaled_write_to_gmem(params.gmem_dbias, thread_c/ELEMENTS_PER_LDG, dbias, params.wgrad_coeff); - } else { - write_to_gmem(params.gmem_dscale, thread_c/ELEMENTS_PER_LDG, dscale); - write_to_gmem(params.gmem_dbias, thread_c/ELEMENTS_PER_LDG, dbias); - } - } - - // scale - float scale[ELEMENTS_PER_LDG]; - zero_array(scale); - if (is_valid_c) { - read_from_gmem(scale, params.gmem_scale, thread_c/ELEMENTS_PER_LDG); - } - - // Further normalize the dscale to be used in dx calculation - multiply(dscale, var); - // scale the inv-var as well, afterwards - multiply(var, scale); - - // inverse count - float inv_count = params.svar_inv_count; - - // The base pointer to write to. - uint16_t *const gmem_dst = ¶ms.gmem_dst[thread_c]; - - // Store the elements in registers. - #pragma unroll 1 - for (int loop_i = OUTER_LOOPS-1; loop_i >= 0; --loop_i) { - // The value for nhw. - int out_nhw = cta_nhw_regs + loop_i*pixels_per_iteration; - - // Normalize the elements and write to memory. - #pragma unroll - for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { - // Convert to float. - float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; - to_float(x_math, x_storage[i]); - to_float(dy_math, dy_storage[i]); - - float dx[ELEMENTS_PER_LDG]; - bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count); - - // Write back. - const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG; - if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) { - stg_stream(&gmem_dst[idx*params.c], dx); - } - } - - // The next value of nhw. - out_nhw -= pixels_per_iteration; - - // Read the next elements from memory. - #pragma unroll - for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { - const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG; - if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) { - ldg_stream(x_storage[i], &gmem_src[idx*params.c]); - ldg_stream(dy_storage[i], &gmem_dy[idx*params.c]); - } - } - } - - // Normalize the elements from SMEM and write them out. - if (pixels_in_smem > 0) { - for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) { - const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG; - const bool is_valid = ((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c; - if (is_valid) { - // Read from SMEM. - int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG; - PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG], - dy_storage_local[PACKED_ELEMENTS_PER_LDG]; - read_from_smem(x_storage_local, &smem_storage_packed[offset], threadIdx.x); - offset += PIXELS_PER_THREAD_IN_SMEM*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG; - read_from_smem(dy_storage_local, &smem_storage_packed[offset], threadIdx.x); - float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; - to_float(x_math, x_storage_local); - to_float(dy_math, dy_storage_local); - - float dx[ELEMENTS_PER_LDG]; - bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count); - - // Write back. - stg_stream(&gmem_dst[idx*params.c], dx); - } - } - } - // We're about to start on the next c-blk. Needed? - __syncthreads(); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< - typename Storage, - int THREADS_PER_CTA, - int THREADS_PER_PIXEL, - int PIXELS_PER_THREAD_IN_REGISTERS, - int PIXELS_PER_THREAD_IN_SMEM, - int ELEMENTS_PER_LDG, - int USE_ONLINE_APPROACH, - int OUTER_LOOPS_, - int DESIRED_OCCUPANCY -> -__global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) - void nhwc_batch_norm_bwd_relu(NhwcBatchNormBwdParams params) { - // The number of pixels loaded in a single LDG. - const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL; - // The number of pixels computed per CTA stored in registers. - const int PIXELS_PER_CTA_IN_REGISTERS = PIXELS_PER_THREAD_IN_REGISTERS * PIXELS_PER_LDG; - // The number of pixels computed per CTA stored in SMEM. - const int PIXELS_PER_CTA_IN_SMEM = PIXELS_PER_THREAD_IN_SMEM*PIXELS_PER_LDG; - // The number of C elements per CTA. - const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG; - - // Shared memory to do CTA-wide parallel sums. - __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/warpSize)*ELEMENTS_PER_LDG]; - - // The adapter for the storage. - typedef PackedStorage PackedStorage_; - // The data type for packed storage in SMEM. - typedef typename PackedStorage_::Type PackedStorageType; - // The number of elements in the packed storage. - const int PACKED_ELEMENTS_PER_LDG = PackedStorage_::PACKED_ELEMENTS_PER_LDG; - // Registers to keep the data live for the persistent approach. - PackedStorageType x_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG]; - PackedStorageType dy_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG]; - - // Shared memory buffer to store the extra pixels. - extern __shared__ PackedStorageType smem_storage_packed[]; - - for (int c_blk_index = blockIdx.y; c_blk_index < params.c_blks; c_blk_index += gridDim.y) { - // The position in the NHW dimension where the CTA starts. - int cta_nhw_regs = blockIdx.x * PIXELS_PER_CTA_IN_REGISTERS; - // The position in the NHW dimension where the CTA starts for the portion in SMEM. - int cta_nhw_smem = blockIdx.x * PIXELS_PER_CTA_IN_SMEM; - // Compute the NHW coordinate of the thread in the CTA. - const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL; - - // The position in the C dimension where the CTA starts. - const int cta_c = c_blk_index * C_ELEMENTS_PER_CTA; - // Compute the C coordinate of the thread in the CTA. - const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL; - // Compute the C coordinate of the thread. - const int thread_c = cta_c + thread_in_cta_c*ELEMENTS_PER_LDG; - - // Is the thread working on a valid C dimension? - const int is_valid_c = thread_c < params.c; - - - // Registers to store the mean/var/scale/bias used for the entire duration - // Register usage optimizations: - // 1. Can combine bias - (mean * var * scale) into a single register - // 2. Can combine var * scale into a single register - float varscale[ELEMENTS_PER_LDG]; - zero_array(varscale); - if (is_valid_c) { - read_from_gmem(varscale, params.gmem_saved_var, thread_c/ELEMENTS_PER_LDG); - } - float tmp[ELEMENTS_PER_LDG]; - zero_array(tmp); - if (is_valid_c) { - read_from_gmem(tmp, params.gmem_scale, thread_c/ELEMENTS_PER_LDG); - } - multiply(varscale, tmp); - float mean[ELEMENTS_PER_LDG]; - zero_array(mean); - if (is_valid_c) { - read_from_gmem(mean, params.gmem_saved_mean, thread_c/ELEMENTS_PER_LDG); - } - zero_array(tmp); - if (is_valid_c) { - read_from_gmem(tmp, params.gmem_bias, thread_c/ELEMENTS_PER_LDG); - } - float mean_var_scale_bias[ELEMENTS_PER_LDG]; - #pragma unroll - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - mean_var_scale_bias[i] = tmp[i] - (mean[i] * varscale[i]); - } - - // accumulation related registers - float count = 0.f, dscale[ELEMENTS_PER_LDG], dbias[ELEMENTS_PER_LDG]; - zero_array(dscale); - zero_array(dbias); - - // The number of elements loaded by this CTA. - int cta_count = 0; - // The base pointers to load from. - const uint16_t *gmem_src = ¶ms.gmem_src[thread_c]; - const uint16_t *gmem_dy = ¶ms.gmem_dy[thread_c]; - - // outer loops - int OUTER_LOOPS = OUTER_LOOPS_ == 1? 1 : params.outer_loops; - // Load the batch of elements. Compute sum across them - const int pixels_per_iteration = PIXELS_PER_CTA_IN_REGISTERS*gridDim.x; - - if (OUTER_LOOPS_ != 1) { - // We cannot load everything to store persistently, so let's makes sure registers and - // smem are fully utilized - int offset = params.nhw - pixels_per_iteration * OUTER_LOOPS - - PIXELS_PER_CTA_IN_SMEM * gridDim.x; - cta_nhw_regs += offset; - cta_nhw_smem += offset; - } - - #pragma unroll 1 - for (int loop_i = 0; loop_i < OUTER_LOOPS; ++loop_i) { - // The nhw position. - int nhw_regs = cta_nhw_regs + loop_i*pixels_per_iteration; - // Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!! - cta_count += max(0, min(PIXELS_PER_CTA_IN_REGISTERS, params.nhw-nhw_regs)); - - // Read the elements from memory. - float is_valid[PIXELS_PER_THREAD_IN_REGISTERS]; - #pragma unroll - for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { - const int idx = nhw_regs + thread_in_cta_nhw + i*PIXELS_PER_LDG; - zero_array(x_storage[i]); - zero_array(dy_storage[i]); - is_valid[i] = 0.f; - if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) { - if (loop_i == OUTER_LOOPS - 1) { - ldg_stream(x_storage[i], &gmem_src[idx*params.c]); - ldg_stream(dy_storage[i], &gmem_dy[idx*params.c]); - } else { - ldg(x_storage[i], &gmem_src[idx*params.c]); - ldg(dy_storage[i], &gmem_dy[idx*params.c]); - } - is_valid[i] = 1.f; - } - } - - // Do the math. - #pragma unroll - for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { - // Convert to float and update - float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; - to_float(x_math, x_storage[i]); - to_float(dy_math, dy_storage[i]); - - // Update the count. - count += is_valid[i]; - // Invert the count. - float inv_count = is_valid[i] ? 1.f / count : 0.f; - - relu_bwd(dy_math, x_math, mean_var_scale_bias, varscale, is_valid[i]); - bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count); - } - } - - // The elements to load and store in SMEM. - int smem_nhw = OUTER_LOOPS*pixels_per_iteration + cta_nhw_smem; - // Load elements from SMEM, update the CTA count. - int pixels_in_smem = min(PIXELS_PER_CTA_IN_SMEM, params.nhw-smem_nhw); - if (pixels_in_smem > 0) { - cta_count += pixels_in_smem; - for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) { - const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG; - bool is_pixel_valid = (((unsigned int)idx < - (unsigned int)params.nhw) && is_valid_c); - PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG], - dy_storage_local[PACKED_ELEMENTS_PER_LDG]; - zero_array(x_storage_local); - zero_array(dy_storage_local); - if (is_pixel_valid) { - ldg_stream(x_storage_local, &gmem_src[idx*params.c]); - ldg_stream(dy_storage_local, &gmem_dy[idx*params.c]); - } - - // The offset to store in SMEM. - int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG; - // Store in SMEM. - write_to_smem(&smem_storage_packed[offset], threadIdx.x, x_storage_local); - offset += PIXELS_PER_THREAD_IN_SMEM*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG; - write_to_smem(&smem_storage_packed[offset], threadIdx.x, dy_storage_local); - // Update the count. - count += is_pixel_valid; - // Invert the count. - float inv_count = is_pixel_valid ? 1.f / count : 0.f; - - float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; - to_float(x_math, x_storage_local); - to_float(dy_math, dy_storage_local); - - relu_bwd(dy_math, x_math, mean_var_scale_bias, varscale, is_pixel_valid); - bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count); - } - } - - // We scale the mean by the number of elements. It brings more stability. - #pragma unroll - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - dbias[i] *= count; - dscale[i] *= count; - } - - // dscale parallel sum -#ifdef __HIP_PLATFORM_HCC__ - ParallelSums::template dispatch( -#else - ParallelSums::dispatch( -#endif - smem, dscale, thread_in_cta_nhw); - __syncthreads(); - // The values in shared memory correspond to the CTA-wide sums. - read_from_smem(dscale, smem, thread_in_cta_c); - __syncthreads(); - - // dbias parallel sum -#ifdef __HIP_PLATFORM_HCC__ - ParallelSums::template dispatch( -#else - ParallelSums::dispatch( -#endif - smem, dbias, thread_in_cta_nhw); - __syncthreads(); - // The values in shared memory correspond to the CTA-wide sums. - read_from_smem(dbias, smem, thread_in_cta_c); - __syncthreads(); - - // The workspace in global memory is distributed across the different CTA. - int gmem_sums_offset = c_blk_index*gridDim.x*C_ELEMENTS_PER_CTA*2; - // Write the data for the CTA to global memory. - float *gmem_sums = ¶ms.gmem_sums[gmem_sums_offset]; - if (threadIdx.x < THREADS_PER_PIXEL) { - const int idx = blockIdx.x*THREADS_PER_PIXEL + threadIdx.x; - write_to_gmem(&gmem_sums[ 0], idx, dscale); - write_to_gmem(&gmem_sums[C_ELEMENTS_PER_CTA*gridDim.x], idx, dbias); - } - - // The counters to count how many CTAs have retired at this point. - // A given cta uses the same counter every other time through the outer loop. - int *gmem_retired_ctas = ¶ms.gmem_retired_ctas[c_blk_index % (2 * gridDim.y)]; - inter_block_sync(gmem_retired_ctas, gridDim.x, blockIdx.x == 0); - - // Reset the accumulators for global summation - zero_array(dscale); - zero_array(dbias); - - // Build the global accumulation - #pragma unroll 1 - for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL*gridDim.x; idx += THREADS_PER_CTA) { - float tmp1[ELEMENTS_PER_LDG], tmp2[ELEMENTS_PER_LDG]; - read_from_gmem(tmp1, gmem_sums, idx); - read_from_gmem(tmp2, gmem_sums+C_ELEMENTS_PER_CTA*gridDim.x, idx); - - #pragma unroll - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - dscale[i] += tmp1[i]; - dbias[i] += tmp2[i]; - } - } - - // dscale parallel sum -#ifndef __HIP_PLATFORM_HCC__ - if (params.sync_iters>0) { - ParallelSums::dispatchX( - smem, dscale, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+1, params.magic, params.sync_iters); - } else { -#endif -#ifdef __HIP_PLATFORM_HCC__ - ParallelSums::template dispatch( -#else - ParallelSums::dispatch( -#endif - smem, dscale, thread_in_cta_nhw); -#ifndef __HIP_PLATFORM_HCC__ - } -#endif - - __syncthreads(); - // The values in shared memory correspond to the CTA-wide sums. - read_from_smem(dscale, smem, thread_in_cta_c); - __syncthreads(); - - // dbias parallel sum -#ifndef __HIP_PLATFORM_HCC__ - if (params.sync_iters>0) { - ParallelSums::dispatchX( - smem, dbias, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+0, params.magic, params.sync_iters); - } else { -#endif -#ifdef __HIP_PLATFORM_HCC__ - ParallelSums::template dispatch( -#else - ParallelSums::dispatch( -#endif - smem, dbias, thread_in_cta_nhw); -#ifndef __HIP_PLATFORM_HCC__ - } -#endif - - __syncthreads(); - // The values in shared memory correspond to the CTA-wide sums. - read_from_smem(dbias, smem, thread_in_cta_c); - - // Normalize the dscale. - float var[ELEMENTS_PER_LDG]; - zero_array(var); - if (is_valid_c) { - read_from_gmem(var, params.gmem_saved_var, thread_c/ELEMENTS_PER_LDG); - } - multiply(dscale, var); - - // store dscale/dbias - bool is_valid_for_saving = is_valid_c && blockIdx.x == 0 && thread_in_cta_nhw == 0; - if (is_valid_for_saving) { - if (params.sync_iters>0) - { - scaled_write_to_gmem(params.gmem_dscale, thread_c/ELEMENTS_PER_LDG, dscale, params.wgrad_coeff); - scaled_write_to_gmem(params.gmem_dbias, thread_c/ELEMENTS_PER_LDG, dbias, params.wgrad_coeff); - } else { - write_to_gmem(params.gmem_dscale, thread_c/ELEMENTS_PER_LDG, dscale); - write_to_gmem(params.gmem_dbias, thread_c/ELEMENTS_PER_LDG, dbias); - } - } - - // Further normalize the dscale to be used in dx calculation - float scale[ELEMENTS_PER_LDG]; - zero_array(scale); - if (is_valid_c) { - read_from_gmem(scale, params.gmem_scale, thread_c/ELEMENTS_PER_LDG); - } - multiply(dscale, var); - // scale the inv-var as well, afterwards - multiply(var, scale); - - // inverse count - float inv_count = params.svar_inv_count; - - // The base pointer to write to. - uint16_t *const gmem_dst = ¶ms.gmem_dst[thread_c]; - - // Store the elements in registers. - #pragma unroll 1 - for (int loop_i = OUTER_LOOPS-1; loop_i >= 0; --loop_i) { - // The value for nhw. - int out_nhw = cta_nhw_regs + loop_i*pixels_per_iteration; - - // Normalize the elements and write to memory. - #pragma unroll - for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { - // Convert to float. - float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; - to_float(x_math, x_storage[i]); - to_float(dy_math, dy_storage[i]); - relu_bwd_for_dx(dy_math, x_math, mean_var_scale_bias, var); - - float dx[ELEMENTS_PER_LDG]; - bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count); - - // Write back. - const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG; - if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) { - stg_stream(&gmem_dst[idx*params.c], dx); - } - } - - // The next value of nhw. - out_nhw -= pixels_per_iteration; - - // Read the next elements from memory. - #pragma unroll - for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { - const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG; - if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) { - ldg_stream(x_storage[i], &gmem_src[idx*params.c]); - ldg_stream(dy_storage[i], &gmem_dy[idx*params.c]); - } - } - } - - // Normalize the elements from SMEM and write them out. - if (pixels_in_smem > 0) { - for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) { - const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG; - const bool is_valid = ((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c; - if (is_valid) { - // Read from SMEM. - int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG; - PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG], - dy_storage_local[PACKED_ELEMENTS_PER_LDG]; - read_from_smem(x_storage_local, &smem_storage_packed[offset], threadIdx.x); - offset += PIXELS_PER_THREAD_IN_SMEM*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG; - read_from_smem(dy_storage_local, &smem_storage_packed[offset], threadIdx.x); - float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; - to_float(x_math, x_storage_local); - to_float(dy_math, dy_storage_local); - relu_bwd_for_dx(dy_math, x_math, mean_var_scale_bias, var); - - float dx[ELEMENTS_PER_LDG]; - bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count); - - // Write back. - stg_stream(&gmem_dst[idx*params.c], dx); - } - } - } - // We're about to start on the next c-blk. Needed? - __syncthreads(); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< - typename Storage, - int THREADS_PER_CTA, - int THREADS_PER_PIXEL, - int PIXELS_PER_THREAD_IN_REGISTERS, - int PIXELS_PER_THREAD_IN_SMEM, - int ELEMENTS_PER_LDG, - int USE_ONLINE_APPROACH, - int OUTER_LOOPS_, - int DESIRED_OCCUPANCY -> -__global__ __launch_bounds__(THREADS_PER_CTA, DESIRED_OCCUPANCY) - void nhwc_batch_norm_bwd_add_relu(NhwcBatchNormBwdParams params) { - // The number of pixels loaded in a single LDG. - const int PIXELS_PER_LDG = THREADS_PER_CTA / THREADS_PER_PIXEL; - // The number of pixels computed per CTA stored in registers. - const int PIXELS_PER_CTA_IN_REGISTERS = PIXELS_PER_THREAD_IN_REGISTERS * PIXELS_PER_LDG; - // The number of pixels computed per CTA stored in SMEM. - const int PIXELS_PER_CTA_IN_SMEM = PIXELS_PER_THREAD_IN_SMEM*PIXELS_PER_LDG; - // The number of C elements per CTA. - const int C_ELEMENTS_PER_CTA = THREADS_PER_PIXEL*ELEMENTS_PER_LDG; - - // Shared memory to do CTA-wide parallel sums. - __shared__ float smem[THREADS_PER_PIXEL*(THREADS_PER_CTA/warpSize)*ELEMENTS_PER_LDG]; - - // The adapter for the storage. - typedef PackedStorage PackedStorage_; - // The data type for packed storage in SMEM. - typedef typename PackedStorage_::Type PackedStorageType; - // The number of elements in the packed storage. - const int PACKED_ELEMENTS_PER_LDG = PackedStorage_::PACKED_ELEMENTS_PER_LDG; - // Registers to keep the data live for the persistent approach. - PackedStorageType x_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG]; - PackedStorageType dy_storage[PIXELS_PER_THREAD_IN_REGISTERS][PACKED_ELEMENTS_PER_LDG]; - - // Shared memory buffer to store the extra pixels. - extern __shared__ PackedStorageType smem_storage_packed[]; - - for (int c_blk_index = blockIdx.y; c_blk_index < params.c_blks; c_blk_index += gridDim.y) { - // The position in the NHW dimension where the CTA starts. - int cta_nhw_regs = blockIdx.x * PIXELS_PER_CTA_IN_REGISTERS; - // The position in the NHW dimension where the CTA starts for the portion in SMEM. - int cta_nhw_smem = blockIdx.x * PIXELS_PER_CTA_IN_SMEM; - // Compute the NHW coordinate of the thread in the CTA. - const int thread_in_cta_nhw = threadIdx.x / THREADS_PER_PIXEL; - - // The position in the C dimension where the CTA starts. - const int cta_c = c_blk_index * C_ELEMENTS_PER_CTA; - // Compute the C coordinate of the thread in the CTA. - const int thread_in_cta_c = threadIdx.x % THREADS_PER_PIXEL; - // Compute the C coordinate of the thread. - const int thread_c = cta_c + thread_in_cta_c*ELEMENTS_PER_LDG; - - // Is the thread working on a valid C dimension? - const int is_valid_c = thread_c < params.c; - - float mean[ELEMENTS_PER_LDG]; - zero_array(mean); - if (is_valid_c) { - read_from_gmem(mean, params.gmem_saved_mean, thread_c/ELEMENTS_PER_LDG); - } - - // accumulation related registers - float count = 0.f, dscale[ELEMENTS_PER_LDG], dbias[ELEMENTS_PER_LDG]; - zero_array(dscale); - zero_array(dbias); - - // The number of elements loaded by this CTA. - int cta_count = 0; - // The base pointers to load from. - const uint16_t *gmem_src = ¶ms.gmem_src[thread_c]; - const uint16_t *gmem_dy = ¶ms.gmem_dy[thread_c]; - uint16_t *gmem_dst1 = ¶ms.gmem_dst1[thread_c]; - - // outer loops - int OUTER_LOOPS = OUTER_LOOPS_ == 1? 1 : params.outer_loops; - // Load the batch of elements. Compute sum across them - const int pixels_per_iteration = PIXELS_PER_CTA_IN_REGISTERS*gridDim.x; - - if (OUTER_LOOPS_ != 1) { - // We cannot load everything to store persistently, so let's makes sure registers and - // smem are fully utilized, offset is evenly divisible by 32 - int offset = (pixels_per_iteration * OUTER_LOOPS + PIXELS_PER_CTA_IN_SMEM * gridDim.x - - params.nhw) & ~31; - cta_nhw_regs -= offset; - cta_nhw_smem -= offset; - } - - const bitmask_t *const gmem_relu_bitmask = params.gmem_relu_bitmask + -#ifdef __HIP_PLATFORM_HCC__ - ((params.nhw + 3) & ~3) * 2 * c_blk_index; -#else - ((params.nhw + 31) & ~31) * 2 * c_blk_index; -#endif - - #pragma unroll 1 - for (int loop_i = 0; loop_i < OUTER_LOOPS; ++loop_i) { - // The nhw position. - int nhw_regs = cta_nhw_regs + loop_i*pixels_per_iteration; - // Update the number of elements loaded by this CTA. TODO: Skip if <= 0!!! - cta_count += max(0, min(PIXELS_PER_CTA_IN_REGISTERS, params.nhw-nhw_regs)); - -#ifdef __HIP_PLATFORM_HCC__ - int lane_id = threadIdx.x & 63; -#else - int lane_id = threadIdx.x & 31; -#endif - - // Read the elements from memory. - float is_valid[PIXELS_PER_THREAD_IN_REGISTERS]; - bitmask_t relu_mask[PIXELS_PER_THREAD_IN_REGISTERS]; - #pragma unroll - for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { - const int idx = nhw_regs + thread_in_cta_nhw + i*PIXELS_PER_LDG; - zero_array(x_storage[i]); - zero_array(dy_storage[i]); - is_valid[i] = 0.f; - const bool is_valid_nhw = - static_cast(idx) < static_cast(params.nhw); - if (is_valid_nhw) { - if (is_valid_c) { - if (loop_i == OUTER_LOOPS - 1) { - ldg_stream(x_storage[i], &gmem_src[idx*params.c]); - ldg_stream(dy_storage[i], &gmem_dy[idx*params.c]); - } else { - ldg(x_storage[i], &gmem_src[idx*params.c]); - ldg(dy_storage[i], &gmem_dy[idx*params.c]); - } - is_valid[i] = 1.f; - } - - if (lane_id < ELEMENTS_PER_LDG) { - relu_mask[i] = gmem_relu_bitmask[idx * BITMASK_OFFSET + lane_id]; - } - } - } - - // Do the math. - #pragma unroll - for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { - const int idx = nhw_regs + thread_in_cta_nhw + i*PIXELS_PER_LDG; - // Convert to float and update - float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; - bool rectified[ELEMENTS_PER_LDG]; - #pragma unroll - for (int j = 0; j < ELEMENTS_PER_LDG; ++j) { - rectified[j] = ((shfl_sync(relu_mask[i], j) & - (ONE_BITMASK << lane_id)) != 0); - } - to_float(x_math, x_storage[i]); - to_float(dy_math, dy_storage[i]); - - // Update the count. - count += is_valid[i]; - // Invert the count. - float inv_count = is_valid[i] ? 1.f / count : 0.f; - - relu_bwd(dy_math, rectified, is_valid[i]); - bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count); - - // Lastly we need 'dy' only for BN, so store the 'relu-dgrad'ed version - from_float(dy_storage[i], dy_math); - - // dZ for elementwise add - if (is_valid[i]) { - if (loop_i == OUTER_LOOPS - 1) { - stg_stream(&gmem_dst1[idx*params.c], dy_storage[i]); - } else { - stg(&gmem_dst1[idx*params.c], dy_storage[i]); - } - } - } - } - - // The elements to load and store in SMEM. - int smem_nhw = OUTER_LOOPS*pixels_per_iteration + cta_nhw_smem; - // Load elements from SMEM, update the CTA count. - int pixels_in_smem = min(PIXELS_PER_CTA_IN_SMEM, params.nhw-smem_nhw); - if (pixels_in_smem > 0) { - cta_count += pixels_in_smem; - for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) { - const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG; - const bool is_pixel_valid_nhw = - static_cast(idx) < static_cast(params.nhw); - const bool is_pixel_valid = is_pixel_valid_nhw && is_valid_c; - PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG], - dy_storage_local[PACKED_ELEMENTS_PER_LDG]; - bitmask_t relu_mask; -#ifdef __HIP_PLATFORM_HCC__ - int lane_id = threadIdx.x & 63; -#else - int lane_id = threadIdx.x & 31; -#endif - zero_array(x_storage_local); - zero_array(dy_storage_local); - if (is_pixel_valid_nhw) { - if (is_valid_c) { - ldg_stream(x_storage_local, &gmem_src[idx*params.c]); - ldg_stream(dy_storage_local, &gmem_dy[idx*params.c]); - } - if (lane_id < ELEMENTS_PER_LDG) { - relu_mask = gmem_relu_bitmask[idx * BITMASK_OFFSET + lane_id]; - } - } - bool rectified[ELEMENTS_PER_LDG]; - #pragma unroll - for (int j = 0; j < ELEMENTS_PER_LDG; ++j) { - rectified[j] = ((shfl_sync(relu_mask, j) & - (ONE_BITMASK << lane_id)) != 0); - } - - // The offset to store in SMEM. - int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG; - // Store in SMEM. - write_to_smem(&smem_storage_packed[offset], threadIdx.x, x_storage_local); - offset += PIXELS_PER_THREAD_IN_SMEM*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG; - // Update the count. - count += is_pixel_valid; - // Invert the count. - float inv_count = is_pixel_valid ? 1.f / count : 0.f; - - float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; - to_float(x_math, x_storage_local); - to_float(dy_math, dy_storage_local); - - relu_bwd(dy_math, rectified, is_pixel_valid); - bwd_update(dscale, dbias, dy_math, x_math, mean, inv_count); - - from_float(dy_storage_local, dy_math); - // dZ for elementwise add - if (is_pixel_valid) { - stg_stream(&gmem_dst1[idx*params.c], dy_storage_local); - } - // only store the 'relu-dgrad'ed version! - write_to_smem(&smem_storage_packed[offset], threadIdx.x, dy_storage_local); - } - } - - // We scale the mean by the number of elements. It brings more stability. - #pragma unroll - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - dbias[i] *= count; - dscale[i] *= count; - } - - // dscale parallel sum -#ifdef __HIP_PLATFORM_HCC__ - ParallelSums::template dispatch( -#else - ParallelSums::dispatch( -#endif - smem, dscale, thread_in_cta_nhw); - __syncthreads(); - // The values in shared memory correspond to the CTA-wide sums. - read_from_smem(dscale, smem, thread_in_cta_c); - __syncthreads(); - - // dbias parallel sum -#ifdef __HIP_PLATFORM_HCC__ - ParallelSums::template dispatch( -#else - ParallelSums::dispatch( -#endif - smem, dbias, thread_in_cta_nhw); - __syncthreads(); - // The values in shared memory correspond to the CTA-wide sums. - read_from_smem(dbias, smem, thread_in_cta_c); - __syncthreads(); - - // The workspace in global memory is distributed across the different CTA. - int gmem_sums_offset = c_blk_index*gridDim.x*C_ELEMENTS_PER_CTA*2; - // Write the data for the CTA to global memory. - float *gmem_sums = ¶ms.gmem_sums[gmem_sums_offset]; - if (threadIdx.x < THREADS_PER_PIXEL) { - const int idx = blockIdx.x*THREADS_PER_PIXEL + threadIdx.x; - write_to_gmem(&gmem_sums[ 0], idx, dscale); - write_to_gmem(&gmem_sums[C_ELEMENTS_PER_CTA*gridDim.x], idx, dbias); - } - - // The counters to count how many CTAs have retired at this point. - // A given cta uses the same counter every other time through the outer loop. - int *gmem_retired_ctas = ¶ms.gmem_retired_ctas[c_blk_index % (2 * gridDim.y)]; - inter_block_sync(gmem_retired_ctas, gridDim.x, blockIdx.x == 0); - - // Reset the accumulators for global summation - zero_array(dscale); - zero_array(dbias); - - // Build the global accumulation - #pragma unroll 1 - for (int idx = threadIdx.x; idx < THREADS_PER_PIXEL*gridDim.x; idx += THREADS_PER_CTA) { - float tmp1[ELEMENTS_PER_LDG], tmp2[ELEMENTS_PER_LDG]; - read_from_gmem(tmp1, gmem_sums, idx); - read_from_gmem(tmp2, gmem_sums+C_ELEMENTS_PER_CTA*gridDim.x, idx); - - #pragma unroll - for (int i = 0; i < ELEMENTS_PER_LDG; ++i) { - dscale[i] += tmp1[i]; - dbias[i] += tmp2[i]; - } - } - - // dscale parallel sum -#ifndef __HIP_PLATFORM_HCC__ - if (params.sync_iters>0) { - ParallelSums::dispatchX( - smem, dscale, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+1, params.magic, params.sync_iters); - } else { -#endif -#ifdef __HIP_PLATFORM_HCC__ - ParallelSums::template dispatch( -#else - ParallelSums::dispatch( -#endif - smem, dscale, thread_in_cta_nhw); -#ifndef __HIP_PLATFORM_HCC__ - } -#endif - - __syncthreads(); - // The values in shared memory correspond to the CTA-wide sums. - read_from_smem(dscale, smem, thread_in_cta_c); - __syncthreads(); - - // dbias parallel sum -#ifndef __HIP_PLATFORM_HCC__ - if (params.sync_iters>0) { - ParallelSums::dispatchX( - smem, dbias, thread_in_cta_nhw, params.my_data, params.pair_datas, 4*c_blk_index+0, params.magic, params.sync_iters); - } else { -#endif -#ifdef __HIP_PLATFORM_HCC__ - ParallelSums::template dispatch( -#else - ParallelSums::dispatch( -#endif - smem, dbias, thread_in_cta_nhw); -#ifndef __HIP_PLATFORM_HCC__ - } -#endif - - __syncthreads(); - // The values in shared memory correspond to the CTA-wide sums. - read_from_smem(dbias, smem, thread_in_cta_c); - - // Normalize the dscale. - float var[ELEMENTS_PER_LDG]; - zero_array(var); - if (is_valid_c) { - read_from_gmem(var, params.gmem_saved_var, thread_c/ELEMENTS_PER_LDG); - } - multiply(dscale, var); - - // store dscale/dbias - bool is_valid_for_saving = is_valid_c && blockIdx.x == 0 && thread_in_cta_nhw == 0; - if (is_valid_for_saving) { - if (params.sync_iters>0) - { - scaled_write_to_gmem(params.gmem_dscale, thread_c/ELEMENTS_PER_LDG, dscale, params.wgrad_coeff); - scaled_write_to_gmem(params.gmem_dbias, thread_c/ELEMENTS_PER_LDG, dbias, params.wgrad_coeff); - } else { - write_to_gmem(params.gmem_dscale, thread_c/ELEMENTS_PER_LDG, dscale); - write_to_gmem(params.gmem_dbias, thread_c/ELEMENTS_PER_LDG, dbias); - } - } - - // Further normalize the dscale to be used in dx calculation - float scale[ELEMENTS_PER_LDG]; - zero_array(scale); - if (is_valid_c) { - read_from_gmem(scale, params.gmem_scale, thread_c/ELEMENTS_PER_LDG); - } - multiply(dscale, var); - // scale the inv-var as well, afterwards - multiply(var, scale); - - // inverse count - float inv_count = params.svar_inv_count; - - // The base pointer to write to. - uint16_t *const gmem_dst = ¶ms.gmem_dst[thread_c]; - - // Store the elements in registers. - #pragma unroll 1 - for (int loop_i = OUTER_LOOPS-1; loop_i >= 0; --loop_i) { - // The value for nhw. - int out_nhw = cta_nhw_regs + loop_i*pixels_per_iteration; - - // Normalize the elements and write to memory. - #pragma unroll - for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { - const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG; - const bool is_valid = ((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c; - // Convert to float. - float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; - to_float(x_math, x_storage[i]); - to_float(dy_math, dy_storage[i]); - - float dx[ELEMENTS_PER_LDG]; - bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count); - - // Write back. - if (is_valid) { - stg_stream(&gmem_dst[idx*params.c], dx); - } - } - - // The next value of nhw. - out_nhw -= pixels_per_iteration; - - // Read the next elements from memory. - #pragma unroll - for (int i = 0; i < PIXELS_PER_THREAD_IN_REGISTERS; ++i) { - const int idx = out_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG; - float y[ELEMENTS_PER_LDG]; - zero_array(y); - if (((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c) { - ldg_stream(x_storage[i], &gmem_src[idx*params.c]); - ldg_stream(dy_storage[i], &gmem_dst1[idx*params.c]); - } - } - } - - // Normalize the elements from SMEM and write them out. - if (pixels_in_smem > 0) { - for (int i = 0; i < PIXELS_PER_THREAD_IN_SMEM; ++i) { - const int idx = smem_nhw + thread_in_cta_nhw + i*PIXELS_PER_LDG; - const bool is_valid = ((unsigned int)idx < (unsigned int)params.nhw) && is_valid_c; - if (is_valid) { - // Read from SMEM. - int offset = i*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG; - PackedStorageType x_storage_local[PACKED_ELEMENTS_PER_LDG], - dy_storage_local[PACKED_ELEMENTS_PER_LDG]; - read_from_smem(x_storage_local, &smem_storage_packed[offset], threadIdx.x); - offset += PIXELS_PER_THREAD_IN_SMEM*THREADS_PER_CTA*PACKED_ELEMENTS_PER_LDG; - read_from_smem(dy_storage_local, &smem_storage_packed[offset], threadIdx.x); - float x_math[ELEMENTS_PER_LDG], dy_math[ELEMENTS_PER_LDG]; - to_float(x_math, x_storage_local); - to_float(dy_math, dy_storage_local); - - float dx[ELEMENTS_PER_LDG]; - bwd_dx(dx, dy_math, var, x_math, mean, dscale, dbias, inv_count); - - // Write back. - stg_stream(&gmem_dst[idx*params.c], dx); - } - } - } - // We're about to start on the next c-blk. Needed? - __syncthreads(); - } -} - -#endif // MXNET_OPERATOR_NN_CUDNN_NHWC_BATCH_NORM_KERNEL_H_ diff --git a/apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp b/apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp deleted file mode 100644 index b026acf..0000000 --- a/apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp +++ /dev/null @@ -1,139 +0,0 @@ -#include - -#include -#include - -void index_mul_2d_float_foward_cuda(at::Tensor &out, - const at::Tensor &in1, - const at::Tensor &in2, - const at::Tensor &idx1); - -void index_mul_2d_float_backward_cuda(at::Tensor &grad_in1, - at::Tensor &grad_in2, - const at::Tensor &grad_out, - const at::Tensor &in1, - const at::Tensor &in2, - const at::Tensor &idx1); - -void index_mul_2d_float_backward_backward_cuda(at::Tensor &grad_grad_out, - at::Tensor &grad_in1, - at::Tensor &grad_in2, - const at::Tensor &grad_out, - const at::Tensor &grad_grad_in1, - const at::Tensor &grad_grad_in2, - const at::Tensor &in1, - const at::Tensor &in2, - const at::Tensor &idx1); - -void index_mul_2d_half_foward_cuda(at::Tensor &out, - const at::Tensor &in1, - const at::Tensor &in2, - const at::Tensor &idx1); - -void index_mul_2d_half_backward_cuda(at::Tensor &grad_in1, - at::Tensor &grad_in2, - const at::Tensor &grad_out, - const at::Tensor &in1, - const at::Tensor &in2, - const at::Tensor &idx1); - -void index_mul_2d_half_backward_backward_cuda(at::Tensor &grad_grad_out, - at::Tensor &grad_in1, - at::Tensor &grad_in2, - const at::Tensor &grad_out, - const at::Tensor &grad_grad_in1, - const at::Tensor &grad_grad_in2, - const at::Tensor &in1, - const at::Tensor &in2, - const at::Tensor &idx1); - -#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) \ - AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) \ - CHECK_CUDA(x); \ - CHECK_CONTIGUOUS(x) - -void index_mul_2d_float_forward( - at::Tensor &out, - const at::Tensor &in1, - const at::Tensor &in2, - const at::Tensor &idx1) -{ - return index_mul_2d_float_foward_cuda(out, in1, in2, idx1); -} - -void index_mul_2d_float_backward( - at::Tensor &grad_in1, - at::Tensor &grad_in2, - const at::Tensor &grad_out, - const at::Tensor &in1, - const at::Tensor &in2, - const at::Tensor &idx1) -{ - return index_mul_2d_float_backward_cuda(grad_in1, grad_in2, grad_out, in1, in2, idx1); -} - -void index_mul_2d_float_backwrad_backward( - at::Tensor &grad_grad_out, - at::Tensor &grad_in1, - at::Tensor &grad_in2, - const at::Tensor &grad_out, - const at::Tensor &grad_grad_in1, - const at::Tensor &grad_grad_in2, - const at::Tensor &in1, - const at::Tensor &in2, - const at::Tensor &idx1) -{ - return index_mul_2d_float_backward_backward_cuda(grad_grad_out, grad_in1, grad_in2, grad_out, grad_grad_in1, grad_grad_in2, in1, in2, idx1); -} - -void index_mul_2d_half_forward( - at::Tensor &out, - const at::Tensor &in1, - const at::Tensor &in2, - const at::Tensor &idx1) -{ - return index_mul_2d_half_foward_cuda(out, in1, in2, idx1); -} - -void index_mul_2d_half_backward( - at::Tensor &grad_in1, - at::Tensor &grad_in2, - const at::Tensor &grad_out, - const at::Tensor &in1, - const at::Tensor &in2, - const at::Tensor &idx1) -{ - return index_mul_2d_half_backward_cuda(grad_in1, grad_in2, grad_out, in1, in2, idx1); -} - -void index_mul_2d_half_backwrad_backward( - at::Tensor &grad_grad_out, - at::Tensor &grad_in1, - at::Tensor &grad_in2, - const at::Tensor &grad_out, - const at::Tensor &grad_grad_in1, - const at::Tensor &grad_grad_in2, - const at::Tensor &in1, - const at::Tensor &in2, - const at::Tensor &idx1) -{ - return index_mul_2d_half_backward_backward_cuda(grad_grad_out, grad_in1, grad_in2, grad_out, grad_grad_in1, grad_grad_in2, in1, in2, idx1); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("float_forward", &index_mul_2d_float_forward, - "index mul float calculation forward (CUDA)"); - m.def("float_backward", &index_mul_2d_float_backward, - "index mul float calculation backward (CUDA)"); - m.def("float_backward_backward", &index_mul_2d_float_backwrad_backward, - "index mul float calculation backward backward (CUDA)"); - m.def("half_forward", &index_mul_2d_half_forward, - "index mul half calculation forward (CUDA)"); - m.def("half_backward", &index_mul_2d_half_backward, - "index mul half calculation backward (CUDA)"); - m.def("half_backward_backward", &index_mul_2d_half_backwrad_backward, - "index mul half calculation backward backward (CUDA)"); -} - diff --git a/apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda_kernel.cu b/apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda_kernel.cu deleted file mode 100644 index 4f18da3..0000000 --- a/apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda_kernel.cu +++ /dev/null @@ -1,492 +0,0 @@ -#include -#include -#include -#ifdef ATEN_ATOMIC_HEADER - #include -#else - #include -#endif - - -__global__ void index_mul_2d_float_dim64( - float *out, - const float *in1, - const float *in2, - const int64_t *idx1, - const int64_t size) -{ - const int tidx = threadIdx.x; - const int tidy = threadIdx.y; - const int bidx = blockIdx.x; - const int start_idx = bidx * blockDim.y + tidy; - constexpr int fea_dim = 64; - - if (start_idx < size) { - int64_t vec_idx1 = (idx1[start_idx] * fea_dim) / 4 + tidx; - int64_t vec_idx2 = (start_idx * fea_dim) / 4 + tidx; - - float4 res, src1, src2; - src1 = reinterpret_cast(in1)[vec_idx1]; - src2 = reinterpret_cast(in2)[vec_idx2]; - res.x = src1.x * src2.x; - res.y = src1.y * src2.y; - res.z = src1.z * src2.z; - res.w = src1.w * src2.w; - reinterpret_cast(out)[vec_idx2] = res; - } -} - -__global__ void index_mul_2d_float( - float *out, - const float *in1, - const float *in2, - const int64_t *idx1, - const int64_t size, - const int64_t fea_dim) -{ - const int tidx = threadIdx.x; - const int tidy = threadIdx.y; - const int bidx = blockIdx.x; - const int start_idx = bidx * blockDim.y + tidy; - const int stride = blockDim.x; - - if (start_idx < size) { - int64_t vec_idx1 = (idx1[start_idx] * fea_dim); - int64_t vec_idx2 = (start_idx * fea_dim); - - for (int i = tidx; i < fea_dim; i += stride) { - out[vec_idx2 + i] = in1[vec_idx1 + i] * in2[vec_idx2 + i]; - } - } -} - -__global__ void index_mul_2d_half( - at::Half *out, - const at::Half *in1, - const at::Half *in2, - const int64_t *idx1, - const int64_t size, - const int64_t fea_dim) -{ - const int tidx = threadIdx.x; - const int tidy = threadIdx.y; - const int bidx = blockIdx.x; - const int start_idx = bidx * blockDim.y + tidy; - const int stride = blockDim.x; - - if (start_idx < size) { - int64_t vec_idx1 = (idx1[start_idx] * fea_dim); - int64_t vec_idx2 = (start_idx * fea_dim); - - for (int i = tidx; i < fea_dim; i += stride) { - out[vec_idx2 + i] = at::Half(static_cast(in1[vec_idx1 + i]) * static_cast(in2[vec_idx2 + i])); - } - } -} - -__global__ void index_mul_2d_grad_float_dim64( - float *grad_in1, - float *grad_in2, - const float *grad_out, - const float *in1, - const float *in2, - const int64_t *idx1, - const int64_t size) -{ - const int tidx = threadIdx.x; - const int tidy = threadIdx.y; - const int bidx = blockIdx.x; - const int start_idx = bidx * blockDim.y + tidy; - constexpr int fea_dim = 64; - - if (start_idx < size) { - int64_t vec_idx1 = (idx1[start_idx] * fea_dim) / 4 + tidx; - int64_t vec_idx2 = (start_idx * fea_dim) / 4 + tidx; - - float4 src_in1, src_in2, src_grad_out, dst_grad_in2; - src_grad_out = reinterpret_cast(grad_out)[vec_idx2]; - src_in1 = reinterpret_cast(in1)[vec_idx1]; - src_in2 = reinterpret_cast(in2)[vec_idx2]; - int64_t grad_in1_base_idx = idx1[start_idx] * fea_dim + tidx * 4; - gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 0, src_grad_out.x * src_in2.x); - gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 1, src_grad_out.y * src_in2.y); - gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 2, src_grad_out.z * src_in2.z); - gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 3, src_grad_out.w * src_in2.w); - dst_grad_in2.x = src_grad_out.x * src_in1.x; - dst_grad_in2.y = src_grad_out.y * src_in1.y; - dst_grad_in2.z = src_grad_out.z * src_in1.z; - dst_grad_in2.w = src_grad_out.w * src_in1.w; - reinterpret_cast(grad_in2)[vec_idx2] = dst_grad_in2; - } -} - -__global__ void index_mul_2d_grad_float( - float *grad_in1, - float *grad_in2, - const float *grad_out, - const float *in1, - const float *in2, - const int64_t *idx1, - const int64_t size, - const int64_t fea_dim) -{ - const int tidx = threadIdx.x; - const int tidy = threadIdx.y; - const int bidx = blockIdx.x; - const int start_idx = bidx * blockDim.y + tidy; - const int stride = blockDim.x; - - if (start_idx < size) { - int64_t vec_idx1 = idx1[start_idx] * fea_dim; - int64_t vec_idx2 = start_idx * fea_dim; - - for (int i = tidx; i < fea_dim; i += stride) { - float src_in1 = in1[vec_idx1 + i]; - float src_in2 = in2[vec_idx2 + i]; - float src_grad_out = grad_out[vec_idx2 + i]; - grad_in2[vec_idx2 + i] = src_grad_out * src_in1; - gpuAtomicAdd(grad_in1 + vec_idx1 + i, src_grad_out * src_in2); - } - } -} - -__global__ void index_mul_2d_grad_half( - at::Half *grad_in1, - at::Half *grad_in2, - const at::Half *grad_out, - const at::Half *in1, - const at::Half *in2, - const int64_t *idx1, - const int64_t size, - const int64_t fea_dim) -{ - const int tidx = threadIdx.x; - const int tidy = threadIdx.y; - const int bidx = blockIdx.x; - const int start_idx = bidx * blockDim.y + tidy; - const int stride = blockDim.x; - - if (start_idx < size) { - int64_t vec_idx1 = idx1[start_idx] * fea_dim; - int64_t vec_idx2 = start_idx * fea_dim; - - for (int i = tidx; i < fea_dim; i += stride) { - float src_in1 = static_cast(in1[vec_idx1 + i]); - float src_in2 = static_cast(in2[vec_idx2 + i]); - float src_grad_out = static_cast(grad_out[vec_idx2 + i]); - grad_in2[vec_idx2 + i] = at::Half(src_grad_out * src_in1); - gpuAtomicAdd(grad_in1 + vec_idx1 + i, at::Half(src_grad_out * src_in2)); - } - } -} - -__global__ void index_mul_2d_grad_grad_float_dim64( - float *grad_grad_out, - float *grad_in1, - float *grad_in2, - const float *grad_out, - const float *grad_grad_in1, - const float *grad_grad_in2, - const float *in1, - const float *in2, - const int64_t *idx1, - const int64_t size) -{ - const int tidx = threadIdx.x; - const int tidy = threadIdx.y; - const int bidx = blockIdx.x; - const int start_idx = bidx * blockDim.y + tidy; - constexpr int fea_dim = 64; - - if (start_idx < size) { - int64_t vec_idx1 = (idx1[start_idx] * fea_dim) / 4 + tidx; - int64_t vec_idx2 = (start_idx * fea_dim) / 4 + tidx; - - float4 src_grad_grad_in1, src_in1, src_grad_grad_in2, src_in2, src_grad_out; - float4 dst_grad_grad_out, dst_grad_in2; - src_grad_grad_in1 = reinterpret_cast(grad_grad_in1)[vec_idx1]; - src_in1 = reinterpret_cast(in1)[vec_idx1]; - src_grad_grad_in2 = reinterpret_cast(grad_grad_in2)[vec_idx2]; - src_in2 = reinterpret_cast(in2)[vec_idx2]; - dst_grad_grad_out.x = src_grad_grad_in1.x * src_in2.x + src_grad_grad_in2.x * src_in1.x; - dst_grad_grad_out.y = src_grad_grad_in1.y * src_in2.y + src_grad_grad_in2.y * src_in1.y; - dst_grad_grad_out.z = src_grad_grad_in1.z * src_in2.z + src_grad_grad_in2.z * src_in1.z; - dst_grad_grad_out.w = src_grad_grad_in1.w * src_in2.w + src_grad_grad_in2.w * src_in1.w; - reinterpret_cast(grad_grad_out)[vec_idx2] = dst_grad_grad_out; - src_grad_out = reinterpret_cast(grad_out)[vec_idx2]; - int64_t grad_in1_base_idx = idx1[start_idx] * fea_dim + tidx * 4; - gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 0, src_grad_grad_in2.x * src_grad_out.x); - gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 1, src_grad_grad_in2.y * src_grad_out.y); - gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 2, src_grad_grad_in2.z * src_grad_out.z); - gpuAtomicAdd(grad_in1 + grad_in1_base_idx + 3, src_grad_grad_in2.w * src_grad_out.w); - dst_grad_in2.x = src_grad_grad_in1.x * src_grad_out.x; - dst_grad_in2.y = src_grad_grad_in1.y * src_grad_out.y; - dst_grad_in2.z = src_grad_grad_in1.z * src_grad_out.z; - dst_grad_in2.w = src_grad_grad_in1.w * src_grad_out.w; - reinterpret_cast(grad_in2)[vec_idx2] = dst_grad_in2; - } -} - -__global__ void index_mul_2d_grad_grad_float( - float *grad_grad_out, - float *grad_in1, - float *grad_in2, - const float *grad_out, - const float *grad_grad_in1, - const float *grad_grad_in2, - const float *in1, - const float *in2, - const int64_t *idx1, - const int64_t size, - const int64_t fea_dim) -{ - const int tidx = threadIdx.x; - const int tidy = threadIdx.y; - const int bidx = blockIdx.x; - const int start_idx = bidx * blockDim.y + tidy; - const int stride = blockDim.x; - - if (start_idx < size) { - int64_t vec_idx1 = idx1[start_idx] * fea_dim; - int64_t vec_idx2 = start_idx * fea_dim; - - for (int i = tidx; i < fea_dim; i += stride) { - float src_grad_grad_in1 = grad_grad_in1[vec_idx1 + i]; - float src_grad_grad_in2 = grad_grad_in2[vec_idx2 + i]; - float src_in1 = in1[vec_idx1 + i]; - float src_in2 = in2[vec_idx2 + i]; - float src_grad_out = grad_out[vec_idx2 + i]; - grad_grad_out[vec_idx2 + i] = src_grad_grad_in1 * src_in2 + src_grad_grad_in2 * src_in1; - grad_in2[vec_idx2 + i] = src_grad_grad_in1 * src_grad_out; - gpuAtomicAdd(grad_in1 + vec_idx1 + i, src_grad_grad_in2 * src_grad_out); - } - } -} - -__global__ void index_mul_2d_grad_grad_half( - at::Half *grad_grad_out, - at::Half *grad_in1, - at::Half *grad_in2, - const at::Half *grad_out, - const at::Half *grad_grad_in1, - const at::Half *grad_grad_in2, - const at::Half *in1, - const at::Half *in2, - const int64_t *idx1, - const int64_t size, - const int64_t fea_dim) -{ - const int tidx = threadIdx.x; - const int tidy = threadIdx.y; - const int bidx = blockIdx.x; - const int start_idx = bidx * blockDim.y + tidy; - const int stride = blockDim.x; - - if (start_idx < size) { - int64_t vec_idx1 = idx1[start_idx] * fea_dim; - int64_t vec_idx2 = start_idx * fea_dim; - - for (int i = tidx; i < fea_dim; i += stride) { - float src_grad_grad_in1 = static_cast(grad_grad_in1[vec_idx1 + i]); - float src_grad_grad_in2 = static_cast(grad_grad_in2[vec_idx2 + i]); - float src_in1 = static_cast(in1[vec_idx1 + i]); - float src_in2 = static_cast(in2[vec_idx2 + i]); - float src_grad_out = static_cast(grad_out[vec_idx2 + i]); - grad_grad_out[vec_idx2 + i] = at::Half(src_grad_grad_in1 * src_in2 + src_grad_grad_in2 * src_in1); - grad_in2[vec_idx2 + i] = at::Half(src_grad_grad_in1 * src_grad_out); - gpuAtomicAdd(grad_in1 + vec_idx1 + i, at::Half(src_grad_grad_in2 * src_grad_out)); - } - } -} - -void index_mul_2d_float_foward_cuda(at::Tensor &out, - const at::Tensor &in1, - const at::Tensor &in2, - const at::Tensor &idx1) { - const int64_t size = in2.size(0); - const int64_t fea_dim = in2.size(1); - if (size < 0){ - return; - } - - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - if (fea_dim == 64) { - const int BLOCK_THREADS_DIMX = 16; - const int BLOCK_THREADS_DIMY = 16; - const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; - dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1); - - index_mul_2d_float_dim64<<>>( - out.data_ptr(), in1.data_ptr(), in2.data_ptr(), - idx1.data_ptr(), size); - } else { - const int BLOCK_THREADS_DIMX = 32; - const int BLOCK_THREADS_DIMY = 8; - const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; - dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1); - - index_mul_2d_float<<>>( - out.data_ptr(), in1.data_ptr(), in2.data_ptr(), - idx1.data_ptr(), size, fea_dim); - } - - AT_CUDA_CHECK(cudaGetLastError()); -} - -void index_mul_2d_float_backward_cuda(at::Tensor &grad_in1, - at::Tensor &grad_in2, - const at::Tensor &grad_out, - const at::Tensor &in1, - const at::Tensor &in2, - const at::Tensor &idx1) { - const int64_t size = in2.size(0); - const int64_t fea_dim = in2.size(1); - if (size < 0){ - return; - } - - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - if (fea_dim == 64) { - const int BLOCK_THREADS_DIMX = 16; - const int BLOCK_THREADS_DIMY = 16; - const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; - dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1); - - index_mul_2d_grad_float_dim64<<>>( - grad_in1.data_ptr(), grad_in2.data_ptr(), grad_out.data_ptr(), - in1.data_ptr(), in2.data_ptr(), idx1.data_ptr(), size); - - AT_CUDA_CHECK(cudaGetLastError()); - } else { - const int BLOCK_THREADS_DIMX = 32; - const int BLOCK_THREADS_DIMY = 8; - const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; - dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1); - - index_mul_2d_grad_float<<>>( - grad_in1.data_ptr(), grad_in2.data_ptr(), grad_out.data_ptr(), - in1.data_ptr(), in2.data_ptr(), idx1.data_ptr(), size, fea_dim); - } -} - -void index_mul_2d_float_backward_backward_cuda(at::Tensor &grad_grad_out, - at::Tensor &grad_in1, - at::Tensor &grad_in2, - const at::Tensor &grad_out, - const at::Tensor &grad_grad_in1, - const at::Tensor &grad_grad_in2, - const at::Tensor &in1, - const at::Tensor &in2, - const at::Tensor &idx1) { - const int64_t size = in2.size(0); - const int64_t fea_dim = in2.size(1); - if (size < 0){ - return; - } - - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - if (fea_dim == 64) { - const int BLOCK_THREADS_DIMX = 16; - const int BLOCK_THREADS_DIMY = 16; - const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; - dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1); - - index_mul_2d_grad_grad_float_dim64<<>>( - grad_grad_out.data_ptr(), grad_in1.data_ptr(), grad_in2.data_ptr(), - grad_out.data_ptr(), grad_grad_in1.data_ptr(), grad_grad_in2.data_ptr(), - in1.data_ptr(), in2.data_ptr(), idx1.data_ptr(), size); - } else { - const int BLOCK_THREADS_DIMX = 32; - const int BLOCK_THREADS_DIMY = 8; - const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; - dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1); - - index_mul_2d_grad_grad_float<<>>( - grad_grad_out.data_ptr(), grad_in1.data_ptr(), grad_in2.data_ptr(), - grad_out.data_ptr(), grad_grad_in1.data_ptr(), grad_grad_in2.data_ptr(), - in1.data_ptr(), in2.data_ptr(), idx1.data_ptr(), size, fea_dim); - } - - AT_CUDA_CHECK(cudaGetLastError()); -} - -void index_mul_2d_half_foward_cuda(at::Tensor &out, - const at::Tensor &in1, - const at::Tensor &in2, - const at::Tensor &idx1) { - const int64_t size = in2.size(0); - const int64_t fea_dim = in2.size(1); - if (size < 0){ - return; - } - - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - const int BLOCK_THREADS_DIMX = 32; - const int BLOCK_THREADS_DIMY = 8; - const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; - dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1); - - index_mul_2d_half<<>>( - out.data_ptr(), in1.data_ptr(), in2.data_ptr(), - idx1.data_ptr(), size, fea_dim); - - AT_CUDA_CHECK(cudaGetLastError()); -} - -void index_mul_2d_half_backward_cuda(at::Tensor &grad_in1, - at::Tensor &grad_in2, - const at::Tensor &grad_out, - const at::Tensor &in1, - const at::Tensor &in2, - const at::Tensor &idx1) { - const int64_t size = in2.size(0); - const int64_t fea_dim = in2.size(1); - if (size < 0){ - return; - } - - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - const int BLOCK_THREADS_DIMX = 32; - const int BLOCK_THREADS_DIMY = 8; - const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; - dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1); - - index_mul_2d_grad_half<<>>( - grad_in1.data_ptr(), grad_in2.data_ptr(), grad_out.data_ptr(), - in1.data_ptr(), in2.data_ptr(), idx1.data_ptr(), size, fea_dim); -} - -void index_mul_2d_half_backward_backward_cuda(at::Tensor &grad_grad_out, - at::Tensor &grad_in1, - at::Tensor &grad_in2, - const at::Tensor &grad_out, - const at::Tensor &grad_grad_in1, - const at::Tensor &grad_grad_in2, - const at::Tensor &in1, - const at::Tensor &in2, - const at::Tensor &idx1) { - const int64_t size = in2.size(0); - const int64_t fea_dim = in2.size(1); - if (size < 0){ - return; - } - - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - const int BLOCK_THREADS_DIMX = 32; - const int BLOCK_THREADS_DIMY = 8; - const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; - dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1); - - index_mul_2d_grad_grad_half<<>>( - grad_grad_out.data_ptr(), grad_in1.data_ptr(), grad_in2.data_ptr(), - grad_out.data_ptr(), grad_grad_in1.data_ptr(), grad_grad_in2.data_ptr(), - in1.data_ptr(), in2.data_ptr(), idx1.data_ptr(), size, fea_dim); - - AT_CUDA_CHECK(cudaGetLastError()); -} diff --git a/apex/contrib/csrc/layer_norm/ln.h b/apex/contrib/csrc/layer_norm/ln.h deleted file mode 100644 index f843820..0000000 --- a/apex/contrib/csrc/layer_norm/ln.h +++ /dev/null @@ -1,210 +0,0 @@ -#pragma once - -#include -#include -#if defined(__HIP_PLATFORM_HCC__) -#include "hip/hip_fp16.h" -#include "hip/hip_bfloat16.h" -#else -#include -#include -#endif - -namespace layer_norm { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct LaunchParams{ - - size_t workspace_bytes; - size_t barrier_size; - - cudaDeviceProp * props; - - cudaStream_t stream; - - Params params; - -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct ParamsBase { - ParamsBase() - : ctas_per_col(0) - , rows(0) - , cols(0) - , x(nullptr) - , mu(nullptr) - , rs(nullptr) - , gamma(nullptr) - , workspace(nullptr) - , barrier(nullptr) - { - } - - // For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x. - int ctas_per_col; - - // Input is interpreted as matrix. We normalize across columns. - int rows; - int cols; - - // Common data pointers. - void *x; - void *mu; - void *rs; - void *gamma; - - // Multi-CTA workspace in gmem. - void *workspace; - - // Multi-CTA sync barriers in gmem. - int *barrier; - -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct FwdParams : public ParamsBase { - FwdParams() - : ParamsBase() - , z(nullptr) - , beta(nullptr) - , epsilon(0.f) - { - } - - // Output of LN FWD. - void *z; - void *beta; - float epsilon; - -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct BwdParams : public ParamsBase { - BwdParams() - : ParamsBase() - , dz(nullptr) - , dbeta_part(nullptr) - , dgamma_part(nullptr) - , dx(nullptr) - , dbeta(nullptr) - , dgamma(nullptr) - { - } - - // Input: gradient wrt. LN FWD output. - void *dz; - - // Workspace for Wgrad pre-reduction. - void *dbeta_part; - void *dgamma_part; - - // Output: Dgrad. - void *dx; - // Output: Wgrad. - void *dbeta; - void *dgamma; - -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -using FwdFunction = std::function&, const bool)>; -using BwdFunction = std::function&, const bool)>; -using FunctionKey = uint64_t; -using FwdRegistry = std::unordered_map; -using BwdRegistry = std::unordered_map; - -extern FwdRegistry FWD_FUNCS; -extern BwdRegistry BWD_FUNCS; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -using fp32 = float; -using fp16 = half; -#if defined(__HIP_PLATFORM_HCC__) -using bf16 = hip_bfloat16; -#else -using bf16 = nv_bfloat16; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct TypeId{}; - -template<> -struct TypeId{ - constexpr static uint32_t Value = 0; -}; - -template<> -struct TypeId{ - constexpr static uint32_t Value = 1; -}; - -template<> -struct TypeId{ - constexpr static uint32_t Value = 2; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Type2Key{ - constexpr static uint32_t Value = TypeId::Value << S; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct WeightType2Key : public Type2Key{}; - -template -struct InputType2Key : public Type2Key{}; - -template -struct OutputType2Key : public Type2Key{}; - -template -struct ComputeType2Key : public Type2Key{}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Types2Key{ - constexpr static uint32_t Value = WeightType2Key::Value | InputType2Key::Value | OutputType2Key::Value | ComputeType2Key::Value; - constexpr static inline uint64_t get(const uint64_t hidden_size){ - constexpr uint64_t type_key = Value; - return (type_key << 32) | hidden_size; - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct FwdRegistrar{ - FwdRegistrar(FwdFunction f){ - uint64_t key = Types2Key::get(HIDDEN_SIZE); - FWD_FUNCS.insert({ key, f }); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct BwdRegistrar{ - BwdRegistrar(BwdFunction f){ - uint64_t key = Types2Key::get(HIDDEN_SIZE); - BWD_FUNCS.insert({ key, f }); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace layer_norm diff --git a/apex/contrib/csrc/layer_norm/ln_api.cpp b/apex/contrib/csrc/layer_norm/ln_api.cpp deleted file mode 100644 index 3893dd2..0000000 --- a/apex/contrib/csrc/layer_norm/ln_api.cpp +++ /dev/null @@ -1,246 +0,0 @@ -#include -#include "ATen/cuda/CUDAContext.h" - -#include "ln.h" - -/* - -Supported Type combinations: - -input compute weights output -======================================= -fp32 fp32 fp32 fp32 -fp16 fp32 fp16 fp16 -bf16 fp32 bf16 bf16 -fp32 fp32 fp16 fp16 -fp32 fp32 bf16 bf16 - -Remarks: -Output type = Weight type -Compute always in FP32 - -*/ - -namespace layer_norm { - -// Create registries and provide runtime versions of config hash functions. - -// FwdRegistry FWD_FUNCS; -// BwdRegistry BWD_FUNCS; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -uint32_t get_type_id(torch::Dtype dtype){ - if( dtype == torch::kFloat16 ) { - return TypeId::Value; - } else if( dtype == torch::kBFloat16 ) { - return TypeId::Value; - } else if( dtype == torch::kFloat32 ) { - return TypeId::Value; - } else { - TORCH_CHECK(false, "Type not supported: ", dtype); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -uint64_t get_key(torch::Dtype wtype, torch::Dtype itype, torch::Dtype otype, torch::Dtype ctype, uint64_t hidden_size) { - using namespace layer_norm; - uint64_t type_key = get_type_id(wtype) | (get_type_id(itype) << 2) | (get_type_id(otype) << 4) | (get_type_id(ctype) << 6); - uint64_t launcher_key = (type_key << 32) | hidden_size; - return launcher_key; -} - -} // namespace layer_norm - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -layer_norm::FwdFunction & get_fwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) { - auto iter = layer_norm::FWD_FUNCS.find(layer_norm::get_key(wtype, itype, otype, ctype, hidden_size)); - if( iter != layer_norm::FWD_FUNCS.end() ) { - return iter->second; - } else { - TORCH_CHECK(false, "FWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, otype, ctype); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -layer_norm::BwdFunction & get_bwd_launcher(torch::Dtype wtype, torch::Dtype itype, torch::Dtype otype, torch::Dtype ctype, uint32_t hidden_size) { - auto iter = layer_norm::BWD_FUNCS.find(layer_norm::get_key(wtype, itype, otype, ctype, hidden_size)); - if( iter != layer_norm::BWD_FUNCS.end() ) { - return iter->second; - } else { - TORCH_CHECK(false, "BWD: Unsupported hidden_size or types: ", hidden_size, wtype, itype, otype, ctype); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -std::vector ln_fwd(const at::Tensor &x, // BxSxhidden_size - const at::Tensor &gamma, // hidden_size - const at::Tensor &beta, // hidden_size - const float epsilon -) { - auto itype = x.scalar_type(); - auto wtype = gamma.scalar_type(); - auto otype = wtype; - auto ctype = torch::kFloat32; - - TORCH_CHECK(beta.scalar_type() == wtype); - - TORCH_CHECK(x.is_cuda()) - TORCH_CHECK(gamma.is_cuda()) - TORCH_CHECK(beta.is_cuda()) - - TORCH_CHECK(x.is_contiguous()); - auto sizes = x.sizes(); - TORCH_CHECK(sizes.size() == 2); - - const int rows = sizes[0]; - const int cols = sizes[1]; - auto hidden_size = gamma.numel(); - - TORCH_CHECK(gamma.sizes() == beta.sizes()); - TORCH_CHECK(hidden_size == cols); - - TORCH_CHECK(epsilon >= 0.f); - - auto opts = x.options(); - - auto z = torch::empty(sizes, opts.dtype(otype)); - - auto mu = torch::empty({ rows }, opts.dtype(ctype)); - auto rsigma = torch::empty({ rows }, opts.dtype(ctype)); - - layer_norm::LaunchParams launch_params; - - launch_params.props = at::cuda::getCurrentDeviceProperties(); - launch_params.stream = at::cuda::getCurrentCUDAStream().stream(); - - // Request the kernel launcher. - auto launcher = get_fwd_launcher(wtype, itype, otype, ctype, hidden_size); - - // Query the kernel-specific launch parameters. - launcher(launch_params, true); - - at::Tensor workspace, barrier; - - // Set the kernel runtime parameters. - layer_norm::FwdParams ¶ms = launch_params.params; - params.rows = rows; - params.cols = cols; - params.x = x.data_ptr(); - params.mu = mu.data_ptr(); - params.rs = rsigma.data_ptr(); - params.gamma = gamma.data_ptr(); - params.beta = beta.data_ptr(); - params.z = z.data_ptr(); - params.epsilon = epsilon; - - if( launch_params.barrier_size > 0 ) { - auto options = x.options(); - barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32)); - workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar)); - params.workspace = workspace.data_ptr(); - params.barrier = barrier.data_ptr(); - } - - // Launch the kernel. - launcher(launch_params, false); - - return { z, mu, rsigma }; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -std::vector ln_bwd(const at::Tensor &dz, // BxSxhidden_size - const at::Tensor &x, // BxSxhidden_size - const at::Tensor &mu, // BxS, FP32! - const at::Tensor &rsigma, // BxS, FP32! - const at::Tensor &gamma // hidden_size -) { - - auto itype = x.scalar_type(); - auto wtype = gamma.scalar_type(); - auto otype = wtype; - auto ctype = torch::kFloat32; - - TORCH_CHECK(dz.dtype() == otype); - TORCH_CHECK(mu.dtype() == ctype); - TORCH_CHECK(rsigma.dtype() == ctype); - - TORCH_CHECK(x.is_cuda()); - TORCH_CHECK(dz.is_cuda()); - TORCH_CHECK(mu.is_cuda()); - TORCH_CHECK(rsigma.is_cuda()); - TORCH_CHECK(gamma.is_cuda()); - - TORCH_CHECK(x.is_contiguous()); - TORCH_CHECK(dz.is_contiguous()); - - auto sizes = x.sizes(); - TORCH_CHECK(sizes.size() == 2); - TORCH_CHECK(dz.sizes() == sizes); - auto rows = sizes[0]; - auto cols = sizes[1]; - - auto hidden_size = gamma.numel(); - - TORCH_CHECK(mu.numel() == rows); - TORCH_CHECK(mu.sizes() == rsigma.sizes()); - - TORCH_CHECK(gamma.numel() == cols); - - auto options = x.options(); - - auto dx = torch::empty_like(x); - auto dgamma = torch::empty_like(gamma); - auto dbeta = torch::empty_like(gamma); - - layer_norm::LaunchParams launch_params; - launch_params.stream = at::cuda::getCurrentCUDAStream().stream(); - launch_params.props = at::cuda::getCurrentDeviceProperties(); - - auto launcher = get_bwd_launcher(wtype, itype, otype, ctype, hidden_size); - - launcher(launch_params, true); - - auto dgamma_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, options.dtype(ctype)); - auto dbeta_part = torch::empty({ launch_params.params.ctas_per_col, hidden_size }, options.dtype(ctype)); - at::Tensor workspace, barrier; - - layer_norm::BwdParams ¶ms = launch_params.params; - params.rows = rows; - params.cols = cols; - params.x = x.data_ptr(); - params.mu = mu.data_ptr(); - params.rs = rsigma.data_ptr(); - params.gamma = gamma.data_ptr(); - params.dz = dz.data_ptr(); - params.dx = dx.data_ptr(); - params.dbeta = dbeta.data_ptr(); - params.dgamma = dgamma.data_ptr(); - params.dbeta_part = dbeta_part.data_ptr(); - params.dgamma_part = dgamma_part.data_ptr(); - - if( launch_params.barrier_size > 0 ) { - // TODO Any way to avoid this? - barrier = torch::zeros(launch_params.barrier_size, options.dtype(torch::kInt32)); - workspace = torch::empty(launch_params.workspace_bytes, options.dtype(torch::kChar)); - params.workspace = workspace.data_ptr(); - params.barrier = barrier.data_ptr(); - } - - launcher(launch_params, false); - - return { dx, dgamma, dbeta, dgamma_part, dbeta_part }; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.doc() = "CUDA LayerNorm"; - m.def("ln_fwd", &ln_fwd, "Run LayerNorm forward kernel"); - m.def("ln_bwd", &ln_bwd, "Run LayerNorm backward kernel"); -} diff --git a/apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh b/apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh deleted file mode 100644 index 8595f5e..0000000 --- a/apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh +++ /dev/null @@ -1,315 +0,0 @@ -#pragma once - -namespace layer_norm { - -template -__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) -void ln_bwd_kernel(layer_norm::BwdParams params) { - - enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; - enum { WARPS_M = Ktraits::WARPS_M }; - enum { WARPS_N = Ktraits::WARPS_N }; - enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW }; - enum { COLS = Ktraits::COLS }; - enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW }; - enum { LDGS = Ktraits::LDGS }; - enum { NUM_ELTS = Ktraits::ELTS_PER_LDG }; - enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP }; - enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW }; - - using compute_t = typename Ktraits::compute_t; - using index_t = typename Ktraits::index_t; - using Ivec = typename Ktraits::Ivec; - using Ovec = typename Ktraits::Ovec; - using Wvec = typename Ktraits::Wvec; - using Cvec = typename Ktraits::Cvec; - using Reducer = typename Ktraits::Reducer; - using reduce_t = typename Reducer::Type; - - extern __shared__ char smem_[]; - - const index_t tidx = threadIdx.x; - const index_t bidn = blockIdx.x % CTAS_PER_ROW; - const index_t bidm = blockIdx.x / CTAS_PER_ROW; - const index_t lane = tidx % THREADS_PER_WARP; - const index_t warp = tidx / THREADS_PER_WARP; - const index_t warp_m = warp / Ktraits::WARPS_N; - const index_t warp_n = warp % Ktraits::WARPS_N; - const index_t tid_r = warp_n * THREADS_PER_WARP + lane; - - const index_t r = bidm * Ktraits::ROWS_PER_CTA + warp_m; - const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane; - - static_assert(COLS == THREADS_PER_ROW * LDGS * NUM_ELTS * CTAS_PER_ROW); - - Cvec dzy_sum[LDGS]; - Cvec dz_sum[LDGS]; - - memset(dzy_sum, 0, sizeof(dzy_sum)); - memset(dz_sum, 0, sizeof(dz_sum)); - - compute_t * smem_wgrad = reinterpret_cast(smem_); - char *smem_dgrad = smem_ + Ktraits::SMEM_BYTES_WGRAD; - - Reducer reducer(params, bidm, bidn, warp_m, warp_n, lane, smem_dgrad); - - Sum sum; - - constexpr float rn = 1.f / float(COLS); - Wvec gamma[LDGS]; - index_t idx = c; - #pragma unroll - for( int it = 0; it < LDGS; it++ ) { - gamma[it].load_from(params.gamma, idx); - idx += Ktraits::VEC_COLS_PER_LDG; - } - // TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the - // last blocks with syncthreads! - // grid stride over rows - #pragma unroll 1 - for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) { - const compute_t mu_r = static_cast(params.mu)[row]; - const compute_t rs_r = static_cast(params.rs)[row]; - Ivec x[LDGS]; - Ovec dz[LDGS]; - index_t idx = row * Ktraits::VEC_COLS + c; - #pragma unroll - for( int it = 0; it < LDGS; it++ ) { - dz[it].load_from(params.dz, idx); - x[it].load_from(params.x, idx); - idx += Ktraits::VEC_COLS_PER_LDG; - } - - compute_t dy[LDGS * NUM_ELTS]; - compute_t y[LDGS * NUM_ELTS]; - - compute_t mdy_local = 0.f; - compute_t mdyy_local = 0.f; - #pragma unroll - for( int it = 0; it < LDGS; it++ ) { - #pragma unroll - for( int jt = 0; jt < NUM_ELTS; jt++ ) { - compute_t x_tmp = x[it].data.elt[jt]; - compute_t y_tmp = rs_r * (x_tmp - mu_r); - compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]); - dy_tmp *= compute_t(dz[it].data.elt[jt]); - compute_t dz_tmp = dz[it].data.elt[jt]; - - mdy_local += dy_tmp; - mdyy_local += dy_tmp * y_tmp; - - dy[it * NUM_ELTS + jt] = dy_tmp; - y[it * NUM_ELTS + jt] = y_tmp; - - dzy_sum[it].data.elt[jt] += dz_tmp * y_tmp; - dz_sum[it].data.elt[jt] += dz_tmp; - } - } - - reduce_t result = reducer.allreduce({mdy_local, mdyy_local}, sum); - mdy_local = layer_norm::Get<0>::of(result) * rn; - mdyy_local = layer_norm::Get<1>::of(result) * rn; - - Ivec dx[LDGS]; - idx = row * Ktraits::VEC_COLS + c; - #pragma unroll - for( int it = 0; it < LDGS; it++ ) { - #pragma unroll - for( int jt = 0; jt < NUM_ELTS; jt++ ) { - compute_t dy_tmp = dy[it * NUM_ELTS + jt]; - compute_t y_tmp = y[it * NUM_ELTS + jt]; - compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp + mdy_local)); - dx[it].data.elt[jt] = dx_tmp; - } - dx[it].store_to(params.dx, idx); - idx += Ktraits::VEC_COLS_PER_LDG; - } - - } // end: grid stride loop - - if( WARPS_M == 1 ) { - idx = r * Ktraits::VEC_COLS + c; - #pragma unroll - for( int it = 0; it < LDGS; it++ ) { - dz_sum[it].store_to(params.dbeta_part, idx); - dzy_sum[it].store_to(params.dgamma_part, idx); - idx += Ktraits::VEC_COLS_PER_LDG; - } - } else { - static_assert(WARPS_M == 1 || Ktraits::CTAS_PER_ROW == 1, "Multiple rows per CTA not supported for Multi-CTA."); - // Finalize reduction of part dgamma and dbeta for this CTA - // by reducing over the rows held across the WARPS_M warps - - // Assumption: blockSize divides hidden size. - enum { NUM_RES = COLS / Ktraits::THREADS_PER_CTA }; - static_assert(NUM_RES * Ktraits::THREADS_PER_CTA == COLS, ""); - - idx = warp_m * Ktraits::VEC_COLS + tid_r; - #pragma unroll - for( int it = 0; it < LDGS; it++ ) { - dz_sum[it].store_to(smem_wgrad, idx); - idx += THREADS_PER_ROW; - } - __syncthreads(); - compute_t cta_dz_sum[NUM_RES]; - memset(cta_dz_sum, 0, sizeof(compute_t) * NUM_RES); - for( int it = 0; it < ROWS_PER_CTA; it++ ) { - for( int jt = 0; jt < NUM_RES; jt++ ) { - cta_dz_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA]; - } - } - __syncthreads(); - - idx = warp_m * Ktraits::VEC_COLS + tid_r; - #pragma unroll - for( int it = 0; it < LDGS; it++ ) { - dzy_sum[it].store_to(smem_wgrad, idx); - idx += THREADS_PER_ROW; - } - __syncthreads(); - compute_t cta_dzy_sum[NUM_RES]; - memset(cta_dzy_sum, 0, sizeof(compute_t) * NUM_RES); - for( int it = 0; it < ROWS_PER_CTA; it++ ) { - for( int jt = 0; jt < NUM_RES; jt++ ) { - cta_dzy_sum[jt] += smem_wgrad[it * COLS + tidx + jt * Ktraits::THREADS_PER_CTA]; - } - } - - compute_t *dgamma_part = static_cast(params.dgamma_part) + bidm * COLS + tidx; - for( int jt = 0; jt < NUM_RES; jt++ ) { - *dgamma_part = cta_dzy_sum[jt]; - dgamma_part += Ktraits::THREADS_PER_CTA; - } - - compute_t *dbeta_part = static_cast(params.dbeta_part) + bidm * COLS + tidx; - for( int jt = 0; jt < NUM_RES; jt++ ) { - *dbeta_part = cta_dz_sum[jt]; - dbeta_part += Ktraits::THREADS_PER_CTA; - } - } -} - -template -__global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) -void ln_bwd_finalize_kernel(BwdParams params) -{ - - using compute_t = typename Kernel_traits::compute_t; - using weight_t = typename Kernel_traits::weight_t; - using index_t = typename Kernel_traits::index_t; - using Reducer = typename Kernel_traits::Reducer; - using reduce_t = typename Reducer::Type; - - Sum sum; - enum { NUM_ELT = Kernel_traits::ELTS_PER_LDG }; - enum { THREADS_PER_WARP = Kernel_traits::THREADS_PER_WARP }; - - __shared__ char smem_[Kernel_traits::SMEM_BYTES_PER_CTA]; - - constexpr uint32_t bidm = 0; - - const uint32_t bidn = blockIdx.x; - const uint32_t tidx = threadIdx.x; - const uint32_t warp = tidx / THREADS_PER_WARP; - const uint32_t lane = tidx % THREADS_PER_WARP; - - Reducer reducer(params, bidm, bidn, 0, 0, lane, smem_); - - const uint32_t c = bidn * THREADS_PER_WARP + lane; - const uint32_t c_out = bidn * THREADS_PER_WARP / 2 + lane; - constexpr uint32_t COL_STRIDE = Kernel_traits::CTAS * THREADS_PER_WARP; - for( uint32_t col = c, col_out = c_out; col < Kernel_traits::COLS; col += COL_STRIDE, col_out += COL_STRIDE / 2 ) { - // Each thread sums over NUM_ELT columns. - Vec dbeta_local, dgamma_local; - memset(&dgamma_local, 0, sizeof(dgamma_local)); - memset(&dbeta_local, 0, sizeof(dbeta_local)); - for( uint32_t row = warp; row < params.ctas_per_col; row += Kernel_traits::ROWS_PER_CTA ) { - index_t idx = row * Kernel_traits::COLS + col; - - Vec dbeta_part, dgamma_part; - dbeta_part.load_from(params.dbeta_part, idx); - dgamma_part.load_from(params.dgamma_part, idx); - #pragma unroll - for( int it = 0; it < NUM_ELT; it++ ) { - dgamma_local.data.elt[it] += dgamma_part.data.elt[it]; - dbeta_local.data.elt[it] += dbeta_part.data.elt[it]; - } - } - - void * smem_gamma = smem_; - void * smem_beta = &smem_[Kernel_traits::SMEM_BYTES_TRANSPOSE]; - - const int write_row = warp; - const int write_col = lane ^ write_row; - const int write_idx = write_row * THREADS_PER_WARP + write_col; - - dgamma_local.store_to(smem_gamma, write_idx); - dbeta_local.store_to(smem_beta, write_idx); - - __syncthreads(); - - // It would be probably safe to reuse the first row of smem_beta and smem_gamma - void * smem_gamma_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE]; - void * smem_beta_out = &smem_[2 * Kernel_traits::SMEM_BYTES_TRANSPOSE + Kernel_traits::SMEM_BYTES_OUTPUT]; - - - // More than one iter iff ROWS_PER_CTA < 32. - for( int w = warp; w < THREADS_PER_WARP; w += Kernel_traits::ROWS_PER_CTA ) { - const int read_row = lane; - const int read_col = w ^ read_row; - const int read_idx = read_row * THREADS_PER_WARP + read_col; - - memset(&dbeta_local, 0, sizeof(dbeta_local)); - memset(&dgamma_local, 0, sizeof(dgamma_local)); - - // Load beta and gamma transposed - if(read_row < Kernel_traits::ROWS_PER_CTA){ - dbeta_local.load_from(smem_beta, read_idx); - dgamma_local.load_from(smem_gamma, read_idx); - } - - // Call reducer on the loaded value(s) and convert. - #pragma unroll - for( int it = 0; it < NUM_ELT; it++ ) { - compute_t b_i = dbeta_local.data.elt[it]; - compute_t g_i = dgamma_local.data.elt[it]; - b_i = reducer.allreduce(b_i, sum); - g_i = reducer.allreduce(g_i, sum); - - dgamma_local.data.elt[it] = g_i; - dbeta_local.data.elt[it] = b_i; - } - - // Leader stores the result at the current column. - if(lane == 0){ - dgamma_local.store_to(smem_gamma_out, w); - dbeta_local.store_to(smem_beta_out, w); - } - - } - - // All writes done. - __syncthreads(); - - // Pack and store: 2-wide stores with half the threads. - if( warp == Kernel_traits::ROWS_PER_CTA - 1 && lane < THREADS_PER_WARP / 2 ) { - - using src_t = typename TypeToVec2::Type; - using dst_t = typename TypeToVec2::Type; - Vec dbeta_vec2, dgamma_vec2; - Vec dbeta_out2, dgamma_out2; - - dgamma_vec2.load_from(smem_gamma_out, lane); - dbeta_vec2.load_from(smem_beta_out, lane); - #pragma unroll - for( int it = 0; it < NUM_ELT; it++ ) { - dgamma_out2.data.elt[it] = Converter::convert(dgamma_vec2.data.elt[it]); - dbeta_out2.data.elt[it] = Converter::convert(dbeta_vec2.data.elt[it]); - } - dgamma_out2.store_to(params.dgamma, col_out); - dbeta_out2.store_to(params.dbeta, col_out); - - } - } -} -} // namespace layer_norm diff --git a/apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu b/apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu deleted file mode 100644 index 8c7f904..0000000 --- a/apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu +++ /dev/null @@ -1,250 +0,0 @@ -#include "ln.h" -#include "ln_utils.cuh" -#include "ln_kernel_traits.h" -#include "ln_bwd_kernels.cuh" - -using namespace layer_norm; - -BwdRegistry layer_norm::BWD_FUNCS; - -template< - typename weight_t, - typename input_t, - typename output_t, - typename compute_t, - typename index_t, - int HIDDEN_SIZE, - int CTAS_PER_ROW, - int WARPS_M, - int WARPS_N, - int BYTES_PER_LDG_MAIN, - int BYTES_PER_LDG_FINAL -> -void launch_(LaunchParams &launch_params, const bool configure_params){ - - using Kernel_traits = Kernel_traits; - auto kernel = &ln_bwd_kernel; - - if( configure_params ) { - int ctas_per_sm; - cudaError_t status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES); - launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW; - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if(Kernel_traits::CTAS_PER_ROW > 1) { - launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; - launch_params.workspace_bytes = launch_params.params.ctas_per_col - * Kernel_traits::WARPS_M - * Kernel_traits::CTAS_PER_ROW - * sizeof(typename Kernel_traits::reduce_t) - * 2; - } - return; - } - - if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) { - #if defined(__HIP_PLATFORM_HCC__) - CHECK_CUDA(hipFuncSetAttribute((const void *)kernel, hipFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES)); - #else - CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES)); - #endif - } - auto stream = launch_params.stream; - auto ctas_per_col = launch_params.params.ctas_per_col; - - if( Kernel_traits::CTAS_PER_ROW == 1 ) { - kernel<<>>(launch_params.params); - } else { - dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col); - dim3 block(Kernel_traits::THREADS_PER_CTA); - void *params_ = (void *)&launch_params.params; - #if defined(__HIP_PLATFORM_HCC__) - hipLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES, stream); - #else - cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES, stream); - #endif - } - - using Kernel_traits_f = layer_norm::Kernel_traits_finalize; - - auto kernel_f = &layer_norm::ln_bwd_finalize_kernel; - kernel_f<<>>(launch_params.params); -} - -// Create backward launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL - -REGISTER_BWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 768, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 768, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); -// REGISTER_BWD_LAUNCHER( 768, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); -// REGISTER_BWD_LAUNCHER( 768, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); - -REGISTER_BWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_LAUNCHER( 1024, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); -// REGISTER_BWD_LAUNCHER( 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); -// REGISTER_BWD_LAUNCHER( 1024, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); - -REGISTER_BWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 1536, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER( 1536, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -// REGISTER_BWD_LAUNCHER( 1536, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); -// REGISTER_BWD_LAUNCHER( 1536, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); - -REGISTER_BWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 2048, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 2048, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -// REGISTER_BWD_LAUNCHER( 2048, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -// REGISTER_BWD_LAUNCHER( 2048, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); - -REGISTER_BWD_LAUNCHER( 2304, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER( 2304, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4); -REGISTER_BWD_LAUNCHER( 2304, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); -// REGISTER_BWD_LAUNCHER( 2304, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4); -// REGISTER_BWD_LAUNCHER( 2304, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); - -REGISTER_BWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 3072, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 3072, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -// REGISTER_BWD_LAUNCHER( 3072, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -// REGISTER_BWD_LAUNCHER( 3072, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); - -REGISTER_BWD_LAUNCHER( 3840, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER( 3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4); -REGISTER_BWD_LAUNCHER( 3840, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); -// REGISTER_BWD_LAUNCHER( 3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4); -// REGISTER_BWD_LAUNCHER( 3840, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); - -REGISTER_BWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 4096, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -// REGISTER_BWD_LAUNCHER( 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -// REGISTER_BWD_LAUNCHER( 4096, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); - -REGISTER_BWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 5120, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 5120, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -// REGISTER_BWD_LAUNCHER( 5120, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -// REGISTER_BWD_LAUNCHER( 5120, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); - -REGISTER_BWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER( 6144, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER( 6144, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); -// REGISTER_BWD_LAUNCHER( 6144, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); -// REGISTER_BWD_LAUNCHER( 6144, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); - -REGISTER_BWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 8192, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER( 8192, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4); -// REGISTER_BWD_LAUNCHER( 8192, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4); -// REGISTER_BWD_LAUNCHER( 8192, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4); - -REGISTER_BWD_LAUNCHER(10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(10240, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(10240, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4); -// REGISTER_BWD_LAUNCHER(10240, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4); -// REGISTER_BWD_LAUNCHER(10240, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4); - -REGISTER_BWD_LAUNCHER(12288, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(12288, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(12288, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); -// REGISTER_BWD_LAUNCHER(12288, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); -// REGISTER_BWD_LAUNCHER(12288, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); - -REGISTER_BWD_LAUNCHER(12800, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(12800, fp16, fp16, fp16, fp32, 5, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4); -// REGISTER_BWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 5, 1, 4, 8, 4); -// REGISTER_BWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4); - -REGISTER_BWD_LAUNCHER(14336, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER(14336, fp16, fp16, fp16, fp32, 4, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER(14336, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4); -// REGISTER_BWD_LAUNCHER(14336, bf16, bf16, bf16, fp32, 4, 1, 4, 8, 4); -// REGISTER_BWD_LAUNCHER(14336, bf16, fp32, bf16, fp32, 4, 1, 4, 8, 4); - -REGISTER_BWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 4, 1, 4, 4, 4); -REGISTER_BWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4); -// REGISTER_BWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 4, 1, 4, 4, 4); -// REGISTER_BWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 4, 1, 4, 8, 4); - -REGISTER_BWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(16384, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); -// REGISTER_BWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); -// REGISTER_BWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); - -REGISTER_BWD_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(18432, fp16, fp16, fp16, fp32, 4, 1, 4, 8, 4); -REGISTER_BWD_LAUNCHER(18432, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); -// REGISTER_BWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 4, 1, 4, 8, 4); -// REGISTER_BWD_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); - -REGISTER_BWD_LAUNCHER(20480, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(20480, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(20480, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); -// REGISTER_BWD_LAUNCHER(20480, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); -// REGISTER_BWD_LAUNCHER(20480, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); - -REGISTER_BWD_LAUNCHER(24576, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(24576, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(24576, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4); -// REGISTER_BWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4); -// REGISTER_BWD_LAUNCHER(24576, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4); - -REGISTER_BWD_LAUNCHER(25600, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(25600, fp16, fp16, fp16, fp32, 5, 1, 4, 16, 4); -REGISTER_BWD_LAUNCHER(25600, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4); -// REGISTER_BWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 5, 1, 4, 16, 4); -// REGISTER_BWD_LAUNCHER(25600, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4); - -REGISTER_BWD_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 8, 8, 4); -REGISTER_BWD_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 8, 4, 4); -REGISTER_BWD_LAUNCHER(30720, fp16, fp32, fp16, fp32, 4, 1, 8, 8, 4); -// REGISTER_BWD_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 8, 4, 4); -// REGISTER_BWD_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 8, 8, 4); - -REGISTER_BWD_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(32768, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4); -// REGISTER_BWD_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4); -// REGISTER_BWD_LAUNCHER(32768, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4); - -REGISTER_BWD_LAUNCHER(40960, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(40960, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4); -// REGISTER_BWD_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4); -// REGISTER_BWD_LAUNCHER(40960, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4); - -REGISTER_BWD_LAUNCHER(49152, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(49152, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(49152, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4); -// REGISTER_BWD_LAUNCHER(49152, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4); -// REGISTER_BWD_LAUNCHER(49152, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4); - -REGISTER_BWD_LAUNCHER(65536, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4); -REGISTER_BWD_LAUNCHER(65536, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4); -// REGISTER_BWD_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4); -// REGISTER_BWD_LAUNCHER(65536, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4); - diff --git a/apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu b/apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu deleted file mode 100644 index 660e4a0..0000000 --- a/apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu +++ /dev/null @@ -1,235 +0,0 @@ -#include "ln.h" -#include "ln_utils.cuh" -#include "ln_kernel_traits.h" -#include "ln_fwd_kernels.cuh" - -using namespace layer_norm; - -FwdRegistry layer_norm::FWD_FUNCS; - -template< - typename weight_t, - typename input_t, - typename output_t, - typename compute_t, - typename index_t, - int HIDDEN_SIZE, - int CTAS_PER_ROW, - int WARPS_M, - int WARPS_N, - int BYTES_PER_LDG -> -void launch_(LaunchParams &launch_params, const bool configure_params){ - - using Kernel_traits = Kernel_traits; - auto kernel = &ln_fwd_kernel; - - if( configure_params ) { - int ctas_per_sm; - cudaError_t status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD); - launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW; - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if(Kernel_traits::CTAS_PER_ROW > 1) { - launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; - launch_params.workspace_bytes = launch_params.params.ctas_per_col - * Kernel_traits::WARPS_M - * Kernel_traits::CTAS_PER_ROW - * sizeof(typename Kernel_traits::Stats::stats_t) - * 2; - } - return; - } - - if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) { - #if defined(__HIP_PLATFORM_HCC__) - CHECK_CUDA(hipFuncSetAttribute((const void *)kernel, hipFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD)); - #else - CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD)); - #endif - } - auto stream = launch_params.stream; - auto ctas_per_col = launch_params.params.ctas_per_col; - - if( Kernel_traits::CTAS_PER_ROW == 1 ) { - kernel<<>>(launch_params.params); - } else { - dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col); - dim3 block(Kernel_traits::THREADS_PER_CTA); - void *params_ = (void *)&launch_params.params; - #if defined(__HIP_PLATFORM_HCC__) - hipLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES_FWD, stream); - #else - cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES_FWD, stream); - #endif - } - -} - - -REGISTER_FWD_LAUNCHER( 768, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 768, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 768, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -// REGISTER_FWD_LAUNCHER( 768, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -// REGISTER_FWD_LAUNCHER( 768, bf16, fp32, bf16, fp32, 1, 4, 1, 16); - -REGISTER_FWD_LAUNCHER( 1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 1024, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -// REGISTER_FWD_LAUNCHER( 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -// REGISTER_FWD_LAUNCHER( 1024, bf16, fp32, bf16, fp32, 1, 4, 1, 16); - -REGISTER_FWD_LAUNCHER( 1536, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 1536, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 1536, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -// REGISTER_FWD_LAUNCHER( 1536, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -// REGISTER_FWD_LAUNCHER( 1536, bf16, fp32, bf16, fp32, 1, 4, 1, 16); - -REGISTER_FWD_LAUNCHER( 2048, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 2048, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 2048, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -// REGISTER_FWD_LAUNCHER( 2048, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -// REGISTER_FWD_LAUNCHER( 2048, bf16, fp32, bf16, fp32, 1, 4, 1, 16); - -REGISTER_FWD_LAUNCHER( 2304, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 2304, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_LAUNCHER( 2304, fp16, fp32, fp16, fp32, 1, 4, 1, 16); -// REGISTER_FWD_LAUNCHER( 2304, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -// REGISTER_FWD_LAUNCHER( 2304, bf16, fp32, bf16, fp32, 1, 4, 1, 16); - -REGISTER_FWD_LAUNCHER( 3072, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 3072, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 3072, fp16, fp32, fp16, fp32, 1, 1, 4, 16); -// REGISTER_FWD_LAUNCHER( 3072, bf16, bf16, bf16, fp32, 1, 1, 4, 16); -// REGISTER_FWD_LAUNCHER( 3072, bf16, fp32, bf16, fp32, 1, 1, 4, 16); - -REGISTER_FWD_LAUNCHER( 3840, fp32, fp32, fp32, fp32, 1, 1, 4, 4); -REGISTER_FWD_LAUNCHER( 3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4); -REGISTER_FWD_LAUNCHER( 3840, fp16, fp32, fp16, fp32, 1, 1, 4, 4); -// REGISTER_FWD_LAUNCHER( 3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4); -// REGISTER_FWD_LAUNCHER( 3840, bf16, fp32, bf16, fp32, 1, 1, 4, 4); - -REGISTER_FWD_LAUNCHER( 4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 4096, fp16, fp32, fp16, fp32, 1, 1, 4, 16); -// REGISTER_FWD_LAUNCHER( 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16); -// REGISTER_FWD_LAUNCHER( 4096, bf16, fp32, bf16, fp32, 1, 1, 4, 16); - -REGISTER_FWD_LAUNCHER( 5120, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 5120, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 5120, fp16, fp32, fp16, fp32, 1, 1, 4, 16); -// REGISTER_FWD_LAUNCHER( 5120, bf16, bf16, bf16, fp32, 1, 1, 4, 16); -// REGISTER_FWD_LAUNCHER( 5120, bf16, fp32, bf16, fp32, 1, 1, 4, 16); - -REGISTER_FWD_LAUNCHER( 6144, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 6144, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 6144, fp16, fp32, fp16, fp32, 1, 1, 4, 16); -// REGISTER_FWD_LAUNCHER( 6144, bf16, bf16, bf16, fp32, 1, 1, 4, 16); -// REGISTER_FWD_LAUNCHER( 6144, bf16, fp32, bf16, fp32, 1, 1, 4, 16); - -REGISTER_FWD_LAUNCHER( 8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER( 8192, fp16, fp32, fp16, fp32, 1, 1, 4, 16); -// REGISTER_FWD_LAUNCHER( 8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16); -// REGISTER_FWD_LAUNCHER( 8192, bf16, fp32, bf16, fp32, 1, 1, 4, 16); - -REGISTER_FWD_LAUNCHER(10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(10240, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_LAUNCHER(10240, fp16, fp32, fp16, fp32, 1, 1, 4, 16); -// REGISTER_FWD_LAUNCHER(10240, bf16, bf16, bf16, fp32, 1, 1, 4, 16); -// REGISTER_FWD_LAUNCHER(10240, bf16, fp32, bf16, fp32, 1, 1, 4, 16); - -REGISTER_FWD_LAUNCHER(12288, fp32, fp32, fp32, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(12288, fp16, fp16, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(12288, fp16, fp32, fp16, fp32, 2, 1, 4, 16); -// REGISTER_FWD_LAUNCHER(12288, bf16, bf16, bf16, fp32, 2, 1, 4, 16); -// REGISTER_FWD_LAUNCHER(12288, bf16, fp32, bf16, fp32, 2, 1, 4, 16); - -REGISTER_FWD_LAUNCHER(12800, fp32, fp32, fp32, fp32, 2, 1, 4, 4); -REGISTER_FWD_LAUNCHER(12800, fp16, fp16, fp16, fp32, 2, 1, 4, 4); -REGISTER_FWD_LAUNCHER(12800, fp16, fp32, fp16, fp32, 2, 1, 4, 4); -// REGISTER_FWD_LAUNCHER(12800, bf16, bf16, bf16, fp32, 2, 1, 4, 4); -// REGISTER_FWD_LAUNCHER(12800, bf16, fp32, bf16, fp32, 2, 1, 4, 4); - -REGISTER_FWD_LAUNCHER(14336, fp32, fp32, fp32, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(14336, fp16, fp16, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(14336, fp16, fp32, fp16, fp32, 2, 1, 4, 16); -// REGISTER_FWD_LAUNCHER(14336, bf16, bf16, bf16, fp32, 2, 1, 4, 16); -// REGISTER_FWD_LAUNCHER(14336, bf16, fp32, bf16, fp32, 2, 1, 4, 8); - -REGISTER_FWD_LAUNCHER(15360, fp32, fp32, fp32, fp32, 2, 1, 4, 8); -REGISTER_FWD_LAUNCHER(15360, fp16, fp16, fp16, fp32, 2, 1, 4, 8); -REGISTER_FWD_LAUNCHER(15360, fp16, fp32, fp16, fp32, 2, 1, 4, 8); -// REGISTER_FWD_LAUNCHER(15360, bf16, bf16, bf16, fp32, 2, 1, 4, 8); -// REGISTER_FWD_LAUNCHER(15360, bf16, fp32, bf16, fp32, 2, 1, 4, 8); - -REGISTER_FWD_LAUNCHER(16384, fp32, fp32, fp32, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(16384, fp16, fp16, fp16, fp32, 2, 1, 4, 16); -// REGISTER_FWD_LAUNCHER(16384, bf16, bf16, bf16, fp32, 2, 1, 4, 16); -// REGISTER_FWD_LAUNCHER(16384, bf16, fp32, bf16, fp32, 2, 1, 4, 16); - -REGISTER_FWD_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16); -REGISTER_FWD_LAUNCHER(18432, fp16, fp16, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(18432, fp16, fp32, fp16, fp32, 2, 1, 4, 16); -// REGISTER_FWD_LAUNCHER(18432, bf16, bf16, bf16, fp32, 2, 1, 4, 16); -// REGISTER_FWD_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16); - -REGISTER_FWD_LAUNCHER(20480, fp32, fp32, fp32, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(20480, fp16, fp16, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(20480, fp16, fp32, fp16, fp32, 2, 1, 4, 16); -// REGISTER_FWD_LAUNCHER(20480, bf16, bf16, bf16, fp32, 2, 1, 4, 16); -// REGISTER_FWD_LAUNCHER(20480, bf16, fp32, bf16, fp32, 2, 1, 4, 16); - -REGISTER_FWD_LAUNCHER(24576, fp32, fp32, fp32, fp32, 4, 1, 4, 16); -REGISTER_FWD_LAUNCHER(24576, fp16, fp16, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_LAUNCHER(24576, fp16, fp32, fp16, fp32, 2, 1, 4, 16); -// REGISTER_FWD_LAUNCHER(24576, bf16, bf16, bf16, fp32, 2, 1, 4, 16); -// REGISTER_FWD_LAUNCHER(24576, bf16, fp32, bf16, fp32, 2, 1, 4, 16); - -REGISTER_FWD_LAUNCHER(25600, fp32, fp32, fp32, fp32, 4, 1, 4, 4); -REGISTER_FWD_LAUNCHER(25600, fp16, fp16, fp16, fp32, 2, 1, 4, 8); -REGISTER_FWD_LAUNCHER(25600, fp16, fp32, fp16, fp32, 4, 1, 4, 4); -// REGISTER_FWD_LAUNCHER(25600, bf16, bf16, bf16, fp32, 2, 1, 4, 8); -// REGISTER_FWD_LAUNCHER(25600, bf16, fp32, bf16, fp32, 4, 1, 4, 4); - -REGISTER_FWD_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 4, 4); -REGISTER_FWD_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 4, 4); -REGISTER_FWD_LAUNCHER(30720, fp16, fp32, fp16, fp32, 4, 1, 4, 4); -// REGISTER_FWD_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 4, 4); -// REGISTER_FWD_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 4, 4); - -REGISTER_FWD_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 4, 16); -REGISTER_FWD_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 4, 16); -REGISTER_FWD_LAUNCHER(32768, fp16, fp32, fp16, fp32, 4, 1, 4, 16); -// REGISTER_FWD_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 4, 16); -// REGISTER_FWD_LAUNCHER(32768, bf16, fp32, bf16, fp32, 4, 1, 4, 16); - -REGISTER_FWD_LAUNCHER(40960, fp32, fp32, fp32, fp32, 4, 1, 4, 16); -REGISTER_FWD_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 4, 16); -REGISTER_FWD_LAUNCHER(40960, fp16, fp32, fp16, fp32, 4, 1, 4, 16); -// REGISTER_FWD_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 4, 16); -// REGISTER_FWD_LAUNCHER(40960, bf16, fp32, bf16, fp32, 4, 1, 4, 16); - -REGISTER_FWD_LAUNCHER(49152, fp32, fp32, fp32, fp32, 8, 1, 4, 16); -REGISTER_FWD_LAUNCHER(49152, fp16, fp16, fp16, fp32, 4, 1, 4, 16); -REGISTER_FWD_LAUNCHER(49152, fp16, fp32, fp16, fp32, 4, 1, 4, 16); -// REGISTER_FWD_LAUNCHER(49152, bf16, bf16, bf16, fp32, 4, 1, 4, 16); -// REGISTER_FWD_LAUNCHER(49152, bf16, fp32, bf16, fp32, 4, 1, 4, 16); - -REGISTER_FWD_LAUNCHER(65536, fp32, fp32, fp32, fp32, 8, 1, 4, 16); -REGISTER_FWD_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 4, 16); -REGISTER_FWD_LAUNCHER(65536, fp16, fp32, fp16, fp32, 8, 1, 4, 16); -// REGISTER_FWD_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 4, 16); -// REGISTER_FWD_LAUNCHER(65536, bf16, fp32, bf16, fp32, 8, 1, 4, 16); - diff --git a/apex/contrib/csrc/layer_norm/ln_fwd_kernels.cuh b/apex/contrib/csrc/layer_norm/ln_fwd_kernels.cuh deleted file mode 100644 index b6210d0..0000000 --- a/apex/contrib/csrc/layer_norm/ln_fwd_kernels.cuh +++ /dev/null @@ -1,114 +0,0 @@ -#pragma once - -#if defined(__HIP_PLATFORM_HCC__) -#include "ln_utils.cuh" -#else -#include "ln.h" -#endif - -namespace layer_norm { - -template -__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) -void ln_fwd_kernel(FwdParams params) { - - enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; - enum { WARPS_N = Ktraits::WARPS_N }; - enum { WARPS_M = Ktraits::WARPS_M }; - enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW }; - enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG }; - enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW }; - enum { LDGS = Ktraits::LDGS }; - enum { NUM_ELTS = Ktraits::NUM_ELTS }; - enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW }; - - using output_t = typename Ktraits::output_t; - using index_t = typename Ktraits::index_t; - using compute_t = typename Ktraits::compute_t; - using Ivec = typename Ktraits::Ivec; - using Ovec = typename Ktraits::Ovec; - using Wvec = typename Ktraits::Wvec; - using Cvec = typename Ktraits::Cvec; - - using Stats = typename Ktraits::Stats; - using stats_t = typename Stats::stats_t; - - extern __shared__ char smem_[]; - - const index_t tidx = threadIdx.x; - const index_t bidn = blockIdx.x % CTAS_PER_ROW; - const index_t bidm = blockIdx.x / CTAS_PER_ROW; - const index_t lane = tidx % THREADS_PER_WARP; - const index_t warp = tidx / THREADS_PER_WARP; - const index_t warp_m = warp / WARPS_N; - const index_t warp_n = warp % WARPS_N; - - const index_t r = bidm * ROWS_PER_CTA + warp_m; - const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane; - - Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem_); - - compute_t *mu_ptr = static_cast(params.mu); - compute_t *rs_ptr = static_cast(params.rs); - - Wvec gamma[LDGS]; - Wvec beta[LDGS]; - index_t idx = c; - #pragma unroll - for( int it = 0; it < LDGS; it++ ) { - gamma[it].load_from(params.gamma, idx); - beta[it].load_from(params.beta, idx); - idx += VEC_COLS_PER_LDG; - } - - constexpr compute_t rn = 1.f / compute_t(Ktraits::COLS); - - for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) { - Ivec x[LDGS]; - index_t idx = row * Ktraits::VEC_COLS + c; - compute_t xf[LDGS * NUM_ELTS]; - #pragma unroll - for( int it = 0; it < LDGS; it++ ) { - x[it].load_from(params.x, idx); - #pragma unroll - for( int jt = 0; jt < NUM_ELTS; jt++ ) { - compute_t x_ij = compute_t(x[it].data.elt[jt]); - xf[it * NUM_ELTS + jt] = x_ij; - } - idx += VEC_COLS_PER_LDG; - } - - stats_t s = stats.compute(xf, rn); - - compute_t mu = layer_norm::Get<0>::of(s); - compute_t m2 = layer_norm::Get<1>::of(s); - - if( bidn == 0 && warp_n == 0 && lane == 0 ) { - mu_ptr[row] = mu; - } - - compute_t rs = rsqrtf(rn * m2 + params.epsilon); - - if( bidn == 0 && warp_n == 0 && lane == 0 ) { - rs_ptr[row] = rs; - } - - Ovec z[LDGS]; - idx = row * Ktraits::VEC_COLS + c; - #pragma unroll - for( int it = 0; it < LDGS; it++ ) { - #pragma unroll - for( int jt = 0; jt < NUM_ELTS; jt++ ) { - output_t y_ij = output_t(rs * (xf[it * NUM_ELTS + jt] - mu)); - output_t g_ij = gamma[it].data.elt[jt]; - output_t b_ij = beta[it].data.elt[jt]; - z[it].data.elt[jt] = (g_ij * y_ij + b_ij); - } - z[it].store_to(params.z, idx); - idx += VEC_COLS_PER_LDG; - } - - } -} - -} // namespace layer_norm diff --git a/apex/contrib/csrc/layer_norm/ln_kernel_traits.h b/apex/contrib/csrc/layer_norm/ln_kernel_traits.h deleted file mode 100644 index ed745c5..0000000 --- a/apex/contrib/csrc/layer_norm/ln_kernel_traits.h +++ /dev/null @@ -1,159 +0,0 @@ -#pragma once - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace layer_norm { -template< - uint32_t HIDDEN_SIZE_, - typename weight_t_, - typename input_t_, - typename output_t_, - typename compute_t_, - typename index_t_, - uint32_t THREADS_PER_CTA_ -> -struct Kernel_traits_base { - - using weight_t = weight_t_; - using input_t = input_t_; - using output_t = output_t_; - using compute_t = compute_t_; - using index_t = index_t_; - - enum { HIDDEN_SIZE = HIDDEN_SIZE_ }; - enum { THREADS_PER_CTA = THREADS_PER_CTA_ }; - enum { THREADS_PER_WARP = 32 }; - -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< - uint32_t HIDDEN_SIZE_, - typename weight_t_, - typename input_t_, - typename output_t_, - typename compute_t_, - typename index_t_, - uint32_t THREADS_PER_CTA_, - uint32_t BYTES_PER_LDG_, - typename Base = Kernel_traits_base -> -struct Kernel_traits_finalize : public Base { - enum { ROWS_PER_CTA = Base::THREADS_PER_CTA / Base::THREADS_PER_WARP }; - static_assert((int) ROWS_PER_CTA <= (int) Base::THREADS_PER_WARP); - // Bytes per global load from the input. - enum { BYTES_PER_LDG = BYTES_PER_LDG_ }; - // Number of elements fetched by a global load. - enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(compute_t_) }; - // Bytes per global store of the weights. - enum { BYTES_PER_STG = ELTS_PER_LDG * sizeof(weight_t_) }; - static_assert(sizeof(BYTES_PER_LDG) == 4, "Conflict-free smem transpose only implemented for 4B compute type!"); - static_assert(Base::THREADS_PER_CTA == ROWS_PER_CTA * Base::THREADS_PER_WARP, "We assume one warp per row!"); - // The total number of BYTES_PER_LDG-wide words in a hidden vector. - enum { COLS = HIDDEN_SIZE_ * sizeof(compute_t_) / BYTES_PER_LDG }; - static_assert(COLS * BYTES_PER_LDG == HIDDEN_SIZE_ * sizeof(compute_t_)); - - // Shared memory size to transpose the CTA result. - enum { SMEM_BYTES_TRANSPOSE = Base::THREADS_PER_CTA * BYTES_PER_LDG }; - // Shared memory size to coalsece the CTA result. - enum { SMEM_BYTES_OUTPUT = Base::THREADS_PER_WARP * BYTES_PER_LDG }; - // Shared memory requirement per CTA. - enum { SMEM_BYTES_PER_CTA = 2 * SMEM_BYTES_TRANSPOSE + 2 * SMEM_BYTES_OUTPUT }; - - // The type of the reducer. - using Reducer = layer_norm::Reducer; - - // Condition for the whole CTA to participate in syncthreads. - static_assert(COLS % Base::THREADS_PER_WARP == 0); - enum { CTAS = COLS / Base::THREADS_PER_WARP }; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -template< - typename weight_t_, - typename input_t_, - typename output_t_, - typename compute_t_, - typename index_t_, - uint32_t HIDDEN_SIZE_, - uint32_t CTAS_PER_ROW_, - uint32_t WARPS_M_, - uint32_t WARPS_N_, - uint32_t BYTES_PER_LDG_ = 16, - typename Base = Kernel_traits_base< - HIDDEN_SIZE_, - weight_t_, - input_t_, - output_t_, - compute_t_, - index_t_, - WARPS_M_*WARPS_N_*THREADS_PER_WARP - > -> -struct Kernel_traits : public Base { - - using input_t = typename Base::input_t; - using weight_t = typename Base::weight_t; - using compute_t = typename Base::compute_t; - using output_t = typename Base::output_t; - using index_t = typename Base::index_t; - - enum { CTAS_PER_ROW = CTAS_PER_ROW_ }; - enum { WARPS_M = WARPS_M_ }; - enum { WARPS_N = WARPS_N_ }; - enum { COLS = HIDDEN_SIZE_ }; - enum { HIDDEN_SIZE = HIDDEN_SIZE_ }; - enum { BYTES_PER_LDG = BYTES_PER_LDG_ }; - enum { NUM_ELTS = BYTES_PER_LDG / sizeof(input_t) }; - - enum { THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP }; - enum { THREADS_PER_CTA = WARPS_M * THREADS_PER_ROW }; - enum { ROWS_PER_CTA = WARPS_M }; - - enum { BYTES_PER_ROW = COLS * sizeof(input_t) }; - enum { BYTES_PER_ROW_PER_CTA = THREADS_PER_ROW * BYTES_PER_LDG }; - // Multi-row per CTA not supported for multi-CTA => no smem for WGRAD needed - enum { SMEM_BYTES_WGRAD = CTAS_PER_ROW > 1 ? 0 : ROWS_PER_CTA * COLS * sizeof(compute_t) }; - static_assert(WARPS_M == 1 || CTAS_PER_ROW == 1); - - using reduce_t = typename layer_norm::TypeToVec2::Type; - using Reducer = layer_norm::Reducer; - - enum { SMEM_BYTES_DGRAD = Reducer::SMEM_BYTES }; - enum { SMEM_BYTES = SMEM_BYTES_DGRAD + SMEM_BYTES_WGRAD }; - - using Ivec = layer_norm::Vec; - using Ovec = layer_norm::Vec; - using Wvec = layer_norm::Vec; - using Cvec = layer_norm::Vec; - enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(input_t) }; - - // Assume that each thread can handle the same number of elements in the output and weights as in the input. - static_assert(sizeof(input_t) >= sizeof(output_t)); - static_assert(sizeof(input_t) >= sizeof(weight_t)); - // The number of columns fetched per load from input: one per thread. - enum { VEC_COLS_PER_LDG = CTAS_PER_ROW * THREADS_PER_ROW }; - // The total number of vectorized loads/stores per hidden vector. - enum { VEC_COLS = COLS / ELTS_PER_LDG }; - // The number of loads per thread for the input. - enum { LDGS = VEC_COLS / VEC_COLS_PER_LDG }; - static_assert(LDGS * VEC_COLS_PER_LDG == VEC_COLS); - //static_assert(LDGS * BYTES_PER_ROW_PER_CTA * CTAS_PER_ROW == BYTES_PER_ROW, ""); - - using Stats = layer_norm::Stats; - enum { SMEM_BYTES_FWD = Stats::SMEM_BYTES }; - -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace layer_norm diff --git a/apex/contrib/csrc/layer_norm/ln_utils.cuh b/apex/contrib/csrc/layer_norm/ln_utils.cuh deleted file mode 100644 index 317848c..0000000 --- a/apex/contrib/csrc/layer_norm/ln_utils.cuh +++ /dev/null @@ -1,793 +0,0 @@ -#pragma once - -#include - -#if defined(__HIP_PLATFORM_HCC__) -#include "hip/hip_fp16.h" -#include "hip/hip_bfloat16.h" -#else -#include -#include -#endif - -#include "ln.h" - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -constexpr uint32_t THREADS_PER_WARP = 32; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline void check_cuda_(cudaError_t status, const char *file, int line) { - if( status != cudaSuccess ) { - fprintf(stderr, "CUDA Error: %s %s %d\n", cudaGetErrorString(status), file, line); - exit(status); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define CHECK_CUDA(ans) \ - { check_cuda_((ans), __FILE__, __LINE__); } - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define DIVUP(x, y) (((x) + ((y)-1)) / (y)) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define REGISTER_FWD_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG) \ - void ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE(LaunchParams &launch_params, \ - const bool configure_params) { \ - launch_( \ - launch_params, configure_params); \ - } \ - static FwdRegistrar reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define REGISTER_BWD_LAUNCHER( \ - HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \ - void ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE(LaunchParams &launch_params, \ - const bool configure_params) { \ - launch_(launch_params, configure_params); \ - } \ - static BwdRegistrar reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 operator+(const float2 & a, const float2 & b){ - return {a.x + b.x, a.y + b.y}; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void operator+=(float2 & a, const float2 & b){ - a.x += b.x; - a.y += b.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Sum { - inline __device__ Sum(){} - inline __device__ T operator()(const T &a, const T &b){ - return a + b; - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ T warp_shuffle_xor(const T & x, uint32_t idx){ - #if defined(__HIP_PLATFORM_HCC__) - return __shfl_xor(x, idx); - #else - return __shfl_xor_sync(uint32_t(-1), x, idx); - #endif -} - -template<> -inline __device__ float2 warp_shuffle_xor(const float2 & x, uint32_t idx){ - return { warp_shuffle_xor(x.x, idx), warp_shuffle_xor(x.y, idx) }; -} - -template -inline __device__ T warp_shuffle_down(const T & x, uint32_t idx){ - #if defined(__HIP_PLATFORM_HCC__) - return __shfl_down(x, idx); - #else - return __shfl_down_sync(uint32_t(-1), x, idx); - #endif -} - -template<> -inline __device__ float2 warp_shuffle_down(const float2 & x, uint32_t idx){ - return { warp_shuffle_down(x.x, idx), warp_shuffle_down(x.y, idx) }; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace layer_norm { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct uint16 { - uint4 u; - uint4 v; - uint4 s; - uint4 t; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct uint8 { - uint4 u; - uint4 v; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct BytesToType {}; - -template<> -struct BytesToType<64> { - using Type = uint16; - static_assert(sizeof(Type) == 64); -}; - -template<> -struct BytesToType<32> { - using Type = uint8; - static_assert(sizeof(Type) == 32); -}; - -template<> -struct BytesToType<16> { - using Type = uint4; - static_assert(sizeof(Type) == 16); -}; - -template<> -struct BytesToType<8> { - using Type = uint64_t; - static_assert(sizeof(Type) == 8); -}; - -template<> -struct BytesToType<4> { - using Type = uint32_t; - static_assert(sizeof(Type) == 4); -}; - -template<> -struct BytesToType<2> { - using Type = uint16_t; - static_assert(sizeof(Type) == 2); -}; - -template<> -struct BytesToType<1> { - using Type = uint8_t; - static_assert(sizeof(Type) == 1); -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct TypeToVec2 {}; - -template<> -struct TypeToVec2 { - using Type = float2; -}; - -template<> -struct TypeToVec2 { - using Type = half2; -}; - -#if 0 -template<> -struct TypeToVec2 { - using Type = nv_bfloat162; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Get { - template - static inline __device__ R of(const T &vec); -}; - -template<> -template -inline __device__ R Get<0>::of(const T &vec) { - return vec.x; -} - -template<> -template -inline __device__ R Get<1>::of(const T &vec) { - return vec.y; -} - -template<> -template -inline __device__ R Get<2>::of(const T &vec) { - return vec.z; -} - -template<> -template -inline __device__ R Get<3>::of(const T &vec) { - return vec.w; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Converter{ - static inline __device__ Dst convert(const Src &from) { - return Dst(from); - } -}; - -template<> -struct Converter{ - static inline __device__ half2 convert(const float2 &x) { - return __float22half2_rn(x); - } -}; - -#if defined(__HIP_PLATFORM_HCC__) -template<> -struct Converter{ - static inline __device__ half convert(const float &x) { - return __float2half(x); - } -}; - -template<> -struct Converter{ - static inline __device__ float convert(const half &x) { - return __half2float(x); - } -}; - -template<> -struct Converter{ - static inline __device__ hip_bfloat16 convert(const float &x) { - return hip_bfloat16::round_to_bfloat16(x); - } -}; - -template<> -struct Converter{ - static inline __device__ float convert(const hip_bfloat16 &x) { - return float(x); - } -}; -#endif - -#if 0 -template<> -struct Converter{ - static inline __device__ nv_bfloat162 convert(const float2 &x) { -#if __CUDA_ARCH__ >= 800 - return __float22bfloat162_rn(x); -#else - union { - nv_bfloat162 raw; - nv_bfloat16 x; - nv_bfloat16 y; - } tmp; - tmp.x = __float2bfloat16_rn(x.x); - tmp.y = __float2bfloat16_rn(x.y); - return tmp.raw; -#endif - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Zeros{ - static inline __device__ T get() { - return T(0.f); - } -}; - -template<> -struct Zeros{ - static inline __device__ float2 get() { - return make_float2(0.f, 0.f); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Vec { - - enum { BYTES = NUM_ELT * sizeof(Elt_type) }; - - using Vec_type = typename BytesToType::Type; - - using Alias_type = union { - Vec_type vec; - Elt_type elt[NUM_ELT]; - }; - - Alias_type data; - - template - inline __device__ void to(Vec &other) { - #pragma unroll - for( int it = 0; it < NUM_ELT; it++ ) { - other.data.elt[it] = S(this->data.elt[it]); - } - } - - template - inline __device__ void assign(const Op &op) { - #pragma unroll - for( int it = 0; it < NUM_ELT; it++ ) { - this->data.elt[it] = op(it); - } - } - - inline __device__ void load_from(const void *base_ptr, const size_t idx) { - this->data.vec = static_cast(base_ptr)[idx]; - } - - inline __device__ void store_to(void *base_ptr, const size_t idx) { - static_cast(base_ptr)[idx] = this->data.vec; - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct InterCTASync { - - template - inline __device__ InterCTASync(Params & params, uint32_t bidm, uint32_t bidn) - : phase_counter_(0) - , b0_(params.barrier + bidm) // The barrier for this group of CTAs. - , b1_(params.barrier + bidm + params.ctas_per_col) // The barrier for this group of CTAs. - { - // BARRIERS ARE ASSUMED TO BE INITIALIZED TO 0! - } - - inline __device__ void spin_wait_(int *barrier, int step, int expected) { - #if defined(__HIP_PLATFORM_HCC__) - atomicAdd(barrier, step); - for( int found = -1; found != expected; ) { - // asm volatile("global_load_dword %0, %1, off;" : "=v"(found) : "v"(barrier)); - found = atomicCAS(barrier, expected, expected); - } - #else - asm volatile("red.release.gpu.global.add.s32 [%0], %1;" ::"l"(barrier), "r"(step)); - for( int found = -1; found != expected; ) { - asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier)); - } - #endif - } - - inline __device__ void sync(){ - // ALL THREADS MUST ENTER! - - // We switch barrier every iteration. - int *barrier = phase_counter_ & 0x1 ? b1_ : b0_; - // We decrement every other iteration. - bool dec = phase_counter_ & 0x2; - int step = dec ? -1 : 1; - int expected = dec ? 0 : CTAS_PER_ROW; - // There are only 4 phases: up/down for b0/b1. - phase_counter_ = (phase_counter_ + 1) & 0x3; - - if( threadIdx.x == 0 ) { - spin_wait_(barrier, step, expected); - } - // CTA waits for thread 0 - __syncthreads(); - } - - int phase_counter_; - int * b0_; - int * b1_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Reducer : public Reducer { - - using InterCTASync = InterCTASync; - using Base = Reducer; - using Type = typename Base::Type; - - enum { SMEM_BYTES = Base::SMEM_BYTES }; - - enum { WS_BARRIER_BYTES = 2 * sizeof(int) }; - enum { WS_DATA_BYTES = WARPS_M * CTAS_PER_ROW * sizeof(T) }; - - // size of the barriers + temporary result per CTA (multiply with CTAS_PER_ROW to get total) - enum { WORKSPACE_BYTES_PER_GROUP = Base::WORKSPACE_BYTES_PER_GROUP + WS_BARRIER_BYTES + WS_DATA_BYTES }; - - template - inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) - : Base(params, bidm, bidn, warp_m, warp_n, lane, smem) - , inter_cta_(params, bidm, bidn) - , bidn_(bidn) // CTA id within the group. - , w0_(static_cast(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW) - , w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW) - { - } - - template - inline __device__ T allreduce(T data, Op &op) { - data = Base::reduce(data, op); - // We switch workspace every iteration. - T *workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_; - - // Warp leaders 0 hold the CTA-local results. - if( this->warp_n_ == 0 && this->lane_ == 0 ) { - workspace[bidn_] = data; - } - inter_cta_.sync(); - static_assert(CTAS_PER_ROW <= 32); - T total = Zeros::get(); - if(this->lane_ < CTAS_PER_ROW){ - total = workspace[this->lane_]; - } - total = Reducer::allreduce_(total, op); - - return total; - } - - InterCTASync inter_cta_; - - T *w0_; - T *w1_; - int bidn_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Reducer { - - using Type = T; - enum { SMEM_BYTES = 0 }; - enum { WORKSPACE_BYTES_PER_GROUP = 0 }; - - enum { THREADS_PER_WARP = 32 }; - - template - inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) - : warp_n_(warp_n) - , lane_(lane) - { - } - - template - static inline __device__ T allreduce_(T data, Op &op) { - #pragma unroll - for( int it = 1; it < THREADS_PER_WARP; it *= 2 ) { - data = op(data, warp_shuffle_xor(data, it)); - } - return data; - } - - template - inline __device__ T allreduce(T data, Op &op) { - return allreduce_(data, op); - } - - template - inline __device__ T reduce(T data, Op &op){ - // only lane 0 holds the result! - #pragma unroll - for( int it = THREADS_PER_WARP / 2; it > 0; it /= 2 ) { - data = op(data, warp_shuffle_down(data, it)); - } - return data; - } - int warp_n_; - int lane_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Reducer : public Reducer { - - using Base = Reducer; - - using Type = T; - - enum { SMEM_BYTES = Base::SMEM_BYTES + WARPS_M * WARPS_N * sizeof(T) * 2 }; - enum { WORKSPACE_BYTES_PER_GROUP = 0 }; - - enum { THREADS_PER_WARP = 32 }; - - template - inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) - : Base(params, bidm, bidn, warp_m, warp_n, lane, smem) - , use0_(true) - { - smem0_ = &static_cast(smem)[warp_m * WARPS_N]; - smem1_ = smem0_ + WARPS_M * WARPS_N; - } - - template - inline __device__ T allreduce(T data, Op & op) { - T * smem = use0_ ? smem0_ : smem1_; - use0_ = !use0_; - data = Base::reduce(data, op); - if( this->lane_ == 0 ) { - smem[this->warp_n_] = data; - } - __syncthreads(); - T out = Zeros::get(); - #pragma unroll - for( int it = 0; it < WARPS_N; it++ ) { - out = op(out, smem[it]); - } - return out; - } - - template - inline __device__ T reduce(T data, Op &op) { - T * smem = use0_ ? smem0_ : smem1_; - use0_ = !use0_; - // only intra-CTA group leader holds the result! - data = Base::reduce(data, op); - if( this->lane_ == 0 ) { - smem[this->warp_n_] = data; - } - __syncthreads(); - T out = Zeros::get(); - if( this->warp_n_ == 0 && this->lane_ == 0 ) { - #pragma unroll - for( int it = 0; it < WARPS_N; it++ ) { - out = op(out, smem[it]); - } - } - return out; - } - - T * smem0_; - T * smem1_; - bool use0_; - -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, T &n_a, int num_active){ - //Assume at least leftmost is valid and init: step = next_pow2(num_active) / 2 (might get NaN otherwise) - int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1); - - #pragma unroll - for( int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2 ) { - // Exchange - T n_b = warp_shuffle_down(n_a, step); - T m_b = warp_shuffle_down(m_a, step); - T m2_b = warp_shuffle_down(m2_a, step); - - // Update - const T n_ab = n_a + n_b; // We can handle one of them being 0, not both. - const T rn_ab = 1.f / n_ab; // Might have different n per thread, otherwise this would simplify :( - const T delta = m_a - m_b; - const float m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab; - const float m_ab = (n_a * m_a + n_b * m_b) * rn_ab; - - n_a = n_ab; - m_a = m_ab; - m2_a = m2_ab; - } - // Intra-warp broadcast (only lane 0 has valid stats). - #if defined(__HIP_PLATFORM_HCC__) - m_a = __shfl(m_a, 0); - m2_a = __shfl(m2_a, 0); - #else - m_a = __shfl_sync(uint32_t(-1), m_a, 0); - m2_a = __shfl_sync(uint32_t(-1), m2_a, 0); - #endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Stats { - // This could be done generically with the Reducer. But then we would have to exchange 3 instead of 2 fields. - - using InterCTASync = InterCTASync; - using BlockStats = Stats; - using stats_t = typename BlockStats::stats_t; - - enum { SMEM_BYTES = BlockStats::SMEM_BYTES }; - - template - inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) - : inter_cta_(params, bidm, bidn) - , block_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem) - , bidn_(bidn) // CTA id within the group. - , w0_(static_cast(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW) - , w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW) - , warp_n_(warp_n) - , lane_(lane) - { - } - - template - inline __device__ stats_t compute(const T (&elts)[N], const T rn) { - constexpr T ELTS_PER_ROW_PER_CTA = N * WARPS_N * THREADS_PER_WARP; - // TODO rn is not really needed here.. - constexpr T block_rn = 1.f / T(ELTS_PER_ROW_PER_CTA); - stats_t block_stats = block_stats_.compute(elts, block_rn); - - stats_t *workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_; - - if( warp_n_ == 0 && lane_ == 0 ) { - workspace[bidn_] = block_stats; - } - - // Wait for all CTAS_PER_ROW CTAS in the group to have written their result. - inter_cta_.sync(); - - T n = Zeros::get(); - T m = Zeros::get(); - T m2 = Zeros::get(); - - // Assume CTA group size in N less than 32, such that we can finalize with a single warp. - static_assert(CTAS_PER_ROW <= 32); - - // Every warp does the final reduction locally. - if( lane_ < CTAS_PER_ROW ) { - stats_t result = workspace[lane_]; - n = ELTS_PER_ROW_PER_CTA; - m = layer_norm::Get<0>::of(result); - m2 = layer_norm::Get<1>::of(result); - } - - warp_chan_upd_dynamic(m, m2, n, CTAS_PER_ROW); - - return { m, m2 }; - } - - InterCTASync inter_cta_; - BlockStats block_stats_; - - stats_t *w0_; - stats_t *w1_; - int bidn_; - int warp_n_; - int lane_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Stats { - - using WarpStats = Stats; - using stats_t = typename WarpStats::stats_t; - - enum { SMEM_BYTES = WARPS_M * WARPS_N * sizeof(stats_t) * 2 }; - - template - inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) - : warp_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem) - , use0_(true) - { - smem0_ = static_cast(smem) + warp_m * WARPS_N; - smem1_ = smem0_ + WARPS_M * WARPS_N; - } - - template - inline __device__ stats_t compute(const T (&elts)[N], const T rn) { - stats_t * smem = use0_ ? smem0_ : smem1_; - use0_ = !use0_; - // Compute warp local for all WARPS_N - constexpr T warp_rn = 1.f / T(N * THREADS_PER_WARP); - stats_t warp_stats = warp_stats_.compute(elts, warp_rn); - - //Each warp warp leader stores its stats - const auto warp_n = warp_stats_.reducer_.warp_n_; - const auto lane = warp_stats_.reducer_.lane_; - if( lane == 0 ) { - smem[warp_n] = warp_stats; - } - __syncthreads(); - - T n = Zeros::get(); - T m = Zeros::get(); - T m2 = Zeros::get(); - - // Assume that there are less than 32 warps, such that we can finalize with a single warp - static_assert(WARPS_N <= 32); - if(lane < WARPS_N){ - stats_t result = smem[lane]; - n = N * THREADS_PER_WARP; - m = layer_norm::Get<0>::of(result); - m2 = layer_norm::Get<1>::of(result); - } - - warp_chan_upd_dynamic(m, m2, n, WARPS_N); - - return { m, m2 }; - } - WarpStats warp_stats_; - stats_t * smem0_; - stats_t * smem1_; - bool use0_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Stats { - - using stats_t = typename TypeToVec2::Type; - // The simple Warp reducer. - using Reducer = Reducer; - - enum { SMEM_BYTES = 0 }; - - template - inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) - : reducer_(params, bidm, bidn, warp_m, warp_n, lane, smem) - { - } - - template - inline __device__ stats_t compute(const T (&elts)[N], const T rn) { - - auto sum = Sum(); - - T m = Zeros::get(); - #pragma unroll - for( int it = 0; it < N; it++ ) { - m += elts[it]; - } - m = reducer_.allreduce(m, sum) * rn; - - T m2 = Zeros::get(); - #pragma unroll - for( int it = 0; it < N; it++ ) { - T diff = (elts[it] - m); - m2 += diff * diff; - } - m2 = reducer_.allreduce(m2, sum); - - return {m, m2}; - } - - Reducer reducer_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace layer_norm diff --git a/apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu b/apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu deleted file mode 100644 index 8177f53..0000000 --- a/apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu +++ /dev/null @@ -1,113 +0,0 @@ -#include -#include -#include - -#include -#include -//#include -#include - -#include -#include -#include - -#include "dropout.cuh" -#include "softmax.cuh" - -// symbol to be automatically resolved by PyTorch libs - -namespace multihead_attn { -namespace fused_softmax { -namespace additive_mask_softmax_dropout { - -std::vector fwd_cuda(bool is_training, int heads, - torch::Tensor const &input, - const half *pad_mask, float dropout_prob) { - const int attn_batches = input.size(0); - const int sequences = attn_batches / heads; - const int q_seq_len = input.size(1); - const int k_seq_len = q_seq_len; - // const int dropout_elems = attn_batches * q_seq_len * k_seq_len; - - // There is no reason to use more than one stream as every kernel is - // sequentially dependent - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - cublasSetStream(handle, stream); - - // 3 Intermediate Results + Output (Note: dropout intermediates are generated - // by ATen library code) - auto act_options = input.options().requires_grad(false); - auto mask_options = act_options.dtype(torch::kUInt8); - - torch::Tensor softmax_results = - torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_results = - torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_mask = - torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); - - // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) - void *input_ptr = static_cast(input.data_ptr()); - void *softmax_results_ptr = static_cast(softmax_results.data_ptr()); - - // Padded Softmax - bool softmax_success = false; - if (pad_mask == nullptr) { - softmax_success = dispatch_softmax( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(input_ptr), k_seq_len, k_seq_len, - attn_batches * q_seq_len); - } else { - softmax_success = dispatch_additive_masked_softmax( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(input_ptr), pad_mask, k_seq_len, - k_seq_len, attn_batches * q_seq_len, - attn_batches * q_seq_len / sequences); - } - - if (is_training) { - // use at:: function so that C++ version generates the same random mask as - // python version - auto dropout_tuple = - at::_fused_dropout(softmax_results, 1.0f - dropout_prob); - dropout_results = std::get<0>(dropout_tuple); - dropout_mask = std::get<1>(dropout_tuple); - } - - // Matmul2 - - return {dropout_results, dropout_mask, softmax_results}; -} - -torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads, - torch::Tensor const &softmax_results, - torch::Tensor const &dropout_mask, float dropout_prob) { - const int attn_batches = output_grads.size(0); - const int q_seq_len = output_grads.size(1); - const int k_seq_len = q_seq_len; - // const int dropout_elems = attn_batches * q_seq_len * k_seq_len; - // TODO: Streams can be used in Backprop but I haven't added more than one - // in my first attempt to create the code - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - cublasSetStream(handle, stream); - - // Output Tensor Allocations - // torch::Tensor input_grads = torch::empty_like(output_grads); - - // Apply Dropout Mask and Scale by Dropout Probability - // Softmax Grad - dispatch_masked_scale_softmax_backward_stream( - static_cast(output_grads.data_ptr()), - static_cast(output_grads.data_ptr()), - reinterpret_cast(softmax_results.data_ptr()), - static_cast(dropout_mask.data_ptr()), - 1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len, - attn_batches * q_seq_len, stream); - // backward pass is completely in-place - return output_grads; -} -} // namespace additive_mask_softmax_dropout -} // namespace fused_softmax -} // namespace multihead_attn diff --git a/apex/contrib/csrc/multihead_attn/dropout.cuh b/apex/contrib/csrc/multihead_attn/dropout.cuh deleted file mode 100644 index 6f3922a..0000000 --- a/apex/contrib/csrc/multihead_attn/dropout.cuh +++ /dev/null @@ -1,272 +0,0 @@ -#pragma once -#include - -#ifdef OLD_GENERATOR_PATH -#include -#else -#include -#endif - -#include -#include - -namespace { -constexpr int UNROLL = 4; -} // namespace - -template -__global__ void -apex_fused_dropout_kernel(scalar_t const *inputs, scalar_t *outputs, - uint8_t *mask, IndexType totalElements, accscalar_t p, - std::pair seeds) { - accscalar_t pinv = accscalar_t(1) / p; - IndexType idx = blockIdx.x * blockDim.x + threadIdx.x; - - curandStatePhilox4_32_10_t state; - curand_init(seeds.first, idx, seeds.second, &state); - - IndexType rounded_size = - ((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) * - blockDim.x * gridDim.x * UNROLL; - for (IndexType linearIndex = idx; linearIndex < rounded_size; - linearIndex += gridDim.x * blockDim.x * UNROLL) { - float4 rand = curand_uniform4(&state); - scalar_t src[UNROLL]; - rand.x = rand.x <= p; - rand.y = rand.y <= p; - rand.z = rand.z <= p; - rand.w = rand.w <= p; - - for (int ii = 0; ii < UNROLL; ii++) { - IndexType li = linearIndex + blockDim.x * gridDim.x * ii; - if (li < totalElements) { - src[ii] = inputs[li]; - } - } - for (int ii = 0; ii < UNROLL; ii++) { - IndexType li = linearIndex + blockDim.x * gridDim.x * ii; - if (li < totalElements) { - outputs[li] = src[ii] * (&rand.x)[ii] * pinv; - mask[li] = (uint8_t)(&rand.x)[ii]; - } - } - __syncthreads(); - } -} - -template -__global__ void apex_dropout_add_kernel(scalar_t const *inputs, - scalar_t const *add_inputs, - scalar_t *outputs, uint8_t *mask, - IndexType totalElements, accscalar_t p, - std::pair seeds) { - accscalar_t pinv = accscalar_t(1) / p; - IndexType idx = blockIdx.x * blockDim.x + threadIdx.x; - - curandStatePhilox4_32_10_t state; - curand_init(seeds.first, idx, seeds.second, &state); - - IndexType rounded_size = - ((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) * - blockDim.x * gridDim.x * UNROLL; - for (IndexType linearIndex = idx; linearIndex < rounded_size; - linearIndex += gridDim.x * blockDim.x * UNROLL) { - float4 rand = curand_uniform4(&state); - scalar_t src[UNROLL]; - scalar_t add_src[UNROLL]; - rand.x = rand.x <= p; - rand.y = rand.y <= p; - rand.z = rand.z <= p; - rand.w = rand.w <= p; - for (int ii = 0; ii < UNROLL; ii++) { - IndexType li = linearIndex + blockDim.x * gridDim.x * ii; - if (li < totalElements) { - src[ii] = inputs[li]; - add_src[ii] = add_inputs[li]; - } - } - for (int ii = 0; ii < UNROLL; ii++) { - IndexType li = linearIndex + blockDim.x * gridDim.x * ii; - if (li < totalElements) { - accscalar_t int1 = src[ii] * (&rand.x)[ii] * pinv; - outputs[li] = - static_cast(static_cast(add_src[ii]) + int1); - mask[li] = (uint8_t)(&rand.x)[ii]; - } - } - __syncthreads(); - } -} - -template -__global__ void apex_add_kernel(scalar_t const *inputs, - scalar_t const *add_inputs, scalar_t *outputs, - IndexType totalElements) { - IndexType idx = blockIdx.x * blockDim.x + threadIdx.x; - IndexType rounded_size = - ((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) * - blockDim.x * gridDim.x * UNROLL; - for (IndexType linearIndex = idx; linearIndex < rounded_size; - linearIndex += gridDim.x * blockDim.x * UNROLL) { - scalar_t src[UNROLL]; - scalar_t add_src[UNROLL]; - for (int ii = 0; ii < UNROLL; ii++) { - IndexType li = linearIndex + blockDim.x * gridDim.x * ii; - if (li < totalElements) { - src[ii] = inputs[li]; - add_src[ii] = add_inputs[li]; - } - } - for (int ii = 0; ii < UNROLL; ii++) { - IndexType li = linearIndex + blockDim.x * gridDim.x * ii; - if (li < totalElements) { - outputs[li] = src[ii] + add_src[ii]; - } - } - __syncthreads(); - } -} - -template -__global__ void apex_masked_scale_kernel(scalar_t const *inputs, - scalar_t *outputs, uint8_t const *mask, - IndexType totalElements, - accscalar_t scale) { - IndexType idx = blockIdx.x * blockDim.x + threadIdx.x; - IndexType rounded_size = - ((totalElements - 1) / (blockDim.x * gridDim.x * UNROLL) + 1) * - blockDim.x * gridDim.x * UNROLL; - for (IndexType linearIndex = idx; linearIndex < rounded_size; - linearIndex += gridDim.x * blockDim.x * UNROLL) { - scalar_t src[UNROLL]; - scalar_t msk[UNROLL]; - for (int ii = 0; ii < UNROLL; ii++) { - IndexType li = linearIndex + blockDim.x * gridDim.x * ii; - if (li < totalElements) { - src[ii] = static_cast(inputs[li]); - msk[ii] = static_cast(mask[li]); - } - } - for (int ii = 0; ii < UNROLL; ii++) { - IndexType li = linearIndex + blockDim.x * gridDim.x * ii; - if (li < totalElements) { - outputs[li] = static_cast(src[ii]) * scale * - static_cast(msk[ii]); - } - } - } -} - -template -void apex_fused_dropout_cuda(scalar_t const *inputs, scalar_t *outputs, - uint8_t *mask, IndexType totalElements, - accscalar_t p) { - auto gen = at::cuda::detail::getDefaultCUDAGenerator(); - - int block_size = 256; - dim3 dim_block(block_size); - dim3 grid((totalElements + block_size - 1) / block_size); - unsigned int blocks_per_sm = - at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / - block_size; - grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties() - ->multiProcessorCount * - blocks_per_sm, - grid.x); - - // number of times random will be generated per thread, to offset philox - // counter in the random state - int64_t counter_offset = - ((totalElements - 1) / (block_size * grid.x * UNROLL) + 1) * UNROLL; - std::pair rng_engine_inputs; - { - // See Note [Acquire lock when using random generators] - std::lock_guard lock(gen.mutex()); - rng_engine_inputs = - at::check_generator(gen)->philox_engine_inputs( - counter_offset); - } - - apex_fused_dropout_kernel - <<>>( - inputs, outputs, mask, totalElements, p, rng_engine_inputs); - C10_CUDA_CHECK(cudaGetLastError()); -} - -template -void apex_dropout_add_cuda(scalar_t const *inputs, scalar_t const *add_inputs, - scalar_t *outputs, uint8_t *mask, - IndexType totalElements, accscalar_t p) { - auto gen = at::cuda::detail::getDefaultCUDAGenerator(); - - int block_size = 256; - dim3 dim_block(block_size); - dim3 grid((totalElements + block_size - 1) / block_size); - unsigned int blocks_per_sm = - at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / - block_size; - grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties() - ->multiProcessorCount * - blocks_per_sm, - grid.x); - - // number of times random will be generated per thread, to offset philox - // counter in the random state - int64_t counter_offset = - ((totalElements - 1) / (block_size * grid.x * UNROLL) + 1) * UNROLL; - std::pair rng_engine_inputs; - { - // See Note [Acquire lock when using random generators] - std::lock_guard lock(gen.mutex()); - rng_engine_inputs = - at::check_generator(gen)->philox_engine_inputs( - counter_offset); - } - - apex_dropout_add_kernel - <<>>( - inputs, add_inputs, outputs, mask, totalElements, p, - rng_engine_inputs); - C10_CUDA_CHECK(cudaGetLastError()); -} - -template -void apex_add_cuda(scalar_t const *inputs, scalar_t const *add_inputs, - scalar_t *outputs, IndexType totalElements) { - int block_size = 256; - dim3 dim_block(block_size); - dim3 grid((totalElements + block_size - 1) / block_size); - unsigned int blocks_per_sm = - at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / - block_size; - grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties() - ->multiProcessorCount * - blocks_per_sm, - grid.x); - - apex_add_kernel - <<>>( - inputs, add_inputs, outputs, totalElements); - C10_CUDA_CHECK(cudaGetLastError()); -} - -template -void apex_masked_scale_cuda(scalar_t const *inputs, scalar_t *outputs, - uint8_t const *mask, IndexType totalElements, - accscalar_t scale) { - int block_size = 256; - dim3 dim_block(block_size); - dim3 grid((totalElements + block_size - 1) / block_size); - unsigned int blocks_per_sm = - at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / - block_size; - grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties() - ->multiProcessorCount * - blocks_per_sm, - grid.x); - - apex_masked_scale_kernel - <<>>( - inputs, outputs, mask, totalElements, scale); - C10_CUDA_CHECK(cudaGetLastError()); -} diff --git a/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu b/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu deleted file mode 100644 index 850b24d..0000000 --- a/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu +++ /dev/null @@ -1,611 +0,0 @@ -#include -#include -#include - -#include -#include -//#include -#include - -#include -#include -#include - -#include "dropout.cuh" -#include "softmax.cuh" -#include "strided_batched_gemm.cuh" - -namespace multihead_attn { -namespace encdec { -namespace rocblas_gemmex { - -std::vector fwd_cuda(bool use_time_mask, bool is_training, - int heads, torch::Tensor const &inputs_q, - torch::Tensor const &inputs_kv, - torch::Tensor const &input_weights_q, - torch::Tensor const &input_weights_kv, - torch::Tensor const &output_weights, - const uint8_t *pad_mask, - float dropout_prob) { - const int embed_dim = inputs_q.size(2); - const int sequences = inputs_q.size(1); - const int q_seq_len = inputs_q.size(0); - const int k_seq_len = inputs_kv.size(0); - const int batches_q = sequences * q_seq_len; - const int batches_kv = sequences * k_seq_len; - const int head_dim = embed_dim / heads; - const int output_lin_q_dim = embed_dim; - const int output_lin_kv_dim = 2 * embed_dim; - const int attn_batches = heads * sequences; - const int lead_dim_q = attn_batches * head_dim; - const int lead_dim_kv = attn_batches * 2 * head_dim; - const int batch_stride_q = head_dim; - const int batch_stride_kv = 2 * head_dim; - const int dropout_elems = attn_batches * q_seq_len * k_seq_len; - const float alpha = 1.0; - const float beta = 0.0; - const float scale = 1.0 / sqrt(static_cast(head_dim)); - - // There is no reason to use more than one stream as every kernel is - // sequentially dependent - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - cublasSetStream(handle, stream); - - // 3 Intermediate Results + Output (Note: dropout intermediates are generated - // by ATen library code) - auto act_options = inputs_q.options().requires_grad(false); - auto mask_options = act_options.dtype(torch::kUInt8); - - torch::Tensor input_lin_q_results = - torch::empty({q_seq_len, sequences, output_lin_q_dim}, act_options); - torch::Tensor input_lin_kv_results = - torch::empty({k_seq_len, sequences, output_lin_kv_dim}, act_options); - torch::Tensor softmax_results = - torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_results = - torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_mask = - torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); - torch::Tensor matmul2_results = - torch::empty({q_seq_len, attn_batches, head_dim}, act_options); - torch::Tensor outputs = torch::empty_like(inputs_q, act_options); - - // Input Linear Results Pointers to Q, K, and V of interviewed activations - void *q_lin_results_ptr = static_cast(input_lin_q_results.data_ptr()); - void *k_lin_results_ptr = - static_cast(input_lin_kv_results.data_ptr()); - void *v_lin_results_ptr = static_cast( - static_cast(input_lin_kv_results.data_ptr()) + head_dim); - - // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) - void *softmax_results_ptr = static_cast(softmax_results.data_ptr()); - - char a_layout_t{'t'}; - char a_layout_n{'n'}; - char b_layout_n{'n'}; - - rocblas_int flags = 0; - - //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); - - // Input Linear Q Fwd - TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, - CUBLAS_OP_T, - CUBLAS_OP_N, - output_lin_q_dim, - batches_q, - embed_dim, - static_cast(&alpha), - static_cast(input_weights_q.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(inputs_q.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(&beta), - q_lin_results_ptr, - rocblas_datatype_f16_r, - output_lin_q_dim, - q_lin_results_ptr, - rocblas_datatype_f16_r, - output_lin_q_dim, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard /*algo*/, - 0 /*solution_index*/, - flags)); - - // Input Linear KV Fwd - TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, - CUBLAS_OP_T, - CUBLAS_OP_N, - output_lin_kv_dim, - batches_kv, - embed_dim, - static_cast(&alpha), - static_cast(input_weights_kv.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(inputs_kv.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(&beta), - k_lin_results_ptr, - rocblas_datatype_f16_r, - output_lin_kv_dim, - k_lin_results_ptr, - rocblas_datatype_f16_r, - output_lin_kv_dim, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard /*algo*/, - 0 /*solution_index*/, - flags)); - - // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) - gemm_switch_fp32accum( a_layout_t, - b_layout_n, - k_seq_len, - q_seq_len, - head_dim, - scale, - static_cast(k_lin_results_ptr), - lead_dim_kv, - batch_stride_kv, - static_cast(q_lin_results_ptr), - lead_dim_q, - batch_stride_q, - beta, - static_cast(softmax_results_ptr), - k_seq_len, - k_seq_len*q_seq_len, - static_cast(softmax_results_ptr), - k_seq_len, - k_seq_len*q_seq_len, - attn_batches, - flags); - - // Padded Softmax - bool softmax_success = false; - if (pad_mask == nullptr) { - softmax_success = dispatch_softmax( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(softmax_results_ptr), k_seq_len, - k_seq_len, attn_batches * q_seq_len); - } else { - if (use_time_mask) { - softmax_success = dispatch_time_masked_softmax( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(softmax_results_ptr), pad_mask, - k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len); - } else { - softmax_success = dispatch_masked_softmax( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(softmax_results_ptr), pad_mask, - k_seq_len, k_seq_len, attn_batches * q_seq_len, - attn_batches * q_seq_len / sequences); - } - } - assert(softmax_success); - - if (is_training) { - apex_fused_dropout_cuda( - static_cast(softmax_results.data_ptr()), - static_cast(dropout_results.data_ptr()), - static_cast(dropout_mask.data_ptr()), dropout_elems, - (1.0f - dropout_prob)); - } - - // Matmul2 - gemm_switch_fp32accum( a_layout_n, - b_layout_n, - head_dim, - q_seq_len, - k_seq_len, - alpha, - static_cast(v_lin_results_ptr), - lead_dim_kv, - batch_stride_kv, - (is_training) ? static_cast(dropout_results.data_ptr()) : static_cast(softmax_results.data_ptr()) , - k_seq_len, - k_seq_len*q_seq_len, - beta, - static_cast(matmul2_results.data_ptr()), - head_dim*attn_batches, - head_dim, - static_cast(matmul2_results.data_ptr()), - head_dim*attn_batches, - head_dim, - attn_batches, - flags); - - // Output Linear - TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, - CUBLAS_OP_T, - CUBLAS_OP_N, - embed_dim, - batches_q, - embed_dim, - static_cast(&alpha), - static_cast(output_weights.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(matmul2_results.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(&beta), - static_cast(outputs.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(outputs.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard /*algo*/, - 0 /*solution_index*/, - flags)); - //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); - - return {input_lin_q_results, - input_lin_kv_results, - softmax_results, - dropout_results, - dropout_mask, - matmul2_results, - outputs}; -} - -std::vector bwd_cuda( - int heads, torch::Tensor const &output_grads, - torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results, - torch::Tensor const &softmax_results, - torch::Tensor const &input_lin_q_results, - torch::Tensor const &input_lin_kv_results, torch::Tensor const &inputs_q, - torch::Tensor const &inputs_kv, torch::Tensor const &input_weights_q, - torch::Tensor const &input_weights_kv, torch::Tensor const &output_weights, - torch::Tensor const &dropout_mask, float dropout_prob) { - const int embed_dim = inputs_q.size(2); - const int sequences = inputs_q.size(1); - const int q_seq_len = inputs_q.size(0); - const int k_seq_len = inputs_kv.size(0); - const int batches_q = sequences * q_seq_len; - const int batches_kv = sequences * k_seq_len; - const int head_dim = embed_dim / heads; - const int output_lin_q_dim = embed_dim; - const int output_lin_kv_dim = 2 * embed_dim; - const int attn_batches = heads * sequences; - const int lead_dim_q = attn_batches * head_dim; - const int lead_dim_kv = attn_batches * 2 * head_dim; - const int batch_stride_q = head_dim; - const int batch_stride_kv = 2 * head_dim; - const int dropout_elems = attn_batches * q_seq_len * k_seq_len; - const float alpha = 1.0; - const float beta = 0.0; - const float scale = 1.0 / sqrt(static_cast(head_dim)); - - // TODO: Streams can be used in Backprop but I haven't added more than one - // in my first attempt to create the code - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - cublasSetStream(handle, stream); - - // Output Tensor Allocations - torch::Tensor input_q_grads = torch::empty_like(inputs_q); - torch::Tensor input_kv_grads = torch::empty_like(inputs_kv); - torch::Tensor input_weight_q_grads = torch::empty_like(input_weights_q); - torch::Tensor input_weight_kv_grads = torch::empty_like(input_weights_kv); - torch::Tensor output_weight_grads = torch::empty_like(output_weights); - // Intermediate Tensor Allocations - at::Tensor output_lin_grads = torch::empty_like(matmul2_results); - at::Tensor matmul2_grads = torch::empty_like(dropout_results); - at::Tensor input_lin_q_output_grads = torch::empty_like(input_lin_q_results); - at::Tensor input_lin_kv_output_grads = - torch::empty_like(input_lin_kv_results); - - auto q_lin_results_ptr = static_cast(input_lin_q_results.data_ptr()); - auto k_lin_results_ptr = static_cast(input_lin_kv_results.data_ptr()); - auto v_lin_results_ptr = - static_cast(input_lin_kv_results.data_ptr()) + head_dim; - - auto q_lin_grads_ptr = - static_cast(input_lin_q_output_grads.data_ptr()); - auto k_lin_grads_ptr = - static_cast(input_lin_kv_output_grads.data_ptr()); - auto v_lin_grads_ptr = - static_cast(input_lin_kv_output_grads.data_ptr()) + head_dim; - - char a_layout_n{'n'}; - char a_layout_t{'t'}; - char b_layout_n{'n'}; - char b_layout_t{'t'}; - - rocblas_int flags = 0; - - //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); - #ifdef __HIP_PLATFORM_HCC__ - #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR) - #define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) - #if USE_GEMM_FLAGS_FP16_ALT_IMPL - #ifdef BACKWARD_PASS_GUARD - flags = at::BACKWARD_PASS_GUARD_CLASS::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0; - #endif - #endif - #endif - - // Output Linear Dgrad - TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, - CUBLAS_OP_N, - CUBLAS_OP_N, - embed_dim, - batches_q, - embed_dim, - static_cast(&alpha), - static_cast(output_weights.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(output_grads.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(&beta), - static_cast(output_lin_grads.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(output_lin_grads.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard /*algo*/, - 0 /*solution_index*/, - flags)); - - // Output Linear Wgrad - TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, - CUBLAS_OP_N, - CUBLAS_OP_T, - embed_dim, - embed_dim, - batches_q, - static_cast(&alpha), - static_cast(matmul2_results.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(output_grads.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(&beta), - static_cast(output_weight_grads.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(output_weight_grads.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard /*algo*/, - 0 /*solution_index*/, - flags)); - - // MatMul2 Dgrad1 - gemm_switch_fp32accum( a_layout_t, - b_layout_n, - k_seq_len, - q_seq_len, - head_dim, - alpha, - static_cast(v_lin_results_ptr), - lead_dim_kv, - batch_stride_kv, - static_cast(output_lin_grads.data_ptr()), - head_dim*attn_batches, - head_dim, - beta, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - attn_batches, - flags); - - // Matmul2 Dgrad2 - gemm_switch_fp32accum( a_layout_n, - b_layout_t, - head_dim, - k_seq_len, - q_seq_len, - alpha, - static_cast(output_lin_grads.data_ptr()), - head_dim*attn_batches, - head_dim, - static_cast(dropout_results.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - beta, - v_lin_grads_ptr, - lead_dim_kv, - batch_stride_kv, - v_lin_grads_ptr, - lead_dim_kv, - batch_stride_kv, - attn_batches, - flags); - - // Apply Dropout Mask and Scale by Dropout Probability - apex_masked_scale_cuda( - static_cast(matmul2_grads.data_ptr()), - static_cast(matmul2_grads.data_ptr()), - static_cast(dropout_mask.data_ptr()), - dropout_elems, - (1.0 / (1.0 - dropout_prob))); - - // Softmax Grad - bool softmax_success = false; - softmax_success = dispatch_softmax_backward( - static_cast(matmul2_grads.data_ptr()), - static_cast(matmul2_grads.data_ptr()), - reinterpret_cast(softmax_results.data_ptr()), k_seq_len, - k_seq_len, attn_batches * q_seq_len); - assert(softmax_success); - - // Matmul1 Dgrad1 - gemm_switch_fp32accum( a_layout_n, - b_layout_n, - head_dim, - q_seq_len, - k_seq_len, - scale, - k_lin_results_ptr, - lead_dim_kv, - batch_stride_kv, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - beta, - q_lin_grads_ptr, - lead_dim_q, - batch_stride_q, - q_lin_grads_ptr, - lead_dim_q, - batch_stride_q, - attn_batches, - flags); - - // Matmul1 Dgrad2 - gemm_switch_fp32accum( a_layout_n, - b_layout_t, - head_dim, - k_seq_len, - q_seq_len, - scale, - q_lin_results_ptr, - lead_dim_q, - batch_stride_q, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - beta, - k_lin_grads_ptr, - lead_dim_kv, - batch_stride_kv, - k_lin_grads_ptr, - lead_dim_kv, - batch_stride_kv, - attn_batches, - flags); - - // Input Linear Q Dgrad - TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, - CUBLAS_OP_N, - CUBLAS_OP_N, - embed_dim, - batches_q, - output_lin_q_dim, - static_cast(&alpha), - static_cast(input_weights_q.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(q_lin_grads_ptr), - rocblas_datatype_f16_r, - output_lin_q_dim, - static_cast(&beta), - static_cast(input_q_grads.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(input_q_grads.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard /*algo*/, - 0 /*solution_index*/, - flags)); - - // Input Linear Q Wgrad - TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, - CUBLAS_OP_N, - CUBLAS_OP_T, - embed_dim, - output_lin_q_dim, - batches_q, - static_cast(&alpha), - static_cast(inputs_q.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(q_lin_grads_ptr), - rocblas_datatype_f16_r, - output_lin_q_dim, - static_cast(&beta), - static_cast(input_weight_q_grads.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(input_weight_q_grads.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard /*algo*/, - 0 /*solution_index*/, - flags)); - - // Input Linear KV Dgrad - TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, - CUBLAS_OP_N, - CUBLAS_OP_N, - embed_dim, - batches_kv, - output_lin_kv_dim, - static_cast(&alpha), - static_cast(input_weights_kv.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(k_lin_grads_ptr), - rocblas_datatype_f16_r, - output_lin_kv_dim, - static_cast(&beta), - static_cast(input_kv_grads.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(input_kv_grads.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard /*algo*/, - 0 /*solution_index*/, - flags)); - - // Input Linear KV Wgrad - TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, - CUBLAS_OP_N, - CUBLAS_OP_T, - embed_dim, - output_lin_kv_dim, - batches_kv, - static_cast(&alpha), - static_cast(inputs_kv.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(k_lin_grads_ptr), - rocblas_datatype_f16_r, - output_lin_kv_dim, - static_cast(&beta), - static_cast(input_weight_kv_grads.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(input_weight_kv_grads.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard /*algo*/, - 0 /*solution_index*/, - flags)); - // TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); - return { - input_q_grads, - input_kv_grads, - input_weight_q_grads, - input_weight_kv_grads, - output_weight_grads - }; -} - -} // end namespace rocblas_gemmex -} // end namespace encdec -} // end namespace multihead_attn - diff --git a/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu b/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu deleted file mode 100644 index 063c2d6..0000000 --- a/apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu +++ /dev/null @@ -1,690 +0,0 @@ -#include -#include -#include - -#include -#include -//#include -#include - -#include -#include -#include - -#include "dropout.cuh" -#include "layer_norm.cuh" -#include "softmax.cuh" -#include "strided_batched_gemm.cuh" - -namespace multihead_attn { -namespace encdec_norm_add { -namespace rocblas_gemmex { - -std::vector fwd_cuda( - bool use_time_mask, - bool is_training, - int heads, - torch::Tensor const& inputs_q, - torch::Tensor const& inputs_kv, - torch::Tensor const& lyr_nrm_gamma_weights, - torch::Tensor const& lyr_nrm_beta_weights, - torch::Tensor const& input_weights_q, - torch::Tensor const& input_weights_kv, - torch::Tensor const& output_weights, - const uint8_t* pad_mask, - float dropout_prob - ) -{ - const int embed_dim = inputs_q.size(2); - const int sequences = inputs_q.size(1); - const int q_seq_len = inputs_q.size(0); - const int k_seq_len = inputs_kv.size(0); - const int batches_q = sequences * q_seq_len; - const int batches_kv = sequences * k_seq_len; - const int total_tokens_q = batches_q * embed_dim; - const int head_dim = embed_dim / heads; - const int output_lin_q_dim = embed_dim; - const int output_lin_kv_dim = 2 * embed_dim; - const int attn_batches = heads * sequences; - const int lead_dim_q = attn_batches * head_dim; - const int lead_dim_kv = attn_batches * 2 *head_dim; - const int batch_stride_q = head_dim; - const int batch_stride_kv = 2 * head_dim; - const int dropout_elems = attn_batches * q_seq_len * k_seq_len; - const float alpha = 1.0; - const float beta = 0.0; - const float scale = 1.0 / sqrt(static_cast(head_dim)); - - // There is no reason to use more than one stream as every kernel is - // sequentially dependent - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - cublasSetStream(handle, stream); - - // 3 Intermediate Results + Output (Note: dropout intermediates are generated - // by ATen library code) - auto act_options = inputs_q.options().requires_grad(false); - auto lyr_nrm_options = act_options.dtype(torch::kFloat32); - auto mask_options = act_options.dtype(torch::kUInt8); - - torch::Tensor lyr_nrm_mean = torch::empty({batches_q}, lyr_nrm_options); - torch::Tensor lyr_nrm_invvar = torch::empty({batches_q}, lyr_nrm_options); - torch::Tensor lyr_nrm_results = torch::empty_like(inputs_q, act_options); - - torch::Tensor input_lin_q_results = - torch::empty({q_seq_len, sequences, output_lin_q_dim}, act_options); - torch::Tensor input_lin_kv_results = - torch::empty({k_seq_len, sequences, output_lin_kv_dim}, act_options); - torch::Tensor softmax_results = - torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_results = - torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_mask = - torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); - torch::Tensor matmul2_results = - torch::empty({q_seq_len, attn_batches, head_dim}, act_options); - torch::Tensor output_lin_results = torch::empty_like(inputs_q, act_options); - torch::Tensor dropout_add_mask = torch::empty_like(inputs_q, mask_options); - torch::Tensor outputs = torch::empty_like(inputs_q, act_options); - - // Input Linear Results Pointers to Q, K, and V of interviewed activations - void *q_lin_results_ptr = static_cast(input_lin_q_results.data_ptr()); - void *k_lin_results_ptr = - static_cast(input_lin_kv_results.data_ptr()); - void *v_lin_results_ptr = static_cast( - static_cast(input_lin_kv_results.data_ptr()) + head_dim); - - // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) - void *softmax_results_ptr = static_cast(softmax_results.data_ptr()); - - char a_layout_t{'t'}; - char a_layout_n{'n'}; - char b_layout_n{'n'}; - - rocblas_int flags = 0; - - //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); - // Layer Norm - HostApplyLayerNorm( - static_cast(lyr_nrm_results.data_ptr()), - static_cast(lyr_nrm_mean.data_ptr()), - static_cast(lyr_nrm_invvar.data_ptr()), - static_cast(inputs_q.data_ptr()), - static_cast(batches_q), // n1 - static_cast(embed_dim), // n2 - 1.0e-5, static_cast(lyr_nrm_gamma_weights.data_ptr()), - static_cast(lyr_nrm_beta_weights.data_ptr())); - - // Input Linear Q Fwd - TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, - CUBLAS_OP_T, - CUBLAS_OP_N, - output_lin_q_dim, - batches_q, - embed_dim, - static_cast(&alpha), - static_cast(input_weights_q.data_ptr()), - rocblas_datatype_f16_r /*a_type*/, - embed_dim, - //static_cast(inputs_q.data_ptr()), - static_cast(lyr_nrm_results.data_ptr()), - rocblas_datatype_f16_r /*b_type*/, - embed_dim, - static_cast(&beta), - q_lin_results_ptr, - rocblas_datatype_f16_r /*c_type*/, - output_lin_q_dim, - q_lin_results_ptr, - rocblas_datatype_f16_r /*d_type*/, - output_lin_q_dim, - rocblas_datatype_f32_r /*compute_type*/, - rocblas_gemm_algo_standard /*algo*/, - 0 /*solution_index*/, - flags)); - - // Input Linear KV Fwd - TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, - CUBLAS_OP_T, - CUBLAS_OP_N, - output_lin_kv_dim, - batches_kv, - embed_dim, - static_cast(&alpha), - static_cast(input_weights_kv.data_ptr()), - rocblas_datatype_f16_r /*a_type*/, - embed_dim, - static_cast(inputs_kv.data_ptr()), - rocblas_datatype_f16_r /*b_type*/, - embed_dim, - static_cast(&beta), - k_lin_results_ptr, - rocblas_datatype_f16_r /*c_type*/, - output_lin_kv_dim, - k_lin_results_ptr, - rocblas_datatype_f16_r /*d_type*/, - output_lin_kv_dim, - rocblas_datatype_f32_r /*compute_type*/, - rocblas_gemm_algo_standard /*algo*/, - 0 /*solution_index*/, - flags)); - // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) - gemm_switch_fp32accum( a_layout_t, - b_layout_n, - k_seq_len, - q_seq_len, - head_dim, - scale, - static_cast(k_lin_results_ptr), - lead_dim_kv, - batch_stride_kv, - static_cast(q_lin_results_ptr), - lead_dim_q, - batch_stride_q, - beta, - static_cast(softmax_results_ptr), - k_seq_len, - k_seq_len*q_seq_len, - static_cast(softmax_results_ptr), - k_seq_len, - k_seq_len*q_seq_len, - attn_batches, - flags); - - // Padded Softmax - bool softmax_success = false; - if (pad_mask == nullptr) { - softmax_success = dispatch_softmax( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(softmax_results_ptr), k_seq_len, - k_seq_len, attn_batches * q_seq_len); - } else { - if (use_time_mask) { - softmax_success = dispatch_time_masked_softmax( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(softmax_results_ptr), pad_mask, - k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len); - } else { - softmax_success = dispatch_masked_softmax( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(softmax_results_ptr), pad_mask, - k_seq_len, k_seq_len, attn_batches * q_seq_len, - attn_batches * q_seq_len / sequences); - } - } - assert(softmax_success); - - if (is_training) { - apex_fused_dropout_cuda( - static_cast(softmax_results.data_ptr()), - static_cast(dropout_results.data_ptr()), - static_cast(dropout_mask.data_ptr()), dropout_elems, - (1.0f - dropout_prob)); - } - - // Matmul2 - gemm_switch_fp32accum( a_layout_n, - b_layout_n, - head_dim, - q_seq_len, - k_seq_len, - alpha, - static_cast(v_lin_results_ptr), - lead_dim_kv, - batch_stride_kv, - (is_training) ? static_cast(dropout_results.data_ptr()) : static_cast(softmax_results.data_ptr()), - //static_cast(dropout_results.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - beta, - static_cast(matmul2_results.data_ptr()), - head_dim*attn_batches, - head_dim, - static_cast(matmul2_results.data_ptr()), - head_dim*attn_batches, - head_dim, - attn_batches, - flags); - - // Output Linear - TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, - CUBLAS_OP_T, - CUBLAS_OP_N, - embed_dim, - batches_q, - embed_dim, - static_cast(&alpha), - static_cast(output_weights.data_ptr()), - rocblas_datatype_f16_r /*a_type*/, - embed_dim, - static_cast(matmul2_results.data_ptr()), - rocblas_datatype_f16_r /*b_type*/, - embed_dim, - static_cast(&beta), - static_cast(output_lin_results.data_ptr()), - rocblas_datatype_f16_r /*c_type*/, - embed_dim, - static_cast(output_lin_results.data_ptr()), - rocblas_datatype_f16_r /*d_type*/, - embed_dim, - rocblas_datatype_f32_r /*compute_type*/, - rocblas_gemm_algo_standard /*algo*/, - 0 /*solution_index*/, - flags)); - - // End-of-block Dropout-Add - if (is_training) { - apex_dropout_add_cuda( - static_cast(output_lin_results.data_ptr()), - static_cast(inputs_q.data_ptr()), - static_cast(outputs.data_ptr()), - static_cast(dropout_add_mask.data_ptr()), total_tokens_q, - (1.0f - dropout_prob)); - } else { - apex_add_cuda( - static_cast(output_lin_results.data_ptr()), - static_cast(inputs_q.data_ptr()), - static_cast(outputs.data_ptr()), total_tokens_q); - } - - //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); - - return {lyr_nrm_results, - lyr_nrm_mean, - lyr_nrm_invvar, - input_lin_q_results, - input_lin_kv_results, - softmax_results, - dropout_results, - dropout_mask, - matmul2_results, - dropout_add_mask, - outputs}; -} - -std::vector bwd_cuda( - int heads, torch::Tensor const &output_grads, - torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results, - torch::Tensor const &softmax_results, - torch::Tensor const &input_lin_q_results, - torch::Tensor const &input_lin_kv_results, - torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean, - torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs_q, - torch::Tensor const &inputs_kv, torch::Tensor const &lyr_nrm_gamma_weights, - torch::Tensor const &lyr_nrm_beta_weights, - torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv, - torch::Tensor const &output_weights, torch::Tensor const &dropout_mask, - torch::Tensor const &dropout_add_mask, float dropout_prob) { - const int embed_dim = inputs_q.size(2); - const int sequences = inputs_q.size(1); - const int q_seq_len = inputs_q.size(0); - const int k_seq_len = inputs_kv.size(0); - const int batches_q = sequences * q_seq_len; - const int batches_kv = sequences * k_seq_len; - const int total_tokens_q = batches_q * embed_dim; - const int head_dim = embed_dim / heads; - const int output_lin_q_dim = embed_dim; - const int output_lin_kv_dim = 2 * embed_dim; - const int attn_batches = heads * sequences; - const int lead_dim_q = attn_batches * head_dim; - const int lead_dim_kv = attn_batches * 2 * head_dim; - const int batch_stride_q = head_dim; - const int batch_stride_kv = 2 * head_dim; - const int dropout_elems = attn_batches * q_seq_len * k_seq_len; - const float alpha = 1.0; - const float beta = 0.0; - const float scale = 1.0 / sqrt(static_cast(head_dim)); - - // TODO: Streams can be used in Backprop but I haven't added more than one - // in my first attempt to create the code - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - cublasSetStream(handle, stream); - - // Output Tensor Allocations - torch::Tensor input_q_grads = torch::empty_like(inputs_q); - torch::Tensor input_kv_grads = torch::empty_like(inputs_kv); - torch::Tensor lyr_nrm_gamma_grads = torch::empty_like(lyr_nrm_gamma_weights); - torch::Tensor lyr_nrm_beta_grads = torch::empty_like(lyr_nrm_beta_weights); - torch::Tensor input_weight_q_grads = torch::empty_like(input_weights_q); - torch::Tensor input_weight_kv_grads = torch::empty_like(input_weights_kv); - torch::Tensor output_weight_grads = torch::empty_like(output_weights); - // Intermediate Tensor Allocations - at::Tensor dropout_add_grads = torch::empty_like(output_grads); - at::Tensor output_lin_grads = torch::empty_like(matmul2_results); - at::Tensor matmul2_grads = torch::empty_like(dropout_results); - at::Tensor input_lin_q_output_grads = torch::empty_like(input_lin_q_results); - at::Tensor input_lin_kv_output_grads = - torch::empty_like(input_lin_kv_results); - at::Tensor input_lin_q_grads = torch::empty_like(inputs_q); - - auto q_lin_results_ptr = static_cast(input_lin_q_results.data_ptr()); - auto k_lin_results_ptr = static_cast(input_lin_kv_results.data_ptr()); - auto v_lin_results_ptr = - static_cast(input_lin_kv_results.data_ptr()) + head_dim; - - auto q_lin_grads_ptr = - static_cast(input_lin_q_output_grads.data_ptr()); - auto k_lin_grads_ptr = - static_cast(input_lin_kv_output_grads.data_ptr()); - auto v_lin_grads_ptr = - static_cast(input_lin_kv_output_grads.data_ptr()) + head_dim; - - char a_layout_n{'n'}; - char a_layout_t{'t'}; - char b_layout_n{'n'}; - char b_layout_t{'t'}; - - rocblas_int flags = 0; - - //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); - #ifdef __HIP_PLATFORM_HCC__ - #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR) - #define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) - #if USE_GEMM_FLAGS_FP16_ALT_IMPL - #ifdef BACKWARD_PASS_GUARD - flags = at::BACKWARD_PASS_GUARD_CLASS::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0; - #endif - #endif - #endif - - // Dropout Add Backward - apex_masked_scale_cuda( - static_cast(output_grads.data_ptr()), - static_cast(dropout_add_grads.data_ptr()), - static_cast(dropout_add_mask.data_ptr()), - total_tokens_q, - (1.0 / (1.0 - dropout_prob))); - - // Output Linear Dgrad - TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, - CUBLAS_OP_N, - CUBLAS_OP_N, - embed_dim, - batches_q, - embed_dim, - static_cast(&alpha), - static_cast(output_weights.data_ptr()), - rocblas_datatype_f16_r /*a_type*/, - embed_dim, - static_cast(dropout_add_grads.data_ptr()), - rocblas_datatype_f16_r /*b_type*/, - embed_dim, - static_cast(&beta), - static_cast(output_lin_grads.data_ptr()), - rocblas_datatype_f16_r /*c_type*/, - embed_dim, - static_cast(output_lin_grads.data_ptr()), - rocblas_datatype_f16_r /*d_type*/, - embed_dim, - rocblas_datatype_f32_r /*compute_type*/, - rocblas_gemm_algo_standard /*algo*/, - 0 /*solution_index*/, - flags)); - - // Output Linear Wgrad - TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, - CUBLAS_OP_N, - CUBLAS_OP_T, - embed_dim, - embed_dim, - batches_q, - static_cast(&alpha), - static_cast(matmul2_results.data_ptr()), - rocblas_datatype_f16_r /*a_type*/, - embed_dim, - static_cast(dropout_add_grads.data_ptr()), - rocblas_datatype_f16_r /*b_type*/, - embed_dim, - static_cast(&beta), - static_cast(output_weight_grads.data_ptr()), - rocblas_datatype_f16_r /*c_type*/, - embed_dim, - static_cast(output_weight_grads.data_ptr()), - rocblas_datatype_f16_r /*d_type*/, - embed_dim, - rocblas_datatype_f32_r /*compute_type*/, - rocblas_gemm_algo_standard /*algo*/, - 0 /*solution_index*/, - flags)); - - // MatMul2 Dgrad1 - gemm_switch_fp32accum( a_layout_t, - b_layout_n, - k_seq_len, - q_seq_len, - head_dim, - alpha, - static_cast(v_lin_results_ptr), - lead_dim_kv, - batch_stride_kv, - static_cast(output_lin_grads.data_ptr()), - head_dim*attn_batches, - head_dim, - beta, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - attn_batches, - flags); - - // Matmul2 Dgrad2 - gemm_switch_fp32accum( a_layout_n, - b_layout_t, - head_dim, - k_seq_len, - q_seq_len, - alpha, - static_cast(output_lin_grads.data_ptr()), - head_dim*attn_batches, - head_dim, - static_cast(dropout_results.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - beta, - v_lin_grads_ptr, - lead_dim_kv, - batch_stride_kv, - v_lin_grads_ptr, - lead_dim_kv, - batch_stride_kv, - attn_batches, - flags); - - // Apply Dropout Mask and Scale by Dropout Probability - apex_masked_scale_cuda( - static_cast(matmul2_grads.data_ptr()), - static_cast(matmul2_grads.data_ptr()), - static_cast(dropout_mask.data_ptr()), - dropout_elems, - (1.0 / (1.0 - dropout_prob))); - - // Softmax Grad - bool softmax_success = false; - softmax_success = dispatch_softmax_backward( - static_cast(matmul2_grads.data_ptr()), - static_cast(matmul2_grads.data_ptr()), - reinterpret_cast(softmax_results.data_ptr()), k_seq_len, - k_seq_len, attn_batches * q_seq_len); - assert(softmax_success); - - // Matmul1 Dgrad1 - gemm_switch_fp32accum( a_layout_n, - b_layout_n, - head_dim, - q_seq_len, - k_seq_len, - scale, - k_lin_results_ptr, - lead_dim_kv, - batch_stride_kv, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - beta, - q_lin_grads_ptr, - lead_dim_q, - batch_stride_q, - q_lin_grads_ptr, - lead_dim_q, - batch_stride_q, - attn_batches, - flags); - - // Matmul1 Dgrad2 - gemm_switch_fp32accum( a_layout_n, - b_layout_t, - head_dim, - k_seq_len, - q_seq_len, - scale, - q_lin_results_ptr, - lead_dim_q, - batch_stride_q, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - beta, - k_lin_grads_ptr, - lead_dim_kv, - batch_stride_kv, - k_lin_grads_ptr, - lead_dim_kv, - batch_stride_kv, - attn_batches, - flags); - - // Input Linear Q Dgrad - TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, - CUBLAS_OP_N, - CUBLAS_OP_N, - embed_dim, - batches_q, - output_lin_q_dim, - static_cast(&alpha), - static_cast(input_weights_q.data_ptr()), - rocblas_datatype_f16_r /*a_type*/, - embed_dim, - static_cast(q_lin_grads_ptr), - rocblas_datatype_f16_r /*b_type*/, - output_lin_q_dim, - static_cast(&beta), - //static_cast(input_q_grads.data_ptr()), - static_cast(input_lin_q_grads.data_ptr()), - rocblas_datatype_f16_r /*c_type*/, - embed_dim, - static_cast(input_lin_q_grads.data_ptr()), - rocblas_datatype_f16_r /*d_type*/, - embed_dim, - rocblas_datatype_f32_r /*compute_type*/, - rocblas_gemm_algo_standard /*algo*/, - 0 /*solution_index*/, - flags)); - - // Input Linear Q Wgrad - TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, - CUBLAS_OP_N, - CUBLAS_OP_T, - embed_dim, - output_lin_q_dim, - batches_q, - static_cast(&alpha), - static_cast(inputs_q.data_ptr()), - rocblas_datatype_f16_r /*a_type*/, - embed_dim, - static_cast(q_lin_grads_ptr), - rocblas_datatype_f16_r /*b_type*/, - output_lin_q_dim, - static_cast(&beta), - static_cast(input_weight_q_grads.data_ptr()), - rocblas_datatype_f16_r /*c_type*/, - embed_dim, - static_cast(input_weight_q_grads.data_ptr()), - rocblas_datatype_f16_r /*d_type*/, - embed_dim, - rocblas_datatype_f32_r /*compute_type*/, - rocblas_gemm_algo_standard /*algo*/, - 0 /*solution_index*/, - flags)); - - // Input Linear KV Dgrad - TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, - CUBLAS_OP_N, - CUBLAS_OP_N, - embed_dim, - batches_kv, - output_lin_kv_dim, - static_cast(&alpha), - static_cast(input_weights_kv.data_ptr()), - rocblas_datatype_f16_r /*a_type*/, - embed_dim, - static_cast(k_lin_grads_ptr), - rocblas_datatype_f16_r /*b_type*/, - output_lin_kv_dim, - static_cast(&beta), - static_cast(input_kv_grads.data_ptr()), - rocblas_datatype_f16_r /*c_type*/, - embed_dim, - static_cast(input_kv_grads.data_ptr()), - rocblas_datatype_f16_r /*d_type*/, - embed_dim, - rocblas_datatype_f32_r /*compute_type*/, - rocblas_gemm_algo_standard /*algo*/, - 0 /*solution_index*/, - flags)); - - // Input Linear KV Wgrad - TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, - CUBLAS_OP_N, - CUBLAS_OP_T, - embed_dim, - output_lin_kv_dim, - batches_kv, - static_cast(&alpha), - static_cast(inputs_kv.data_ptr()), - rocblas_datatype_f16_r /*a_type*/, - embed_dim, - static_cast(k_lin_grads_ptr), - rocblas_datatype_f16_r /*b_type*/, - output_lin_kv_dim, - static_cast(&beta), - static_cast(input_weight_kv_grads.data_ptr()), - rocblas_datatype_f16_r /*c_type*/, - embed_dim, - static_cast(input_weight_kv_grads.data_ptr()), - rocblas_datatype_f16_r /*d_type*/, - embed_dim, - rocblas_datatype_f32_r /*compute_type*/, - rocblas_gemm_algo_standard /*algo*/, - 0 /*solution_index*/, - flags)); - - // Fused Layer Norm Bwd with Residual Add - HostLayerNormGradient( - static_cast(input_lin_q_grads.data_ptr()), - static_cast(output_grads.data_ptr()), - static_cast(lyr_nrm_mean.data_ptr()), - static_cast(lyr_nrm_invvar.data_ptr()), - inputs_q, - static_cast(batches_q), // n1 - static_cast(embed_dim), // n2 - static_cast(lyr_nrm_gamma_weights.data_ptr()), - static_cast(lyr_nrm_beta_weights.data_ptr()), - 1.0e-5, - static_cast(input_q_grads.data_ptr()), - static_cast(lyr_nrm_gamma_grads.data_ptr()), - static_cast(lyr_nrm_beta_grads.data_ptr()) - ); - - //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); - - return {input_q_grads, input_kv_grads, lyr_nrm_gamma_grads, - lyr_nrm_beta_grads, input_weight_q_grads, input_weight_kv_grads, - output_weight_grads}; -} - -} // end namespace rocblas_gemmex -} // end namespace encdec_norm_add -} // end namespace multihead_attn \ No newline at end of file diff --git a/apex/contrib/csrc/multihead_attn/layer_norm.cuh b/apex/contrib/csrc/multihead_attn/layer_norm.cuh deleted file mode 100644 index 277323c..0000000 --- a/apex/contrib/csrc/multihead_attn/layer_norm.cuh +++ /dev/null @@ -1,649 +0,0 @@ -#pragma once -#include -#include -#include -#include - -namespace { -template -__device__ void cuWelfordOnlineSum(const U curr, U &mu, U &sigma2, U &count) { - count = count + U(1); - U delta = curr - mu; - U lmean = mu + delta / count; - mu = lmean; - U delta2 = curr - lmean; - sigma2 = sigma2 + delta * delta2; -} - -template -__device__ void cuChanOnlineSum(const U muB, const U sigma2B, const U countB, - U &mu, U &sigma2, U &count) { - U delta = muB - mu; - U nA = count; - U nB = countB; - count = count + countB; - U nX = count; - if (nX > U(0)) { - nA = nA / nX; - nB = nB / nX; - mu = nA * mu + nB * muB; - sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX; - } else { - mu = U(0); - sigma2 = U(0); - } -} - -template -__device__ void cuWelfordMuSigma2(const T *__restrict__ vals, const int n1, - const int n2, const int i1, U &mu, U &sigma2, - U *buf) { - // Assumptions: - // 1) blockDim.x == warpSize - // 2) Tensor is contiguous - // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. - // - // compute variance and mean over n2 - U count = U(0); - mu = U(0); - sigma2 = U(0); - if (i1 < n1) { - // one warp normalizes one n1 index, - // synchronization is implicit - // initialize with standard Welford algorithm - const int numx = blockDim.x * blockDim.y; - const int thrx = threadIdx.x + threadIdx.y * blockDim.x; - const T *lvals = vals + i1 * n2; - int l = 4 * thrx; - for (; l + 3 < n2; l += 4 * numx) { - for (int k = 0; k < 4; ++k) { - U curr = static_cast(lvals[l + k]); - cuWelfordOnlineSum(curr, mu, sigma2, count); - } - } - for (; l < n2; ++l) { - U curr = static_cast(lvals[l]); - cuWelfordOnlineSum(curr, mu, sigma2, count); - } - // intra-warp reductions - for (int l = 0; l <= 4; ++l) { - int srcLaneB = (threadIdx.x + (1 << l)) & 31; - U muB = WARP_SHFL(mu, srcLaneB, 32); - U countB = WARP_SHFL(count, srcLaneB, 32); - U sigma2B = WARP_SHFL(sigma2, srcLaneB, 32); - cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count); - } - // threadIdx.x == 0 has correct values for each warp - // inter-warp reductions - if (blockDim.y > 1) { - U *ubuf = (U *)buf; - U *ibuf = (U *)(ubuf + blockDim.y); - for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { - // upper half of warps write to shared - if (threadIdx.x == 0 && threadIdx.y >= offset && - threadIdx.y < 2 * offset) { - const int wrt_y = threadIdx.y - offset; - ubuf[2 * wrt_y] = mu; - ubuf[2 * wrt_y + 1] = sigma2; - ibuf[wrt_y] = count; - } - __syncthreads(); - // lower half merges - if (threadIdx.x == 0 && threadIdx.y < offset) { - U muB = ubuf[2 * threadIdx.y]; - U sigma2B = ubuf[2 * threadIdx.y + 1]; - U countB = ibuf[threadIdx.y]; - cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); - } - __syncthreads(); - } - // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values - if (threadIdx.x == 0 && threadIdx.y == 0) { - ubuf[0] = mu; - ubuf[1] = sigma2; - } - __syncthreads(); - mu = ubuf[0]; - sigma2 = ubuf[1] / U(n2); - // don't care about final value of count, we know count == n2 - } else { - mu = WARP_SHFL(mu, 0, 32); - sigma2 = WARP_SHFL(sigma2 / U(n2), 0, 32); - } - } -} - -template <> -__device__ void cuWelfordMuSigma2(const at::Half *__restrict__ vals, - const int n1, const int n2, const int i1, - float &mu, float &sigma2, float *buf) { - // Assumptions: - // 1) blockDim.x == warpSize - // 2) Tensor is contiguous - // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. - // - // compute variance and mean over n2 - float count = 0.0f; - mu = float(0); - sigma2 = float(0); - - if (i1 < n1) { - // one warp normalizes one n1 index, - // synchronization is implicit - // initialize with standard Welford algorithm - const int numx = blockDim.x * blockDim.y; - const int thrx = threadIdx.x + threadIdx.y * blockDim.x; - const at::Half *lvals = vals + i1 * n2; - int l = 8 * thrx; - if ((((size_t)lvals) & 3) != 0) { - // 16 bit alignment - // first thread consumes first point - if (thrx == 0) { - float curr = static_cast(lvals[0]); - cuWelfordOnlineSum(curr, mu, sigma2, count); - } - ++l; - } - // at this point, lvals[l] are 32 bit aligned for all threads. - for (; l + 7 < n2; l += 8 * numx) { - for (int k = 0; k < 8; k += 2) { - float2 curr = __half22float2(*((__half2 *)(lvals + l + k))); - cuWelfordOnlineSum(curr.x, mu, sigma2, count); - cuWelfordOnlineSum(curr.y, mu, sigma2, count); - } - } - for (; l < n2; ++l) { - float curr = static_cast(lvals[l]); - cuWelfordOnlineSum(curr, mu, sigma2, count); - } - // intra-warp reductions - for (int l = 0; l <= 4; ++l) { - int srcLaneB = (threadIdx.x + (1 << l)) & 31; - float muB = WARP_SHFL(mu, srcLaneB, 32); - float countB = WARP_SHFL(count, srcLaneB, 32); - float sigma2B = WARP_SHFL(sigma2, srcLaneB, 32); - cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count); - } - // threadIdx.x == 0 has correct values for each warp - // inter-warp reductions - if (blockDim.y > 1) { - float *ubuf = (float *)buf; - float *ibuf = (float *)(ubuf + blockDim.y); - for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { - // upper half of warps write to shared - if (threadIdx.x == 0 && threadIdx.y >= offset && - threadIdx.y < 2 * offset) { - const int wrt_y = threadIdx.y - offset; - ubuf[2 * wrt_y] = mu; - ubuf[2 * wrt_y + 1] = sigma2; - ibuf[wrt_y] = count; - } - __syncthreads(); - // lower half merges - if (threadIdx.x == 0 && threadIdx.y < offset) { - float muB = ubuf[2 * threadIdx.y]; - float sigma2B = ubuf[2 * threadIdx.y + 1]; - float countB = ibuf[threadIdx.y]; - cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); - } - __syncthreads(); - } - // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values - if (threadIdx.x == 0 && threadIdx.y == 0) { - ubuf[0] = mu; - ubuf[1] = sigma2; - } - __syncthreads(); - mu = ubuf[0]; - sigma2 = ubuf[1] / float(n2); - // don't care about final value of count, we know count == n2 - } else { - mu = WARP_SHFL(mu, 0, 32); - sigma2 = WARP_SHFL(sigma2 / float(n2), 0, 32); - } - } -} - -template U rsqrt(U v) { - return U(1) / sqrt(v); -} -//template<> float rsqrt(float v) { -// return rsqrtf(v); -//} - -#if defined __HIP_PLATFORM_HCC__ -__device__ float rsqrt(float v) { return rsqrtf(v); } -#else -template<> float rsqrt(float v) { return rsqrtf(v); } -#endif -template<> double rsqrt(double v) { return rsqrt(v); } -// template __device__ U rsqrt(U v) { return U(1) / sqrt(v); } -// template <> __device__ float rsqrt(float v) { return rsqrtf(v); } -// template <> __device__ double rsqrt(double v) { return rsqrt(v); } - -// This is the un-specialized struct. Note that we prevent instantiation of -// this struct by putting an undefined symbol in the function body so it won't -// compile. -// template -// struct SharedMemory -// { -// // Ensure that we won't compile any un-specialized types -// __device__ T *getPointer() -// { -// extern __device__ void error(void); -// error(); -// return NULL; -// } -// }; -// https://github.com/NVIDIA/apex/issues/246 -template struct SharedMemory; -template <> struct SharedMemory { - __device__ float *getPointer() { - extern __shared__ float s_float[]; - return s_float; - } -}; - -template <> struct SharedMemory { - __device__ double *getPointer() { - extern __shared__ double s_double[]; - return s_double; - } -}; - -template -__global__ void -cuApplyLayerNorm(T *__restrict__ output_vals, U *__restrict__ mean, - U *__restrict__ invvar, const T *__restrict__ vals, - const int n1, const int n2, const U epsilon, - const T *__restrict__ gamma, const T *__restrict__ beta) { - // Assumptions: - // 1) blockDim.x == warpSize - // 2) Tensors are contiguous - // - for (int i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) { - SharedMemory shared; - U *buf = shared.getPointer(); - U mu, sigma2; - cuWelfordMuSigma2(vals, n1, n2, i1, mu, sigma2, buf); - const T *lvals = vals + i1 * n2; - T *ovals = output_vals + i1 * n2; - U c_invvar = rsqrt(sigma2 + epsilon); - const int numx = blockDim.x * blockDim.y; - const int thrx = threadIdx.x + threadIdx.y * blockDim.x; - if (gamma != NULL && beta != NULL) { - for (int i = thrx; i < n2; i += numx) { - U curr = static_cast(lvals[i]); - ovals[i] = gamma[i] * static_cast(c_invvar * (curr - mu)) + beta[i]; - } - } else { - for (int i = thrx; i < n2; i += numx) { - U curr = static_cast(lvals[i]); - ovals[i] = static_cast(c_invvar * (curr - mu)); - } - } - if (threadIdx.x == 0 && threadIdx.y == 0) { - mean[i1] = mu; - invvar[i1] = c_invvar; - } - } -} - -template -__device__ void cuLoadWriteStridedInputs( - const int i1_block, const int thr_load_row_off, const int thr_load_col_off, - const int i2_off, const int row_stride, U *warp_buf1, U *warp_buf2, - const T *input, const T *dout, const int i1_end, const int n2, - const U *__restrict__ mean, const U *__restrict__ invvar) { - int i1 = i1_block + thr_load_row_off; - if (i1 < i1_end) { - U curr_mean = mean[i1]; - U curr_invvar = invvar[i1]; - for (int k = 0; k < blockDim.y; ++k) { - int i2 = i2_off + k; - int load_idx = i1 * n2 + i2; - int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; - if (i2 < n2) { - U curr_input = static_cast(input[load_idx]); - U curr_dout = static_cast(dout[load_idx]); - warp_buf1[write_idx] = curr_dout; - warp_buf2[write_idx] = - curr_dout * (curr_input - curr_mean) * curr_invvar; - } else { - warp_buf1[write_idx] = U(0); - warp_buf2[write_idx] = U(0); - } - } - } else { - for (int k = 0; k < blockDim.y; ++k) { - int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; - warp_buf1[write_idx] = U(0); - warp_buf2[write_idx] = U(0); - } - } -} - -template -__device__ void cuLoadAddStridedInputs( - const int i1_block, const int thr_load_row_off, const int thr_load_col_off, - const int i2_off, const int row_stride, U *warp_buf1, U *warp_buf2, - const T *input, const T *dout, const int i1_end, const int n2, - const U *__restrict__ mean, const U *__restrict__ invvar) { - int i1 = i1_block + thr_load_row_off; - if (i1 < i1_end) { - U curr_mean = mean[i1]; - U curr_invvar = invvar[i1]; - for (int k = 0; k < blockDim.y; ++k) { - int i2 = i2_off + k; - int load_idx = i1 * n2 + i2; - int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k; - if (i2 < n2) { - U curr_input = static_cast(input[load_idx]); - U curr_dout = static_cast(dout[load_idx]); - warp_buf1[write_idx] += curr_dout; - warp_buf2[write_idx] += - curr_dout * (curr_input - curr_mean) * curr_invvar; - } - } - } -} - -template -__global__ void cuComputePartGradGammaBeta( - const T *__restrict__ dout, const T *__restrict__ input, const int n1, - const int n2, const U *__restrict__ mean, const U *__restrict__ invvar, - U epsilon, U *part_grad_gamma, U *part_grad_beta) { - const int numsegs_n1 = - (n1 + blockDim.y * blockDim.y - 1) / (blockDim.y * blockDim.y); - const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y; - const int i1_beg = blockIdx.y * segs_per_block * blockDim.y * blockDim.y; - const int i1_beg_plus_one = - (blockIdx.y + 1) * segs_per_block * blockDim.y * blockDim.y; - const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1; - const int row_stride = blockDim.x + 1; - const int thr_load_col_off = (threadIdx.x * blockDim.y) & (blockDim.x - 1); - const int thr_load_row_off = - (threadIdx.x * blockDim.y) / blockDim.x + threadIdx.y * blockDim.y; - const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off; - SharedMemory shared; - U *buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * - // blockDim.y + (blockDim.y - - // 1)*(blockDim.x/blockDim.y) elements - U *warp_buf1 = (U *)buf; - U *warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride; - // compute partial sums from strided inputs - // do this to increase number of loads in flight - cuLoadWriteStridedInputs(i1_beg, thr_load_row_off, thr_load_col_off, i2_off, - row_stride, warp_buf1, warp_buf2, input, dout, - i1_end, n2, mean, invvar); - for (int i1_block = i1_beg + blockDim.y * blockDim.y; i1_block < i1_end; - i1_block += blockDim.y * blockDim.y) { - cuLoadAddStridedInputs(i1_block, thr_load_row_off, thr_load_col_off, i2_off, - row_stride, warp_buf1, warp_buf2, input, dout, - i1_end, n2, mean, invvar); - } - __syncthreads(); - // inter-warp reductions - // sum within each warp - U acc1 = U(0); - U acc2 = U(0); - for (int k = 0; k < blockDim.y; ++k) { - int row1 = threadIdx.y + k * blockDim.y; - int idx1 = row1 * row_stride + threadIdx.x; - acc1 += warp_buf1[idx1]; - acc2 += warp_buf2[idx1]; - } - warp_buf1[threadIdx.y * row_stride + threadIdx.x] = acc1; - warp_buf2[threadIdx.y * row_stride + threadIdx.x] = acc2; - __syncthreads(); - // sum all warps - for (int offset = blockDim.y / 2; offset > 1; offset /= 2) { - if (threadIdx.y < offset) { - int row1 = threadIdx.y; - int row2 = threadIdx.y + offset; - int idx1 = row1 * row_stride + threadIdx.x; - int idx2 = row2 * row_stride + threadIdx.x; - warp_buf1[idx1] += warp_buf1[idx2]; - warp_buf2[idx1] += warp_buf2[idx2]; - } - __syncthreads(); - } - int i2 = blockIdx.x * blockDim.x + threadIdx.x; - if (threadIdx.y == 0 && i2 < n2) { - int row1 = threadIdx.y; - int row2 = threadIdx.y + 1; - int idx1 = row1 * row_stride + threadIdx.x; - int idx2 = row2 * row_stride + threadIdx.x; - part_grad_beta[blockIdx.y * n2 + i2] = warp_buf1[idx1] + warp_buf1[idx2]; - part_grad_gamma[blockIdx.y * n2 + i2] = warp_buf2[idx1] + warp_buf2[idx2]; - } -} - -template -__global__ void -cuComputeGradGammaBeta(const U *part_grad_gamma, const U *part_grad_beta, - const int part_size, const int n1, const int n2, - T *grad_gamma, T *grad_beta) { - // sum partial gradients for gamma and beta - SharedMemory shared; - U *buf = shared.getPointer(); - int i2 = blockIdx.x * blockDim.x + threadIdx.x; - if (i2 < n2) { - // each warp does sequential reductions until reduced part_size is num_warps - int num_warp_reductions = part_size / blockDim.y; - U sum_gamma = U(0); - U sum_beta = U(0); - const U *part_grad_gamma_ptr = - part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2; - const U *part_grad_beta_ptr = - part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2; - for (int warp_offset = 0; warp_offset < num_warp_reductions; - ++warp_offset) { - sum_gamma += part_grad_gamma_ptr[warp_offset * n2]; - sum_beta += part_grad_beta_ptr[warp_offset * n2]; - } - // inter-warp reductions - const int nbsize3 = blockDim.x * blockDim.y / 2; - for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) { - // top half write to shared memory - if (threadIdx.y >= offset && threadIdx.y < 2 * offset) { - const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x; - buf[write_idx] = sum_gamma; - buf[write_idx + nbsize3] = sum_beta; - } - __syncthreads(); - // bottom half sums - if (threadIdx.y < offset) { - const int read_idx = threadIdx.y * blockDim.x + threadIdx.x; - sum_gamma += buf[read_idx]; - sum_beta += buf[read_idx + nbsize3]; - } - __syncthreads(); - } - // write out fully summed gradients - if (threadIdx.y == 0) { - grad_gamma[i2] = sum_gamma; - grad_beta[i2] = sum_beta; - } - } -} - - -template -__global__ void -cuComputeGradInput(const T *__restrict__ dout, const T *__restrict__ dout_resid, - const T *__restrict__ input, const int n1, const int n2, - const U *__restrict__ mean, const U *__restrict__ invvar, - U epsilon, const T *gamma, T *grad_input) { - for (int i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) { - U sum_loss1 = U(0); - U sum_loss2 = U(0); - const U c_mean = mean[i1]; - const U c_invvar = invvar[i1]; - const T *k_input = input + i1 * n2; - const T *k_dout = dout + i1 * n2; - const T *k_dout_resid = dout_resid + i1 * n2; - const int numx = blockDim.x * blockDim.y; - const int thrx = threadIdx.x + threadIdx.y * blockDim.x; - if (gamma != NULL) { - int l = 4 * thrx; - for (; l + 3 < n2; l += 4 * numx) { - for (int k = 0; k < 4; ++k) { - const U c_h = static_cast(k_input[l + k]); - const U c_loss = static_cast(k_dout[l + k]); - sum_loss1 += c_loss * static_cast(gamma[l + k]); - sum_loss2 += - c_loss * static_cast(gamma[l + k]) * (c_h - c_mean) * c_invvar; - } - } - for (; l < n2; ++l) { - const U c_h = static_cast(k_input[l]); - const U c_loss = static_cast(k_dout[l]); - sum_loss1 += c_loss * static_cast(gamma[l]); - sum_loss2 += - c_loss * static_cast(gamma[l]) * (c_h - c_mean) * c_invvar; - } - } else { - int l = 4 * thrx; - for (; l + 3 < n2; l += 4 * numx) { - for (int k = 0; k < 4; ++k) { - const U c_h = static_cast(k_input[l + k]); - const U c_loss = static_cast(k_dout[l + k]); - sum_loss1 += c_loss; - sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; - } - } - for (; l < n2; ++l) { - const U c_h = static_cast(k_input[l]); - const U c_loss = static_cast(k_dout[l]); - sum_loss1 += c_loss; - sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; - } - } - // intra-warp reductions - for (int mask = blockDim.x / 2; mask > 0; mask /= 2) { - sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask, 32); - sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask, 32); - } - // inter-warp reductions - if (blockDim.y > 1) { - SharedMemory shared; - U *buf = shared.getPointer(); - for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { - // upper half of warps write to shared - if (threadIdx.y >= offset && threadIdx.y < 2 * offset) { - const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x; - buf[2 * wrt_i] = sum_loss1; - buf[2 * wrt_i + 1] = sum_loss2; - } - __syncthreads(); - // lower half merges - if (threadIdx.y < offset) { - const int read_i = threadIdx.y * blockDim.x + threadIdx.x; - sum_loss1 += buf[2 * read_i]; - sum_loss2 += buf[2 * read_i + 1]; - } - __syncthreads(); - } - if (threadIdx.y == 0) { - buf[2 * threadIdx.x] = sum_loss1; - buf[2 * threadIdx.x + 1] = sum_loss2; - } - __syncthreads(); - if (threadIdx.y != 0) { - sum_loss1 = buf[2 * threadIdx.x]; - sum_loss2 = buf[2 * threadIdx.x + 1]; - } - } - // all threads now have the two sums over l - U fH = (U)n2; - U term1 = (U(1) / fH) * c_invvar; - T *k_grad_input = grad_input + i1 * n2; - if (gamma != NULL) { - for (int l = thrx; l < n2; l += numx) { - const U c_h = static_cast(k_input[l]); - const U c_loss = static_cast(k_dout[l]); - const T c_resid = static_cast(k_dout_resid[l]); - U f_grad_input = fH * c_loss * static_cast(gamma[l]); - f_grad_input -= sum_loss1; - f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; - f_grad_input *= term1; - k_grad_input[l] = static_cast(f_grad_input) + c_resid; - } - } else { - for (int l = thrx; l < n2; l += numx) { - const U c_h = static_cast(k_input[l]); - const U c_loss = static_cast(k_dout[l]); - const T c_resid = static_cast(k_dout_resid[l]); - U f_grad_input = fH * c_loss; - f_grad_input -= sum_loss1; - f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; - f_grad_input *= term1; - k_grad_input[l] = static_cast(f_grad_input) + c_resid; - } - } - } -} - -template -void HostApplyLayerNorm(T *output, U *mean, U *invvar, const T *input, int n1, - int n2, double epsilon, const T *gamma, const T *beta) { - auto stream = at::cuda::getCurrentCUDAStream().stream(); - const dim3 threads(32, 4, 1); - const uint64_t maxGridY = - at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; - const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1); - int nshared = - threads.y > 1 ? threads.y * sizeof(U) + (threads.y / 2) * sizeof(U) : 0; - cuApplyLayerNorm<<>>( - output, mean, invvar, input, n1, n2, U(epsilon), gamma, beta); -} - -template -void HostLayerNormGradient(const T *dout, const T *dout_resid, const U *mean, - const U *invvar, const at::Tensor &input, int n1, - int n2, const T *gamma, const T *beta, - double epsilon, T *grad_input, T *grad_gamma, - T *grad_beta) { - auto stream = at::cuda::getCurrentCUDAStream().stream(); - - if (gamma != NULL && beta != NULL) { - // compute grad_gamma(j) and grad_beta(j) - const int part_size = 16; - const dim3 threads2(32, 4, 1); - const dim3 blocks2((n2 + threads2.x - 1) / threads2.x, part_size, 1); - const int nshared2_a = - 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1); - const int nshared2_b = threads2.x * threads2.y * sizeof(U); - const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; - at::Tensor part_grad_gamma = at::empty( - {part_size, n2}, - input.options().dtype(input.scalar_type() == at::ScalarType::Half - ? at::ScalarType::Float - : input.scalar_type())); - at::Tensor part_grad_beta = at::empty_like(part_grad_gamma); - cuComputePartGradGammaBeta<<>>( - dout, static_cast(input.data_ptr()), n1, n2, mean, invvar, - U(epsilon), static_cast(part_grad_gamma.data_ptr()), - static_cast(part_grad_beta.data_ptr())); - - const dim3 threads3(32, 8, 1); - const dim3 blocks3((n2 + threads2.x - 1) / threads2.x, 1, 1); - const int nshared3 = threads3.x * threads3.y * sizeof(U); - cuComputeGradGammaBeta<<>>( - static_cast(part_grad_gamma.data_ptr()), - static_cast(part_grad_beta.data_ptr()), part_size, n1, n2, - grad_gamma, grad_beta); - } - - // compute grad_input - const uint64_t maxGridY = - at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; - const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1); - const dim3 threads1(32, 4, 1); - int nshared = threads1.y > 1 ? threads1.y * threads1.x * sizeof(U) : 0; - cuComputeGradInput<<>>( - dout, dout_resid, static_cast(input.data_ptr()), n1, n2, mean, - invvar, U(epsilon), gamma, grad_input); -} -} // namespace diff --git a/apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu b/apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu deleted file mode 100644 index 2adb6e9..0000000 --- a/apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu +++ /dev/null @@ -1,124 +0,0 @@ -#include -#include -#include - -#include -#include -//#include -#include - -#include -#include -#include - -#include "dropout.cuh" -#include "softmax.cuh" - -namespace multihead_attn { -namespace fused_softmax { -namespace mask_softmax_dropout { - -std::vector fwd_cuda(bool is_training, int heads, - torch::Tensor const &input, - const uint8_t *pad_mask, - float dropout_prob) { - const int attn_batches = input.size(0); - const int sequences = attn_batches / heads; - const int q_seq_len = input.size(1); - const int k_seq_len = q_seq_len; - const int dropout_elems = attn_batches * q_seq_len * k_seq_len; - - // There is no reason to use more than one stream as every kernel is - // sequentially dependent - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - cublasSetStream(handle, stream); - - // 3 Intermediate Results + Output (Note: dropout intermediates are generated - // by ATen library code) - auto act_options = input.options().requires_grad(false); - auto mask_options = act_options.dtype(torch::kUInt8); - - torch::Tensor softmax_results = - torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_results = - torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_mask = - torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); - - // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) - void *input_ptr = static_cast(input.data_ptr()); - void *softmax_results_ptr = static_cast(softmax_results.data_ptr()); - - // Padded Softmax - bool softmax_success = false; - if (pad_mask == nullptr) { - softmax_success = dispatch_softmax( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(input_ptr), k_seq_len, k_seq_len, - attn_batches * q_seq_len); - } else { - softmax_success = dispatch_masked_softmax( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(input_ptr), pad_mask, k_seq_len, - k_seq_len, attn_batches * q_seq_len, - attn_batches * q_seq_len / sequences); - } - - if (is_training) { - // use at:: function so that C++ version generates the same random mask as - // python version - auto dropout_tuple = - at::_fused_dropout(softmax_results, 1.0f - dropout_prob); - dropout_results = std::get<0>(dropout_tuple); - dropout_mask = std::get<1>(dropout_tuple); - } - - // Matmul2 - - return {dropout_results, dropout_mask, softmax_results}; -} - -torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads, - torch::Tensor const &softmax_results, - torch::Tensor const &dropout_mask, - const uint8_t *padding_mask, float dropout_prob) { - const int attn_batches = output_grads.size(0); - const int q_seq_len = output_grads.size(1); - const int k_seq_len = q_seq_len; - const int dropout_elems = attn_batches * q_seq_len * k_seq_len; - // TODO: Streams can be used in Backprop but I haven't added more than one - // in my first attempt to create the code - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - cublasSetStream(handle, stream); - - // Output Tensor Allocations - // torch::Tensor input_grads = torch::empty_like(output_grads); - - // Apply Dropout Mask and Scale by Dropout Probability - // Softmax Grad - if (padding_mask == nullptr) { - dispatch_masked_scale_softmax_backward_stream( - static_cast(output_grads.data_ptr()), - static_cast(output_grads.data_ptr()), - reinterpret_cast(softmax_results.data_ptr()), - static_cast(dropout_mask.data_ptr()), - 1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len, - attn_batches * q_seq_len, stream); - } else { - dispatch_masked_scale_softmax_backward_masked_out_stream( - static_cast(output_grads.data_ptr()), - static_cast(output_grads.data_ptr()), - reinterpret_cast(softmax_results.data_ptr()), - static_cast(dropout_mask.data_ptr()), - static_cast(padding_mask), 1.0 / (1.0 - dropout_prob), - k_seq_len, k_seq_len, attn_batches * q_seq_len, heads, stream); - } - // backward pass is completely in-place - return output_grads; -} -} // namespace mask_softmax_dropout -} // namespace fused_softmax -} // namespace multihead_attn diff --git a/apex/contrib/csrc/multihead_attn/multihead_attn_frontend.cpp b/apex/contrib/csrc/multihead_attn/multihead_attn_frontend.cpp deleted file mode 100644 index 809620e..0000000 --- a/apex/contrib/csrc/multihead_attn/multihead_attn_frontend.cpp +++ /dev/null @@ -1,836 +0,0 @@ -#include - -#include -#include - - -#define CHECK_CUDA(x) \ - AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) \ - AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) \ - CHECK_CUDA(x); \ - CHECK_CONTIGUOUS(x) - -namespace multihead_attn { -namespace fused_softmax { -namespace additive_mask_softmax_dropout { - -std::vector fwd_cuda(bool is_training, int heads, - torch::Tensor const &input, - const half *pad_mask, float dropout_prob); - -torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads, - torch::Tensor const &softmax_results, - torch::Tensor const &dropout_mask, float dropout_prob); - -std::vector fwd(bool use_mask, bool is_training, int heads, - torch::Tensor const &input, - torch::Tensor const &pad_mask, - float dropout_prob) { - AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(input.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - if (use_mask) { - AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); - AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Half, - "Only BYTE is supported"); - } - - return fwd_cuda(is_training, heads, input, - use_mask ? static_cast(pad_mask.data_ptr()) - : nullptr, - dropout_prob); -} - -torch::Tensor bwd(bool use_mask, int heads, torch::Tensor const &output_grads, - torch::Tensor const &softmax_results, - torch::Tensor const &dropout_mask, float dropout_prob) { - AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - // AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, - // "Only BYTE is supported"); - - return bwd_cuda(heads, output_grads, softmax_results, dropout_mask, - dropout_prob); -} - -} // namespace additive_mask_softmax_dropout -namespace mask_softmax_dropout { - -std::vector fwd_cuda(bool is_training, int heads, - torch::Tensor const &input, - const uint8_t *pad_mask, - float dropout_prob); - -torch::Tensor bwd_cuda(int heads, torch::Tensor const &output_grads, - torch::Tensor const &softmax_results, - torch::Tensor const &dropout_mask, - const uint8_t *padding_mask, float dropout_prob); - -std::vector fwd(bool use_mask, bool is_training, int heads, - torch::Tensor const &input, - torch::Tensor const &pad_mask, - float dropout_prob) { - AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(input.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - - if (use_mask) { - AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); - AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, - "Only BYTE is supported"); - } - - return fwd_cuda(is_training, heads, input, - use_mask ? static_cast(pad_mask.data_ptr()) - : nullptr, - dropout_prob); -} - -torch::Tensor bwd(bool use_mask, int heads, torch::Tensor const &output_grads, - torch::Tensor const &softmax_results, - torch::Tensor const &dropout_mask, - torch::Tensor const &padding_mask, float dropout_prob) { - AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor"); - - AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - // AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, - // "Only BYTE is supported"); - - return bwd_cuda(heads, output_grads, softmax_results, dropout_mask, - use_mask - ? static_cast(padding_mask.data_ptr()) - : nullptr, - dropout_prob); -} - -} // end namespace mask_softmax_dropout -} // end namespace fused_softmax - -namespace encdec { -namespace rocblas_gemmex { - -std::vector fwd_cuda(bool use_time_mask, bool is_training, - int heads, torch::Tensor const &inputs_q, - torch::Tensor const &inputs_kv, - torch::Tensor const &input_weights_q, - torch::Tensor const &input_weights_kv, - torch::Tensor const &output_weights, - const uint8_t *pad_mask, - float dropout_prob); -std::vector bwd_cuda( - int heads, torch::Tensor const &output_grads, - torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results, - torch::Tensor const &softmax_results, - torch::Tensor const &input_lin_q_results, - torch::Tensor const &input_lin_kv_results, torch::Tensor const &inputs_q, - torch::Tensor const &inputs_kv, torch::Tensor const &input_weights_q, - torch::Tensor const &input_weights_kv, torch::Tensor const &output_weights, - torch::Tensor const &dropout_mask, float dropout_prob); - -std::vector -fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, - torch::Tensor const &inputs_q, torch::Tensor const &inputs_kv, - torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv, - torch::Tensor const &output_weights, torch::Tensor const &pad_mask, - float dropout_prob) { - AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(input_weights_q.dim() == 2, "expected 2D tensor"); - AT_ASSERTM(input_weights_kv.dim() == 2, "expected 2D tensor"); - AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); - - AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - - if (use_mask) { - AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); - AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, - "Only BYTE is supported"); - } - - return fwd_cuda(use_time_mask, is_training, heads, inputs_q, inputs_kv, - input_weights_q, input_weights_kv, output_weights, - use_mask ? static_cast(pad_mask.data_ptr()) - : nullptr, - dropout_prob); -} - -std::vector -bwd(int heads, torch::Tensor const &output_grads, - torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results, - torch::Tensor const &softmax_results, - torch::Tensor const &input_lin_q_results, - torch::Tensor const &input_lin_kv_results, torch::Tensor const &inputs_q, - torch::Tensor const &inputs_kv, torch::Tensor const &input_weights_q, - torch::Tensor const &input_weights_kv, torch::Tensor const &output_weights, - torch::Tensor const &dropout_mask, float dropout_prob) { - AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(input_lin_q_results.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(input_lin_kv_results.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(input_weights_q.dim() == 2, "expected 2D tensor"); - AT_ASSERTM(input_weights_kv.dim() == 2, "expected 2D tensor"); - AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); - AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor"); - - AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(input_lin_q_results.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(input_lin_kv_results.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, - "Only BYTE is supported"); - - return bwd_cuda(heads, output_grads, matmul2_results, dropout_results, - softmax_results, input_lin_q_results, input_lin_kv_results, - inputs_q, inputs_kv, input_weights_q, input_weights_kv, - output_weights, dropout_mask, dropout_prob); -} - -} // end namespace rocblas_gemmex -} // end namespace encdec - -namespace encdec_norm_add { -namespace rocblas_gemmex { - -std::vector fwd_cuda(bool use_time_mask, bool is_training, - int heads, torch::Tensor const &inputs_q, - torch::Tensor const &inputs_kv, - torch::Tensor const &lyr_nrm_gamma_weights, - torch::Tensor const &lyr_nrm_beta_weights, - torch::Tensor const &input_weights_q, - torch::Tensor const &input_weights_kv, - torch::Tensor const &output_weights, - const uint8_t *pad_mask, - float dropout_prob); - -std::vector bwd_cuda( - int heads, torch::Tensor const &output_grads, - torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results, - torch::Tensor const &softmax_results, - torch::Tensor const &input_lin_q_results, - torch::Tensor const &input_lin_kv_results, - torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean, - torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs_q, - torch::Tensor const &inputs_kv, torch::Tensor const &lyr_nrm_gamma_weights, - torch::Tensor const &lyr_nrm_beta_weights, - torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv, - torch::Tensor const &output_weights, torch::Tensor const &dropout_mask, - torch::Tensor const &dropout_add_mask, float dropout_prob); - -std::vector -fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, - torch::Tensor const &inputs_q, torch::Tensor const &inputs_kv, - torch::Tensor const &lyr_nrm_gamma_weights, - torch::Tensor const &lyr_nrm_beta_weights, - torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv, - torch::Tensor const &output_weights, torch::Tensor const &pad_mask, - float dropout_prob) { - AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor"); - AT_ASSERTM(lyr_nrm_beta_weights.dim() == 1, "expected 1D tensor"); - AT_ASSERTM(input_weights_q.dim() == 2, "expected 2D tensor"); - AT_ASSERTM(input_weights_kv.dim() == 2, "expected 2D tensor"); - AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); - - AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - - if (use_mask) { - AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); - AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, - "Only BYTE is supported"); - } - - return fwd_cuda(use_time_mask, is_training, heads, inputs_q, inputs_kv, - lyr_nrm_gamma_weights, lyr_nrm_beta_weights, input_weights_q, - input_weights_kv, output_weights, - use_mask ? static_cast(pad_mask.data_ptr()) - : nullptr, - dropout_prob); -} - -std::vector -bwd(int heads, torch::Tensor const &output_grads, - torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results, - torch::Tensor const &softmax_results, - torch::Tensor const &input_lin_q_results, - torch::Tensor const &input_lin_kv_results, - torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean, - torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs_q, - torch::Tensor const &inputs_kv, torch::Tensor const &lyr_nrm_gamma_weights, - torch::Tensor const &lyr_nrm_beta_weights, - torch::Tensor const &input_weights_q, torch::Tensor const &input_weights_kv, - torch::Tensor const &output_weights, torch::Tensor const &dropout_mask, - torch::Tensor const &dropout_add_mask, float dropout_prob) { - AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(input_lin_q_results.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(input_lin_kv_results.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(lyr_nrm_results.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(lyr_nrm_mean.dim() == 1, "expected 1D tensor"); - AT_ASSERTM(lyr_nrm_invvar.dim() == 1, "expected 1D tensor"); - AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor"); - AT_ASSERTM(lyr_nrm_beta_weights.dim() == 1, "expected 1D tensor"); - AT_ASSERTM(input_weights_q.dim() == 2, "expected 2D tensor"); - AT_ASSERTM(input_weights_kv.dim() == 2, "expected 2D tensor"); - AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); - AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(dropout_add_mask.dim() == 3, "expected 3D tensor"); - - AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(input_lin_q_results.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(input_lin_kv_results.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(lyr_nrm_results.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(lyr_nrm_mean.type().scalarType() == at::ScalarType::Float, - "Only FLOAT is supported"); - AT_ASSERTM(lyr_nrm_invvar.type().scalarType() == at::ScalarType::Float, - "Only FLOAT is supported"); - AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, - "Only BYTE is supported"); - AT_ASSERTM(dropout_add_mask.type().scalarType() == at::ScalarType::Byte, - "Only BYTE is supported"); - - return bwd_cuda(heads, output_grads, matmul2_results, dropout_results, - softmax_results, input_lin_q_results, input_lin_kv_results, - lyr_nrm_results, lyr_nrm_mean, lyr_nrm_invvar, inputs_q, - inputs_kv, lyr_nrm_gamma_weights, lyr_nrm_beta_weights, - input_weights_q, input_weights_kv, output_weights, - dropout_mask, dropout_add_mask, dropout_prob); -} - - -} // end namespace rocblas_gemmex -} // end namespace encdec_norm_add - -namespace self { -namespace rocblas_gemmex { - -std::vector fwd_cuda(bool use_time_mask, bool is_training, - int heads, torch::Tensor const &inputs, - torch::Tensor const &input_weights, - torch::Tensor const &output_weights, - const uint8_t *pad_mask, - float dropout_prob); - -std::vector bwd_cuda( - int heads, torch::Tensor const &output_grads, - torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results, - torch::Tensor const &softmax_results, - torch::Tensor const &input_lin_results, torch::Tensor const &inputs, - torch::Tensor const &input_weights, torch::Tensor const &output_weights, - torch::Tensor const &dropout_mask, float dropout_prob); - -std::vector -fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, - torch::Tensor const &inputs, torch::Tensor const &input_weights, - torch::Tensor const &output_weights, torch::Tensor const &pad_mask, - float dropout_prob) { - AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor"); - AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); - - AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - - if (use_mask) { - AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); - AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, - "Only BYTE is supported"); - } - - return fwd_cuda( - use_time_mask, is_training, heads, inputs, input_weights, output_weights, - use_mask ? static_cast(pad_mask.data_ptr()) : nullptr, - dropout_prob); -} - -std::vector -bwd(int heads, torch::Tensor const &output_grads, - torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results, - torch::Tensor const &softmax_results, - torch::Tensor const &input_lin_results, torch::Tensor const &inputs, - torch::Tensor const &input_weights, torch::Tensor const &output_weights, - torch::Tensor const &dropout_mask, float dropout_prob) { - AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor"); - AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); - AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor"); - - AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, - "Only BYTE is supported"); - - return bwd_cuda(heads, output_grads, matmul2_results, dropout_results, - softmax_results, input_lin_results, inputs, input_weights, - output_weights, dropout_mask, dropout_prob); -} - -} // end namespace rocblas_gemmex -} // end namespace self -namespace self_bias { -namespace rocblas_gemmex { - -std::vector -fwd_cuda(bool use_time_mask, bool is_training, int heads, - torch::Tensor const &inputs, torch::Tensor const &input_weights, - torch::Tensor const &output_weights, torch::Tensor const &input_biases, - torch::Tensor const &output_biases, const uint8_t *pad_mask, - float dropout_prob); - -std::vector bwd_cuda( - int heads, torch::Tensor const &output_grads, - torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results, - torch::Tensor const &softmax_results, - torch::Tensor const &input_lin_results, torch::Tensor const &inputs, - torch::Tensor const &input_weights, torch::Tensor const &output_weights, - // torch::Tensor const& input_biases, - // torch::Tensor const& output_biases, - torch::Tensor const &dropout_mask, float dropout_prob); - -std::vector -fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, - torch::Tensor const &inputs, torch::Tensor const &input_weights, - torch::Tensor const &output_weights, torch::Tensor const &input_biases, - torch::Tensor const &output_biases, torch::Tensor const &pad_mask, - float dropout_prob) { - AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor"); - AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); - - AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - - if (use_mask) { - AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); - AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, - "Only BYTE is supported"); - } - - return fwd_cuda(use_time_mask, is_training, heads, inputs, input_weights, - output_weights, input_biases, output_biases, - use_mask ? static_cast(pad_mask.data_ptr()) - : nullptr, - dropout_prob); -} - -std::vector -bwd(int heads, torch::Tensor const &output_grads, - torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results, - torch::Tensor const &softmax_results, - torch::Tensor const &input_lin_results, torch::Tensor const &inputs, - torch::Tensor const &input_weights, torch::Tensor const &output_weights, - torch::Tensor const &dropout_mask, float dropout_prob) { - AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor"); - AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); - AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor"); - - AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, - "Only BYTE is supported"); - - return bwd_cuda(heads, output_grads, matmul2_results, dropout_results, - softmax_results, input_lin_results, inputs, input_weights, - output_weights, dropout_mask, dropout_prob); -} - -} // end namespace rocblas_gemmex -} // namespace self_bias -namespace self_bias_additive_mask { -namespace rocblas_gemmex { - -std::vector fwd_cuda(bool use_time_mask, bool is_training, - int heads, torch::Tensor const &inputs, - torch::Tensor const &input_weights, - torch::Tensor const &output_weights, - torch::Tensor const &input_biases, - torch::Tensor const &output_biases, - const half *pad_mask, float dropout_prob); - -std::vector bwd_cuda( - int heads, torch::Tensor const &output_grads, - torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results, - // torch::Tensor const& softmax_results, - torch::Tensor const &bmm1_results, torch::Tensor const &pad_mask, - torch::Tensor const &input_lin_results, torch::Tensor const &inputs, - torch::Tensor const &input_weights, torch::Tensor const &output_weights, - // torch::Tensor const& input_biases, - // torch::Tensor const& output_biases, - torch::Tensor const &dropout_mask, float dropout_prob); - -std::vector -fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, - torch::Tensor const &inputs, torch::Tensor const &input_weights, - torch::Tensor const &output_weights, torch::Tensor const &input_biases, - torch::Tensor const &output_biases, torch::Tensor const &pad_mask, - float dropout_prob) { - AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor"); - AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); - - AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(use_mask, "no mask is not supported"); - - if (use_mask) { - AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); - AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Half, - "Only Half is supported"); - } - - return fwd_cuda(use_time_mask, is_training, heads, inputs, input_weights, - output_weights, input_biases, output_biases, - use_mask ? static_cast(pad_mask.data_ptr()) - : nullptr, - dropout_prob); -} - -std::vector -bwd(int heads, torch::Tensor const &output_grads, - torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results, - torch::Tensor const &bmm1_results, torch::Tensor const &pad_mask, - torch::Tensor const &input_lin_results, torch::Tensor const &inputs, - torch::Tensor const &input_weights, torch::Tensor const &output_weights, - torch::Tensor const &dropout_mask, float dropout_prob) { - AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor"); - AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); - AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor"); - - AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, - "Only BYTE is supported"); - - return bwd_cuda(heads, output_grads, matmul2_results, dropout_results, - bmm1_results, pad_mask, input_lin_results, inputs, - input_weights, output_weights, dropout_mask, dropout_prob); -} - -} // end namespace rocblas_gemmex -} // namespace self_bias_additive_mask - -namespace self_norm_add { -namespace rocblas_gemmex { - -std::vector fwd_cuda(bool use_time_mask, bool is_training, - int heads, torch::Tensor const &inputs, - torch::Tensor const &lyr_nrm_gamma_weights, - torch::Tensor const &lyr_nrm_beta_weights, - torch::Tensor const &input_weights, - torch::Tensor const &output_weights, - const uint8_t *pad_mask, - float dropout_prob); - -std::vector bwd_cuda( - int heads, torch::Tensor const &output_grads, - torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results, - torch::Tensor const &softmax_results, - torch::Tensor const &input_lin_results, - torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean, - torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs, - torch::Tensor const &lyr_nrm_gamma_weights, - torch::Tensor const &lyr_nrm_beta_weights, - torch::Tensor const &input_weights, torch::Tensor const &output_weights, - torch::Tensor const &dropout_mask, torch::Tensor const &dropout_add_mask, - float dropout_prob); - -std::vector -fwd(bool use_mask, bool use_time_mask, bool is_training, int heads, - torch::Tensor const &inputs, torch::Tensor const &lyr_nrm_gamma_weights, - torch::Tensor const &lyr_nrm_beta_weights, - torch::Tensor const &input_weights, torch::Tensor const &output_weights, - torch::Tensor const &pad_mask, float dropout_prob) { - AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor"); - AT_ASSERTM(lyr_nrm_beta_weights.dim() == 1, "expected 1D tensor"); - AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor"); - AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); - - AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - - if (use_mask) { - AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor"); - AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, - "Only BYTE is supported"); - } - - return fwd_cuda( - use_time_mask, is_training, heads, inputs, lyr_nrm_gamma_weights, - lyr_nrm_beta_weights, input_weights, output_weights, - use_mask ? static_cast(pad_mask.data_ptr()) : nullptr, - dropout_prob); -} - -std::vector -bwd(int heads, torch::Tensor const &output_grads, - torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results, - torch::Tensor const &softmax_results, - torch::Tensor const &input_lin_results, - torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean, - torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs, - torch::Tensor const &lyr_nrm_gamma_weights, - torch::Tensor const &lyr_nrm_beta_weights, - torch::Tensor const &input_weights, torch::Tensor const &output_weights, - torch::Tensor const &dropout_mask, torch::Tensor const &dropout_add_mask, - float dropout_prob) { - AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(lyr_nrm_results.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(lyr_nrm_mean.dim() == 1, "expected 1D tensor"); - AT_ASSERTM(lyr_nrm_invvar.dim() == 1, "expected 1D tensor"); - AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor"); - AT_ASSERTM(lyr_nrm_beta_weights.dim() == 1, "expected 1D tensor"); - AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor"); - AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor"); - AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(dropout_add_mask.dim() == 3, "expected 3D tensor"); - - AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(lyr_nrm_results.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(lyr_nrm_mean.type().scalarType() == at::ScalarType::Float, - "Only FLOAT is supported"); - AT_ASSERTM(lyr_nrm_invvar.type().scalarType() == at::ScalarType::Float, - "Only FLOAT is supported"); - AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, - "Only HALF is supported"); - AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, - "Only BYTE is supported"); - AT_ASSERTM(dropout_add_mask.type().scalarType() == at::ScalarType::Byte, - "Only BYTE is supported"); - - return bwd_cuda(heads, output_grads, matmul2_results, dropout_results, - softmax_results, input_lin_results, lyr_nrm_results, - lyr_nrm_mean, lyr_nrm_invvar, inputs, lyr_nrm_gamma_weights, - lyr_nrm_beta_weights, input_weights, output_weights, - dropout_mask, dropout_add_mask, dropout_prob); -} - -} // end namespace rocblas_gemmex -} // end namespace self_norm_add -} // end namespace multihead_attn - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("additive_mask_softmax_dropout_forward", - &multihead_attn::fused_softmax::additive_mask_softmax_dropout::fwd, - "Self Multihead Attention masked softmax dropout -- Forward."); - m.def("additive_mask_softmax_dropout_backward", - &multihead_attn::fused_softmax::additive_mask_softmax_dropout::bwd, - "Self Multihead Attention masked softmax dropout -- Backward."); - m.def("mask_softmax_dropout_forward", &multihead_attn::fused_softmax::mask_softmax_dropout::fwd, - "Self Multihead Attention masked softmax dropout -- Forward."); - m.def("mask_softmax_dropout_backward", &multihead_attn::fused_softmax::mask_softmax_dropout::bwd, - "Self Multihead Attention masked softmax dropout -- Backward."); - m.def("encdec_multihead_attn_forward", &multihead_attn::encdec::rocblas_gemmex::fwd, - "Encdec Multihead Attention Forward."); - m.def("encdec_multihead_attn_backward", &multihead_attn::encdec::rocblas_gemmex::bwd, - "Encdec Multihead Attention Backward."); - m.def("encdec_multihead_attn_norm_add_forward", &multihead_attn::encdec_norm_add::rocblas_gemmex::fwd, - "Encdec Multihead Attention Plus Layer Norm and Residual Add Forward."); - m.def( - "encdec_multihead_attn_norm_add_backward", &multihead_attn::encdec_norm_add::rocblas_gemmex::bwd, - "Encdec Multihead Attention Plus Layer Norm and Residual Add Backward."); - m.def("self_attn_forward", &multihead_attn::self::rocblas_gemmex::fwd, - "Self Multihead Attention Forward."); - m.def("self_attn_backward", &multihead_attn::self::rocblas_gemmex::bwd, - "Self Multihead Attention Backward."); - m.def("self_attn_bias_forward", &multihead_attn::self_bias::rocblas_gemmex::fwd, - "Self Multihead Attention with Bias -- Forward."); - m.def("self_attn_bias_backward", &multihead_attn::self_bias::rocblas_gemmex::bwd, - "Self Multihead Attention with Bias -- Backward."); - m.def("self_attn_bias_additive_mask_forward", &multihead_attn::self_bias_additive_mask::rocblas_gemmex::fwd, - "Self Multihead Attention with Bias -- Forward."); - m.def("self_attn_bias_additive_mask_backward", - &multihead_attn::self_bias_additive_mask::rocblas_gemmex::bwd, - "Self Multihead Attention with Bias -- Backward."); - m.def("self_attn_norm_add_forward", &multihead_attn::self_norm_add::rocblas_gemmex::fwd, - "Self Multihead Attention Plus Layer Norm and Residual Add Forward."); - m.def("self_attn_norm_add_backward", &multihead_attn::self_norm_add::rocblas_gemmex::bwd, - "Self Multihead Attention Plus Layer Norm and Residual Add Backward."); -} - -#undef CHECK_CUDA -#undef CHECK_CONTIGUOUS -#undef CHECK_INPUT diff --git a/apex/contrib/csrc/multihead_attn/philox.cuh b/apex/contrib/csrc/multihead_attn/philox.cuh deleted file mode 100644 index 7660be6..0000000 --- a/apex/contrib/csrc/multihead_attn/philox.cuh +++ /dev/null @@ -1,96 +0,0 @@ -#pragma once -// Philox CUDA. - -namespace { - -class Philox { -public: - __device__ inline Philox(unsigned long long seed, - unsigned long long subsequence, - unsigned long long offset) { - key.x = (unsigned int)seed; - key.y = (unsigned int)(seed >> 32); - counter = make_uint4(0, 0, 0, 0); - counter.z = (unsigned int)(subsequence); - counter.w = (unsigned int)(subsequence >> 32); - STATE = 0; - incr_n(offset / 4); - } - __device__ inline uint4 operator()() { - if (STATE == 0) { - uint4 counter_ = counter; - uint2 key_ = key; - // 7-round philox - for (int i = 0; i < 6; i++) { - counter_ = single_round(counter_, key_); - key_.x += (kPhilox10A); - key_.y += (kPhilox10B); - } - output = single_round(counter_, key_); - incr(); - } - // return a float4 directly - // unsigned long ret; - // switch(STATE) { - // case 0: ret = output.x; break; - // case 1: ret = output.y; break; - // case 2: ret = output.z; break; - // case 3: ret = output.w; break; - //} - // STATE = (STATE + 1) % 4; - return output; - } - -private: - uint4 counter; - uint4 output; - uint2 key; - unsigned int STATE; - __device__ inline void incr_n(unsigned long long n) { - unsigned int nlo = (unsigned int)(n); - unsigned int nhi = (unsigned int)(n >> 32); - counter.x += nlo; - if (counter.x < nlo) - nhi++; - counter.y += nhi; - if (nhi <= counter.y) - return; - if (++counter.z) - return; - ++counter.w; - } - __device__ inline void incr() { - if (++counter.x) - return; - if (++counter.y) - return; - if (++counter.z) - return; - ++counter.w; - } - __device__ unsigned int mulhilo32(unsigned int a, unsigned int b, - unsigned int *result_high) { - *result_high = __umulhi(a, b); - return a * b; - } - __device__ inline uint4 single_round(uint4 ctr, uint2 key) { - unsigned int hi0; - unsigned int hi1; - unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0); - unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1); - uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0}; - return ret; - } - static const unsigned long kPhilox10A = 0x9E3779B9; - static const unsigned long kPhilox10B = 0xBB67AE85; - static const unsigned long kPhiloxSA = 0xD2511F53; - static const unsigned long kPhiloxSB = 0xCD9E8D57; -}; -// Inverse of 2^32. -constexpr float M_RAN_INVM32 = 2.3283064e-10f; -__device__ __inline__ float4 uniform4(uint4 x) { - return make_float4(x.x * M_RAN_INVM32, x.y * M_RAN_INVM32, x.z * M_RAN_INVM32, - x.w * M_RAN_INVM32); -} - -} // namespace diff --git a/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu b/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu deleted file mode 100644 index 226cfbf..0000000 --- a/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu +++ /dev/null @@ -1,504 +0,0 @@ -#include -#include -#include - -#include -#include -//#include -#include - -#include -#include -#include - -#include "dropout.cuh" -#include "softmax.cuh" -#include "strided_batched_gemm.cuh" - -namespace multihead_attn { -namespace self_bias_additive_mask { -namespace rocblas_gemmex { - -std::vector fwd_cuda(bool use_time_mask, bool is_training, - int heads, torch::Tensor const& inputs, - torch::Tensor const& input_weights, - torch::Tensor const& output_weights, - torch::Tensor const& input_biases, - torch::Tensor const& output_biases, - const half* pad_mask, float dropout_prob) { - const int embed_dim = inputs.size(2); - const int sequences = inputs.size(1); - const int q_seq_len = inputs.size(0); - const int k_seq_len = q_seq_len; - const int batches = sequences * q_seq_len; - const int head_dim = embed_dim / heads; - const int output_lin_dim = 3 * embed_dim; - const int attn_batches = heads * sequences; - const int lead_dim = attn_batches * 3 * head_dim; - const int batch_stride = 3 * head_dim; - const int dropout_elems = attn_batches * q_seq_len * k_seq_len; - const float alpha = 1.0; - const float beta_zero = 0.0; - const float beta_one = 1.0; - const float scale = 1.0 / sqrt(static_cast(head_dim)); - - // There is no reason to use more than one stream as every kernel is - // sequentially dependent - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - cublasSetStream(handle, stream); - - // 3 Intermediate Results + Output (Note: dropout intermediates are generated - // by ATen library code) - auto act_options = inputs.options().requires_grad(false); - auto mask_options = act_options.dtype(torch::kUInt8); - - torch::Tensor input_lin_results = - torch::empty({q_seq_len, sequences, output_lin_dim}, act_options); - torch::Tensor bmm1_results = - torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_results = - torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_mask = - torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); - torch::Tensor matmul2_results = - torch::empty({q_seq_len, attn_batches, head_dim}, act_options); - torch::Tensor outputs = torch::empty_like(inputs, act_options); - - // Input Linear Results Pointers to Q, K, and V of interviewed activations - void *q_lin_results_ptr = static_cast(input_lin_results.data_ptr()); - void *k_lin_results_ptr = static_cast( - static_cast(input_lin_results.data_ptr()) + head_dim); - void *v_lin_results_ptr = static_cast( - static_cast(input_lin_results.data_ptr()) + 2 * head_dim); - - // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) - void *bmm1_results_ptr = static_cast(bmm1_results.data_ptr()); - void *dropout_results_ptr = static_cast(dropout_results.data_ptr()); - - char a_layout_t{'t'}; - char a_layout_n{'n'}; - char b_layout_n{'n'}; - - rocblas_int flags = 0; - - //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); - - // Input Linear Fwd - input_lin_results.copy_(input_biases); - TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, - CUBLAS_OP_T, - CUBLAS_OP_N, - output_lin_dim, - batches, - embed_dim, - static_cast(&alpha), - static_cast(input_weights.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(inputs.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(&beta_one), - q_lin_results_ptr, - rocblas_datatype_f16_r, - output_lin_dim, - q_lin_results_ptr, - rocblas_datatype_f16_r, - output_lin_dim, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard /*algo*/, - 0 /*solution_index*/, - flags)); - - // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) - gemm_switch_fp32accum( a_layout_t, - b_layout_n, - k_seq_len, - q_seq_len, - head_dim, - scale, - static_cast(k_lin_results_ptr), - lead_dim, - batch_stride, - static_cast(q_lin_results_ptr), - lead_dim, - batch_stride, - beta_zero, - static_cast(bmm1_results_ptr), - k_seq_len, - k_seq_len*q_seq_len, - static_cast(bmm1_results_ptr), - k_seq_len, - k_seq_len*q_seq_len, - attn_batches, - flags); - - // Padded Softmax - bool softmax_success = false; - if (is_training) { - softmax_success = - dispatch_additive_masked_softmax_dropout( - reinterpret_cast(dropout_results_ptr), - (is_training) - ? reinterpret_cast(dropout_mask.data_ptr()) - : nullptr, - reinterpret_cast(bmm1_results_ptr), pad_mask, - attn_batches * q_seq_len * q_seq_len, k_seq_len, k_seq_len, - attn_batches * q_seq_len, attn_batches * q_seq_len / sequences, - 1.0f - dropout_prob, stream); - } else { - softmax_success = dispatch_additive_masked_softmax( - reinterpret_cast( - dropout_results_ptr), // this is actually softmax results, but - // making it consistent for the next function - reinterpret_cast(bmm1_results_ptr), pad_mask, k_seq_len, - k_seq_len, attn_batches * q_seq_len, - attn_batches * q_seq_len / sequences); - } - - // Matmul2 - gemm_switch_fp32accum( a_layout_n, - b_layout_n, - head_dim, - q_seq_len, - k_seq_len, - alpha, - static_cast(v_lin_results_ptr), - lead_dim, - batch_stride, - static_cast(dropout_results.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - beta_zero, - static_cast(matmul2_results.data_ptr()), - head_dim*attn_batches, - head_dim, - static_cast(matmul2_results.data_ptr()), - head_dim*attn_batches, - head_dim, - attn_batches, - flags); - - outputs.copy_(output_biases); - - // Output Linear - TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, - CUBLAS_OP_T, - CUBLAS_OP_N, - embed_dim, - batches, - embed_dim, - static_cast(&alpha), - static_cast(output_weights.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(matmul2_results.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(&beta_one), - static_cast(outputs.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(outputs.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard /*algo*/, - 0 /*solution_index*/, - flags)); - //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); - - return {input_lin_results, bmm1_results, dropout_results, - dropout_mask, matmul2_results, outputs}; -} - -std::vector bwd_cuda( - int heads, torch::Tensor const &output_grads, - torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results, - torch::Tensor const &bmm1_results, torch::Tensor const &pad_mask, - torch::Tensor const &input_lin_results, torch::Tensor const &inputs, - torch::Tensor const &input_weights, torch::Tensor const &output_weights, - torch::Tensor const &dropout_mask, float dropout_prob) { - const int embed_dim = inputs.size(2); - const int sequences = inputs.size(1); - const int q_seq_len = inputs.size(0); - const int k_seq_len = q_seq_len; - const int batches = sequences * q_seq_len; - const int head_dim = embed_dim / heads; - const int output_lin_dim = 3 * embed_dim; - const int attn_batches = heads * sequences; - const int lead_dim = attn_batches * 3 * head_dim; - const int batch_stride = 3 * head_dim; - const int dropout_elems = attn_batches * q_seq_len * k_seq_len; - const float alpha = 1.0; - const float beta = 0.0; - const float scale = 1.0 / sqrt(static_cast(head_dim)); - - // TODO: Streams can be used in Backprop but I haven't added more than one - // in my first attempt to create the code - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - cublasSetStream(handle, stream); - - // Output Tensor Allocations - torch::Tensor input_grads = torch::empty_like(inputs); - torch::Tensor input_weight_grads = torch::empty_like(input_weights); - torch::Tensor output_weight_grads = torch::empty_like(output_weights); - // Intermediate Tensor Allocations - at::Tensor output_lin_grads = torch::empty_like(matmul2_results); - at::Tensor matmul2_grads = torch::empty_like(dropout_results); - at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results); - - auto q_lin_results_ptr = static_cast(input_lin_results.data_ptr()); - auto k_lin_results_ptr = - static_cast(input_lin_results.data_ptr()) + head_dim; - auto v_lin_results_ptr = - static_cast(input_lin_results.data_ptr()) + 2 * head_dim; - - auto q_lin_grads_ptr = static_cast(input_lin_output_grads.data_ptr()); - auto k_lin_grads_ptr = - static_cast(input_lin_output_grads.data_ptr()) + head_dim; - auto v_lin_grads_ptr = - static_cast(input_lin_output_grads.data_ptr()) + 2 * head_dim; - - char a_layout_n{'n'}; - char a_layout_t{'t'}; - char b_layout_n{'n'}; - char b_layout_t{'t'}; - - rocblas_int flags = 0; - - //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); - #ifdef __HIP_PLATFORM_HCC__ - #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR) - #define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) - #if USE_GEMM_FLAGS_FP16_ALT_IMPL - #ifdef BACKWARD_PASS_GUARD - flags = at::BACKWARD_PASS_GUARD_CLASS::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0; - #endif - #endif - #endif - - // Output Linear Dgrad - TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, - CUBLAS_OP_N, - CUBLAS_OP_N, - embed_dim, - batches, - embed_dim, - static_cast(&alpha), - static_cast(output_weights.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(output_grads.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(&beta), - static_cast(output_lin_grads.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(output_lin_grads.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard /*algo*/, - 0 /*solution_index*/, - flags)); - - // Output Linear Wgrad - TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, - CUBLAS_OP_N, - CUBLAS_OP_T, - embed_dim, - embed_dim, - batches, - static_cast(&alpha), - static_cast(matmul2_results.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(output_grads.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(&beta), - static_cast(output_weight_grads.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(output_weight_grads.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard /*algo*/, - 0 /*solution_index*/, - flags)); - - auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false); - // MatMul2 Dgrad1 - gemm_switch_fp32accum( a_layout_t, - b_layout_n, - k_seq_len, - q_seq_len, - head_dim, - alpha, - static_cast(v_lin_results_ptr), - lead_dim, - batch_stride, - static_cast(output_lin_grads.data_ptr()), - head_dim*attn_batches, - head_dim, - beta, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - attn_batches, - flags); - - // Matmul2 Dgrad2 - gemm_switch_fp32accum( a_layout_n, - b_layout_t, - head_dim, - k_seq_len, - q_seq_len, - alpha, - static_cast(output_lin_grads.data_ptr()), - head_dim*attn_batches, - head_dim, - static_cast(dropout_results.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - beta, - v_lin_grads_ptr, - lead_dim, - batch_stride, - v_lin_grads_ptr, - lead_dim, - batch_stride, - attn_batches, - flags); - - // Apply Dropout Mask and Scale by Dropout Probability - // Softmax Grad - dispatch_masked_scale_softmax_backward_recompute( - static_cast(matmul2_grads.data_ptr()), - static_cast(matmul2_grads.data_ptr()), - reinterpret_cast(bmm1_results.data_ptr()), - reinterpret_cast(pad_mask.data_ptr()), - static_cast(dropout_mask.data_ptr()), - 1.0/(1.0-dropout_prob), - k_seq_len, - k_seq_len, - attn_batches*q_seq_len/sequences, - attn_batches*q_seq_len, - stream); - - // Matmul1 Dgrad1 - gemm_switch_fp32accum( a_layout_n, - b_layout_n, - head_dim, - q_seq_len, - k_seq_len, - scale, - k_lin_results_ptr, - lead_dim, - batch_stride, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - beta, - q_lin_grads_ptr, - lead_dim, - batch_stride, - q_lin_grads_ptr, - lead_dim, - batch_stride, - attn_batches, - flags); - - // Matmul1 Dgrad2 - gemm_switch_fp32accum( a_layout_n, - b_layout_t, - head_dim, - k_seq_len, - q_seq_len, - scale, - q_lin_results_ptr, - lead_dim, - batch_stride, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - beta, - k_lin_grads_ptr, - lead_dim, - batch_stride, - k_lin_grads_ptr, - lead_dim, - batch_stride, - attn_batches, - flags); - - // Input Linear Dgrad - TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, - CUBLAS_OP_N, - CUBLAS_OP_N, - embed_dim, - batches, - output_lin_dim, - static_cast(&alpha), - static_cast(input_weights.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(input_lin_output_grads.data_ptr()), - rocblas_datatype_f16_r, - output_lin_dim, - static_cast(&beta), - static_cast(input_grads.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(input_grads.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard /*algo*/, - 0 /*solution_index*/, - flags)); - - // Input Linear Wgrad - TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, - CUBLAS_OP_N, - CUBLAS_OP_T, - embed_dim, - output_lin_dim, - batches, - static_cast(&alpha), - static_cast(inputs.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(q_lin_grads_ptr), - rocblas_datatype_f16_r, - output_lin_dim, - static_cast(&beta), - static_cast(input_weight_grads.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(input_weight_grads.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard /*algo*/, - 0 /*solution_index*/, - flags)); - - auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false); - //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); - - return {input_grads, input_weight_grads, output_weight_grads, - input_bias_grads, output_bias_grads}; -} - -} // end namespace rocblas_gemmex -} // end namespace self_bias_additive_mask -} // end namespace multihead_attn diff --git a/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu b/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu deleted file mode 100644 index f9a2a49..0000000 --- a/apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu +++ /dev/null @@ -1,504 +0,0 @@ -#include -#include -#include - -#include -#include -//#include -#include - -#include -#include -#include - -#include "dropout.cuh" -#include "softmax.cuh" -#include "strided_batched_gemm.cuh" - -namespace multihead_attn { -namespace self_bias { -namespace rocblas_gemmex { - -std::vector -fwd_cuda(bool use_time_mask, bool is_training, int heads, - torch::Tensor const &inputs, torch::Tensor const &input_weights, - torch::Tensor const &output_weights, torch::Tensor const &input_biases, - torch::Tensor const &output_biases, const uint8_t *pad_mask, - float dropout_prob) { - const int embed_dim = inputs.size(2); - const int sequences = inputs.size(1); - const int q_seq_len = inputs.size(0); - const int k_seq_len = q_seq_len; - const int batches = sequences * q_seq_len; - const int head_dim = embed_dim / heads; - const int output_lin_dim = 3 * embed_dim; - const int attn_batches = heads * sequences; - const int lead_dim = attn_batches * 3 * head_dim; - const int batch_stride = 3 * head_dim; - const int dropout_elems = attn_batches * q_seq_len * k_seq_len; - const float alpha = 1.0; - const float beta_zero = 0.0; - const float beta_one = 1.0; - const float scale = 1.0 / sqrt(static_cast(head_dim)); - - // There is no reason to use more than one stream as every kernel is - // sequentially dependent - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - cublasSetStream(handle, stream); - - // 3 Intermediate Results + Output (Note: dropout intermediates are generated - // by ATen library code) - auto act_options = inputs.options().requires_grad(false); - auto mask_options = act_options.dtype(torch::kUInt8); - - torch::Tensor input_lin_results = - torch::empty({q_seq_len, sequences, output_lin_dim}, act_options); - torch::Tensor softmax_results = - torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_results = - torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_mask = - torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); - torch::Tensor matmul2_results = - torch::empty({q_seq_len, attn_batches, head_dim}, act_options); - torch::Tensor outputs = torch::empty_like(inputs, act_options); - - // Input Linear Results Pointers to Q, K, and V of interviewed activations - void *q_lin_results_ptr = static_cast(input_lin_results.data_ptr()); - void *k_lin_results_ptr = static_cast( - static_cast(input_lin_results.data_ptr()) + head_dim); - void *v_lin_results_ptr = static_cast( - static_cast(input_lin_results.data_ptr()) + 2 * head_dim); - - // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) - void *softmax_results_ptr = static_cast(softmax_results.data_ptr()); - - char a_layout_t{'t'}; - char a_layout_n{'n'}; - char b_layout_n{'n'}; - - rocblas_int flags = 0; - - //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); - - // Input Linear Fwd - input_lin_results.copy_(input_biases); - TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, - CUBLAS_OP_T, - CUBLAS_OP_N, - output_lin_dim, - batches, - embed_dim, - static_cast(&alpha), - static_cast(input_weights.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(inputs.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(&beta_one), - q_lin_results_ptr, - rocblas_datatype_f16_r, - output_lin_dim, - q_lin_results_ptr, - rocblas_datatype_f16_r, - output_lin_dim, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard /*algo*/, - 0 /*solution_index*/, - flags)); - - // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) - gemm_switch_fp32accum( a_layout_t, - b_layout_n, - k_seq_len, - q_seq_len, - head_dim, - scale, - static_cast(k_lin_results_ptr), - lead_dim, - batch_stride, - static_cast(q_lin_results_ptr), - lead_dim, - batch_stride, - beta_zero, - static_cast(softmax_results_ptr), - k_seq_len, - k_seq_len*q_seq_len, - static_cast(softmax_results_ptr), - k_seq_len, - k_seq_len*q_seq_len, - attn_batches, - flags); - - // Padded Softmax - bool softmax_success = false; - if (pad_mask == nullptr) { - softmax_success = dispatch_softmax( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(softmax_results_ptr), k_seq_len, - k_seq_len, attn_batches * q_seq_len); - } else { - if (use_time_mask) { - softmax_success = dispatch_time_masked_softmax( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(softmax_results_ptr), pad_mask, - k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len); - } else { - softmax_success = dispatch_masked_softmax( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(softmax_results_ptr), pad_mask, - k_seq_len, k_seq_len, attn_batches * q_seq_len, - attn_batches * q_seq_len / sequences); - } - } - - if (is_training) { - // use at:: function so that C++ version generates the same random mask as - // python version - auto dropout_tuple = - at::_fused_dropout(softmax_results, 1.0f - dropout_prob); - dropout_results = std::get<0>(dropout_tuple); - dropout_mask = std::get<1>(dropout_tuple); - } - - // Matmul2 - gemm_switch_fp32accum( a_layout_n, - b_layout_n, - head_dim, - q_seq_len, - k_seq_len, - alpha, - static_cast(v_lin_results_ptr), - lead_dim, - batch_stride, - (is_training) ? static_cast(dropout_results.data_ptr()) : static_cast(softmax_results.data_ptr()) , - k_seq_len, - k_seq_len*q_seq_len, - beta_zero, - static_cast(matmul2_results.data_ptr()), - head_dim*attn_batches, - head_dim, - static_cast(matmul2_results.data_ptr()), - head_dim*attn_batches, - head_dim, - attn_batches, - flags); - - outputs.copy_(output_biases); - - // Output Linear - TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, - CUBLAS_OP_T, - CUBLAS_OP_N, - embed_dim, - batches, - embed_dim, - static_cast(&alpha), - static_cast(output_weights.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(matmul2_results.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(&beta_one), - static_cast(outputs.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(outputs.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard /*algo*/, - 0 /*solution_index*/, - flags)); - //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); - - return {input_lin_results, softmax_results, dropout_results, - dropout_mask, matmul2_results, outputs}; -} - -std::vector bwd_cuda( - int heads, torch::Tensor const &output_grads, - torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results, - torch::Tensor const &softmax_results, - torch::Tensor const &input_lin_results, torch::Tensor const &inputs, - torch::Tensor const &input_weights, torch::Tensor const &output_weights, - torch::Tensor const &dropout_mask, float dropout_prob) { - const int embed_dim = inputs.size(2); - const int sequences = inputs.size(1); - const int q_seq_len = inputs.size(0); - const int k_seq_len = q_seq_len; - const int batches = sequences * q_seq_len; - const int head_dim = embed_dim / heads; - const int output_lin_dim = 3 * embed_dim; - const int attn_batches = heads * sequences; - const int lead_dim = attn_batches * 3 * head_dim; - const int batch_stride = 3 * head_dim; - const int dropout_elems = attn_batches * q_seq_len * k_seq_len; - const float alpha = 1.0; - const float beta = 0.0; - const float scale = 1.0 / sqrt(static_cast(head_dim)); - - // TODO: Streams can be used in Backprop but I haven't added more than one - // in my first attempt to create the code - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - cublasSetStream(handle, stream); - - // Output Tensor Allocations - torch::Tensor input_grads = torch::empty_like(inputs); - torch::Tensor input_weight_grads = torch::empty_like(input_weights); - torch::Tensor output_weight_grads = torch::empty_like(output_weights); - // Intermediate Tensor Allocations - at::Tensor output_lin_grads = torch::empty_like(matmul2_results); - at::Tensor matmul2_grads = torch::empty_like(dropout_results); - at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results); - - auto q_lin_results_ptr = static_cast(input_lin_results.data_ptr()); - auto k_lin_results_ptr = - static_cast(input_lin_results.data_ptr()) + head_dim; - auto v_lin_results_ptr = - static_cast(input_lin_results.data_ptr()) + 2 * head_dim; - - auto q_lin_grads_ptr = static_cast(input_lin_output_grads.data_ptr()); - auto k_lin_grads_ptr = - static_cast(input_lin_output_grads.data_ptr()) + head_dim; - auto v_lin_grads_ptr = - static_cast(input_lin_output_grads.data_ptr()) + 2 * head_dim; - - char a_layout_n{'n'}; - char a_layout_t{'t'}; - char b_layout_n{'n'}; - char b_layout_t{'t'}; - - rocblas_int flags = 0; - - //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); - #ifdef __HIP_PLATFORM_HCC__ - #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR) - #define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) - #if USE_GEMM_FLAGS_FP16_ALT_IMPL - #ifdef BACKWARD_PASS_GUARD - flags = at::BACKWARD_PASS_GUARD_CLASS::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0; - #endif - #endif - #endif - - // Output Linear Dgrad - TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, - CUBLAS_OP_N, - CUBLAS_OP_N, - embed_dim, - batches, - embed_dim, - static_cast(&alpha), - static_cast(output_weights.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(output_grads.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(&beta), - static_cast(output_lin_grads.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(output_lin_grads.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard /*algo*/, - 0 /*solution_index*/, - flags)); - - // Output Linear Wgrad - TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, - CUBLAS_OP_N, - CUBLAS_OP_T, - embed_dim, - embed_dim, - batches, - static_cast(&alpha), - static_cast(matmul2_results.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(output_grads.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(&beta), - static_cast(output_weight_grads.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(output_weight_grads.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard /*algo*/, - 0 /*solution_index*/, - flags)); - - auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false); - // MatMul2 Dgrad1 - gemm_switch_fp32accum( a_layout_t, - b_layout_n, - k_seq_len, - q_seq_len, - head_dim, - alpha, - static_cast(v_lin_results_ptr), - lead_dim, - batch_stride, - static_cast(output_lin_grads.data_ptr()), - head_dim*attn_batches, - head_dim, - beta, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - attn_batches, - flags); - - // Matmul2 Dgrad2 - gemm_switch_fp32accum( a_layout_n, - b_layout_t, - head_dim, - k_seq_len, - q_seq_len, - alpha, - static_cast(output_lin_grads.data_ptr()), - head_dim*attn_batches, - head_dim, - static_cast(dropout_results.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - beta, - v_lin_grads_ptr, - lead_dim, - batch_stride, - v_lin_grads_ptr, - lead_dim, - batch_stride, - attn_batches, - flags); - - // Apply Dropout Mask and Scale by Dropout Probability - // Softmax Grad - dispatch_masked_scale_softmax_backward_stream( - static_cast(matmul2_grads.data_ptr()), - static_cast(matmul2_grads.data_ptr()), - reinterpret_cast(softmax_results.data_ptr()), - static_cast(dropout_mask.data_ptr()), - 1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len, - attn_batches * q_seq_len, stream); - - // Matmul1 Dgrad1 - gemm_switch_fp32accum( a_layout_n, - b_layout_n, - head_dim, - q_seq_len, - k_seq_len, - scale, - k_lin_results_ptr, - lead_dim, - batch_stride, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - beta, - q_lin_grads_ptr, - lead_dim, - batch_stride, - q_lin_grads_ptr, - lead_dim, - batch_stride, - attn_batches, - flags); - - // Matmul1 Dgrad2 - gemm_switch_fp32accum( a_layout_n, - b_layout_t, - head_dim, - k_seq_len, - q_seq_len, - scale, - q_lin_results_ptr, - lead_dim, - batch_stride, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - beta, - k_lin_grads_ptr, - lead_dim, - batch_stride, - k_lin_grads_ptr, - lead_dim, - batch_stride, - attn_batches, - flags); - // Input Linear Dgrad - TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, - CUBLAS_OP_N, - CUBLAS_OP_N, - embed_dim, - batches, - output_lin_dim, - static_cast(&alpha), - static_cast(input_weights.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(input_lin_output_grads.data_ptr()), - rocblas_datatype_f16_r, - output_lin_dim, - static_cast(&beta), - static_cast(input_grads.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(input_grads.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard /*algo*/, - 0 /*solution_index*/, - flags)); - - // Input Linear Wgrad - TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, - CUBLAS_OP_N, - CUBLAS_OP_T, - embed_dim, - output_lin_dim, - batches, - static_cast(&alpha), - static_cast(inputs.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(q_lin_grads_ptr), - rocblas_datatype_f16_r, - output_lin_dim, - static_cast(&beta), - static_cast(input_weight_grads.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(input_weight_grads.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard /*algo*/, - 0 /*solution_index*/, - flags)); - - auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false); - //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); - - return {input_grads, input_weight_grads, output_weight_grads, - input_bias_grads, output_bias_grads}; -} - -} // end namespace rocblas_gemmex -} // end namespace self -} // end namespace multihead_attn \ No newline at end of file diff --git a/apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu b/apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu deleted file mode 100644 index af60e5a..0000000 --- a/apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu +++ /dev/null @@ -1,509 +0,0 @@ -#include -#include -#include - -#include -#include -//#include -#include - -#include -#include -#include - -#include "dropout.cuh" -#include "softmax.cuh" -#include "strided_batched_gemm.cuh" - -namespace multihead_attn { -namespace self { -namespace rocblas_gemmex { - -std::vector fwd_cuda(bool use_time_mask, bool is_training, - int heads, torch::Tensor const &inputs, - torch::Tensor const &input_weights, - torch::Tensor const &output_weights, - const uint8_t *pad_mask, - float dropout_prob) { - const int embed_dim = inputs.size(2); - const int sequences = inputs.size(1); - const int q_seq_len = inputs.size(0); - const int k_seq_len = q_seq_len; - const int batches = sequences * q_seq_len; - const int head_dim = embed_dim / heads; - const int output_lin_dim = 3 * embed_dim; - const int attn_batches = heads * sequences; - const int lead_dim = attn_batches * 3 * head_dim; - const int batch_stride = 3 * head_dim; - const int dropout_elems = attn_batches * q_seq_len * k_seq_len; - const float alpha = 1.0; - const float beta = 0.0; - const float scale = 1.0 / sqrt(static_cast(head_dim)); - - // There is no reason to use more than one stream as every kernel is - // sequentially dependent - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - cublasSetStream(handle, stream); - - // 3 Intermediate Results + Output (Note: dropout intermediates are generated - // by ATen library code) - auto act_options = inputs.options().requires_grad(false); - auto mask_options = act_options.dtype(torch::kUInt8); - - torch::Tensor input_lin_results = - torch::empty({q_seq_len, sequences, output_lin_dim}, act_options); - torch::Tensor softmax_results = - torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_results = - torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_mask = - torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); - torch::Tensor matmul2_results = - torch::empty({q_seq_len, attn_batches, head_dim}, act_options); - torch::Tensor outputs = torch::empty_like(inputs, act_options); - - // Input Linear Results Pointers to Q, K, and V of interviewed activations - void *q_lin_results_ptr = static_cast(input_lin_results.data_ptr()); - void *k_lin_results_ptr = static_cast( - static_cast(input_lin_results.data_ptr()) + head_dim); - void *v_lin_results_ptr = static_cast( - static_cast(input_lin_results.data_ptr()) + 2 * head_dim); - - // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) - void *softmax_results_ptr = static_cast(softmax_results.data_ptr()); - - char a_layout_t{'t'}; - char a_layout_n{'n'}; - char b_layout_n{'n'}; - - rocblas_int flags = 0; - - //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); - - // Input Linear Fwd - TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, - CUBLAS_OP_T, - CUBLAS_OP_N, - output_lin_dim, - batches, - embed_dim, - static_cast(&alpha), - static_cast(input_weights.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(inputs.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(&beta), - q_lin_results_ptr, - rocblas_datatype_f16_r, - output_lin_dim, - q_lin_results_ptr, - rocblas_datatype_f16_r, - output_lin_dim, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard /*algo*/, - 0 /*solution_index*/, - flags)); - - // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) - gemm_switch_fp32accum( a_layout_t, - b_layout_n, - k_seq_len, - q_seq_len, - head_dim, - scale, - static_cast(k_lin_results_ptr), - lead_dim, - batch_stride, - static_cast(q_lin_results_ptr), - lead_dim, - batch_stride, - beta, - static_cast(softmax_results_ptr), - k_seq_len, - k_seq_len*q_seq_len, - static_cast(softmax_results_ptr), - k_seq_len, - k_seq_len*q_seq_len, - attn_batches, - flags); - - // Padded Softmax - bool softmax_success = false; - if (pad_mask == nullptr) { - softmax_success = dispatch_softmax( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(softmax_results_ptr), k_seq_len, - k_seq_len, attn_batches * q_seq_len); - } else { - if (use_time_mask) { - softmax_success = dispatch_time_masked_softmax( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(softmax_results_ptr), pad_mask, - k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len); - } else { - softmax_success = dispatch_masked_softmax( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(softmax_results_ptr), pad_mask, - k_seq_len, k_seq_len, attn_batches * q_seq_len, - attn_batches * q_seq_len / sequences); - } - } - assert(softmax_success); - - if (is_training) { - apex_fused_dropout_cuda( - static_cast(softmax_results.data_ptr()), - static_cast(dropout_results.data_ptr()), - static_cast(dropout_mask.data_ptr()), dropout_elems, - (1.0f - dropout_prob)); - } - - // Matmul2 - gemm_switch_fp32accum( a_layout_n, - b_layout_n, - head_dim, - q_seq_len, - k_seq_len, - alpha, - static_cast(v_lin_results_ptr), - lead_dim, - batch_stride, - (is_training) ? static_cast(dropout_results.data_ptr()) : static_cast(softmax_results.data_ptr()) , - k_seq_len, - k_seq_len*q_seq_len, - beta, - static_cast(matmul2_results.data_ptr()), - head_dim*attn_batches, - head_dim, - static_cast(matmul2_results.data_ptr()), - head_dim*attn_batches, - head_dim, - attn_batches, - flags); - - // Output Linear - TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, - CUBLAS_OP_T, - CUBLAS_OP_N, - embed_dim, - batches, - embed_dim, - static_cast(&alpha), - static_cast(output_weights.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(matmul2_results.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(&beta), - static_cast(outputs.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(outputs.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard /*algo*/, - 0 /*solution_index*/, - flags)); - //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); - - return {input_lin_results, softmax_results, dropout_results, - dropout_mask, matmul2_results, outputs}; -} - -std::vector bwd_cuda( - int heads, torch::Tensor const &output_grads, - torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results, - torch::Tensor const &softmax_results, - torch::Tensor const &input_lin_results, torch::Tensor const &inputs, - torch::Tensor const &input_weights, torch::Tensor const &output_weights, - torch::Tensor const &dropout_mask, float dropout_prob) { - const int embed_dim = inputs.size(2); - const int sequences = inputs.size(1); - const int q_seq_len = inputs.size(0); - const int k_seq_len = q_seq_len; - const int batches = sequences * q_seq_len; - const int head_dim = embed_dim / heads; - const int output_lin_dim = 3 * embed_dim; - const int attn_batches = heads * sequences; - const int lead_dim = attn_batches * 3 * head_dim; - const int batch_stride = 3 * head_dim; - const int dropout_elems = attn_batches * q_seq_len * k_seq_len; - const float alpha = 1.0; - const float beta = 0.0; - const float scale = 1.0 / sqrt(static_cast(head_dim)); - - // TODO: Streams can be used in Backprop but I haven't added more than one - // in my first attempt to create the code - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - cublasSetStream(handle, stream); - - // Output Tensor Allocations - torch::Tensor input_grads = torch::empty_like(inputs); - torch::Tensor input_weight_grads = torch::empty_like(input_weights); - torch::Tensor output_weight_grads = torch::empty_like(output_weights); - // Intermediate Tensor Allocations - at::Tensor output_lin_grads = torch::empty_like(matmul2_results); - at::Tensor matmul2_grads = torch::empty_like(dropout_results); - at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results); - - auto q_lin_results_ptr = static_cast(input_lin_results.data_ptr()); - auto k_lin_results_ptr = - static_cast(input_lin_results.data_ptr()) + head_dim; - auto v_lin_results_ptr = - static_cast(input_lin_results.data_ptr()) + 2 * head_dim; - - auto q_lin_grads_ptr = static_cast(input_lin_output_grads.data_ptr()); - auto k_lin_grads_ptr = - static_cast(input_lin_output_grads.data_ptr()) + head_dim; - auto v_lin_grads_ptr = - static_cast(input_lin_output_grads.data_ptr()) + 2 * head_dim; - - char a_layout_n{'n'}; - char a_layout_t{'t'}; - char b_layout_n{'n'}; - char b_layout_t{'t'}; - - rocblas_int flags = 0; - - //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); - #ifdef __HIP_PLATFORM_HCC__ - #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR) - #define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) - #if USE_GEMM_FLAGS_FP16_ALT_IMPL - #ifdef BACKWARD_PASS_GUARD - flags = at::BACKWARD_PASS_GUARD_CLASS::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0; - #endif - #endif - #endif - - // Output Linear Dgrad - TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, - CUBLAS_OP_N, - CUBLAS_OP_N, - embed_dim, - batches, - embed_dim, - static_cast(&alpha), - static_cast(output_weights.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(output_grads.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(&beta), - static_cast(output_lin_grads.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(output_lin_grads.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard /*algo*/, - 0 /*solution_index*/, - flags)); - - // Output Linear Wgrad - TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, - CUBLAS_OP_N, - CUBLAS_OP_T, - embed_dim, - embed_dim, - batches, - static_cast(&alpha), - static_cast(matmul2_results.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(output_grads.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(&beta), - static_cast(output_weight_grads.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(output_weight_grads.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard /*algo*/, - 0 /*solution_index*/, - flags)); - - // MatMul2 Dgrad1 - gemm_switch_fp32accum( a_layout_t, - b_layout_n, - k_seq_len, - q_seq_len, - head_dim, - alpha, - static_cast(v_lin_results_ptr), - lead_dim, - batch_stride, - static_cast(output_lin_grads.data_ptr()), - head_dim*attn_batches, - head_dim, - beta, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - attn_batches, - flags); - - // Matmul2 Dgrad2 - gemm_switch_fp32accum( a_layout_n, - b_layout_t, - head_dim, - k_seq_len, - q_seq_len, - alpha, - static_cast(output_lin_grads.data_ptr()), - head_dim*attn_batches, - head_dim, - static_cast(dropout_results.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - beta, - v_lin_grads_ptr, - lead_dim, - batch_stride, - v_lin_grads_ptr, - lead_dim, - batch_stride, - attn_batches, - flags); - - // Apply Dropout Mask and Scale by Dropout Probability - apex_masked_scale_cuda( - static_cast(matmul2_grads.data_ptr()), - static_cast(matmul2_grads.data_ptr()), - static_cast(dropout_mask.data_ptr()), - dropout_elems, - (1.0 / (1.0 - dropout_prob))); - - // Softmax Grad - bool softmax_success = false; - softmax_success = dispatch_softmax_backward( - static_cast(matmul2_grads.data_ptr()), - static_cast(matmul2_grads.data_ptr()), - reinterpret_cast(softmax_results.data_ptr()), k_seq_len, - k_seq_len, attn_batches * q_seq_len); - assert(softmax_success); - - // Matmul1 Dgrad1 - gemm_switch_fp32accum( a_layout_n, - b_layout_n, - head_dim, - q_seq_len, - k_seq_len, - scale, - k_lin_results_ptr, - lead_dim, - batch_stride, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - beta, - q_lin_grads_ptr, - lead_dim, - batch_stride, - q_lin_grads_ptr, - lead_dim, - batch_stride, - attn_batches, - flags); - - // Matmul1 Dgrad2 - gemm_switch_fp32accum( a_layout_n, - b_layout_t, - head_dim, - k_seq_len, - q_seq_len, - scale, - q_lin_results_ptr, - lead_dim, - batch_stride, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - beta, - k_lin_grads_ptr, - lead_dim, - batch_stride, - k_lin_grads_ptr, - lead_dim, - batch_stride, - attn_batches, - flags); - - // Input Linear Dgrad - TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, - CUBLAS_OP_N, - CUBLAS_OP_N, - embed_dim, - batches, - output_lin_dim, - static_cast(&alpha), - static_cast(input_weights.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(q_lin_grads_ptr), - rocblas_datatype_f16_r, - output_lin_dim, - static_cast(&beta), - static_cast(input_grads.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(input_grads.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard /*algo*/, - 0 /*solution_index*/, - flags)); - - // Input Linear Wgrad - TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, - CUBLAS_OP_N, - CUBLAS_OP_T, - embed_dim, - output_lin_dim, - batches, - static_cast(&alpha), - static_cast(inputs.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(q_lin_grads_ptr), - rocblas_datatype_f16_r, - output_lin_dim, - static_cast(&beta), - static_cast(input_weight_grads.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - static_cast(input_weight_grads.data_ptr()), - rocblas_datatype_f16_r, - embed_dim, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard /*algo*/, - 0 /*solution_index*/, - flags)); - //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); - - return { - input_grads, - input_weight_grads, - output_weight_grads - }; -} - -} // end namespace rocblas_gemmex -} // end namespace self -} // end namespace multihead_attn - diff --git a/apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu b/apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu deleted file mode 100644 index 711a67f..0000000 --- a/apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu +++ /dev/null @@ -1,580 +0,0 @@ -#include -#include -#include - -#include -#include -//#include -#include - -#include -#include -#include - -#include "dropout.cuh" -#include "layer_norm.cuh" -#include "softmax.cuh" -#include "strided_batched_gemm.cuh" - -namespace multihead_attn { -namespace self_norm_add { -namespace rocblas_gemmex { - -std::vector fwd_cuda(bool use_time_mask, bool is_training, - int heads, torch::Tensor const &inputs, - torch::Tensor const &lyr_nrm_gamma_weights, - torch::Tensor const &lyr_nrm_beta_weights, - torch::Tensor const &input_weights, - torch::Tensor const &output_weights, - const uint8_t *pad_mask, - float dropout_prob) { - const int embed_dim = inputs.size(2); - const int sequences = inputs.size(1); - const int q_seq_len = inputs.size(0); - const int k_seq_len = q_seq_len; - const int batches = sequences * q_seq_len; - const int total_tokens = batches * embed_dim; - const int head_dim = embed_dim / heads; - const int output_lin_dim = 3 * embed_dim; - const int attn_batches = heads * sequences; - const int lead_dim = attn_batches * 3 * head_dim; - const int batch_stride = 3 * head_dim; - const int dropout_elems = attn_batches * q_seq_len * k_seq_len; - const float alpha = 1.0; - const float beta = 0.0; - const float scale = 1.0 / sqrt(static_cast(head_dim)); - - // There is no reason to use more than one stream as every kernel is - // sequentially dependent - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - cublasSetStream(handle, stream); - - // 3 Intermediate Results + Output (Note: dropout intermediates are generated - // by ATen library code) - auto act_options = inputs.options().requires_grad(false); - auto lyr_nrm_options = act_options.dtype(torch::kFloat32); - auto mask_options = act_options.dtype(torch::kUInt8); - - torch::Tensor lyr_nrm_mean = torch::empty({batches}, lyr_nrm_options); - torch::Tensor lyr_nrm_invvar = torch::empty({batches}, lyr_nrm_options); - torch::Tensor lyr_nrm_results = torch::empty_like(inputs, act_options); - - torch::Tensor input_lin_results = - torch::empty({q_seq_len, sequences, output_lin_dim}, act_options); - torch::Tensor softmax_results = - torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_results = - torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options); - torch::Tensor dropout_mask = - torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options); - torch::Tensor matmul2_results = - torch::empty({q_seq_len, attn_batches, head_dim}, act_options); - torch::Tensor output_lin_results = torch::empty_like(inputs, act_options); - torch::Tensor dropout_add_mask = torch::empty_like(inputs, mask_options); - torch::Tensor outputs = torch::empty_like(inputs, act_options); - - // Input Linear Results Pointers to Q, K, and V of interviewed activations - void *q_lin_results_ptr = static_cast(input_lin_results.data_ptr()); - void *k_lin_results_ptr = static_cast( - static_cast(input_lin_results.data_ptr()) + head_dim); - void *v_lin_results_ptr = static_cast( - static_cast(input_lin_results.data_ptr()) + 2 * head_dim); - - // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) - void *softmax_results_ptr = static_cast(softmax_results.data_ptr()); - - char a_layout_t{'t'}; - char a_layout_n{'n'}; - char b_layout_n{'n'}; - - rocblas_int flags = 0; - - //THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); - // Layer Norm - HostApplyLayerNorm( - static_cast(lyr_nrm_results.data_ptr()), - static_cast(lyr_nrm_mean.data_ptr()), - static_cast(lyr_nrm_invvar.data_ptr()), - static_cast(inputs.data_ptr()), - static_cast(batches), // n1 - static_cast(embed_dim), // n2 - 1.0e-5, static_cast(lyr_nrm_gamma_weights.data_ptr()), - static_cast(lyr_nrm_beta_weights.data_ptr())); - - // Input Linear Fwd - TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, - CUBLAS_OP_T, - CUBLAS_OP_N, - output_lin_dim, - batches, - embed_dim, - static_cast(&alpha), - static_cast(input_weights.data_ptr()), - rocblas_datatype_f16_r /*a_type*/, - embed_dim, - //static_cast(inputs.data_ptr()), - static_cast(lyr_nrm_results.data_ptr()), - rocblas_datatype_f16_r /*b_type*/, - embed_dim, - static_cast(&beta), - q_lin_results_ptr, - rocblas_datatype_f16_r /*c_type*/, - output_lin_dim, - q_lin_results_ptr, - rocblas_datatype_f16_r /*d_type*/, - output_lin_dim, - rocblas_datatype_f32_r /*compute_type*/, - rocblas_gemm_algo_standard /*algo*/, - 0 /*solution_index*/, - flags)); - - // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) - gemm_switch_fp32accum( a_layout_t, - b_layout_n, - k_seq_len, - q_seq_len, - head_dim, - scale, - static_cast(k_lin_results_ptr), - lead_dim, - batch_stride, - static_cast(q_lin_results_ptr), - lead_dim, - batch_stride, - beta, - static_cast(softmax_results_ptr), - k_seq_len, - k_seq_len*q_seq_len, - static_cast(softmax_results_ptr), - k_seq_len, - k_seq_len*q_seq_len, - attn_batches, - flags); - - // Padded Softmax - bool softmax_success = false; - if (pad_mask == nullptr) { - softmax_success = dispatch_softmax( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(softmax_results_ptr), k_seq_len, - k_seq_len, attn_batches * q_seq_len); - } else { - if (use_time_mask) { - softmax_success = dispatch_time_masked_softmax( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(softmax_results_ptr), pad_mask, - k_seq_len, k_seq_len, attn_batches * q_seq_len, q_seq_len); - } else { - softmax_success = dispatch_masked_softmax( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(softmax_results_ptr), pad_mask, - k_seq_len, k_seq_len, attn_batches * q_seq_len, - attn_batches * q_seq_len / sequences); - } - } - assert(softmax_success); - - if (is_training) { - apex_fused_dropout_cuda( - static_cast(softmax_results.data_ptr()), - static_cast(dropout_results.data_ptr()), - static_cast(dropout_mask.data_ptr()), dropout_elems, - (1.0f - dropout_prob)); - } - - // Matmul2 - gemm_switch_fp32accum( a_layout_n, - b_layout_n, - head_dim, - q_seq_len, - k_seq_len, - alpha, - static_cast(v_lin_results_ptr), - lead_dim, - batch_stride, - (is_training) ? static_cast(dropout_results.data_ptr()) : static_cast(softmax_results.data_ptr()) , - //static_cast(dropout_results.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - beta, - static_cast(matmul2_results.data_ptr()), - head_dim*attn_batches, - head_dim, - static_cast(matmul2_results.data_ptr()), - head_dim*attn_batches, - head_dim, - attn_batches, - flags); - - // Output Linear - TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, - CUBLAS_OP_T, - CUBLAS_OP_N, - embed_dim, - batches, - embed_dim, - static_cast(&alpha), - static_cast(output_weights.data_ptr()), - rocblas_datatype_f16_r /*a_type*/, - embed_dim, - static_cast(matmul2_results.data_ptr()), - rocblas_datatype_f16_r /*b_type*/, - embed_dim, - static_cast(&beta), - static_cast(output_lin_results.data_ptr()), - rocblas_datatype_f16_r /*c_type*/, - embed_dim, - static_cast(output_lin_results.data_ptr()), - rocblas_datatype_f16_r /*d_type*/, - embed_dim, - rocblas_datatype_f32_r /*compute_type*/, - rocblas_gemm_algo_standard /*algo*/, - 0 /*solution_index*/, - flags)); - - - // End-of-block Dropout-Add - if (is_training) { - apex_dropout_add_cuda( - static_cast(output_lin_results.data_ptr()), - static_cast(inputs.data_ptr()), - static_cast(outputs.data_ptr()), - static_cast(dropout_add_mask.data_ptr()), total_tokens, - (1.0f - dropout_prob)); - } else { - apex_add_cuda( - static_cast(output_lin_results.data_ptr()), - static_cast(inputs.data_ptr()), - static_cast(outputs.data_ptr()), total_tokens); - } - - //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); - - return {lyr_nrm_results, lyr_nrm_mean, lyr_nrm_invvar, input_lin_results, - softmax_results, dropout_results, dropout_mask, matmul2_results, - dropout_add_mask, outputs}; -} - -std::vector bwd_cuda( - int heads, torch::Tensor const &output_grads, - torch::Tensor const &matmul2_results, torch::Tensor const &dropout_results, - torch::Tensor const &softmax_results, - torch::Tensor const &input_lin_results, - torch::Tensor const &lyr_nrm_results, torch::Tensor const &lyr_nrm_mean, - torch::Tensor const &lyr_nrm_invvar, torch::Tensor const &inputs, - torch::Tensor const &lyr_nrm_gamma_weights, - torch::Tensor const &lyr_nrm_beta_weights, - torch::Tensor const &input_weights, torch::Tensor const &output_weights, - torch::Tensor const &dropout_mask, torch::Tensor const &dropout_add_mask, - float dropout_prob) { - const int embed_dim = inputs.size(2); - const int sequences = inputs.size(1); - const int q_seq_len = inputs.size(0); - const int k_seq_len = q_seq_len; - const int batches = sequences * q_seq_len; - const int total_tokens = batches * embed_dim; - const int head_dim = embed_dim / heads; - const int output_lin_dim = 3 * embed_dim; - const int attn_batches = heads * sequences; - const int lead_dim = attn_batches * 3 * head_dim; - const int batch_stride = 3 * head_dim; - const int dropout_elems = attn_batches * q_seq_len * k_seq_len; - const float alpha = 1.0; - const float beta = 0.0; - const float scale = 1.0 / sqrt(static_cast(head_dim)); - - // TODO: Streams can be used in Backprop but I haven't added more than one - // in my first attempt to create the code - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - cublasSetStream(handle, stream); - - // Output Tensor Allocations - torch::Tensor input_grads = torch::empty_like(inputs); - torch::Tensor lyr_nrm_gamma_grads = torch::empty_like(lyr_nrm_gamma_weights); - torch::Tensor lyr_nrm_beta_grads = torch::empty_like(lyr_nrm_beta_weights); - torch::Tensor input_weight_grads = torch::empty_like(input_weights); - torch::Tensor output_weight_grads = torch::empty_like(output_weights); - // Intermediate Tensor Allocations - torch::Tensor dropout_add_grads = torch::empty_like(output_grads); - torch::Tensor output_lin_grads = torch::empty_like(matmul2_results); - torch::Tensor matmul2_grads = torch::empty_like(dropout_results); - torch::Tensor input_lin_output_grads = torch::empty_like(input_lin_results); - torch::Tensor input_lin_grads = torch::empty_like(inputs); - - auto q_lin_results_ptr = static_cast(input_lin_results.data_ptr()); - auto k_lin_results_ptr = - static_cast(input_lin_results.data_ptr()) + head_dim; - auto v_lin_results_ptr = - static_cast(input_lin_results.data_ptr()) + 2 * head_dim; - - auto q_lin_grads_ptr = static_cast(input_lin_output_grads.data_ptr()); - auto k_lin_grads_ptr = - static_cast(input_lin_output_grads.data_ptr()) + head_dim; - auto v_lin_grads_ptr = - static_cast(input_lin_output_grads.data_ptr()) + 2 * head_dim; - - char a_layout_n{'n'}; - char a_layout_t{'t'}; - char b_layout_n{'n'}; - char b_layout_t{'t'}; - - rocblas_int flags = 0; - - //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); - #ifdef __HIP_PLATFORM_HCC__ - #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR) - #define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) - #if USE_GEMM_FLAGS_FP16_ALT_IMPL - #ifdef BACKWARD_PASS_GUARD - flags = at::BACKWARD_PASS_GUARD_CLASS::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0; - #endif - #endif - #endif - - // Dropout Add Backward - apex_masked_scale_cuda( - static_cast(output_grads.data_ptr()), - static_cast(dropout_add_grads.data_ptr()), - static_cast(dropout_add_mask.data_ptr()), total_tokens, - (1.0 / (1.0 - dropout_prob))); - - // Output Linear Dgrad - TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, - CUBLAS_OP_N, - CUBLAS_OP_N, - embed_dim, - batches, - embed_dim, - static_cast(&alpha), - static_cast(output_weights.data_ptr()), - rocblas_datatype_f16_r /*a_type*/, - embed_dim, - static_cast(dropout_add_grads.data_ptr()), - rocblas_datatype_f16_r /*b_type*/, - embed_dim, - static_cast(&beta), - static_cast(output_lin_grads.data_ptr()), - rocblas_datatype_f16_r /*c_type*/, - embed_dim, - static_cast(output_lin_grads.data_ptr()), - rocblas_datatype_f16_r /*d_type*/, - embed_dim, - rocblas_datatype_f32_r /*compute_type*/, - rocblas_gemm_algo_standard /*algo*/, - 0 /*solution_index*/, - flags)); - - // Output Linear Wgrad - TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, - CUBLAS_OP_N, - CUBLAS_OP_T, - embed_dim, - embed_dim, - batches, - static_cast(&alpha), - static_cast(matmul2_results.data_ptr()), - rocblas_datatype_f16_r /*a_type*/, - embed_dim, - static_cast(dropout_add_grads.data_ptr()), - rocblas_datatype_f16_r /*b_type*/, - embed_dim, - static_cast(&beta), - static_cast(output_weight_grads.data_ptr()), - rocblas_datatype_f16_r /*c_type*/, - embed_dim, - static_cast(output_weight_grads.data_ptr()), - rocblas_datatype_f16_r /*d_type*/, - embed_dim, - rocblas_datatype_f32_r /*compute_type*/, - rocblas_gemm_algo_standard /*algo*/, - 0 /*solution_index*/, - flags)); - - // MatMul2 Dgrad1 - gemm_switch_fp32accum( a_layout_t, - b_layout_n, - k_seq_len, - q_seq_len, - head_dim, - alpha, - static_cast(v_lin_results_ptr), - lead_dim, - batch_stride, - static_cast(output_lin_grads.data_ptr()), - head_dim*attn_batches, - head_dim, - beta, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - attn_batches, - flags); - - // Matmul2 Dgrad2 - gemm_switch_fp32accum( a_layout_n, - b_layout_t, - head_dim, - k_seq_len, - q_seq_len, - alpha, - static_cast(output_lin_grads.data_ptr()), - head_dim*attn_batches, - head_dim, - static_cast(dropout_results.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - beta, - v_lin_grads_ptr, - lead_dim, - batch_stride, - v_lin_grads_ptr, - lead_dim, - batch_stride, - attn_batches, - flags); - - // Apply Dropout Mask and Scale by Dropout Probability - apex_masked_scale_cuda( - static_cast(matmul2_grads.data_ptr()), - static_cast(matmul2_grads.data_ptr()), - static_cast(dropout_mask.data_ptr()), - dropout_elems, - (1.0 / (1.0 - dropout_prob))); - - // Softmax Grad - bool softmax_success = false; - softmax_success = dispatch_softmax_backward( - static_cast(matmul2_grads.data_ptr()), - static_cast(matmul2_grads.data_ptr()), - reinterpret_cast(softmax_results.data_ptr()), k_seq_len, - k_seq_len, attn_batches * q_seq_len); - assert(softmax_success); - - // Matmul1 Dgrad1 - gemm_switch_fp32accum( a_layout_n, - b_layout_n, - head_dim, - q_seq_len, - k_seq_len, - scale, - k_lin_results_ptr, - lead_dim, - batch_stride, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - beta, - q_lin_grads_ptr, - lead_dim, - batch_stride, - q_lin_grads_ptr, - lead_dim, - batch_stride, - attn_batches, - flags); - - // Matmul1 Dgrad2 - gemm_switch_fp32accum( a_layout_n, - b_layout_t, - head_dim, - k_seq_len, - q_seq_len, - scale, - q_lin_results_ptr, - lead_dim, - batch_stride, - static_cast(matmul2_grads.data_ptr()), - k_seq_len, - k_seq_len*q_seq_len, - beta, - k_lin_grads_ptr, - lead_dim, - batch_stride, - k_lin_grads_ptr, - lead_dim, - batch_stride, - attn_batches, - flags); - - // Input Linear Dgrad - TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, - CUBLAS_OP_N, - CUBLAS_OP_N, - embed_dim, - batches, - output_lin_dim, - static_cast(&alpha), - static_cast(input_weights.data_ptr()), - rocblas_datatype_f16_r /*a_type*/, - embed_dim, - static_cast(q_lin_grads_ptr), - rocblas_datatype_f16_r /*b_type*/, - output_lin_dim, - static_cast(&beta), - //static_cast(input_grads.data_ptr()), - static_cast(input_lin_grads.data_ptr()), - rocblas_datatype_f16_r /*c_type*/, - embed_dim, - static_cast(input_lin_grads.data_ptr()), - rocblas_datatype_f16_r /*d_type*/, - embed_dim, - rocblas_datatype_f32_r /*compute_type*/, - rocblas_gemm_algo_standard /*algo*/, - 0 /*solution_index*/, - flags)); - - // Input Linear Wgrad - TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, - CUBLAS_OP_N, - CUBLAS_OP_T, - embed_dim, - output_lin_dim, - batches, - static_cast(&alpha), - //static_cast(inputs.data_ptr()), - static_cast(lyr_nrm_results.data_ptr()), - rocblas_datatype_f16_r /*a_type*/, - embed_dim, - static_cast(q_lin_grads_ptr), - rocblas_datatype_f16_r /*b_type*/, - output_lin_dim, - static_cast(&beta), - static_cast(input_weight_grads.data_ptr()), - rocblas_datatype_f16_r /*c_type*/, - embed_dim, - static_cast(input_weight_grads.data_ptr()), - rocblas_datatype_f16_r /*d_type*/, - embed_dim, - rocblas_datatype_f32_r /*compute_type*/, - rocblas_gemm_algo_standard /*algo*/, - 0 /*solution_index*/, - flags)); - - // Fused Layer Norm Bwd with Residual Add - HostLayerNormGradient( - static_cast(input_lin_grads.data_ptr()), - static_cast(output_grads.data_ptr()), - static_cast(lyr_nrm_mean.data_ptr()), - static_cast(lyr_nrm_invvar.data_ptr()), inputs, - static_cast(batches), // n1 - static_cast(embed_dim), // n2 - static_cast(lyr_nrm_gamma_weights.data_ptr()), - static_cast(lyr_nrm_beta_weights.data_ptr()), 1.0e-5, - static_cast(input_grads.data_ptr()), - static_cast(lyr_nrm_gamma_grads.data_ptr()), - static_cast(lyr_nrm_beta_grads.data_ptr())); - - //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); - - return {input_grads, lyr_nrm_gamma_grads, lyr_nrm_beta_grads, - input_weight_grads, output_weight_grads}; -} - -} // end namespace rocblas_gemmex -} // end namespace self_norm_add -} // end namespace multihead_attn \ No newline at end of file diff --git a/apex/contrib/csrc/multihead_attn/softmax.cuh b/apex/contrib/csrc/multihead_attn/softmax.cuh deleted file mode 100644 index 996dd41..0000000 --- a/apex/contrib/csrc/multihead_attn/softmax.cuh +++ /dev/null @@ -1,3149 +0,0 @@ -#pragma once -#include "philox.cuh" -#include -#include - -#ifdef OLD_GENERATOR_PATH -#include -#else -#include -#endif - -#include -#include -#include -#include -#include -#include -#include -#include - -#ifdef __HIP_PLATFORM_HCC__ -#define APEX_WARP_SHFL_XOR(mask, value, offset, width) __shfl_xor(value, offset, width) -#else -#define APEX_WARP_SHFL_XOR __shfl_xor_sync -#endif -namespace { -template -__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); - -template -__device__ __inline__ void apply_mask(Datatype *dst, Datatype value, - const uint8_t *src); - -template -__device__ __inline__ void apply_additive_mask(Datatype *dst, - const Datatype *additive_mask); - -template <> -__device__ __inline__ void copy_vector<__half, 1>(__half *dst, - const __half *src) { - *dst = *src; -} - -template <> -__device__ __inline__ void copy_vector(float *dst, const float *src) { - *dst = *src; -} - -template <> -__device__ __inline__ void copy_vector<__half, 4>(__half *dst, - const __half *src) { - *((float2 *)dst) = *((float2 *)src); -} -template <> -__device__ __inline__ void copy_vector(uint8_t *dst, - const uint8_t *src) { - *dst = *src; -} - -template <> -__device__ __inline__ void copy_vector(uint8_t *dst, - const uint8_t *src) { - *((half2 *)dst) = *((half2 *)src); -} - -template <> -__device__ __inline__ void apply_mask<__half, 1>(__half *dst, __half value, - const uint8_t *src) { - if (*src == 1) { - *dst = value; - } -} - -template <> -__device__ __inline__ void -apply_additive_mask<__half, 1>(__half *dst, const __half *additive_mask) { - *dst += *additive_mask; -} - -template <> -__device__ __inline__ void -apply_additive_mask<__half, 4>(__half *dst, const __half *additive_mask) { - *dst += *additive_mask; - *(dst + 1) += *(additive_mask + 1); - *(dst + 2) += *(additive_mask + 2); - *(dst + 3) += *(additive_mask + 3); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -// Warp Softmax forward -//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - -// WARP_BATCH number of batches. -// WARP_ITERATOINS The number of iterations required for one warp to iterate -// over all data. WARP_SIZE number of elements working on a single batch, has to -// be a power of two. ELEMENTS_PER_LDG_STG has to be 1. -template -__global__ void softmax_warp_forward(input_t *dst, const output_t *src, - int batch_size, int stride, - int element_count) { - assert(ELEMENTS_PER_LDG_STG == 1); - - int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; - - // batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the - // batch - int local_idx = threadIdx.x; - - src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; - dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; - - // load data from global memory - input_t elements_input[WARP_BATCH][WARP_ITERATIONS]; - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : element_count; - for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; -#pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - elements_input[i][it + element] = - -std::numeric_limits::infinity(); - } - - if (element_index < batch_element_count) { - copy_vector( - &elements_input[i][it], src + i * element_count + it * WARP_SIZE); - } - } - } - - // convert input_t to acc_t - acc_t elements[WARP_BATCH][WARP_ITERATIONS]; - for (int i = 0; i < WARP_BATCH; ++i) { - for (int it = 0; it < WARP_ITERATIONS; ++it) { - elements[i][it] = elements_input[i][it]; - } - } - - constexpr uint32_t FULL_MASK = 0xffffffff; - - // compute local max_value - - // take the max_value of the first element to avoid one max call - acc_t max_value[WARP_BATCH]; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = elements[i][0]; - } - -#pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = - (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; - } - } - -// reduction max_value -#pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { - float val[WARP_BATCH]; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE); - } -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i]; - } - } - - // compute local sum - acc_t sum[WARP_BATCH]{0.0f}; - -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - for (int it = 0; it < WARP_ITERATIONS; ++it) { - // elements[i][it] = expf(elements[i][it] - max_value[i]); - elements[i][it] = std::exp(elements[i][it] - max_value[i]); - sum[i] += elements[i][it]; - } - } - -// reduction sum -#pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE); - } - } - -// store result -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { - // dst[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i]; - output_t out[ELEMENTS_PER_LDG_STG]; - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = elements[i][it + element] / sum[i]; - } - copy_vector( - dst + i * element_count + it * WARP_SIZE, out); - } else { - break; - } - } - } -} - -// WARP_BATCH number of batches. -// WARP_ITERATOINS The number of iterations required for one warp to iterate -// over all data. WARP_SIZE number of elements working on a single batch, has to -// be a power of two. ELEMENTS_PER_LDG_STG has to be 1. -template -using softmax_forward_func = void (*)(input_t *dst, const output_t *src, - int batch_size, int stride, - int element_count); - -template -bool warp_softmax_kernel(int log2_elements, int &warp_size, - int &batches_per_warp, - softmax_forward_func &kernel) { - // determine size of a warp - const int next_power_of_two = 1 << log2_elements; - warp_size = (next_power_of_two < 32) ? next_power_of_two : 32; - - // determine how many batches a warp should process. - batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - switch (log2_elements) { - case 0: // 1 - kernel = &softmax_warp_forward; - break; - case 1: // 2 - kernel = &softmax_warp_forward; - break; - case 2: // 4 - kernel = &softmax_warp_forward; - break; - case 3: // 8 - kernel = &softmax_warp_forward; - break; - case 4: // 16 - kernel = &softmax_warp_forward; - break; - case 5: // 32 - kernel = &softmax_warp_forward; - break; - case 6: // 64 - kernel = &softmax_warp_forward; - break; - case 7: // 128 - kernel = &softmax_warp_forward; - break; - case 8: // 256 - kernel = &softmax_warp_forward; - break; - case 9: // 512 - kernel = &softmax_warp_forward; - break; - case 10: // 1024 - kernel = &softmax_warp_forward; - break; - default: - return false; - } - return true; -} - -template -bool dispatch_softmax(output_t *dst, const input_t *src, int softmax_elements, - int softmax_elements_stride, int batch_count) { - if (softmax_elements == 0) { - return true; - } else if (softmax_elements <= 1024) { - // compute function index. there's a function for each power of two size up - // to 1024. - int log2_elements = 0; - while ((1 << log2_elements) < softmax_elements) - ++log2_elements; - - softmax_forward_func kernel; - int warp_size, batches_per_warp; - if (!warp_softmax_kernel( - log2_elements, warp_size, batches_per_warp, kernel)) { - return false; - } - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - // compute warps per block. - int warps_per_block = (threads_per_block / warp_size); - - // compute launch size - int batches_per_block = warps_per_block * batches_per_warp; - int blocks = (batch_count + batches_per_block - 1) / batches_per_block; - dim3 threads(warp_size, warps_per_block, 1); - - // launch - kernel<<>>( - dst, src, batch_count, softmax_elements_stride, softmax_elements); - return true; - } - return false; -} - -template -__global__ void additive_masked_softmax_dropout_warp_forward_vec4( - output_t *dst, uint8_t *dropout_mask, const input_t *src, - const input_t *pad_mask, int batch_size, int stride, int element_count, - int pad_batch_stride, at::PhiloxCudaState philox_args, float p) { - - assert(ELEMENTS_PER_LDG_STG == 4); - int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; - int tid = blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x + - threadIdx.x; - acc_t pinv = acc_t(1) / p; - // batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the - // batch - int local_idx = threadIdx.x; - // vectorize if element_count is multiple of 4, else don't vectorize - input_t elements_input[WARP_BATCH][WARP_ITERATIONS]; - - int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; - src += thread_offset; - dst += thread_offset; - dropout_mask += thread_offset; - - // load data from global memory - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : element_count; - int pad_thread_offset = ((first_batch + i) / pad_batch_stride) * stride + - ELEMENTS_PER_LDG_STG * local_idx; - const half *curr_mask = pad_mask + pad_thread_offset; -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; -#pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - // masking_value is a large negative value - elements_input[i][it + element] = -10000; - } - - if (element_index < batch_element_count) { - int itr_jmp = it * WARP_SIZE; - int itr_idx = i * element_count + itr_jmp; - copy_vector(&elements_input[i][it], - src + itr_idx); - apply_additive_mask( - &elements_input[i][it], - curr_mask + - itr_jmp); //(__half)-std::numeric_limits::infinity() - } - } - } - // convert input_t to acc_t - acc_t elements[WARP_BATCH][WARP_ITERATIONS]; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; ++it) { - elements[i][it] = elements_input[i][it]; - } - } - - constexpr uint32_t FULL_MASK = 0xffffffff; - - // compute local max_value - - // take the max_value of the first element to avoid one max call - acc_t max_value[WARP_BATCH]; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = elements[i][0]; - } - -#pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = - (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; - } - } - -// reduction max_value -#pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { - float val[WARP_BATCH]; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE); - } -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i]; - } - } - - // compute local sum - acc_t sum[WARP_BATCH]{0.0f}; - -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - for (int it = 0; it < WARP_ITERATIONS; ++it) { - elements[i][it] = std::exp(elements[i][it] - max_value[i]); - sum[i] += elements[i][it]; - } - } - -// reduction sum -#pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE); - } - } - auto seeds = at::cuda::philox::unpack(philox_args); - Philox ph(std::get<0>(seeds), tid, std::get<1>(seeds)); - uint8_t rands[WARP_BATCH][WARP_ITERATIONS]; - float4 rand_num; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { - rand_num = uniform4(ph()); - rands[i][it] = (rand_num.x <= p) > 0.5; - rands[i][it + 1] = (rand_num.y <= p) > 0.5; - rands[i][it + 2] = (rand_num.z <= p) > 0.5; - rands[i][it + 3] = (rand_num.w <= p) > 0.5; - copy_vector( - dropout_mask + i * element_count + it * WARP_SIZE, &rands[i][it]); - } - } - } - -// store result -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { - output_t out[ELEMENTS_PER_LDG_STG]; -#pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = rands[i][it + element] * - (pinv * (elements[i][it + element] / sum[i])); - } - copy_vector( - dst + i * element_count + it * WARP_SIZE, out); - - } else { - break; - } - } - } -} - -template -__global__ void additive_masked_softmax_dropout_warp_forward( - output_t *dst, uint8_t *dropout_mask, const input_t *src, - const input_t *pad_mask, int batch_size, int stride, int element_count, - int pad_batch_stride, at::PhiloxCudaState philox_args, float p) { - assert(ELEMENTS_PER_LDG_STG == 1); - int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; - int tid = blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x + - threadIdx.x; - acc_t pinv = acc_t(1) / p; - // batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the - // batch - int local_idx = threadIdx.x; - // vectorize if element_count is multiple of 4, else don't vectorize - input_t elements_input[WARP_BATCH][WARP_ITERATIONS]; - - int thread_offset = first_batch * stride + local_idx; - src += thread_offset; - dst += thread_offset; - dropout_mask += thread_offset; - - // load data from global memory - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : element_count; - int pad_thread_offset = - ((first_batch + i) / pad_batch_stride) * stride + local_idx; - const half *curr_mask = pad_mask + pad_thread_offset; -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it += 1) { - int element_index = local_idx + it * WARP_SIZE; -#pragma unroll - for (int element = 0; element < 1; ++element) { - // masking_value is a large negative value - elements_input[i][it + element] = -10000; - } - - if (element_index < batch_element_count) { - int itr_jmp = it * WARP_SIZE; - int itr_idx = i * element_count + itr_jmp; - copy_vector(&elements_input[i][it], src + itr_idx); - apply_additive_mask(&elements_input[i][it], - curr_mask + itr_jmp); - } - } - } - // convert input_t to acc_t - acc_t elements[WARP_BATCH][WARP_ITERATIONS]; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; ++it) { - elements[i][it] = elements_input[i][it]; - } - } - - constexpr uint32_t FULL_MASK = 0xffffffff; - - // compute local max_value - - // take the max_value of the first element to avoid one max call - acc_t max_value[WARP_BATCH]; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = elements[i][0]; - } - -#pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = - (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; - } - } - -// reduction max_value -#pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { - float val[WARP_BATCH]; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE); - } -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i]; - } - } - - // compute local sum - acc_t sum[WARP_BATCH]{0.0f}; - -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - for (int it = 0; it < WARP_ITERATIONS; ++it) { - elements[i][it] = std::exp(elements[i][it] - max_value[i]); - sum[i] += elements[i][it]; - } - } - -// reduction sum -#pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE); - } - } - curandStatePhilox4_32_10_t state; - auto seeds = at::cuda::philox::unpack(philox_args); - curand_init(std::get<0>(seeds), tid, std::get<1>(seeds), &state); - -// store result -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it += 1) { - int element_index = local_idx + it * WARP_SIZE; - if (element_index < element_count) { - output_t out[1]; - acc_t softmax_out[1]; - uint8_t dropout_mask_temp[1]; - // generate a vector of random numbers here - float rand = curand_uniform(&state); - float *rand_ptr = (float *)(&rand); -#pragma unroll - for (int element = 0; element < 1; ++element) { - softmax_out[element] = (elements[i][it + element] / sum[i]); - rand_ptr[element] = rand_ptr[element] <= p; - out[element] = rand_ptr[element] * pinv * softmax_out[element]; - dropout_mask_temp[element] = - rand_ptr[element] > 0.5; // just to distinguish 0.0f and 1.0f - } - copy_vector(dst + i * element_count + it * WARP_SIZE, out); - copy_vector(dropout_mask + i * element_count + - it * WARP_SIZE, - dropout_mask_temp); - - } else { - break; - } - } - } -} - -// WARP_BATCH number of batches. -// WARP_ITERATOINS The number of iterations required for one warp to iterate -// over all data. WARP_SIZE number of elements working on a single batch, has to -// be a power of two. ELEMENTS_PER_LDG_STG has to be 1. -template -using additive_masked_softmax_dropout_forward_func = void (*)( - output_t *dst, uint8_t *dropout_mask, const input_t *src, - const input_t *pad_mask, int batch_size, int stride, int element_count, - int pad_batch_stride, at::PhiloxCudaState philox_args, float p); - -template -bool warp_additive_masked_softmax_dropout_kernel( - int element_count, int log2_elements, int &warp_size, int &batches_per_warp, - additive_masked_softmax_dropout_forward_func - &kernel) { - // determine size of a warp - const int next_power_of_two = 1 << log2_elements; - warp_size = (next_power_of_two < 32) ? next_power_of_two : 32; - - // determine how many batches a warp should process. - batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - bool flag_vec4 = (element_count % 4 == 0); - switch (log2_elements) { - case 0: // 1 - kernel = &additive_masked_softmax_dropout_warp_forward; - break; - case 1: // 2 - kernel = &additive_masked_softmax_dropout_warp_forward; - break; - case 2: // 4 - kernel = &additive_masked_softmax_dropout_warp_forward; - break; - case 3: // 8 - kernel = &additive_masked_softmax_dropout_warp_forward; - break; - case 4: // 16 - kernel = &additive_masked_softmax_dropout_warp_forward; - break; - case 5: // 32 - kernel = &additive_masked_softmax_dropout_warp_forward; - break; - case 6: // 64 - kernel = &additive_masked_softmax_dropout_warp_forward; - break; - case 7: // 128 - if (flag_vec4) - kernel = &additive_masked_softmax_dropout_warp_forward_vec4< - input_t, output_t, acc_t, 2, 4, 32, 4>; - else - kernel = - &additive_masked_softmax_dropout_warp_forward; - break; - case 8: // 256 - if (flag_vec4) - kernel = &additive_masked_softmax_dropout_warp_forward_vec4< - input_t, output_t, acc_t, 1, 8, 32, 4>; - else - kernel = - &additive_masked_softmax_dropout_warp_forward; - break; - case 9: // 512 - if (flag_vec4) - kernel = &additive_masked_softmax_dropout_warp_forward_vec4< - input_t, output_t, acc_t, 1, 16, 32, 4>; - else - kernel = - &additive_masked_softmax_dropout_warp_forward; - break; - case 10: // 1024 - if (flag_vec4) - kernel = &additive_masked_softmax_dropout_warp_forward_vec4< - input_t, output_t, acc_t, 1, 32, 32, 4>; - else - kernel = - &additive_masked_softmax_dropout_warp_forward; - break; - case 11: // 2048 - if (flag_vec4) - kernel = &additive_masked_softmax_dropout_warp_forward_vec4< - input_t, output_t, acc_t, 1, 64, 32, 4>; - else - kernel = - &additive_masked_softmax_dropout_warp_forward; - break; - default: - return false; - } - return true; -} - -template -bool dispatch_additive_masked_softmax_dropout( - output_t *dst, uint8_t *dropout_mask, const input_t *src, - const input_t *pad_mask, int totalElements, int softmax_elements, - int softmax_elements_stride, int batch_count, int pad_batch_stride, float p, - cudaStream_t streamid) // p is the probability to keep, not drop -{ - - if (softmax_elements == 0) { - return true; - } else if (softmax_elements <= 2048) { - // compute function index. there's a function for each power of two size up - // to 1024. - int log2_elements = 0; - while ((1 << log2_elements) < softmax_elements) - ++log2_elements; - - additive_masked_softmax_dropout_forward_func - kernel; - int warp_size, batches_per_warp; - if (!warp_additive_masked_softmax_dropout_kernel( - softmax_elements, log2_elements, warp_size, batches_per_warp, - kernel)) { - return false; - } - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - // compute warps per block. - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - int blocks = (batch_count + batches_per_block - 1) / batches_per_block; - c10::optional gen_; - auto gen = at::get_generator_or_default( - gen_, at::cuda::detail::getDefaultCUDAGenerator()); - int64_t counter_offset = (totalElements / (blocks * threads_per_block) + 1); - at::PhiloxCudaState rng_engine_inputs; - { - std::lock_guard lock(gen->mutex_); - rng_engine_inputs = gen->philox_cuda_state(counter_offset); - } - - // compute launch size - dim3 threads(warp_size, warps_per_block, 1); - - // launch - kernel<<>>( - dst, dropout_mask, src, pad_mask, batch_count, softmax_elements_stride, - softmax_elements, pad_batch_stride, rng_engine_inputs, p); - return true; - } - return false; -} - -// WARP_BATCH number of batches. -// WARP_ITERATOINS The number of iterations required for one warp to iterate -// over all data. WARP_SIZE number of elements working on a single batch, has to -// be a power of two. ELEMENTS_PER_LDG_STG has to be 1. -template -__global__ void additive_masked_softmax_warp_forward( - input_t *dst, const output_t *src, const input_t *pad_mask, int batch_size, - int stride, int element_count, int pad_batch_stride) { - assert(ELEMENTS_PER_LDG_STG == 1); - - int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; - - // batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the - // batch - int local_idx = threadIdx.x; - - int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; - src += thread_offset; - dst += thread_offset; - - // load data from global memory - input_t elements_input[WARP_BATCH][WARP_ITERATIONS]; - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : element_count; - int pad_thread_offset = ((first_batch + i) / pad_batch_stride) * stride + - ELEMENTS_PER_LDG_STG * local_idx; - const half *curr_mask = pad_mask + pad_thread_offset; - for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; -#pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - // masking_value is a large negative value - elements_input[i][it + element] = -10000; - } - - if (element_index < batch_element_count) { - int itr_jmp = it * WARP_SIZE; - int itr_idx = i * element_count + itr_jmp; - copy_vector(&elements_input[i][it], - src + itr_idx); - // apply_mask(&elements_input[i][it], - // (__half)-std::numeric_limits::infinity(), - // curr_mask + itr_jmp); - elements_input[i][it] += *(curr_mask + itr_jmp); - } - } - } - - // convert input_t to acc_t - acc_t elements[WARP_BATCH][WARP_ITERATIONS]; - for (int i = 0; i < WARP_BATCH; ++i) { - for (int it = 0; it < WARP_ITERATIONS; ++it) { - elements[i][it] = elements_input[i][it]; - } - } - - constexpr uint32_t FULL_MASK = 0xffffffff; - - // compute local max_value - - // take the max_value of the first element to avoid one max call - acc_t max_value[WARP_BATCH]; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = elements[i][0]; - } - -#pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = - (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; - } - } - -// reduction max_value -#pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { - float val[WARP_BATCH]; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE); - } -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i]; - } - } - - // compute local sum - acc_t sum[WARP_BATCH]{0.0f}; - -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - for (int it = 0; it < WARP_ITERATIONS; ++it) { - // elements[i][it] = expf(elements[i][it] - max_value[i]); - elements[i][it] = std::exp(elements[i][it] - max_value[i]); - sum[i] += elements[i][it]; - } - } - -// reduction sum -#pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE); - } - } - -// store result -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { - // dst[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i]; - output_t out[ELEMENTS_PER_LDG_STG]; - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = elements[i][it + element] / sum[i]; - } - copy_vector( - dst + i * element_count + it * WARP_SIZE, out); - } else { - break; - } - } - } -} - -// WARP_BATCH number of batches. -// WARP_ITERATOINS The number of iterations required for one warp to iterate -// over all data. WARP_SIZE number of elements working on a single batch, has to -// be a power of two. ELEMENTS_PER_LDG_STG has to be 1. -template -using additive_masked_softmax_forward_func = void (*)( - input_t *dst, const output_t *src, const half *pad_mask, int batch_size, - int stride, int element_count, int pad_batch_stride); - -template -bool warp_additive_masked_softmax_kernel( - int log2_elements, int &warp_size, int &batches_per_warp, - additive_masked_softmax_forward_func &kernel) { - // determine size of a warp - const int next_power_of_two = 1 << log2_elements; - warp_size = (next_power_of_two < 32) ? next_power_of_two : 32; - - // determine how many batches a warp should process. - batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - switch (log2_elements) { - case 0: // 1 - kernel = &additive_masked_softmax_warp_forward; - break; - case 1: // 2 - kernel = &additive_masked_softmax_warp_forward; - break; - case 2: // 4 - kernel = &additive_masked_softmax_warp_forward; - break; - case 3: // 8 - kernel = &additive_masked_softmax_warp_forward; - break; - case 4: // 16 - kernel = &additive_masked_softmax_warp_forward; - break; - case 5: // 32 - kernel = &additive_masked_softmax_warp_forward; - break; - case 6: // 64 - kernel = &additive_masked_softmax_warp_forward; - break; - case 7: // 128 - kernel = &additive_masked_softmax_warp_forward; - break; - case 8: // 256 - kernel = &additive_masked_softmax_warp_forward; - break; - case 9: // 512 - kernel = &additive_masked_softmax_warp_forward; - break; - case 10: // 1024 - kernel = &additive_masked_softmax_warp_forward; - break; - default: - return false; - } - return true; -} - -template -bool dispatch_additive_masked_softmax(output_t *dst, const input_t *src, - const input_t *pad_mask, - int softmax_elements, - int softmax_elements_stride, - int batch_count, int pad_batch_stride) { - if (softmax_elements == 0) { - return true; - } else if (softmax_elements <= 1024) { - // compute function index. there's a function for each power of two size up - // to 1024. - int log2_elements = 0; - while ((1 << log2_elements) < softmax_elements) - ++log2_elements; - - additive_masked_softmax_forward_func kernel; - int warp_size, batches_per_warp; - if (!warp_additive_masked_softmax_kernel( - log2_elements, warp_size, batches_per_warp, kernel)) { - return false; - } - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - // compute warps per block. - int warps_per_block = (threads_per_block / warp_size); - - // compute launch size - int batches_per_block = warps_per_block * batches_per_warp; - int blocks = (batch_count + batches_per_block - 1) / batches_per_block; - dim3 threads(warp_size, warps_per_block, 1); - - // launch - kernel<<>>( - dst, src, pad_mask, batch_count, softmax_elements_stride, - softmax_elements, pad_batch_stride); - return true; - } - return false; -} - -template -bool dispatch_additive_masked_softmax_stream( - output_t *dst, const input_t *src, const input_t *pad_mask, - int softmax_elements, int softmax_elements_stride, int batch_count, - int pad_batch_stride, cudaStream_t streamid) { - if (softmax_elements == 0) { - return true; - } else if (softmax_elements <= 1024) { - // compute function index. there's a function for each power of two size up - // to 1024. - int log2_elements = 0; - while ((1 << log2_elements) < softmax_elements) - ++log2_elements; - additive_masked_softmax_forward_func kernel; - int warp_size, batches_per_warp; - if (!warp_additive_masked_softmax_kernel( - log2_elements, warp_size, batches_per_warp, kernel)) { - return false; - } - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - // compute warps per block. - int warps_per_block = (threads_per_block / warp_size); - // compute launch size - int batches_per_block = warps_per_block * batches_per_warp; - int blocks = (batch_count + batches_per_block - 1) / batches_per_block; - dim3 threads(warp_size, warps_per_block, 1); - // launch - kernel<<>>( - dst, src, pad_mask, batch_count, softmax_elements_stride, - softmax_elements, pad_batch_stride); - return true; - } - return false; -} - -// WARP_BATCH number of batches. -// WARP_ITERATOINS The number of iterations required for one warp to iterate -// over all data. WARP_SIZE number of elements working on a single batch, has to -// be a power of two. ELEMENTS_PER_LDG_STG has to be 1. -template -__global__ void -masked_softmax_warp_forward(input_t *dst, const output_t *src, - const uint8_t *pad_mask, int batch_size, int stride, - int element_count, int pad_batch_stride) { - assert(ELEMENTS_PER_LDG_STG == 1); - - int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; - - // batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the - // batch - int local_idx = threadIdx.x; - - int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; - src += thread_offset; - dst += thread_offset; - - // load data from global memory - input_t elements_input[WARP_BATCH][WARP_ITERATIONS]; - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : element_count; - int pad_thread_offset = ((first_batch + i) / pad_batch_stride) * stride + - ELEMENTS_PER_LDG_STG * local_idx; - const uint8_t *curr_mask = pad_mask + pad_thread_offset; - for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; -#pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - elements_input[i][it + element] = - -std::numeric_limits::infinity(); - } - - if (element_index < batch_element_count) { - int itr_jmp = it * WARP_SIZE; - int itr_idx = i * element_count + itr_jmp; - copy_vector(&elements_input[i][it], - src + itr_idx); - apply_mask( - &elements_input[i][it], - (__half)-std::numeric_limits::infinity(), - curr_mask + itr_jmp); - } - } - } - - // convert input_t to acc_t - acc_t elements[WARP_BATCH][WARP_ITERATIONS]; - for (int i = 0; i < WARP_BATCH; ++i) { - for (int it = 0; it < WARP_ITERATIONS; ++it) { - elements[i][it] = elements_input[i][it]; - } - } - - constexpr uint32_t FULL_MASK = 0xffffffff; - - // compute local max_value - - // take the max_value of the first element to avoid one max call - acc_t max_value[WARP_BATCH]; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = elements[i][0]; - } - -#pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = - (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; - } - } - -// reduction max_value -#pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { - float val[WARP_BATCH]; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE); - } -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i]; - } - } - - // compute local sum - acc_t sum[WARP_BATCH]{0.0f}; - -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - for (int it = 0; it < WARP_ITERATIONS; ++it) { - // elements[i][it] = expf(elements[i][it] - max_value[i]); - elements[i][it] = std::exp(elements[i][it] - max_value[i]); - sum[i] += elements[i][it]; - } - } - -// reduction sum -#pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE); - } - } - -// store result -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { - // dst[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i]; - output_t out[ELEMENTS_PER_LDG_STG]; - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = elements[i][it + element] / sum[i]; - } - copy_vector( - dst + i * element_count + it * WARP_SIZE, out); - } else { - break; - } - } - } -} - -// WARP_BATCH number of batches. -// WARP_ITERATOINS The number of iterations required for one warp to iterate -// over all data. WARP_SIZE number of elements working on a single batch, has to -// be a power of two. ELEMENTS_PER_LDG_STG has to be 1. -template -using masked_softmax_forward_func = void (*)(input_t *dst, const output_t *src, - const uint8_t *pad_mask, - int batch_size, int stride, - int element_count, - int pad_batch_stride); - -template -bool warp_masked_softmax_kernel( - int log2_elements, int &warp_size, int &batches_per_warp, - masked_softmax_forward_func &kernel) { - // determine size of a warp - const int next_power_of_two = 1 << log2_elements; - warp_size = (next_power_of_two < 32) ? next_power_of_two : 32; - - // determine how many batches a warp should process. - batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - switch (log2_elements) { - case 0: // 1 - kernel = &masked_softmax_warp_forward; - break; - case 1: // 2 - kernel = &masked_softmax_warp_forward; - break; - case 2: // 4 - kernel = &masked_softmax_warp_forward; - break; - case 3: // 8 - kernel = &masked_softmax_warp_forward; - break; - case 4: // 16 - kernel = - &masked_softmax_warp_forward; - break; - case 5: // 32 - kernel = - &masked_softmax_warp_forward; - break; - case 6: // 64 - kernel = - &masked_softmax_warp_forward; - break; - case 7: // 128 - kernel = - &masked_softmax_warp_forward; - break; - case 8: // 256 - kernel = - &masked_softmax_warp_forward; - break; - case 9: // 512 - kernel = - &masked_softmax_warp_forward; - break; - case 10: // 1024 - kernel = - &masked_softmax_warp_forward; - break; - default: - return false; - } - return true; -} - -template -bool dispatch_masked_softmax(output_t *dst, const input_t *src, - const uint8_t *pad_mask, int softmax_elements, - int softmax_elements_stride, int batch_count, - int pad_batch_stride) { - if (softmax_elements == 0) { - return true; - } else if (softmax_elements <= 1024) { - // compute function index. there's a function for each power of two size up - // to 1024. - int log2_elements = 0; - while ((1 << log2_elements) < softmax_elements) - ++log2_elements; - - masked_softmax_forward_func kernel; - int warp_size, batches_per_warp; - if (!warp_masked_softmax_kernel( - log2_elements, warp_size, batches_per_warp, kernel)) { - return false; - } - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - // compute warps per block. - int warps_per_block = (threads_per_block / warp_size); - - // compute launch size - int batches_per_block = warps_per_block * batches_per_warp; - int blocks = (batch_count + batches_per_block - 1) / batches_per_block; - dim3 threads(warp_size, warps_per_block, 1); - - // launch - kernel<<>>( - dst, src, pad_mask, batch_count, softmax_elements_stride, - softmax_elements, pad_batch_stride); - return true; - } - return false; -} - -// WARP_BATCH number of batches. -// WARP_ITERATOINS The number of iterations required for one warp to iterate -// over all data. WARP_SIZE number of elements working on a single batch, has to -// be a power of two. ELEMENTS_PER_LDG_STG has to be 1. -template -__global__ void time_masked_softmax_warp_forward( - input_t *dst, const output_t *src, const uint8_t *pad_mask, int batch_size, - int stride, int element_count, int mod_seq_len) { - assert(ELEMENTS_PER_LDG_STG == 1); - - int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; - - // batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the - // batch - int local_idx = threadIdx.x; - - int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; - src += thread_offset; - dst += thread_offset; - - // load data from global memory - input_t elements_input[WARP_BATCH][WARP_ITERATIONS]; - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : element_count; - int pad_thread_offset = ((first_batch + i) % mod_seq_len) * stride + - ELEMENTS_PER_LDG_STG * local_idx; - const uint8_t *curr_mask = pad_mask + pad_thread_offset; - for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; -#pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - elements_input[i][it + element] = - -std::numeric_limits::infinity(); - } - - if (element_index < batch_element_count) { - int itr_jmp = it * WARP_SIZE; - int itr_idx = i * element_count + itr_jmp; - copy_vector(&elements_input[i][it], - src + itr_idx); - apply_mask( - &elements_input[i][it], - (__half)-std::numeric_limits::infinity(), - curr_mask + itr_jmp); - } - } - } - - // convert input_t to acc_t - acc_t elements[WARP_BATCH][WARP_ITERATIONS]; - for (int i = 0; i < WARP_BATCH; ++i) { - for (int it = 0; it < WARP_ITERATIONS; ++it) { - elements[i][it] = elements_input[i][it]; - } - } - - constexpr uint32_t FULL_MASK = 0xffffffff; - - // compute local max_value - - // take the max_value of the first element to avoid one max call - acc_t max_value[WARP_BATCH]; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = elements[i][0]; - } - -#pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = - (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; - } - } - -// reduction max_value -#pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { - float val[WARP_BATCH]; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE); - } -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i]; - } - } - - // compute local sum - acc_t sum[WARP_BATCH]{0.0f}; - -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - for (int it = 0; it < WARP_ITERATIONS; ++it) { - // elements[i][it] = expf(elements[i][it] - max_value[i]); - elements[i][it] = std::exp(elements[i][it] - max_value[i]); - sum[i] += elements[i][it]; - } - } - -// reduction sum -#pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE); - } - } - -// store result -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { - // dst[i * element_count + it * WARP_SIZE] = elements[i][it] / sum[i]; - output_t out[ELEMENTS_PER_LDG_STG]; - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = elements[i][it + element] / sum[i]; - } - copy_vector( - dst + i * element_count + it * WARP_SIZE, out); - } else { - break; - } - } - } -} - -// WARP_BATCH number of batches. -// WARP_ITERATOINS The number of iterations required for one warp to iterate -// over all data. WARP_SIZE number of elements working on a single batch, has to -// be a power of two. ELEMENTS_PER_LDG_STG has to be 1. -template -using time_masked_softmax_forward_func = - void (*)(input_t *dst, const output_t *src, const uint8_t *pad_mask, - int batch_size, int stride, int element_count, int mod_seq_len); - -template -bool warp_time_masked_softmax_kernel( - int log2_elements, int &warp_size, int &batches_per_warp, - time_masked_softmax_forward_func &kernel) { - // determine size of a warp - const int next_power_of_two = 1 << log2_elements; - warp_size = (next_power_of_two < 32) ? next_power_of_two : 32; - - // determine how many batches a warp should process. - batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - switch (log2_elements) { - case 0: // 1 - kernel = - &time_masked_softmax_warp_forward; - break; - case 1: // 2 - kernel = - &time_masked_softmax_warp_forward; - break; - case 2: // 4 - kernel = - &time_masked_softmax_warp_forward; - break; - case 3: // 8 - kernel = - &time_masked_softmax_warp_forward; - break; - case 4: // 16 - kernel = &time_masked_softmax_warp_forward; - break; - case 5: // 32 - kernel = &time_masked_softmax_warp_forward; - break; - case 6: // 64 - kernel = &time_masked_softmax_warp_forward; - break; - case 7: // 128 - kernel = &time_masked_softmax_warp_forward; - break; - case 8: // 256 - kernel = &time_masked_softmax_warp_forward; - break; - case 9: // 512 - kernel = &time_masked_softmax_warp_forward; - break; - case 10: // 1024 - kernel = &time_masked_softmax_warp_forward; - break; - default: - return false; - } - return true; -} - -template -bool dispatch_time_masked_softmax(output_t *dst, const input_t *src, - const uint8_t *pad_mask, int softmax_elements, - int softmax_elements_stride, int batch_count, - int mod_seq_len) { - if (softmax_elements == 0) { - return true; - } else if (softmax_elements <= 1024) { - // compute function index. there's a function for each power of two size up - // to 1024. - int log2_elements = 0; - while ((1 << log2_elements) < softmax_elements) - ++log2_elements; - - time_masked_softmax_forward_func kernel; - int warp_size, batches_per_warp; - if (!warp_time_masked_softmax_kernel( - log2_elements, warp_size, batches_per_warp, kernel)) { - return false; - } - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - // compute warps per block. - int warps_per_block = (threads_per_block / warp_size); - - // compute launch size - int batches_per_block = warps_per_block * batches_per_warp; - int blocks = (batch_count + batches_per_block - 1) / batches_per_block; - dim3 threads(warp_size, warps_per_block, 1); - - // launch - kernel<<>>( - dst, src, pad_mask, batch_count, softmax_elements_stride, - softmax_elements, mod_seq_len); - return true; - } - return false; -} - -int log2_ceil_native(int value) { - int log2_value = 0; - while ((1 << log2_value) < value) - ++log2_value; - return log2_value; -} - -template -__device__ __forceinline__ T -WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, - unsigned int mask = 0xffffffff) { -#if CUDA_VERSION >= 9000 - return __shfl_xor_sync(mask, value, laneMask, width); -#else - return __shfl_xor(value, laneMask, width); -#endif -} - -template -__device__ __forceinline__ void warp_reduce_sum(acc_t *sum) { -#pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); - sum[i] = sum[i] + b; - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -// Warp softmax backward functions as fused variants of -// at::softmax_backward_data function -//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - -// softmax backward data function is taken from native pytorch, elementwise mul -// is fused in the epolog, as well as masking and scaling for fusing dropout - -template -__global__ void masked_scale_softmax_warp_backward_masked_dgrad( - output_t *gradInput, const input_t *grad, const input_t *output, - const uint8_t *mask, const uint8_t *pad_mask, acc_t scale, int batch_size, - int stride, int element_count, int heads) { - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_backward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = - (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - - int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; - - // batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the - // batch - int local_idx = threadIdx.x % WARP_SIZE; - - // the first element to process by the current thread - int thread_offset = first_batch * stride + local_idx; - grad += thread_offset; - output += thread_offset; - gradInput += thread_offset; - mask += thread_offset; - - // The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified - // to one loop, but I think doing so would obfuscate the logic of the - // algorithm, thus I chose to keep the nested loops. This should have no - // impact on performance because the loops are unrolled anyway. - - // load data from global memory - acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]; - acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]; - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : element_count; - for (int it = 0; it < WARP_ITERATIONS; ++it) { - int element_index = local_idx + it * WARP_SIZE; - if (element_index < batch_element_count) { - grad_reg[i][it] = - (input_t)((acc_t)mask[i * element_count + it * WARP_SIZE] * - (acc_t)grad[i * element_count + it * WARP_SIZE] * - (acc_t)scale) * - output[i * element_count + it * WARP_SIZE]; - output_reg[i][it] = output[i * element_count + it * WARP_SIZE]; - } else { - grad_reg[i][it] = acc_t(0); - output_reg[i][it] = acc_t(0); - } - } - } - - acc_t sum[WARP_BATCH]; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] = grad_reg[i][0]; -#pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - sum[i] += grad_reg[i][it]; - } - } - warp_reduce_sum(sum); - -// store result -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; ++it) { - int element_index = local_idx + it * WARP_SIZE; - if (element_index < element_count) { - // compute gradients - int total_ind = thread_offset + i * element_count + it * WARP_SIZE; - int pad_mask_ind = - element_count * - (total_ind / (heads * element_count * element_count)) + - total_ind % element_count; - uint8_t pad_mask_element = 1 - pad_mask[pad_mask_ind]; - if (pad_mask_element == 0) - gradInput[i * element_count + it * WARP_SIZE] = 0; - else { - if (is_log_softmax) { - gradInput[i * element_count + it * WARP_SIZE] = - (grad_reg[i][it] - std::exp(output_reg[i][it]) * sum[i]); - } else { - gradInput[i * element_count + it * WARP_SIZE] = - (grad_reg[i][it] - output_reg[i][it] * sum[i]); - } - } - } - } - } -} -template -void dispatch_masked_scale_softmax_backward_masked_out( - output_t *grad_input, const input_t *grad, const input_t *output, - const uint8_t *mask, const uint8_t *pad_mask, acc_t scale, - int softmax_elements, int softmax_elements_stride, int batch_count, - int heads) { - TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 1024); - if (softmax_elements == 0) { - return; - } else { - int log2_elements = log2_ceil_native(softmax_elements); - const int next_power_of_two = 1 << log2_elements; - - // This value must match the WARP_SIZE constexpr value computed inside - // softmax_warp_backward. - int warp_size = - (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside - // softmax_warp_backward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - int blocks = (batch_count + batches_per_block - 1) / batches_per_block; - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - masked_scale_softmax_warp_backward_masked_dgrad - <<>>( - grad_input, grad, output, mask, pad_mask, scale, batch_count, - softmax_elements_stride, softmax_elements, heads); - break; - case 1: // 2 - masked_scale_softmax_warp_backward_masked_dgrad - <<>>( - grad_input, grad, output, mask, pad_mask, scale, batch_count, - softmax_elements_stride, softmax_elements, heads); - break; - case 2: // 4 - masked_scale_softmax_warp_backward_masked_dgrad - <<>>( - grad_input, grad, output, mask, pad_mask, scale, batch_count, - softmax_elements_stride, softmax_elements, heads); - break; - case 3: // 8 - masked_scale_softmax_warp_backward_masked_dgrad - <<>>( - grad_input, grad, output, mask, pad_mask, scale, batch_count, - softmax_elements_stride, softmax_elements, heads); - break; - case 4: // 16 - masked_scale_softmax_warp_backward_masked_dgrad - <<>>( - grad_input, grad, output, mask, pad_mask, scale, batch_count, - softmax_elements_stride, softmax_elements, heads); - break; - case 5: // 32 - masked_scale_softmax_warp_backward_masked_dgrad - <<>>( - grad_input, grad, output, mask, pad_mask, scale, batch_count, - softmax_elements_stride, softmax_elements, heads); - break; - case 6: // 64 - masked_scale_softmax_warp_backward_masked_dgrad - <<>>( - grad_input, grad, output, mask, pad_mask, scale, batch_count, - softmax_elements_stride, softmax_elements, heads); - break; - case 7: // 128 - masked_scale_softmax_warp_backward_masked_dgrad - <<>>( - grad_input, grad, output, mask, pad_mask, scale, batch_count, - softmax_elements_stride, softmax_elements, heads); - break; - case 8: // 256 - masked_scale_softmax_warp_backward_masked_dgrad - <<>>( - grad_input, grad, output, mask, pad_mask, scale, batch_count, - softmax_elements_stride, softmax_elements, heads); - break; - case 9: // 512 - masked_scale_softmax_warp_backward_masked_dgrad - <<>>( - grad_input, grad, output, mask, pad_mask, scale, batch_count, - softmax_elements_stride, softmax_elements, heads); - break; - case 10: // 1024 - masked_scale_softmax_warp_backward_masked_dgrad - <<>>( - grad_input, grad, output, mask, pad_mask, scale, batch_count, - softmax_elements_stride, softmax_elements, heads); - break; - default: - break; - } - } -} - -template -void dispatch_masked_scale_softmax_backward_masked_out_stream( - output_t *grad_input, const input_t *grad, const input_t *output, - const uint8_t *mask, const uint8_t *pad_mask, acc_t scale, - int softmax_elements, int softmax_elements_stride, int batch_count, - int heads, cudaStream_t streamid) { - TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 1024); - if (softmax_elements == 0) { - return; - } else { - int log2_elements = log2_ceil_native(softmax_elements); - const int next_power_of_two = 1 << log2_elements; - // This value must match the WARP_SIZE constexpr value computed inside - // softmax_warp_backward. - int warp_size = - (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - // This value must match the WARP_BATCH constexpr value computed inside - // softmax_warp_backward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - int blocks = (batch_count + batches_per_block - 1) / batches_per_block; - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - masked_scale_softmax_warp_backward_masked_dgrad - <<>>( - grad_input, grad, output, mask, pad_mask, scale, batch_count, - softmax_elements_stride, softmax_elements, heads); - break; - case 1: // 2 - masked_scale_softmax_warp_backward_masked_dgrad - <<>>( - grad_input, grad, output, mask, pad_mask, scale, batch_count, - softmax_elements_stride, softmax_elements, heads); - break; - case 2: // 4 - masked_scale_softmax_warp_backward_masked_dgrad - <<>>( - grad_input, grad, output, mask, pad_mask, scale, batch_count, - softmax_elements_stride, softmax_elements, heads); - break; - case 3: // 8 - masked_scale_softmax_warp_backward_masked_dgrad - <<>>( - grad_input, grad, output, mask, pad_mask, scale, batch_count, - softmax_elements_stride, softmax_elements, heads); - break; - case 4: // 16 - masked_scale_softmax_warp_backward_masked_dgrad - <<>>( - grad_input, grad, output, mask, pad_mask, scale, batch_count, - softmax_elements_stride, softmax_elements, heads); - break; - case 5: // 32 - masked_scale_softmax_warp_backward_masked_dgrad - <<>>( - grad_input, grad, output, mask, pad_mask, scale, batch_count, - softmax_elements_stride, softmax_elements, heads); - break; - case 6: // 64 - masked_scale_softmax_warp_backward_masked_dgrad - <<>>( - grad_input, grad, output, mask, pad_mask, scale, batch_count, - softmax_elements_stride, softmax_elements, heads); - break; - case 7: // 128 - masked_scale_softmax_warp_backward_masked_dgrad - <<>>( - grad_input, grad, output, mask, pad_mask, scale, batch_count, - softmax_elements_stride, softmax_elements, heads); - break; - case 8: // 256 - masked_scale_softmax_warp_backward_masked_dgrad - <<>>( - grad_input, grad, output, mask, pad_mask, scale, batch_count, - softmax_elements_stride, softmax_elements, heads); - break; - case 9: // 512 - masked_scale_softmax_warp_backward_masked_dgrad - <<>>( - grad_input, grad, output, mask, pad_mask, scale, batch_count, - softmax_elements_stride, softmax_elements, heads); - break; - case 10: // 1024 - masked_scale_softmax_warp_backward_masked_dgrad - <<>>( - grad_input, grad, output, mask, pad_mask, scale, batch_count, - softmax_elements_stride, softmax_elements, heads); - break; - default: - break; - } - } -} - -template -__global__ void -masked_scale_softmax_warp_backward(output_t *gradInput, const input_t *grad, - const input_t *output, const uint8_t *mask, - acc_t scale, int batch_size, int stride, - int element_count) { - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_backward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = - (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - - int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; - - // batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the - // batch - int local_idx = threadIdx.x % WARP_SIZE; - - // the first element to process by the current thread - int thread_offset = first_batch * stride + local_idx; - grad += thread_offset; - output += thread_offset; - gradInput += thread_offset; - mask += thread_offset; - - // The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified - // to one loop, but I think doing so would obfuscate the logic of the - // algorithm, thus I chose to keep the nested loops. This should have no - // impact on performance because the loops are unrolled anyway. - - // load data from global memory - acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]; - acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]; - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : element_count; - for (int it = 0; it < WARP_ITERATIONS; ++it) { - int element_index = local_idx + it * WARP_SIZE; - if (element_index < batch_element_count) { - grad_reg[i][it] = - (input_t)((acc_t)mask[i * element_count + it * WARP_SIZE] * - (acc_t)grad[i * element_count + it * WARP_SIZE] * - (acc_t)scale) * - output[i * element_count + it * WARP_SIZE]; - output_reg[i][it] = output[i * element_count + it * WARP_SIZE]; - } else { - grad_reg[i][it] = acc_t(0); - output_reg[i][it] = acc_t(0); - } - } - } - - acc_t sum[WARP_BATCH]; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] = grad_reg[i][0]; -#pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - sum[i] += grad_reg[i][it]; - } - } - warp_reduce_sum(sum); - -// store result -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; ++it) { - int element_index = local_idx + it * WARP_SIZE; - if (element_index < element_count) { - // compute gradients - if (is_log_softmax) { - gradInput[i * element_count + it * WARP_SIZE] = - (grad_reg[i][it] - std::exp(output_reg[i][it]) * sum[i]); - } else { - gradInput[i * element_count + it * WARP_SIZE] = - (grad_reg[i][it] - output_reg[i][it] * sum[i]); - } - } - } - } -} - -template -__global__ void masked_scale_softmax_warp_backward_recompute( - output_t *gradInput, const input_t *grad, const input_t *softmax_input, - const input_t *pad_mask, const uint8_t *mask, acc_t scale, int batch_size, - int stride, int pad_batch_stride, int element_count) { - int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; - - // batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the - // batch - int local_idx = threadIdx.x % WARP_SIZE; - // vectorize if a row length is multiple of 4 - int flag_vec4 = element_count & 3 == 0; - acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]; - input_t elements_input[WARP_BATCH][WARP_ITERATIONS]; - - // the first element to process by the current thread - int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; - - grad += thread_offset; - softmax_input += thread_offset; - gradInput += thread_offset; - mask += thread_offset; - - // The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified - // to one loop, but I think doing so would obfuscate the logic of the - // algorithm, thus I chose to keep the nested loops. This should have no - // impact on performance because the loops are unrolled anyway. - - // load data from global memory - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : element_count; - int pad_thread_offset = ((first_batch + i) / pad_batch_stride) * stride + - ELEMENTS_PER_LDG_STG * local_idx; - const input_t *curr_mask = pad_mask + pad_thread_offset; -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - -#pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - // masking_value is a large negative value - elements_input[i][it + element] = -10000; - grad_reg[i][it + element] = acc_t(0); - } - - if (element_index < batch_element_count) { - int itr_jmp = it * WARP_SIZE; - int itr_idx = i * element_count + itr_jmp; - copy_vector(&elements_input[i][it], - softmax_input + itr_idx); - apply_additive_mask( - &elements_input[i][it], - curr_mask + - itr_jmp); //(__half)-std::numeric_limits::infinity() - uint8_t mask_temp[ELEMENTS_PER_LDG_STG]; - input_t grad_temp[ELEMENTS_PER_LDG_STG]; - copy_vector(&mask_temp[0], - mask + itr_idx); - copy_vector(&grad_temp[0], - grad + itr_idx); -#pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - grad_reg[i][it + element] = - ((acc_t)mask_temp[element] * (acc_t)grad_temp[element] * - (acc_t)scale); - } - } - } - } - // load data from global memory - - // convert input_t to acc_t - // TODO : remove this, input is already acc_t type in register - acc_t elements[WARP_BATCH][WARP_ITERATIONS]; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; ++it) { - elements[i][it] = elements_input[i][it]; - } - } - - constexpr uint32_t FULL_MASK = 0xffffffff; - - // compute local max_value - - // take the max_value of the first element to avoid one max call - acc_t max_value[WARP_BATCH]; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = elements[i][0]; - } - -#pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = - (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; - } - } - -// reduction max_value -#pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { - float val[WARP_BATCH]; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE); - } -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = max_value[i] > val[i] ? max_value[i] : val[i]; - } - } - - // compute local sum - acc_t sum[WARP_BATCH]{0.0f}; - -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - for (int it = 0; it < WARP_ITERATIONS; ++it) { - // elements[i][it] = expf(elements[i][it] - max_value[i]); - elements[i][it] = std::exp(elements[i][it] - max_value[i]); - sum[i] += elements[i][it]; - } - } - -// reduction sum -#pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE); - } - } - -// store result -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it++) { - elements[i][it] = elements[i][it] / sum[i]; - grad_reg[i][it] = grad_reg[i][it] * elements[i][it]; - } - } - - acc_t grad_sum[WARP_BATCH]; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - grad_sum[i] = grad_reg[i][0]; -#pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - grad_sum[i] += grad_reg[i][it]; - } - } - warp_reduce_sum(grad_sum); - -// store result -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { - // compute gradients - output_t grad_input_reg[ELEMENTS_PER_LDG_STG]; -#pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; element++) { - if (is_log_softmax) { - grad_input_reg[element] = - (grad_reg[i][it + element] - - std::exp(elements[i][it + element]) * grad_sum[i]); - } else { - grad_input_reg[element] = (grad_reg[i][it + element] - - elements[i][it + element] * grad_sum[i]); - } - } - copy_vector( - gradInput + i * element_count + it * WARP_SIZE, grad_input_reg); - } - } - } -} - -template -using masked_scale_softmax_warp_backward_recompute_func = void (*)( - output_t *gradInput, const input_t *grad, const input_t *softmax_input, - const input_t *pad_mask, const uint8_t *mask, acc_t scale, int batch_size, - int stride, int pad_batch_stride, int element_count); - -template -bool masked_scale_softmax_warp_backward_recompute_kernel( - int element_count, int log2_elements, int &warp_size, int &batches_per_warp, - masked_scale_softmax_warp_backward_recompute_func &kernel) { - // determine size of a warp - const int next_power_of_two = 1 << log2_elements; - warp_size = (next_power_of_two < 32) ? next_power_of_two : 32; - - // determine how many batches a warp should process. - batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - bool flag_vec4 = (element_count % 4 == 0); - switch (log2_elements) { - case 0: // 1 - kernel = &masked_scale_softmax_warp_backward_recompute< - input_t, output_t, acc_t, 2, 1, 1, 1, is_log_softmax>; - break; - case 1: // 2 - kernel = &masked_scale_softmax_warp_backward_recompute< - input_t, output_t, acc_t, 2, 1, 2, 1, is_log_softmax>; - break; - case 2: // 4 - kernel = &masked_scale_softmax_warp_backward_recompute< - input_t, output_t, acc_t, 2, 1, 4, 1, is_log_softmax>; - break; - case 3: // 8 - kernel = &masked_scale_softmax_warp_backward_recompute< - input_t, output_t, acc_t, 2, 1, 8, 1, is_log_softmax>; - break; - case 4: // 16 - kernel = &masked_scale_softmax_warp_backward_recompute< - input_t, output_t, acc_t, 2, 1, 16, 1, is_log_softmax>; - break; - case 5: // 32 - kernel = &masked_scale_softmax_warp_backward_recompute< - input_t, output_t, acc_t, 2, 1, 32, 1, is_log_softmax>; - break; - case 6: // 64 - kernel = &masked_scale_softmax_warp_backward_recompute< - input_t, output_t, acc_t, 2, 2, 32, 1, is_log_softmax>; - break; - case 7: // 128 - kernel = &masked_scale_softmax_warp_backward_recompute< - input_t, output_t, acc_t, 2, 4, 32, 1, is_log_softmax>; - break; - case 8: // 256 - if (flag_vec4) - kernel = &masked_scale_softmax_warp_backward_recompute< - input_t, output_t, acc_t, 1, 8, 32, 4, is_log_softmax>; - else - kernel = &masked_scale_softmax_warp_backward_recompute< - input_t, output_t, acc_t, 1, 8, 32, 1, is_log_softmax>; - break; - case 9: // 512 - if (flag_vec4) - kernel = &masked_scale_softmax_warp_backward_recompute< - input_t, output_t, acc_t, 1, 16, 32, 4, is_log_softmax>; - else - kernel = &masked_scale_softmax_warp_backward_recompute< - input_t, output_t, acc_t, 1, 16, 32, 1, is_log_softmax>; - break; - case 10: // 1024 - if (flag_vec4) - kernel = &masked_scale_softmax_warp_backward_recompute< - input_t, output_t, acc_t, 1, 32, 32, 4, is_log_softmax>; - else - kernel = &masked_scale_softmax_warp_backward_recompute< - input_t, output_t, acc_t, 1, 32, 32, 1, is_log_softmax>; - break; - case 11: // 2048 - if (flag_vec4) - kernel = &masked_scale_softmax_warp_backward_recompute< - input_t, output_t, acc_t, 1, 64, 32, 4, is_log_softmax>; - else - kernel = &masked_scale_softmax_warp_backward_recompute< - input_t, output_t, acc_t, 1, 64, 32, 1, is_log_softmax>; - break; - default: - return false; - } - return true; -} - -template -bool dispatch_masked_scale_softmax_backward_recompute( - output_t *grad_input, const input_t *grad, const input_t *softmax_input, - const input_t *pad_mask, const uint8_t *mask, acc_t scale, - int softmax_elements, int softmax_elements_stride, int pad_batch_stride, - int batch_count, cudaStream_t streamid) { - - if (softmax_elements == 0) { - return true; - } else if (softmax_elements <= 2048) { - // compute function index. there's a function for each power of two size up - // to 1024. - int log2_elements = 0; - while ((1 << log2_elements) < softmax_elements) - ++log2_elements; - - masked_scale_softmax_warp_backward_recompute_func - kernel; - int warp_size, batches_per_warp; - if (!masked_scale_softmax_warp_backward_recompute_kernel< - input_t, output_t, acc_t, is_log_softmax>( - softmax_elements, log2_elements, warp_size, batches_per_warp, - kernel)) { - return false; - } - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - // compute warps per block. - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - int blocks = (batch_count + batches_per_block - 1) / batches_per_block; - - // compute launch size - dim3 threads(warp_size, warps_per_block, 1); - - // launch - kernel<<>>( - grad_input, grad, softmax_input, pad_mask, mask, scale, batch_count, - softmax_elements_stride, pad_batch_stride, softmax_elements); - return true; - } - return false; -} - -template -void dispatch_masked_scale_softmax_backward_stream( - output_t *grad_input, const input_t *grad, const input_t *output, - const uint8_t *mask, acc_t scale, int softmax_elements, - int softmax_elements_stride, int batch_count, cudaStream_t streamid) { - TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 1024); - if (softmax_elements == 0) { - return; - } else { - int log2_elements = log2_ceil_native(softmax_elements); - const int next_power_of_two = 1 << log2_elements; - // This value must match the WARP_SIZE constexpr value computed inside - // softmax_warp_backward. - int warp_size = - (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - // This value must match the WARP_BATCH constexpr value computed inside - // softmax_warp_backward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - int blocks = (batch_count + batches_per_block - 1) / batches_per_block; - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - masked_scale_softmax_warp_backward - <<>>( - grad_input, grad, output, mask, scale, batch_count, - softmax_elements_stride, softmax_elements); - break; - case 1: // 2 - masked_scale_softmax_warp_backward - <<>>( - grad_input, grad, output, mask, scale, batch_count, - softmax_elements_stride, softmax_elements); - break; - case 2: // 4 - masked_scale_softmax_warp_backward - <<>>( - grad_input, grad, output, mask, scale, batch_count, - softmax_elements_stride, softmax_elements); - break; - case 3: // 8 - masked_scale_softmax_warp_backward - <<>>( - grad_input, grad, output, mask, scale, batch_count, - softmax_elements_stride, softmax_elements); - break; - case 4: // 16 - masked_scale_softmax_warp_backward - <<>>( - grad_input, grad, output, mask, scale, batch_count, - softmax_elements_stride, softmax_elements); - break; - case 5: // 32 - masked_scale_softmax_warp_backward - <<>>( - grad_input, grad, output, mask, scale, batch_count, - softmax_elements_stride, softmax_elements); - break; - case 6: // 64 - masked_scale_softmax_warp_backward - <<>>( - grad_input, grad, output, mask, scale, batch_count, - softmax_elements_stride, softmax_elements); - break; - case 7: // 128 - masked_scale_softmax_warp_backward - <<>>( - grad_input, grad, output, mask, scale, batch_count, - softmax_elements_stride, softmax_elements); - break; - case 8: // 256 - masked_scale_softmax_warp_backward - <<>>( - grad_input, grad, output, mask, scale, batch_count, - softmax_elements_stride, softmax_elements); - break; - case 9: // 512 - masked_scale_softmax_warp_backward - <<>>( - grad_input, grad, output, mask, scale, batch_count, - softmax_elements_stride, softmax_elements); - break; - case 10: // 1024 - masked_scale_softmax_warp_backward - <<>>( - grad_input, grad, output, mask, scale, batch_count, - softmax_elements_stride, softmax_elements); - break; - default: - break; - } - } -} - -// elementwise multiplication called in at::softmax_backward_data is fused -// inside softmax dgrad kernel as a result of fusion, intermediate -// multiplication result is stored in fp32 in registers, instead of fp16 -template -__global__ void -softmax_warp_backward_fused_native(output_t *gradInput, const input_t *grad, - const input_t *output, int batch_size, - int stride, int element_count) { - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_backward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = - (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - - int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; - - // batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the - // batch - int local_idx = threadIdx.x % WARP_SIZE; - - // the first element to process by the current thread - int thread_offset = first_batch * stride + local_idx; - grad += thread_offset; - output += thread_offset; - gradInput += thread_offset; - - // The nested loops over WARP_BATCH and then WARP_ITERATIONS can be simplified - // to one loop, but I think doing so would obfuscate the logic of the - // algorithm, thus I chose to keep the nested loops. This should have no - // impact on performance because the loops are unrolled anyway. - - // load data from global memory - acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]; - acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]; - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : element_count; - for (int it = 0; it < WARP_ITERATIONS; ++it) { - int element_index = local_idx + it * WARP_SIZE; - if (element_index < batch_element_count) { - grad_reg[i][it] = grad[i * element_count + it * WARP_SIZE] * - output[i * element_count + it * WARP_SIZE]; - output_reg[i][it] = output[i * element_count + it * WARP_SIZE]; - } else { - grad_reg[i][it] = acc_t(0); - output_reg[i][it] = acc_t(0); - } - } - } - - acc_t sum[WARP_BATCH]; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] = grad_reg[i][0]; //* output_reg[i][0]; -#pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - sum[i] += grad_reg[i][it]; // * output_reg[i][it]; - } - } - warp_reduce_sum(sum); - -// store result -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; ++it) { - int element_index = local_idx + it * WARP_SIZE; - if (element_index < element_count) { - // compute gradients - if (is_log_softmax) { - gradInput[i * element_count + it * WARP_SIZE] = - (grad_reg[i][it] - std::exp(output_reg[i][it]) * sum[i]); - } else { - gradInput[i * element_count + it * WARP_SIZE] = - (grad_reg[i][it] - output_reg[i][it] * sum[i]); - } - } - } - } -} - -template -void dispatch_softmax_backward_fused_native( - output_t *grad_input, const input_t *grad, const input_t *output, - int softmax_elements, int softmax_elements_stride, int batch_count) { - TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 1024); - if (softmax_elements == 0) { - return; - } else { - int log2_elements = log2_ceil_native(softmax_elements); - const int next_power_of_two = 1 << log2_elements; - - // This value must match the WARP_SIZE constexpr value computed inside - // softmax_warp_backward. - int warp_size = - (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside - // softmax_warp_backward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - int blocks = (batch_count + batches_per_block - 1) / batches_per_block; - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - softmax_warp_backward_fused_native - <<>>( - grad_input, grad, output, batch_count, softmax_elements_stride, - softmax_elements); - break; - case 1: // 2 - softmax_warp_backward_fused_native - <<>>( - grad_input, grad, output, batch_count, softmax_elements_stride, - softmax_elements); - break; - case 2: // 4 - softmax_warp_backward_fused_native - <<>>( - grad_input, grad, output, batch_count, softmax_elements_stride, - softmax_elements); - break; - case 3: // 8 - softmax_warp_backward_fused_native - <<>>( - grad_input, grad, output, batch_count, softmax_elements_stride, - softmax_elements); - break; - case 4: // 16 - softmax_warp_backward_fused_native - <<>>( - grad_input, grad, output, batch_count, softmax_elements_stride, - softmax_elements); - break; - case 5: // 32 - softmax_warp_backward_fused_native - <<>>( - grad_input, grad, output, batch_count, softmax_elements_stride, - softmax_elements); - break; - case 6: // 64 - softmax_warp_backward_fused_native - <<>>( - grad_input, grad, output, batch_count, softmax_elements_stride, - softmax_elements); - break; - case 7: // 128 - softmax_warp_backward_fused_native - <<>>( - grad_input, grad, output, batch_count, softmax_elements_stride, - softmax_elements); - break; - case 8: // 256 - softmax_warp_backward_fused_native - <<>>( - grad_input, grad, output, batch_count, softmax_elements_stride, - softmax_elements); - break; - case 9: // 512 - softmax_warp_backward_fused_native - <<>>( - grad_input, grad, output, batch_count, softmax_elements_stride, - softmax_elements); - break; - case 10: // 1024 - softmax_warp_backward_fused_native - <<>>( - grad_input, grad, output, batch_count, softmax_elements_stride, - softmax_elements); - break; - default: - break; - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// -// Warp softmax backward -//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__global__ void softmax_warp_backward(__half *gradInput, const __half *grad, - const __half *output, int batch_size, - int stride, int element_count) { - int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; - - // batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the - // batch - int local_idx = threadIdx.x; - - // the first element to process by the current thread - int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; - grad += thread_offset; - output += thread_offset; - gradInput += thread_offset; - - // load data from global memory - input_t grad_reg_input[WARP_BATCH][WARP_ITERATIONS] = {0.0f}; - input_t output_reg_input[WARP_BATCH][WARP_ITERATIONS] = {0.0f}; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : element_count; -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < batch_element_count) { - copy_vector( - &grad_reg_input[i][it], grad + i * element_count + it * WARP_SIZE); - copy_vector(&output_reg_input[i][it], - output + i * element_count + - it * WARP_SIZE); - } - } - } - - // convert half to floating point - acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]; - acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - for (int it = 0; it < WARP_ITERATIONS; ++it) { - grad_reg[i][it] = grad_reg_input[i][it]; - output_reg[i][it] = output_reg_input[i][it]; - } - } - - // compute thread local sum - acc_t sum[WARP_BATCH] = {0}; -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; ++it) { - for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] += grad_reg[i][it] * output_reg[i][it]; - } - } - - // reduction sum - constexpr uint32_t FULL_MASK = 0xffffffff; -#pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE); - } - } - -// store result -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { - // compute gradients - output_t out[ELEMENTS_PER_LDG_STG]; - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = (output_reg[i][it + element] * - (grad_reg[i][it + element] - sum[i])); - } - // store them in global memory - copy_vector( - gradInput + i * element_count + it * WARP_SIZE, out); - } - } - } -} - -// WARP_BATCH number of batches. -// WARP_ITERATOINS The number of iterations required for one warp to iterate -// over all data. WARP_SIZE number of elements working on a single batch, has to -// be a power of two. ELEMENTS_PER_LDG_STG has to be 1. -template -using softmax_backward_func = void (*)(output_t *gradInput, const input_t *grad, - const input_t *output, int batch_size, - int stride, int element_count); - -template -bool warp_softmax_backward_kernel( - int log2_elements, int &warp_size, int &batches_per_warp, - softmax_backward_func &kernel) { - // determine size of a warp - const int next_power_of_two = 1 << log2_elements; - warp_size = (next_power_of_two < 32) ? next_power_of_two : 32; - - // determine how many batches a warp should process. - batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - switch (log2_elements) { - case 0: // 1 - kernel = &softmax_warp_backward; - break; - case 1: // 2 - kernel = &softmax_warp_backward; - break; - case 2: // 4 - kernel = &softmax_warp_backward; - break; - case 3: // 8 - kernel = &softmax_warp_backward; - break; - case 4: // 16 - kernel = &softmax_warp_backward; - break; - case 5: // 32 - kernel = &softmax_warp_backward; - break; - case 6: // 64 - kernel = &softmax_warp_backward; - break; - case 7: // 128 - kernel = &softmax_warp_backward; - break; - case 8: // 256 - kernel = &softmax_warp_backward; - break; - case 9: // 512 - kernel = &softmax_warp_backward; - break; - case 10: // 1024 - kernel = &softmax_warp_backward; - break; - default: - return false; - } - return true; -} - -template -bool dispatch_softmax_backward(output_t *grad_input, const input_t *grad, - const input_t *output, int softmax_elements, - int softmax_elements_stride, int batch_count) { - if (softmax_elements == 0) { - return true; - } else if (softmax_elements <= 1024) { - // compute function index. there's a function for each power of two size up - // to 1024. - int log2_elements = 0; - while ((1 << log2_elements) < softmax_elements) - ++log2_elements; - - softmax_backward_func kernel; - int warp_size, batches_per_warp; - if (!warp_softmax_backward_kernel( - log2_elements, warp_size, batches_per_warp, kernel)) { - return false; - } - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - // compute warps per block. - int warps_per_block = (threads_per_block / warp_size); - - // compute launch size - int batches_per_block = warps_per_block * batches_per_warp; - int blocks = (batch_count + batches_per_block - 1) / batches_per_block; - dim3 threads(warp_size, warps_per_block, 1); - - // launch - kernel<<>>( - grad_input, grad, output, batch_count, softmax_elements_stride, - softmax_elements); - return true; - } - return false; -} - -template -bool dispatch_softmax_backward_stream(output_t *grad_input, const input_t *grad, - const input_t *output, - int softmax_elements, - int softmax_elements_stride, - int batch_count, cudaStream_t streamid) { - if (softmax_elements == 0) { - return true; - } else if (softmax_elements <= 1024) { - // compute function index. there's a function for each power of two size up - // to 1024. - int log2_elements = 0; - while ((1 << log2_elements) < softmax_elements) - ++log2_elements; - softmax_backward_func kernel; - int warp_size, batches_per_warp; - if (!warp_softmax_backward_kernel( - log2_elements, warp_size, batches_per_warp, kernel)) { - return false; - } - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - // compute warps per block. - int warps_per_block = (threads_per_block / warp_size); - // compute launch size - int batches_per_block = warps_per_block * batches_per_warp; - int blocks = (batch_count + batches_per_block - 1) / batches_per_block; - dim3 threads(warp_size, warps_per_block, 1); - // launch - kernel<<>>( - grad_input, grad, output, batch_count, softmax_elements_stride, - softmax_elements); - return true; - } - return false; -} - -template -__global__ void -masked_softmax_warp_backward(__half *gradInput, const __half *grad, - const __half *output, const uint8_t *pad_mask, - int batch_size, int stride, int element_count, - int pad_batch_stride) { - int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; - - // batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the - // batch - int local_idx = threadIdx.x; - - // the first element to process by the current thread - int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; - grad += thread_offset; - output += thread_offset; - gradInput += thread_offset; - - // load data from global memory - input_t grad_reg_input[WARP_BATCH][WARP_ITERATIONS] = {0.0f}; - input_t output_reg_input[WARP_BATCH][WARP_ITERATIONS] = {0.0f}; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : element_count; -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < batch_element_count) { - copy_vector( - &grad_reg_input[i][it], grad + i * element_count + it * WARP_SIZE); - copy_vector(&output_reg_input[i][it], - output + i * element_count + - it * WARP_SIZE); - } - } - } - - // convert half to floating point - acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS]; - acc_t output_reg[WARP_BATCH][WARP_ITERATIONS]; -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - for (int it = 0; it < WARP_ITERATIONS; ++it) { - grad_reg[i][it] = grad_reg_input[i][it]; - output_reg[i][it] = output_reg_input[i][it]; - } - } - - // compute thread local sum - acc_t sum[WARP_BATCH] = {0}; -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; ++it) { - for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] += grad_reg[i][it] * output_reg[i][it]; - } - } - - // reduction sum - constexpr uint32_t FULL_MASK = 0xffffffff; -#pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE); - } - } - -// store result -#pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - int pad_thread_offset = ((first_batch + i) / pad_batch_stride) * stride + - ELEMENTS_PER_LDG_STG * local_idx; - const uint8_t *curr_mask = pad_mask + pad_thread_offset; -#pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { - // compute gradients - output_t out[ELEMENTS_PER_LDG_STG]; - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = (output_reg[i][it + element] * - (grad_reg[i][it + element] - sum[i])); - } - // store them in global memory - int itr_jmp = it * WARP_SIZE; - int itr_idx = i * element_count + itr_jmp; - // It is kind of unfortunate this has to be here to zero something out - // that is close to zero in the first place - apply_mask(&out[0], 0.0, - curr_mask + itr_jmp); - copy_vector(gradInput + itr_idx, out); - } - } - } -} - -// WARP_BATCH number of batches. -// WARP_ITERATOINS The number of iterations required for one warp to iterate -// over all data. WARP_SIZE number of elements working on a single batch, has to -// be a power of two. ELEMENTS_PER_LDG_STG has to be 1. -template -using masked_softmax_backward_func = - void (*)(output_t *gradInput, const input_t *grad, const input_t *output, - const uint8_t *pad_mask, int batch_size, int stride, - int element_count, int pad_batch_stride); - -template -bool warp_masked_softmax_backward_kernel( - int log2_elements, int &warp_size, int &batches_per_warp, - masked_softmax_backward_func &kernel) { - // determine size of a warp - const int next_power_of_two = 1 << log2_elements; - warp_size = (next_power_of_two < 32) ? next_power_of_two : 32; - - // determine how many batches a warp should process. - batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - switch (log2_elements) { - case 0: // 1 - kernel = - &masked_softmax_warp_backward; - break; - case 1: // 2 - kernel = - &masked_softmax_warp_backward; - break; - case 2: // 4 - kernel = - &masked_softmax_warp_backward; - break; - case 3: // 8 - kernel = - &masked_softmax_warp_backward; - break; - case 4: // 16 - kernel = - &masked_softmax_warp_backward; - break; - case 5: // 32 - kernel = - &masked_softmax_warp_backward; - break; - case 6: // 64 - kernel = - &masked_softmax_warp_backward; - break; - case 7: // 128 - kernel = - &masked_softmax_warp_backward; - break; - case 8: // 256 - kernel = - &masked_softmax_warp_backward; - break; - case 9: // 512 - kernel = - &masked_softmax_warp_backward; - break; - case 10: // 1024 - kernel = - &masked_softmax_warp_backward; - break; - default: - return false; - } - return true; -} - -template -bool dispatch_masked_softmax_backward(output_t *grad_input, const input_t *grad, - const input_t *output, - const uint8_t *pad_mask, - int softmax_elements, - int softmax_elements_stride, - int batch_count, int pad_batch_stride) { - if (softmax_elements == 0) { - return true; - } else if (softmax_elements <= 1024) { - // compute function index. there's a function for each power of two size up - // to 1024. - int log2_elements = 0; - while ((1 << log2_elements) < softmax_elements) - ++log2_elements; - - masked_softmax_backward_func kernel; - int warp_size, batches_per_warp; - if (!warp_masked_softmax_backward_kernel( - log2_elements, warp_size, batches_per_warp, kernel)) { - return false; - } - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - // compute warps per block. - int warps_per_block = (threads_per_block / warp_size); - - // compute launch size - int batches_per_block = warps_per_block * batches_per_warp; - int blocks = (batch_count + batches_per_block - 1) / batches_per_block; - dim3 threads(warp_size, warps_per_block, 1); - - // launch - kernel<<>>( - grad_input, grad, output, pad_mask, batch_count, - softmax_elements_stride, softmax_elements, pad_batch_stride); - return true; - } - return false; -} -} // namespace diff --git a/apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh b/apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh deleted file mode 100644 index 78ee110..0000000 --- a/apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh +++ /dev/null @@ -1,135 +0,0 @@ -#pragma once -#include -#include - -#include -#include -//#include -#include - -//#include -#include -#include - -//#include "cutlass/cutlass.h" -//#include "cutlass/gemm/gemm.h" -//#include "cutlass/gemm/wmma_gemm_traits.h" - -// symbol to be automatically resolved by PyTorch libs -/* -rocblas_datatype a_type = rocblas_datatype_f16_r; // OK -rocblas_datatype b_type = rocblas_datatype_f16_r; // OK -rocblas_datatype c_type = rocblas_datatype_f16_r; // OK -rocblas_datatype d_type = rocblas_datatype_f16_r; -rocblas_datatype compute_type = rocblas_datatype_f32_r; - -rocblas_gemm_algo algo = rocblas_gemm_algo_standard; -int32_t solution_index = 0; -rocblas_int flags = 0; -*/ - -namespace { -cublasOperation_t convertTransToCublasOperation(char trans) { - if (trans == 't') - return CUBLAS_OP_T; - else if (trans == 'n') - return CUBLAS_OP_N; - else if (trans == 'c') - return CUBLAS_OP_C; - else { - AT_ERROR("trans must be one of: t, n, c"); - return CUBLAS_OP_T; - } -} - -void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k, - float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB, - float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_gemm_algo algo, rocblas_int flags) { - cublasOperation_t opa = convertTransToCublasOperation(transa); - cublasOperation_t opb = convertTransToCublasOperation(transb); - - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - cublasSetStream(handle, stream); - float fAlpha = alpha; - float fBeta = beta; - //THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); - TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle, - opa, opb, (int)m, (int)n, (int)k, - (void*)&fAlpha, a, rocblas_datatype_f16_r /*a_type*/, (int)lda, strideA, - b, rocblas_datatype_f16_r /*b_type*/, (int)ldb, strideB, - (void*)&fBeta, c, rocblas_datatype_f16_r /*c_type*/, (int)ldc, strideC, - d, rocblas_datatype_f16_r /*d_type*/, int(ldd), strideD, - (int)batchCount, rocblas_datatype_f32_r /*compute_type*/, algo, 0 /*solution_index*/, flags)); -} - -void gemm_switch_fp32accum(char transa, char transb, long m, long n, long k, - float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB, - float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_int flags) { - auto stream = c10::cuda::getCurrentCUDAStream(); - if ( (transa == 't') && (transb == 'n') ) { - if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); } - else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); } - } else if ( (transa == 'n') && (transb == 'n') ) { - if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); } - else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); } - } else if ( (transa == 'n') && (transb == 't') ) { - if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); } - else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); } - } else { - AT_ASSERTM(false, "TransA and TransB are invalid"); - } -} - -void adjustLdLevel3(char transa, char transb, int64_t m, int64_t n, int64_t k, - int64_t *lda, int64_t *ldb, int64_t *ldc) { - int transa_ = ((transa == 't') || (transa == 'T')); - int transb_ = ((transb == 't') || (transb == 'T')); - - // Note: leading dimensions generally are checked that they are > 0 and at - // least as big the result requires (even if the value won't be used). - if (n <= 1) - *ldc = std::max(m, 1); - - if (transa_) { - if (m <= 1) - *lda = std::max(k, 1); - } else { - if (k <= 1) - *lda = std::max(m, 1); - } - - if (transb_) { - if (k <= 1) - *ldb = std::max(n, 1); - } else { - if (n <= 1) - *ldb = std::max(k, 1); - } -} - -void HgemmStridedBatched(char transa, char transb, long m, - long n, long k, float alpha, const half *a, long lda, - long strideA, const half *b, long ldb, long strideB, - float beta, half *c, long ldc, long strideC, - half *d, long ldd, long strideD, long batchCount) { - - if ((m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || - (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX)) - - { - AT_ERROR("Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, " - "batchCount" - "with the bound [val] <= %d", - INT_MAX); - } - - adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); - - // gemm_switch_fp32accum(transa, transb, m, n, k, alpha, a, lda, strideA, - // b, ldb, strideB, beta, c, ldc, strideC, batchCount); - gemm_switch_fp32accum(transa, transb, m, n, k, alpha, a, lda, strideA, - b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, 0 /*flags*/); -} - -} // namespace diff --git a/apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp b/apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp deleted file mode 100644 index 2797cc1..0000000 --- a/apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp +++ /dev/null @@ -1,25 +0,0 @@ -/** - * Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "nccl_p2p_cuda.cuh" - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("get_unique_nccl_id", &apex::contrib::nccl_p2p::get_unique_nccl_id, "get_unique_nccl_id"); - m.def("init_nccl_comm", &apex::contrib::nccl_p2p::init_nccl_comm, "init_nccl_comm"); - m.def("left_right_halo_exchange_inplace", &apex::contrib::nccl_p2p::left_right_halo_exchange_inplace, "left_right_halo_exchange_inplace"); - m.def("left_right_halo_exchange", &apex::contrib::nccl_p2p::left_right_halo_exchange, "left_right_halo_exchange"); - m.def("add_delay", &apex::contrib::nccl_p2p::add_delay, "add_delay"); -} diff --git a/apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu b/apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu deleted file mode 100644 index 8c935ac..0000000 --- a/apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu +++ /dev/null @@ -1,215 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#ifdef __HIP_PLATFORM_HCC__ -#include "rccl/rccl.h" -#else -#include "nccl.h" -#endif - -/* - * This file implements a crude but effective mechanism for copying data between tenors owned by different ranks - * on the same machine using cudaMemcpyAsync peer-to-peer transfers. - */ - -namespace { - -__global__ void AddDelay_kernel(const int delay, int* counter) { - if (blockIdx.x == 0 && threadIdx.x == 0) { - // waste time while doing something compiler can't predict, thus preventing it from optimizing away this code. - int new_counter = 0; - double elapsed = 0; - clock_t start = clock(); - do { - clock_t now = clock(); - elapsed = (double)(now - start)*1e9 / CLOCKS_PER_SEC; - ++new_counter; - } while (elapsed < (double)delay); - *counter = new_counter; - } -} - -class NcclCommWrapper -{ - private: - ncclComm_t comm; - int rank, world_size; - - ncclDataType_t get_nccl_type(at::Tensor input) - { - switch (input.scalar_type()) - { - case at::ScalarType::Half: - return ncclFloat16; - case at::ScalarType::Float: - return ncclFloat32; - case at::ScalarType::Double: - return ncclFloat64; - case at::ScalarType::Byte: - return ncclUint8; - case at::ScalarType::Char: - return ncclInt8; - case at::ScalarType::Int: - return ncclInt32; - case at::ScalarType::Long: - return ncclInt64; - case at::ScalarType::BFloat16: - return ncclBfloat16; - default: - assert(false); - } - } - - public: - NcclCommWrapper() - { - memset(&comm, 0, sizeof(ncclComm_t)); - rank = 0; - world_size = 0; - } - NcclCommWrapper(ncclUniqueId id, int my_rank, int num_ranks) - { - ncclCommInitRank(&comm, num_ranks, id, my_rank); - rank = my_rank; - world_size = num_ranks; - } - - ~NcclCommWrapper() - { - printf("ncclCommDestroy()\n"); - ncclCommDestroy(comm); - } - - void left_right_halo_exchange_inplace(int left_rank, int right_rank, at::Tensor left_output_halo, at::Tensor right_output_halo, at::Tensor left_input_halo, at::Tensor right_input_halo) - { - auto stream = at::cuda::getCurrentCUDAStream(); - ncclGroupStart(); - ncclDataType_t ncclType = get_nccl_type(left_output_halo); - bool left_zero = (left_rank < 0); - bool right_zero = (right_rank < 0); - size_t left_n = torch::numel(left_output_halo); - size_t right_n = torch::numel(right_output_halo); - assert(left_n > 0 && left_n == right_n); - if (left_zero) { - left_input_halo.zero_(); - } else { - AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, left_output_halo.scalar_type(), "left_halo_exch", [&]() { - // send left (to my_rank - 1) - ncclSend(left_output_halo.data_ptr(), left_n, ncclType, left_rank, comm, stream); - // receive left (from my_rank - 1) - ncclRecv(left_input_halo.data_ptr(), right_n, ncclType, left_rank, comm, stream); - }); - } - if (right_zero) { - right_input_halo.zero_(); - } else { - AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, right_output_halo.scalar_type(), "right_halo_exch", [&]() { - // send right (to my_rank + 1 ) - ncclSend(right_output_halo.data_ptr(), right_n, ncclType, right_rank, comm, stream); - // receive right (from my_rank + 1) - ncclRecv(right_input_halo.data_ptr(), left_n, ncclType, right_rank, comm, stream); - }); - } - ncclGroupEnd(); - } - - std::vector left_right_halo_exchange(int left_rank, int right_rank, at::Tensor left_output_halo, at::Tensor right_output_halo) - { - // after halo exchange: - // left_output_halo of rank+1 ends up in right_input_halo of rank - // right_output_halo of rank-1 ends up in left_input_halo of rank - auto right_input_halo = torch::empty_like(left_output_halo); - auto left_input_halo = torch::empty_like(right_output_halo); - left_right_halo_exchange_inplace(left_rank, right_rank, left_output_halo, right_output_halo, left_input_halo, right_input_halo); - return {left_input_halo, right_input_halo}; - } -}; - -class ManagedObjects -{ - public: - ManagedObjects() - { - } - ~ManagedObjects() - { - for (auto it = _nccl_comms.begin(); it != _nccl_comms.end(); ++it) - { - delete *it; - } - } - - int add_comm(NcclCommWrapper* comm) - { - int handle = _nccl_comms.size(); - _nccl_comms.push_back(comm); - return handle; - } - - NcclCommWrapper& get_comm(int handle) - { - assert(handle >= 0 && handle < _nccl_comms.size()); - return *_nccl_comms[handle]; - } - - private: - std::vector _nccl_comms; -}; -class ManagedObjects mo; - -} // end anonymous namespace - -namespace apex { namespace contrib { namespace nccl_p2p { - -at::Tensor get_unique_nccl_id(int n) -{ - ncclUniqueId id; - ncclGetUniqueId(&id); - auto id_tensor = torch::empty({n,(int)sizeof(ncclUniqueId)}, torch::dtype(torch::kUInt8).device(torch::kCPU).requires_grad(false)); - auto id_ptr = id_tensor.data_ptr(); - size_t offset = 0; - for (int i = 0; i < n; ++i) - { - ncclUniqueId id; - ncclGetUniqueId(&id); - memcpy(id_ptr+offset, &id, sizeof(ncclUniqueId)); - offset += sizeof(ncclUniqueId); - } - return id_tensor; -} - -int init_nccl_comm(at::Tensor unique_nccl_id, int my_rank, int num_ranks) -{ - ncclUniqueId id; - auto unique_nccl_id_ptr = unique_nccl_id.data_ptr(); - memcpy(&id, unique_nccl_id_ptr, sizeof(ncclUniqueId)); - NcclCommWrapper* comm = new NcclCommWrapper(id, my_rank, num_ranks); - int handle = mo.add_comm(comm); - comm = 0L; - return handle; -} - -void left_right_halo_exchange_inplace(int handle, int left_rank, int right_rank, at::Tensor left_output_halo, at::Tensor right_output_halo, at::Tensor left_input_halo, at::Tensor right_input_halo) -{ - class NcclCommWrapper& communicator = mo.get_comm(handle); - return communicator.left_right_halo_exchange_inplace(left_rank, right_rank, left_output_halo, right_output_halo, left_input_halo, right_input_halo); -} - -std::vector left_right_halo_exchange(int handle, int left_rank, int right_rank, at::Tensor left_output_halo, at::Tensor right_output_halo) -{ - class NcclCommWrapper& communicator = mo.get_comm(handle); - return communicator.left_right_halo_exchange(left_rank, right_rank, left_output_halo, right_output_halo); -} - -void add_delay(int delay) -{ - auto stream = at::cuda::getCurrentCUDAStream(); - auto t = torch::empty({1}, torch::dtype(torch::kInt32).device(torch::kCUDA)); - AddDelay_kernel<<<1,1,0,stream>>>(delay, t.data_ptr()); -} - -}}} diff --git a/apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh b/apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh deleted file mode 100644 index 6d29420..0000000 --- a/apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cuh +++ /dev/null @@ -1,45 +0,0 @@ -/** - * Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include -#ifndef _nccl_p2p_h_ -#define _nccl_p2p_h_ - -namespace apex { namespace contrib { namespace nccl_p2p { -at::Tensor get_unique_nccl_id(int n); -int init_nccl_comm( - at::Tensor unique_nccl_id, - int my_rank, - int num_ranks - ); -void left_right_halo_exchange_inplace( - int handle, - int left_rank, - int right_rank, - at::Tensor left_output_halo, - at::Tensor right_output_halo, - at::Tensor left_input_halo, - at::Tensor right_input_halo); -std::vector left_right_halo_exchange( - int handle, - int left_rank, - int right_rank, - at::Tensor left_output_halo, - at::Tensor right_output_halo); -void add_delay(int delay); -}}} -#endif diff --git a/apex/contrib/csrc/optimizers/fused_adam_cuda.cpp b/apex/contrib/csrc/optimizers/fused_adam_cuda.cpp deleted file mode 100644 index c03c90f..0000000 --- a/apex/contrib/csrc/optimizers/fused_adam_cuda.cpp +++ /dev/null @@ -1,86 +0,0 @@ -#include - -// CUDA forward declaration -void fused_strided_check_finite(at::Tensor & overflow_flag, at::Tensor & p_copy, int stride, int clear_overflow_first); - -void fused_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay); -void fused_reversible_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay); -void fused_maybe_adam_undo_cuda(at::Tensor & overflow_flag, at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay); - -void fused_adam_cuda_mt(int chunk_size, at::Tensor overflow_flag, std::vector> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay); - -void maybe_cast_cuda(at::Tensor & overflow_flag, at::Tensor & p_in, at::Tensor & p_out); -void maybe_cast_cuda_mt(int chunk_size, at::Tensor overflow_flag, std::vector> tensor_lists); - -#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) - -// C++ interface -void strided_check_finite( - at::Tensor& overflow_flag, - at::Tensor& p_copy, - int stride, - int clear_overflow_first - ) { - CHECK_INPUT(p_copy); - fused_strided_check_finite(overflow_flag, p_copy, stride, clear_overflow_first); -} -void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) { - CHECK_INPUT(p); - if (p_copy.numel() > 0) CHECK_INPUT(p_copy); - CHECK_INPUT(m); - CHECK_INPUT(v); - CHECK_INPUT(g); - int64_t num_elem = p.numel(); - AT_ASSERTM(m.numel() == num_elem, "number of elements in m and p tensors should be equal"); - AT_ASSERTM(v.numel() == num_elem, "number of elements in v and p tensors should be equal"); - AT_ASSERTM(g.numel() == num_elem, "number of elements in g and p tensors should be equal"); - AT_ASSERTM(p_copy.numel() == num_elem || p_copy.numel() == 0, "number of elements in p_copy and p tensors should be equal, or p_copy should be empty"); - - fused_adam_cuda(p, p_copy, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay); -} -void reversible_adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) { - CHECK_INPUT(p); - if (p_copy.numel() > 0) CHECK_INPUT(p_copy); - CHECK_INPUT(m); - CHECK_INPUT(v); - CHECK_INPUT(g); - int64_t num_elem = p.numel(); - AT_ASSERTM(m.numel() == num_elem, "number of elements in m and p tensors should be equal"); - AT_ASSERTM(v.numel() == num_elem, "number of elements in v and p tensors should be equal"); - AT_ASSERTM(g.numel() == num_elem, "number of elements in g and p tensors should be equal"); - AT_ASSERTM(p_copy.numel() == num_elem || p_copy.numel() == 0, "number of elements in p_copy and p tensors should be equal, or p_copy should be empty"); - - fused_reversible_adam_cuda(p, p_copy, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay); -} -void maybe_adam_undo(at::Tensor & overflow_flag, at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) { - CHECK_INPUT(p); - CHECK_INPUT(m); - CHECK_INPUT(v); - CHECK_INPUT(g); - int64_t num_elem = p.numel(); - AT_ASSERTM(m.numel() == num_elem, "number of elements in m and p tensors should be equal"); - AT_ASSERTM(v.numel() == num_elem, "number of elements in v and p tensors should be equal"); - AT_ASSERTM(g.numel() == num_elem, "number of elements in g and p tensors should be equal"); - - fused_maybe_adam_undo_cuda(overflow_flag, p, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay); -} -void maybe_cast(at::Tensor & overflow_flag, at::Tensor & p_in, at::Tensor & p_out) { - CHECK_INPUT(p_in); - CHECK_INPUT(p_out); - int64_t num_elem = p_in.numel(); - AT_ASSERTM(p_out.numel() == num_elem, "number of elements in p_in and p_out should be equal"); - - maybe_cast_cuda(overflow_flag, p_in, p_out); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("strided_check_finite", &strided_check_finite, "Strided finite check."); - m.def("adam", &adam, "Adam optimized CUDA implementation."); - m.def("reversible_adam", &reversible_adam, "Reversible Adam optimized CUDA implementation."); - m.def("adam_mt", &fused_adam_cuda_mt, "Multi tensor Adam optimized CUDA implementation."); - m.def("maybe_adam_undo", &maybe_adam_undo, "Undo function for Adam optimized CUDA implementation."); - m.def("maybe_cast", &maybe_cast, "Unpack byte tensor containing e5m2 floats."); - m.def("maybe_cast_mt", &maybe_cast_cuda_mt, "Unpack byte tensor containing e5m2 floats."); -} diff --git a/apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu b/apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu deleted file mode 100644 index 18b6026..0000000 --- a/apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu +++ /dev/null @@ -1,1037 +0,0 @@ -#include -#include -#include -#include - -#include "ATen/ATen.h" -#include "ATen/cuda/CUDAContext.h" -#include "ATen/cuda/detail/IndexUtils.cuh" -#include "ATen/TensorUtils.h" -// #include "ATen/Type.h" -#include "ATen/AccumulateType.h" - -#include "multi_tensor_apply.cuh" - -#define BLOCK_SIZE 512 -#define ILP 4 - -template -__device__ __forceinline__ bool is_aligned(T* p){ - return ((uint64_t)p) % (ILP*sizeof(T)) == 0; -} - -template -__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){ - typedef typename std::aligned_storage::type LT; - ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset]; -} - -#include "type_shim.h" - -typedef enum{ - ADAM_MODE_0 =0, // eps under square root - ADAM_MODE_1 =1 // eps outside square root -} adamMode_t; - -template -__global__ void adam_cuda_kernel( - T* __restrict__ p, - GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed - T* __restrict__ m, - T* __restrict__ v, - const GRAD_T * __restrict__ g, - const float b1, - const float b2, - const float eps, - const float grad_scale, - const float step_size, - const size_t tsize, - adamMode_t mode, - const float decay) -{ - //Assuming 2D grids and 2D blocks - const int blockId = gridDim.x * blockIdx.y + blockIdx.x; - const int threadsPerBlock = blockDim.x * blockDim.y; - const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x; - const int i = (blockId * threadsPerBlock + threadIdInBlock); - const int totThreads = gridDim.x*gridDim.y*threadsPerBlock; - - for (int j = i; j < tsize; j+=totThreads) { - T scaled_grad = g[j]/grad_scale; - m[j] = b1*m[j] + (1-b1)*scaled_grad; - v[j] = b2*v[j] + (1-b2)*scaled_grad*scaled_grad; - float denom; - if (mode == ADAM_MODE_0) - denom = sqrtf(v[j] + eps); - else // Mode 1 - denom = sqrtf(v[j]) + eps; - float update = (m[j]/denom) + (decay*p[j]); - p[j] = p[j] - (step_size*update); - if (p_copy != NULL) p_copy[j] = (GRAD_T) p[j]; - } -} - -template -struct AdamFunctor -{ - __device__ __forceinline__ void operator()( - int chunk_size, - volatile int* noop_gmem, - TensorListMetadata& tl, - const float b1, - const float b2, - const float eps, - const float grad_scale, - const float step_size, - adamMode_t mode, - const float decay) - { - int tensor_loc = tl.block_to_tensor[blockIdx.x]; - int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; - - T* p = (T *)tl.addresses[0][tensor_loc]; - p += chunk_idx*chunk_size; - T* m = (T *)tl.addresses[1][tensor_loc]; - m += chunk_idx*chunk_size; - T* v = (T *)tl.addresses[2][tensor_loc]; - v += chunk_idx*chunk_size; - GRAD_T* g = (GRAD_T *)tl.addresses[3][tensor_loc]; - g += chunk_idx*chunk_size; - GRAD_T* p_copy = NULL; - if (DEPTH == 5) { - p_copy = (GRAD_T *)tl.addresses[4][tensor_loc]; - p_copy += chunk_idx*chunk_size; - } - - n -= chunk_idx*chunk_size; - - T incoming_p[ILP]; - T incoming_m[ILP]; - T incoming_v[ILP]; - T incoming_g[ILP]; - - // to make things simple, we put aligned case in a different code path - if(n % ILP == 0 && - chunk_size % ILP == 0 && - is_aligned(p) && - is_aligned(m) && - is_aligned(v) && - is_aligned(g) && - is_aligned(p_copy)) - { - for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x) - { - // load - GRAD_T tmp_g[ILP]; - load_store(incoming_p, p, 0, i_start); - load_store(incoming_m, m, 0, i_start); - load_store(incoming_v, v, 0, i_start); - load_store(tmp_g, g, 0, i_start); -#pragma unroll - for(int ii = 0; ii < ILP; ii++) { - incoming_g[ii] = static_cast(tmp_g[ii]); - T scaled_grad = incoming_g[ii]/grad_scale; - incoming_m[ii] = b1*incoming_m[ii] + (1-b1)*scaled_grad; - incoming_v[ii] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad; - float denom; - if (mode == ADAM_MODE_0) - denom = sqrtf(incoming_v[ii] + eps); - else // Mode 1 - denom = sqrtf(incoming_v[ii]) + eps; - float update = (incoming_m[ii]/denom) + (decay*incoming_p[ii]); - incoming_p[ii] = incoming_p[ii] - (step_size*update); - if (DEPTH == 5) tmp_g[ii] = static_cast(incoming_p[ii]); - } - load_store(p, incoming_p, i_start, 0); - load_store(m, incoming_m, i_start, 0); - load_store(v, incoming_v, i_start, 0); - if (DEPTH == 5) load_store(p_copy, tmp_g, i_start, 0); - } - } - else - { - for(int i_start = 0; - i_start < n && i_start < chunk_size; - i_start += blockDim.x*ILP) { - -#pragma unroll - for(int ii = 0; ii < ILP; ii++) { - incoming_p[ii] = 0; - incoming_m[ii] = 0; - incoming_v[ii] = 0; - incoming_g[ii] = 0; - - int i = i_start + threadIdx.x + ii*blockDim.x; - if (i < n && i < chunk_size) { - incoming_p[ii] = p[i]; - incoming_m[ii] = m[i]; - incoming_v[ii] = v[i]; - incoming_g[ii] = static_cast(g[i]); - } - } - - // note for clarification to future michael: - // From a pure memory dependency perspective, there's likely no point unrolling - // the write loop, since writes just fire off once their LDGs arrive. - // Put another way, the STGs are dependent on the LDGs, but not on each other. - // There is still compute ILP benefit from unrolling the loop though. -#pragma unroll - for(int ii = 0; ii < ILP; ii++) { - int j = i_start + threadIdx.x + ii*blockDim.x; - - if(j < n && j < chunk_size) { - T scaled_grad = incoming_g[ii]/grad_scale; - m[j] = b1*incoming_m[ii] + (1-b1)*scaled_grad; - v[j] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad; - float denom; - if (mode == ADAM_MODE_0) - denom = sqrtf(v[j] + eps); - else // Mode 1 - denom = sqrtf(v[j]) + eps; - float update = (m[j]/denom) + (decay*incoming_p[ii]); - p[j] = incoming_p[ii] - (step_size*update); - if (DEPTH == 5) p_copy[j] = (GRAD_T) p[j]; - } - } - } - } - } -}; - -void fused_adam_cuda( - at::Tensor & p, - at::Tensor & p_copy, - at::Tensor & m, - at::Tensor & v, - at::Tensor & g, - float lr, - float beta1, - float beta2, - float eps, - float grad_scale, - int step, - int mode, - int bias_correction, - float decay) -{ -// using namespace at; - - //Get tensor size - int tsize = p.numel(); - //Determine #threads and #blocks - const int threadsPerBlock = 512; - const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock); - AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32"); - //Constants - float step_size = 0; - if (bias_correction == 1) { - const float bias_correction1 = 1 - std::pow(beta1, step); - const float bias_correction2 = 1 - std::pow(beta2, step); - step_size = lr * std::sqrt(bias_correction2)/bias_correction1; - } - else { - step_size = lr; - } - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - if (g.scalar_type() == at::ScalarType::Half || g.scalar_type() == at::ScalarType::BFloat16) { -//all other values should be fp32 for half gradients - AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type"); -//dispatch is done on the gradient type - using namespace at; // prevents "toString is undefined" errors - DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(g.scalar_type(), 0, "adam_cuda_kernel", - using accscalar_t = at::acc_type; - adam_cuda_kernel<<>>( - p.DATA_PTR(), - p_copy.numel() ? p_copy.DATA_PTR() : NULL, - m.DATA_PTR(), - v.DATA_PTR(), - g.DATA_PTR(), - beta1, - beta2, - eps, - grad_scale, - step_size, - tsize, - (adamMode_t) mode, - decay); - ); - } else { - using namespace at; - DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel", - adam_cuda_kernel<<>>( - p.DATA_PTR(), - NULL, //don't output p_copy for fp32, it's wasted write - m.DATA_PTR(), - v.DATA_PTR(), - g.DATA_PTR(), - beta1, - beta2, - eps, - grad_scale, - step_size, - tsize, - (adamMode_t) mode, - decay); - ); - } - C10_CUDA_CHECK(cudaGetLastError()); - -} - -void fused_adam_cuda_mt( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, // p, m, v, g, p_copy - float lr, - float beta1, - float beta2, - float eps, - float grad_scale, - int step, - int mode, - int bias_correction, - float decay) { - - //Constants - float step_size = 0; - if (bias_correction == 1) { - const float bias_correction1 = 1 - std::pow(beta1, step); - const float bias_correction2 = 1 - std::pow(beta2, step); - step_size = lr * std::sqrt(bias_correction2)/bias_correction1; - } - else { - step_size = lr; - } - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - size_t tl_sz = tensor_lists.size(); - AT_ASSERTM(tl_sz == 4 || tl_sz == 5, "expected tensor lists of size 4 or 5"); - - if (tensor_lists[3][0].scalar_type() == at::ScalarType::Half || tensor_lists[3][0].scalar_type() == at::ScalarType::BFloat16) { -//alher values should be fp32 for half gradients - AT_ASSERTM(tensor_lists[0][0].scalar_type() == at::ScalarType::Float, "expected parameter to be of float type"); -//dich is done on the gradient type - if (tl_sz == 5) { - DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel", - using accscalar_t = at::acc_type; - multi_tensor_apply<5>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - AdamFunctor<5, accscalar_t, scalar_t_0>(), - beta1, - beta2, - eps, - grad_scale, - step_size, - (adamMode_t) mode, - decay); - ); - } else { - DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel", - using accscalar_t = at::acc_type; - multi_tensor_apply<4>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - AdamFunctor<4, accscalar_t, scalar_t_0>(), - beta1, - beta2, - eps, - grad_scale, - step_size, - (adamMode_t) mode, - decay); - ); - } - } else { - if (tl_sz == 5) { - DISPATCH_DOUBLE_AND_FLOAT(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel", - multi_tensor_apply<5>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - AdamFunctor<5, scalar_t_0, scalar_t_0>(), - beta1, - beta2, - eps, - grad_scale, - step_size, - (adamMode_t) mode, - decay); - ); - } else { - DISPATCH_DOUBLE_AND_FLOAT(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel", - multi_tensor_apply<4>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - AdamFunctor<4, scalar_t_0, scalar_t_0>(), - beta1, - beta2, - eps, - grad_scale, - step_size, - (adamMode_t) mode, - decay); - ); - } - } - C10_CUDA_CHECK(cudaGetLastError()); -} - -template -__device__ void convert(const FROM_T vi, TO_T& vo) -{ - vo = static_cast(vi); -} - -template <> -__device__ void convert(const float vi, uint8_t& vo) -{ - union S - { - float as_float; - int as_int; - }; - S s; - s.as_float = vi; - s.as_int = s.as_int & 0xFF800000; - union T - { - at::Half as_half; - uint8_t as_byte[2]; - }; - T t; - t.as_half = static_cast(vi + s.as_float / 8.0f); - vo = t.as_byte[1]; -} - -template <> -__device__ void convert(const uint8_t vi, float& vo) -{ - union T - { - at::Half as_half; - uint8_t as_byte[2]; - }; - T t; - t.as_byte[0] = 0; - t.as_byte[1] = vi; - vo = static_cast(t.as_half); -} - -template <> -__device__ void convert(const at::Half vi, uint8_t& vo) -{ - union S - { - float as_float; - int as_int; - }; - S s; - s.as_float = static_cast(vi); - s.as_int = s.as_int & 0xFF800000; - union T - { - at::Half as_half; - uint8_t as_byte[2]; - }; - T t; - t.as_half = static_cast(vi + s.as_float / 8.0f); - vo = t.as_byte[1]; -} - -template <> -__device__ void convert(const uint8_t vi, at::Half& vo) -{ - union T - { - at::Half as_half; - uint8_t as_byte[2]; - }; - T t; - t.as_byte[0] = 0; - t.as_byte[1] = vi; - vo = t.as_half; -} - -template -__global__ void strided_check_finite_cuda_kernel( - volatile int* noop_gmem, - GRAD_T* __restrict__ p_copy, - const size_t tsize, - int stride, - int clear_overflow_first) -{ - //Assuming 2D grids and 2D blocks - const int blockId = gridDim.x * blockIdx.y + blockIdx.x; - const int threadsPerBlock = blockDim.x * blockDim.y; - const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x; - const int i = (blockId * threadsPerBlock + threadIdInBlock) * stride; - const int totThreads = gridDim.x*gridDim.y*threadsPerBlock*stride; - - if (clear_overflow_first) { - if (i == 0) { - *noop_gmem = 0; - } - __syncthreads(); - } - - for (int j = i; j < tsize; j+=totThreads) { - GRAD_T pi = p_copy[j]; - if (!isfinite(pi)) { - *noop_gmem = 1; - } - } -} -template <> -__global__ void strided_check_finite_cuda_kernel( - volatile int* noop_gmem, - uint8_t* __restrict__ p_copy, - const size_t tsize, - int stride, - int clear_overflow_first) -{ - //Assuming 2D grids and 2D blocks - const int blockId = gridDim.x * blockIdx.y + blockIdx.x; - const int threadsPerBlock = blockDim.x * blockDim.y; - const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x; - const int i = (blockId * threadsPerBlock + threadIdInBlock) * stride; - const int totThreads = gridDim.x*gridDim.y*threadsPerBlock*stride; - - if (clear_overflow_first) { - if (i == 0) { - *noop_gmem = 0; - } - __syncthreads(); - } - - for (int j = i; j < tsize; j+=totThreads) { - at::Half pi; - convert(p_copy[j], pi); - if (!isfinite(pi)) { - *noop_gmem = 1; - } - } -} - -template -__global__ void maybe_cast_kernel( - volatile int* overflow_flag, - const FROM_T* p_in, - TO_T* p_out, - const size_t tsize) -{ - if (overflow_flag && *overflow_flag != 0) return; - - //Assuming 2D grids and 2D blocks - const int blockId = gridDim.x * blockIdx.y + blockIdx.x; - const int threadsPerBlock = blockDim.x * blockDim.y; - const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x; - const int i = (blockId * threadsPerBlock + threadIdInBlock); - const int totThreads = gridDim.x*gridDim.y*threadsPerBlock; - - FROM_T pi[ILP]; - TO_T po[ILP]; - - for(int j_start = 0; j_start < tsize; j_start+=totThreads*ILP) { -#pragma unroll - for(int ii = 0; ii < ILP; ii++) { - pi[ii] = 0; - - int j = j_start + i + totThreads*ii; - if (j < tsize) { - pi[ii] = p_in[j]; - } - } - -#pragma unroll - for(int ii = 0; ii < ILP; ii++) { - convert(pi[ii], po[ii]); - } - -#pragma unroll - for(int ii = 0; ii < ILP; ii++) { - int j = j_start + i + totThreads*ii; - if (j < tsize) { - p_out[j] = po[ii]; - } - } - } -} - -template -__global__ void reversible_adam_cuda_kernel( - T* __restrict__ p, - REDU_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed - T* __restrict__ m, - T* __restrict__ v, - const GRAD_T * __restrict__ g, - const float b1, - const float b2, - const float eps, - const float grad_scale, - const float step_size, - const size_t tsize, - adamMode_t mode, - const float decay) -{ - //Assuming 2D grids and 2D blocks - const int blockId = gridDim.x * blockIdx.y + blockIdx.x; - const int threadsPerBlock = blockDim.x * blockDim.y; - const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x; - const int i = (blockId * threadsPerBlock + threadIdInBlock); - const int totThreads = gridDim.x*gridDim.y*threadsPerBlock; - - T mi[ILP]; - T vi[ILP]; - T pi[ILP]; - T gi[ILP]; - - bool overflow = false; - for(int j_start = 0; j_start < tsize; j_start+=totThreads*ILP) { -#pragma unroll - for(int ii = 0; ii < ILP; ii++) { - mi[ii] = T(0); - vi[ii] = T(0); - pi[ii] = T(0); - gi[ii] = GRAD_T(0); - - int j = j_start + i + totThreads*ii; - if (j < tsize) { - pi[ii] = p[j]; - mi[ii] = m[j]; - vi[ii] = v[j]; - gi[ii] = static_cast(g[j]); - } - } - -#pragma unroll - for(int ii = 0; ii < ILP; ii++) { - T scaled_grad = gi[ii]/grad_scale; - if (isfinite(scaled_grad)) { - mi[ii] = b1*mi[ii] + (1-b1)*scaled_grad; - vi[ii] = b2*vi[ii] + (1-b2)*scaled_grad*scaled_grad; - float denom; - if (mode == ADAM_MODE_0) - denom = sqrtf(vi[ii] + eps); - else // Mode 1 - denom = sqrtf(vi[ii]) + eps; - float update = (mi[ii]/denom) + (decay*pi[ii]); - pi[ii] = pi[ii] - (step_size*update); - } else { - overflow = true; - } - } - -#pragma unroll - for(int ii = 0; ii < ILP; ii++) { - int j = j_start + i + totThreads*ii; - if (j < tsize) { - m[j] = mi[ii]; - v[j] = vi[ii]; - p[j] = pi[ii]; - if (p_copy != NULL) { - convert(pi[ii], p_copy[j]); - } - } - } - } - - if (p_copy != NULL) { - __syncthreads(); - if (overflow) { - convert(float(INFINITY), p_copy[0]); - } - } -} - -template -__global__ void maybe_adam_undo_cuda_kernel( - volatile int* overflow_flag, - T* __restrict__ p, - T* __restrict__ m, - T* __restrict__ v, - const GRAD_T * __restrict__ g, - const float b1, - const float b2, - const float eps, - const float grad_scale, - const float step_size, - const size_t tsize, - adamMode_t mode, - const float decay) -{ - // NB! Skip undo kernel when overflow flag is NOT set - if (overflow_flag && *overflow_flag == 0) return; - - //Assuming 2D grids and 2D blocks - const int blockId = gridDim.x * blockIdx.y + blockIdx.x; - const int threadsPerBlock = blockDim.x * blockDim.y; - const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x; - const int i = (blockId * threadsPerBlock + threadIdInBlock); - const int totThreads = gridDim.x*gridDim.y*threadsPerBlock; - - T mi[ILP]; - T vi[ILP]; - T pi[ILP]; - T gi[ILP]; - - for(int j_start = 0; j_start < tsize; j_start+=totThreads*ILP) { -#pragma unroll - for(int ii = 0; ii < ILP; ii++) { - mi[ii] = T(0); - vi[ii] = T(0); - pi[ii] = T(0); - gi[ii] = GRAD_T(0); - - int j = j_start + i*ILP; - if (j < tsize) { - pi[ii] = p[j]; - mi[ii] = m[j]; - vi[ii] = v[j]; - gi[ii] = static_cast(g[j]); - } - } - -#pragma unroll - for(int ii = 0; ii < ILP; ii++) { - T scaled_grad = gi[ii]/grad_scale; - if (isfinite(scaled_grad)) { - float denom; - if (mode == ADAM_MODE_0) - denom = sqrtf(vi[ii] + eps); - else // Mode 1 - denom = sqrtf(vi[ii]) + eps; - pi[ii] = (pi[ii] + step_size*(mi[ii]/denom)) / (1.0f - step_size*decay); - mi[ii] = (mi[ii] - (1-b1)*scaled_grad) / b1; - vi[ii] = (vi[ii] - (1-b2)*scaled_grad*scaled_grad) / b2; - // Make sure round off errors don't create (small) negative value. - // This can happen if we have to revert the very first step. - vi[ii] = vi[ii] >= 0.0f ? vi[ii] : 0.0f; - } - } - -#pragma unroll - for(int ii = 0; ii < ILP; ii++) { - int j = j_start + i*ILP; - if (j < tsize) { - m[j] = mi[ii]; - v[j] = vi[ii]; - p[j] = pi[ii]; - } - } - } -} - -template -struct MaybeCastFunctor -{ - __device__ __forceinline__ void operator()( - int chunk_size, - volatile int* overflow_flag, - TensorListMetadata& tl) - { - if (overflow_flag && *overflow_flag != 0) return; - - int tensor_loc = tl.block_to_tensor[blockIdx.x]; - int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; - - FROM_T* p_in = (FROM_T *)tl.addresses[0][tensor_loc]; - p_in += chunk_idx*chunk_size; - TO_T* p_out = (TO_T *)tl.addresses[1][tensor_loc]; - p_out += chunk_idx*chunk_size; - - n -= chunk_idx*chunk_size; - int dim = chunk_size < n ? chunk_size : n; - - FROM_T pi[ILP]; - TO_T po[ILP]; - - for(int j_start = 0; j_start < dim; j_start+=blockDim.x*ILP) { -#pragma unroll - for(int ii = 0; ii < ILP; ii++) { - pi[ii] = FROM_T(0); - int j = j_start + threadIdx.x + ii*blockDim.x; - if (j < dim) { - pi[ii] = p_in[j]; - } - } - -#pragma unroll - for(int ii = 0; ii < ILP; ii++) { - convert(pi[ii], po[ii]); - } - -#pragma unroll - for(int ii = 0; ii < ILP; ii++) { - int j = j_start + threadIdx.x + ii*blockDim.x; - if (j < dim) { - p_out[j] = po[ii]; - } - } - } - } -}; - -void fused_strided_check_finite( - at::Tensor & overflow_flag, - at::Tensor & p_copy, - int stride, - int clear_overflow_first) -{ - //Get tensor size - int tsize = p_copy.numel(); - int niter = (tsize + stride - 1) / stride; - - //Determine #threads and #blocks - const int threadsPerBlock = 512; - //In order to avoid race condition, blocks must be 1 when clear_overflow_first flag is set. - const dim3 blocks(clear_overflow_first ? 1 : (niter+threadsPerBlock-1)/threadsPerBlock); - AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p_copy), "parameter tensor is too large to be indexed with int32"); - - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - using namespace at; // prevents "toString is undefined" errors - DISPATCH_FLOAT_HALF_AND_BYTE(p_copy.scalar_type(), 0, "check_finite_cuda_kernel", - strided_check_finite_cuda_kernel<<>>( - overflow_flag.DATA_PTR(), - p_copy.DATA_PTR(), - tsize, - stride, - clear_overflow_first); - ); - C10_CUDA_CHECK(cudaGetLastError()); -} - -void fused_reversible_adam_cuda( - at::Tensor & p, - at::Tensor & p_copy, - at::Tensor & m, - at::Tensor & v, - at::Tensor & g, - float lr, - float beta1, - float beta2, - float eps, - float grad_scale, - int step, - int mode, - int bias_correction, - float decay) -{ -// using namespace at; - - //Get tensor size - int tsize = p.numel(); - //Determine #threads and #blocks - const int threadsPerBlock = 512; - const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock); - AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32"); - //Constants - float step_size = 0; - if (bias_correction == 1) { - const float bias_correction1 = 1 - std::pow(beta1, step); - const float bias_correction2 = 1 - std::pow(beta2, step); - step_size = lr * std::sqrt(bias_correction2)/bias_correction1; - } - else { - step_size = lr; - } - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - if (g.scalar_type() == at::ScalarType::Half || g.scalar_type() == at::ScalarType::BFloat16) { - //all other values should be fp32 for half gradients - AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type"); - //dispatch is done on the gradient type - using namespace at; // prevents "toString is undefined" errors - if (p_copy.numel() == 0 || p_copy.scalar_type() == g.scalar_type()) { - DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(g.scalar_type(), 0, "adam_cuda_kernel", - using accscalar_t = at::acc_type; - reversible_adam_cuda_kernel<<>>( - p.DATA_PTR(), - p_copy.numel() ? p_copy.DATA_PTR() : NULL, - m.DATA_PTR(), - v.DATA_PTR(), - g.DATA_PTR(), - beta1, - beta2, - eps, - grad_scale, - step_size, - tsize, - (adamMode_t) mode, - decay); - ); - } else { - AT_ASSERTM(p_copy.scalar_type() == at::ScalarType::Byte, "expected parameter to be of byte type"); - DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(g.scalar_type(), 0, "adam_cuda_e5m2_kernel", - using accscalar_t = at::acc_type; - reversible_adam_cuda_kernel<<>>( - p.DATA_PTR(), - p_copy.DATA_PTR(), - m.DATA_PTR(), - v.DATA_PTR(), - g.DATA_PTR(), - beta1, - beta2, - eps, - grad_scale, - step_size, - tsize, - (adamMode_t) mode, - decay); - ); - } - } else { - using namespace at; - DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel", - reversible_adam_cuda_kernel<<>>( - p.DATA_PTR(), - NULL, //don't output p_copy for fp32, it's wasted write - m.DATA_PTR(), - v.DATA_PTR(), - g.DATA_PTR(), - beta1, - beta2, - eps, - grad_scale, - step_size, - tsize, - (adamMode_t) mode, - decay); - ); - } - C10_CUDA_CHECK(cudaGetLastError()); -} - -void maybe_cast_cuda( - at::Tensor & overflow_flag, - at::Tensor & p_in, - at::Tensor & p_out) -{ - //Get tensor size - int tsize = p_in.numel(); - AT_ASSERTM(tsize == p_out.numel(), "p_in.numel() must equal p_out.numel()"); - //Determine #threads and #blocks - const int threadsPerBlock = 512; - const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock); - AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p_in), "parameter tensor is too large to be indexed with int32"); - //Constants - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - DISPATCH_FLOAT_HALF_AND_BYTE(p_in.scalar_type(), 0, "maybe_cast_cuda" - DISPATCH_FLOAT_HALF_AND_BYTE(p_out.scalar_type(), 1, "maybe_cast_cuda", - maybe_cast_kernel<<>>( - overflow_flag.numel() ? overflow_flag.DATA_PTR() : NULL, - p_in.DATA_PTR(), - p_out.DATA_PTR(), - tsize); )) - C10_CUDA_CHECK(cudaGetLastError()); -} - -void maybe_cast_cuda_mt( - int chunk_size, - at::Tensor overflow_flag, - std::vector> tensor_lists) // p_in, p_out -{ - //Constants - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - size_t tl_sz = tensor_lists.size(); - AT_ASSERTM(tl_sz == 2, "expected tensor lists of size 2"); - - DISPATCH_FLOAT_HALF_AND_BYTE(tensor_lists[0][0].scalar_type(), 0, "maybe_cast_cuda_mt_kernel", - DISPATCH_FLOAT_HALF_AND_BYTE(tensor_lists[1][0].scalar_type(), 1, "maybe_cast_cuda_mt_kernel", - multi_tensor_apply<2>( - BLOCK_SIZE, - chunk_size, - overflow_flag, - tensor_lists, - MaybeCastFunctor<2, scalar_t_0, scalar_t_1>()); )) - C10_CUDA_CHECK(cudaGetLastError()); -} - -void fused_maybe_adam_undo_cuda( - at::Tensor & overflow_flag, - at::Tensor & p, - at::Tensor & m, - at::Tensor & v, - at::Tensor & g, - float lr, - float beta1, - float beta2, - float eps, - float grad_scale, - int step, - int mode, - int bias_correction, - float decay) -{ - //Get tensor size - int tsize = p.numel(); - //Determine #threads and #blocks - const int threadsPerBlock = 512; - const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock); - AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32"); - //Constants - float step_size = 0; - if (bias_correction == 1) { - const float bias_correction1 = 1 - std::pow(beta1, step); - const float bias_correction2 = 1 - std::pow(beta2, step); - step_size = lr * std::sqrt(bias_correction2)/bias_correction1; - } - else { - step_size = lr; - } - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - if (g.scalar_type() == at::ScalarType::Half || g.scalar_type() == at::ScalarType::BFloat16) { - //all other values should be fp32 for half gradients - AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type"); - //dispatch is done on the gradient type - using namespace at; // prevents "toString is undefined" errors - DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(g.scalar_type(), 0, "adam_cuda_kernel", - using accscalar_t = at::acc_type; - maybe_adam_undo_cuda_kernel<<>>( - overflow_flag.numel() ? overflow_flag.DATA_PTR() : NULL, - p.DATA_PTR(), - m.DATA_PTR(), - v.DATA_PTR(), - g.DATA_PTR(), - beta1, - beta2, - eps, - grad_scale, - step_size, - tsize, - (adamMode_t) mode, - decay); - ); - } else { - using namespace at; - DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel", - maybe_adam_undo_cuda_kernel<<>>( - overflow_flag.numel() ? overflow_flag.DATA_PTR() : NULL, - p.DATA_PTR(), - m.DATA_PTR(), - v.DATA_PTR(), - g.DATA_PTR(), - beta1, - beta2, - eps, - grad_scale, - step_size, - tsize, - (adamMode_t) mode, - decay); - ); - } - C10_CUDA_CHECK(cudaGetLastError()); -} diff --git a/apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp b/apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp deleted file mode 100644 index 98a2411..0000000 --- a/apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp +++ /dev/null @@ -1,21 +0,0 @@ -#include - -void multi_tensor_lamb_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - const float lr, - const float beta1, - const float beta2, - const float epsilon, - const int step, - const int bias_correction, - const float weight_decay, - const int grad_averaging, - const int mode, - const float global_grad_norm, - const float max_grad_norm); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("lamb", &multi_tensor_lamb_cuda, "Computes and apply update for LAMB optimizer"); -} diff --git a/apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu b/apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu deleted file mode 100644 index 3bb93b0..0000000 --- a/apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu +++ /dev/null @@ -1,294 +0,0 @@ -#include -#include -#include -#include -// Another possibility: -// #include - -#include - -#include "type_shim.h" -#include "multi_tensor_apply.cuh" - -#define BLOCK_SIZE 512 -#define ILP 4 - -typedef enum{ - MOMENT_MODE_0 =0, // L2 regularization mode - MOMENT_MODE_1 =1 // Decoupled weight decay mode -} adamMode_t; - -std::tuple multi_tensor_l2norm_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - at::optional per_tensor_python); - -using MATH_T = float; - -template -struct LAMBStage1Functor -{ - __device__ __forceinline__ void operator()( - int chunk_size, - volatile int* noop_gmem, - TensorListMetadata<4>& tl, - const float beta1, - const float beta2, - const float beta3, - const float beta1_correction, - const float beta2_correction, - const float epsilon, - adamMode_t mode, - const float decay, - const float global_grad_norm, - const float max_global_grad_norm) - { - // I'd like this kernel to propagate infs/nans. - // if(*noop_gmem == 1) - // return; - - int tensor_loc = tl.block_to_tensor[blockIdx.x]; - int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; - - float clipped_global_grad_norm = global_grad_norm > max_global_grad_norm ? global_grad_norm / max_global_grad_norm : 1.0f; - - T* g = (T*)tl.addresses[0][tensor_loc]; - g += chunk_idx*chunk_size; - - T* p = (T*)tl.addresses[1][tensor_loc]; - p += chunk_idx*chunk_size; - - T* m = (T*)tl.addresses[2][tensor_loc]; - m += chunk_idx*chunk_size; - - T* v = (T*)tl.addresses[3][tensor_loc]; - v += chunk_idx*chunk_size; - - n -= chunk_idx*chunk_size; - - // see note in multi_tensor_scale_kernel.cu - for(int i_start = 0; - i_start < n && i_start < chunk_size; - i_start += blockDim.x*ILP) - { - MATH_T r_g[ILP]; - MATH_T r_p[ILP]; - MATH_T r_m[ILP]; - MATH_T r_v[ILP]; -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) - { - r_g[ii] = g[i]; - // special ?optimization? for lamb stage 1 - if (decay == 0) { - r_p[ii] = MATH_T(0); - } - else { - r_p[ii] = p[i]; - } - r_m[ii] = m[i]; - r_v[ii] = v[i]; - } else { - r_g[ii] = MATH_T(0); - r_p[ii] = MATH_T(0); - r_m[ii] = MATH_T(0); - r_v[ii] = MATH_T(0); - } - } -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - if (mode == MOMENT_MODE_0) { - MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm; - // L2 on scaled grad - scaled_grad = scaled_grad + decay*r_p[ii]; - r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; - r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad; - MATH_T next_m_unbiased = r_m[ii] / beta1_correction; - MATH_T next_v_unbiased = r_v[ii] / beta2_correction; - MATH_T denom = sqrtf(next_v_unbiased) + epsilon; - r_p[ii] = next_m_unbiased / denom; - } - else { - MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm; - r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; - r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad; - MATH_T next_m_unbiased = r_m[ii] / beta1_correction; - MATH_T next_v_unbiased = r_v[ii] / beta2_correction; - MATH_T denom = sqrtf(next_v_unbiased) + epsilon; - r_p[ii] = (next_m_unbiased/denom) + (decay*r_p[ii]); - } - } -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) - { - g[i] = r_p[ii]; - m[i] = r_m[ii]; - v[i] = r_v[ii]; - } - } - } - } -}; - -// Step 2 reads in 'update' value and per-tensor param_norm and update_norm. -// It computes new parameter value. -template -struct LAMBStage2Functor -{ - __device__ __forceinline__ void operator()( - int chunk_size, - volatile int* noop_gmem, - TensorListMetadata<2>& tl, - const float* per_tensor_param_norm, - const float* per_tensor_update_norm, - const float learning_rate, - const float decay) - { - // I'd like this kernel to propagate infs/nans. - // if(*noop_gmem == 1) - // return; - - int tensor_loc = tl.block_to_tensor[blockIdx.x]; - int tensor_num = tl.start_tensor_this_launch + tensor_loc; - int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; - - MATH_T ratio = learning_rate; - // apply adaptive learning rate to parameters with non-zero weight decay - if (decay != 0.0) - { - float param_norm = per_tensor_param_norm[tensor_num]; - float update_norm = per_tensor_update_norm[tensor_num]; - ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate; - } - - T* update = (T*)tl.addresses[0][tensor_loc]; - update += chunk_idx*chunk_size; - - T* p = (T*)tl.addresses[1][tensor_loc]; - p += chunk_idx*chunk_size; - - n -= chunk_idx*chunk_size; - - for(int i_start = 0; - i_start < n && i_start < chunk_size; - i_start += blockDim.x*ILP) - { - MATH_T r_p[ILP]; - MATH_T r_update[ILP]; -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) - { - r_p[ii] = p[i]; - r_update[ii] = update[i]; - } - } -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - r_p[ii] = r_p[ii] - (ratio * r_update[ii]); - } -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) - { - p[i] = r_p[ii]; - } - } - } - } -}; - - -void multi_tensor_lamb_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - const float lr, - const float beta1, - const float beta2, - const float epsilon, - const int step, - const int bias_correction, - const float weight_decay, - const int grad_averaging, - const int mode, - const float global_grad_norm, - const float max_grad_norm) -{ - using namespace at; - // Master weight and 32bit momentum(potentially changing) is not handled by this - // So we assume every tensor are all in the same type - - // Handle bias correction mode - float bias_correction1 = 1.0f, bias_correction2 = 1.0f; - if (bias_correction == 1) { - bias_correction1 = 1 - std::pow(beta1, step); - bias_correction2 = 1 - std::pow(beta2, step); - } - - // Handle grad averaging mode - float beta3 = 1.0f; - if (grad_averaging == 1) beta3 = 1 - beta1; - - std::vector> grad_list(tensor_lists.begin(), tensor_lists.begin()+1); - std::vector> param_list(tensor_lists.begin()+1, tensor_lists.begin()+2); - - // Compute per tensor param norm - auto param_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, param_list, true); - - // We now in-place modify grad to store update before compute its norm - // Generally this is not a issue since people modify grad in step() method all the time - // We can also grab list of empty tensor to avoid this, but I'd like to save space/cpu code - DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1", - multi_tensor_apply<4>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - LAMBStage1Functor(), - beta1, - beta2, - beta3, // 1-beta1 or 1 depends on averaging mode - bias_correction1, - bias_correction2, - epsilon, - (adamMode_t) mode, - weight_decay, - global_grad_norm, - max_grad_norm); ) - - // Compute update norms - auto update_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, grad_list, true); - - std::vector> grad_param_list(tensor_lists.begin(), tensor_lists.begin()+2); - - DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2", - multi_tensor_apply<2>( - BLOCK_SIZE, - chunk_size, - noop_flag, - grad_param_list, - LAMBStage2Functor(), - std::get<1>(param_norm_tuple).DATA_PTR(), - std::get<1>(update_norm_tuple).DATA_PTR(), - lr, - weight_decay); ) - - AT_CUDA_CHECK(cudaGetLastError()); - -} diff --git a/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp b/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp deleted file mode 100644 index 7ae13d5..0000000 --- a/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp +++ /dev/null @@ -1,20 +0,0 @@ -#include - -void multi_tensor_fused_adam_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - at::Tensor per_tensor_beta1, - at::Tensor per_tensor_beta2, - at::Tensor per_tensor_bias_correction, - at::Tensor per_tensor_eps, - at::Tensor per_tensor_weight_decay, - float lr, - float grad_scale, - int step, - int mode); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("multi_tensor_fused_adam", &multi_tensor_fused_adam_cuda, - "Multi tensor Adam optimized CUDA implementation."); -} diff --git a/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu b/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu deleted file mode 100644 index f89fb59..0000000 --- a/apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu +++ /dev/null @@ -1,228 +0,0 @@ -#include -#include -#include -#include -// Another possibility: -// #include - -#include -#include -#include "type_shim.h" -#include "multi_tensor_apply.cuh" - -#define BLOCK_SIZE 512 -#define ILP 4 - -template -__device__ __forceinline__ bool is_aligned(T* p){ - return ((uint64_t)p) % (ILP*sizeof(T)) == 0; -} - -template -__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){ - typedef typename std::aligned_storage::type LT; - ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset]; -} - -typedef enum{ - ADAM_MODE_0 =0, // eps under square root - ADAM_MODE_1 =1 // eps outside square root -} adamMode_t; - -template -struct DistAdamFunctor -{ - __device__ __forceinline__ void operator()( - int chunk_size, - volatile int* noop_gmem, - TensorListMetadata& tl, - const float* per_tensor_beta1, - const float* per_tensor_beta2, - const int* per_tensor_bias_correction, - const float* per_tensor_eps, - const float* per_tensor_weight_decay, - const float lr, - const float grad_scale, - const int step, - adamMode_t mode) - { - int tensor_loc = tl.block_to_tensor[blockIdx.x]; - int tensor_num = tl.start_tensor_this_launch + tensor_loc; - int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; - - float b1 = per_tensor_beta1[tensor_num]; - float b2 = per_tensor_beta2[tensor_num]; - float eps = per_tensor_eps[tensor_num]; - float decay = per_tensor_weight_decay[tensor_num]; - - float beta1_correction = 1.0f, beta2_correction = 1.0f; - if (per_tensor_bias_correction[tensor_num] == 1) { - beta1_correction = 1 - std::pow(b1, step); - beta2_correction = 1 - std::pow(b2, step); - } - - T* p = (T *)tl.addresses[0][tensor_loc]; - p += chunk_idx*chunk_size; - T* m = (T *)tl.addresses[1][tensor_loc]; - m += chunk_idx*chunk_size; - T* v = (T *)tl.addresses[2][tensor_loc]; - v += chunk_idx*chunk_size; - GRAD_T* g = (GRAD_T *)tl.addresses[3][tensor_loc]; - g += chunk_idx*chunk_size; - GRAD_T* p_copy = NULL; - if (DEPTH == 5) { - p_copy = (GRAD_T *)tl.addresses[4][tensor_loc]; - p_copy += chunk_idx*chunk_size; - } - - n -= chunk_idx*chunk_size; - - T incoming_p[ILP]; - T incoming_m[ILP]; - T incoming_v[ILP]; - T incoming_g[ILP]; - - // to make things simple, we put aligned case in a different code path - if (n % ILP == 0 && - chunk_size % ILP == 0 && - is_aligned(p) && - is_aligned(m) && - is_aligned(v) && - is_aligned(g) && - is_aligned(p_copy)) { - for (int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x) { - // load - GRAD_T tmp_g[ILP]; - load_store(incoming_p, p, 0, i_start); - load_store(incoming_m, m, 0, i_start); - load_store(incoming_v, v, 0, i_start); - load_store(tmp_g, g, 0, i_start); -#pragma unroll - for (int ii = 0; ii < ILP; ii++) { - incoming_g[ii] = static_cast(tmp_g[ii]); - T scaled_grad = incoming_g[ii]/grad_scale; - incoming_m[ii] = b1*incoming_m[ii] + (1-b1)*scaled_grad; - incoming_v[ii] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad; - T next_m_unbiased = incoming_m[ii] / beta1_correction; - T next_v_unbiased = incoming_v[ii] / beta2_correction; - float denom; - if (mode == ADAM_MODE_0) - denom = sqrtf(next_v_unbiased + eps); - else // Mode 1 - denom = sqrtf(next_v_unbiased) + eps; - float update = (next_m_unbiased / denom) + (decay * incoming_p[ii]); - incoming_p[ii] = incoming_p[ii] - (lr * update); - if (DEPTH == 5) tmp_g[ii] = static_cast(incoming_p[ii]); - } - load_store(p, incoming_p, i_start, 0); - load_store(m, incoming_m, i_start, 0); - load_store(v, incoming_v, i_start, 0); - if (DEPTH == 5) load_store(p_copy, tmp_g, i_start, 0); - } - } else { - for (int i_start = 0; - i_start < n && i_start < chunk_size; - i_start += blockDim.x*ILP) { - -#pragma unroll - for (int ii = 0; ii < ILP; ii++) { - incoming_p[ii] = 0; - incoming_m[ii] = 0; - incoming_v[ii] = 0; - incoming_g[ii] = 0; - - int i = i_start + threadIdx.x + ii*blockDim.x; - if (i < n && i < chunk_size) { - incoming_p[ii] = p[i]; - incoming_m[ii] = m[i]; - incoming_v[ii] = v[i]; - incoming_g[ii] = static_cast(g[i]); - } - } - -#pragma unroll - for (int ii = 0; ii < ILP; ii++) { - int j = i_start + threadIdx.x + ii*blockDim.x; - - if (j < n && j < chunk_size) { - T scaled_grad = incoming_g[ii]/grad_scale; - m[j] = b1*incoming_m[ii] + (1-b1)*scaled_grad; - v[j] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad; - T next_m_unbiased = m[j] / beta1_correction; - T next_v_unbiased = v[j] / beta2_correction; - float denom; - if (mode == ADAM_MODE_0) - denom = sqrtf(next_v_unbiased + eps); - else // Mode 1 - denom = sqrtf(next_v_unbiased) + eps; - float update = (next_m_unbiased / denom) + (decay * incoming_p[ii]); - p[j] = incoming_p[ii] - (lr * update); - if (DEPTH == 5) p_copy[j] = (GRAD_T) p[j]; - } - } - } - } - } -}; - -void multi_tensor_fused_adam_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, // p, m, v, g, p_copy - at::Tensor per_tensor_beta1, - at::Tensor per_tensor_beta2, - at::Tensor per_tensor_bias_correction, - at::Tensor per_tensor_eps, - at::Tensor per_tensor_weight_decay, - float lr, - float grad_scale, - int step, - int mode) -{ - using namespace at; - - size_t tl_sz = tensor_lists.size(); - AT_ASSERTM(tl_sz == 4 || tl_sz == 5, "expected tensor lists of size 4 or 5"); - - if (tl_sz == 5) { - DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[3][0].scalar_type(), 0, "dist_adam_cuda_kernel", // g - using accscalar_t = at::acc_type; - multi_tensor_apply<5>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - DistAdamFunctor<5, accscalar_t, scalar_t_0>(), - per_tensor_beta1.DATA_PTR(), - per_tensor_beta2.DATA_PTR(), - per_tensor_bias_correction.DATA_PTR(), - per_tensor_eps.DATA_PTR(), - per_tensor_weight_decay.DATA_PTR(), - lr, - grad_scale, - step, - (adamMode_t) mode); - ); - } else { - DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[3][0].scalar_type(), 0, "dist_adam_cuda_kernel", // g - using accscalar_t = at::acc_type; - multi_tensor_apply<4>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - DistAdamFunctor<4, accscalar_t, scalar_t_0>(), - per_tensor_beta1.DATA_PTR(), - per_tensor_beta2.DATA_PTR(), - per_tensor_bias_correction.DATA_PTR(), - per_tensor_eps.DATA_PTR(), - per_tensor_weight_decay.DATA_PTR(), - lr, - grad_scale, - step, - (adamMode_t) mode); - ); - } - C10_CUDA_CHECK(cudaGetLastError()); -} diff --git a/apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp b/apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp deleted file mode 100644 index 584b2a0..0000000 --- a/apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp +++ /dev/null @@ -1,36 +0,0 @@ -#include - -void multi_tensor_lamb_compute_update_term_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - at::Tensor per_tensor_beta1, - at::Tensor per_tensor_beta2, - at::Tensor per_tensor_beta3, - at::Tensor per_tensor_bias_correction, - at::Tensor step, - at::Tensor per_tensor_epsilon, - const int mode, - at::Tensor per_tensor_decay, - at::Tensor global_scale, - at::Tensor global_grad_norm, - const float max_grad_norm); - -void multi_tensor_lamb_update_weights_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - at::Tensor per_tensor_param_norm, - at::Tensor per_tensor_update_norm, - at::Tensor update_norm_offset, - at::Tensor learning_rate, - at::Tensor per_tensor_decay, - at::Tensor global_grad_norm, - bool use_nvlamb); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("multi_tensor_lamb_compute_update_term", &multi_tensor_lamb_compute_update_term_cuda, - "Computes update term for LAMB optimizer"); - m.def("multi_tensor_lamb_update_weights", &multi_tensor_lamb_update_weights_cuda, - "Applies update term for LAMB optimizer"); -} diff --git a/apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu b/apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu deleted file mode 100644 index 95ee009..0000000 --- a/apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu +++ /dev/null @@ -1,506 +0,0 @@ -#include -#include -#include -#include -// Another possibility: -// #include - -#include - -#include "type_shim.h" -#include "multi_tensor_apply.cuh" - -#define BLOCK_SIZE 512 -#define ILP 4 - -template -__device__ __forceinline__ bool is_aligned(T* p){ - return ((uint64_t)p) % (ILP*sizeof(T)) == 0; -} - -template -__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){ - typedef typename std::aligned_storage::type LT; - ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset]; -} - -template -__device__ void convert(const FROM_T vi, TO_T& vo) -{ - vo = static_cast(vi); -} - -template <> -__device__ void convert(const float vi, uint8_t& vo) -{ - union S - { - float as_float; - int as_int; - }; - S s; - s.as_float = vi; - s.as_int = s.as_int & 0xFF800000; - union T - { - at::Half as_half; - uint8_t as_byte[2]; - }; - T t; - t.as_half = static_cast(vi + s.as_float / 8.0f); - vo = t.as_byte[1]; -} - -template <> -__device__ void convert(const uint8_t vi, float& vo) -{ - union T - { - at::Half as_half; - uint8_t as_byte[2]; - }; - T t; - t.as_byte[0] = 0; - t.as_byte[1] = vi; - vo = static_cast(t.as_half); -} - -template <> -__device__ void convert(const at::Half vi, uint8_t& vo) -{ - union S - { - float as_float; - int as_int; - }; - S s; - s.as_float = static_cast(vi); - s.as_int = s.as_int & 0xFF800000; - union T - { - at::Half as_half; - uint8_t as_byte[2]; - }; - T t; - t.as_half = static_cast(vi + s.as_float / 8.0f); - vo = t.as_byte[1]; -} - -template <> -__device__ void convert(const uint8_t vi, at::Half& vo) -{ - union T - { - at::Half as_half; - uint8_t as_byte[2]; - }; - T t; - t.as_byte[0] = 0; - t.as_byte[1] = vi; - vo = t.as_half; -} - -typedef enum{ - MOMENT_MODE_0 =0, // L2 regularization mode - MOMENT_MODE_1 =1 // Decoupled weight decay mode -} adamMode_t; - -template -struct DistOptLAMBStage1Functor -{ - __device__ __forceinline__ void operator()( - int chunk_size, - volatile int* noop_gmem, - TensorListMetadata<5>& tl, - const MATH_T* per_tensor_beta1, - const MATH_T* per_tensor_beta2, - const MATH_T* per_tensor_beta3, - const int* per_tensor_bias_correction, - const int* step, - const MATH_T* per_tensor_epsilon, - adamMode_t mode, - const MATH_T* per_tensor_decay, - const MATH_T* global_scale, - const MATH_T* global_grad_norm, - const float max_grad_norm) - { - // I'd like this kernel to propagate infs/nans. - if (*noop_gmem == 1) - return; - - int tensor_loc = tl.block_to_tensor[blockIdx.x]; - int tensor_num = tl.start_tensor_this_launch + tensor_loc; - int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; - - float combined_scale = *global_scale; - if (max_grad_norm > 0) { - combined_scale = max_grad_norm / (*global_grad_norm / *global_scale + 1e-6); - combined_scale = *global_scale / std::min((float) 1.0, combined_scale); - } - - MATH_T beta1 = per_tensor_beta1[tensor_num]; - MATH_T beta2 = per_tensor_beta2[tensor_num]; - MATH_T beta3 = 1 - beta1; - MATH_T beta1_correction, beta2_correction; - if (per_tensor_bias_correction[tensor_num] == 1) { - beta1_correction = 1 - pow(beta1, *step); - beta2_correction = 1 - pow(beta2, *step); - } else { - beta1_correction = (MATH_T) 1.0; - beta2_correction = (MATH_T) 1.0; - } - MATH_T epsilon = per_tensor_epsilon[tensor_num]; - MATH_T decay = per_tensor_decay[tensor_num]; - - GRAD_T* g = (GRAD_T*)tl.addresses[0][tensor_loc]; - g += chunk_idx*chunk_size; - - T* p = (T*)tl.addresses[1][tensor_loc]; - p += chunk_idx*chunk_size; - - T* m = (T*)tl.addresses[2][tensor_loc]; - m += chunk_idx*chunk_size; - - T* v = (T*)tl.addresses[3][tensor_loc]; - v += chunk_idx*chunk_size; - - MATH_T* u = (MATH_T*)tl.addresses[4][tensor_loc]; - u += chunk_idx*chunk_size; - - n -= chunk_idx*chunk_size; - - MATH_T r_g[ILP]; - MATH_T r_p[ILP]; - MATH_T r_m[ILP]; - MATH_T r_v[ILP]; - // to make things simple, we put aligned case in a different code path - if(n % ILP == 0 && - chunk_size % ILP == 0 && - is_aligned(g) && - is_aligned(p) && - is_aligned(m) && - is_aligned(v)) - { - GRAD_T l_g[ILP]; - T l_p[ILP]; - T l_m[ILP]; - T l_v[ILP]; - for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x) - { - // load - load_store(l_g, g, 0, i_start); - if (decay != 0) - load_store(l_p, p, 0, i_start); - load_store(l_m, m, 0, i_start); - load_store(l_v, v, 0, i_start); - // unpack -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - r_g[ii] = l_g[ii]; - if (decay == 0) { - r_p[ii] = MATH_T(0); - } - else { - r_p[ii] = l_p[ii]; - } - r_m[ii] = l_m[ii]; - r_v[ii] = l_v[ii]; - } -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - if (mode == MOMENT_MODE_0) { - MATH_T scaled_grad = r_g[ii] / combined_scale; - // L2 on scaled grad - scaled_grad = scaled_grad + decay*r_p[ii]; - r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; - r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad; - MATH_T next_m_unbiased = r_m[ii] / beta1_correction; - MATH_T next_v_unbiased = r_v[ii] / beta2_correction; - MATH_T denom = sqrtf(next_v_unbiased) + epsilon; - r_p[ii] = next_m_unbiased / denom; - } - else { - MATH_T scaled_grad = r_g[ii] / combined_scale; - r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; - r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad; - MATH_T next_m_unbiased = r_m[ii] / beta1_correction; - MATH_T next_v_unbiased = r_v[ii] / beta2_correction; - MATH_T denom = sqrtf(next_v_unbiased) + epsilon; - r_p[ii] = (next_m_unbiased/denom) + (decay*r_p[ii]); - } - } -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - l_m[ii] = r_m[ii]; - l_v[ii] = r_v[ii]; - } - // store - load_store(u, r_p, i_start, 0); - load_store(m, l_m, i_start, 0); - load_store(v, l_v, i_start, 0); - } - } - else - { - // see note in multi_tensor_scale_kernel.cu - for(int i_start = 0; - i_start < n && i_start < chunk_size; - i_start += blockDim.x*ILP) - { - MATH_T r_g[ILP]; - MATH_T r_p[ILP]; - MATH_T r_m[ILP]; - MATH_T r_v[ILP]; -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) - { - r_g[ii] = g[i]; - // special ?optimization? for lamb stage 1 - if (decay == 0) { - r_p[ii] = MATH_T(0); - } - else { - r_p[ii] = p[i]; - } - r_m[ii] = m[i]; - r_v[ii] = v[i]; - } else { - r_g[ii] = MATH_T(0); - r_p[ii] = MATH_T(0); - r_m[ii] = MATH_T(0); - r_v[ii] = MATH_T(0); - } - } -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - if (mode == MOMENT_MODE_0) { - MATH_T scaled_grad = r_g[ii] / combined_scale; - // L2 on scaled grad - scaled_grad = scaled_grad + decay*r_p[ii]; - r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; - r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad; - MATH_T next_m_unbiased = r_m[ii] / beta1_correction; - MATH_T next_v_unbiased = r_v[ii] / beta2_correction; - MATH_T denom = sqrtf(next_v_unbiased) + epsilon; - r_p[ii] = next_m_unbiased / denom; - } - else { - MATH_T scaled_grad = r_g[ii] / combined_scale; - r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; - r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad; - MATH_T next_m_unbiased = r_m[ii] / beta1_correction; - MATH_T next_v_unbiased = r_v[ii] / beta2_correction; - MATH_T denom = sqrtf(next_v_unbiased) + epsilon; - r_p[ii] = (next_m_unbiased/denom) + (decay*r_p[ii]); - } - } -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) - { - u[i] = r_p[ii]; - m[i] = r_m[ii]; - v[i] = r_v[ii]; - } - } - } - } - } -}; - -// Step 2 reads in 'update' value and per-tensor param_norm and update_norm. -// It computes new parameter value. -template -struct DistOptLAMBStage2Functor -{ - __device__ __forceinline__ void operator()( - int chunk_size, - volatile int* noop_gmem, - TensorListMetadata<3>& tl, - const MATH_T* per_tensor_param_norm, - const MATH_T* per_tensor_update_norm, - const long* update_norm_offset, - const MATH_T* learning_rate, - const MATH_T* per_tensor_decay, - const MATH_T* global_grad_norm, - bool use_nvlamb) - { - // I'd like this kernel to propagate infs/nans. - if (*noop_gmem == 1) - return; - - int tensor_loc = tl.block_to_tensor[blockIdx.x]; - int tensor_num = tl.start_tensor_this_launch + tensor_loc; - int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; - - MATH_T decay = per_tensor_decay[tensor_num]; - - MATH_T ratio = *learning_rate; - // nvlamb: apply adaptive learning rate to all parameters - // otherwise, only apply to those with non-zero weight decay - if (use_nvlamb || (decay != (MATH_T) 0.0)) - { - MATH_T param_norm = per_tensor_param_norm[tensor_num]; - MATH_T update_norm = per_tensor_update_norm[update_norm_offset[tensor_num]]; - ratio = (update_norm != 0.0 && param_norm != 0.0) ? (*learning_rate) * (param_norm / update_norm) : (*learning_rate); - } - - MATH_T* update = (MATH_T*)tl.addresses[0][tensor_loc]; - update += chunk_idx*chunk_size; - - T* p = (T*)tl.addresses[1][tensor_loc]; - p += chunk_idx*chunk_size; - - GRAD_T* p_copy = (GRAD_T*)tl.addresses[2][tensor_loc]; - p_copy += chunk_idx*chunk_size; - - n -= chunk_idx*chunk_size; - - // to make things simple, we put aligned case in a different code path - if(n % ILP == 0 && - chunk_size % ILP == 0 && - is_aligned(p) && - is_aligned(update)) - { - T r_p[ILP]; - MATH_T r_update[ILP]; - GRAD_T r_p_copy[ILP]; - for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x) - { - // load - load_store(r_p, p, 0, i_start); - load_store(r_update, update, 0, i_start); -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - r_p[ii] = static_cast(r_p[ii]) - (ratio * r_update[ii]); - convert(r_p[ii], r_p_copy[ii]); - } - load_store(p, r_p, i_start, 0); - load_store(p_copy, r_p_copy, i_start, 0); - } - } - else - { - for(int i_start = 0; - i_start < n && i_start < chunk_size; - i_start += blockDim.x*ILP) - { - MATH_T r_p[ILP]; - MATH_T r_update[ILP]; -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) - { - r_p[ii] = p[i]; - r_update[ii] = update[i]; - } - } -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - r_p[ii] = r_p[ii] - (ratio * r_update[ii]); - } -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) - { - p[i] = r_p[ii]; - convert(r_p[ii], p_copy[i]); - } - } - } - } - } -}; - -void multi_tensor_lamb_compute_update_term_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - at::Tensor per_tensor_beta1, - at::Tensor per_tensor_beta2, - at::Tensor per_tensor_beta3, - at::Tensor per_tensor_bias_correction, - at::Tensor step, - at::Tensor per_tensor_epsilon, - const int mode, - at::Tensor per_tensor_decay, - at::Tensor global_scale, - at::Tensor global_grad_norm, - const float max_grad_norm) -{ - using namespace at; - - DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 0, "lamb_stage_1", - DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 1, "lamb_stage_1", - DISPATCH_FLOAT_AND_HALF(tensor_lists[4][0].scalar_type(), 2, "lamb_stage_1", - multi_tensor_apply<5>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - DistOptLAMBStage1Functor(), - per_tensor_beta1.DATA_PTR(), - per_tensor_beta2.DATA_PTR(), - per_tensor_beta3.DATA_PTR(), - per_tensor_bias_correction.DATA_PTR(), - step.DATA_PTR(), - per_tensor_epsilon.DATA_PTR(), - (adamMode_t) mode, - per_tensor_decay.DATA_PTR(), - global_scale.DATA_PTR(), - global_grad_norm.DATA_PTR(), - max_grad_norm); ))) - - AT_CUDA_CHECK(cudaGetLastError()); -} - -void multi_tensor_lamb_update_weights_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - at::Tensor per_tensor_param_norm, - at::Tensor per_tensor_update_norm, - at::Tensor update_norm_offset, - at::Tensor learning_rate, - at::Tensor per_tensor_decay, - at::Tensor global_grad_norm, - bool use_nvlamb) -{ - using namespace at; - - DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 0, "lamb_stage_2", - DISPATCH_FLOAT_HALF_AND_BYTE(tensor_lists[2][0].scalar_type(), 1, "lamb_stage_2", - DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 2, "lamb_stage_2", - multi_tensor_apply<3>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - DistOptLAMBStage2Functor(), - per_tensor_param_norm.DATA_PTR(), - per_tensor_update_norm.DATA_PTR(), - update_norm_offset.DATA_PTR(), - learning_rate.DATA_PTR(), - per_tensor_decay.DATA_PTR(), - global_grad_norm.DATA_PTR(), - use_nvlamb); ))) - - AT_CUDA_CHECK(cudaGetLastError()); -} diff --git a/apex/contrib/csrc/peer_memory/peer_memory.cpp b/apex/contrib/csrc/peer_memory/peer_memory.cpp deleted file mode 100644 index 2c4f773..0000000 --- a/apex/contrib/csrc/peer_memory/peer_memory.cpp +++ /dev/null @@ -1,29 +0,0 @@ -/** - * Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "peer_memory_cuda.cuh" - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("allocate_raw", &apex::contrib::peer_memory::allocate_raw, "allocate_raw"); - m.def("free_raw", &apex::contrib::peer_memory::free_raw, "free_raw"); - m.def("zero", &apex::contrib::peer_memory::zero, "zero"); - m.def("get_raw_ipc_address", &apex::contrib::peer_memory::get_raw_ipc_address, "get_raw_ipc_address"); - m.def("get_raw_peers", &apex::contrib::peer_memory::get_raw_peers, "get_raw_peers"); - m.def("blob_view_half", &apex::contrib::peer_memory::blob_view_half, "blob_view_half"); - m.def("blob_view_float", &apex::contrib::peer_memory::blob_view_float, "blob_view_float"); - m.def("blob_view_int", &apex::contrib::peer_memory::blob_view_int, "blob_view_int"); - m.def("push_pull_halos_1d", &apex::contrib::peer_memory::push_pull_halos_1d, "push_pull_halos_1d"); -} diff --git a/apex/contrib/csrc/peer_memory/peer_memory_cuda.cu b/apex/contrib/csrc/peer_memory/peer_memory_cuda.cu deleted file mode 100644 index 61368eb..0000000 --- a/apex/contrib/csrc/peer_memory/peer_memory_cuda.cu +++ /dev/null @@ -1,750 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include - -#ifdef __HIP_PLATFORM_HCC__ -#include -#include "rccl/rccl.h" -#else -#include -#include "nccl.h" -#endif - -namespace cg = cooperative_groups; - -#define CUDACHECK(cmd) do { \ - cudaError_t err = cmd; \ - if( err != cudaSuccess ) { \ - char hostname[1024]; \ - gethostname(hostname, 1024); \ - printf("%s: CUDA failure %s:%d '%s'\n", \ - hostname, \ - __FILE__,__LINE__,cudaGetErrorString(err)); \ - } \ -} while(0) - -// C++17 removes 'register' storage keyword -#if __cplusplus < 201703L -#define REGISTER register -#else -#define REGISTER -#endif - -namespace { - -/* Basic deleter function for from_blob function. -void deleter(void* ptr) -{ - printf("deleter(ptr=%p)\n",ptr); - cudaFree(ptr); -} -*/ - -template -at::Tensor blob_view(T* raw_ptr, std::vector shape, const at::TensorOptions& options, bool channels_last) -{ - size_t size = 1; - std::vector strides(shape.size()); - if (channels_last) { - assert(shape.size() == 4); - strides[0] = shape[1]*shape[2]*shape[3]; - strides[1] = 1; - strides[2] = shape[1]*shape[3]; - strides[3] = shape[1]; - } else { - int idx = strides.size(); - for (auto it = shape.rbegin(); it != shape.rend(); ++it) - { - strides[--idx] = size; - size *= *it; - } - } - size *= sizeof(T); - // TODO: Implement dynamic reuse of pooled peer memory. - // We provide no deleter function because all peer memory allocations are static in this implementation. - return torch::from_blob((void*)raw_ptr, shape, strides, 0L, options); -} - -void tensor_shape(at::Tensor t, bool explicit_nhwc, int& N, int& C, int& H, int& W) -{ - if (t.dim() == 3) { - N = 1; - if (explicit_nhwc) { - C = t.size(2); - H = t.size(0); - W = t.size(1); - } else { - C = t.size(0); - H = t.size(1); - W = t.size(2); - } - } else if (t.dim() == 4) { - if (explicit_nhwc) { - N = t.size(0); - C = t.size(3); - H = t.size(1); - W = t.size(2); - } else { - N = t.size(0); - C = t.size(1); - H = t.size(2); - W = t.size(3); - } - } else { - printf("%s;%d - t.dim() must be either 3 or 4 (was %d)\n",__FILE__,__LINE__,t.dim()); - assert(t.dim() == 3 || t.dim() == 4); - } -} - -void tensor_strides(at::Tensor t, bool explicit_nhwc, int& stride_N, int& stride_C, int& stride_H, int& stride_W) -{ - if (t.dim() == 3) { - if (explicit_nhwc) { - stride_C = t.stride(2); - stride_H = t.stride(0); - stride_W = t.stride(1); - } else { - stride_C = t.stride(0); - stride_H = t.stride(1); - stride_W = t.stride(2); - } - stride_N = t.size(0)*t.size(1)*t.size(2); - } else if (t.dim() == 4) { - if (explicit_nhwc) { - stride_N = t.stride(0); - stride_C = t.stride(3); - stride_H = t.stride(1); - stride_W = t.stride(2); - } else { - stride_N = t.stride(0); - stride_C = t.stride(1); - stride_H = t.stride(2); - stride_W = t.stride(3); - } - } else { - printf("%s;%d - t.dim() must be either 3 or 4 (was %d)\n",__FILE__,__LINE__,t.dim()); - assert(t.dim() == 3 || t.dim() == 4); - } -} - -template -__device__ void __zero(T* dst) -{ - *dst = T(0); -} - -__device__ void __zero(int4* dst) -{ - int4 v; - v.x = v.y = v.z = v.w = 0; - *dst = v; -} - -template -__device__ void strided_copy_kernel( - T* dst, const int dst_stride_C, const int dst_stride_H, const int dst_stride_W, - const T* src, const int src_stride_C, const int src_stride_H, const int src_stride_W, - const int NC, const int NH, const int NW - ) -{ - size_t tot_num_threads = gridDim.x * blockDim.x; - size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x; - const size_t count = NC*NH*NW; - for (size_t i = thread_id; i < count; i += tot_num_threads) - { - size_t c,h,w; - if (is_HWC) { - w = i / NC; - c = i - w * NC; - h = w / NW; - w = w - h * NW; - } - else { - h = i / NW; - w = i - h * NW; - c = h / NH; - h = h - c * NH; - } - size_t dst_off = c*dst_stride_C + h*dst_stride_H + w*dst_stride_W; - if (zero) { - __zero(dst+dst_off); - } else { - size_t src_off = c*src_stride_C + h*src_stride_H + w*src_stride_W; - dst[dst_off] = src[src_off]; - } - } -} - -template -__device__ void checked_signal( - volatile int* signal1_flag, volatile int* signal2_flag, - const int v1, const int v2, const int v3, const int v4 - ) -{ - cg::this_grid().sync(); - bool is_main_thread = (blockIdx.x == 0 && threadIdx.x == 0) ? true : false; - if (is_main_thread) { - // flush all writes to global memory - __threadfence_system(); - // wait for top or bottom neighbor to clear signal - REGISTER int r1, r2, r3, r4; - if (!(top_zero || btm_zero)) { - bool top_zeroed=false, top_done=false; - bool btm_zeroed=false, btm_done=false; - do { - do { - if (!top_zeroed) { -#ifdef __HIP_PLATFORM_HCC__ - r1 = __builtin_nontemporal_load(signal1_flag); - r2 = __builtin_nontemporal_load(signal1_flag + 1); - r3 = __builtin_nontemporal_load(signal1_flag + 2); - r4 = __builtin_nontemporal_load(signal1_flag + 3); -#else - asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(signal1_flag) : "memory"); -#endif - if (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4) top_zeroed = true; - } - if (!btm_zeroed) { -#ifdef __HIP_PLATFORM_HCC__ - r1 = __builtin_nontemporal_load(signal2_flag); - r2 = __builtin_nontemporal_load(signal2_flag + 1); - r3 = __builtin_nontemporal_load(signal2_flag + 2); - r4 = __builtin_nontemporal_load(signal2_flag + 3); -#else - asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(signal2_flag) : "memory"); -#endif - if (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4) btm_zeroed = true; - } - } while((top_zeroed == top_done) && (btm_zeroed == btm_done)); - if (!top_done && top_zeroed) { - // signal to top neighbor my output is ready -#ifdef __HIP_PLATFORM_HCC__ - __builtin_nontemporal_store(v1, signal1_flag); - __builtin_nontemporal_store(v2, signal1_flag + 1); - __builtin_nontemporal_store(v3, signal1_flag + 2); - __builtin_nontemporal_store(v4, signal1_flag + 3); -#else - asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal1_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory"); -#endif - top_done = true; - } - if (!btm_done && btm_zeroed) { - // signal to bottom neighbor my output is ready -#ifdef __HIP_PLATFORM_HCC__ - __builtin_nontemporal_store(v1, signal2_flag); - __builtin_nontemporal_store(v2, signal2_flag + 1); - __builtin_nontemporal_store(v3, signal2_flag + 2); - __builtin_nontemporal_store(v4, signal2_flag + 3); -#else - asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal2_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory"); -#endif - btm_done = true; - } - } while (!top_done || !btm_done); - } else if (top_zero) { - bool btm_zeroed=false, btm_done=false; - do { - do { - if (!btm_zeroed) { -#ifdef __HIP_PLATFORM_HCC__ - r1 = __builtin_nontemporal_load(signal2_flag); - r2 = __builtin_nontemporal_load(signal2_flag + 1); - r3 = __builtin_nontemporal_load(signal2_flag + 2); - r4 = __builtin_nontemporal_load(signal2_flag + 3); -#else - asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(signal2_flag) : "memory"); -#endif - if (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4) btm_zeroed = true; - } - } while(btm_zeroed == btm_done); - if (!btm_done && btm_zeroed) { - // signal to bottom neighbor my output is ready -#ifdef __HIP_PLATFORM_HCC__ - __builtin_nontemporal_store(v1, signal2_flag); - __builtin_nontemporal_store(v2, signal2_flag + 1); - __builtin_nontemporal_store(v3, signal2_flag + 2); - __builtin_nontemporal_store(v4, signal2_flag + 3); -#else - asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal2_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory"); -#endif - btm_done = true; - } - } while (!btm_done); - - } else if (btm_zero) { - bool top_zeroed=false, top_done=false; - do { - do { - if (!top_zeroed) { -#ifdef __HIP_PLATFORM_HCC__ - r1 = __builtin_nontemporal_load(signal1_flag); - r2 = __builtin_nontemporal_load(signal1_flag + 1); - r3 = __builtin_nontemporal_load(signal1_flag + 2); - r4 = __builtin_nontemporal_load(signal1_flag + 3); -#else - asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(signal1_flag) : "memory"); -#endif - if (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4) top_zeroed = true; - } - } while(top_zeroed == top_done); - if (!top_done && top_zeroed) { - // signal to top neighbor my output is ready -#ifdef __HIP_PLATFORM_HCC__ - __builtin_nontemporal_store(v1, signal1_flag); - __builtin_nontemporal_store(v2, signal1_flag + 1); - __builtin_nontemporal_store(v3, signal1_flag + 2); - __builtin_nontemporal_store(v4, signal1_flag + 3); -#else - asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal1_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory"); -#endif - top_done = true; - } - } while (!top_done); - } - } -} - -__device__ void wait_for( - volatile int* wait_flag, - const int v1, const int v2, const int v3, const int v4 - ) -{ - bool is_main_thread = (blockIdx.x == 0 && threadIdx.x == 0) ? true : false; - if (is_main_thread) { - REGISTER int r1, r2, r3, r4; - // wait for senders to signal their output is read - do { -#ifdef __HIP_PLATFORM_HCC__ - r1 = __builtin_nontemporal_load(wait_flag); - r2 = __builtin_nontemporal_load(wait_flag + 1); - r3 = __builtin_nontemporal_load(wait_flag + 2); - r4 = __builtin_nontemporal_load(wait_flag + 3); -#else - asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(wait_flag) : "memory"); -#endif - } while (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4); - } - cg::this_grid().sync(); // all threads wait for main -} - - -__device__ void clear_flag( - volatile int* wait_flag - ) -{ - cg::this_grid().sync(); // wait for all threads in kernel to finish - bool is_main_thread = (blockIdx.x == 0 && threadIdx.x == 0) ? true : false; - if (is_main_thread) { - REGISTER int r1, r2, r3, r4; - r1 = 0; r2 = 0; r3 = 0; r4 = 0; -#ifdef __HIP_PLATFORM_HCC__ - __builtin_nontemporal_store(r1, wait_flag); - __builtin_nontemporal_store(r2, wait_flag + 1); - __builtin_nontemporal_store(r3, wait_flag + 2); - __builtin_nontemporal_store(r4, wait_flag + 3); -#else - asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(wait_flag), "r"(r1), "r"(r2), "r"(r3), "r"(r4) : "memory"); -#endif - } -} - -template -#if __CUDA_ARCH__ == 700 || __CUDA_ARCH__ == 800 || __CUDA_ARCH__ == 900 -__launch_bounds__(128, 16) -#endif -__global__ void push_pull_halos_1d_kernel( - // top halo, - const T* toh, int toh_stride_C, int toh_stride_H, int toh_stride_W, // top output halo - T* tox, int tox_stride_C, int tox_stride_H, int tox_stride_W, // top output tx buffer - T* tix, int tix_stride_C, int tix_stride_H, int tix_stride_W, // top input tx buffer - T* tih, int tih_stride_C, int tih_stride_H, int tih_stride_W, // top input halo - // btm halo - const T* boh, int boh_stride_C, int boh_stride_H, int boh_stride_W, // btm output halo - T* box, int box_stride_C, int box_stride_H, int box_stride_W, // btm output tx buffer - T* bix, int bix_stride_C, int bix_stride_H, int bix_stride_W, // btm input tx buffer - T* bih, int bih_stride_C, int bih_stride_H, int bih_stride_W, // btm input halo - // dimensions - int NC, int NH, int NW, - // signals - int* signal1_flag, - int* signal2_flag, - int* wait1_flag, - int* wait2_flag - ) -{ - // push top output halo to transfer buffer - if (!top_zero) strided_copy_kernel(tox, tox_stride_C, tox_stride_H, tox_stride_W, toh, toh_stride_C, toh_stride_H, toh_stride_W, NC, NH, NW); - // push btm output halo to transfer buffer - if (!btm_zero) strided_copy_kernel(box, box_stride_C, box_stride_H, box_stride_W, boh, boh_stride_C, boh_stride_H, boh_stride_W, NC, NH, NW); - // signal to top and btm neigbhbors that output halos are ready to be read - // the choice of values for v1-v4 is arbitrary and does not matter, as long as all ranks use the same values - if (!(top_zero || btm_zero)) { - checked_signal(signal1_flag, signal2_flag, -987751720, 840868300, -225529332, 281513358); - } else if (top_zero) { - checked_signal(signal1_flag, signal2_flag, -987751720, 840868300, -225529332, 281513358); - } else if (btm_zero) { - checked_signal(signal1_flag, signal2_flag, -987751720, 840868300, -225529332, 281513358); - } - // pull top halo from transfer buffer in peer memory to input - if (top_zero) { - strided_copy_kernel(tih, tih_stride_C, tih_stride_H, tih_stride_W, tix, tix_stride_C, tix_stride_H, tix_stride_W, NC, NH, NW); - } else { - wait_for(wait1_flag, -987751720, 840868300, -225529332, 281513358); - strided_copy_kernel(tih, tih_stride_C, tih_stride_H, tih_stride_W, tix, tix_stride_C, tix_stride_H, tix_stride_W, NC, NH, NW); - clear_flag(wait1_flag); - } - // pull btm halo from transfer buffer in peer memory to input - if (btm_zero) { - strided_copy_kernel(bih, bih_stride_C, bih_stride_H, bih_stride_W, bix, bix_stride_C, bix_stride_H, bix_stride_W, NC, NH, NW); - } else { - wait_for(wait2_flag, -987751720, 840868300, -225529332, 281513358); - strided_copy_kernel(bih, bih_stride_C, bih_stride_H, bih_stride_W, bix, bix_stride_C, bix_stride_H, bix_stride_W, NC, NH, NW); - clear_flag(wait2_flag); - } -} - -__global__ void delay_kernel(int delay_nanoseconds, int* counter) -{ - if (blockIdx.x == 0 && threadIdx.x == 0) { - // waste time while doing something compiler can't predict, thus preventing it from optimizing away this code. - int new_counter = 0; - double elapsed = 0; - clock_t start = clock(); - do { - clock_t now = clock(); - elapsed = (double)(now - start)*1e9 / CLOCKS_PER_SEC; - ++new_counter; - } while (elapsed < (double)delay_nanoseconds); - *counter = new_counter; - } -} - -} - -namespace apex { namespace contrib { namespace peer_memory { - -int64_t allocate_raw(int64_t size) -{ - float* ptr = 0L; - cudaMalloc(&ptr, size); - cudaMemset(ptr, 0, size); - return (int64_t)ptr; -} - -void free_raw(int64_t raw) -{ - cudaFree((void*)raw); -} - -void zero(int64_t raw, int64_t size) -{ - cudaMemset((void*)raw, 0, size); -} - -at::Tensor get_raw_ipc_address(int64_t raw) -{ - cudaIpcMemHandle_t mem_handle; - CUDACHECK( cudaIpcGetMemHandle(&mem_handle, (void*)raw) ); - const int n = sizeof(cudaIpcMemHandle_t); - auto address_tensor = torch::empty({n}, torch::dtype(torch::kUInt8)); - auto address_tensor_p = address_tensor.data_ptr(); - memcpy(address_tensor_p, (uint8_t*)&mem_handle, n); - return address_tensor; -} - -std::vector get_raw_peers(at::Tensor ipc_addresses, int peer_rank, int64_t raw) -{ - int peer_group_size = ipc_addresses.size(0); - std::vector results(peer_group_size); - for (int i = 0; i < peer_group_size; ++i) { - if (i != peer_rank) { - cudaIpcMemHandle_t mem_handle; - memcpy(&mem_handle, ipc_addresses.index({i}).data_ptr(), sizeof(cudaIpcMemHandle_t)); - void* p = 0L; - CUDACHECK( cudaIpcOpenMemHandle((void**)&p, mem_handle, cudaIpcMemLazyEnablePeerAccess) ); - results[i] = (int64_t)p; - } else { - results[i] = (int64_t)raw; - } - } - return results; -} - -at::Tensor blob_view_half(int64_t raw, std::vector shape, bool channels_last) -{ - return blob_view((at::Half*)raw, shape, torch::dtype(torch::kFloat16).device(torch::kCUDA), channels_last); -} - -at::Tensor blob_view_float(int64_t raw, std::vector shape, bool channels_last) -{ - return blob_view((float*)raw, shape, torch::dtype(torch::kFloat32).device(torch::kCUDA), channels_last); -} - -at::Tensor blob_view_int(int64_t raw, std::vector shape, bool channels_last) -{ - return blob_view((int*)raw, shape, torch::dtype(torch::kInt32).device(torch::kCUDA), channels_last); -} - -void push_pull_halos_1d( - bool diagnostics, - bool explicit_nhwc, - int numSM, // number of SMs to use - bool top_zero, // true if top halo should be zeroed - at::Tensor top_out_halo, // top output halo in sender device memory - at::Tensor top_out_tx, // top output transfer buffer in sender peer pool memory - at::Tensor top_inp_tx, // top input transfer buffer in top neighbor peer pool memory - at::Tensor top_inp_halo, // top input halo in receiver device memory - bool btm_zero, // true if btm halo should be zeroed - at::Tensor btm_out_halo, // btm output halo in sender device memory - at::Tensor btm_out_tx, // btm output transfer buffer in sender peer pool memory - at::Tensor btm_inp_tx, // btm input transfer buffer in btm neighbor peer pool memory - at::Tensor btm_inp_halo, // btm input halo in receiver device memory - at::Tensor top_signal, // top input signal in receiver device memory - at::Tensor btm_signal, // btm input signal in receiver device memory - at::Tensor waits // top and btm signals for this rank - ) -{ - // basic checks of inputs - TORCH_CHECK(top_out_halo.is_cuda()); - TORCH_CHECK(top_out_tx.is_cuda()); - TORCH_CHECK(top_inp_tx.is_cuda()); - TORCH_CHECK(top_inp_halo.is_cuda()); - TORCH_CHECK(btm_out_halo.is_cuda()); - TORCH_CHECK(btm_out_tx.is_cuda()); - TORCH_CHECK(btm_inp_tx.is_cuda()); - TORCH_CHECK(btm_inp_halo.is_cuda()); - TORCH_CHECK(top_signal.is_cuda()); - TORCH_CHECK(btm_signal.is_cuda()); - TORCH_CHECK(waits.is_cuda()); - TORCH_CHECK(!(top_zero && btm_zero)); - - // shapes and strides - int toh_N, toh_C, toh_H, toh_W; - tensor_shape(top_out_halo, explicit_nhwc, toh_N, toh_C, toh_H, toh_W); - int tox_N, tox_C, tox_H, tox_W; - tensor_shape(top_out_tx, explicit_nhwc, tox_N, tox_C, tox_H, tox_W); - int tix_N, tix_C, tix_H, tix_W; - tensor_shape(top_inp_tx, explicit_nhwc, tix_N, tix_C, tix_H, tix_W); - int tih_N, tih_C, tih_H, tih_W; - tensor_shape(top_inp_halo, explicit_nhwc, tih_N, tih_C, tih_H, tih_W); - TORCH_CHECK( - (toh_N == tox_N && tox_N == tix_N && tix_N == tih_N) && - (toh_C == tox_C && tox_C == tix_C && tix_C == tih_C) && - (toh_H == tox_H && tox_H == tix_H && tix_H == tih_H) && - (toh_W == tox_W && tox_W == tix_W && tix_W == tih_W)); - int boh_N, boh_C, boh_H, boh_W; - tensor_shape(btm_out_halo, explicit_nhwc, boh_N, boh_C, boh_H, boh_W); - int box_N, box_C, box_H, box_W; - tensor_shape(btm_out_tx, explicit_nhwc, box_N, box_C, box_H, box_W); - int bix_N, bix_C, bix_H, bix_W; - tensor_shape(btm_inp_tx, explicit_nhwc, bix_N, bix_C, bix_H, bix_W); - int bih_N, bih_C, bih_H, bih_W; - tensor_shape(btm_inp_halo, explicit_nhwc, bih_N, bih_C, bih_H, bih_W); - TORCH_CHECK( - (boh_N == box_N && box_N == bix_N && bix_N == bih_N) && - (boh_C == box_C && box_C == bix_C && bix_C == bih_C) && - (boh_H == box_H && box_H == bix_H && bix_H == bih_H) && - (boh_W == box_W && box_W == bix_W && bix_W == bih_W)); - TORCH_CHECK( - (toh_N == boh_N) && - (toh_C == boh_C) && - (toh_H == boh_H) && - (toh_W == boh_W)); - int NC=toh_C, NH=toh_H, NW=toh_W; - if (diagnostics) printf("NC=%d, NH=%d, NW=%d\n",NC,NH,NW); - - int toh_stride_N, toh_stride_C, toh_stride_H, toh_stride_W; - tensor_strides(top_out_halo, explicit_nhwc, toh_stride_N, toh_stride_C, toh_stride_H, toh_stride_W); - int tox_stride_N, tox_stride_C, tox_stride_H, tox_stride_W; - tensor_strides(top_out_tx, explicit_nhwc, tox_stride_N, tox_stride_C, tox_stride_H, tox_stride_W); - int tix_stride_N, tix_stride_C, tix_stride_H, tix_stride_W; - tensor_strides(top_inp_tx, explicit_nhwc, tix_stride_N, tix_stride_C, tix_stride_H, tix_stride_W); - int tih_stride_N, tih_stride_C, tih_stride_H, tih_stride_W; - tensor_strides(top_inp_halo, explicit_nhwc, tih_stride_N, tih_stride_C, tih_stride_H, tih_stride_W); - int boh_stride_N, boh_stride_C, boh_stride_H, boh_stride_W; - tensor_strides(btm_out_halo, explicit_nhwc, boh_stride_N, boh_stride_C, boh_stride_H, boh_stride_W); - int box_stride_N, box_stride_C, box_stride_H, box_stride_W; - tensor_strides(btm_out_tx, explicit_nhwc, box_stride_N, box_stride_C, box_stride_H, box_stride_W); - int bix_stride_N, bix_stride_C, bix_stride_H, bix_stride_W; - tensor_strides(btm_inp_tx, explicit_nhwc, bix_stride_N, bix_stride_C, bix_stride_H, bix_stride_W); - int bih_stride_N, bih_stride_C, bih_stride_H, bih_stride_W; - tensor_strides(btm_inp_halo, explicit_nhwc, bih_stride_N, bih_stride_C, bih_stride_H, bih_stride_W); - - // determine if nhwc - auto is_nhwc = (toh_stride_C == 1) ? true : false; - if (diagnostics) printf("is_nhwc = %s\n",is_nhwc?"true":"false"); - - // figure out launch parameters - int device; - cudaGetDevice(&device); - cudaDeviceProp prop; - cudaGetDeviceProperties(&prop, device); - assert(numSM > 0 && numSM <= prop.multiProcessorCount); - auto current_stream = at::cuda::getCurrentCUDAStream(); - const int numThreads = 128; - dim3 block(numThreads,1,1); - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, top_out_halo.scalar_type(), "push_pull_halos_1d_kernel", [&]{ - if (diagnostics) printf("size(scalar_t) = %ld\n",sizeof(scalar_t)); - scalar_t* toh_p = top_out_halo.data_ptr(); - scalar_t* tox_p = top_out_tx.data_ptr(); - scalar_t* tix_p = top_inp_tx.data_ptr(); - scalar_t* tih_p = top_inp_halo.data_ptr(); - scalar_t* boh_p = btm_out_halo.data_ptr(); - scalar_t* box_p = btm_out_tx.data_ptr(); - scalar_t* bix_p = btm_inp_tx.data_ptr(); - scalar_t* bih_p = btm_inp_halo.data_ptr(); - if (diagnostics) printf("waypoint1\n"); - int* top_signal_p = top_signal.data_ptr() + 4; - int* btm_signal_p = btm_signal.data_ptr(); - int* top_wait_p = waits.data_ptr(); - int* btm_wait_p = waits.data_ptr() + 4; - if (diagnostics) printf("waypoint2\n"); - - // do int4 vector loads if channel count permits - int elem_size_in_bytes = toh_C * sizeof(scalar_t); - int elem_size_in_int4 = (elem_size_in_bytes / 16); - if (diagnostics) printf("elem_size_in_bytes = %d, elem_size_in_int4 = %d\n",elem_size_in_bytes,elem_size_in_int4); - if (is_nhwc && elem_size_in_int4*16 == elem_size_in_bytes) { - // can do int4 transfers - int divisor = toh_C / elem_size_in_int4; - if (diagnostics) printf("CAN DO INT4 :: divisor = %d\n",divisor); - toh_stride_N /= divisor; toh_stride_H /= divisor; toh_stride_W /= divisor; - tox_stride_N /= divisor; tox_stride_H /= divisor; tox_stride_W /= divisor; - tix_stride_N /= divisor; tix_stride_H /= divisor; tix_stride_W /= divisor; - tih_stride_N /= divisor; tih_stride_H /= divisor; tih_stride_W /= divisor; - boh_stride_N /= divisor; boh_stride_H /= divisor; boh_stride_W /= divisor; - box_stride_N /= divisor; box_stride_H /= divisor; box_stride_W /= divisor; - bix_stride_N /= divisor; bix_stride_H /= divisor; bix_stride_W /= divisor; - bih_stride_N /= divisor; bih_stride_H /= divisor; bih_stride_W /= divisor; - NC /= divisor; - if (diagnostics) { - printf("divisor=%d\n",divisor); - printf("toh_stride :: N=%d, C=%d, H=%d, W=%d\n",toh_stride_N,toh_stride_C,toh_stride_H,toh_stride_W); - printf("tox_stride :: N=%d, C=%d, H=%d, W=%d\n",tox_stride_N,tox_stride_C,tox_stride_H,tox_stride_W); - printf("tix_stride :: N=%d, C=%d, H=%d, W=%d\n",tix_stride_N,tix_stride_C,tix_stride_H,tix_stride_W); - printf("tih_stride :: N=%d, C=%d, H=%d, W=%d\n",tih_stride_N,tih_stride_C,tih_stride_H,tih_stride_W); - printf("boh_stride :: N=%d, C=%d, H=%d, W=%d\n",boh_stride_N,boh_stride_C,boh_stride_H,boh_stride_W); - printf("box_stride :: N=%d, C=%d, H=%d, W=%d\n",box_stride_N,box_stride_C,box_stride_H,box_stride_W); - printf("bix_stride :: N=%d, C=%d, H=%d, W=%d\n",bix_stride_N,bix_stride_C,bix_stride_H,bix_stride_W); - printf("bih_stride :: N=%d, C=%d, H=%d, W=%d\n",bih_stride_N,bih_stride_C,bih_stride_H,bih_stride_W); - printf("NC=%d, NH=%d, NW=%d\n",NC,NH,NW); - } - void *kernelArgs[] = { - (int4**)&toh_p, &toh_stride_C, &toh_stride_H, &toh_stride_W, - (int4**)&tox_p, &tox_stride_C, &tox_stride_H, &tox_stride_W, - (int4**)&tix_p, &tix_stride_C, &tix_stride_H, &tix_stride_W, - (int4**)&tih_p, &tih_stride_C, &tih_stride_H, &tih_stride_W, - (int4**)&boh_p, &boh_stride_C, &boh_stride_H, &boh_stride_W, - (int4**)&box_p, &box_stride_C, &box_stride_H, &box_stride_W, - (int4**)&bix_p, &bix_stride_C, &bix_stride_H, &bix_stride_W, - (int4**)&bih_p, &bih_stride_C, &bih_stride_H, &bih_stride_W, - &NC, &NH, &NW, - &top_signal_p, &btm_signal_p, &top_wait_p, &btm_wait_p - }; - if (top_zero) { - int numBlocksPerSm; - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel, numThreads, 0); - dim3 grid(numSM*numBlocksPerSm,1,1); -#ifdef __HIP_PLATFORM_HCC__ - hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel, grid, block, kernelArgs, 0, current_stream); -#else - cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel, grid, block, kernelArgs, 0, current_stream); -#endif - } else if (btm_zero) { - int numBlocksPerSm; - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel, numThreads, 0); - dim3 grid(numSM*numBlocksPerSm,1,1); -#ifdef __HIP_PLATFORM_HCC__ - hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel, grid, block, kernelArgs, 0, current_stream); -#else - cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel, grid, block, kernelArgs, 0, current_stream); -#endif - } else { - int numBlocksPerSm; - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel, numThreads, 0); - dim3 grid(numSM*numBlocksPerSm,1,1); -#ifdef __HIP_PLATFORM_HCC__ - hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel, grid, block, kernelArgs, 0, current_stream); -#else - cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel, grid, block, kernelArgs, 0, current_stream); -#endif - } - } else { - // cannot do int4 transfers - if (diagnostics) printf("CAN NOT DO INT4\n"); - void *kernelArgs[] = { - &toh_p, &toh_stride_C, &toh_stride_H, &toh_stride_W, - &tox_p, &tox_stride_C, &tox_stride_H, &tox_stride_W, - &tix_p, &tix_stride_C, &tix_stride_H, &tix_stride_W, - &tih_p, &tih_stride_C, &tih_stride_H, &tih_stride_W, - &boh_p, &boh_stride_C, &boh_stride_H, &boh_stride_W, - &box_p, &box_stride_C, &box_stride_H, &box_stride_W, - &bix_p, &bix_stride_C, &bix_stride_H, &bix_stride_W, - &bih_p, &bih_stride_C, &bih_stride_H, &bih_stride_W, - &NC, &NH, &NW, - &top_signal_p, &btm_signal_p, &top_wait_p, &btm_wait_p - }; - int numBlocksPerSm; - if (is_nhwc) { - if (top_zero) { - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel, numThreads, 0); - dim3 grid(numSM*numBlocksPerSm,1,1); -#ifdef __HIP_PLATFORM_HCC__ - hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel, grid, block, kernelArgs, 0, current_stream); -#else - cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel, grid, block, kernelArgs, 0, current_stream); -#endif - } else if (btm_zero) { - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel, numThreads, 0); - dim3 grid(numSM*numBlocksPerSm,1,1); -#ifdef __HIP_PLATFORM_HCC__ - hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel, grid, block, kernelArgs, 0, current_stream); -#else - cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel, grid, block, kernelArgs, 0, current_stream); -#endif - } else { - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel, numThreads, 0); - dim3 grid(numSM*numBlocksPerSm,1,1); -#ifdef __HIP_PLATFORM_HCC__ - hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel, grid, block, kernelArgs, 0, current_stream); -#else - cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel, grid, block, kernelArgs, 0, current_stream); -#endif - } - } else { - if (top_zero) { - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel, numThreads, 0); - dim3 grid(numSM*numBlocksPerSm,1,1); -#ifdef __HIP_PLATFORM_HCC__ - hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel, grid, block, kernelArgs, 0, current_stream); -#else - cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel, grid, block, kernelArgs, 0, current_stream); -#endif - } else if (btm_zero) { - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel, numThreads, 0); - dim3 grid(numSM*numBlocksPerSm,1,1); -#ifdef __HIP_PLATFORM_HCC__ - hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel, grid, block, kernelArgs, 0, current_stream); -#else - cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel, grid, block, kernelArgs, 0, current_stream); -#endif - } else { - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, push_pull_halos_1d_kernel, numThreads, 0); - dim3 grid(numSM*numBlocksPerSm,1,1); -#ifdef __HIP_PLATFORM_HCC__ - hipLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel, grid, block, kernelArgs, 0, current_stream); -#else - cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel, grid, block, kernelArgs, 0, current_stream); -#endif - } - } - } - } ); -} - -} } } - diff --git a/apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh b/apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh deleted file mode 100644 index 4f0169f..0000000 --- a/apex/contrib/csrc/peer_memory/peer_memory_cuda.cuh +++ /dev/null @@ -1,50 +0,0 @@ -/** - * Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include -#ifndef _peer_memory_h_ -#define _peer_memory_h_ - -namespace apex { namespace contrib { namespace peer_memory { - int64_t allocate_raw(int64_t size); - void free_raw(int64_t raw); - void zero(int64_t raw, int64_t size); - at::Tensor get_raw_ipc_address(int64_t raw); - std::vector get_raw_peers(at::Tensor ipc_addresses, int peer_rank, int64_t raw); - at::Tensor blob_view_half(int64_t raw, std::vector shape, bool channels_last); - at::Tensor blob_view_float(int64_t raw, std::vector shape, bool channels_last); - at::Tensor blob_view_int(int64_t raw, std::vector shape, bool channels_last); - void push_pull_halos_1d( - bool diagnostics, - bool explicit_nhwc, - int numSM, // number of SMs to use - bool top_zero, // true if top halo should be zeroed - at::Tensor top_out_halo, // top output halo in sender device memory - at::Tensor top_out_tx, // top output transfer buffer in sender peer pool memory - at::Tensor top_inp_tx, // top input transfer buffer in top neighbor peer pool memory - at::Tensor top_inp_halo, // top input halo in receiver device memory - bool btm_zero, // true if btm halo should be zeroed - at::Tensor btm_out_halo, // btm output halo in sender device memory - at::Tensor btm_out_tx, // btm output transfer buffer in sender peer pool memory - at::Tensor btm_inp_tx, // btm input transfer buffer in btm neighbor peer pool memory - at::Tensor btm_inp_halo, // btm input halo in receiver device memory - at::Tensor top_signal, // top input signal in receiver device memory - at::Tensor btm_signal, // btm input signal in receiver device memory - at::Tensor waits // top and btm signals for this rank - ); -} } } -#endif diff --git a/apex/contrib/csrc/transducer/transducer_joint.cpp b/apex/contrib/csrc/transducer/transducer_joint.cpp deleted file mode 100755 index 351e7ca..0000000 --- a/apex/contrib/csrc/transducer/transducer_joint.cpp +++ /dev/null @@ -1,98 +0,0 @@ -#include -#include - -#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) - -std::vector transducer_joint_cuda_forward( - torch::Tensor f, - torch::Tensor g, - torch::Tensor fLen, - torch::Tensor gLen, - torch::Tensor batchOffset, - int64_t packedBatch, - int opt, - bool packOutput, - bool relu, - bool dropout, - float dropoutProb, - int tileSize); - - -std::vector transducer_joint_cuda_backward( - std::vector in, - torch::Tensor fLen, - torch::Tensor gLen, - torch::Tensor batchOffset, - int maxFLen, - int maxGLen, - bool packOutput, - float scale); - -std::vector transducer_joint_forward( - torch::Tensor f, - torch::Tensor g, - torch::Tensor fLen, - torch::Tensor gLen, - torch::Tensor batchOffset, - int64_t packedBatch, - int opt, - bool packOutput, - bool relu, - bool dropout, - float dropoutProb, - int tileSize) { - CHECK_INPUT(f); - CHECK_INPUT(g); - CHECK_INPUT(fLen); - CHECK_INPUT(gLen); - if (packOutput) - CHECK_INPUT(batchOffset); - return transducer_joint_cuda_forward( - f, - g, - fLen, - gLen, - batchOffset, - packedBatch, - opt, - packOutput, - relu, - dropout, - dropoutProb, - tileSize); -} - -std::vector transducer_joint_backward( - std::vector in, - torch::Tensor fLen, - torch::Tensor gLen, - torch::Tensor batchOffset, - int maxFLen, - int maxGLen, - bool packOutput, - float scale) { - for (auto t : in){ - CHECK_INPUT(t); - } - CHECK_INPUT(fLen); - CHECK_INPUT(gLen); - if (packOutput) - CHECK_INPUT(batchOffset); - return transducer_joint_cuda_backward( - in, - fLen, - gLen, - batchOffset, - maxFLen, - maxGLen, - packOutput, - scale); -} - - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &transducer_joint_forward, "transducer joint forward (CUDA)"); - m.def("backward", &transducer_joint_backward, "transducer joint backward (CUDA)"); -} \ No newline at end of file diff --git a/apex/contrib/csrc/transducer/transducer_joint_kernel.cu b/apex/contrib/csrc/transducer/transducer_joint_kernel.cu deleted file mode 100755 index c0fb572..0000000 --- a/apex/contrib/csrc/transducer/transducer_joint_kernel.cu +++ /dev/null @@ -1,985 +0,0 @@ -#include -#include -#include - -#include -#include - -#ifdef OLD_GENERATOR_PATH -#include -#else -#include -#endif - -#include -#include -#include - -#include "philox.cuh" - -#ifdef __HIP_PLATFORM_HCC__ -#define SHFL_DOWN(val, laneMask, width) __shfl_down(val, laneMask, width) -#else -#define SHFL_DOWN(val, laneMask, width) __shfl_down_sync(0xffffffff, val, laneMask, width) -#endif - -// Warp reduce kernels to reduce N groups of data into N numbers, where N = warpSize / width. -// width should be a power of 2 and should be less than warpSize. -template -__device__ __forceinline__ scalar_t warpReduce(scalar_t x, int width=C10_WARP_SIZE){ - for (unsigned offset = width/2; offset > 0; offset /= 2){ - x += SHFL_DOWN(x, offset, width); - } - return x; -} - -inline int largestPowerOfTwo(int x){ - int y = 1; - while (y <= x) - y <<= 1; - return y >> 1; -} - -/* -Figure out vectorization type for masks. -Similar to how PyTorch figures out acc_t here: -aten/src/ATen/AccumulateType.h -*/ -template -struct MaskVecType { }; - -template <> struct MaskVecType<1> { using type = uint8_t; }; -template <> struct MaskVecType<2> { using type = uint16_t; }; -template <> struct MaskVecType<4> { using type = uint32_t; }; - -template -using mvec_type = typename MaskVecType::type; - -// Helper class to calculate pointer offset that can be shared by different flavors of kernels. -// For fwd, batch offset and stride are different for packing and non-packing mode. -struct OffsetCalFwd{ - __device__ __forceinline__ OffsetCalFwd( - int64_t batch, - const int64_t *batchOffset, - int64_t maxFLen, - int64_t maxGLen, - int64_t gLen, - int64_t hiddenSize, - bool packOutput) : - batch(batch), - batchOffset(batchOffset), - maxFLen(maxFLen), - maxGLen(maxGLen), - gLen(gLen), - hiddenSize(hiddenSize), - packOutput(packOutput) - {} - - int64_t batch; - const int64_t *batchOffset; - int64_t maxFLen; - int64_t maxGLen; - int64_t gLen; - int64_t hiddenSize; - bool packOutput; - - __device__ __forceinline__ int64_t getBatchOffset(){ - return packOutput ? ((batch==0) ? 0 : batchOffset[batch-1])*hiddenSize - : batch*maxFLen*maxGLen*hiddenSize; - } - - __device__ __forceinline__ int64_t getStrideF(){ - return packOutput ? gLen*hiddenSize : maxGLen*hiddenSize; - } - - -}; - -// Helper class to calculate pointer offset that can be shared by different flavors of kernels -// For bwd, batch offset and stride are different for packing and non-packing mode. -// The reducion is done for two input tensors. Therefore, generating two sets of offsets -// according to bwdFasterDim can lead to a unified implementation in the actual kernel. -struct OffsetCalBwd{ - __device__ __forceinline__ OffsetCalBwd( - int64_t batch, - const int64_t *batchOffset, - const int *fLen, - const int *gLen, - int64_t maxFLen, - int64_t maxGLen, - int64_t hiddenSize, - bool packOutput, - bool bwdFasterDim) : - batch(batch), - batchOffset(batchOffset), - maxFLen(maxFLen), - maxGLen(maxGLen), - fLen(fLen), - gLen(gLen), - hiddenSize(hiddenSize), - packOutput(packOutput), - bwdFasterDim(bwdFasterDim) - {} - - int64_t batch; - const int64_t *batchOffset; - const int *fLen; - const int *gLen; - int64_t maxFLen; - int64_t maxGLen; - int64_t hiddenSize; - bool packOutput; - bool bwdFasterDim; // whether doing bwd on the faster moving dimension - - __device__ __forceinline__ int64_t getBatchOffset(){ - return packOutput ? ((batch==0) ? 0 : batchOffset[batch-1])*hiddenSize - : batch*maxFLen*maxGLen*hiddenSize; - } - - __device__ __forceinline__ int64_t getMaxXLen(){ - return bwdFasterDim ? maxGLen : maxFLen; - } - - __device__ __forceinline__ auto getMyXLen() -> decltype(gLen[batch]){ - return bwdFasterDim ? gLen[batch] : fLen[batch]; - } - - __device__ __forceinline__ auto getMyYLen() -> decltype(gLen[batch]){ - return bwdFasterDim ? fLen[batch] : gLen[batch]; - } - - __device__ __forceinline__ int64_t getStrideX(){ - return bwdFasterDim ? hiddenSize : ((packOutput ? gLen[batch] : maxGLen) * hiddenSize); - } - - __device__ __forceinline__ int64_t getStrideY(){ - return bwdFasterDim ? ((packOutput ? gLen[batch] : maxGLen) * hiddenSize) : hiddenSize; - } -}; - - -// Vanila transducer joint forward kernel -// Detail of this joint function can be found in: -// [1] Sequence Transduction with Recurrent Neural Networks. - -// f is a tensor of shape [batch, T, H] -// g is a tensor of shape [batch, U, H] -// the transducer joint does -// sum = f.unsqueeze(dim=2) + g.unsqueeze(dim=1) -// The resultant tensor is of shape [batch, T, U, H] -// Each thread block is working on one "batch" of data in the output tensor, [batch, t, u, :] - -// This joint function can optionally pack the output where the output tensor with a shape of -// [B, T, U, H] is packed into [B_packed, H]. -// Don't-care region (t > fLen) or (u > gLen) is removed. -// To enable packing, the starting offset for each batch need to be specified with batchOffset. -template -__global__ void transducer_joint_forward( - const scalar_t *f, - const scalar_t *g, - const int *fLen, - const int *gLen, - const int64_t *batchOffset, - int64_t maxFLen, - int64_t maxGLen, - int64_t hiddenSize, - bool packOutput, - scalar_t *sum) { - - - const int batch = blockIdx.z; - const int t = blockIdx.y; - const int u = blockIdx.x; - const auto myFLen = fLen[batch]; - const auto myGLen = gLen[batch]; - - OffsetCal offsetCal(batch, batchOffset, maxFLen, maxGLen, myGLen, hiddenSize, packOutput); - const auto myBatchOffset = offsetCal.getBatchOffset(); - const auto strideF = offsetCal.getStrideF(); - scalar_t const *myF = f + batch*maxFLen*hiddenSize + t*hiddenSize; - scalar_t const *myG = g + batch*maxGLen*hiddenSize + u*hiddenSize; - scalar_t *mySum = sum + myBatchOffset + t*strideF + u * hiddenSize; - - if (t < myFLen and u < myGLen){ - #pragma unroll - for (int h = threadIdx.x; h < hiddenSize; h += blockDim.x){ - if (h < hiddenSize){ - mySum[h] = myF[h] + myG[h]; - } - } - } - else if (packOutput == false and t < maxFLen and u < maxGLen){ - // Need to write finite data to don't-care region because we instantiate the result tensor - // with torch::empty for performance reasons. Even though it is don't-care region, the - // contents need to be finite, otherwise could lead to NaN in WGRAD. - // In packing mode, this write is no longer necessary as we remove the don't-care region - // from the output. - // Picking -1 (over 0) here for ease of testing. - #pragma unroll - for (int h = threadIdx.x; h < hiddenSize; h += blockDim.x){ - if (h < hiddenSize){ - mySum[h] = -1; - } - } - } -} - -/* -Tiled version of the joint forward kernel -Detail of this joint function can be found in: -[1] Sequence Transduction with Recurrent Neural Networks. - -f is a tensor of shape [batch, T, H] -g is a tensor of shape [batch, U, H] -the transducer joint does -sum = f.unsqueeze(dim=2) + g.unsqueeze(dim=1) -The resultant tensor is of shape [batch, T, U, H] -Each thread is working on a tile of the shape of tileF x tileG in the result tensor. -The input for the tile is first loaded in the register and is reused tileG and tileF times. - -This joint function can optionally pack the output where the output tensor with a shape of -[B, T, U, H] is packed into [B_packed, H]. -Don't-care region (t > fLen) or (u > gLen) is removed. -To enable packing, the starting offset for each batch need to be specified with batchOffset. - -Optionally this joint function performs ReLU and/or dropout on the joint output, which is -controlled by arguments relu and dropout, respectively. philoxArgs is argument used for generating -pseudorandom number. When at least one of operations in ReLU and dropout is activated, the joint -function is a masked operation, which is controlled by the template argument masked. In this case, -masks are saved to backward. -*/ -template -__global__ void transducer_joint_tiled_forward( - const scalar_t *f, - const scalar_t *g, - const int *fLen, - const int *gLen, - const int64_t *batchOffset, - int64_t maxFLen, - int64_t maxGLen, - int64_t hiddenSize, - int64_t hiddenPerBlock, - bool packOutput, - bool relu, - bool dropout, - float p, - at::PhiloxCudaState philoxArgs, - scalar_t *sum, - uint8_t *mask) { - - static_assert(U == 4, "U has to be 4, as random numbers are generated in batch of 4"); - - const int batch = blockIdx.z; - const int t = blockIdx.y * tileF; - const int hiddenBlock = (hiddenSize + hiddenPerBlock - 1) / hiddenPerBlock; - const int u = blockIdx.x / hiddenBlock * tileG; - const int hOffset = (blockIdx.x % hiddenBlock) * hiddenPerBlock; - const int h = threadIdx.x; - const auto myFLen = fLen[batch]; - const auto myGLen = gLen[batch]; - - OffsetCal offsetCal(batch, batchOffset, maxFLen, maxGLen, myGLen, hiddenSize, packOutput); - const auto myBatchOffset = offsetCal.getBatchOffset(); - const auto strideF = offsetCal.getStrideF(); - - scalar_t const *myF = f + batch*maxFLen*hiddenSize + t*hiddenSize + hOffset; - scalar_t const *myG = g + batch*maxGLen*hiddenSize + u*hiddenSize + hOffset; - scalar_t *mySum = sum + myBatchOffset + t*strideF + u*hiddenSize + hOffset; - uint8_t *myMask = mask + myBatchOffset + t*strideF + u*hiddenSize + hOffset; - - // The following code is only needed for dropout. We try to bypass them as much as possible. - auto seeds = masked ? at::cuda::philox::unpack(philoxArgs) - : std::make_tuple(static_cast(0), static_cast(0)); - uint64_t tid = masked ? (static_cast(blockIdx.z)*gridDim.y*gridDim.x + - blockIdx.y*gridDim.x + blockIdx.x) * blockDim.x + threadIdx.x - : 0; - Philox ph(std::get<0>(seeds), tid, std::get<1>(seeds)); - scalar_t scale = masked ? ((p == 0) ? 0 : 1 / p) : 0; - bool dropoutMask[U]; - - if (t < myFLen and u < myGLen and hOffset+h < hiddenSize){ - // register buffers for tiled input reuse - scalar_t fBuffer[tileF], gBuffer[tileG]; - for (int i = 0; i < tileF; ++i){ - if (t + i < myFLen) - fBuffer[i] = myF[i*hiddenSize + h]; - } - for (int j = 0; j < tileG; ++j){ - if (u + j < myGLen) - gBuffer[j] = myG[j*hiddenSize + h]; - } - #pragma unroll - for (int i = 0; i < tileF; ++i){ - if (t + i < myFLen){ - #pragma unroll - for (int j = 0; j < tileG; ++j){ - int idx = i*tileG + j; - if (masked and dropout and idx % U == 0){ - // For performance, generate 4 random numbers in one shot - // auto rand4 = curand_uniform4(&state); - auto rand4 = uniform4(ph()); - dropoutMask[0] = rand4.x < p; - dropoutMask[1] = rand4.y < p; - dropoutMask[2] = rand4.z < p; - dropoutMask[3] = rand4.w < p; - } - - if (u + j < myGLen){ - scalar_t out = fBuffer[i] + gBuffer[j]; - if (masked){ - // Apply ReLU here when relu is True - bool localMask = relu ? (out>0) : 1; - localMask = dropout ? localMask & dropoutMask[idx%U] : localMask; - out = dropout ? out*localMask*scale : out*localMask; - myMask[i*strideF + j*hiddenSize + h] = static_cast(localMask); - } - mySum[i*strideF + j*hiddenSize + h] = out; - } - else if (packOutput == false and u + j < maxGLen) - mySum[i*strideF + j*hiddenSize + h] = -1; - } - } - else if (packOutput == false and t + i < maxFLen){ - // Again need to write finite data to don't-care region - #pragma unroll - for (int j = 0; j < tileG; ++j){ - if (u + j < maxGLen) - mySum[i*strideF + j*hiddenSize + h] = -1; - } - } - } - } - else if (packOutput == false and t < maxFLen and u < maxGLen and hOffset+h < hiddenSize){ - // Only need to ensure the finity in normal mode - #pragma unroll - for (int i = 0; i < tileF; ++i){ - if (t + i < maxFLen){ - #pragma unroll - for (int j = 0; j < tileG; ++j){ - if (u + j < maxGLen) - mySum[i*strideF + j*hiddenSize + h] = -1; - } - } - } - } -} - -/* -Bwd operation (reduction) on one input tensor. Since the operation performed for the two input -tensors are exactly the same, only one kernel is needed, and the different indexing offsets -and strides are handled by OffsetCalBwd. - -When packing is enabled in the fwd op, unpacking is needed to restore the gradients in a -non-packed form. - -When ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation, -and mask contains the mask information. -*/ -template -__device__ void transducer_joint_single_backward( - const scalar_t *grad, - const uint8_t *mask, - const int *fLen, - const int *gLen, - const int64_t *batchOffset, - int64_t maxFLen, - int64_t maxGLen, - int64_t hiddenSize, - bool packOutput, - bool bwdFasterDim, // whether bwd on the faster moving dimension (u) - float scale, - scalar_t *inGrad, - int yBlockOffset=0) { - - - const int batch = blockIdx.z; - // For the second input tensor, this offset need to be subtracted because the first yBlockOffset - // sets of thread blocks are for the first input tensor. - const int x = blockIdx.y-yBlockOffset; - const int hOffset = blockIdx.x*C10_WARP_SIZE; - const int wid = threadIdx.y; - const int lid = threadIdx.x; - const int numWarp = blockDim.y; - extern __shared__ char smem8[]; - auto smem = reinterpret_cast(smem8); - - OffsetCal offsetCal(batch, batchOffset, fLen, gLen, maxFLen, maxGLen, hiddenSize, packOutput, - bwdFasterDim); - const auto maxXLen = offsetCal.getMaxXLen(); - const auto myXLen = offsetCal.getMyXLen(); - const auto myYLen = offsetCal.getMyYLen(); - scalar_t *myInGrad = inGrad + batch*maxXLen*hiddenSize + x*hiddenSize + hOffset; - - if (x < myXLen){ - - const auto myBatchOffset = offsetCal.getBatchOffset(); - const auto strideX = offsetCal.getStrideX(); - const auto strideY = offsetCal.getStrideY(); - const scalar_t *myGrad = grad + myBatchOffset + x*strideX + hOffset; - const uint8_t *myMask = masked ? mask + myBatchOffset + x*strideX + hOffset : nullptr; - - // Each warp reduces numYPerWarp "y" first - acc_t warpSum = 0; - auto numYPerWarp = (myYLen+numWarp-1)/numWarp; - #pragma unroll - for (int warpY = 0; warpY < numYPerWarp; ++warpY){ - auto y = wid*numYPerWarp + warpY; - if (y < myYLen and (hOffset+lid) < hiddenSize) - if (masked) - warpSum += static_cast(myGrad[y*strideY + lid]) * myMask[y*strideY + lid] * scale; - else - warpSum += myGrad[y*strideY + lid]; - } - - // transpose partial sum in SMEM and reduce further using warpReduce - smem[lid*numWarp + wid] = warpSum; - __syncthreads(); - auto sum = smem[wid*C10_WARP_SIZE + lid]; - sum = warpReduce(sum, numWarp); - - // a a b b c c d d - // a a b b c c d d - // a a b b c c d d - // a a b b c c d d - // example of 4 warps (a, b, c, d) with 8 threads per warp - // Each warp need 8 / 4 = 2 threads to write the results. - if (hOffset+wid*C10_WARP_SIZE/numWarp+lid/numWarp < hiddenSize){ - if (lid % numWarp == 0){ - myInGrad[wid*C10_WARP_SIZE/numWarp + lid/numWarp] = sum; - } - } - } - else if (wid == 0 and hOffset + lid < hiddenSize){ - // Need to ensure the grad is zero for don't care region - myInGrad[lid] = 0; - } -} - -/* -Actual bwd (reduction) kernel get launched. -Call transducer_joint_single_backward twice on two input tensors. -The two bwd ops are launched together, the first op uses blockIdx.y < maxFLen, and the second op -uses the rest. -When ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation, -and mask contains the mask information. -*/ -template -__global__ void transducer_joint_combined_backward( - const scalar_t *grad, - const uint8_t *mask, - const int *fLen, - const int *gLen, - const int64_t *batchOffset, - int64_t maxFLen, - int64_t maxGLen, - int64_t hiddenSize, - bool packOutput, - float scale, - scalar_t *fGrad, - scalar_t *gGrad) { - if (blockIdx.y < maxFLen){ - transducer_joint_single_backward( - grad, - mask, - fLen, - gLen, - batchOffset, - maxFLen, - maxGLen, - hiddenSize, - packOutput, - false, - scale, - fGrad); - } - else{ - transducer_joint_single_backward( - grad, - mask, - fLen, - gLen, - batchOffset, - maxFLen, - maxGLen, - hiddenSize, - packOutput, - true, - scale, - gGrad, - maxFLen); - } -} - -/* -Vectorized version of transducer_joint_single_backward -Doing exact same operation as transducer_joint_single_backward except the load and store are -vectorized. -When packing is enabled in the fwd op, unpacking is needed to restore the gradients in a -non-packed form. -When ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation, -and mask contains the mask information. -*/ -template -__device__ void transducer_joint_single_vec_backward( - const scalar_t *grad, - const uint8_t *mask, - const int *fLen, - const int *gLen, - const int64_t *batchOffset, - int64_t maxFLen, - int64_t maxGLen, - int64_t hiddenSize, - bool packOutput, - bool bwdFasterDim, - float scale, - scalar_t *inGrad, - int yBlockOffset=0){ - - const int batch = blockIdx.z; - const int x = blockIdx.y - yBlockOffset; - const int hOffset = blockIdx.x*C10_WARP_SIZE*V; - const int wid = threadIdx.y; - const int lid = threadIdx.x; - const int numWarp = blockDim.y; - - // Figure out the vectorization type for mask - using mvec_t = mvec_type; - - OffsetCal offsetCal(batch, batchOffset, fLen, gLen, maxFLen, maxGLen, hiddenSize, packOutput, - bwdFasterDim); - const auto maxXLen = offsetCal.getMaxXLen(); - const auto myXLen = offsetCal.getMyXLen(); - const auto myYLen = offsetCal.getMyYLen(); - scalar_t *myInGrad = inGrad + batch*maxXLen*hiddenSize + x*hiddenSize + hOffset; - extern __shared__ char smem8[]; - auto smem = reinterpret_cast(smem8); - - acc_t warpSum[V]; - scalar_t inBuffer[V]; - uint8_t maskBuffer[V]; - scalar_t outBuffer[V]; - auto myInGradVec = reinterpret_cast(myInGrad); - auto outBufferVec = reinterpret_cast(outBuffer); - - if (x < myXLen){ - const auto myBatchOffset = offsetCal.getBatchOffset(); - const auto strideX = offsetCal.getStrideX(); - const auto strideY = offsetCal.getStrideY(); - const scalar_t *myGrad = grad + myBatchOffset + x*strideX + hOffset; - const uint8_t *myMask = masked ? mask + myBatchOffset + x*strideX + hOffset - :nullptr; - - for (int i = 0; i < V; ++i) - warpSum[i] = 0; - - // Each warp reduces numYPerWarp "y" first - auto numYPerWarp = (myYLen+numWarp-1)/numWarp; - for (int warpY = 0; warpY < numYPerWarp; ++warpY){ - auto y = wid*numYPerWarp + warpY; - auto myGradVec = reinterpret_cast(myGrad + y*strideY); - auto myMaskVec = masked ? reinterpret_cast(myMask + y*strideY) - : nullptr; - auto inBufferVec = reinterpret_cast(inBuffer); - auto maskBufferVec = reinterpret_cast(maskBuffer); - if (hOffset + lid*V < hiddenSize and y < myYLen){ - *inBufferVec = myGradVec[lid]; // vectorized load - if (masked){ - *maskBufferVec = myMaskVec[lid]; - #pragma unroll - for (int i = 0; i < V; ++i) - warpSum[i] += static_cast(inBuffer[i]) * maskBuffer[i] * scale; - } - else{ - #pragma unroll - for (int i = 0; i < V; ++i) - warpSum[i] += inBuffer[i]; - } - } - } - - // transpose partial sum in SMEM and reduce further using warpReduce - for (int i = 0; i < V; ++i){ - smem[lid*numWarp + wid] = warpSum[i]; - __syncthreads(); - auto sum = smem[wid*C10_WARP_SIZE + lid]; - - if (hOffset+(wid*C10_WARP_SIZE/numWarp)*V < hiddenSize){ - sum = warpReduce(sum, numWarp); - if (lid % numWarp == 0){ - outBuffer[i] = sum; - } - } - __syncthreads(); - } - - // a a b b c c d d - // a a b b c c d d - // a a b b c c d d - // a a b b c c d d - // example of 4 warps (a, b, c, d) with 8 threads per warp - // Each warp need 8 / 4 = 2 threads to write the results. - if (lid % numWarp == 0 and hOffset+(wid*C10_WARP_SIZE/numWarp + lid/numWarp)*V < hiddenSize) - myInGradVec[wid*C10_WARP_SIZE/numWarp + lid/numWarp] = *outBufferVec; - } - else if (wid == 0 and hOffset + lid*V < hiddenSize){ - // Need to ensure the grad is zero for don't care region - myInGradVec[lid] = 0; - } -} - -/* -Vecotrized version of transducer_joint_combined_backward -Call transducer_joint_single_vec_backward twice on two input tensors. -The two bwd ops are launched together, the first op uses blockIdx.y < maxFLen, and the second op -uses the rest. -When ReLU and/or dropout are performed in the fwd pass, this operation becomes a masked operation, -and mask contains the mask information. -*/ -template -__global__ void transducer_joint_combined_vec_backward( - const scalar_t *grad, - const uint8_t *mask, - const int *fLen, - const int *gLen, - const int64_t *batchOffset, - int64_t maxFLen, - int64_t maxGLen, - int64_t hiddenSize, - bool packOutput, - float scale, - scalar_t *fGrad, - scalar_t *gGrad) { - if (blockIdx.y < maxFLen){ - transducer_joint_single_vec_backward( - grad, - mask, - fLen, - gLen, - batchOffset, - maxFLen, - maxGLen, - hiddenSize, - packOutput, - false, - scale, - fGrad); - } - else{ - transducer_joint_single_vec_backward( - grad, - mask, - fLen, - gLen, - batchOffset, - maxFLen, - maxGLen, - hiddenSize, - packOutput, - true, - scale, - gGrad, - maxFLen); - } -} - - - - -std::vector transducer_joint_cuda_forward( - torch::Tensor f, - torch::Tensor g, - torch::Tensor fLen, - torch::Tensor gLen, - torch::Tensor batchOffset, - int64_t packedBatch, - int opt, - bool packOutput, - bool relu, - bool dropout, - float dropoutProb, - int tileSize){ - - - auto tensorOpt = f.options(); - auto dtype = f.scalar_type(); - const auto batchSize = f.size(0); - const auto maxFLen = f.size(1); - const auto maxGLen = g.size(1); - const auto hiddenSize = f.size(2); - bool masked = dropout or relu; - - int64_t *batchOffsetPtr = nullptr; - torch::Tensor sum, mask; - auto maskOpt = tensorOpt.dtype(torch::kUInt8); - if (!packOutput){ - sum = torch::empty({batchSize, maxFLen, maxGLen, hiddenSize}, tensorOpt); - batchOffsetPtr = nullptr; - if (masked) - mask = torch::empty({batchSize, maxFLen, maxGLen, hiddenSize}, maskOpt); - } - else{ - sum = torch::empty({packedBatch, hiddenSize}, tensorOpt); - batchOffsetPtr = batchOffset.data_ptr(); - if (masked) - mask = torch::empty({packedBatch, hiddenSize}, maskOpt); - } - uint8_t *maskPtr = masked ? mask.data_ptr() : nullptr; - - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - TORCH_CHECK(opt == 0 or opt == 1, "Got an invalid optimization level ", opt); - // Simple heuristics - const int numThread = std::min(128, (static_cast(hiddenSize)+C10_WARP_SIZE-1) - / C10_WARP_SIZE * C10_WARP_SIZE); - - if (opt == 0){ - // vanilla kernel - const int threads = numThread; - const dim3 blocks(maxGLen, maxFLen, batchSize); - - AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_joint_forward", ([&] { - transducer_joint_forward - <<>>( - f.data_ptr(), - g.data_ptr(), - fLen.data_ptr(), - gLen.data_ptr(), - batchOffsetPtr, - maxFLen, - maxGLen, - hiddenSize, - packOutput, - sum.data_ptr()); - })); - } - if (opt == 1){ - // tiled version. For simplicity, assume tileF == tileG, even though the kernel can - // support more general cases. - const int threads = numThread; - const int hiddenPerBlock = numThread; - const int hiddenBlock = (hiddenSize + hiddenPerBlock - 1) / hiddenPerBlock; - const dim3 blocks( (maxGLen+tileSize-1)/tileSize * hiddenBlock, - (maxFLen+tileSize-1)/tileSize, - batchSize); - - TORCH_CHECK(tileSize == 1 or tileSize == 2 or tileSize == 4, - "Expected tileSize to be in [1, 2, 4], but got ", tileSize); - - at::PhiloxCudaState rng_engine_inputs; - if (masked){ - // set up PRG when the input is masked. rng_engine_inputs will be used as a space filler - // for non-masked calls. - // Therefore no need to initialize. - c10::optional gen_; - auto gen = at::get_generator_or_default(gen_, - at::cuda::detail::getDefaultCUDAGenerator()); - // counterOffset records how many cuRAND calls each thread makes. For a tiled kernel, - // each thread processes tileF * tileG output elements. - int64_t counterOffset = tileSize * tileSize; - { - std::lock_guard lock(gen->mutex_); - rng_engine_inputs = gen->philox_cuda_state(counterOffset); - } - } - - AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_joint_forward", ([&] { - void(*kernel)(const scalar_t*, const scalar_t*, const int*, const int*, const int64_t*, - int64_t, int64_t, int64_t, int64_t, bool, bool, bool, float, - at::PhiloxCudaState, scalar_t*, uint8_t*); - if (masked){ - switch (tileSize){ - case 2: - kernel = &transducer_joint_tiled_forward; - break; - case 4: - kernel = &transducer_joint_tiled_forward; - break; - } - } - else{ - switch (tileSize){ - case 1: - kernel = &transducer_joint_tiled_forward; - break; - case 2: - kernel = &transducer_joint_tiled_forward; - break; - case 4: - kernel = &transducer_joint_tiled_forward; - break; - } - } - - kernel<<>>( - f.data_ptr(), - g.data_ptr(), - fLen.data_ptr(), - gLen.data_ptr(), - batchOffsetPtr, - maxFLen, - maxGLen, - hiddenSize, - hiddenPerBlock, - packOutput, - relu, - dropout, - 1.0f - dropoutProb, - rng_engine_inputs, - sum.data_ptr(), - maskPtr); - })); - } - - C10_CUDA_CHECK(cudaGetLastError()); - if (masked) - return {sum, mask}; - else - return {sum}; -} - -std::vector transducer_joint_cuda_backward( - std::vector in, - torch::Tensor fLen, - torch::Tensor gLen, - torch::Tensor batchOffset, - int maxFLen, - int maxGLen, - bool packOutput, - float scale){ - - auto grad = in[0]; - bool masked = (in.size() == 2); - uint8_t *maskPtr = masked ? in[1].data_ptr() : nullptr; - - auto tensorOpt = grad.options(); - auto dtype = grad.scalar_type(); - const int batchSize = fLen.size(0); - const int hiddenSize = grad.size(-1); - - const auto deviceProperties = at::cuda::getCurrentDeviceProperties(); - const int maxNumWarp = deviceProperties->maxThreadsPerBlock / C10_WARP_SIZE; - - torch::Tensor fGrad = torch::empty({batchSize, maxFLen, hiddenSize}, tensorOpt); - torch::Tensor gGrad = torch::empty({batchSize, maxGLen, hiddenSize}, tensorOpt); - - int64_t *batchOffsetPtr = (!packOutput) ? nullptr : batchOffset.data_ptr(); - - // The number "y" I would like each thread to work on - const int workPerThread = 32; - // Since the bwd for f and g have the same thread block size, we need to use the max of the two. - int numWarp = largestPowerOfTwo((std::max(maxFLen, maxGLen) + workPerThread-1) / workPerThread); - // Would like to have at least 2 warps - numWarp = std::max(2, numWarp); - // cap on the maximum number of warps allowed - numWarp = std::min(maxNumWarp, numWarp); - - // Need smem for transposing the partial sum. The partial sum is in a matrix of the shape - // numWarp x warpSize - const int smemSize = numWarp * C10_WARP_SIZE; - const dim3 threads(C10_WARP_SIZE, numWarp, 1); - - AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_joint_cuda_backward_kernel", ([&] { - auto gradPtr = grad.data_ptr(); - auto fLenPtr = fLen.data_ptr(); - auto gLenPtr = gLen.data_ptr(); - auto fGradPtr = fGrad.data_ptr(); - auto gGradPtr = gGrad.data_ptr(); - - // resolve the acc_t type - using acc_t = at::acc_type; - using vec_t = uint64_t; - - constexpr int vectFactor = sizeof(vec_t) / sizeof(scalar_t); - constexpr int vecAlignment = std::alignment_of::value; - - // if all input and output tensors meet the alignment requirement - bool memAlign = (reinterpret_cast(gradPtr) % vecAlignment == 0) - and (reinterpret_cast(fGradPtr) % vecAlignment == 0) - and (reinterpret_cast(gGradPtr) % vecAlignment == 0); - - if (vectFactor > 1 and hiddenSize%vectFactor == 0 and memAlign){ - // If vectorization helps and the alignment requirement is met, use the vectorized - // kernel. For simplicity, hiddenSize needs to be a multiple vecFactor. - const dim3 blocks( (hiddenSize+C10_WARP_SIZE*vectFactor-1)/(C10_WARP_SIZE*vectFactor), - maxFLen+maxGLen, - batchSize); - if (masked){ - transducer_joint_combined_vec_backward - - <<>>( - gradPtr, - maskPtr, - fLenPtr, - gLenPtr, - batchOffsetPtr, - maxFLen, - maxGLen, - hiddenSize, - packOutput, - scale, - fGradPtr, - gGradPtr); - } - else{ - transducer_joint_combined_vec_backward - - <<>>( - gradPtr, - maskPtr, - fLenPtr, - gLenPtr, - batchOffsetPtr, - maxFLen, - maxGLen, - hiddenSize, - packOutput, - scale, - fGradPtr, - gGradPtr); - } - } - else{ - const dim3 blocks((hiddenSize+C10_WARP_SIZE-1)/C10_WARP_SIZE, - maxFLen + maxGLen, batchSize); - if (masked){ - transducer_joint_combined_backward - <<>>( - gradPtr, - maskPtr, - fLenPtr, - gLenPtr, - batchOffsetPtr, - maxFLen, - maxGLen, - hiddenSize, - packOutput, - scale, - fGradPtr, - gGradPtr); - } - else{ - transducer_joint_combined_backward - <<>>( - gradPtr, - maskPtr, - fLenPtr, - gLenPtr, - batchOffsetPtr, - maxFLen, - maxGLen, - hiddenSize, - packOutput, - scale, - fGradPtr, - gGradPtr); - } - } - })); - - return {fGrad, gGrad}; -} diff --git a/apex/contrib/csrc/transducer/transducer_loss.cpp b/apex/contrib/csrc/transducer/transducer_loss.cpp deleted file mode 100644 index f63a67f..0000000 --- a/apex/contrib/csrc/transducer/transducer_loss.cpp +++ /dev/null @@ -1,109 +0,0 @@ -#include -#include - -#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) - -std::vector transducer_loss_cuda_forward( - torch::Tensor x, - torch::Tensor label, - torch::Tensor audLen, - torch::Tensor txtLen, - torch::Tensor batchOffset, - int maxFLen, - int blankIdx, - int opt, - bool packedInput); - -torch::Tensor transducer_loss_cuda_backward( - torch::Tensor x, - torch::Tensor lossGrad, - torch::Tensor alpha, - torch::Tensor beta, - torch::Tensor audLen, - torch::Tensor txtLen, - torch::Tensor label, - torch::Tensor batchOffset, - int maxFLen, - int blankIdx, - int opt, - bool fuseSoftmaxBackward, - bool packedInput); - - -std::vector transducer_loss_forward( - torch::Tensor x, - torch::Tensor label, - torch::Tensor fLen, - torch::Tensor yLen, - torch::Tensor batchOffset, - int maxFLen, - int blankIdx, - int opt, - bool packedInput - ) { - - CHECK_INPUT(x); - CHECK_INPUT(label); - CHECK_INPUT(fLen); - CHECK_INPUT(yLen); - if (packedInput) - CHECK_INPUT(batchOffset); - return transducer_loss_cuda_forward( - x, - label, - fLen, - yLen, - batchOffset, - maxFLen, - blankIdx, - opt, - packedInput); -} - -torch::Tensor transducer_loss_backward( - torch::Tensor x, - torch::Tensor lossGrad, - torch::Tensor alpha, - torch::Tensor beta, - torch::Tensor fLen, - torch::Tensor yLen, - torch::Tensor label, - torch::Tensor batchOffset, - int maxFLen, - int blankIdx, - int opt, - bool fuseSoftmaxBackward, - bool packedInput){ - - CHECK_INPUT(x); - CHECK_INPUT(label); - CHECK_INPUT(lossGrad); - CHECK_INPUT(alpha); - CHECK_INPUT(beta); - CHECK_INPUT(fLen); - CHECK_INPUT(yLen); - if (packedInput) - CHECK_INPUT(batchOffset); - - return transducer_loss_cuda_backward( - x, - lossGrad, - alpha, - beta, - fLen, - yLen, - label, - batchOffset, - maxFLen, - blankIdx, - opt, - fuseSoftmaxBackward, - packedInput); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &transducer_loss_forward, "transducer loss forward (CUDA)"); - m.def("backward", &transducer_loss_backward, "transducer loss backward (CUDA)"); -} diff --git a/apex/contrib/csrc/transducer/transducer_loss_kernel.cu b/apex/contrib/csrc/transducer/transducer_loss_kernel.cu deleted file mode 100755 index 295e14b..0000000 --- a/apex/contrib/csrc/transducer/transducer_loss_kernel.cu +++ /dev/null @@ -1,767 +0,0 @@ -#include - -#include -#include - -#include -#include -#include -#include - -template -__device__ __forceinline__ scalar_t logSumExp(scalar_t a, scalar_t b) { - // standard log-sum-exp trick is used here to provide better numerical stability - return (a >= b) ? a + std::log1p(exp(b-a)) : b + std::log1p(exp(a-b)); -} - -// Vanilla transducer loss function (i.e. forward-backward algorithm) -// Detail of this loss function can be found in: -// [1] Sequence Transduction with Recurrent Neural Networks. - -// Forward (alpha) and backward (beta) path are launched together. Input is assumed to be converted -// into log scale by the preceding log_softmax layer -// Diagonal wavefront advancing usually used in dynamic programming is leveraged here. -// alpha and beta are of acc_t type, as they are essentially accumulators. - -// This loss function supports packed input where a tensor of shape [B, T, U, H] is packed into -// [B_packed, H]. -// Don't-care region (t > audLen) or (u > txtLen) is removed. -// To support the packed input, the starting offsets for each batch need to be specified with -// batchOffset. -template -__global__ void transducer_loss_forward( - const scalar_t* x, - const int* label, - const int* audLen, - const int* txtLen, - const int64_t* batchOffset, - int64_t dictSize, // 64-bit indexing for data tensor - int64_t blankIdx, - int64_t maxFLen, - int64_t maxGLen, - bool packedInput, - acc_t* alpha, - acc_t* beta, - scalar_t* loss) { - - const int batch = blockIdx.y; - const int tid = threadIdx.x; - const auto myFLen = audLen[batch]; - // Note that start of the sentence is added as 1 here - const auto myGLen = txtLen[batch] + 1; - const auto myLabel = label + batch * (maxGLen-1); - const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1]) - : batch * maxFLen * maxGLen; - const int64_t myStrideT = packedInput ? myGLen : maxGLen; - const scalar_t* myX = x + myBatchOffset * dictSize; - int u = tid; - - if (blockIdx.x == 0){ - // alpha path - acc_t* myAlpha = alpha + batch*maxFLen*maxGLen; - if (u == 0) - myAlpha[0] = 0; - __syncthreads(); - - for (int64_t step = 1; step < myFLen+myGLen-1; ++step){ - // Move along the diagonal wavefront to leverage available parallelism - for (u = tid; u < myGLen; u += blockDim.x){ - int64_t t = step - u; - if (t >= 0 and t < myFLen and u >= 0 and u < myGLen){ - // Eq(16) in [1] - if (u == 0){ - // alpha(t, u) = alpha(t-1, u) * null(t-1, u) - myAlpha[t*maxGLen + u] = myAlpha[(t-1)*maxGLen] - + myX[((t-1)*myStrideT) * dictSize + blankIdx]; - } - else if (t == 0){ - // alpha(t, u-1) = alpha(t, u-1) * y(t, u-1) - myAlpha[u] = myAlpha[u - 1] + myX[(u - 1) * dictSize + myLabel[u - 1]]; - } - else{ - // alpha(t, u) = alpha(t-1, u) * null(t-1, u) + alpha(t, u-1) * y(t, u-1) - acc_t current = myAlpha[(t-1)*maxGLen + u] - + myX[((t-1)*myStrideT + u) * dictSize + blankIdx]; - acc_t next = myAlpha[t*maxGLen + u - 1] - + myX[(t*myStrideT + u - 1) * dictSize + myLabel[u - 1]]; - myAlpha[t*maxGLen + u] = logSumExp(next, current); - } - } - } - __syncthreads(); - } - } - else if (blockIdx.x == 1){ - // beta path - acc_t* myBeta = beta + batch*maxFLen*maxGLen; - if (u == 0){ - myBeta[(myFLen-1)*maxGLen + myGLen - 1] = myX[((myFLen-1)*myStrideT - + myGLen - 1) * dictSize + blankIdx]; - } - __syncthreads(); - - for (int64_t step = myFLen+myGLen - 3; step >= 0; --step){ - for (u = tid; u < myGLen; u += blockDim.x){ - int64_t t = step - u; - if (t >= 0 and t < myFLen and u >=0 and u < myGLen){ - // Eq(18) in [1] - if (u == myGLen - 1){ - // beta(t, u) = beta(t+1, u) * null(t, u) - myBeta[t*maxGLen + u] = myBeta[(t+1)*maxGLen + u] - + myX[(t*myStrideT + u) * dictSize + blankIdx]; - } - else if (t == myFLen - 1){ - // beta(t, u) = beta(t, u+1) * y(t, u) - myBeta[t*maxGLen + u] = myBeta[t*maxGLen + u + 1] - + myX[(t*myStrideT + u) * dictSize + myLabel[u]]; - } - else{ - // beta(t, u) = beta(t+1, u)*null(t, u) + beta(t, u+1)*y(t, u) - acc_t current = myBeta[(t+1)*maxGLen + u] - + myX[(t*myStrideT + u) * dictSize + blankIdx]; - acc_t next = myBeta[t*maxGLen + u + 1] - + myX[(t*myStrideT + u) * dictSize + myLabel[u]]; - myBeta[t*maxGLen + u] = logSumExp(next, current); - } - } - } - __syncthreads(); - } - if (tid == 0) - loss[batch] = -myBeta[0]; - } - -} - -// transudcer loss function (i.e. forward-backward algorithm) with batch loading optimization. -// Compared to the vanilla version, there are two optimizations: -// 1. load x in batch through loop unrolling to reduce the latency. -// 2. Use registers and shared memory to hold alpha and beta values passed from one step the next. -// For simplicity, this kernel currently only supports U <= maxThread, which should be the common -// case. For cases where U > maxThread, the vanilla kernel is used as a fallback option. - -// Detail of this loss function can be found in: -// [1] Sequence Transduction with Recurrent Neural Networks. -// Forward (alpha) and backward (beta) path are launched together. Input is assumed to be converted -// into log scale by the preceding log_softmax layer -// Diagonal wavefront advancing usually used in dynamic programming is leveraged here. -// alpha and beta are of acc_t type, as they are essentially accumulators. - -// This loss function supports packed input where a tensor of shape [B, T, U, H] is packed into -// [B_packed, H]. -// Don't-care region (t > audLen) or (u > txtLen) is removed. -// To support the packed input, the starting offsets for each batch need to be specified with -// batchOffset. -template -__global__ void transducer_loss_batch_load_forward( - const scalar_t* x, - const int* label, - const int* audLen, - const int* txtLen, - const int64_t* batchOffset, - int64_t dictSize, - int64_t blankIdx, - int64_t maxFLen, - int64_t maxGLen, - bool packedInput, - acc_t* alpha, - acc_t* beta, - scalar_t* loss) { - - const int batch = blockIdx.y; - int u = threadIdx.x; - const auto myFLen = audLen[batch]; - const auto myGLen = txtLen[batch] + 1; - const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1]) - : batch * maxFLen * maxGLen; - const int64_t myStrideT = packedInput ? myGLen : maxGLen; - const scalar_t* myX = x + myBatchOffset * dictSize; - scalar_t next[batchLdSize], current[batchLdSize]; - extern __shared__ char smem8[]; - auto smem = reinterpret_cast(smem8); - - if (blockIdx.x == 0){ - // alpha path - acc_t* myAlpha = alpha + batch*maxFLen*maxGLen; - // two SMEM regions for double buffering read and write data to avoid data race - acc_t * const sharedAlpha[2] = {smem, smem+maxGLen}; - - sharedAlpha[0][u] = 0; - __syncthreads(); - - if (u == 0) - myAlpha[0] = 0; - - auto myAlphaLabel = (u == 0) ? 0 : label[batch*(maxGLen-1) + u - 1]; - // register used to pass value to the next step for the same thread - acc_t prvStepAlpha = 0; - for (int64_t step = 1; step < myFLen+myGLen-1+batchLdSize; step += batchLdSize){ - // Move along the diagonal wavefront to leverage available parallelism - // Batch loading X through loop unrolling - #pragma unroll - for (int i = 0; i < batchLdSize; ++i){ - if (step+i= 0 and t < myFLen and u >= 0 and u < myGLen){ - if (u == 0){ - current[i] = myX[currentId]; - } - else if (t == 0){ - next[i] = myX[nextId]; - } - else{ - current[i] = myX[currentId]; - next[i] = myX[nextId]; - } - } - } - } - // main computing loop - for (int i = 0; i < batchLdSize; ++i){ - // swap the pointer for double buffering - auto sharedAlphaRd = sharedAlpha[(step+i-1)%2]; - auto sharedAlphaWr = sharedAlpha[(step+i)%2]; - if (step+i= 0 and t < myFLen and u >= 0 and u < myGLen){ - // Eq(16) in [1] - if (u == 0) - prvStepAlpha = prvStepAlpha+current[i]; - else if (t == 0) - prvStepAlpha = sharedAlphaRd[u-1]+next[i]; - else - prvStepAlpha = logSumExp(prvStepAlpha+current[i], sharedAlphaRd[u-1] - + next[i]); - sharedAlphaWr[u] = prvStepAlpha; - myAlpha[t*maxGLen + u] = prvStepAlpha; - } - } - __syncthreads(); - } - } - } - else if (blockIdx.x == 1){ - // beta path - acc_t* myBeta = beta + batch*maxFLen*maxGLen; - // two SMEM regions for double buffering read and write data to avoid data race - acc_t * const sharedBeta[2] = {smem, smem + maxGLen}; - sharedBeta[0][u] = myX[((myFLen-1)*myStrideT + myGLen - 1) * dictSize + blankIdx]; - __syncthreads(); - - auto myBetaLabel = (u == maxGLen - 1) ? 0 : label[batch*(maxGLen-1) + u]; - // register used to pass value to the next step for the same thread - acc_t prvStepBeta = myX[((myFLen-1)*myStrideT + myGLen - 1) * dictSize + blankIdx]; - if (u == 0) - myBeta[(myFLen-1)*maxGLen + myGLen - 1] = prvStepBeta; - - for (int64_t step = 1; step < myFLen+myGLen-1; step += batchLdSize){ - // Move along the diagonal wavefront to leverage available parallelism - // Batch loading X - #pragma unroll - for (int i = 0; i < batchLdSize; ++i){ - if (step+i= 0 and t < myFLen and u >= 0 and u < myGLen){ - if (u == myGLen - 1){ - current[i] = myX[currentId]; - } - else if (t == myFLen - 1){ - next[i] = myX[nextId]; - } - else{ - current[i] = myX[currentId]; - next[i] = myX[nextId]; - } - } - } - } - // main computing loop - for (int i = 0; i < batchLdSize; ++i){ - // swap the pointer for double buffering - auto sharedBetaRd = sharedBeta[(step+i-1)%2]; - auto sharedBetaWr = sharedBeta[(step+i)%2]; - if (step+i= 0 and t < myFLen and u >= 0 and u < myGLen){ - // Eq(18) in [1] - if (u == myGLen - 1) - prvStepBeta = prvStepBeta+current[i]; - else if (t == myFLen - 1) - prvStepBeta = sharedBetaRd[u+1]+next[i]; - else - prvStepBeta = logSumExp(prvStepBeta+current[i], sharedBetaRd[u+1] - + next[i]); - sharedBetaWr[u] = prvStepBeta; - myBeta[t*maxGLen + u] = prvStepBeta; - } - - } - __syncthreads(); - } - } - if (u == 0) - loss[batch] = -prvStepBeta; - } - -} - -// Vanilla transudcer loss backward operation. -// Detail of this loss function can be found in: -// [1] Sequence Transduction with Recurrent Neural Networks. -// For this backward kernel, bwd op for the preceding softmax is assumed to be handled elsewhere, -// hence only Eq(20) in [1] is implemented in this kernel. - -// Each thread block works on [batch, t, :, :] of data. Each thread works on a specific u at a time -// Since only gradients for the correct token and null token need to be updated, gradients at other -// locations are initialized to 0. - -// To support the packed input, the starting offsets for each batch need to be specified with -// batchOffset. -template -__global__ void transducer_loss_backward( - const scalar_t* x, - const scalar_t* lossGrad, - const int* audLen, - const int* txtLen, - const int* label, - const acc_t* alpha, - const acc_t* beta, - const int64_t* batchOffset, - int64_t dictSize, - int64_t blankIdx, - int64_t maxFLen, - int64_t maxGLen, - bool packedInput, - scalar_t* xGrad) { - - const int tid = threadIdx.x; - const int t = blockIdx.x; - const int batch = blockIdx.y; - const int64_t myFLen = audLen[batch]; - const int64_t myGLen = txtLen[batch] + 1; - const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1]) - : batch * maxFLen * maxGLen; - const int64_t myStrideT = packedInput ? myGLen : maxGLen; - auto myX = x + (myBatchOffset + t*myStrideT)*dictSize; - auto myAlpha = alpha + batch*maxFLen*maxGLen; - auto myBeta = beta + batch*maxFLen*maxGLen; - auto myXGrad = xGrad + (myBatchOffset + t*myStrideT)*dictSize; - auto myLabel = label + batch*(maxGLen-1); - - int64_t u = tid; - while (t < myFLen and u < myGLen){ - // Do the update - // loss = -ln(Pr(y*|x)) - acc_t grad = std::log(lossGrad[batch]) + myAlpha[t*maxGLen + u] - myBeta[0]; - if (u != myGLen - 1) - myXGrad[u*dictSize + myLabel[u]] = -std::exp(grad + myBeta[t*maxGLen + u + 1] - + myX[u*dictSize + myLabel[u]]); - if (t == myFLen - 1 and u == myGLen - 1) - myXGrad[u*dictSize + blankIdx] = -std::exp(grad + myX[u*dictSize + blankIdx]); - else if (t != myFLen - 1) - myXGrad[u*dictSize + blankIdx] = -std::exp(grad + myBeta[(t+1)*maxGLen + u] - + myX[u*dictSize + blankIdx]); - - u += blockDim.x; - } -} - -// Fused transudcer loss backward operation. -// Detail of this loss function can be found in: -// [1] Sequence Transduction with Recurrent Neural Networks. -// The bwd op of the preceding softmax layer is fused in this kernel. -// Each thread block works on [batch, t, u, :] of data. Each thread works on a specific h at a time - -// To support the packed input, the starting offsets for each batch need to be specified with -// batchOffset. -template -__global__ void transducer_loss_fused_backward( - const scalar_t* x, - const scalar_t* lossGrad, - const int* audLen, - const int* txtLen, - const int* label, - const acc_t* alpha, - const acc_t* beta, - const int64_t* batchOffset, - int64_t dictSize, - int64_t blankIdx, - int64_t maxFLen, - int64_t maxGLen, - bool packedInput, - scalar_t* xGrad) { - - const int tid = threadIdx.x; - const int u = blockIdx.x; - const int t = blockIdx.y; - const int batch = blockIdx.z; - const int64_t myFLen = audLen[batch]; - const int64_t myGLen = txtLen[batch] + 1; - const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1]) - : batch * maxFLen * maxGLen; - const int64_t myStrideT = packedInput ? myGLen : maxGLen; - - __shared__ acc_t commonFactor, myBetaTU, myBetaTUp1, myBetaTp1U, myLabelShared; - auto myXGrad = xGrad + (myBatchOffset + t*myStrideT +u)*dictSize; - - if (t < myFLen and u < myGLen){ - auto myX = x + (myBatchOffset + t*myStrideT +u)*dictSize; - auto myAlpha = alpha + batch*maxFLen*maxGLen; - auto myBeta = beta + batch*maxFLen*maxGLen; - auto myLabel = label + batch*(maxGLen-1); - - // load and store shared variables in SMEM - if (tid == 0){ - commonFactor = std::log(lossGrad[batch]) + myAlpha[t*maxGLen + u] - myBeta[0]; - myBetaTU = myBeta[t*maxGLen + u]; - myBetaTUp1 = myBeta[t*maxGLen + u + 1]; - myBetaTp1U = myBeta[(t+1)*maxGLen + u]; - myLabelShared = myLabel[u]; - } - - __syncthreads(); - - for (int64_t h = tid; h < dictSize; h += blockDim.x){ - // Do the update - acc_t grad = commonFactor + myX[h]; // loss = -ln(Pr(y*|x)) - acc_t myGrad = std::exp(grad + myBetaTU); - if (u != myGLen - 1 and h == myLabelShared){ - myGrad -= std::exp(grad + myBetaTUp1); - } - else if (h == blankIdx){ - if (t == myFLen - 1 and u == myGLen - 1) - myGrad -= std::exp(grad); - else if (t != myFLen - 1) - myGrad -= std::exp(grad + myBetaTp1U); - } - myXGrad[h] = myGrad; - } - } - else if (!packedInput){ - // In non-pack mode, need to make sure the gradients for don't-care regions are zero. - for (int64_t h = tid; h < dictSize; h += blockDim.x){ - myXGrad[h] = 0; - } - } -} - - -// Vectorized version of fused transudcer loss backward operation. -// Detail of this loss function can be found in: -// [1] Sequence Transduction with Recurrent Neural Networks. -// The bwd op of the preceding softmax layer is fused in this kernel. -// Each thread block works on [batch, t, u, :] of data. Each thread works on a specific h at a time - -// To support the packed input, the starting offsets for each batch need to be specified with -// batchOffset. -template -__global__ void transducer_loss_fused_vec_backward( - const scalar_t* x, - const scalar_t* lossGrad, - const int* audLen, - const int* txtLen, - const int* label, - const acc_t* alpha, - const acc_t* beta, - const int64_t* batchOffset, - int64_t dictSize, - int64_t blankIdx, - int64_t maxFLen, - int64_t maxGLen, - bool packedInput, - scalar_t* xGrad) { - - const int tid = threadIdx.x; - const int u = blockIdx.x; - const int t = blockIdx.y; - const int batch = blockIdx.z; - const int64_t myFLen = audLen[batch]; - const int64_t myGLen = txtLen[batch] + 1; - const int64_t myBatchOffset = packedInput ? (batch == 0 ? 0 : batchOffset[batch-1]) - : batch * maxFLen * maxGLen; - const int64_t myStrideT = packedInput ? myGLen : maxGLen; - - __shared__ acc_t commonFactor, myBetaTU, myBetaTUp1, myBetaTp1U, myLabelShared; - auto myXGrad = xGrad + (myBatchOffset + t*myStrideT +u)*dictSize; - auto myX = x + (myBatchOffset + t*myStrideT +u)*dictSize; - auto myAlpha = alpha + batch*maxFLen*maxGLen; - auto myBeta = beta + batch*maxFLen*maxGLen; - auto myLabel = label + batch*(maxGLen-1); - - // Variabels for vectorization - scalar_t myXBuffer[V], myXGradBuffer[V]; - auto myXVec = reinterpret_cast(myX); - auto myXGradVec = reinterpret_cast(myXGrad); - auto myXBufferVec = reinterpret_cast(myXBuffer); - auto myXGradBufferVec = reinterpret_cast(myXGradBuffer); - if (t < myFLen and u < myGLen){ - // load and store shared variables in SMEM - if (tid == 0){ - commonFactor = std::log(lossGrad[batch]) + myAlpha[t*maxGLen + u] - myBeta[0]; - myBetaTU = myBeta[t*maxGLen + u]; - if (t != myFLen - 1) - myBetaTp1U = myBeta[(t+1)*maxGLen + u]; - if (u != myGLen - 1){ - myBetaTUp1 = myBeta[t*maxGLen + u + 1]; - myLabelShared = myLabel[u]; - } - } - - __syncthreads(); - - #pragma unroll - for (int64_t h0 = tid*V; h0 < dictSize; h0 += blockDim.x*V){ - // Load myX in a vector form - *myXBufferVec = myXVec[h0/V]; - // Do the update for a vector of input - #pragma unroll - for (int i = 0; i < V; ++i){ - auto h = h0 + i; - acc_t grad = commonFactor + myXBuffer[i]; // loss = -ln(Pr(y*|x)) - acc_t myGrad = std::exp(grad + myBetaTU); - if (u != myGLen - 1 and h == myLabelShared){ - myGrad -= std::exp(grad + myBetaTUp1); - } - else if (h == blankIdx){ - if (t == myFLen - 1 and u == myGLen - 1) - myGrad -= std::exp(grad); - else if (t != myFLen - 1) - myGrad -= std::exp(grad + myBetaTp1U); - } - myXGradBuffer[i] = myGrad; - } - - // Store myXGrad in a vector form - myXGradVec[h0/V] = *myXGradBufferVec; - - } - } - else if (!packedInput){ - // In non-pack mode, need to make sure the gradients for don't-care regions are zero. - for (int64_t h0 = tid*V; h0 < dictSize; h0 += blockDim.x*V){ - myXGradVec[h0/V] = 0; - } - } -} - - -std::vector transducer_loss_cuda_forward( - torch::Tensor x, - torch::Tensor label, - torch::Tensor audLen, - torch::Tensor txtLen, - torch::Tensor batchOffset, - int maxFLen, - int blankIdx, - int opt, - bool packedInput){ - - auto scalarType = x.scalar_type(); - auto tensorOpt = x.options(); - const int batchSize = label.size(0); - const int maxGLen = label.size(1) + 1; - const int dictSize = x.size(-1); - - TORCH_CHECK(blankIdx >= 0 and blankIdx < dictSize, - "Expected blank index to be in the range of 0 to ", - dictSize-1, - ", but got ", - blankIdx); - TORCH_CHECK(opt == -1 or opt == 0 or opt == 1, - "Got an invalid optimization level ", - opt); - - // The data type of alpha and beta will be resolved at dispatch time, - // hence defined here and assigned later - torch::Tensor alpha; - torch::Tensor beta; - torch::Tensor loss = torch::empty({batchSize}, tensorOpt); - const auto deviceProperties = at::cuda::getCurrentDeviceProperties(); - const auto maxThreadPerBlock = deviceProperties->maxThreadsPerBlock; - const auto maxSmemPerBlock = deviceProperties->sharedMemPerBlock; - const auto batchOffsetPtr = packedInput ? batchOffset.data_ptr() : nullptr; - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - AT_DISPATCH_FLOATING_TYPES_AND_HALF(scalarType, "transducer_loss_cuda_forward", ([&] { - // resolve accumulation type - using acc_t = at::acc_type; - auto accType = c10::CppTypeToScalarType::value; - auto accTensorOpt = tensorOpt.dtype(accType); - alpha = torch::empty({batchSize, maxFLen, maxGLen}, accTensorOpt); - beta = torch::empty({batchSize, maxFLen, maxGLen}, accTensorOpt); - - // decide what kernel to launch based on the problem size - // if the required SMEM size or number threads exceeds the limit, fall back to the vanilla - // kernel. - const auto smemSize = 2*maxGLen*sizeof(acc_t); - const auto optFallBack = (maxGLen > maxThreadPerBlock or smemSize > maxSmemPerBlock) ? 0 - : (opt == -1) ? 1 : opt; - const int threads = std::min(maxThreadPerBlock, maxGLen); - const dim3 blocks(2, batchSize, 1); - - if (optFallBack == 0) - transducer_loss_forward<<>>( - x.data_ptr(), - label.data_ptr(), - audLen.data_ptr(), - txtLen.data_ptr(), - batchOffsetPtr, - dictSize, - blankIdx, - maxFLen, - maxGLen, - packedInput, - alpha.data_ptr(), - beta.data_ptr(), - loss.data_ptr()); - else if (optFallBack == 1) - transducer_loss_batch_load_forward - <<>>( - x.data_ptr(), - label.data_ptr(), - audLen.data_ptr(), - txtLen.data_ptr(), - batchOffsetPtr, - dictSize, - blankIdx, - maxFLen, - maxGLen, - packedInput, - alpha.data_ptr(), - beta.data_ptr(), - loss.data_ptr()); - - })); - C10_CUDA_CHECK(cudaGetLastError()); - - return {alpha, beta, loss}; -} - - - - -torch::Tensor transducer_loss_cuda_backward( - torch::Tensor x, - torch::Tensor lossGrad, - torch::Tensor alpha, - torch::Tensor beta, - torch::Tensor audLen, - torch::Tensor txtLen, - torch::Tensor label, - torch::Tensor batchOffset, - int maxFLen, - int blankIdx, - int opt, - bool fuseSoftmaxBackward, - bool packedInput){ - - auto dtype = x.scalar_type(); - torch::Tensor xGrad; - const int batchSize = label.size(0); - const int maxGLen = label.size(1) + 1; - const int dictSize = x.size(-1); - const auto deviceProperties = at::cuda::getCurrentDeviceProperties(); - const int maxThreadPerBlock = deviceProperties->maxThreadsPerBlock; - const int warpSize = deviceProperties->warpSize; - const auto batchOffsetPtr = packedInput ? batchOffset.data_ptr() : nullptr; - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - if (fuseSoftmaxBackward){ - // alloc empty tensors for performance, hence need to ensure zeros are writtern to - // don't-care region in the kernel. - xGrad = torch::empty_like(x); - - // Would like each thread to work on 4 hidden units - const int workPerThread = 4; - // Don't want to have more than 128 threads per thread block - const int maxThreadPerElmt = std::min(128, maxThreadPerBlock); - const int threads = std::min(maxThreadPerElmt, std::max(warpSize, - (dictSize+workPerThread-1)/workPerThread)); - const dim3 blocks(maxGLen, maxFLen, batchSize); - - AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_loss_cuda_backward", ([&] { - using vec_t = uint64_t; - using acc_t = at::acc_type; - constexpr int vectFactor = sizeof(vec_t) / sizeof(scalar_t); - constexpr int vecAlignment = std::alignment_of::value; - // if all input and output tensors meet the alignment requirement - bool memAlign = reinterpret_cast(x.data_ptr()) % vecAlignment == 0 - and reinterpret_cast(xGrad.data_ptr()) - % vecAlignment == 0; - - if (vectFactor > 1 and dictSize%vectFactor == 0 and memAlign){ - transducer_loss_fused_vec_backward - <<>>( - x.data_ptr(), - lossGrad.data_ptr(), - audLen.data_ptr(), - txtLen.data_ptr(), - label.data_ptr(), - alpha.data_ptr(), - beta.data_ptr(), - batchOffsetPtr, - dictSize, - blankIdx, - maxFLen, - maxGLen, - packedInput, - xGrad.data_ptr()); - } - else{ - transducer_loss_fused_backward<<>>( - x.data_ptr(), - lossGrad.data_ptr(), - audLen.data_ptr(), - txtLen.data_ptr(), - label.data_ptr(), - alpha.data_ptr(), - beta.data_ptr(), - batchOffsetPtr, - dictSize, - blankIdx, - maxFLen, - maxGLen, - packedInput, - xGrad.data_ptr()); - - } - })); - } - else{ - // for non-fused kernel, the gradients need to be writtern are very sparse, hence initialize - // the tensor with all zeros. - xGrad = torch::zeros_like(x); - // don't launch more threads than needed. - const int threads = std::min(maxThreadPerBlock, maxGLen); - const dim3 blocks(maxFLen, batchSize); - AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_loss_cuda_backward", ([&] { - using acc_t = at::acc_type; - transducer_loss_backward<<>>( - x.data_ptr(), - lossGrad.data_ptr(), - audLen.data_ptr(), - txtLen.data_ptr(), - label.data_ptr(), - alpha.data_ptr(), - beta.data_ptr(), - batchOffsetPtr, - dictSize, - blankIdx, - maxFLen, - maxGLen, - packedInput, - xGrad.data_ptr()); - })); - } - C10_CUDA_CHECK(cudaGetLastError()); - - return xGrad; -} diff --git a/apex/contrib/csrc/xentropy/interface.cpp b/apex/contrib/csrc/xentropy/interface.cpp deleted file mode 100644 index 3a8c201..0000000 --- a/apex/contrib/csrc/xentropy/interface.cpp +++ /dev/null @@ -1,52 +0,0 @@ -#include - -// CUDA forward declarations - -std::vector softmax_xentropy_cuda( - const at::Tensor &input, - const at::Tensor &labels, - const float smoothing, - const bool half_to_float); - -at::Tensor softmax_xentropy_backward_cuda( - const at::Tensor &grad_loss, - const at::Tensor &logits, - const at::Tensor &max_log_sum_exp, - const at::Tensor &labels, - const float smoothing); - -// C++ interface - -#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) - -std::vector softmax_xentropy_forward( - const at::Tensor &input, - const at::Tensor &labels, - const float smoothing, - const bool half_to_float) { - CHECK_CUDA(input); - CHECK_INPUT(labels); - - return softmax_xentropy_cuda(input, labels, smoothing, half_to_float); -} - -at::Tensor softmax_xentropy_backward( - const at::Tensor &grad_loss, - const at::Tensor &logits, - const at::Tensor &max_log_sum_exp, - const at::Tensor &labels, - const float smoothing) { - CHECK_CUDA(grad_loss); - CHECK_CUDA(logits); - CHECK_INPUT(max_log_sum_exp); - CHECK_INPUT(labels); - - return softmax_xentropy_backward_cuda(grad_loss, logits, max_log_sum_exp, labels, smoothing); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &softmax_xentropy_forward, "Softmax cross entropy loss with label smoothing forward (CUDA)"); - m.def("backward", &softmax_xentropy_backward, "Softmax cross entropy loss with label smoothing backward (CUDA)"); -} diff --git a/apex/contrib/csrc/xentropy/xentropy_kernel.cu b/apex/contrib/csrc/xentropy/xentropy_kernel.cu deleted file mode 100644 index 4d75956..0000000 --- a/apex/contrib/csrc/xentropy/xentropy_kernel.cu +++ /dev/null @@ -1,726 +0,0 @@ -/** - * From PyTorch: - * - * Copyright (c) 2016- Facebook, Inc (Adam Paszke) - * Copyright (c) 2014- Facebook, Inc (Soumith Chintala) - * Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) - * Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) - * Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) - * Copyright (c) 2011-2013 NYU (Clement Farabet) - * Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) - * Copyright (c) 2006 Idiap Research Institute (Samy Bengio) - * Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) - * - * From Caffe2: - * - * Copyright (c) 2016-present, Facebook Inc. All rights reserved. - * - * All contributions by Facebook: - * Copyright (c) 2016 Facebook Inc. - * - * All contributions by Google: - * Copyright (c) 2015 Google Inc. - * All rights reserved. - * - * All contributions by Yangqing Jia: - * Copyright (c) 2015 Yangqing Jia - * All rights reserved. - * - * All contributions from Caffe: - * Copyright(c) 2013, 2014, 2015, the respective contributors - * All rights reserved. - * - * All other contributions: - * Copyright(c) 2015, 2016 the respective contributors - * All rights reserved. - * - * Caffe2 uses a copyright model similar to Caffe: each contributor holds - * copyright over their contributions to Caffe2. The project versioning records - * all such contribution and copyright details. If a contributor wants to further - * mark their specific copyright on a particular contribution, they should - * indicate their copyright solely in the commit message of the change when it is - * committed. - * - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * - * 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America - * and IDIAP Research Institute nor the names of its contributors may be - * used to endorse or promote products derived from this software without - * specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE - * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - * POSSIBILITY OF SUCH DAMAGE. - */ -#include -#include - -#include -#include - -#include "type_shim.h" -#include "compat.h" - -#define ALIGN_BYTES 16 - -#ifdef __HIP_PLATFORM_HCC__ -#define WARP_SIZE 64 -#define SYNCWARP(mask) -#else -#define WARP_SIZE 32 -#define SYNCWARP(mask) __syncwarp(mask) -#endif - -using Tensor = at::Tensor; -using TensorList = at::TensorList; -using ScalarType = at::ScalarType; -using at::acc_type; - -template -struct LogSoftMaxForwardEpilogue { - __device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_input, AccumT sum) - : logsum(max_input + std::log(sum)) {} - - __device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_log_sum_exp) - : logsum(max_log_sum_exp) {} - - __device__ __forceinline__ OutT operator()(T input) const { - return static_cast(input - logsum); - } - - const AccumT logsum; -}; - -template -struct LogSoftMaxBackwardEpilogue { - __device__ __forceinline__ LogSoftMaxBackwardEpilogue(AccumT sum) - : sum(sum) {} - - __device__ __forceinline__ T operator()(OutT gradOutput, OutT output) const { - return static_cast(gradOutput - std::exp(static_cast(output)) * sum); - } - - const AccumT sum; -}; - - - -const int max_threads = 1024; - -inline dim3 SoftMax_getBlockSize(int ILP, uint64_t dim_size) { - uint64_t block_size = 1; - uint64_t max_block_size = std::min(dim_size / ILP, static_cast(max_threads)); - while (block_size < (max_block_size/2)) block_size *= 2; - // Launch at least a single warp - the kernel assumes that. - block_size = std::max(block_size, static_cast(WARP_SIZE)); - return dim3(block_size); -} - -template -struct Add { - __device__ __forceinline__ T operator()(T a, T b) const { - return a + b; - } -}; - -template -struct Max { - __device__ __forceinline__ T operator()(T a, T b) const { - return a < b ? b : a; - } -}; - - -//////////////////////////////////////////////////////////////////////////////// -// Regular kernel (fast when dim_size is large; requires inner_size == 1) -//////////////////////////////////////////////////////////////////////////////// - - -template -struct MaxFloat -{ - __device__ __forceinline__ AccumT operator()(AccumT max, T v) const { - return ::max(max, (AccumT)v); - } -}; - -template -struct AddFloat -{ - __device__ __forceinline__ AccumT operator()(AccumT sum, T v) const { - return sum + v; - } -}; - -template -struct SumExpFloat -{ - __device__ __forceinline__ SumExpFloat(AccumT v) - : max_k(v) {} - - __device__ __forceinline__ AccumT operator()(AccumT sum, T v) const { - return sum + std::exp(v - max_k); - } - - const AccumT max_k; -}; - -template class Reduction, typename AccumT> -__device__ __forceinline__ AccumT -blockReduce(AccumT* smem, AccumT val, - const Reduction& r, - AccumT defaultVal) -{ - // To avoid RaW races from chaining blockReduce calls together, we need a sync here - __syncthreads(); - - smem[threadIdx.x] = val; - - __syncthreads(); - - AccumT warpVal = defaultVal; - - // First warp will perform per-warp reductions for the remaining warps - uint32_t mask = (((uint64_t)1) << (blockDim.x / WARP_SIZE)) - 1; - if (threadIdx.x < WARP_SIZE) { - int lane = threadIdx.x % WARP_SIZE; - if (lane < blockDim.x / WARP_SIZE) { -#pragma unroll - for (int i = 0; i < WARP_SIZE; ++i) { - warpVal = r(warpVal, smem[lane * WARP_SIZE + i]); - } - SYNCWARP(mask); - smem[lane] = warpVal; - } - } - - __syncthreads(); - - // First thread will perform a reduction of the above per-warp reductions - AccumT blockVal = defaultVal; - - if (threadIdx.x == 0) { - for (int i = 0; i < blockDim.x / WARP_SIZE; ++i) { - blockVal = r(blockVal, smem[i]); - } - smem[0] = blockVal; - } - - // Sync and broadcast - __syncthreads(); - return smem[0]; -} - -template class Reduction1, template class Reduction2, typename AccumT> -__device__ __forceinline__ void -blockReduce(AccumT* smem, - AccumT* reducVal1, - AccumT val1, - const Reduction1& r1, - AccumT defaultVal1, - AccumT* reducVal2, - AccumT val2, - const Reduction2& r2, - AccumT defaultVal2) -{ - // To avoid RaW races from chaining blockReduce calls together, we need a sync here - __syncthreads(); - - smem[threadIdx.x] = val1; - smem[blockDim.x + threadIdx.x] = val2; - - __syncthreads(); - - AccumT warpVal1 = defaultVal1; - AccumT warpVal2 = defaultVal2; - - // First warp will perform per-warp reductions for the remaining warps - uint32_t mask = (((uint64_t)1) << (blockDim.x / WARP_SIZE)) - 1; - if (threadIdx.x < WARP_SIZE) { - int lane = threadIdx.x % WARP_SIZE; - if (lane < blockDim.x / WARP_SIZE) { -#pragma unroll - for (int i = 0; i < WARP_SIZE; ++i) { - warpVal1 = r1(warpVal1, smem[lane * WARP_SIZE + i]); - warpVal2 = r2(warpVal2, smem[lane * WARP_SIZE + i + blockDim.x]); - } - SYNCWARP(mask); - smem[lane] = warpVal1; - smem[lane + blockDim.x] = warpVal2; - } - } - - __syncthreads(); - - // First thread will perform a reduction of the above per-warp reductions - AccumT blockVal1 = defaultVal1; - AccumT blockVal2 = defaultVal2; - - if (threadIdx.x == 0) { - for (int i = 0; i < blockDim.x / WARP_SIZE; ++i) { - blockVal1 = r1(blockVal1, smem[i]); - blockVal2 = r2(blockVal2, smem[i + blockDim.x]); - } - smem[0] = blockVal1; - smem[blockDim.x] = blockVal2; - } - - // Sync and broadcast - __syncthreads(); - *reducVal1 = smem[0]; - *reducVal2 = smem[blockDim.x]; - __syncthreads(); -} - -template class Reduction, int ILP, typename T, typename AccumT> -__device__ __forceinline__ AccumT -ilpReduce(int shift, - T* data, - int size, - const Reduction& r, - AccumT defaultVal) -{ - typedef typename std::aligned_storage::type LoadT; - AccumT threadVal = defaultVal; - int offset = threadIdx.x; - - // shift and do 1 - if(shift > 0){ - data -= shift; - size += shift; - if(threadIdx.x >= shift){ - threadVal = r(threadVal, data[offset]); - } - size -= blockDim.x; - data += blockDim.x; - } - int last = size % (ILP * blockDim.x); - - T v[ILP]; - LoadT* value = reinterpret_cast(&v); - - for (; offset * ILP < (size - last); offset += blockDim.x) { - *value = reinterpret_cast(data)[offset]; - - for (int j = 0; j < ILP; ++j) { - threadVal = r(threadVal, v[j]); - } - } - - offset = size - last + threadIdx.x; - // Epilogue - for (; offset < size; offset += blockDim.x) - threadVal = r(threadVal, data[offset]); - - return threadVal; -} - -template class Reduction1, template class Reduction2, int ILP, typename T, typename AccumT> -__device__ __forceinline__ void -ilpReduce(int shift, - T* data, - int size, - AccumT* reducVal1, - const Reduction1& r1, - AccumT defaultVal1, - AccumT* reducVal2, - const Reduction2& r2, - AccumT defaultVal2) -{ - typedef typename std::aligned_storage::type LoadT; - - AccumT threadVal1 = defaultVal1; - AccumT threadVal2 = defaultVal2; - int offset = threadIdx.x; - - // shift and do 1 - if(shift > 0){ - data -= shift; - size += shift; - if(threadIdx.x >= shift){ - threadVal1 = r1(threadVal1, data[offset]); - threadVal2 = r2(threadVal2, data[offset]); - } - size -= blockDim.x; - data += blockDim.x; - } - int last = size % (ILP * blockDim.x); - - T v[ILP]; - LoadT* value = reinterpret_cast(&v); - - for (; offset * ILP < (size - last); offset += blockDim.x) { - *value = reinterpret_cast(data)[offset]; - - for (int j = 0; j < ILP; ++j) { - threadVal1 = r1(threadVal1, v[j]); - threadVal2 = r2(threadVal2, v[j]); - } - } - - offset = size - last + threadIdx.x; - // Epilogue - for (; offset < size; offset += blockDim.x) { - threadVal1 = r1(threadVal1, data[offset]); - threadVal2 = r2(threadVal2, data[offset]); - } - - *reducVal1 = threadVal1; - *reducVal2 = threadVal2; -} - -template class Epilogue> -__global__ void -cunn_SoftMaxXEntropyForward( - accscalar_t *losses, - outscalar_t *max_log_sum_exp, - scalar_t *input, - int64_t *labels, - int64_t classes, - const float smoothing) -{ - extern __shared__ unsigned char smem[]; - auto sdata = reinterpret_cast(smem); - // forward pointers to batch[blockIdx.x] - // each block handles a sample in the mini-batch - input += blockIdx.x * classes; - //output += blockIdx.x * classes; - const int shift = ((uint64_t)input) % ALIGN_BYTES / sizeof(scalar_t); - - int64_t label = labels[blockIdx.x]; - - // find the max and sum - accscalar_t threadMax, threadSum, max_k, sum_k; - ilpReduce( - shift, input, classes, - &threadMax, MaxFloat(), - -at::numeric_limits::max(), - &threadSum, AddFloat(), - static_cast(0)); - - blockReduce( - sdata, - &max_k, threadMax, Max(), - -at::numeric_limits::max(), - &sum_k, threadSum, Add(), - static_cast(0)); - - accscalar_t threadExp = ilpReduce(shift, input, classes, SumExpFloat(max_k), static_cast(0)); - accscalar_t sumAll = blockReduce( - sdata, threadExp, Add(), static_cast(0)); - - Epilogue epilogue(max_k, sumAll); - - // calculate per element loss with label smoothing - // reserve max + log_sum_exp for bprop - if (threadIdx.x == 0) { - accscalar_t log_prob = epilogue(static_cast(input[label])); - losses[blockIdx.x] = (max_k + std::log(sumAll) - sum_k / classes) \ - * smoothing - log_prob * (1 - smoothing); - max_log_sum_exp[blockIdx.x] = max_k + std::log(sumAll); - } -} - -template -__device__ __forceinline__ void -apply(scalar_t *gradInput, - scalar_t *logits, - outscalar_t *max_log_sum_exp, - outscalar_t *gradOutput, - int64_t *labels, - const float smoothing, - int classes) -{ - accscalar_t smooth_positives = 1.0 - smoothing; - accscalar_t smooth_negatives = smoothing / classes; - accscalar_t tmpGradOutput = gradOutput[blockIdx.x]; - int64_t label = labels[blockIdx.x]; - accscalar_t coeff = max_log_sum_exp[blockIdx.x]; - - int offset = threadIdx.x; - int last = classes % (ILP * blockDim.x); - - for (; offset < classes - last; offset += blockDim.x * ILP) { - accscalar_t tmpLogits[ILP]; - -#pragma unroll - for (int j = 0; j < ILP; ++j) { - tmpLogits[j] = static_cast(logits[offset + j * blockDim.x]); - } - -#pragma unroll - for (int j = 0; j < ILP; ++j) - gradInput[offset + j * blockDim.x] = tmpGradOutput * ( - std::exp(tmpLogits[j] - coeff) - static_cast( - (offset + j * blockDim.x == label) ? 1 : 0) * - smooth_positives - smooth_negatives); - } - - for (; offset < classes; offset += blockDim.x) - gradInput[offset] = tmpGradOutput * (std::exp( - static_cast(logits[offset]) - coeff) - - static_cast((offset == label) ? 1 : 0) * - smooth_positives - smooth_negatives); -} - - -template -__device__ __forceinline__ void -aligned_apply(int shift, - scalar_t *gradInput, - scalar_t *logits, - outscalar_t *max_log_sum_exp, - outscalar_t *gradOutput, - int64_t *labels, - const float smoothing, - int classes) -{ - accscalar_t smooth_positives = 1.0 - smoothing; - accscalar_t smooth_negatives = smoothing / classes; - accscalar_t tmpGradOutput = gradOutput[blockIdx.x]; - int64_t label = labels[blockIdx.x]; - accscalar_t coeff = max_log_sum_exp[blockIdx.x]; - - int offset = threadIdx.x; - - // shift and do 1 - if(shift > 0){ - logits -= shift; - gradInput -= shift; - classes += shift; - if(threadIdx.x >= shift){ - gradInput[offset] = tmpGradOutput * (std::exp( - static_cast(logits[offset]) - coeff) - - static_cast(((offset - shift) == label) ? 1 : 0) * - smooth_positives - smooth_negatives); - } - classes -= blockDim.x; - gradInput += blockDim.x; - logits += blockDim.x; - shift -= blockDim.x; - } - - int last = classes % (ILP * blockDim.x); - - typedef typename std::aligned_storage::type LoadT; - // input - scalar_t v[ILP]; - LoadT* value = reinterpret_cast(&v); - // output - scalar_t r[ILP]; - LoadT* result = reinterpret_cast(&r); - - for (; offset * ILP < (classes - last); offset += blockDim.x) { - *value = reinterpret_cast(logits)[offset]; - -#pragma unroll - for (int j = 0; j < ILP; ++j) { - r[j] = tmpGradOutput * (std::exp( - static_cast(v[j]) - coeff) - - static_cast(((ILP * offset + j - shift) == label) ? 1 : 0) * - smooth_positives - smooth_negatives); - } - reinterpret_cast(gradInput)[offset] = *result; - } - - offset = classes - last + threadIdx.x; - for (; offset < classes; offset += blockDim.x) - gradInput[offset] = tmpGradOutput * (std::exp( - static_cast(logits[offset]) - coeff) - - static_cast(((offset - shift) == label) ? 1 : 0) * - smooth_positives - smooth_negatives); - -} - -template class Epilogue> -__global__ void -cunn_SoftMaxXEntropyBackward( - scalar_t *gradInput, - scalar_t *logits, - outscalar_t *max_log_sum_exp, - outscalar_t *gradOutput, - int64_t *labels, - const float smoothing, - int classes) -{ - gradInput += blockIdx.x * classes; - logits += blockIdx.x * classes; - - // Do vectorized load/store when input/output have same alignment - const int shift = ((uint64_t)logits) % ALIGN_BYTES / sizeof(scalar_t); - const int shift_ = ((uint64_t)gradInput) % ALIGN_BYTES / sizeof(scalar_t); - if (shift == shift_){ - aligned_apply(shift, gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes); - } - else { - apply(gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes); - } - -} - -template class Epilogue> -std::vector host_softmax_xentropy( - const Tensor & input_, - const Tensor & labels_, - const float smoothing, - const bool half_to_float){ - if (half_to_float) AT_ASSERTM(input_.type().scalarType() == ScalarType::Half || input_.type().scalarType() == ScalarType::BFloat16,"conversion is supported for Half and BFloat16 type only"); - AT_ASSERTM(labels_.type().scalarType() == ScalarType::Long,"Label type should be CUDA Long"); - - auto input = input_.contiguous(); - Tensor max_log_sum_exp = at::empty_like(labels_, half_to_float ? input.options().dtype(ScalarType::Float) : input.options()); - Tensor losses = at::empty_like(labels_, input_.options().dtype(ScalarType::Float)); - - static_assert(std::is_same, float>::value || - std::is_same, double>::value, - "accscalar_t for half should be float or double"); - AT_ASSERTM(input.dim() == 2, "Currently only 2 dim input supported"); - AT_ASSERTM(labels_.dim() == 1, "Labels should be 1 dimensional"); - AT_ASSERTM(input.size(0) == labels_.size(0), "Input and label should have same number of examples"); - AT_ASSERTM(input.numel() > 0, "Number of classes in input should not be 0"); - - const int64_t dim = 1; - int64_t outer_size = 1; - int64_t dim_size = input.size(dim); - int64_t inner_size = 1; - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - for (int64_t i = 0; i < dim; ++i) - outer_size *= input.size(i); - for (int64_t i = dim + 1; i < input.dim(); ++i) - inner_size *= input.size(i); - // This kernel spawns a block per each element in the batch. - // XXX: it assumes that inner_size == 1 - TORCH_CHECK(inner_size == 1, "Currently only inner size 1 supported"); - - dim3 grid(outer_size); - - using namespace at; - DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(input.scalar_type(), 0, "host_softmax_xentropy", - using accscalar_t = at::acc_type; - const int ILP = sizeof(float4)/sizeof(scalar_t_0); - dim3 block = SoftMax_getBlockSize(ILP, dim_size); - if (!half_to_float) { - cunn_SoftMaxXEntropyForward - <<>>( - losses.DATA_PTR(), max_log_sum_exp.DATA_PTR(), - input.DATA_PTR(), labels_.DATA_PTR(), - dim_size, smoothing - ); - } else { - cunn_SoftMaxXEntropyForward - <<>>( - losses.DATA_PTR(), max_log_sum_exp.DATA_PTR(), - input.DATA_PTR(), labels_.DATA_PTR(), - dim_size, smoothing - ); - } - ); - - C10_CUDA_CHECK(cudaGetLastError()); - - std::vector ret = {losses, max_log_sum_exp}; - return ret; -} - -template class Epilogue> -Tensor host_softmax_xentropy_backward( - const at::Tensor &grad_loss, - const at::Tensor &logits_, - const at::Tensor &max_log_sum_exp, - const at::Tensor &labels, - const float smoothing, - bool half_to_float) { - const int64_t dim = 1; - Tensor gI = at::empty_like(logits_); - if (grad_loss.numel() == 0) { - return gI; - } - - auto grad = grad_loss.contiguous(); - auto logits = logits_.contiguous(); - - static_assert(std::is_same, float>::value || - std::is_same, double>::value, - "accscalar_t for half should be float or double"); - if (grad.dim() == 0) grad = grad.view(1); - - AT_ASSERTM(logits_.dim() == 2, "Currently only 2 dim input supported"); - AT_ASSERTM(labels.dim() == 1, "Labels should be 1 dimensional"); - AT_ASSERTM(logits_.numel() > 0, "Number of classes in input should not be 0"); - AT_ASSERTM(logits_.size(0) == labels.size(0), "Input and label should have same number of examples"); - AT_ASSERTM(labels.size(0) == grad.size(0), "Label and loss should have same number of examples"); - - int64_t outer_size = 1; - int64_t dim_size = logits.size(dim); - int64_t inner_size = 1; - for (int64_t i = 0; i < dim; ++i) - outer_size *= logits.size(i); - for (int64_t i = dim + 1; i < logits.dim(); ++i) - inner_size *= logits.size(i); - // See descriptions of kernels above. - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - TORCH_CHECK(inner_size == 1, "Currently only inner size 1 supported"); - - dim3 grid(outer_size); - - DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(gI.scalar_type(), 0, "host_softmax_xentropy_backward", - using accscalar_t = acc_type; - const int ILP = sizeof(float4)/sizeof(scalar_t_0); - dim3 block = SoftMax_getBlockSize(ILP, dim_size); - if (!half_to_float) { - cunn_SoftMaxXEntropyBackward - <<>>( - gI.DATA_PTR(), logits.DATA_PTR(), - max_log_sum_exp.DATA_PTR(), - grad.DATA_PTR(), labels.DATA_PTR(), - smoothing, dim_size - ); - } else { - cunn_SoftMaxXEntropyBackward - <<>>( - gI.DATA_PTR(), logits.DATA_PTR(), - max_log_sum_exp.DATA_PTR(), - grad.DATA_PTR(), labels.DATA_PTR(), - smoothing, dim_size - ); - } - ); - - C10_CUDA_CHECK(cudaGetLastError()); - return gI; -} - -std::vector softmax_xentropy_cuda(const Tensor &input, const Tensor &labels, const float smoothing, const bool half_to_float){ - return host_softmax_xentropy(input, labels, smoothing, half_to_float); -} - -at::Tensor softmax_xentropy_backward_cuda( - const at::Tensor &grad_loss, - const at::Tensor &logits, - const at::Tensor &max_log_sum_exp, - const at::Tensor &labels, - const float smoothing) { - bool half_to_float = grad_loss.type().scalarType() != logits.type().scalarType(); - if (half_to_float) { - AT_ASSERTM((grad_loss.type().scalarType() == ScalarType::Float && (logits.type().scalarType() == ScalarType::Half || logits.type().scalarType() == ScalarType::BFloat16)), "expected input and grad types to match, or input to be at::Half or at::Bfloat16 and grad to be at::Float"); - } - return host_softmax_xentropy_backward(grad_loss, logits, max_log_sum_exp, labels, smoothing, half_to_float); -} diff --git a/apex/contrib/examples/multihead_attn/func_test_multihead_attn.py b/apex/contrib/examples/multihead_attn/func_test_multihead_attn.py deleted file mode 100644 index 10407b9..0000000 --- a/apex/contrib/examples/multihead_attn/func_test_multihead_attn.py +++ /dev/null @@ -1,108 +0,0 @@ -import torch -import torch.nn.functional as F -import argparse - -from apex.contrib.multihead_attn import SelfMultiheadAttn -from apex.contrib.multihead_attn import EncdecMultiheadAttn - -parser = argparse.ArgumentParser(description='Multihead Attention Standalone Test') -parser.add_argument('--seq-length', default=64, type=int, help='Sequence Length of Input') -parser.add_argument('--num-seqs-start', default=5, type=int, help='Start Range of Number of Sequences') -parser.add_argument('--num-seqs-stop', default=80, type=int, help='Stop Range of Number of Sequences') -parser.add_argument('--num-seqs-inc', default=5, type=int, help='Range Increment of Number of Sequences') -parser.add_argument('--trials', default=20, type=int, help='Number of Trials to Execute') -parser.add_argument('--warmup-trials', default=5, type=int, help='Warmup Trials to discard') -parser.add_argument('--layers', default=18, type=int, help='Attention Layers to Execute to Gain CPU/GPU Time Overlap') -parser.add_argument('--seed-start', default=1, type=int, help='Attention Layers to Execute to Gain CPU/GPU Time Overlap') -parser.add_argument('--seed-end', default=100, type=int, help='Attention Layers to Execute to Gain CPU/GPU Time Overlap') -parser.add_argument('--hidden-dim', default=1024, type=int, help='Multihead Attention hidden dimension') -parser.add_argument('--heads', default=16, type=int, help='Number of Multihead Attention heads') -parser.add_argument('--encdec-attn', action='store_true', help='Use Encoder-Decoder Attention instead of Self Attention.') -parser.add_argument('--norm-add', action='store_true', help='Include Layer Norm and Dropout-Add in Multihead Attention block.') -parser.add_argument('--ref', action='store_true', help='Reference implementation in python pytorch.') -parser.add_argument('--native', action='store_true', help='torch.nn.MultitheadAttention Version.') -parser.add_argument('--fwd', action='store_true', help='Only execute Fwd Pass.') -parser.add_argument('--eval', action='store_true', help='Inference only, no backward pass.') - -args = parser.parse_args() -assert args.seq_length % 64 == 0, "Sequence Length should be a multiple of 64!" - -if not torch.cuda.is_available(): - raise NotImplementedError('Running on CPU is not supported') -torch.cuda.set_device(0) - -dropout_prob = 0.1 - -for seed in range(args.seed_start, args.seed_end+1) : - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(seed) - ref_layer = None - if args.encdec_attn : - ref_layer = EncdecMultiheadAttn(args.hidden_dim, args.heads, dropout=dropout_prob, bias=False, include_norm_add=args.norm_add, impl='default') - else : - ref_layer = SelfMultiheadAttn(args.hidden_dim, args.heads, dropout=dropout_prob, bias=False, include_norm_add=args.norm_add, impl='default') - ref_layer.cuda() - ref_layer.half() - ref_layer.reset_parameters() - - ref_inputs = torch.randn(args.seq_length, args.num_seqs_start, args.hidden_dim, dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True) - ref_inputs_kv = None - if args.encdec_attn : - ref_inputs_kv = torch.randn(args.seq_length, args.num_seqs_start, args.hidden_dim, dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True) - - ref_grads = torch.randn_like(ref_inputs) - - ref_outputs,_ = ref_layer.forward(ref_inputs, - ref_inputs_kv, - ref_inputs_kv, - key_padding_mask=None, - need_weights=False, - attn_mask=None, - is_training=(not args.eval)) - - ref_outputs.backward(ref_grads) - - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(seed) - tst_layer = None - if args.encdec_attn : - tst_layer = EncdecMultiheadAttn(args.hidden_dim, args.heads, dropout=dropout_prob, bias=False, include_norm_add=args.norm_add, impl='fast') - else: - tst_layer = SelfMultiheadAttn(args.hidden_dim, args.heads, dropout=dropout_prob, bias=False, include_norm_add=args.norm_add, impl='fast') - tst_layer.cuda() - tst_layer.half() - tst_layer.reset_parameters() - - tst_inputs = torch.randn(args.seq_length, args.num_seqs_start, args.hidden_dim, dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True) - tst_inputs_kv = None - if args.encdec_attn : - tst_inputs_kv = torch.randn(args.seq_length, args.num_seqs_start, args.hidden_dim, dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True) - - assert torch.equal(ref_inputs,tst_inputs), "ERROR: Inputs are different!" - - tst_grads = torch.randn_like(tst_inputs) - - tst_outputs,_ = tst_layer.forward(tst_inputs, - tst_inputs_kv, - tst_inputs_kv, - key_padding_mask=None, - need_weights=False, - attn_mask=None, - is_training=(not args.eval)) - - tst_outputs.backward(tst_grads) - - fwd_close = torch.equal(ref_outputs, tst_outputs) - bwd_close = torch.equal(ref_inputs.grad, tst_inputs.grad) - - diff_fwd = ref_outputs - tst_outputs - diff_cnt_fwd = diff_fwd.ne(0.0).sum() - diff_accum_fwd = diff_fwd.abs().sum() - - diff_bwd = ref_inputs.grad - tst_inputs.grad - diff_cnt_bwd = diff_bwd.ne(0.0).sum() - diff_accum_bwd = diff_bwd.abs().sum() - - print(">>> Seed: ", seed, fwd_close, diff_cnt_fwd.item(), diff_accum_fwd.item(), bwd_close, diff_cnt_bwd.item(), diff_accum_bwd.item()) diff --git a/apex/contrib/examples/multihead_attn/perf_test_multihead_attn.py b/apex/contrib/examples/multihead_attn/perf_test_multihead_attn.py deleted file mode 100644 index f81522a..0000000 --- a/apex/contrib/examples/multihead_attn/perf_test_multihead_attn.py +++ /dev/null @@ -1,115 +0,0 @@ -import torch -import torch.nn.functional as F -import argparse - -from apex.contrib.multihead_attn import SelfMultiheadAttn -from apex.contrib.multihead_attn import EncdecMultiheadAttn - -parser = argparse.ArgumentParser(description='Multihead Attention Standalone Test') -parser.add_argument('--seq-length', default=64, type=int, help='Sequence Length of Input') -parser.add_argument('--num-seqs-start', default=10, type=int, help='Start Range of Number of Sequences') -parser.add_argument('--num-seqs-stop', default=120, type=int, help='Stop Range of Number of Sequences') -parser.add_argument('--num-seqs-inc', default=5, type=int, help='Range Increment of Number of Sequences') -parser.add_argument('--trials', default=20, type=int, help='Number of Trials to Execute') -parser.add_argument('--warmup-trials', default=5, type=int, help='Warmup Trials to discard') -parser.add_argument('--layers', default=18, type=int, help='Attention Layers to Execute to Gain CPU/GPU Time Overlap') -parser.add_argument('--hidden-dim', default=1024, type=int, help='Multihead Attention hidden dimension') -parser.add_argument('--heads', default=16, type=int, help='Number of Multihead Attention heads') -parser.add_argument('--encdec-attn', action='store_true', help='Use Encoder-Decoder Attention instead of Self Attention.') -parser.add_argument('--norm-add', action='store_true', help='Include Layer Norm and Dropout-Add in Multihead Attention block.') -parser.add_argument('--ref', action='store_true', help='Reference implementation in python pytorch.') -parser.add_argument('--native', action='store_true', help='torch.nn.MultitheadAttention Version.') -parser.add_argument('--fwd', action='store_true', help='Only execute Fwd Pass.') -parser.add_argument('--biases', action='store_true', help='Execute multihead attention with Linear Biases.') - -args = parser.parse_args() - -if not torch.cuda.is_available(): - raise NotImplementedError('Running on CPU is not supported') -torch.cuda.set_device(0) - -torch.manual_seed(111) -if torch.cuda.is_available(): - torch.cuda.manual_seed_all(111) - -attn_layers = [] -for idx in range(0, args.layers) : - if args.encdec_attn : - if args.ref : - attn_layers.append(EncdecMultiheadAttn(args.hidden_dim, args.heads, dropout=0.1, bias=args.biases, include_norm_add=False, impl='default')) - else : - attn_layers.append(EncdecMultiheadAttn(args.hidden_dim, args.heads, dropout=0.1, bias=args.biases, include_norm_add=args.norm_add, impl='fast')) - else : - if args.native : - attn_layers.append(torch.nn.MultiheadAttention(args.hidden_dim, args.heads, dropout=0.1, bias=args.biases)) - elif args.ref : - attn_layers.append(SelfMultiheadAttn(args.hidden_dim, args.heads, dropout=0.1, bias=args.biases, include_norm_add=args.norm_add, impl='default')) - else : - attn_layers.append(SelfMultiheadAttn(args.hidden_dim, args.heads, dropout=0.1, bias=args.biases, include_norm_add=args.norm_add, impl='fast')) - attn_layers[idx].cuda() - attn_layers[idx].half() - if not args.native : - attn_layers[idx].reset_parameters() - -start_evt_fwd = [] -start_evt_bwd = [] -stop_evt_bwd = [] -for recorded_trial in range(0, args.trials) : - start_evt_fwd.append(torch.cuda.Event(enable_timing=True)) - start_evt_bwd.append(torch.cuda.Event(enable_timing=True)) - stop_evt_bwd.append(torch.cuda.Event(enable_timing=True)) - -for sequences in range(args.num_seqs_start, args.num_seqs_stop + args.num_seqs_inc, args.num_seqs_inc) : - inputs = torch.randn(args.seq_length, sequences, args.hidden_dim, dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True) - grads = torch.randn_like(inputs) - - for trial in range(0, args.trials + args.warmup_trials) : - layer_inputs = inputs - evt_idx = trial - args.warmup_trials - - if evt_idx >= 0 : - start_evt_fwd[evt_idx].record() - - for lyr_idx in range(0, args.layers) : - if args.native : - outputs,_ = attn_layers[lyr_idx].forward(layer_inputs, - layer_inputs, - layer_inputs, - key_padding_mask=None, - need_weights=False, - attn_mask=None) - else : - outputs,_ = attn_layers[lyr_idx].forward(layer_inputs, - layer_inputs, - layer_inputs, - key_padding_mask=None, - need_weights=False, - attn_mask=None, - is_training=True) - layer_inputs = outputs - - if evt_idx >= 0 : - start_evt_bwd[evt_idx].record() - - if not args.fwd : - layer_inputs.backward(grads) - - if evt_idx >= 0 : - stop_evt_bwd[evt_idx].record() - - torch.cuda.synchronize() - elapsed_time_fwd = 0.0 - elapsed_time_bwd = 0.0 - for evt_idx in range(0, args.trials) : - elapsed_time_fwd += start_evt_fwd[evt_idx].elapsed_time(start_evt_bwd[evt_idx]) - elapsed_time_bwd += start_evt_bwd[evt_idx].elapsed_time(stop_evt_bwd[evt_idx]) - - print("[ {} Attn {} ]Total Tokens: {:4d} Sequences: {:3d} Sequence Length: {:3d} Fwd Time / Layer: {:.3f} ms Bwd Time / Layer: {:.3f} ms".format( - 'Encdec' if args.encdec_attn else 'Self', \ - 'Norm&Add' if args.norm_add else '', \ - sequences*args.seq_length, \ - sequences, \ - args.seq_length, \ - elapsed_time_fwd / ( args.trials * args.layers ), \ - elapsed_time_bwd / ( args.trials * args.layers ))) - diff --git a/apex/contrib/fmha/__init__.py b/apex/contrib/fmha/__init__.py deleted file mode 100644 index ec2e9c6..0000000 --- a/apex/contrib/fmha/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .fmha import FMHAFun diff --git a/apex/contrib/fmha/fmha.py b/apex/contrib/fmha/fmha.py deleted file mode 100644 index 6aaca80..0000000 --- a/apex/contrib/fmha/fmha.py +++ /dev/null @@ -1,76 +0,0 @@ -############################################################################### -# Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * Neither the name of the NVIDIA CORPORATION nor the -# names of its contributors may be used to endorse or promote products -# derived from this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND -# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY -# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; -# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND -# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -############################################################################### - - -import torch -import torch.nn.functional as F -import fmhalib as mha - -class FMHAFun(torch.autograd.Function): - @staticmethod - def forward(ctx, qkv, cu_seqlens, p_dropout, max_s, is_training, zero_tensors): - batch_size = cu_seqlens.numel() - 1 - if batch_size < 4: - max_s = 512 - context, S_dmask = mha.fwd_nl(qkv, cu_seqlens, p_dropout, max_s, is_training, True, zero_tensors, None) - else: - context, S_dmask = mha.fwd(qkv, cu_seqlens, p_dropout, max_s, is_training, False, zero_tensors, None) - ctx.save_for_backward(qkv, S_dmask) - ctx.cu_seqlens = cu_seqlens - ctx.p_dropout = p_dropout - ctx.max_s = max_s - ctx.zero_tensors = zero_tensors - return context - - @staticmethod - def backward(ctx, dout): - qkv, S_dmask = ctx.saved_tensors - batch_size = ctx.cu_seqlens.numel() - 1 - if batch_size < 4: - dqkv, dp, _ = mha.bwd_nl(dout, qkv, S_dmask, ctx.cu_seqlens, ctx.p_dropout, ctx.max_s, ctx.zero_tensors) - else: - dqkv, dp = mha.bwd(dout, qkv, S_dmask, ctx.cu_seqlens, ctx.p_dropout, ctx.max_s, ctx.zero_tensors) - - return dqkv, None, None, None, None, None - -class FMHA(torch.nn.Module): - - def __init__(self, config): - - super(FMHA, self).__init__() - - self.p_dropout = config.attention_probs_dropout_prob - self.h = config.num_attention_heads - self.hidden_size = config.hidden_size - self.d = self.hidden_size // self.h - assert self.d * self.h == self.hidden_size, "Invalid hidden size/num_heads" - - def forward(self, qkv, cu_seqlens, max_s, is_training=True, zero_tensors=False): - - ctx = FMHAFun.apply(qkv.view(-1, 3, self.h, self.d), cu_seqlens, self.p_dropout, max_s, is_training, zero_tensors) - - return ctx.view(-1, self.hidden_size) diff --git a/apex/contrib/focal_loss/__init__.py b/apex/contrib/focal_loss/__init__.py deleted file mode 100644 index 0589eef..0000000 --- a/apex/contrib/focal_loss/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -try: - import torch - import focal_loss_cuda - from .focal_loss import focal_loss - del torch - del focal_loss_cuda - del focal_loss -except ImportError as err: - print("apex was installed without --focal_loss flag, apex.contrib.focal_loss is not available") diff --git a/apex/contrib/focal_loss/focal_loss.py b/apex/contrib/focal_loss/focal_loss.py deleted file mode 100644 index 85c6f62..0000000 --- a/apex/contrib/focal_loss/focal_loss.py +++ /dev/null @@ -1,60 +0,0 @@ -import torch - -import focal_loss_cuda - - -class FocalLoss(torch.autograd.Function): - @staticmethod - def forward( - ctx, - cls_output, - cls_targets_at_level, - num_positives_sum, - num_real_classes, - alpha, - gamma, - label_smoothing=0.0, - ): - loss, partial_grad = focal_loss_cuda.forward( - cls_output, - cls_targets_at_level, - num_positives_sum, - num_real_classes, - alpha, - gamma, - label_smoothing, - ) - - ctx.save_for_backward(partial_grad, num_positives_sum) - return loss - - @staticmethod - def backward(ctx, grad_loss): - partial_grad, num_positives_sum = ctx.saved_tensors - - # The backward kernel is actually in-place to save memory space, - # partial_grad and grad_input are the same tensor. - grad_input = focal_loss_cuda.backward(grad_loss, partial_grad, num_positives_sum) - - return grad_input, None, None, None, None, None, None - - -def focal_loss( - cls_output: torch.Tensor, - cls_targets_at_level: torch.Tensor, - num_positive_sum: torch.Tensor, - num_real_classes: int, - alpha: float, - gamma: float, - label_smoothing: float = 0.0, -) -> torch.Tensor: - """Fused focal loss function.""" - return FocalLoss.apply( - cls_output, - cls_targets_at_level, - num_positive_sum, - num_real_classes, - alpha, - gamma, - label_smoothing, - ) diff --git a/apex/contrib/groupbn/__init__.py b/apex/contrib/groupbn/__init__.py deleted file mode 100644 index 2f85770..0000000 --- a/apex/contrib/groupbn/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -try: - import torch - import bnp - from .batch_norm import BatchNorm2d_NHWC - del torch - del bnp - del batch_norm -except ImportError as err: - print("apex was installed without --bnp flag, contrib.groupbn is not available") diff --git a/apex/contrib/groupbn/batch_norm.py b/apex/contrib/groupbn/batch_norm.py deleted file mode 100644 index af0b7e9..0000000 --- a/apex/contrib/groupbn/batch_norm.py +++ /dev/null @@ -1,260 +0,0 @@ -import torch -import numpy as np -from torch.nn.modules.batchnorm import _BatchNorm - -import bnp - -def check_if_rocm_pytorch(): - is_rocm_pytorch = False - if torch.__version__ >= '1.5': - from torch.utils.cpp_extension import ROCM_HOME - is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False - - return is_rocm_pytorch - -IS_ROCM_PYTORCH = check_if_rocm_pytorch() - -def check_and_convert_channels_last(tensor, torch_channels_last): - if torch_channels_last: - channels_last = tensor.is_contiguous(memory_format = torch.channels_last) - if not channels_last: - tensor = tensor.to(memory_format = torch.channels_last) - return tensor - -class bn_NHWC_impl(torch.autograd.Function): - @staticmethod - def forward(ctx, x, s, b, rm, riv, mini_m, mini_riv, ret_cta, mom, epsilon, fuse_relu, is_train, torch_channels_last, bn_group, my_data, pair_data, magic, pair_data2, pair_data3, fwd_occup, fwd_grid_x, bwd_occup, bwd_grid_x, multi_stream): - x = check_and_convert_channels_last(x, torch_channels_last) - if is_train: - ctx.save_for_backward(x, s, b, rm, riv, mini_m, mini_riv) - ctx.torch_channels_last = torch_channels_last - ctx.epsilon = epsilon - ctx.momentum = mom - ctx.ret_cta = ret_cta - ctx.fuse_relu = fuse_relu - ctx.my_data = my_data - ctx.pair_data = pair_data - ctx.magic = magic - ctx.pair_data2 = pair_data2 - ctx.pair_data3 = pair_data3 - ctx.bn_group = bn_group - ctx.bwd_occup = bwd_occup - ctx.bwd_grid_x = bwd_grid_x - ctx.multi_stream = multi_stream - - res = bnp.bn_fwd_nhwc(x, s, b, rm, riv, mini_m, mini_riv, ret_cta, mom, epsilon, fuse_relu, my_data, pair_data, pair_data2, pair_data3, bn_group, magic, fwd_occup, fwd_grid_x, multi_stream) - return res - else: - return bnp.bn_fwd_eval_nhwc(x, s, b, rm, riv, ret_cta, bn_group, mom, epsilon, fuse_relu) - - @staticmethod - def backward(ctx, grad_y): - x, s, b, rm, riv, mini_m, mini_riv = ctx.saved_variables - grad_y = check_and_convert_channels_last(grad_y, ctx.torch_channels_last) - x = check_and_convert_channels_last(x, ctx.torch_channels_last) - epsilon = ctx.epsilon - mom = ctx.momentum - ret_cta = ctx.ret_cta - fuse_relu = ctx.fuse_relu - my_data = ctx.my_data - pair_data = ctx.pair_data - magic = ctx.magic - pair_data2 = ctx.pair_data2 - pair_data3 = ctx.pair_data3 - bn_group = ctx.bn_group - bwd_occup = ctx.bwd_occup - bwd_grid_x = ctx.bwd_grid_x - multi_stream = ctx.multi_stream - - dx, dscale, dbias = bnp.bn_bwd_nhwc(x, grad_y, s, b, rm, riv, mini_m, mini_riv, ret_cta, mom, epsilon, fuse_relu, my_data, pair_data, pair_data2, pair_data3, bn_group, magic, bwd_occup, bwd_grid_x, multi_stream) - - return dx, dscale, dbias, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None - - -class bn_addrelu_NHWC_impl(torch.autograd.Function): - @staticmethod - def forward(ctx, x, z, s, b, rm, riv, mini_m, mini_riv, grid_dim_y, ret_cta, mom, epsilon, is_train, torch_channels_last, bn_group, my_data, pair_data, magic, pair_data2, pair_data3, fwd_occup, fwd_grid_x, bwd_occup, bwd_grid_x, multi_stream): - x = check_and_convert_channels_last(x, torch_channels_last) - z = check_and_convert_channels_last(z, torch_channels_last) - if is_train: - if IS_ROCM_PYTORCH: - if torch_channels_last: - nhw = x.shape[0] * x.shape[2] * x.shape[3] - else: - nhw = x.shape[0] * x.shape[1] * x.shape[2] - shape = int(((nhw + 3) & ~3) * 2 * grid_dim_y) - bitmask = torch.cuda.LongTensor(shape) - else: - bitmask = torch.cuda.IntTensor(((x.numel()+31)//32) * 2 * grid_dim_y) - ctx.save_for_backward(x, s, b, rm, riv, mini_m, mini_riv, bitmask) - ctx.torch_channels_last = torch_channels_last - ctx.epsilon = epsilon - ctx.momentum = mom - ctx.ret_cta = ret_cta - ctx.my_data = my_data - ctx.pair_data = pair_data - ctx.magic = magic - ctx.pair_data2 = pair_data2 - ctx.pair_data3 = pair_data3 - ctx.bn_group = bn_group - ctx.bwd_occup = bwd_occup - ctx.bwd_grid_x = bwd_grid_x - ctx.multi_stream = multi_stream - - res = bnp.bn_addrelu_fwd_nhwc(x, z, s, b, rm, riv, mini_m, mini_riv, bitmask, ret_cta, mom, epsilon, my_data, pair_data, pair_data2, pair_data3, bn_group, magic, fwd_occup, fwd_grid_x, multi_stream) - return res - else: - return bnp.bn_addrelu_fwd_eval_nhwc(x, z, s, b, rm, riv, ret_cta, bn_group, mom, epsilon) - - @staticmethod - def backward(ctx, grad_y): - x, s, b, rm, riv, mini_m, mini_riv, bitmask = ctx.saved_variables - grad_y = check_and_convert_channels_last(grad_y, ctx.torch_channels_last) - x = check_and_convert_channels_last(x, ctx.torch_channels_last) - epsilon = ctx.epsilon - mom = ctx.momentum - ret_cta = ctx.ret_cta - my_data = ctx.my_data - pair_data = ctx.pair_data - magic = ctx.magic - pair_data2 = ctx.pair_data2 - pair_data3 = ctx.pair_data3 - bn_group = ctx.bn_group - bwd_occup = ctx.bwd_occup - bwd_grid_x = ctx.bwd_grid_x - multi_stream = ctx.multi_stream - - dx, dz, dscale, dbias = bnp.bn_addrelu_bwd_nhwc(x, grad_y, s, b, rm, riv, mini_m, mini_riv, bitmask, ret_cta, mom, epsilon, my_data, pair_data, pair_data2, pair_data3, bn_group, magic, bwd_occup, bwd_grid_x, multi_stream) - - return dx, dz, dscale, dbias, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None - - - - - -class BatchNorm2d_NHWC(_BatchNorm): - # if using BatchNorm2d_NHWC simultaneously with multiple streams set multi_stream to True - def __init__(self, num_features, fuse_relu=False, bn_group=1, torch_channels_last=False,max_cta_per_sm=2, cta_launch_margin=12, multi_stream=False): - super(BatchNorm2d_NHWC, self).__init__(num_features) - - self.fuse_relu = fuse_relu - self.torch_channels_last = torch_channels_last - self.multi_stream = multi_stream - - self.minibatch_mean = torch.cuda.FloatTensor(num_features) - self.minibatch_riv = torch.cuda.FloatTensor(num_features) - - #defaut to distributed bn disabled - self.bn_group = bn_group - self.max_cta_per_sm = max_cta_per_sm #used only in training fwd and bwd - self.cta_launch_margin = cta_launch_margin #used only in training fwd and bwd - self.my_data = None - self.pair_data = None - self.pair_data2 = None - self.pair_data3 = None - self.local_rank = 0 - self.magic = torch.IntTensor([0]) - - #calculate cta per sm occupancies - assert(max_cta_per_sm>0) # won't be able to do much with 0 CTAs :) - self.fwd_occupancy = min(bnp.bn_fwd_nhwc_occupancy(), max_cta_per_sm) - self.bwd_occupancy = min(bnp.bn_bwd_nhwc_occupancy(), max_cta_per_sm) - self.addrelu_fwd_occupancy = min(bnp.bn_addrelu_fwd_nhwc_occupancy(), max_cta_per_sm) - self.addrelu_bwd_occupancy = min(bnp.bn_addrelu_bwd_nhwc_occupancy(), max_cta_per_sm) - - #calculate grid dimentions based on occupancy numbers - mp_count = torch.cuda.get_device_properties(None).multi_processor_count - self.fwd_grid_dim_x = max(mp_count*self.fwd_occupancy - cta_launch_margin , 1) - self.bwd_grid_dim_x = max(mp_count*self.bwd_occupancy - cta_launch_margin , 1) - self.addrelu_fwd_grid_dim_x = max(mp_count*self.addrelu_fwd_occupancy - cta_launch_margin , 1) - self.addrelu_bwd_grid_dim_x = max(mp_count*self.addrelu_bwd_occupancy - cta_launch_margin , 1) - self.grid_dim_y = (num_features + 63) // 64 - - # allocate scratch space used by implementation - # TODO: scratch space that is not supposed to be exposed at user code. We only need one time initialization, the - # same buffer could be reused in future iterations. Currently we exposed it here instead of requesting new - # buffer from cache allocator to avoid unnecessary initialization at future iterations. - self.ret_cta = torch.cuda.ByteTensor(8192).fill_(0) - - #FIXME: turn pair handles into an array - if bn_group>1: - local_rank = torch.distributed.get_rank() - world_size = torch.distributed.get_world_size() - assert(world_size >= bn_group) - assert(world_size % bn_group == 0) - - bn_sync_steps = 1 - if (bn_group==4): - bn_sync_steps = 2 - if (bn_group==8): - bn_sync_steps = 3 - - self.ipc_buffer = torch.cuda.ByteTensor(bnp.get_buffer_size(bn_sync_steps)) - self.my_data = bnp.get_data_ptr(self.ipc_buffer) - # we are walking on very thin ice here by utilizing internal `_share_cuda_()` - self.storage = self.ipc_buffer.storage() - self.share_cuda = self.storage._share_cuda_() - internal_cuda_mem = self.share_cuda - # internal_cuda_mem[1]: ipc_mem_handle - my_handle = torch.cuda.ByteTensor(np.frombuffer(internal_cuda_mem[1], dtype=np.uint8)) - # internal_cuda_mem[3]: offset - my_offset = torch.cuda.IntTensor([internal_cuda_mem[3]]) - - handles_all = torch.empty(world_size, my_handle.size(0), dtype=my_handle.dtype, device=my_handle.device) - handles_l = list(handles_all.unbind(0)) - torch.distributed.all_gather(handles_l, my_handle) - - offsets_all = torch.empty(world_size, my_offset.size(0), dtype=my_offset.dtype, device=my_offset.device) - offsets_l = list(offsets_all.unbind(0)) - torch.distributed.all_gather(offsets_l, my_offset) - - #whom do I actually care about? that would be local_rank XOR 1 - self.pair_handle = handles_l[local_rank ^ 1].cpu().contiguous() - pair_offset = offsets_l[local_rank ^ 1].cpu() - self.pair_data = bnp.get_remote_data_ptr(self.pair_handle, pair_offset) - - if bn_group>2: - self.pair_handle2 = handles_l[local_rank ^ 2].cpu().contiguous() - pair_offset2 = offsets_l[local_rank ^ 2].cpu() - self.pair_data2 = bnp.get_remote_data_ptr(self.pair_handle2, pair_offset2) - - if bn_group>4: - self.pair_handle3 = handles_l[local_rank ^ 4].cpu().contiguous() - pair_offset3 = offsets_l[local_rank ^ 4].cpu() - self.pair_data3 = bnp.get_remote_data_ptr(self.pair_handle3, pair_offset3) - - #FIXME: get magic value into C code and eliminate from here - self.magic = torch.IntTensor([2]) - self.local_rank = local_rank - - - def forward(self, x, z=None): - if z is not None: - assert(self.fuse_relu==True) - return bn_addrelu_NHWC_impl.apply(x, z, - self.weight, self.bias, - self.running_mean, self.running_var, - self.minibatch_mean, self.minibatch_riv, self.grid_dim_y, self.ret_cta, - self.momentum, - self.eps, self.training, self.torch_channels_last, self.bn_group, self.my_data, self.pair_data, (self.magic), self.pair_data2, self.pair_data3, - self.addrelu_fwd_occupancy, self.addrelu_fwd_grid_dim_x, - self.addrelu_bwd_occupancy, self.addrelu_bwd_grid_dim_x, - self.multi_stream) - else: - return bn_NHWC_impl.apply(x, - self.weight, self.bias, - self.running_mean, self.running_var, - self.minibatch_mean, self.minibatch_riv, self.ret_cta, - self.momentum, - self.eps, self.fuse_relu, self.training, self.torch_channels_last, self.bn_group, self.my_data, self.pair_data, (self.magic), self.pair_data2, self.pair_data3, - self.fwd_occupancy, self.fwd_grid_dim_x, - self.bwd_occupancy, self.bwd_grid_dim_x, - self.multi_stream) - - def __del__(self): - if self.bn_group>1: - bnp.close_remote_data(self.pair_handle) - if self.bn_group>2: - bnp.close_remote_data(self.pair_handle2) - if self.bn_group>4: - bnp.close_remote_data(self.pair_handle3) diff --git a/apex/contrib/index_mul_2d/__init__.py b/apex/contrib/index_mul_2d/__init__.py deleted file mode 100644 index edb63d3..0000000 --- a/apex/contrib/index_mul_2d/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .index_mul_2d import index_mul_2d diff --git a/apex/contrib/index_mul_2d/index_mul_2d.py b/apex/contrib/index_mul_2d/index_mul_2d.py deleted file mode 100644 index 1d34fe2..0000000 --- a/apex/contrib/index_mul_2d/index_mul_2d.py +++ /dev/null @@ -1,144 +0,0 @@ -import torch - -import fused_index_mul_2d - -class IndexMul2d_(torch.autograd.Function): - ''' - Currently only support index in dimension 0 with a 2-dimension tensor. - The shape of indexed in1 must be same with in2. Now this kernel does not support broadcast. - The datatype must be float32 or float16. - ''' - @staticmethod - def forward(ctx, in1: torch.Tensor, in2: torch.Tensor, idx1: torch.Tensor) -> torch.Tensor: - assert in2.size(0) == idx1.size(0) - if ((in1.dtype != torch.float32 and in1.dtype != torch.half) or in2.dtype != in1.dtype): - raise RuntimeError("input1'dtype and input2's dtype must be fp32 or fp16. And input type must be same") - if (in1.dim() != 2 or in2.dim() != 2): - raise RuntimeError("in1 and in2 must be 2-dimension tensor.") - if (idx1.dim() != 1): - raise RuntimeError("idx1 must be 1-dimension tensor.") - - if not in1.is_contiguous(): - in1 = in1.contiguous() - if not in2.is_contiguous(): - in2 = in2.contiguous() - if not idx1.is_contiguous(): - idx1 = idx1.contiguous() - - assert in1.is_contiguous() - assert in2.is_contiguous() - assert idx1.is_contiguous() - - out = torch.empty_like(in2) - - if (in1.dtype == torch.float32): - fused_index_mul_2d.float_forward( - out, - in1, - in2, - idx1) - elif (in1.dtype == torch.half): - fused_index_mul_2d.half_forward( - out, - in1, - in2, - idx1) - - ctx.for_backwards = (in1, in2, idx1) - return out - - @staticmethod - def backward(ctx, grad_out): - - in1, in2, idx1 = ctx.for_backwards - - grad_in1, grad_in2 = index_mul_2d_backward(in1, in2, idx1, grad_out) - - return grad_in1, grad_in2, None - - -class IndexMul2dBackward_(torch.autograd.Function): - @staticmethod - def forward(ctx, in1: torch.Tensor, in2: torch.Tensor, idx1: torch.Tensor, - grad_out: torch.Tensor) -> torch.Tensor: - if not in1.is_contiguous(): - in1 = in1.contiguous() - if not in2.is_contiguous(): - in2 = in2.contiguous() - if not idx1.is_contiguous(): - idx1 = idx1.contiguous() - if not grad_out.is_contiguous(): - grad_out = grad_out.contiguous() - - assert in1.is_contiguous() - assert in2.is_contiguous() - assert idx1.is_contiguous() - assert grad_out.is_contiguous() - - grad_in1 = torch.zeros_like(in1) - grad_in2 = torch.empty_like(in2) - - if (in1.dtype == torch.float32): - fused_index_mul_2d.float_backward( - grad_in1, - grad_in2, - grad_out, - in1, - in2, - idx1) - elif (in1.dtype == torch.half): - fused_index_mul_2d.half_backward( - grad_in1, - grad_in2, - grad_out, - in1, - in2, - idx1) - - ctx.for_backwards = (in1, in2, idx1, grad_out) - return grad_in1, grad_in2 - - @staticmethod - def backward(ctx, grad_grad_in1, grad_grad_in2): - if not grad_grad_in1.is_contiguous(): - grad_grad_in1 = grad_grad_in1.contiguous() - if not grad_grad_in2.is_contiguous(): - grad_grad_in2 = grad_grad_in2.contiguous() - - assert grad_grad_in1.is_contiguous() - assert grad_grad_in2.is_contiguous() - - in1, in2, idx1, grad_out = ctx.for_backwards - - grad_in1 = torch.zeros_like(in1) - grad_in2 = torch.empty_like(in2) - grad_grad_out = torch.empty_like(grad_out) - - if (in1.dtype == torch.float32): - fused_index_mul_2d.float_backward_backward( - grad_grad_out, - grad_in1, - grad_in2, - grad_out, - grad_grad_in1, - grad_grad_in2, - in1, - in2, - idx1) - elif (in1.dtype == torch.half): - fused_index_mul_2d.half_backward_backward( - grad_grad_out, - grad_in1, - grad_in2, - grad_out, - grad_grad_in1, - grad_grad_in2, - in1, - in2, - idx1) - - return grad_in1, grad_in2, None, grad_grad_out - -index_mul_2d = IndexMul2d_.apply -index_mul_2d_backward = IndexMul2dBackward_.apply - diff --git a/apex/contrib/layer_norm/__init__.py b/apex/contrib/layer_norm/__init__.py deleted file mode 100644 index 4bbc476..0000000 --- a/apex/contrib/layer_norm/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .layer_norm import FastLayerNorm diff --git a/apex/contrib/layer_norm/layer_norm.py b/apex/contrib/layer_norm/layer_norm.py deleted file mode 100644 index b084b1a..0000000 --- a/apex/contrib/layer_norm/layer_norm.py +++ /dev/null @@ -1,53 +0,0 @@ -import torch -from torch.nn import init - -from apex._autocast_utils import _cast_if_autocast_enabled -import fast_layer_norm - - -class FastLayerNormFN(torch.autograd.Function): - @staticmethod - def forward(ctx, x, gamma, beta, epsilon): - x = x.contiguous() - gamma = gamma.contiguous() - beta = beta.contiguous() - hidden_size = gamma.numel() - xmat = x.view((-1, hidden_size)) - ymat, mu, rsigma = fast_layer_norm.ln_fwd(xmat, gamma, beta, epsilon) - ctx.save_for_backward(x, gamma, mu, rsigma) - return ymat.view(x.shape) - - @staticmethod - def backward(ctx, dy): - # assert dy.is_contiguous() - dy = dy.contiguous() # this happens! - x, gamma, mu, rsigma = ctx.saved_tensors - - hidden_size = gamma.numel() - xmat = x.view((-1, hidden_size)) - dymat = dy.view(xmat.shape) - dxmat, dgamma, dbeta, _, _ = fast_layer_norm.ln_bwd(dymat, xmat, mu, rsigma, gamma) - dx = dxmat.view(x.shape) - return dx, dgamma, dbeta, None - - -def _fast_layer_norm(x, weight, bias, epsilon): - args = _cast_if_autocast_enabled(x, weight, bias, epsilon) - with torch.cuda.amp.autocast(enabled=False): - return FastLayerNormFN.apply(*args) - - -class FastLayerNorm(torch.nn.Module): - def __init__(self, hidden_size, eps=1e-5): - super().__init__() - self.epsilon = eps - self.weight = torch.nn.Parameter(torch.empty(hidden_size)) - self.bias = torch.nn.Parameter(torch.empty(hidden_size)) - self.reset_parameters() - - def reset_parameters(self): - init.ones_(self.weight) - init.zeros_(self.bias) - - def forward(self, x): - return _fast_layer_norm(x, self.weight, self.bias, self.epsilon) diff --git a/apex/contrib/multihead_attn/MHA_bwd.png b/apex/contrib/multihead_attn/MHA_bwd.png deleted file mode 100644 index 7069973e0a88d6f459a88dc85c0d6d92f4e85d73..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 86630 zcmeFZWmweR_dW^;iYU^cw1FT9sC3RCB_$0K11dd~bc3LRfP~Ur(lAKF451()NJ%$H zGjum+&-4BLzmMQKZ_ahjyPwyE3^054{_MTtUiVsS`zt+_CM2LHz{0{JlzAedf`x?( z!NR(53;!bcjoDq(ci@PA%k zxP=+UFH52TzFmIxM9U5fi;5KUg>91b3oILoB_koG=8U~GdMTPoJ?3=1+)lAwV6OS< zBST_eNkd6urfAr0^8pXte!Lm$K6Q;}5?T!=6n0Pv&1V`CVsEi8N&Wtcdv|Vx%+=M^ zSa7%2(}Z84ZJB1#aeUxdq=U;&JY6*@$L+9BXRa#qyWTzXzP!qt&Vqxrvn1HK#89j= z|IJ^`m@MzlD5zbLlWj_}2f;yVyGi229f>t1qe1pIOa6i+getx^^3n zIE>`nQlZHec=(>=yX$E?e=RWd@#_g`0>-<4EeUFKh4_eNWlDPaua9GVhxILl94aq$ zc1NIC-oGgs_moGPLWchSI09@RhI_B0Kb+kk@Q4lU(6zLR5Qb9ezdnu}2N$A(_3*}D zj~bu@R&B!ofrtF{arbcXJwIc;e01g*FgrgINVcq`|IQW|2m*~9vnhK!oHrK@Y3V9or7Wx z8;uW8B9aNN0r6{hP~QLS5!9OioM-0+@ls;wgu3_XVeM%#k^97jos-?#@tu{?!^OAU zNxl93aVHIt^7#V4PaABvv8#`fEthRSQU^WD``nymS{Dto4!gyn<=x<%iNzG&H= zi?XW}lFoJFY_04@g+SawpbrVcsSMJyxp43e^^)`R+dQ{(^#{E#T*~|W`Fmq9ow=s2 z%{z4s4XL5%m7$Xrr`qDhR1#Ca-ls3|g zpwaiPFbhQySTbz8?9Xs|nVrP1ZL_hl71wwi4prDpEydfl7h+$yxF_S{vSgO@ZH4i` zfMRA7^`idbVLii1=+11r(5kRbiRnj*?#lR;r;8}_?Zw_Qmrdj0BBQ|4va+Gjo8lHb zOM@A!T*pCa_ON8;Bq5mQ3mIHu7JDE#?>ygTG2jT04wRUMiL8CSz-*~Nbq6j+oIYaL zA+pm1IVm*gGir?-G>U>5Bb!?syNHJC{fJNu#kZwCyhl}*n)f4jDWHR3wKCsADz0p8 z@Hs6F*!+@=;JDKcIdx~_R6H`d-C{#*-m|}9gP+=yAKlIb+Y{O96qmW@b&S>@=IPfp zuy(Ge`PAO@y^gH2GHZ3nzO^$&9JZDg-9D3bm<6rDS>GLVD_A%b)pcAQTgJH+d*bA^ z>x7VeVRkPlxn4>+Y4^Jw^I{Db15ls>=x_bP29#nCEjBKaF&Xu;aHdO|!GTILTic6P z$nht)wT1ujiMh=OSC#QV>_e6^}tZ>-1_+3r!b97Q^0FdTCk|9me*WL?9$ z)?a(F-Anb2YxJGxdIQxcUbhU{W6o6COKtQHHJv$jJKVCP0SmeWUKg7_+%O2ct6)p@ z77k0siV8-rPc<#iK~5`!=tYL~He>szf5s9z=tJ3;X4?~n%_PzN1*@h^RMyBy)q5Fs zmEBJpUBu@$^2^8IoV7=`O=_j^SYFH6D_=bNQ{;(GB%BrN3yNMtrLS^x za*8@UR_reCra$YLkr4W9H*iU`XK_uO;UE;|Jo%2f4Z*kva~hrTnXEj~aZr&AIFeKp zatb(lrR0H2{GJ7>Q~O@&r!Ul+6#w*@`}9X?m6d~fBFNJ04PJ|AR9LjFw{Tms?%f6Q zQr`vHU=sGQl@NGa!L$v$D8WiOH%guoCd{%M7=RQZq!%6#@jBU`Q^8?K;D`~e!y5^t`!P-dk0?QvsKEV2YC+OU`t?{-W zpYb`I9%9l?>vT!s*wtxeWP$W{W13beTwR*$<#vjO)q7{jedv}VmF~ghV6E5b(_#-` zyh>wZsi~k{%^W`hY~7FMyiPJv(HjsP_5Kd`SuVQ31&;FCu;aa{5R~Tx1qH><>B;!1 z-Kf5x@duaG6*?l(y<{_w0}_bmk+aa$2d<@@Tttz?aec)6w%5s=m&H_LP`liwnZR*{ zhRAQGaNSDVbdj9_0|b$)rDcBZyG6bib-3Ly9!oA|mzB!7r4Fx?;^9P}yRu`M#?96q zXdN$qM(2D zn`-$TmEHrqxxPHDDGtitbgfy3q3xBQ9a!gH+gh%9-6MOlNZlfk`6^ zdLJb#CW#^r13Bs(zR|k{JIx9A)|~E?*E)<}Ij}&-|9Yr&5-rDB&8_m-$fGGsCm6|{ z6;ba8IeM`0oL8nXbonFI^a_ zsCLpIGm;Xnl`J^VmDX_r6x^kJ-1GL}(b0jWgp2%6_4P4kkI1dTe$fRtWvX9`cOCX; zqJ|}@VAH(m?gFy0><|hM<^<1!g_T3jop@U=+|%^K{hihZx1PDSOR%2DmSa^e>&AAg z*`imdIF`6Y$dZQT1s7J|?8y21L$`T2U-=#Vid;CD zg`CugnBsb!3O&iPD6&tN&`v(0hlNaJJIXby2IZ~hxWw}?;P_qp*Nv@7-rrAVD5=^1 zqV0854>=MaG4jYI;(6Hj2(te}@ppTcdSHMJgNl32LvvFgS= zI@6oLd_3O`DUI&Li(WM;z>6^4@H*Y~66%{wntv&CxZP&$**&A9{FSL}zNz5RvR~k8 zD}n)~yu#}jMF@$S96#CN+l`cx3=QMx^LV6#C-w3TgL3d2Z9K^sPTeR(#UhT*X?W{n zD#F_RHtV`v`=*CG&x{XrT^G`$mQdrzR+%Ik_9oh=+{c3^Nu?p4TWK6?6Kn$!$U`I4 zK^CPb0an)hOWvJZr9yMr?KgtB$!ex@J0gx8KISD$rP7LcdfB(TPtlCen+i-x&v(v! zcQe~VE2l~AOOB+qtDd^_&uM-8{;I2tB3h;ET@@#6R9~Me!O3aFrb$NT$#x>cH_722 zvvd8c(mhRvz>D*|}BDm1(F5a%Oq zyYQwwRdeUN5?j}*<7Mkr`zQ^nr^xrJT?xG7BDa2BL<)p)_>65MoGPzS1#zb_XizUm zIz_m5!J6y?YuEbej7zIJ-}NS)Y%!O6;ytwIV;gDGyEqO|*+_^g-fQwarQhnu&w61( zqjajRmtl9@m(BIj%>t>Ew^(JN*+Ly;%unqLdKV?qR!qIfX)E`B2Pz(j&4j65bxqCe z2}tS7HSZ9$rbmmSdk)NQYSj*fh^d0qzFP+trq8rS^LrZtGZvmL>@C_+=={odTPj+p z(&&_?ViD|Dt4pG*O++MS61>VP_0=zLHyBOn+x6f0G4H;O-Wd+?=xNwyFCY0uvgf=R zH$L*vea9>_!B}=NI~7%pk>~Nkin5#L`jMRaWB2w_cHJ+}6|OtT?%FN%ryI;k>nqFM z#|tkh5FLnH?kFgbzROUmD7u#7@e8u=B2H&A+S^~K(Rz0#vMa%}G_l)_z;@pGtsm9u z%AAnnX$hT&W-iOBnzK&fW}{m|c86s~f;+7DQRgMmEhh!9IPzyIsWy21^NbUqAN#>#Rav2lgbm1j2N~LMKh395?J1P zcR`pFMfb{)$~y~ZG0}F+Wv-Gz*Pk?rRi|dP-s=ZblmyH{YPOqbgs)~j$O^VQ>Xd}m zNWKy@AXEh#8~_euTBaPVz`(;l+!uNWUrtj5l4ubezP1?Nfq z>QmdPGRjkuwJJM_&dwNud!$UtlS5M4!~y87YdbdBy)~*)vg2m;$vM&Z=%u<8%~Q9d++&rEHyj2F|LQsJA}w{ z?G4BKIrJ8?zq%a{Z#(9RTF+ytxgzU_DPN^4-0kOFW$5i)X18_;^cq}`Uhc4N@-14Epz5m$3=B?@l2Qr0K!L|# z?|HfjIkl5U-*h>=qJkGukL@q-QFCm?`gIB71aS5${-eW%Y+gN_O+qxS;Tcs7`yIqJT@*dCg_AA?#xfqJE-!H81gzXS#Cn^A z2n_7-+$m-rRp0DR6<)lSTFVj-@v0%D<^N*v(7|K9;nI^Mv^>PIiAOieKG?{)5+SPT z#8LW5rHf7xFY{%f=MnmrvLux`n?~OIgof&NsZj^-tVp4QOjIM#62hD-Pbx~^7N}{q ztKzL5d1g0+);`}s_oQhHF)hkwLL1un#+PCb&Bq45`i-d~E{yj|e!EusbgjX9xVCSf z+*5J%fqT?-4`T=A1v$-~eoiz)FDsMY#O>iE{n}#-;hH&-Re=S5R(|SsowPp2(=El* z9+EvJy1H(z?dllQ@j4hSnS+^piN#$=d5nEP{wi=QD`7Ycv4~osBF(P3r%9IjzzZK* zes}8$BGd1oxOXIywp;HVB0pp@CDgro=|4xHM_9cR@9|QR7|JM))1sVmw;1k`EY!HD zlNWf6VmqyW)PMr-QXLiXO53etMaXr7K;->-t5}YfBT;jDyt_w4LQJgNu-A31=u6q0 zHCkm}be=u#@S(gwe!QFRmpKapHDS=~`AgP&(w~#`o8)EgA1CmgdWK7v!q*SUt~r`& zF^z1_`LtgcElr~3Hfopp5pFKn^oZ;*Whz%_f)zeVQiG3QpwA9>5B;wH)-Mz7!n=r| zye9f28+eG%1hiRkK&fSK<$RBy6xu4>4zNx1-Ic+v84$izi2mS_ZzV2DBvbX(FG=4! z5OOq0gh0pHI`2)SDOP0LJ!8LS|1>wymq>7bg<4xN2VtYU)8F3&@eRv7*xkaYh_gW?aJD>i9~HF0_@$Z=fHs?{nS#{+uM=ww`UY)B2F#7qp$ zB5c(|?HzylZCG?px?imh`j7#g3K@IlBar(%n0NXO13<-XzN@?kLW@6g$Y?@46C53^ zmCXa!8pKi^SoRY*;X?@vsm9$k^T99qJJWW1sU^H0P?VlDYpkWA5C|WclR&5YVL$0z zRQM~>8BYBJI-*h+CWU_c&l1&+Cm_+cWJVv>Ns9TNo~)ia!isuGB}dim3rRT^5;G^K zqE=wemi@r6>(1$>{(b{w48iHju2{OHq#P=PE$P8av8MD~R!p1WC<}oI7P|EWd5Gqv zJWw3>UzaU|)UKDswC$1`Drgwy10K^-lqGAweok7mhHuJ}i}rn&q>~`iaMCbBP;%67 z&E)>h>XWke88rvhM*Xm&hZrhgvLO$=*AEf#f)-tMx?9?A`hOw*8cZYX;rMgiKF!Ls7nCPP4#B%xg z3JwDtk?;8%E;7Z*?|@D1whLpss*iOb5h}wM-YDT!c*B8YR9RE=jf6sB8iIMu9p(^J zj?k^4vI@GLB*EmqRCgF_dSX#q&{46tq9dQRf(#q$-Q*rAV$$L|k~mD4C?rT`5&M!y zj-*TDafg1Gl`27~spxoVt5ewe_Wa_> z!HlYi@#bb6eE1!=hEILc>Ajwh)!o@+8D@^7r-|4!APc^HC#B zWm|Qg3p?H+hG=|!;d^Vu1z8PGG-*B7#J@!YI;C~R3MIwle@ykXXfet^U%F>1)U*bH zZ0t)DDj@4+9~7)tSoQcl;JkULK@gi^`0bri9ZcbfZWCa`$*DP=_vytU1aSB1+YB}? z7fRy$CZ%-(V2b>&Q=a=?q#fS122O97ib% z`L#bLU&lqaSYQUh9*A&JTmzY@6=fQY1QZL@d9*IX2{~%ICk@j`YcIt*@GgB!AGY( zA~^zizG!rdBuXP$sGoJX%o6A7Ljy`#I;8gkcW4wY=XolnWLl%sGcn5qf;A&WLo zc=R=4)#6U{s$lPs4mM~z4^wUmrr{6lwtb<+tV11Z4;l+MXBSIxVrzH;7FVVB7ZzP| zP#uDr>fv^bV{?yTiXbXuU>cExJveTyJ(80yCuidA-KW%^W1I~GJCYd!bXLIw|xGKQ>D6gv9mbl63X#&6&aSR)tiMr|@Vo^I0} z4j6<~V#WEqXxOMcZg;9NnAC_G;yi9s-KC_!V?X}V1_X2TBH2-9WWprE+LIUWmcx?z z(Rk>~!rjBKw}I8G#aKOv?-DkS7F8!v#49wBr4lgT2a8E zA_ApTQQFFVt8oW8JPjdUM+W_Q8dL(*5qE}X=vJMjo9RKeVHe}A1~cg zcid$zfg{sL(MrJP_xE5U^7+oS(wxJ{=eQ9V28H7?yy@yg#X4+ehfAFy&JLQLu01?o z`3{4%?k}{yx)QuL+Ume&g8jRBb<0sN+HD2;Q+Ih*)>mUiMbMLY}uy`lOZmT@901oEUO;YP9Uf)+JZF?M_7}QaIa`zp?3l4iV6~s40!>G-FETh|dmUM_iHb?*Jso3EiE#Ef&&A^abfaB&8yDkglcEG4?F zESPGgeD$-YXD~t|(~gGvYx+jF`y`vKy$*GDJ2-PG$qzn~yV*=-g2+)tBPnn5qcr`) zV^tSG5e`4(wl11MaV?j0KeandXGleUIoR*;S~+RZ$*kp9MH9}jcG3ws6zt9N)w=vz zvQ_9=ek{fUt6JnJJ5WSRyH{k<}HjW!|&7WPgll;(guaslHXIT>c%jEI5Z_xWqv!70fJZ4F6u=z;`th0^Qv+M ziIx{1-FM+`D`!+Lry75E$9*!@rEat7X}5h)Drrqx&hFNQ`y)yGgMyzFx{+!}!s`~? z+P<;#Y-_}*4s8vH{5T*h9PG8a+_a@2J@uQd@>Fm*V%d*#Yq<9jE5Qf%o#LcYJCK+C zJP{VV#>SX_Mchl1J+arnwCbSNHh=G}i(pNS)<&+Ky=TTv36J1Uuab^gn#9Zl1@$#` z>-GG1#yiMXpGjrM>FXVmq6q0a6y1N5adaiE;-u}XaVyMKe8Tbs1o`TgfLz%anfVh> z`U97`5bDu8@z%`33Ui-0K9OhbEo#=nG1^@pHi~g7Md}P*u+WA?Rd{}iB^=55W9Jyd zmfx!l4e)V=-=yF`hVIOzjfD_en8sQkrz|Y?&Dh}grb1&OP|S;*1q7W!c=jgh?AWA)uv~_RvLCWm z1J)X(p2<0(ctSkGnGOn5TzUwh=~f3NWaR@EOd>I4L?yLL9tcBp{=J-UZb34 zc-$saOh1OSJibMj>w4CA6ESr1jyW7YveqW3Z;$YWjK?Oz#~l#4>0vc& zbxG#zxxZmsbR?^r8QbPABCJaD5H{w>l_M|YVO-8hLV0Lqhjo<$dNc1>VVf>AZ|q0L zzU3lHyyubaP*<3lEBzL2?`WtTAP~S>y2#|rl zaKZl^1LgL9l^Cv8DUzE1u@aObc#pAvRVuGO+$-MwHvx^ul6><84?KvJN%{sg63fY( z^(IJs<1qo^Df!x^2(f?j>l5E!)lmn>Vsaylp*OIR)rO59OaFa}|9SOi%J|>0|GZ-4 z0$lf_sb=l-Wf9P1C?;Vx@13Ik@527)6`09~A;!j#{ja|$C4-M6*!rr({#VYt`2zq4 zKlzI&|N0Ay68Jbg*XLQ3e zaK4GM7sn0(MxD}jd#lM~_Un$&>Cxu$q#uK0{$pxt>gH3Di@yc_T0Al1f&`^fsjP#Y zzoxd1&ggKgo|e`PYsxY4zs~k5D6?p0G|o|{_DG!)1r<;T^E zscSkYN6SohOte(^BaGR}HHq*1qeB@tG5TsfRM@z`IfoiyP*70NP*t7M_BQ=(Sw1=| zc{;-mL0O4_m254BA9N!A$+;bk^cjf1bTFDx~Gyse?B z8I|38S65L>D+Xky`6un~o;PWo-FQV*8Vd`y_jU@Zt3_65{jSj!9UQpwSpvbgnVVEN z|J=mQ@fh=?<(OwSSY))b9^y4}vXu#Oz!l`SSskl(JX|i1Bl-Z!9Fh%gB%Yt`HPjIo zSbx0!vNB%2x)w#R1J%R94o_o{(RR?e9D=k37->MuBq8EmezELdGvkB$D|~aq`N~hx z@aw}(3|=FpJNTDGe2ExVd@Es9&NjYfjqvk33uTOf`rk!pm%^_Jg=*fA`>j<=}sH0if6aclIEB&ZO9a;xBX-pp2=@ zDl=%h);+biK4m%G{1LB~2H?pPVr8-OW_5$@)!y2R;TjLrq0@Y4N^nb_R&gVynl6st z=4F%HyS=v)(l_ZT7SH2#C_UcL@bL0i?+b``em?Zl!L%~ZXz-x;o_D5S;QZyB$pfG| zI&sUx6-1Q>Bz28w+RPp~IJ(mSNDudl@Xk@MQ~xCbwM z6ykW>4tAE~b~Y;Y08~IU4AP%prvN3t<5INwXv)fcx7`(*FeeG$1iI2-5EM9t*w?z1BDrVbf5^T_~jF0Y9ofro)XG0>CYb;ffDp~$*1d+fK!6Khh{( zUfzuAT3MBx=YHDGEKz>QimmC#YiZ03upPp)LX|9y&&|g-jDVqZzlw)Yr}BN|#l@TO+2IQ*92^UgFY0|AO4hFPzj}@g+>PG9 z?70;!{v-GtP9R4?RaL`4ZPH+w<*+e0ceB-}GBW-&DUzdFTZ*GPEXKGH{wwDladYu& zDKo8ds{~M0*&E({J0TS}9YZR!3v)k~1EuDIfx7`NV8VdmdIlu+hELe;>|PRBGhI#9L!=K83kY@`XUiMt=SD@Sq^~ZD~VLk?LYAoo##Ms zIlQx_Ed2!~mO-tnW#fk%ya4ChI)7$QbO#aR50B`8^Rd+(@5o3>Oi}L@*PgyUDea7s zk{&!E@~!Kl?#`c$qVIWtaSL&UxM#%+3x^J)1qhu$WI2!*r>dkS88Ex$!PWx8Tn>n6 zkycD3HW+#XYDjVJ8LQ=>FfB(=fvD}WG>YQeXg7Bris3RrPMXQ+QOD#DbCDj+^WJ1T zSIIxmj?)yvP)r_Z?#_d6h*Y{YB*Nk#o zSca$4(znb!wD~Oo%Af;_@HY)*fN6+)uEl=#$w2Wxi&xWZTO6!&YHhOqqob5qFbFMB z2zGJbtNtZdxs3F|xSgQEbEXeUa@q7VIBBH`@!J&=_byz#`YK;I<7omsA|gVW@wc`9;P#Etyt>UY0Mn z9G>-|2P*QhEdOmfFdKIH!crmhbU-3nLI6+?Jp?+7A)yOF5qLX@)Pyov;7*`yX>cY= z--`N;fO?nU3Tn0HZLvc1gKt(f`?GJmW6nj|p)}sWsvxJ+Y51&4e0_bN(Db%hFmd+Q zf;wy|ZQqy;Hn7vvr=a(zCYzw*MFex z@!e?(cE+GFQsKbUYe5YlpM1WR4>P=!K_?2R5fZ_UOre`;bY9x1YRBY(L#usgL7ier zZb8kEhp`w( zRu^g1N=dbqRZyTk-i7O*xEpqN@8er_h0pY!doTE!K-6dyNXahyXg4zCiFlVw^3vIC6+-o@Xq1*>6md_EQO0l0Yc!|WN|7IK0mtE-*VRELSNw1 z3febw>sp4#)4#bygo&yvfjBDvELC|yE3SWq=MLlEB#-S@1t5x0$2-HL#N`15HK)g` zAyr1(RR2JyBiYY6{hPCG#h9w-49c6-kNMya(hVeG&Z!eu{9Tb)O}G9ufl8e?@P$IF zQNi5YTsbX87+JPkD*#uj?X6e$$(oJ(Z)3;&v}zhjks^NWpHp}c(CK`e$-^}<6?W)oaa(K83Kp9tH5ZLRXSwh;h&eV z%GZcAJ=X;$)xZ!nOf}V3^EBa-+OVOk4G5E*XZY_uoChrni;X^gI{h%G1cg}cA20o+ zRg@&yAFHiY@9G_nu6mllWz(%5v9f6*>~ABhRvZfKJX&OBWk(LF4GzHJ*R<}YAXqRR zF>XnxXC4U~SMZAI7koUC@Soi3%AbC!J{4#n)QW@aW$Zo$E~|icH;9H0NuqURwg#r@ zhlngFPeI2)naN7^Olwq}&s`zMrC0l6;q&%@2Pv~(P}wa2HR8O6&BRpD%gmE5!^307 zp#EIMth(x)6MK-GpYI?>1T#rzZL!XZqzC9&k%z6@trDoL=2I*@jvkMOx#=hNdT&t` z^1xs*ajHWlT65zahs?_0y{xLwCB)hHCnHNnO3b>KfI&pH#c(lqExItA0deyYIMj}g z4FME%**KqWno1~?ZbbuByx+I;W7dQk2canSs!){gI;lfMV-FLRneq* zg~gpoGpB2{q80)Cjcfo4-VwSpH5I_(o z;#PK;IY*9L1acU+b&B$+&TPfJBu;*QJ|Y4j`K(^)qErrc_)JP4J$zeCjByy~W)b_( zgEMRMrsw$Lw5sQJu#mY72t%1&wI6c8H@b*~_k{aE5rANymW^FLd(B>RM#a&XdUg#R0A120AwKQc`YC21ka6gC?+U zi<6`MzH;mFx0?^Ky)6L0*If1heKxdm4gK1cCcTWQ?T4WpnMT;~L^42YoDV(dl2hbo zAR}&$0BWKuQ%rYe^OE1=4Gj%hfCj6b-+`R5pW^Paf)dyA;&#E-i-gqSrj%pSjB{YA zwoltHo?Ys1QhxjSP*A+!RSYcu-Mk9l989kk%!;V=D49{|2^d7MV5&rQcDmQN+S=L{ zK)W!8|yzkMvSpBIXuc5(s+^(&g9T01sG7A~x%ndH%3M-DqmB(WLMVl#OHVORE><%t7mx%3ISAMiub#WPyPn%)k}Wap0n7! zITUJVS6sdIORmgudFZowJh#~=cI}d8pub`5O%k{PVqmU^v`mZd&qOnZ(C6S0toDWg zv|-ZK8L#nxQAHIRG~NAYsyNJ4LN9yHD$*0x2|%3efVMcR*;m)$$13xW5*_$tQAj|=s`EP}~x5b?iXIvZY0 z;Bgx+FZSkI_U5Q(q*#zefGVm1AaXA-y1{*A^>(KLIJCb)Y^XI=R(7^ceSc1`bgGKF zx>06=V=O?zOkMY@LZBfa+1fRA?VivX#-~~VTHnuF&~&hk4qHLu&=a_`U;(k?0}S`P zf^?k6pX9eUJ@XfjLr1rrK@)}B$UCfyL$Xv*-Nz=>&s||(oc;Cd z*QZx=zJxTh+HR)TbqH{(VPJqYKCuk&XdLlgA7t{p1gE&KQOfO+nL}hgFh^LC@1#J>OiS^)D_qH zhqCwB*-I_5a2m54 zZQPDa&*B8^GJgHzx<5!&-q*{Uu!#ISg11VaVzRFURmX9L^B)F{5H45d!?`q=y4x0 z0=IwReICx4Q6aB4oGbuah9`23L6p|4skNoWVq^LT!n_rDI~h7;h^!mLW3Sa)Y6ug7M8#VfIBqAm{dw7Ev7$h+}2fWrU#>C zmT`a^>gkafy2o^%j?XIIeCDC+*8G#!2V`KJCK%`Fp(XU^IPw&XCx>nP0ytV5u!5gO zKsmW1032UStjB9UM!vXzf6U2Z3q+Ru1n{t@;T;gj20#L&?FHZ9`SYF#Zw`DGn`~{DJe_im zUm&Bdoim-@ouN<+a+3Bh0`G%i?{0x%O2N3>|2`JT806t&q)VefZ&}LP%+gY!l8{g_ z7r~Bka^7aL?8eY!evHFzvPq5n^UOmtUi;>%cZ&azA08U|!~s)L`Ar(#o*>i@Y(%uZ z-rZZTq3gh=v(y+!1^tPO82vG9XTEl63$Nu+xNU{@%x1KpLrz;9Uk7jrQHJ^QkFoym zT^rO73izult`>tuOh%n)d3ir)Hk9OuW_EzJg?0dE*i{_D>E`#Z&xVFjWnluM5KzUC zF`|hl!VPbY7qENsGFBz<2Jt+G32|*NGz2{@{{{8)>NuB(=t}@|G6bB6%8C;xqX3Ec zR#0eQ3BuWXb0!x-JE|6DWZ$}>^`9fxaX>`sBo+d|W%@{eON%Vk79i_7?)_6t2ApU6 zRJ-0Eo?6Ks2Y6^Zu@DeWf6hZ`J7q=q`R3(kWjzsqaQ#0$P}dcffs_??0?*H^7zdE# zx8a(Vdj%8yhGLChjKB2f0V##zEcT@)`yneovr-pI%E)IbCRxeCWid2z9e98`YT-Wt zvUe3hNqPA&I6+va_+viG| zk~27BA>rCTfi4x|NRk|30605fi#)mjzRNOT05P?=fxU8*gR$$8e&0+`{Scp{BKG*PyC#u$Miy-Y5-HU$__n{OeQO9^J* zF!&KMl<8PiHn@lHfgorHbXgCA%9kYrhQnv6$1~uDfPV$8`ORx>bM)*CCYS@nv!d1; zlG{=h$xVP!Q2|U1I7ld~cPhST-m}8G-r}fRJ%uuxdmnY)hiAw`v7n(fsXc9qt8Kk(jxf_kw3)iGoWpT4m=PX z9E*#eaX?!$E_#<5BGTc1HK6(ZYh)RC4sF-Hc=v1#VZh|`8RlD$Ccr&2Frqg~!V)mi zgt_+v;BhB*mY#selI+l}8cqOUT|PA)B!X)%yo9Vn^bd#n<5ame|B0v`Z8iuFKRCK?zp@E z-2zN;%-XvWS=b|B0OH;KaT0SlUrL5uN!dHPQ|M4R6)ZIJp4;`5W_-`L6`NFusNfHS zz{2f~u(woBHkkbx!?7kdhI%ZN+H=W!4`;cJu&?Nt$$V_%$p3XyaD*{?(7K7|%;<*y z70U@f&v^Sp_oo)XS*C)*C5mh%9EKe03U%3v6M!REoNq9p9h!drsTGVwI17%J$`ldP z`uW0J&z{tFZGLUKOTyypz0;V6;E@g_a$M53q_nq`$qZN9ku%w;`l+MzEd*dh6lhO= zDoY%T2&mB=--s+Hx3EKEXmps0h|Eg(N|%x6e-+$Y7vDb9JW$(}{4w8_E`K2*Zh&&~ z^1uu8!JWsQIk)HQNRP}%s~hvG^U)|3J)>SK&n$H7oaqlA0#8%^0Za?x=osj&7b$0a z!vHAwVhq*3Jiato3?hZPkYICj^F zZ^6~*+njNxjh_~=HAEM<*@#FJG!ksHBd*cLnD+U&)F%5dP?u>KxHqA%6@kln%&n{( z+h)o6@u`BgR5UeB7Nz$Z`D(hrS1BHcIzlOpH7BA-+d%qWy)R+$WwN^iAYC|WZBHl<6~pqzA3wz3Ge^T(2G-^VLY6Za%R3A$2cn!q0~) zY5%8*;PM=Zrr}1RH2-$d&nt^CnVm0J#Dt4sZVU`sh&xs=q~l!bWL2y(?0j@-c`?C zAyqQr(#7|L?xq}*x?acIJD96xFeU*7ZVZ5ZpYP?F-i^`D81Xz@wtBBa*o|o>ftXOJ zOv`)4gG%)OTE!>O{A_Sb(F?%X6izki(;*(k1S8EcmZLF*L0_r1IIK6u41Y__d*^~q zn`4#XvXDLjy*_vD`1mo5j=Ry~=rP-k+OHxSxCLy;+P;$hWGwSCS+UZiCnaI#eYlsO z`K0R7`v)MAKV!I>G!VLGAm+0lOebrLYuzQH6ErLzu?@U8?&o`eTyLbYi^a+T_zR?r z<6B|II{&T`kEJt0S!atz=d%2b(sV2$7WG96oMkb-l0FZbkS&w9XjK~UqAR#VZ*CgvI2DL|jOWI; zXHhnDtN*y7&e+_h%xmjPJpau8)&$8NR`z8`i5vQUU3%cHy@wco0yu|daL-B#dh8L1 zvcs|>dNx5*f9dk5+pQlhfN^{pn;1d`U$EJ1<3mAsTZ~ez1crz}xcT1$8>D>R6KQd! z*c12B?ar}Z7VJWrm63GpVqel+hHZ;8Oh>y&`;tHRv^2Jt5%5vBrao7T?0;sK zm7Of$)}9S$r$gR+xw|^PiVR{Q6#-m<>Lc>fjc3)a4$@u?7|QOB{J!vO*+-!5wR{MZ zQsxQ&tnP6KjYccxMEcx4Y4>W-&gUraH6r+@G(>bHX6LN$OX!=Mt@@t515=?X?VY-HKz zw1z7DX^rIS^(^zYuGHsEk#l)Q{qSO7&dOz7A|e!@b#}c=cdxsrp`8=J$`=sCE2}k`x2$aA3?TvDsiVmSrfh6eOpphDa(5S^|DNXO3C>rujOQi@-*lO z9G76@%hfE|^u*H(YqtyIjg1V9b!0#}O8;;#1$TML;Lx(z7lWx6IwZKl8div6-;A&e zg8I&@_uZC#CSjuzeNk*Vkqlb>&oBHdiTLE+uaTAab;|A=O^BzJnKELPZUNc6#s5XE zkOg<>+UoLke@S+sp&jc5$7;C)LYr^b2nBNc=#0{gpUFIUELS6IV3(UeX8RpaQ12=Z zy?I%mjw&&hjK+z{4}{M(&#V>6?q$gc7`R}`d?@oyX3O)0tg0!x)n>P`5DAB=KVdKO z_8K^~3abhKKC%8tiAR7XJiy7n9P)I4DG*Ge^5}k=qwi+1fS-38o6Xwqd#3{uGW0d z3AA+`t$&J~lG78fe;Uu<#WN}w{=okGdy;pi3JL$tPCQ@q72~!l=Mn=^N8n1Ic2-gEG3*{X&4~8j<-f(iFy8jG1aT{^f`c|h_^7dC=bG^g2Jl39#OstmH{J&f>?8h$K+Jy(j zpTJv#lWFA4PE!{0Zcc6X_3%2{zfeX8FCPXwxzW5i)WrQ3>|OSJh*9;cK}&RvOPk!U z2s&1zMJ4H%BTBN(XyIZ=uyF;Y? z=IYa*-&*&Nw_HB&+?g{c_St98)kGEg#a6GZ!L$FRNU4}Ijdf(=(O-ld1Ly$dsS5Dt zyBhle>f$Jv+>kN-glo3^ponj1)#^QGb%P5FS%xBu{dqtHZ7{ zLZeq?&5%*Q&auSDnsFEBXI0bp`5an>=g@7P%WSY^WUBHy_^Lt60_D1t%u@c z`>%sB|GH|h{|Hc-y(h-Ab8~Y&x^vZ5_MnsXi|@a+4`XI`o$quTF;Px$v|M5pt1Qo| z!W|i3VBS;%1x7W!IP;f+Htw>r|Jxe+N0i$?(?Yp-V2Dcc08^?u z0Tiv#l<*pRH}yiv6iC%`_?+(&5bR;;ASw>Rc$B1tqQAVQC5fRh>KdI|*ldh{)&nwkNVT~2uZ)dTTxR*I0lQxn!`;|D?D{|KKCV9EitAed1N5O=&9 zlb5cM)#SctK_S0MopkID^xm{lTN62fLR^1*3NMV;#DpTW=Gq4&ri6_lnM!O8$CRiT z><;NmsqdRlU>S*r?CS@MdU~NLsH}3+7d9ECWpCg??1Xdt>1v`S5n#P(3i~_Utt;D8s=< z#`*Ru9$lQMy;W8=j-ZgHufF{?nFyM(A!5OQ*STA(ul zr^v8ROM~}%R!poE~w9ej{W(0!VHU_Dk|3S>CR|T)lV64K*!aS-CLE z4-l*!!1)62ygl+z-1|d1=ICHaU~ExE;4?mll187}^27t$@*i6KLtjaahS6ZzSn5Yg zSBJt595iJUdi%soCq58O9g7Ws2EdiVIz{p*ewyh^RZuufyY0wFOWo63>%Iv|i{JO@ z{HK%aoauQoSD;>3IXbYtvG88=S_O}TsfpjnJ{QM==(D!h1b>~Bt-fqCDNe|QO6>&t z8;~h{4$C4E=GIcELdFc=1(LdDh;fdF%l&AOgUS1xKjgT>**^PRQ*<8t0XPH5o5|aR|4=u)5~+&;Ey!ir7klL(@!>rjz>}IYt`Fd4Kq{}Qtfr^!-(`Nf zk1?TAmh-+2;{{p47iR<1J+?(^A1s0lA0az4E@Rqp+!v(tSfA=km!q)od&f(xCS^Gn zrW&5(m%Nps332)rJ+NTg1_mChI4Qv1M-_91;zS6C;VPbcfK&{?`P~HqK%RrjUyl9! zg_QiUABIPu0P6zfG*;CPG{Jy2x(=8Y1d?&!WXg`hEHD#cT{ zBl5MkK*qHA%JkfbfvUIcJGH8)ci#AM_V{I)OSclyBSe`D#(IlHhQd5?Xlwg2NTJ)S zSE{k-PBOJ*GEBg4S-Jnb{V_}U^FnyESaUg*xRJi`LZ@ygH_vn^dG5#Jqe7X@J()6s zTJS0Hi+y)h73bq93B6U+<3EP-b5bLRLS5l=_<}2J13J2d)B6&DFbHZhg=FR|6@(YqMivm0e!Ljv(V2H2QwtW8T)klBK~iET zDn?@n*CUpdU`93`i@&{P#2>U8ifY2S5tJv*x|CqB0mVL}Qd?a)qseGKUJQLI75e$KNzJA}Sfmwo& zvuv&=i@o_!DL-i$YyoFyfB)8|*3V@jmz9hMl_oX`@25+4d7VjvgmOHZb{OMW4g{*A zww)h)<6{V&m&nx?3M}`{K!80u9(RCI=VUe8VL|(OyUBzP41oO;_a96Sm0=hi(%K;z z1%!u)ja9h6E)ELu56#S&h)2_w)YjHYy_KWcKzdm})Z2^Gkb;41@id;-#U>5Oo{8>f zL*JsdnYzfi^7vrI4jbAL!pGEo9@?~Pr!A3)TV@1DNTq920fIMy`nFX>_L~RSQFTin zRcs|ttZl!6oLKkcTSF$T2&38hq(ms%XTLYKe7-Xx{!+5q0ZObCjUbOL)w2LiIRKm| zSZe$`E(Yxn@+bi^SZnwE7O})K)WafM2vJ1RKlJ^tZ=u|ad!9l(%-qwgC zii`V6BMa#}=~K=Sg$j!*3x2H-;l2*jrK_~zFqbwG-CWO?x07_01N*#O93;;E|s`euepOX+(RpS2p;ysRrA<34p6` z+@KMG?H;#@@bnY!u9w+Fwr5umijwO7!PyQb2we=9XBcT^u<9zk-eD??>XA*WId9?j zOR&#p%Rs>zm23+FK5K8_Ze=PKK3`u5@dE@ggR3v=)^6ZnK?^Pw_}2+&I|zW!ifV++ z_;dA6WpOMf3g=z6A*zTA|1gY*0-AH*=5bh0`x=gJ`C_YCdR@jjEgM6GiIPBlf~-2) z>j&DGc3pwxxpEwqV|o0;Qe2~A)Ys`1u930KjD4l!e%_wK!CVdGIB%-R8qk!*ywUW! zo*a%2hqqO2fgnXmAC-p6Fh~O>(|qhw&AudxeKX$X6P{xUb^SmLuZwsYbJ5Q{Bu6HA{!x|pXYCol zGpnw*_T^jO?H0J6eaqi_BQGAub$XJLEZ*m?idBAQWO@LRuq}NON;uqwmBcZJ)&G1C zJ>oovq)CdgD3P%Mq$ck)oHC%?8T#tunnAY#riGf1cey~T7qrD|c}Td$&+!(H9_AGm z9#D1p-TqDXQMc~H)Hxj5lRoX7jT)8W>|ZYoVh#R*+MU7McaAx6p)*uKhEtMp>$s^` zkGLy0JxRUUFp!JYINX~RpBSZ=L|0&TK3A(UnX%pz1;4MP>m2ZQJb1m;x6SPM{l~fk zTEri|h+K+L7`^xVF?4|O0E~=c??H)O+nT1z#adns`(_qFZgD^!_b-gICdtdXo9MU_ zp!FeOp#L~3JxSo!iId5c4b_ONyl?$ThCZ~?XErS^GBOb-Wqqu#z>(WOW3R$Ub2Bi2&&^0l;I9g`0UK9B>r5h@zbHlkJRCu!19?Tp6s>J7onQKZGIns{BYgQNM#WS&*T>dzwO(lN25NRXd8WInAU6 zFWJ22F2jcuOj3J9rP_48B98Ul{ zc<;>;>HC{}wiiRCF8u+^#+uGUGob{JBt?&zwmRk5H%amysW?SSbM=867r>~_vX|{3 zDn8z;yO`2huDDoK`)weBhskJ}p_PoDj)rt83`^@}d<`WDTt!92IaGxiOpj)#)A#@E&wJ>`D{6N_$yqVlW))(bg!EVbE>S=d5UcEq9h?L#MsT2*w3 z&Fi-rA>a4dr|F(Ose(bKtsfArvX04NWZ@#a(u~*n`DjAn^X&#_vV&?adnpokUmt>W>ck2x7cpXs3e z&Fl_85OqMY2_>PIgC1!7+VI?*sl&WoOdnDzBJS(|XGip~k&!=v(!u-BIw_}uiy*l* z1Cg{Q`{C@C7mV3=cy)*9%qWl2=z^E&F>gMwxB4b5fQmRA zu^hjsY$N8UPB4wPXBRmcy_q=6wdkmrw<3O2mYPWv-%y&4&c6%|#*XgE{R z53~nh%)44qoRI&M89sYzWMot;<4`94NT z9`(p*q4-;V0a*unN>9eRBEu;prn+2b15Qyc@v@MV^BlhlK51cB1&njC`mrS5Y)3M* z3=M#~>SrhNVT7weCdp@g=pJ@20n%dy;5kkK0esUqRtx(!zipJ@6E8}qvF)rRbHMH- z_s}YJZ7p{G8+GoNfnNXQ3P0#u^@KNEJ48xFe(KMV1 z_P}SSYeLCJ-h?Y#i%m^ANRR)mJzWtghwm_?A=;p?x~Bj~wxC(>Sn^gu@9w|_ggl{P z#B%uiAsa+PDOd!O?+K;f8>#G@ZWoJYrf0lGdRbPw98p&K6VHCEZY{&Zsj)MJ#r9y} z`_E|?lq0sn-Myct9rVtJ4WEgW$Q1456iRfb1L;(wgG~D_fo$3;y6`}j>s>gT)zdC( zg?Lwwj&KYz`=<`dmFwL*=yYN)I;6Z|MW3er(5O7TJFE$UlQqCD3sTp@ay&f!WZR)fsK6T-WVkgNrnQx_pf2L)l!cC$Qy0M1}1ric>YNO48CO!66|d z0ENu1tW;#w_sPl0p(<4=keRMHm;ZR|c@N>gPct=1^IP?)AanQEas7Z-#eE6k@E~j~ zQitIV4N{ZzyY&`s+Z71uDc0H=Me<#;iaR@bh1K2(acJEj$KtHk5tq-pI$lRS3~|`+ zjgTc^-#$7<*zHU`EyeK1J6;7AdIo@|eg#`?IY22s12bQeZ#Z2qUbz8Hj<9Wpq@OpN zG59a&5nzb!{rMDMA;xpq1p-AAY>6?gSKa=hfHekpU97#2cI3phX3oBY4f)`dBmL#e z$5Gpy8odQd5{Q6UP>G{Fo8QOrAGV{}67QGKu(DJWN^)fd?9{5LA{H7rlL|Fr9?p6d z`=_0Xq5C5pr=_L&5V-6ySaaMuS*or3-f|nnV9Zuqr6CkCc3&P=h|&E#7h|4G6$6yYmjRa_FYZ##WhQ%h)FX2M=h&Z z&(^hR>!=MwJ=MSC^qu6B6p$sE&nNQdh>GmWJwe82<|m?W4+YGO7`cn= z+}uUsebu1arY2|r z14b)5AB^5dc-GxAu`3<1U87%iUblHCmt}_492hh$$(BUeI=d-#ljsSVc z)(jJ0skB_>Z_dg&zJj5(rBB~>C1W5MMe(Zn6Ve)3uG&V|*F!oucIuPF14}L#AkPzf zd%IG-{+bc^4^MQjyE}Id$D^hA>{Z?-Fl#Cc9Ox3w#QP7Pb#*OCBi0*&Cn5HJ%*dZ!a`5lnz1SrvAGHt+n86M?8(#WyWx`mI(F67<^jtlPe*%_Y>6G6P zb%@$)>NW|hI{;}v^Yx)!%Jd9^ipK@^mXt3dx zrqeJY&}9@Es~kCc)${K!_s^J;onGMq^X{)mf89whnAE_z!102m4P@YvXkA8@;Ajzpk0{?7tDVeZHE z3l%k-&wM(Z+oQj6tp$@D({V>mnCM53EBg@$67PJ4s)(~pC5T3Acd7UkOc)B9yllzE zyAQ;|;r3IYS>ez$6bxx(vXA?Z;@aBDt*zSALlFJc9 z7|L_U=C)3wnXLy~a1o6`oke&Q*T!Fz zuoaUk%s=~Kr^FyDcDpj6hOvT-=XGW3ryaf&QiS#|(48W@XolCr#HKg9&Uon-0H|Lt zz#I#fGTcV1FA@_YBT3af(3I3GS@KxI`8l_A_RsatMX-Fz;NoIyZ{Mr0^X3vo)E%)3 z%$zXduC6dRznxE0lKLpGq877&{#HeUM7nGGh(3e$bB=Wk6CJZzb?n*16rNAwd;F|T zlCYN;%t_G1R0`HGM%61SefSV{5TwFiZ(+pQ2DIn$%4YG@Sp zP$h7=L^lGa*vQh7*3dEFkNGqn`Mz9=k(GHfb^nit0`=oK?F*=Vlv(bo1haTTX+ zF7KN+Fi;WU_O}m*SK>S!@*u?laSsG_+4r+JKIB4x7 zfw}gMHU)-GCs@}IDr0uJ8AyrSnC zg$Vf#qKs~n?&PIy^X?zJ@ir7aMM-$fsiEoF?TR$ogg`iS7VXdvaqd$rpf z>@P4AM5wP}@>YH#Aj=9Y=znt;vefyzURb+*4-ZG50VB`f_O%|2Q6#yB zb8v9z!F>T#qEYa2mDLIY+P?wDtAfu{^TYS^^8o~ccwBtZRYQ?N+-`69mwU%vnd&eL zJ=(FID<)5KXj`k=I4AHfwSo;8qs=j*RL(!woRc%IDG0b13WlsPmUq$nD`15USf^`$ z=_Dr&5}8(urA~veHITLFDXryBJps0hz0w6uB&po%fs%6^8jp~$@V4sW6&754byru{ zHW*wWubY=jx}Mv#G~~Mn%lqe}$Z+5N;p6nOV(j?Iyuy_Z*B8dd#tI4wFFKIV!}%{~ zEuO4uGY z{OHBqqYZx!ar_L>3WAlZUC`_PpYiAU+zY}}3uGa3JPIln!M%1^+uRf{@evGdZ*Lco z^3jd<@|2#lk(Pd7rRT>)ZN|m7W+mvV@1}D9S)M&&i*6OAWF2QFrpj}Al!SUF4qe3= z)|F32b5yNWy2550`wO4>I6K0>8U=1a*brcOJy+U}Y)wriGAs1-F2qHwj8?(4IrZTp zcs+_Nq3W^^5SF$<7R31E;R?Ka{r`58s1eapz=C&At4uD2rj*AuA|k?2PPtl)2#LsO z=Ec&pgUCkwcyQwGey)|>0`!Rq1K|>#C-|eo_;o#>)v-oJe>|Sv_Rs&uC_#ZM+`S$` zRL}pZn1m3c^WJV2b_+cmt7BA=pX-GK*Z2eM_UyMD#C%$=lg6D ziG?;VdXEJlGW=3x&+>kBGhxYhAdkfA?)oq=E~94{`6_SM5OX+de<0sa^zTeeiY4+z z#)4EnT5dBP6SgG)t0PKI{1wk`mf*M7dLqUa7U;1P31}|rxrQ#VH$XExAPS06HXG#s z?{M$IG6^*=TMmIaU;SKbO-)UK3?nD8_>JeUXx0aBp8IFgf$TaCGT&OYdSd1AL3r1S zL|YBlcr;@TAw_qlieJECtK)|fONh|QNtxQm=)PfVqn^g=?H+QS6rX;vsv_yE7vgF+ z+n5eKHaM>kvdM_nkHOp`$8leJ;`_Lmn3y~+MVs4;80x1U{qtCl!EO5OX|$eTeEIq{ zbz{T8#Ka_g{IXTBYqBqaS0@YjyXPtB@fb!QhjnMwKvYXt@1kqDlz^DT7C@vVv2 zId6}n9qntOpd=8VQ%N+m=AZSK@iQi@+m7Nmk>m){kQW!em`)lXfknCd)Fi2Med<>R z2-+~ut*Ud^gl|?;&nMFU$s^lC?g`sjiZLP)sAc$S%H){Qnp;L%@E&=_#Hm>KjSjvZ zp=(D*8z;t!L50-c(up?~bLZL4q$>#6BP=2}km>hqD%f~WBvLt_ej~z%aEiomkdHqv zO_tX2<`c5Bw1pRT(o6q>eb(|21S&Jsq4J1-6knkYPDKv6TC97ZT zb$Y1Z@V&B;FAHq(&tX#qcx_gM`S~9}g(0n-0oHa8j6cUh373~a3j6WXgf9|D9mWPQ z03!eQ1Tzq;=go~p1~mw7Z*Oz`+QyZssj1gqs|pEy+*uuBu!nUeh{clzx}{@CXx^JxtQYxWs=$uzk=NAdKE$06)(iJ?8&X$5cJi49}v zDYtUNBZC4>$qD8w22W~OS3q`X)X$bygukm$WWCRlvPp?jIMcdrGS1{Kw8~hn78^=x zmv%gtwNbm3CZQF=PK88(Zy2o*KxT>QrsDOKwf z4$slaj44c)T|Td+#~HsXctwMmVN{eWN*Xkc`h{Qu9Ez0~ug%A-cv9!RDF&U>aZvt* zwP+C`r#z;xWn^V#m9l{U2x9M_!-)B3w>I22-goRO0lH*Zi4P}y3F^~n;uvZ*Z;cgei{A~0z-kiN^JFPm}k z=^+eJu``0v$Wi3ty*&V#A?w>;3SG{yiyXoIITjEMf_XJHlW1jU7Z#-U1~W5{U#nYL zl{|$^bv+2HG07IJlBWG8Dr)h9U%TMb))EtzZjS$+h(ypq^V>u-tua1}AcbgFU+wLl zPFeIrnKly=)Bu*~tdZ}uj-~^tHxlV(AK*|E_=pA&jOh2t1}?O#|Nda2Vrs-GQx;>j zURi?G)z#e1FVWHDa46@eA|fI(ZqGUBgYU&TIbX%~xH(0$Zz~;+Nc9TG(YCJ}Na{bI zCU)fvWQ`awRljl~h42eEN#aIZf9P`3*1$D?{miypGQFgD%4)w;LctIbo9N1i_c;c* zOfCAuY~RU7IxR3|_Q~X&*^vM2f)*XdlHn_Q$v$s*o`@Wvs5E{lCN7>+Tl=-NtZb68 zRX?F_aG#{>*s#ZLo6fo_0d*r?SG@7=xs4pLBhI6Vh=Jl#4@M??iJ`*}Q!U=Y?~Ca6 z*i{7CXVfdU<&KuZ%LAgL5~Qavn(u=jFH4-9^Yaw`;lr)F>qE_xcA0(s9t10xKbff) zvw#-3;g1Nw+Ypkuz>v>5ecMFB}U4Cw0HR96EtW(N}5b{OsZ0o zJj^jjT{rm!`G1-DH&m^|kdP2$D(ADavu~!Fdk^A9Y{Zb>nTkqFYvyyFcX@)k-HSZ& zj?sx!UkEqyb;&s|y+ z2%gO|#NeBK76~BWHO5?iTAz7x1s6o2hqHgfCG<@fnu~Baf)N{3dYH8t5Pf7AwGs5x7st% ztesrN=HHm-(McwvK>}?Q2V9xgh)OVEd3kMMnJCzZX0~RZCr*m+rOPH#n2+ziRgX$}6xou^wDC!7 z6S?87c;JI!GJTm7^+h7xjsZ;IAb5DOIKS4CZ_9P1e>%u+`Zm~o>f(Dhz~il&HzFXa zBAX(ZUZIzj5-4M&Glc}Q7^oNT5_{*O9dioM;{fPu7GZ76$~S1H(klsB@h1Gn|rV%WX(CWrfR*Is-R?en?z1Vrk^Q5C+wPR>;nqk zK?~CvJ1HyPkM+1>cT;J4A}?j;n$+-W7-a@O``~g?v@?oI>MVZ;@XlE7w65fJ-!FBgyJbdx_|8_ z65;WaC&M7Y)^|pVi_JCOn{RokUoAif8687aUEn?lamM|*oD^nJuX)^>`Q%!!-EW}2 zUMFgHHG@UgI=bTt7N^qI(pHFQU2PI)M~{DSkK%)pA*%%cqbcGKC_&E_7+aLIIT1X0 zFJPI-+y(N)#9wT(JD+7UsjmC^e){xyWlo9yx-!zf3*TV~bkJMivR>gb`?|GU6yN`| zV?^8tM>a?Ns5p&b`3JV>Qf_Sm0OwwwRk?F6N@i?0?Vtl4fg8Y zg_xp?c9)Gg^exM6@LUTI*GV_2?Z&r}oy@XsFZjTWpPyGA|MkQ+h(}1cSP-@fL9-65 zh#pY9VR223K9-DwWpn!=DY7X)gwemTd3DtIQ{Nzmr|6RJ6-rPfY|jD;Ya|4tTA9bl&r5}!!aNYK4YlOMg%M@Aw5 zFpmF;%mWEhH#M^?Bky2I3O`I?y#-r5vWA})B8H`u2Pmd`9>^g}NAFaQ->C;=@(3?+ z)JsXTh+6r;h2$Ei_>3xlZe@%^=ce9Xy!@jD07^ik0I=%Nj~^d8Z>9LXNP{PE^aO$+ z5-2=3#ovtIA{~JGERwyzi6-@65$AyhqsH{#<9iL{k^`Ifc7qMAOy7u&xX?#@$99^4-Qk z$^)C4hO=4H2HRX?2}Skoy(*e>RiNWU+xKax1r7i0nxnJ;!KLzxtFR(Axky4Hqxcj@Vou2Zp#(Jz^=+JWHp}?@wklfpz7Nq0K=fl?I`Z3I>&Ip8~jecc~xPjY{SzyL8WEn>Z7vFv$xcpQUO%t}JaE6GJ1mO_YfF!pQGYq8w*@I=Ba+s^3f{l9Inh)6?D$!AD66JR zP`Z&#OqXTR+s%aa#TQ$Jc>UxR(=;W#9D#f!#!q_j@+K3a^mP$J<2)bp8X#BDK7c35 z{ee`a*DLVC?!U56guC?t@p{&C5~`OkW42$;&(HTU;wug)sytSS4p<@iW84(57rs(1 zE_G)omTY4U6`Ov1Aty^QG&kXJbnb{3LUlbqZYBdBzNC>()q3!3@rOfXRJvNo#ri7$ znZ+Q==h-}s3lVWf8NtT4v{*Hfo63|tqmMn1*Eu~4z!%Vn5aGrWwR|k9=?>L;piVdY z;G_QY+}h{;TCZfBov*Cb9!>u3@F4Jc%F9jt4Dv+@xsn;3YV&D;D4i`WZW|p-Hrrj| zNZxF@r&2Xh%-u|XscQ5KKe%a1e@%Zjs^t;-^GA@uRe#D*j-+pEk`EXAls}V<6&c3e z^HbSX5I&l5UnGkjuR1hcnh81Iax#Q@%@E3 zl;&ZlZsL=SmCj`1KK7Ep^H6|k{}4t-Mn%?SW+u+-VMC*x8Vhu%c_hPv(J;H8n?25} zZb&}@==gP|uC6>-hHCnHw`EPauh#4GjBEs{$yAsw9gEmrngS;5J+Ji_$vqzUO+b7u zuTZxLb{EXmrGc5N40YN0^B~O|??k7KXVjFG$+ng? zn0y7^DeJG8)W91rSs{B>ZTO4_`8jqhqzB+{N?g!q*I#i^yr9d`2&zEp@!9r z3G?+S4@?d*L;{Ops){7Ollh>{4CnzR15nj^qu6uw3XJh>1MU2%S@yFt@ETI=w#~_! z=yv|X@Gze#ODB8v^ATt8WI~Nnf#PMSfsT%MXJ@BD$n}vQkHn3Z&tFFis;8q9KpNf!aBu13Ym9k#%; z*)I190s2tuuVZx%o5M-&$pEjL9#?uQ?Hdf;pW+R~4_VZ-7=`ii@v ^rN+)r%?+U zF}f;siG)fp-d#7w4=dp(ylz2kIJxgfs0g#6`%cr1d3RQkQw0CPn2!A{K!^a*3_XoA`f`3k8_i#>e zScLw}E7_wKYuubomsa^QvxcVA)pVEZo9HG#O|+OiIh<7A! zwp3CX6Z7*5C5h!KZnau^DyjgbAw*Vh{qxJCXMdd)C+x+uQqY3pl^8oiBOVR3hb7kLFT8D!KkBA-rRd%Mv+_2GE2T;J%+<)2;VZAo)&XZ`hI%9w9LN-N-7{~kVAXo{4PERK zWc}SaO2vz---nxSHyCfvz8At>JNjINrSdpiz_~4TzUjD~`rYg=X7j2B$1=K7YNZyi z-hr(dnUdwTe`Bx?R1PZ7rVvcqaQ_$yC_IRpH`XnhZY$4zAlO0TqD#G~;^J&7Sj#K; zn@1OEKS?$TEG#2;h~m-M%boZ&30}4P=FI5I(_dW^swJ;t3uA^#N-~aUx}`=}D!zm$ z9^d=r3GbuIyl7tWBhPD~46$Dj8c5QApl}YSoyK6VvbkYhqbGoX_vxP=H*Y#DP%z&H z^TMw;cEvWA11;ib
U9>|mN=DEh5uJC=m*#1dweUZe!Db1?A2~TzWqfv~btyTCS z`atw5&x1#ri$CPjujlZspy=6m>2W!YOhGH2|2oGXgteBN+L!U0NE&J)=V*%RNfz6S zy)lds5fRIQ)GakVJ+f)*KU>mD(mb58$Iv;5*MfWPq~@|;5gAreEMZ)%TkK|sR^0Sm zlj(bFngBOkC~|`r`r_@T-t?_@hbSQ%GcU^R7W#C;qpm!jOPBQ5 zRRHzQiqOpxLTJ^3vv?o@(8sVb=q1hsgVSk%{`W8?)24grZ#~V+20NwtDT(#XTkB;J zln`$HRlD*E;pR>I;k6>d7n=c^)L>RxG|ZSjGj|^8+{L#E6qS|1?iXTR-fs1LFLz$u zwYc-xr zB+A{SAjin7d`Hv3agm>!3m_+St6g#b&m@#^?@I;}N@Ya=Ze}vhdJCBud#^6|5=y+P za+lDwzhw-pS3m$_tk~@0B4MO35MTzv^KnY|px->#5M#7rtZm9MQh?B`;1*1&e=Ru@ z)E#@5r^c> z-cM9^vp@`~sZ%sD*U6+XDMz<&78VKk=@xfC$V%BErw%=btfoDw*M~X8?^|1SF;SX@ ze&JGH{s=~LH`TV=Bj5A@bK;F^m3b8M4uw?wmuZ5X2f`n}^l!+h7o+^SodB(0qLjzt z+PL=b7YdJ!rIM7CmzOuR=-gx$YPMgZDs4D<**Y~=Wx3%@O4j(}eyOR)Pj5Hw;h*p? zEBo|Yq)PZGZn(IeT&UdkCeSnHr+&v{RDI}1upP1tNhNVh)_)iUt$=Jm3y$X>!v6cs z!KhpN`%%X9rkMo!R$uqQY_x163r;d~@%;Rxz&Yz(aJ->=9-FGV4hPxdDQawQ)ypX! zeTP1y(5hWOti9DIqF4{Oj!Nvx>*(eG-a~;TPx1~FkaZrWa^|61XnF6|ddPtC$3 zBuo~TCllxx82aV#Bc?uRDh;y@x^8PG+;0jq+jQ6k8UtFABbe6|WkGsC#xFGN&5eJ{ zvVAx_krhuOAO1e0slnm_W)So5z|(w!Hk5t=rMsJ7A}ophJ>R^<=B=&l{O7EYd^KaA z`;jp*Z8C1-KVPWDCJ!-Sj}@4*te({fs?q5eHNb_7s{bJ%X70*Vk^!}F{v zHJe~c*}%x@sGMw3>!>v<3&hqELM-@cvEd6pn8l{-yS!Ix0{f~wTmz8f z<)k8(sp&nA!*FZOZ6>^ROMU|{T9*jCH%ivf>cn9+|YZiD4#^8=|N;%~ZOUwAX!LH?|vk!U%8Z@ zQO0dz>8oz-wl1d8@>sQ$O#Kep#r4o%_|ih{Zxu%L5nA^qfWXN%PNoEs;S?Y%5>g#W zr{aonRI|5c&b1x7hpOY3wYd^buGg(e(Sa;KZq693F3BPPX6O{ZiS5ne#T^`O6;;;2 zUcO{j)6IpmSmlZpC#Aq1v}p0tf8ps79D~L}5BE21;Dj+XfF`8gh&u z!JQ~ZgEPSE*n*U$L0o#WKObFCP556Hr5JLrJv2%;l;&@OLL1f$Dq33Y*~v-lGYLS> zlDjBYDd4--^0P;9G}B!($0EZSPtoi&teTP(JOn(!PeTQR4)n`R|2+;joR?P zGPe@>X(nc#*L+tl3Rv{CUUhgSiS_r{-ED*o|b)2%@$udE#5Jgn!J%91}0 zf9T@{Se8OS9R1x~Xz&ctl1*Lryt~Ov)#@v!+5y=i=Mw6VrrkL&R&DkCLhGmVK?b(*BiLo)FbH>@3t+B~hz_cnyvCn^e|8UG7l?}Ih z-Y#)2CNwPXJ6qT*`rz{I@UpB?CAKKR3;Xmn0cYxt)x=wMY_7XJtViLpf)4WvA|8&y z?9Y}CQ@T-p?L99I?8=*)C035bn+?{SloUS9Dj?U$e9_BeSR;i?to zr&Q1Pk1qtA`(U}Q4OkmoC8ifa2YxhZUS06IYLhzte%$Zx0TfrjlLT{;U9Vf=!DDLC z(daQbqr2Y&AvV^jc^5Pqu>Tng`0LTZ_wZt^VU2(eUAMmseMw6q!K9kall9Hbm$TKD zDKxMUQuRB3WMWH4wp{*vHw>S590FIOSrq)kov~ZUx;SJ9Yw5t5}~~x1zF*H^)soEm&Oxb@SKC492LtLyGrQTUhW?)=uY`bCQ!wy3?LW<#>!*pP4U2$sBpFCK zR>!GA9#27OSxx#-xowkVnGvE2x?#U4NS`%O$4{lF=M`SIBZus<(UEY{ z2U4O`iI^b*VxKy<+xdW4m|OfH09$R~H9JSBf~#=mBK> z+kv9Khe7u90dkZQ`1tsIFQXC?Bqudo=*TEYzY|fWzp}R+$CTDglX!=iUf}3Vk5+@P z9WN;OPS#}NvpV0$Rux7pQzG5Vx4@4|bEu!e`j${(Ct#AmXBAuO!CyN#woP&vkn1w; zY#}+k9j>!q?>B9O|9xH;Bv3XKkP-qh7z_uFWJ<=N-lA#p=;F}*7d$=)Ld(+y_iMA# zY16>2W5S*rNHLK2PrlO(gx6Vpd_QSa=pZp@6b;{0gj!DN>T5uKFB;2h_>b3ojy5nbTjBeQ?SAeCyYhF#5)ir!+0YWjl-t^c~C!B}k=gYP7|%*7a51 zgxu%+oo1FxLaiM9OjFtqmZ zoDL!MYN*DTe7Jz%Rt_)hmYhSnR2K9UB`miU2L7Ffb`kO+FHM+qLYHrCp1%e;(zDY^ z$)M-7_6RAYRg60u>5#-bv@VH9NWu=~Kd6xObGzs{O8vSFh;2=CPLK#N;Q|~JAHx$} zLHVwporv8RZ3r(fjP*Q?p^-vKEpV#1GO$a?@R9$SNqg%JK>251$a9yvk) zw>Kk0Fth4m-dZb!?MZb1pd>kteoFFCnPDgt&Brzo;Yeg?BR$l==2%pwDa-&7YBj*QV z7c^3MDk55Pn5hZ4Pwk%<;|RSrWXeZM`-KfaU2StJtkTlf;UNVU1PIG^uUq>grp?ty<1<#A0*-(3{9>DV4Q$JmZ!*)#B3X2IpsrEmHMBA$loC8n9wVHecuq?DHkO+Z9M$_ne6L_XTNFV zL&#p$s5}4?4#57K-FjV+G;3#qYESM6=lOF5F2@Fuq^Vjvn(>K=PNr7g@1D^$q{3Z} zf&t17UB(^?!q>*K)`~dcf&0TIS7&cztppWRwO@}Im+&XWKPwo8=NgQV$Y~Hl69Mre_zK6VW`1ZIHf0}AZx)zE~Y)!@15w^2B`D6%ehAO}@ z|H%o4Xa8@1!$Tf`F9(>?ssAR}b7VPe>j&h74y<>QR3qF2~ye8RipAEp@gr(1OVm|IZk z^w@qo_GVn7WcK5!g+ZXt*?`cJ>izt_JWiHx8rL^h*cb?NenaIcZGL}jQi7JVtBoBlr69&q$ZwL zB+#={@_{zOEkh3FW=%YF%N4lU2>uH^(QiD{H~^B8Q&Xd?Xok3{w!ICA6O=DsQsj-} z4y0HOPj1?W2}i~>$a$?ZISJRMvDmvA^gK^#`}~d{3=0r!LRvqSODbp>$A@fb6@_FL zik7lRZwxaKA_!ZVKGlI2wl_u7H=GQhM2SGP8SL=G$qy5)#`0X;glM&!-puUM?1qNyf+k7T2^90Ull_iluN^CrmwPRbnulV zYl@)1^RTAZ8pGF67=0XY#6pg;>BYz&;5jA#mmcswMkp>2+S9-kVABe@0k7C?#;3{d z>mwuoA@O_ZlfGhNdI5(bWws0tUlFXS0a{(%G1PEyaCx0wZJXs`t7|rECLFZ49IhA- zOu{}KWQ~Rq*LgoCb1y=c(q>pw2&aF9Rw9O7$;mZFOTq^8_8ps{cFhR6L?Q@(J@qDk_+ASCcB2;F@8dW@_q-^nueW0UZvjIu z?pMvyxk1273fWm%6!{F+AIR}yPQWnBJ9%9c6exq|yz@sl>jL5Q*x5tIHASQpU5;Fz zykfm5A$!1VK?5+G9%ibUYnX|PhOpCZg7lw}>i`(@mzs@koSdvu|=CCpSrVd5`X;`C}_IIa$J3|UV@xLrZsNqo%}elnsT$tbttnN!W&6aIUjHP1`%TOg;Iqzjtp&t z?yxOxU~_hor@hmN_KS=%Id`$~FUO>^_qQ__dIOu*LHwUDNv4BADOMEC@;{ElQV^l%&!f5)w+Mba!`ybV#RkcZamJbeE)bgOu>@ z#~B^;d(J=bUgx@;`N5gzvuo|O*Sc5S3>+M|tbdZhrn^LFtMGq|M zBE5lpiZ8~d5?}tuG~+Gtr^1A3B&4XJ9v^G|tuboeBu0^MI++aG0X-Cl{e$Pe0fHhS zPzQP#vVRw^enh#@xC&POVC+OIVCo-lE{C+a3HkE!GV|*YAjhMP5!?C5kmr@o{<6>~ zom0-qN{xeJSlN4h$i{UjH-3CUWc)`8iPm+U;3e(2M{jqPwdC-5?=r|s+@K=%$t>q~ z%*v?U(4Cfw=ELcEW6dgwcLBV3lh zQ6g@3V4bM=Au{&@yJ=mz!eZ!C3)_QW)^D*9YTL1d^^9yq87gX3LU6b?AeB<1g`awR z=3g8LnEKCSASMuQ*(D$l3kTIK>3V(+er@#R&US!{t{vGJ5Q2pN=+nJgN^#0R_y&80F_^m_w@cF}+waTG%Y^nfLAac9{uxvH$ zw4Lq!m>IF)B;q`+x(kG2CqKOMdg=`;=_(ze6^)PSWrVLdZu}cC|1*HXAg#UfySuLM;$J|o)s`T4T8|3B3erO{b?K_)e+$C zVG%$o7RP+a)}3c}vJEz6i5&g?Lp$P2trLQdpTi&UOKhRH`iiT@VX|WjzV+)(pI0Uw zWkwUs>$w=#rnDWE5Vpx6i{tw@m*uB_#0v+Cy(|=KMy&p+#9O-0of(~yLe~*$uDTud z#mD5=)g*;ZJnjCRiMxU7@BNmM3pjO91fjFDDoac;mBX{Gy=`Oy$P*ne9w(LYW>hMe zU2B2B(c`jWkmedQzmdKiBeb5lzx@PTelHhov)x%))T+=bjjiVp%_nqfx>o(;nZ~d5 zXzHJ(f#5kE_73(5l?rZitA#}u=EDp8nScW(oyrUia*`S1!v#ZuzdP2Z<>93c2eO|m zEa*$`ltstae~&=nypB@&rmk!~eea#dGr`d5MxskLRW%Sy$G_N%Vb|Bz-0-|FuQVm) z*pZ!HMI9GGZM4PZ0~3<;?`i=?Y;m9*9Xxp^@i=CKih@E)pxILxP=U@qE(l%n_&3mW z2O68WSsD*;XpYX#Qs-BBGgbhUgwh^Bs=b?fi@5dJ7mgRoU8mD_-43T6g4tSKxgXN* z#}qw$^so&dzln2LOtNW|*Q>RLDdMx>`ZMnr!h|+Hjcfr{1>uPZoxjmj0TDTL3Xk8D%FR?L=L?=mCt*i5SlCm<=w~|e*&xhe}gB$cPcTvk9dW<*h03jcxG?@ z-j0Q4bs>t{wV-1S%kXhpC#Wa=!Tof0ke_%^FDk-jfgs#rWG6#QmRRHnZXqNlQDC!;@()IgJRd^y98(M1xu}&CQZkgbkd-` z{z^OYdzHV7@f4yHNQ!=C?z}^Pf{h*FdEZGHR_4W6qPGtuCvXW48S{!?kT!=^z2WpA z$U6MT$sVvcih-uJzLCrEM1b4n#>aLOzv4a;ave*6 zDAkytRXW=t!vQWD#uM9r_pccNHH=p2T6rw3X^*xf`@xwAuH~ zR$iVWzB^t;;_GcvOo^*c+e7eF3{fPG2(ShFiJWj)WEAts>IKv28#P%eZbhQYi}1)^Do{s7aX3wh% z$@;tsJQe))T-<@$LG`CP;CDKE{?q6Oc7LpxJ7>F_3J1 zX}|xya=xL;qO9rUZduch1>!44>$kpN;zgskxOk&E35b$dtx2OmxK_YXj4H`2K>_0+ zU;47sLlKP82n{W-Ho^TGbj;n9H z$MlStb9!C;X!1%G@snAhX!@o1dTA-KO$(Rt@4O*@ZV0et-px&G|BN8e3cTm9~V$vv##-cK=eA{pt`EeHO zpRj+R!Hs-+GKlCFDzzii>5T$*s$~CEp`eiNRhA?I$w*S=$AQ{OhZ2H+y*?hE1jvM5 zx?a^}bfV%xx_eV_bE_vX(Leinmpy8Tnjf|63R}PkAlcpxmDA$=R8kW5jEZV?TH`r@ zJ*Bxr2Mmr7F|ELyG)pvys=9N6He9atA2v1fa+!^GrrGmJXR*BdrmEezIFcdREw_y} zISDGM1w*~`{Gby&i>8{7L0oFEGC_kH2LtW5;&}$ChZMH6s}}gtE$NEZ+WGA4V;cs< zy4PTZH;KaK{-fSAY8($cb$kyBWYjaA?a@CopU z6(<9g3+doFt6dYNp^9_ec$AD~N@Q(c=QBBEqgK?I&Iwu%l+rrhjjNvL9dC3K@eOJX z^V-5L8AA46Kj=AdMT>h$q%0lc!)2jN34$#&&o+6E;ouSf#UFVz9-%^>kMFK~-`Bgz zQ!-MnkF(oz>9DZ6I#mU^yyF>{n9DVGM+Yro2p*i+a$44+^3Icz&g<1gqr|yn7oS2*z1?>JJh)JYA*5gdWq*?H7GNnk)uf@q#Ey?2Y zPT_{qDON-r{kD&fd-0B2*&Cm_Qu9?^^}$@66_Q%_hNv)rd3GcGo+mS;r-K&A{IuSk zWi@cTsCK)owjSX8?qAL9*m2qNF&|KukfhVoW@Wv+|8JnbhR>aZ4Y;p4O--PKI__Kx zv`7K=jq74p@Hz2)9FyX=jxd32^DQiKRZEiU2g0@`x$6Nhq-_Mv9(nrY)DZhTz@%IG zX5ExS)cz2Y<+uT9yWIa1uWM+2+(2`V*R`@g{s6%)X*GGp*QiL-tk5|bzUsXm(sbVQ zDc*kOq5HpX1*ijU_u*?scJa^;1zw~hXeq-da|3};`&apFTVMb4&h?{%%H>qhm8RP| zS|2&uCkS}`ti)i08)#yi@ptGck_P_7>qT>&j;rALhp@tlQGDndPt6{u`C>3Ati)on z*f$MfG}2RBW-VHf|5%Sgx;UBPT(R`SFjcK7XhG} z27vee>nP#4#5RIOP5{48m$wkIM7_rH-9})*7>}TL;Z8ssT-}u<$FnI zPtZD}I*oNa|HKdK=%*T{L0RA0iUh@?Pok6x$FY`&+@UCit@Mc=MTm9RInllrwl<#* zcuuloM0Qx09;hl{4v_n91*u!eeWCMMoH;q$JF5{%^}P)BtD*#BB+0I}Hw8i$VJMug zHy6-lWPjeBhd*V~++1ote!%s#93+{S-!ntc0woT2pz~Ai2sb+~ zQ&TVApttJ*iW{r6b! zcmRiN$zUpoy01@MQIOC#rZpE5E+M3aI}_-@ zoEUwQ^wJ?9ewxiIL{X9}Pa-kHuu~_k%}>Av`GrJHF&vufK`RXmLN*4oE)4w8K+~S6 zI@-+A$I_2E+(2(3q4pCsupMJ{P|2B)$@Qt2(j^(-r?-9c+t{xq+QcoPbIgT_Or9sd z;sEj(JM7(b|JQp8Em7TS;)aXuHwtWEqk)8jQ_975{}Q$q_*U1lQMlH_3o0iu^#m z&iXyR{a25$>Ln1=$v8Yb40DOr(i0ZG&w0>}&S0p+TI6R0(F8m-FtaWakd6|`E#n3j z#5MN3X)s1i@nR`F4{I-T7+u~Bj%+m2Doxz5Kf~dAU6+n}!}Ef*iD+ZHj^$m*X)7Cn!K$Ik4Aut?OM)6M6Qy-7)bc07{ytiN8Kd z#wc1Nf#9Xl{3{9jIHGiqedbr5v5}{7;(lI9jp<7?PM6@MlS(sU(kJNOLy`8s=#yC- zSbRZrQtUJM?jina+BvHaJH-iG3t4R3bD?Vlnc3OG_zvEH(6W4KxfL)q z!i-}=y%x3r{W`ous=Mc`y*YlQRcs$5k zM=_UIxV2C?8&O_bR}X~>o=^3Gk}S<@AH{i)Y?XLF)*A1nR9p z%> zfv}{It{&<(>of~*m0h@2_V%!6hv)`&s=g>33m{VlCl#KBQ8O~-FErl#p|WNyk0r3b-Z51YF5Lj~jD!o53Ck?8`hPlZEZ+EGIA--Qb} zOJXkhU-fCeKkIXYNfQjd6T?mVSROr2r-5~k5%yW2`3Z`J_l} zH{&KRyEY!sI%8mB4i~A`Q_x$OWcuX@>EyWeN*^_bf0oj5BGFl5u?1VT)+@%-4R0|`L#T6a zxFs<cgp+g;QsbxR0a(tN&{ zq+Ea5+Sp4AjgAlLWSoA*Rw|V*0u?uQhHoze$`adAB}K1~m7H};1!i&jndt8?59%=3 zq>#NCP-Hk^4l{-SgqNIQY`JPitk;5BdwMXSmMT%U{gDEY=H{rCe;oqtH}X|r*(o?V zlR&D9|A?yGOp`>*TEWQmmY?Ci1Iw4-$?jMp~;og<-Av&EZG`>W9@FB8GEN{<>QA}vt2|bmRUZV zFNVG@;_Q$uw78KQ+Up#QjLj0xUw=cFKzc7I@X4H2@_IWc{Bi3N37T6&UDxv|y2p>^ zhDh(hoZ&waW?w1$dQ2uQOPua4BJh4+DP~dq$Zn3|1Nmx-nIWgprz$N0lg{_s%ZR=? zj$ZpKsSzd_R13Us>ypH=I@;l>N=0c)r;N+uWpnhAuV1(Tn4Dpn|Ab?ZUKhp~DO~ee z8hZ$H2;yD{Ja(LaAI1xsZXS8r?7^|I?uyE1acJld=y4cd|1>!vS`PHq6Cim){3Ns4 zC{E)t#)Vxnx}Q&#_Ax9Oe(0E5G@eQrBKCuOOD_r!W+xY#A;GhQFOu&$J`xNDKVqjm z7^e#>wNGFA;6i)ff&5+4iW5I$c_S?IqvU zzEi*l1od#41I2=JpH30o+yWsqc+z5%pxJ(5Z#zE)@Q+yyNyNxmC~Bs1Flj>f+e*Kn zsQUJS8#3uJic>dl%XS`6*zz-xvV&G(StN94raH32vn;daNrxl5yl(9z7m?^-$Vh1* ziFKVp5J&iE+M~Ow@`)TKa`C)dq7j03lyS2nI-A`x-tun30 z1ZebkzjU^^JGYJwUyr+m?|l9K10Rj(voT00dSC70AwD+Yy-r#tE)K7aVs`8RKjzG* zeehlcmC{$lLt9FbdeJ|EG=sz&yDfg6JngxzXu_8d?2lvxU#7iVV~^o}RQg~E**uiO zq;EBt)2ndl(?emrBuThO$svrEac-7L&l^t=iAgp4v-0_H-@OwMP~IO|w+{wtrL7o`&f|4Tz|0J-VTGuX6+9O{h;5dl<1a1($wF|5EB~? zMn~~{a}NfA%>il&)yQ_W)^MhfwaHKO>H)Sf6oLui#69=Y{|z}qjJ8Bb4l=6V1NvFe z0M0>$xM%qr@$_=H;7!;VB5GURjbvBD`DxB4kmF2LH>xrmNcwj9#R4vA>ktVS<}2u4 zft!ZC)~j$5L^&!A+Z<@%6dBSEWga0bz0(4t~L|( zT&P&T&42#(9T-i)J4?a^T}ut`(7$r$CF-dW&QRTiHne^0=eJ3#3pmN*&hSpM^O1FB z+~%dRyyNF|q;RbHu^~RJ@jTShEucSD0nwtJj03Mdk$KMEwHy`!lU&#B+rYaI+^Qk9 zo7CMP1%M*vd}0|R&mCn+y}D=b%7L=A{c5XgZ)x7qC-goYDl-wtMK6jD2j$c6`U(Tq znT9w34R8=7`wetU;gvM1RMI)0PD>jclI4|j$0LH@Z;CWYn>8lPTTb-f)!%uf^D*X3 z5?)?&O35Y@!-N-8<^$h6+%}}R@8=(Xb`YS|wC)8J3_Oxt=PP|vwVMbpuGecoNG?2O;=*n6NMp03vumVWWJ>9vvwC$pH1ecQQeeBXM5X|&_AcGjDiyOgvzsbr zZibX>oedP?p29*9%F?3^MrwZ``Z@ge{WX}!p%?d1@ibA=Bna!;Ko2>_)(H6(F_ju~ z50JZ$$1n}53j>NiDZ#O%MIb6!!v#sy=6sa3x2Xtu)w1mAzaLw1>6Su$g1`!5nz9K* zSL_TMdGSnTv08iR<)ie=_P@0?ivI`qtPW&Ml!(a4ws@ST_Y2i#DFvA5;~vtxeQwk+ zFrnw!NXEzo-}cnS7Uq&p_>RwzYL(~VIZeWq!b_IDwC67UuDR9tzkgOkflIiP{qjIF z59N`JE_2G6+Z6jWiHtt9r^uEyTd+tz)&^QKUHA=A-T*Z!7qjaEQhH6%d<4T*&-7W> zs{YiBmbui~42%J2+yY5-?iqOeD0+~XtodFB&0M6lT5N=+K&L}^>+lpoqjjwck*zU> z*vE&fxoLc2*s_=b7)-x|jv(Sn0>eb(%5l`kh&tyBtY|>E7qmX9j|Xb!)s|o-;_fs_ zS2teKtrWN^*0vQCuJ!K`Z_&7w<$(G|{(V%KG$9`HXHr-U2&)VWQZYXa8roEe`JNK` ziwg%n@ZZyFMEsm8CQcri^O-om4DJSr-?7D$OKGYZdKcuDtOKHxRtma zDx;O{zZom*k!d$cUo5|(9#}@h2yigGDA-p^-^oP>g zp1(r=h3fjOyy@Hp#2${*K1Xajh1t(Mew;y2fQAwK8Jj2-ebq78VXPHSp4ws%3@wtHdFUj?&6uu)_WVDw|w-w}KNS+WNTaPlM7&4LCEY_#y< z%Rl3_p&_q|;B+%&q-|Nk>RPwT8U%*7P{&FmzT-{Y`>hqYfmT@BOYZ3KpL#m%u=xVG zH5LY#ScZdjA8^ZdFCUPSSJFiP>II-;lbvJgLC`H#04jabc-=`jva+)DEiGkejdiC% zq9W6@*Bcwav&;ipJ|nR|qQ_ny>tVeogva+aHiA!>DiBZ809{8neRZVeNcZgkj?6BmTzOJ^oWFV@8NM1H zik1xrNstoobD!BsCNYH(2x*S@L9V2sp^Sj+X9{0t~r zI(t$(^V;`Q^pJS;906JH!xS>E@c*-StF#XkkX|Zu6ZM_#3Hwok4kO17 z#nM$VWlYrZFtoFr_{lLbvKC;Se=9Ff{(xEVwNN5+HU;_2ydKF$=Mc@XT-J!}WEfL4 z4L~9fuLPc*dEp&+`IL27|MEKTQy6iWMqc>A`%lJ#UF>4 zmcc~BTlNColV6}_f@)AvSlKR3R)SwMD|UPYt&OD}NW*&&+ljqHnj<2UQzw7bmHZ=# z75sOW{A_o9)`XXi96@08yU+ds#yKuJQKwN?A#uL{@$=#)dTs7)7!P|A6bZ&5h5n=; zY?F)zRW^_KiWnlDlUGkhCiObEw)y;N2H4D+wYf%Yg+^l z-TRKL?rm6NL_}=tlee>9T@GTa7*Y_0yBcgY-9LR#=pE%oW9u=#E;amg#4Rm2HA383 zAQD0>ZPAhH*Y(kt=;B`o5MuH*pny^7>j&zUgB_=XfRWSQk|1PSgFA!=<34DX zXYB|73D*!c?N^!z7_prh%Adf(P4Q;&0PatOh@NC{pX~Qo2z8=Y&AFS;C9LpVxlFFz;84m9$ zDD&e9*?uc~;MP6``CvBl6FCE-DIa)2g&OG(u9fbSd!m!qu4TJF9&EfROLyf=M1fgD z|B66?TiFsN1Kqn#fdm;DljW?uTRWb+k@?1@buBAf-7s5kVOwtXTR5mGC-y5SVW83> z?G8{kGP?GIhX8W_{^9YslZ^H#RU?uCidNRBsPR4MA*)T@6#oEAum`X75lky#^Sv}E zdXkc4I?H|AY{rXZo`nj-u0#oW*pK>Y8 zmY+S#ZxMfs*tfvOE+WZc7sqzFL*(~H_Je@i)hEkoxg`8D0^9#K4hNYBo@4vhjQ$*BFqjFKlW+%&t+t|n|4dd??WgY>k z0ixAAU5`A`^NPK_eg4as07e24D4~LTf1jOK!vgO$nqN54Xy0gvy1J9CDpc``cc^n zpgjIiCoAAE3n1RIk#W-woSiw|px>=0e-=izhn_Y0^|IUzLDM%i%>q5jHXJmbKUsIx}X9n5Z{Jmf3<>dS8eTL90rG{?bb0W}6w2@L=*Rc>t3E+ViushnElXZXhN% z*hqp#`-;%yixfF*Y-E}Wll<4|+-RTh#{?XM<}b=*AqKq6W!kFQ1fzx zf5domJBsy%X+FNmHhZVB(dEuK%{Oaj|6A0N^w2Qe0YgT7M@1^NW!fIdrO)o9Hd*4p zXTU#l7?+Su!CoX3a&+gfe90sl;7O2?MgClCSM*Zsxo7VrGUn>+S;$WO%!Z$-iTYG= z=-BvhoWW~quKzf*D3sjuBh%m$K11u9N;R$)X!shcYdgLrB$<6GI3U~R>&;Ww9I?8! zy(c+PBCM8x4nsS_J$F^-fU+{0;NKYkzKej}Xr6b9juLcOaO=h>lg{n3=9NylbgVj$K@ zrYSFa%=LJ-@>b%=0nsf^4huw-D}UvRfS78KQm71=&{X`gLEUQ>_W}p$RqVzO@jZiN z@jZFN&LoRRjfPWaguZSM?o+o8{P^aWAtaw@ha;(lUptMSPk5LI_JU% z7d;f;BS_rPxfONY9PQ&>p1UZ%4oX&4nu-%PF}lAyHJ9`5Lu3H*e@tL#n1pBH&z}bl zsPjC2mDj-|;d`|cd)MPB>>90#Ni}TS`1QA)b5{A0D%rv$=Y2Wbt814+9YkO9FYi12 z{^PzOMI}6wM$?spum-&?Y7Q9N+S&c4JKdyy(KlYW5NpF94?EY1%UTfQa^!DPAm$gX zn7w@WKjIJ=Q5v58^wFgPWI*4SgW^&55NGG0)M`5{No_?|PKK2EPy7b|Z+>Wn3%jiIaw9|AIfaa5`Tr_Sb%X z#vnW!ioGa+a7t4!dKi?;Pu_5G%zbvLc&Yb=t=jC}0z!N>E-I5%uB_8zv|FdHLk*@H zgy>fQk+4iOg$m=sA3N`{PFHZOX!Lg?m@AFv`$D)=i<6wCxQ%Aw`9T;)R z#$GiCHxH9OyrYs+6T_WR{;Dc1;nPw~UIK!rw|@If)%~k*W(?uEFgO`YxCT~(@VTQa zYNzkZq^w!f$KgUI=RK~CJI8I+RUszIOfNan34dQ&LMck1Xr5(ebdiINLXl3T)&Fw1 zRDYyOo;9LdY;0bH*>AL6(&#k({{CtCr@5KGe{BHybzyrK9r$&apvK8Go73S*?P>92 z?sA8~aqT$6qnRD0iRpDzmu;i;3xyaYnip=(jJGDACk`eb^@t+D>khPpq(=9l&FNr% z`A&w@rzJFxQj`S-3MM|Fxih7( z#3?1kfDYz|qp|+WwRc_@b|sNW(>=b(+nS&?kc7|so&QuQ4u13i} zwz$i;%VG0z`kL5iIUVik*st9`jW@o@zn85CeQT@ZBfxUA-cB1+LBoA$Y(A9T+6kw( z+P;qAD(`5apnaH0i>Cy`KL+@OwkKw$s~5=|iRic>GWE|C^y8ruDPQzk>gA*Ir1rtA z(J5-(APJc|7;md6S7sHp%${t`Pk+J>nzdZ#;w;p$5;8SF3A#1?_zEyQnd+4>2Q(a0 zRWrGj+;k&dU)i_ME~YlTH~0@Tld4S3$O%RVgto`~q&yhc>*k)t8vv=EQo;SGy{&a zol1Dbon&x2$XRqo=BoXk3%pxw)LNOu>ag~KyeQy-ACIH2tL_PR$}_&p3^R< zX$n`??Nc86YN=eQuPgjuOGh19$aG_@(&jJegj<1V{TV+=@D6x*!s_;`>(DjTH{(0& zRM9BZRkf^I8&AxJ;0<+tVk--^F`h_k*@*kS-w>fm18eqQYkv#TDPGWdiwD=EnWq=+ zf~Obv5lzOg#vXo2hj+9w&>@E*sm7(R$gVV1#4*4p?N)U>UysdYCduYxOtCkc8eeXXGW+R z7_U0we1A9iffT5j&l!S%ob;jr*3hs|2x)iCH-FAy3&~+%;K6?n9uV{zFna4wEIJfS zo`S|jVpI`j?_{cN=lziwC8^ju06epWD-9!UvFLYuHF{ZGc|>nU{#VvqVN4@ z+7L>Kf@_;Os;dg-BL{7TSyc66wu?G&dxtX6WbOhw&kjbJDb-FHhowXeWq;kU{&bJ$ zTc(a|kM_O)rUAw1N#KkK=Fkl{I%ZOFz>6L&0Tt?V+C3`^YvQyO-7lULjf68fHnzV$ zR4dC1oYh-*7li_%tA<~eV}keLsYgfe1T6p1y`(7DlEL{4^U^{ERe7n`j&>%i(0B;; z+xReCN-Sn_gO;(7_vE+s5d}w)e!NzI3I&6PF$u_;a9riu1`S^-Jru)fh_cIj=twnZ z!XrG9sG~n<*ICs#z1H|Z_xG>CcRUVJfc=5L%0mm50p}5w@AXliTShMvai>2OIn<|j z{82^4mmTF;XuQn=>*6GK<@)g}0Zf=~@mmM_A?}Xesg4|>_I-$W_dR5za?!csn@C)3 z42@AzCDIzYmp^T=b(#=g4g-bES)2?1Ih*v>Nx1XKf;*9ux1_Cw0uVOIwSo%yd38|7 zsIC#IS45$&HB*&mQ_1$Wp1+c#4zzdtdk>xogZF0j!y{I~)Nc5kn~+zG(KGqH&HTKf zibwK~wM%RUSk$fMraMlShZ>^4h$4Dw-#&(CZ^5s9Fmb1ZQk7ttaChnn=qy@7=y}Vd zVZ|oa8gpslCnk4i%q^NuRL%X$FLyeb!D_V0d45td_`)s zU71SROU3J#xhJluiE|iar%xEo`-6R21eW8TUHQDK!IuaT-cqSsQuN=8%~K2CcjX+l zRq~dZgX~^X%vyrd==dZ)FC!)_`zg+7)R5vKtoAM5IvAAao&;E*U;vp@W$v5oI!VAP zZ&s-M&CSkXV>mM_{24iEd}VtjVI%_{mc{EKN68qa+n}iu}N!1W`b|V!}!WR~@^o*tw7xj;XwC~;87nCjvp8aLSXvK2fDG0x(GGNAY$_6cb$IEoY z66m#j0TX&hQNy8V{ncS_q0{$WETT3uTE()bSC)a_d`6#6X%aPf@Q@%Ds5Z6CU+a&+ z(G*K4=!?k}46qAx)>j~~QO%9Ibpi`i$)r5>4f1WX6ii`Yxm0JwTdUs@tC?+bJK$UnwjC_qMD#O~*)rnO)7r?PDXSLn*M1 zXJaIaOlEZwPIKKm;ecG5w#n$zxsJTHN3ZM z8u^B7V>^SG|Yv;1y2LAR3d=#(b`f7kNxqGK?|>MWt2o3&_-$Bs!<*ZnO)3!z(b3~c`P+!TEBxGA2onV5!XCzgR|XX_pLs)GXtylNXhJy zr80}B(?8X>bDsBWx=rtN!lh9=e~(TfL|nXTT(*V!0v#DmyaV0_i+ zI0K*9TlTD59^f61&%pgewJgr><**c#kZ^r$r+#xWvD17Pxx#j5I*yG~)?8;eYs?emopCvbx-DYz|4Hg2D=pe=jZ}p~2^kUE0eH*XZqb)2nlfQXWN~TX%mDPP- z+VbBq11;L%t|X0eB{Dd5pQ#L%yOo(7qlR%sHuoh?=pqh3fpX*mAZST-9&Lw?7T%_v8;u1LUKJ7+U6-XucQ+s7MZCdncQtso$ekl#*?o~;kbTLw9DlWZiyJwm%d|awpZBk?V=HjYB=$7dS z{Cq|5-bu0X`#vZI*vmlFNR7qvcB&Urp1dwy(^Z*MBk0PdWWVybF;_!qbbS=GumTLi zSf+}3P<(uZuRSEc6;=p;?kmk+NYtZ})7xz#1I>i$tI8k_M#z)tUbZTu@N4?w|4V z0#B&9oXCEk$dbU3f1`bVxXMc&F81aq*?Ra@6=1*1R6Bz4aE8UCRyC3gO4{lwT4NTG%haS zV(NmH!%LX`%J&{T`N~GOGgn`Y=&XN_`D5Jkcb=5b(|7`|Ka4U7Rla)7b*}y92n!9` z2d?4r0L^J-Z0fY<60y*fzrw?C%&&7DwbA4zRWdZXKXQvFCv4{({@+T77uJ5H&2N~W2_Mf(YNsG(!dxJYSw+u7>DM0 zzN1G{U1B3RHSwx?r3w?^l#OJ)OR$m{xJAwmZZgjsbO3;jiYD&Rh(R*O>LHO_ zOMHlZAHx1UY1yU5SqNIH5#2V^!DIN_^GoH0Vm63l(yDF5hT*_C_E&1-M>JJaMrShE z(Cxe@p=^s7UlaDnvY}uCz_v*z5?GOhp#d9Cf6uj{9hUubkh(r%bdiJ$=IX1^FEl&r zsdq=xeZQ9h;6mu{iva~g3&V=`Q@xA*9Q zBs-h_G=s5*e7=cXqz&6G1o5ZeLjggj-kJ7|c|%NXGX(1#)<9?iqj8fiWv$CkJH`Uf zm3n#k9Tl2(O%9V&{;>8T-a7=iHtqNzp`dXc!F;MT_+d3FwQblgi@Nlo!)JCwPXs&w zFh{HY9wA*f>E3!{8OkGPa#OT6g3k>U5e-!jj$FnK#Oowg?_ za~6mCA;z{V$KS>_4yE(|-y{+ae}Uc(E~wjrds%yLBg2Q!s{ zl_veP(JzP@e_Jmlgd;3&r1IOQLx?GG3xGD1TZCce(cvzXFX1ss+n@}*K|Kg&# zsIdPR1Jx6c#0wjr-E`~|YPGNIE!YBr6s?c|XNiC&tB}G&uy;J((6-^?n$HOYVU*>o z>G!Vo$|Nq|Hf!bMKX{e!g7B~P>kxvQ_W;3z8iTfkR=>4Oe~L}H#JZ5?YRSxX|7^Zf z%%t@=vx#6JvRTj59{xzR-T9ch->K8jPMJp7Wph(<~rGG>?++&Xo6XP_L8m z2$EtX5AnRA{u2HIA|V1a)+Ok{!;u!c;L~{X*dnUlORNWpF_H z9G7{D<-UQc4bE2knv0J;g2Y>bPhgMhQ6Imp4xv$fgUM_lJrX@V9bbdlZ~50n?NI<% z1ncqZTAe~#I28^q+YlEftJ?g=(yxKKlnBl%Gu1-5RfX!`$BS(JfFprzt)}#bzd8Nw zX#oHLZhed5F%Z^O8MH@Xp$oI;;f?!WuF850$_b-~a$OL&utzu7$3N=q*K9b1ekn8t z(o*QQSEUoB&H!-Z+dCuW2tjII6n8OyQ7_BN@Ce1^jr>51_X*Ar#^WLEi#k`-S*fE( zKWyNy+3tIUDhcw4-2OF0X0(Bk8xHGRy#?wYOfZ?SO4AamA@2_?t;%zL-XJTaDu$XeSg5FKj#!h z46cBTXn->$9P-C3d3Yg9)YQ5Sj{Gwu7|4?#N=9Y{0Cs;(0s#$N02Ca?zvd34FcYZi zoDUJiDSr$C8B;0@EX;HAPax~(cb^G@X=i@GIKPF-|6{y*cmkjWqy6)3*y$hL_D8hi z&oyB{LO+Box_`nYU?dMwNYIUvRFLZr)eKo(Dj`@aY$-kZpPK{b{WbthTa=RW@DGpq z=jvd=@lyB)$NjZA;A=7veCYoj5s<830jd9YME^g!Mfz%#ScNOjploS8>vhs5%SX4H z>u~_biO%C5hWxEx1R+YWq-L!zP74rk2O>IylKxs|&_ZYu;s!oev;SzbHBQq~S+4)L zmO%--)31hhf!`O;+*?$)`7D9Mp^S2WsNT+OMmNF47*v3;==5OJANEk0Twk5T7bbQ5 zk9Ko^oVw;0<9~h-!oFJq7}Dy=Xk*k}YjbM@hjGjr-o2Q&Pe*j-IA4jhoi@m>{eJHE z*X)7a11EJv0yQBRh&gq3usjzb$^GN_?N@%Y2Y76wzQ_c-T)2I)YSeRt9vnPqNr3RH_thY%(0n8$EiOT%g`Gk4dc%)}E$F_V-Ao@Jwd&vk>fc z=CBo!_zZR2|Gd_>rvC%zOV7}l)aBIUh5Y$x88l|VUt7uZ02+9RPtQ=jKu6)l%X5WN zO~KDLK!$$7JdyTkKf#qq1*YD=UjjlF58eKVtvIiEg3?DLYq00;LRDXdYp_JKUnj8) z#5R{tPb+N!u1{z&-SEF(p0~a{mhYl^Ca(o7H@qHNt@U+Ju4r$L=H{N9ZjMo%B$EIA z145K0p!3fpOuE8|ax+74UeKP*ayy)_#>|ZXHtVrglBTl(hreeCl9v8GIT_G^rO<5Q z)9wlP&AH^|*>dRIQ83Pn>AT1r8q1xbcE!TM*irF*|2u40kUNg8cB?o1U9S*&@*F)` z)tbnF$Myh&*tys$ez8!S8{Z1}kG?o_6Osr8l~@gMPDE@EW}X#l!2i93w^hyY%tk@+ z1n^OdUt7F(mICQ)Tb|mm)!83t0$7=VAYC|=Y(G4VJrI%va$9Z*Ro&wWHi+RHpQ0Z*8NO!HXm*c{=~=q{v7g^23;@?Sje`E+ z|9s$W0odmN8a9T%l?LDo1)C`#t%?1|+#sLJQ2|5qBsO-yTN`g=17 zrAi?$QPujN@t2>5gaE2W&J2a9cqAGmj0SE&PowEbRxnUc8C+B`a5zVNb_ApC| z2n)iv%c%Z(l+FYRA3%pynGU>bs+=i zni05Pm)r~iYysVPDWH!V#H3R*?v0_V1`r>p!!t4}1l=sg8=dTMSd3$XpU}6m+3l*B zUCT5Oh8|EX`=`T97N@%*PwJ&9ug0Tws;~`lx6#+hR#DnmeanndE3;d0nb7@`xbek7 zu2?N8zCku$ZkIz3P%%~jNNp;W*Q-FWd{{1xEHrnCViP_@@E}w7GrgenC^~N7Qm>J& z-Ha-2QL;JlB zox}Vq&y4Mi3Fu>Cn#E2ang>;;KST@(!Zd!LcgmvsFpQ;__|6aCF~=!+h5PVHC~e5g zGWBoY@ibDp2dHL7U+PA7@44U}9rZAs09<){Sd6X`RD9Hh%BIU!vN?X=H34OGcFap} z+xSO7+t>lm1m7ZEf3n?);uKZa^VK3U<1C%d6MTmtVKiafsp6wcq_?eH=r)1Yg`(v# zIZ6`7b>IbSF)8)L|CZwR$!2quGTYrY2*d#TKZoEM)!mg1_=$^8 zzHan$eq53_tBEL`i5$QFx_8cEsr^0e{>KDPz&lzMjLXXK?R1@=Ik|@QYWE&yNB(e)d_ANBFVa$xJi0qQelHCkN8p=BMow0qN z)BAkR@0|1g_xIoTdtGz6x|m+eb9-#}eLv7ykA_Ojr5w91$?mVTUU^gW`gFU!n19yH z7tcyHvEIiw)jif5WZbLy36%p<#y1J3@3Qkx4wYHEQgVl`ZEQR%d8t*d!#zFn+g9@< zRmGGvonvpdwJ`@*;H_M{;Rn#3xi3@QIr`7hZwmIuD=tkJ9s`M+giDeq88kvOs zY2yrDkoq$o&J@S8)AzZ4xv3rr@|^Apn;(*VZp5$B{&`_9a(ruE7$JB+=f~~gAH?fB ziDT%Fpo+&6D#Hb8s9Fi*KU}Zy`D*l|@Z!>+mV-t5Ou;OB-fNeb1ma|*xIYG2+r9`y zH)-URKfY;Pb{HTtw9UUk07>Yv5*%ZWKHDw3kSw`5Z;sOEa??h)e-wQYtJhV?8Gjvb zy5-{8<(`b^JE@qIR60JjweVw*WyiqiBH`G}Oo@%092-fcB%Mhy4(qF)k2FIL>YC2( z=>*2B*iPKTSRL#>E?I`n3K4%C6+n~MEbd}^{89wgUYqqlL|$quk?;!)vA3}@=ld!q zr2ByLZz$)*TB6tZ zGsQa!7XOvZhS}Z@`~K(d$7Oy4EwKZCZKT$Br|K@mmhbF}M7?hE;lqvDel+K?*vYj& zoJZd`z7P?^)?+(cxC{Br;i*DqbtBgSGocbx^hQa9V&spNiB{cKUOpr^rasC12AaA~ zAcuGW8SxDb))wUu%+gbz*}fLF0mO{`Y+s&x%4C9o+dDbJ7RFiYvvU(gL5w{Dbpul~vl+c(8#3P0 zI-qq%PO(awC;EJIMz%(*Da`$)+*hL;|7}uU*vnC^-i%nbmy4?ts(%D_T0^0A{gWjX zfoz_(4=mvDHs+xHY5a)OHf zR7GxeV)QHEl^E2odeL9H*lcoBnQB42BLG;aBv0X}!irvQBRWtldm@dk9)TA>H2r9J zMb{4W8!cPDOc?S2rV=%U8ZqrsFmmI`^n1p#&#&8dCS$m_-sD;?Gb<0_m;j;036;*O{*tV@2aSa*_IygkfA56sK#{DFK+X&!$q%7Jxjsk11{fB^#wh z$AR@T+Fly7ZgU52)ngWY|`s6 za9F{4sDJ!$BBfgk*a#$)^4=z$4FWn_H@^sSm8u4o&jtJjA;;skWlixf5{Zy7lUPeh z1}%MNK0P1QoHN^))f_mA-9C&x*&>t&5h!Ip1;6}qL>x+RpRSD5ASB;RDa*-&Lj2OH zOoc1AAHHBm;9)$Ce!N_P&G>WWO|G4lYxBvO7+k9tgVrrLp8tx3vjKE{W<9Z&3#@I6 z1#!zv?p?-~cy5Hy(;vj6!)@#MX_)jxUKGde1A)1a;*Zba54lrWY6rj64HYKGJsxP$ zLgD%6a&b>;>3aBUa{YYX?tx01i_;rV^{3T$v*$lItNczWcBA_xHZJLT9kg;Do(=lOo&&-vImvqYXI1nCucOU8y^yxTeV$7u@N^NW9!T=Ap7c(c7u0ZHEr z>VgMz)2Ti6!Jy@ddAf*oMGx~f&4T!smleG%GiL2SL9Pu}yyoZg!B%l2f|vbKYvPbJG=&!{L?lN<~t)Q%5z!J9E;d=J_CEvnw@KODhuS<9%ecgtwHn4E5L|5{< zFxWUfx}en<+wXuOkx3VTX}u>o+|QK9=w4Pg6Uy7!RBc?#!Xm#69r?J&t)3OZ~s$TMmGOgy7RUJb}E`p;g7-v>NO2 z9Vi_wgxx<}H&i@->y^fvKwb`wt&HF7LOks~3Pu6uuBUN4tCPw4Lzb_5_-vj|P_ypK z%$HQM6?z-xef7gBs@7{+l#JlYG!cCdWzk4@(uq{Py{O-@{Y<8xx_2*@k_w8WL> zHq4%iqYxNnL*NJ z@{gxB1x~od{;72z*(aGbYf&p0&zsIKfZFK1hyDy408qVW?gCX5Z$HCWItHEGU6JN} zFEdN&xQ0oKQ+qlhNavh)y?4E%y#>QKJ>jMu6&v9=PE%saR6(*{r~fTBamQBxR%BWB z<7NDJp|B?KPd;2;j5`neHJ`o#NZBWdF{vEz`iX$0IRZ!1WNDCS?7jFu>Tyj9t0(moExy|k2=0n0SzZv%QVC212 z$hlju^eJ3$-Jx_+e$8|);5pmsT6*Ta7&|nf`GkJ0X);cE4?SsqgVs7G7%x2^uZHCOO$Sir# z`E&~eru7K0XDwJrW~$v-570f+OvtR;Ac-*u{5FvfMm-qWZZGzoKs}Z`7Sl<2s+TK>9 zT{-)cd`lu>YekIrTAy?#FG=xdmN#>UbCRGIlKdqC6 zdIK#?TbHq{Ibh7CyL?Y)1WT@4(+!#-KZ>b$9tE-pO$r%$qx+)T0Es~Id|u)7&j)u2Zo!rBpada2)b)&8*M#`YOu^-#d8Ut`Fto107_d zXMk4p@~^?u;hA))zb8MJ-rCUTWA}m?8eK(x#xm6C<+lcpKB|n(`bu+VGeZxz->}qr z^01KROl$Fr^SAIDhXrDC9&dVoIQ{N_nDeqN5$uCDEt#GjHWXQSE0oEtvFQYfsHJOy zaL@ZQWqY{B`fNOu=66eny#UU7zV{j4Zx8I>-;gs5oVNSk&ENA>qtJ1RptaYw2jAAL z$8*dyTxR#!E4bxDtwUb@fwoJfqnu7Y$pUX!&aATL!_$&~?SCER3$#JY%}HAc5JOTG zeRaqUI)3&!i7Klv1JO@g1JXf$&h{OzL?fd*8+3}2k5T*x3GE*{Ic9Cv{qn0`-FM>w zWm5t_5%iDueULwLK})o0C;u5<5&4#1>mZA`|CLZTteM#;zFubB8*}nqy zy^F8L^AX-a^$$O${>p)peZDjewGfAMVT&&@P~F)54N|Fm$E1Gddqxj7n;em!7xoQe z&r+W?HrePDD%aUpk$*!Y&u8~qB%l1#(l@!+mWFWd`ZPR-% zkau)drzMEpEI{`-;Z-BB?gtAze)v_;S+Bf$IbcS%tRs?Y!R%zW8o+C&8S4%tJ>+iF z!W-X(&e)SUTls>G1TDK6H&J;h8{o`Q?RHGXzvq5~Kgxh_^J7h2{E3LeuYhF2iwwxs zLY7P?YazgBzU!T&gnJG-9seEUEMCdv8lk}{)a<6hyg z01{Z$W(?$-qEw(QjIjmzY(fz8KI9eRWzso~tBE+}1mi=8DW(e2Rb$qv=;@;+-)3+2 zPEx+L-|*>_wdLgexlGfJobTk`-$D46Z_>$i9~J+j1#nOc`c!jKpTVD(~D$1nx=KCnfY{!Wo_U&#ddL22nbBK+PHt*A8_g_fmsU zXKYC1z;%~ow$_=W8wBTgqse=_z;9k~e^Vi>qR4hsr|6f9@RUD2XT z*>PK8#WQppb|aCNp} zf7HA`C0bAKi$E8iC>JRR22NyA{`_Uu^Q~ERfAYuE!q{-aWWxP)f{{?J#1WmmE?+-M@?oo76|+(dNohn@ChaGH>dKrtk^IHRDRPmahfOdpdHT zD91Wn01Sa;+McI++XHH;6d$>bZiu*BEj$UA>=xVf+Xx#M5fpd`YhjDc|DBaF`^l!Y z3U>#S1W_3BY)k>m?=qR~6J*#hmW>+nu*jQUvXUTXvAk=a*dvWEiVph#<)xN2E;gU3 zoP$=m;@>Um2a}ChSzx!+yUh@Pz6=i>fW%Xq{M4KJ(q_w~*&n#?H+v7D+vA*1ZWWk8|b2nAg{R`07|pde}NI z<)oK~1oB zGjA5s2JkmSc4j}al)cN1v7@$Sd7hXRcA7k4w8(_~xPwb$%NLmoPhZ1?Hx8j>&(ztO z*%-4M``+2eK9op`hVNtUgrQ+csLV{~XGwaUDF4vmj+9KvDKc`v2>=p3`2>F%Dn~tX zi%ssVw|er+&uJ9sj*Xqp)?+bVZy#1Ld`~pX(-UOMNGoQXVHr1Eyb&mEl0Bi6sHaYN znJ?;u)D$p4^+#*HBhJ=?SRFgL_AEY-SFzm6sIqr*;N_NiEjX*i%DmYK4fy4F&PI{J zp*dfJwoB}``?{=1S4?oMA>Y*Ax0YCDE@C01t+D4Gn{*0#r$wEV!f%@Y9ov^fW^c@`GruG5W@kLw| zgma)aK4%&<@_zCgEhrwCgOAUEGy(YZOd_4vt1HqyV`mI!ig6Pq-gLod%yPBk$lY3y zcfFtax)hjm98|Rp?cS*~!`4yP1r>U^b7`A`0uWL6LYasz`e>Id^0PWwRQPMXOrQSZ z6UF)v3>F&~c7~2uAjBIF(;&?Ibn9nLqSo>RbIY5^IN3fS(Zh#o^8~fGRJ6Fuxk|~5 zU4?R3!^fI@H^^nX9N)eDZr*$Qk>lBdnDooYo^@R(ukPeYClDK7ObENC#U>9u9`RL& zju%T;u~rZg74w3IFLKS~ER=7fZoV+r2@7`2>Tj@DuAoIhI?0;5d3Y+|4eXasUUBAQ zx=-p*92T2(k#yCjIW;!J=Q>4Ku0GMlC#<|+x})y9M@fNPhmXGjmr>|tYULDZr2cG} zehVYP`2~r6<_V_3 zKiqWyl~`bE^+lidlBws!sjROG%)z_BqjSO~!>K*+68U0OnHkj}tCg(fxr{Z<$-FG^ z0$F@Y+tU@>Qwv-?){Q#z^HaRq_WYP1ad3Zd8V5P!AMs}4Nh%Ui8Rl_qqyrkI3W`aEA2h8_gIfT*_l6(trN zGB+;Uc(d1x2_LG&IuE%u8&s6n?^Nn$@Cnhr?I7fuiozM-m`-&PKVPnYbWg!cOrr0+YL!emlvvYz+29&xaQ374BP z*)>j``x7VIV~ziUj1NdFTH90=bIOdb=qY%x)B~@s7=mKqU$7;FD81Sg#I*2_{W9_< zjQJC@o=n~&k^GGVqoaExW+~-gH8Zwcn|o8{gOJ~#tl5GoG7leWTzoC(>oxcT~@Iuj6*3~&y?ZNL@6v;j`+x8#<8wiU}_uI z=4@?-O@sLxE{Y(!ie28@V&gd_g^d_y3?&@2d-YYzPX|eGg-`7Hz#6zAEa&0=*1YtX&s%yDb44jvwt`I1e; zA)DrvHZ6g9mFkt+nQU*qFdEK)$G@8=VCyqkVx(y$tWi~3RoETUd#mQ+sq_0Q4#jI* z4ogIBGAz)3C&+-s*nnm?>Z;d^y>CyNl6|so3)@O>`-l`zzx$E2bhh|n?4ga{)CJ#U zeXdTPvFonY`;-1T6+LA}0uH6un#n%>wXr?*K8MR1i?DSeYu4x6PLg@UwdOKs)2#Dq z^nFSTLr=2!^xMn096N)QjZ7}H^%223Zp^&#*84$F;@UYUBAhLI{9t^3yil_OGI@rN zeKK%d6((OLhJTr|)ak-erHgM_dL`aN+nd(zz4!{Wsnvz^-&SitJT}3b$|7-tq*{(D zEC@M+F`zw0KB!BPp0f5Gr(`T%>&0m&vNj*Ca5=CK@h|z+rdL(# zRq^?Bl7%ag7-d z^Gjle3A4A!2jBRxe|&O${%9yq8uC)VM3kYHvxb5v6E!T3a;5Iu^CFqx5?48Q`XbOv z8^xkhlB?zuo_{RckKTl&WM&-xgMhYCcCMJumutQ-t;wCmeGwav%RAmQlnahH_(TAGKtJzX< zGjtU@$D$zb@z$(4M4)=##{THS9r?xggI1+1>;Z_x0zn*P2hI{#kc_)dYBPhHC?XOF z@nIrLn?)0A!2auLCb|Dga3>yEbH(IphSR%tr)mv7RiVO*NDNQYJSD#f|E0Y=DgPEiNeCM+Mh2&%zI_Yn2Aa zorf|{O4X80@tu23(B}0*O~|Qf<{i!6}zkd#K-6S>57 zO_${l7=QW<>Xu2R$EC9=i*$h1fU-R8P@2_txDB&h45uNx9fBljVjZk-o<`CPuWLiV z7`Smuio?9KMQN-Vxij{+VUmiUpjGaOAk6%gF>WKiT?I-pZRw5fAjTx5Y_OB(+&+JL z5)s9b-LOu~N68nrjz@ zOAx@ym%-o7ju%dRTWZ|4@%^a1_OUqGkW-jIM!Z8;cSom?4nV)|L4pAc5jfUAwDPcN z(fg@oQUF+Cl=ncoYG1H3V3T|LxY5Frjzf^vAEe6bj^&3rKyLA+`kW=?jKmagy>80E z73NC~vxOPVn++T;p5XK^C9+K!vECwU-`-I+*KXJ`24fOOZBqaCO>*4*x)+0Z~7KquOSRa zx*uXh8F&IYFQg1_zAEGM(Mx%_Chc_4ZPhG7oC=9F)l-a{^;h5Pik1E~OWs^tPieKU z^lCrp$1i%^qqiznz;A9;Fd_n$^{#%k-H!81ifTAPI2+Mmg)b9v8t)RY;w>zCYlAIyeg zt$ydVxG6!qsa{{>Kj6UP@MKkOy4e~Z6o@-FPc(}i{;Em$HE`aN=+zyE*fyJC*CGoO zbgxl68`Dc21c+lB;Eh_lr^5He9LXDV?uXn*u%XpiQx)^Qc^+!217Drcv)1K1^QR>F zRzsOYQ6ulPf`X2M&c`<`D-+5X_&D#Gu%5^;#Fm{FFO?mBc2d!6PGX{W|G4R5pU?I; zKBCC?;kuROd03*B8h>h1gKWuEaX2TnxQ|8$adlu%L0~p_+k{}~3lAKiPCfF3l2v++ z_?^cxAl!y$@3t@zX!|>gM3fLwjVy`}f2qc&Anu{w&ZHU+{iyCSi|W~sC&(Qhn;zA4 zvebLGyti5GhN$@wLiLVRLf{)~)YCtFj+YM9K9VaNPC>{er5X&;<;HW!kGQ>!bD`9O zL5qXn=#829)K%vxo9Il|{jCqp2uX-fmr@l$?3*Uc=r_By(@w-ZHuS+BFkEG#u zg1765!io8{BW`B4fmMMsvP927J^uE>`W66})5e)vb=)@k7(H2h0QC)P-fmp3?P0FE zv^iXmDUCb8jav_N(1aTAapz3Rp3rZm?RL5hD>Bhl?}Xq?x8?-%N)B|VHW))C@fGe0 z0s?l(?0}ETz*6uLS4jZ>HUUTNZ0`Pf6BKisXEPm@vLSgNjhOL|AJ-5;^6qCWaSI4^ z7Co^de-e6;D>$o#Ynn~|WmYAOXh6Z|Aon+XU++tJDw0HEmw5Tx^LTI0L1_#LJAD#| zQD5}`Fj+y5>}HoDo1Y;Uqm=yytS_4&jNGmeG-1cM8{rwf&(4FfU&Vou57IW7hx-25 z&tVEGSQZNMroSOM^RvP9Ty=g(zrmn{~tE01&3%hJn5FMu7y7-q> z#aQyx3(33os7;T#o-u_)zJ>5kMIRfm^@Qi7Mja_3>p~KonzDT9M<+5w5FQUKYmGV z8j(MPxk(VtT^4x*|>_eq6zoZ8Sdm8bxq9o2$>yK!E@r931yrT#JwiHd*W&V ze|p3ntQ(xc0$0Ci7@|rHKKK=N^j>s!k|^!DdIJY3k2`x`Ee^-aqfK+p@MQ;pYU}wsA5ER`1C3;#P0uEzZGZk}JZIdfdU%ZLER!fzRNy4BS zUcz`UoyrHw=`=80dZ5&h^-Jo&R|Yy z3N)p!8QAiZTy^)}cUrVvnX=>>oQ+I{=;^)W@Q6>^@pzQ*qR(-WI1s31yu zs@t!H{FwFtV_9m(>I^wf*!jK&=bzJ9eAmX|oh(4Nkkun0VQZCRG!&t2YGJPNW1uM8 zzjKp4H#~dkcADVaZx~+@WPJW3=h7QJzT!HgeE7GKBE-xe3@Xk`w_Ph$u@j;&=OsG7 zQ~lmEtssjEr#rQG)4+%Wc}rS7za-gc_)FWtHgQ?qbhX=7!nn4Fpq0v^)aOwdMnCoB zgLHCeoEXSSDn+^H?<&=y3GXxxQux1@=Dvyxo|C4zEnLa63>(@NO7}LOf1@t#l|N)d zcu(Tg(7kOkD0KBwhb`~XosEORTb~BQdx9km!ok>d!*Xqv&x6cm7VfG&U7a5HPGWF| zIL&|Hod#l5q%3XUNFw3DYGsRg@$l=I*H%`vqDDvCAr z2apWI@-?nX`*IH4t&!K!`bzCKhhAdt1E~(Fx{pmWM@$g@d0eG~w zC+8r^XQJRbGl0Q(`L2~PzHWueqmDMqgAu;DTK$YUa*><#!G{OS!Ngbv7Q|^L`Hh4* zedrz;VBlsTg)xxe>i~SEf+zK&;vkX|Z-Xia!#V3_2_8?;1;;S(Fz1K01{C!9thO=w4(bue@|*~U%*Ur=GN;hivo z<7N@PEraAUfSr-HSiiV-N#ES{Bz3HRVL|f4)DKIHg5#@{F0Yb0#$#SDAiMuT=2?+R-N=|NqKl99!hezsNS;Xkqnjr+UeHJw2OoB za%(Rd0q4N|sedd6Wx*Aj+HUXT*Q{SWtsz5ER5>V{-?nOTi+wOV2<1aQG_&%IaAAuT z!6fcbu#z@o*Ht5H`lC34_dTWGtOBI-m}VlIupz30-ShLgrF+N?&pX4Wi2#ed8l?V@ zf{l`v(&Ais?Q|x7-YYyRB;Nyzc0@i6)lz({bP6>SU#n3g)(Xg0SA*5;jv zE8A8%>9+ws?mfcSfJ=ZMoAK>K)28t7lYP*>wOa9TH6ANi9mN?`b0w^&91+F-F zE)mbh;%hS$I>qme`ldlW+1eT0?cj+B))CNmN@#nuJo9+$h5c*syQ$U7mfFYX9~;`K zLLJ^T5lK%=4~{_HU5ENZ>-yJ&H*#Nv;UevMN>Opm`f{t!8$tAq?h)nF&73^+X(hdF zW%O=C#BK~;g6?v~cYKr94!TfsJQOW3uCMjYZ@$A`t`Mas9LSI(-XqgWt?>S#wG6hG z33-h(cz@O1fplvbfURolY2%z0`_oi7D2UwDs~0p4e*oMO(I=FtyAqreb&RX~`x@+s zlwcCahaZkteB17c?n9q|I4?k%<4-G|I-cwbGepdRW(S{BHv!<&nt|g8vt}~jEPB{@ zzcsRO8o|lC5iiGk^2*I28U_42EK|Vq;px21`q{-QSl(p;wL;&b0b2wlPoL5+?HXbAUT%-Ts8H2Y~b z&C=b|mxCg#Gpx~mk(qDfl;10wAz7ULV4ASykR^r+fMXHq&xkqvQG6?WggM`itUKMV{kb+>gBK-#EcQ3ZdtkMu4ke0Ce)*o!$^<=Xgb}f&F)yPIr*R2Ma(3;`Ey)8sM-D5|srms0YGf8@rVlMY90kpH;+h+`O zN%l#Ji^R>;MtC^pNaSFyp_^EvrJ!M17AoiNa=uL#zyCrt(E-RODygF8LU7fGko%^D z4*uZ61(swvM}Y>pUqi?f!Pnm`eul;jvQOk3EI17T0g2)C5UOiXTiU(uYw>b9f9RIRWRbo{;&{$~aj6LsnQ_rHCjs;_yBUiUq9#R!% zRyW#2Ftx1B^y3WNfWdi4eWuF20N_2^Sajbbnez4%-QR|FKPGZ%)FlSvT^iw@WC=jJ zI(iS|?67ax2X8!_TzK6;A84UEArWcwa@z78^E5-Y*9ci}&)Ful{a!Vx1lZE8$G%C~ z0c+4vfB>u%F*Yd@tzWq9n=9_iQ^HeSRkJDMd9ZCf z+-ae>AMN{|HCrSOtWL%+!s&VRXaMRN!nq&1+q6w=BXdMF76Fjr*8{aWJ5xi%w4repZ@eXo@=Xu^x7?!?aT119 z%%LsbFIz9=8&6`}zP5_*0iwZV{Pv(Nnh3W(TReETaO1wwdarV;cS2iaY`$Jzk1(O( z9x8x8H?IOut1yz2YHVp%B_m0SW(qiS>DEVMOYuS9>m>j6j5xVz)P?2uxOV_16g-OX zy0V<~$CAzO=FbZC%2SKqPo1Lqb{bb*n(XygsY>aP%BV;W;rr6(#??Z$_j=X%9%CrG z$1YlYxBL$I>Lg&NlnE5+W=-oQm<+)>7l$l85GKX(0Dk<6op0L0EH4Mx<*EGMq7hEe z+M)hRH`H68FAxA7)oqndqq*Zdt`$qQ1I**I4@Wt^zx;UW%@ccvz$;bZUpWJg#1c9p zlNZNyA$FzUoEb}>X&=;g^vXKTTkC*QBWxav)Y)wZ)Qzq}<|8N4 zV8pn_e%Vc+nD>{Ul=&iL>0t~{6*b-@l(cZlhl;I=ZXRLkwhe9{@*08 zq|^L6$~l}oe;n2RMKkKV1ZYNVsz}7YXoyN3;DRZyGB5m_Fl`S=OU`EyL;t21^H78{ zZo^Ig&E}514j9g2cyo%<>;ER!rch)qIxyzv{`D7$NWiuFYT$(YzyIZ$!EL~(;z9lY z?}a2hUZqz2M>7F)H55>Lz5{|q3D{uOKxRk^{`H%HXe?bzr#RipL3{hBe`xce$6A=0 zPW|(ADfjFExvJWItWML_ZKxyzG23^2;5}4#Vr|eq=t`MPYRe0fD>r!8Z8`pg77UdV9|88=q*R z_?Ra15P|SxcP7UJ*uk+IG=b!ecOVKQ#z9CJb^7NN3S|@MUpcxq_3cBuEoew(eHeFV z`jRWL>EL3H^w1T;#)0!q)gelRV_bYJFv!@oqXLLhg6VV+%AMw_3tu`K5GN6b?aM4K z&aFq-0$f<^7U+uEn;m?(yE6JxQS?hSPt!Rbi%@vr&c~=ub4dk2CL0@Z>dkiB{^=cm z0MPpqgeK=#1>;dn0hnqHpkHPrsX$?$0T$ZZ0|1r1t!wrOUb5`^Zr3r0vMWM6&*e`4 zBdPe8X?H$Uaesnov-)9wz8ZjK%+ziNfimSNzzD#W*Hvt~0NB|lwC?f5ClyE0r`d*& zEL$>!{#@?8Gjny+qxry!(W&Q+3*1H`X#eTnEGmBCC13ViU<#^X-sMUo3VREG*4{qs zIA6csBfB`2Zd*M!SXAEz(2;L3`;tT3RTv#{ampfAQaf$1j>?aY>H23^*#QM_tE%HZ z!$B=!)b1Hmc2{PqBR**8>afsu7MXtrtcn};7?kdvuQ2)t#ii91oCmWSXf%(ON7Jt3 zQzMdmuW$W(-A<>PoJ4>6S1T&IZ+jr~H_eRHl7)^X3G=KT7{*=>)PKJSGgO$mGIP|PXPL_FFW|98W3IZzj;wXzzzb|E5_z;4G$25# zp6YA^ILigV+i0BEo&tkc4QlKyx+3(ofxu{#CI0|Dct%5xpIF!!)c*ijdqo$qY!{mk zH_P4X7pp5Nyoo1(btBt$LCEwTX!z|0PQ_~Ac5)FGT6-}gOiXesA(S|vvCvV-mOPNY z0KyJ;*ui^DMhAtHP2qvG&^0R@-D`h>1jK0cG!F$o9RmyIXsUi`>yFnzwBwdzAga1d zpQqLjqSpd9{RVuuK%qcP^0eVuZQYU$Bn;r=pIJXv1hjyyrTXO&O2cDcF-)8>;pZ1V z>(}HMBVuqXkP~wa-Ww(Uj5Z>iH6RG^w_)NC)B{Sa$Sa)h4C#_ z&Q3C5SbBXktK`n1h)Nh zpxhVsloCCUF#7jyv*yvO^B~Hw^h4zVQ$l6x;Wov7pFsJIB(m1_>hd4B zI=B2Mou<(=EdPbs-ENN{P>`Js`(WD~B;`7E(0GV=mV1_kyQ)6^y_%#ve)6uL zLF|n{K8HM*%O`=HWwVs!V))d~$2K6!X2o!ox-zAx8Jw=0RFH{)_FGzaB4;N4%7`g# z0+}VuFCftPcl2;dN|LhpSGj;Cg|?spP|<=3HNqLy?mN67XoG|3K964iU>Bx}mrz=d zOctTtWYx>m=(aX)Rfe(~@Uri>`;LCvC3uf@n8h+l{xOJUqE-m7Q(>L)8ugzq)$NuK z2vHsNX}=!zI@7XUujPy`vmRUuW%p8_c)0b!dhoRMLh;vd+b9o<0$$djT|rK2%QfB3 zV)9#q#R~)lm^TCAk-^957hG|g!ey@)*=|476x2lV8trmRF815{8;T2X8n5U4v%yT< zxC;SQzHU56Ck6iq>ae8I<>?W5s3zRjcT?^!g&Xq#SaB1pnn!GZSJ4zb#xyF4?6?vd zzhVmLqyyB7#wS!-j*XoY<(^ZKl`AJ}YC@f?8i?Z%JU115{rW3`sJ$H_N@EhYIm{AaJTK87s3Z6#eEveY~wdN>n& z*tWr{M`wrh>J%tjW;zNI8!bGv2}t-VDoR4ZbNF|1uf?8uh_kqc^UDje-*;B~k!IGl z{VhFb>p17m%6$LUL;3+VVU6aS9Dl`GaF)Q4L8OaX{#9f!XraQnjLQ}o2VclA++l1| z7e7sN>$dJPaZv4jW+$c+wmFs?GiZpGsdfVyXpwK=5VprjMG{X@^Q7RfV-G6|F$Xpn zbZi*sM86fjAD#KI{Q&d;L>!d!U)rP)ur`k;xgl$RnGK##t%9dxe#B<3JU0>>P5tGo z`Lvcn-puH-BEA%z8j*S8FYUNV|NR(b24>e|0>vGjR)FLfN2DD_mWLhm_QTH~b^BphuDjqes2IXlWK4GUX z&9R40TbSY#&9vTG*RIVe4sK69BkpCps-w2Q#&hOWyjtHYTz(0E&0jc7e(bSE=^6p- z{R`l)-cS*{{ExF(c+nu37VedoNe%G%y*2&_(rmKGS5!g$Qr$PghQVSOg+x(yCp<^L z`=(an@9$EZQ+Wz9_=UejH)Tc!rs9Gi8N8JWs+PM}7YBU|T>3h_yOJ!$nW4^;uPMNh zA79++l|B=}aefts3f%>$(@*k3ar}{@} zSD0IvPL*6hGGpw|ydK@aHA%Gr(Mwfl5E!h_$Ky13Aca_vWQbPz&^-f_lon}NQ5nuF zhSUB8(H%O^D0RQWB$+2U%U0bgzD|>3UB+I>TA)3$&kwI1cxg3S2Q_AJ5X?)coTX4T8ZX>%WTK|WQVVY_5Ia7}fIJS_K*7m@%4%aixQuYW{Ub9GJz7%@_d;jg2Y zvs794*v>tqu(^v!PB-2J@FlRFR%bVc`^QG-p-%*Dq=TqO_QIt;2~h4uJUs^-IiJx_ zmyL93C7TVB`n#ox3)Rq2S{}k-4Jf@Cbw_Pvi}VqOJAmphHl3P>;Iji5Yvbho75Y9e z$$gyR)dQm%>cRtU|M~&bbbst}eCE$teIl2nq+wL0!Xj9tOD=iD@<&87VdK_yL4n)? zHif3ScrsNjg=HBnRQ5P${l8X&6_v!Zj>6XeQazNY0CdOJN8=wummUOmy%v*tvT)Cr z&#K!lz?PO$VgynJ`8Ilz@o@QBs?{)hk;koxOBLTlwkY)8NJe%Nx!vH{IN;M%0S;fy z$hEqU!VZEyqXX+Tlr&b4GAE!s4<9dzd`NiP=x=+RW0MBRxRlrQs3JTSfcY78mLeUK zkbCur?B(4!w2zBLDIyGTyc!l|GO3Qm&a{~<_d0=5jVHh>8lHb#ov>_H z1`GwiJm6dPY-5YEy$*^owTb1cdCQ{z2&jifp2Y0-C!EdsM_(@CZ3u3E@46--h3n16 zA)qh*Ca3g$se`k2Wy_CJD*&03X>w)FY<`U4`ktSO>c{w=JlUzbd)wfom)|rVo>cp0 zDt-=OfQ{*>?12aoTnaNicP4XK2PPi=98D7zCHdO-R+K$eUyhg3KMjjT&2N#UnkFE_i3bu@_sRI_?oEC|kw&e-2uKTzi?8c+ySFYP$c5_WwEqe}7ONr&e^{`egP$ zx(fW)cW@MWO%qnpEC1F2I;ptB_`>_A~0K%NI|2vSU+|Gp3;|CX$9{WF~oBz2kXh1;ouiNJ0fVj3XVzt8a ze_Z;X=T=$?U;aOOJe0t*3giTEA~sTt|9V>5pcVFiixc5M7YrET%eSxp%41N1_Y?oR ztpNd87+l@cA2O6h_P=g&ZKuWefB%N^h5v67{^ub5zhM&gkG6VFXh!G9{Lz)U`z7Qx zGe2E+*lF+2A(y+%jd$j9U+X{mhl+h7`{I9p+IRBgu2J3i_vfu9bHb{9Vo;;M z%(=72|NXhFm_l<`>EypZ7hR}s2i+VfH)G=adtC$JlR)mvbsGfI^|nE~g%%vAsltP_e;QcKC6_-WRd-ti$jEs_RKg!eBLqZ6FYaM*< z$+}XWu4T23c^*YO^h)F4yFMU**Z!=NE$En0GnHoflcMNCF>Z{2t9%#I>8asRb(I)B zV7zn^4KlZpXD9Fl_CsTWrq)k&m`LdH0$(Rq8IGYyS=+g7$?{Hf`>USC_@V zQp%Hze?m`3Mlwi!|HuvkNf6#Vc~XNJ1qI&$ik3cc9iWkYfJnNY{}GU~=|FmZy?nS2 zY&xkLu%Zj@(+55Ps&7kRnARd+gID~y3A#AwvwrMUgx0xKf|j3=meZYcn1b8c(G+~|Se8fAT$*i=pDzf|(GoXStSRjY z#(;}73zAA{f+Jx$l*;WGsA(MpQc`Uog11?DMmh!NZkH$bzUb~0_WNDCMBbG#Ut(>H zXw&!pKJiwXMk(D)iPc{hK?)}Oip=BO?IwO|S~=cXs0xq}DNKPr-6WVD_=%4coUs~D z1~OrKPigz^r5SS^_vh7J*?PPE!s%S3N@v8>9s2Z+Euf0wRg@D{CkwI(;Jl;QSt#!>D(pt2kup{?aE00M(M&Qg*<`jHvZT>;5Co2xt5$u zwPG(mOI2J_5xxEUdhM&;HZ?MsK5y@tp3PAP>7RY|9!_ML8q2H}u|d~wVhx?}J53h~ zT>w1%P_}_TwaP1?7(aMsv?V#h;~~fpa3w?Kd(K{dIO_Hx_KZ_MysE^n|u58_!D?!yegAC9DA2uIuBuCMH#z*H^AzQn(sPl#n?( z+*`&NLE{g=8^-O3(Nafj+&af|DuZ7+G0YP3=PU#K6?aA_P~F;po2gT7w6k!!PyBnW z`z!&Zu9vv3E8ge6iSB#+@x>*x=!|A9bJOHhBN=ST*KT3DHfdyDZU!yl1mdLMEv>_Y zZChKjlQT@wbWUwAL~WC=Ix(9`4I2I)c zVO)#5WkC}`uGP(w5{aAO8f{1{_r0w1VM|Tvs&W`hVKH?sqo7{;#Of#)zPyDpf`7r1suasnSx^jQQz6tXip>v3KlQN-4Ek zqa;+UQq(Bg(ilN$wTV@fw$GhsT%YIqUf)0A$q%`5CHI}&Ip;p-yx*_a>%876{6%uB zL9_`T>zo`A@{|O z0?T^h4cH`n2TDln&N9lj1JOJocgWVWc(uwjf0@zOj8c?AC z{$k!hA(Jcm{Y5a)S(h;!!il?yR?X$f1~A|3UlIJ?1~o}l-@d#6$&Rn~9-IsS24om{ z6tKY9?E3jd5*8Ccvn6BaU^Mn+L7YitJ*80FOry5a{Ivkf{tKU9FRAaiq{9Hfx=EiT zBsjl)6^g~S8QY+Fo{Vgh#B*y%VKv*<$Tv^36Wmt0!zF5$>@79AX-KgW;1fe>d3{3* zz-|*B$rh^y`TNh-^h1Ez_O)J!)f zu~do6{8^N{O`Dz7=(mqwrxw<8MLxX{;pakP8bLdLmm4X(l%UHqofKukh;q+fggWw9 z0R+cZR_z35xy6U4R+4d&@=P&rWV`V?6s{5hc|o9CZz_4ux9fB8YJ7*g!LPLMmA`KB znwvT60sU+mtq2G4ToA-FMCle#tA)IAE}a77CXIHNRe4hXpV9YX!G;vl~GRhjS@%{e2-asLz;c*rti4r#QL}67abyR%<7@p6iugyS zWBr1WYjFpiP(1)3|MUr0zV{C3HAJ~tfxNhv{RewUpiin-*Wp@%)lBeMqj|UOgTEYI zc(FpQ04Et(`rt=?TJ%ym=5&bcy97$tQT@riIp+?fz5yYcMhejU|GG_7;nfWjfr#e! zlf=GW*aGB@rU&eM1f z*oJYz*X!xI$r#;au7Po(o0UCuy7N#)L;>P{C}8w))wKEcmM;acXqs_~Gh!dUnBs!_ z0*s>hiPMf|&Cb8ol|GNg&ExY+XHqn(weD3HO>Pv(`;XpH>nu<#b`eQBc8zLdXFI#i ziRIY#1XT81DRe|zy@;Wh5FaU261~|qVL#wk4js)ybt`fvE0Ns$gk05!Je_rh**Ef@ z*69ib^?Xe+)Lf`UlM(yxcg`>@nDdV)Doq1MG8=5%(M)XH3!A_aKS%>feq(XgF%*Pg zu=*^3?B*WKxC*=kX8KP6{Jm=PEJ@`_yQUN-$i3=P5*;1b#8_X)@{nfIMu=kxWE8Yr z(E#PJq+ryQ;Ir?*0B7y#GM0qKtgO;X=E!=DwnbspB&!i_IFWVo%=F`g7j9g@Odd1! zcCzK29fqOC`jGv)b%Bl8$vqD=T$%}!40~M?_I-l@3n^WBNwbUTj;3L7K{sOv2AKqg z&F@f- zabfxf@k>1kKxop@7Nq9=0^e;negIw#KWoNC;)P!j4K_Q!8&Ec$pGTh> z4mE1d$H<|`rNJnH$KS%&%(j%9hdWv+CsjQl{uYz=9t6(fD)zx=9 z71{VW%YMgD)27O#Vhl)Kj*YYjI93_KaQaBt!@?U#EJttTiN0i5`({C&i$`}wYx_pI z4fdq_W58>6kKGBdg>W2tvsMrajtw7=#&V_RGWVK}95@ef3>a1{7L-Zlm*FNmxw*NG zlVKjOSxa0?2kQ7VHwA_zb>EKBkT0GVx91;^A|5@|(E}6wP^e`2Q*s;SpT(tUen!!U z=SQktQwUUr%AXp4_VamM4-kgWzk<%A$jnS9lRZ&A(ECJIv5V|k9P5DtzxaEz`Lxy* zlyWy4YlD`T<<)!_!d|X4xa3S%`<*h&o|ey-7`P1z$#ywt9PF{lsJC7qE>A~}J1h0Z z6}u)`3vO+B#TwwOdHv&TdOy~;_$X~*zgiPT3+=NK2BV2j-KF--gf*uqVYDW{iXdzP zBCdfF2!7oLL1xBsp53O;Ewd4<7~xk_^Vw7HuBoj^hFMS|sC|R`M*vN^HPHA^55aR> zic&AvQZS96C&t}Cd{oZ*`T#s?U}SKg(eEp4J8{p*xuY7a=9bPk5H=r~sr=BiO z7iHE1)wWcM@56H)CPc*Qe(h-#Tz;Ii(7#-sRn4@ zjDobWL~TbFsj2dTb6sz~c<)XvtcY~!8cpJDU1uNcD?r*<`L)(>qsqrEqmI1Hp>Ngh zy+>%xX)r~=ht>+JgEyhWrwCtz%f`$47M*MHWydu8#;BiO<4{if*hfV)n>OP^pGHVE!SPf6Diwce}Fhs+=iV#I9p) z4))pO8eXAvf>)+qwRdn?t$6-?xag4kYKd}G9Mb>?zXbdUbQfrC-(*+c?Z>bKVy{`CC^UWbz&t3O{=p_R?cFZ!7?BA-*@$IET1Y7bF2-MmVQXI&V?e{jt&jEZYJ z^8r|LNZpp1vdKi$y<%ch2Naa11R-RTBRWz0 zoZ9_>V(&rDptvWh*I>!7kK zn?2p`^X#~Hf|@pWAlvq&$&c4Pz6|etyl6MVnSf@FVC#7m&sTI!;*%Vg>;T6XGky}P zS2QcqHF(IXy0rK9n8Fyw{#pUTugt0OR*Hs}qrOEbBOVcsut$8+SP{MomCw@__)#b2 zs^^v2?@*1;*bG;dp}&x>5g8l(^Ld({bz%gI<%tsmgl743#O*v`qtpPY55sOH$w)(A znk{K30N1?nm{h~IjjP2hHtD9Rqzu5L8DyTV>8b$&fIs(s6aV}n^Vt%|MeRhecBR(@oa>{vhROYa-c13BW{7_+QhYoI* zS))P}r$EdSd+yuP7CbEeu%_;#-9#&nv98GrlDqw^b^7Kyoi-vF5io&G0b?n0Or|j% zmmsL`uox!AH#u{S?06oaV&E5Ad9=5lM`^pyak3DkYrD8n(~|VNY!Cfr*QMQ$Pd%X8 z>+=Z57;@NU4HY0~p}LDdd~T2IND551u~!ri=;-J$4-s}R54;WxF-cl}bgEE>EAor5 zHu*6iukt>iE;;h`iQu{$$H_c5xz|}Me8*PYeMJiImMq^+d z9Iy?8@ec8La93J55ZkPnX_vrW%2|Rgm@6<+%b|e3sZ-^JdyCp1v~?CSHnzQCCwh%% z>m6L54+C@lYoO{Q4mz9S$30!^J&VlW04(hZTzWR#&fE^T^EhYjBC1;h{C2#Tv`p(R z6Kc0Cjn!<9(@j!q3z>HCH+z-U z2HhO1Fz>mKJ99VPr!QOX^ES}m4Io=+SLChsfk7*@gn38Dd|^wvRP#ZH62|$W=Wo`o zuPAcIs!G>_MO2SJzzvfAA_g()ox*Y~_vB61S~v8$Z!94mQCR>PXtEe^^Oz^R*>@!+U*er>)2@!GHCw2Fm$2I9az>mD>Zw zeBX%xZ(rbyqca(rsED}`--*5y3pcZm1x0hciFt$}A0!F4UvfQOIv@1-(=c;U!hxWQF>elbB}*}W`wkg3JSQpv{ha{gr#d)^-cA4SDpA-LOJ(qCKJ zR(j?b^62{b<37iI%Et_47T&XX|FYD$oK(s0^Q_re`GbHh=GA+%aD)mdcT3@;QO(_o}ckQ(|mAj4lVLP z{GvGf{@0{eLSeu_aeU4}u=TrX%xR>hNT;GzhqMF68X4|v-b3wiYA;EFT${&dNYuMg zCMTah1AxRe#skgagA5fV2cds+1|h;|Kv);H@5L_)cm=0hNcgz%Wc*Ds92?P*jb3n+ zuq6D4cHl0e<3YO4fU#CXXP+#-4rKTbwGb&G0xUz>UBMcy|2$uNgSGKR#WN<{9nFHt zCzj%WlMs=^qGi*L5=Xzp_h?;no}R}vLL|r5J-%tRrKN@E_~Mn0zpntNbb?Q7_~;ir z@L$I!Ax?YfB*d+zx(m%PTuFd-`DlJuGe}#pV!+@NkN8`n4TC12ZvPl?j02z9DEoK z&Ix?Nli)X{9YHzZFI+nn83~-6X2u!t4`-tXvQOmYaX7)xggE%P^f(02N5CI3T!#OC zmd544IdSxRJRBTfQyl#NeMSNN4gD7e{y^XP=kF77xc~b!_+1>{zn{j3#hv)~=M&H! za2y*Op)n9Wy{l!1gF{CN{lPVkZ3n}KKi+jm15%(qSX4Oe)3gvVAEoo_|+do-Kw|Xgm@;<|wn2`TUt&fB|#Huja zwZmR6Qoyp^;&9i!-|lu4{5dEbFC~(~!oml! zN-VgcGJlL1{8NRHeD5yqSBt+!N?3yPDfnqD%q;l#k%56;!6!r|;TQ<~HezT3&viUU zWlCFEivD_#7d0L}Oas9e+^)amUW9_ zaZP67uj#VB0Hdac|B(K3re04;$oB|0w-Y!1GURIsFlu7fCw~v1lzA~YnVH`-~~w5X`)98>$JXQzgc zCr+N8&PU}Gj~FG2ZjBbi?q@CZ=6z|tHF3IphkKY}(++lH$BjPY0dE45{C>T>tf=`i zHIurup2q#n5$j>Q3(EUHo1@K_+%UpaO;3<;+?Lz*=GH@KhmG=)eV);<5xw~+yXHJq zMbWO$SP{4KjZaH6mBv#Ip~a8BoE#{%Nh~TU8F+nO(&Fcj4=I~$?$XLP=r}eoia&2; z`QhL_f;RTg0+}I-X=$&LVPqp}cVNV^FVP#1CRPdyUUJ!vB!mU$mPlWo4h!0mBL_3&=vpm^3 zZ@XBA(bm!$4-RF;h-Zn1&-tm6Cs{p0^33>7upDmHg@nwiNH4QpQdOX9UAw#VZYC^V z(ic4{^FGFRhA|Oug*XYeTF;iUY3`ApnQ4bh(bTiH4!&5rJ|jtHU9H@@Tf4JIgxtI8 zi-T8t<@J`1nGG|q%nKa6bHvjb+7{c@hr2M>!`+eC{Pj9;A3 z)@9_gB>s9WuJ#!ZX28B=h*YKInrQf5?kBJtFIQ-b8Gd_Zb@Pdg3_+)m71&30Hd%;r zf&S0j!*1EBlqr}on^PYiejfYcyAUM2_)$GmLn{32t*3GFA8ynV=#^NHSna5~?98Mt z?5(ww%ygv6gik3v^)hffz=t8j?%;SC(6DYPXBmf0NW?MVM?^%-Cx`3VYVN&TSZwh) zC~!f@bl!-|&gSkLza5L&uUWI2SQz!^%K5liwON@%B2%$_-&F@5)FQS!KR^{Uqzb>& zv_M7EGE&~;SaxR)V~5ssS`rwp`n;p7(ik1bT{g}n*kQ{RS0i09`i-t?nV?4~TfZ9R zIG0_*K4|(N)|ZHkJU z#XKZX_tM1h2q)p5#h$m}kEsZ;0=S3aO{pTLNllI}Qqx7n2&mpq6;9DbdH!wqYvF?7Kv|EQcO-It|RFi3N2on*M! zGBWVoP4=80{xn<(O|QF7>6D!xrFl1m#ir+ViN9FfbH9t&t_y&Zt!X*^c-%-9_Fpj^$7E>OZFglY;%xlGmaK6 z6M=&BiXJ=f=Dt)}kBABCtTPB&Ma^VHZG1?pSZ(0in0%**-oIj!FzI3TZO@C$jkS%_ zE7cXDhpm0!N3gj#G}jzm^(y1iYuJiu>)JJ!)ycx&w)TLXql znFNUVo;9@=rpl8CVDNp5iY!(`MDK=Nsnnm2|U1jQaux##$N~@`iO}_n?%(rcm5RQFD zj-LqVN)cCj0;N0O#`+H{DhMy|Z4QTWq)MZ6{f zwn7!R^y%u%>?7)uTUsKmtNO5m%Hx)o%GEvLEnM#1dUj3x8tUo}gDaD3+fnn|iT)gY zS+<$eHxb$Jb(grES-XSr0A8`2oU)H(t0?!upbt>S+X6+;*0s_A9K25i*yLVwR8}{O z$9|6olaoHr)ah z3GPU#-HnZ0_vt#DiUb@V6vuQxx@HKgx<~(pyx`EJe9&;!-fH7UBFlbSW0xhOtUX+M zKfU%j>Z^zovCl7fM*Q)3@}%@*Leod|E2@@v@}rv6 z9nV#5jR#nD(5s6p?Bp$8Do}rTQDBQ=QY1}X>GjlJU5H<%xABf@5as^#))cSD%t(@O za31DHPfK!-zQ+)1>kA9B;Aj6Fi#f%$@T*z|u?N`H?EL&2+ST`GnGv_^oNR4Q!?r$= zp)X=PZ?~CnAcI`%wDt8p23xz3K`y!Km9#X3SULh-8a0(RGMC)PI}xk12|JfE3QB_A zsx~Uz@hjywVt&F>^!H~q7iF?M%I`5ETSPZYX=V3!H=Q>3H8*IMyJLqn4`yNyzE@|I z$9^|oD-Ha5U&*tbYc}D%)%RLEfgTNF@)_Of+a61mn`n5}VJTKC73MIdEEg{R?4*=2 zX7^}yd*X(_>`BG?A(!@+4iBt_$DUK6C&?OKxVM7*6PBUB@HDr@S@s z>ZUHMBrGYkW8xXh-l&bAVi&&zY+TJgxp|J6!sL_|PZHNVmhF0(!C;m^zC_+9^fR?E zgBLlvDLNG_2eSdWVRSn-8@{l;PnP=PrxB!Un;qewHmeWaH>WgkSqZp8oYnOm9l8`X zdpd0;V0=|8f!25Ig3?~RY|`%O9Hi?!_5x8w&gD5u=Ef2H zOgntqKd1SH_PNKZ_OI>+h3a#U z@X3pxi6%R29qp?Um|%+CNZ!<@nF?H8;vHVi>#$vqYAOoVR4kfbTN~*cFt~oXD)YB!zOzN%R1CTXni6*re)YRuVi+03h`CRoI5b>rp@>rx>9u>4)>gG z8JosXu-bFIJF90wg}n?qfo1cjEV%f1Bx$Ys+>c3l|%;+ ztP^aDF1Z8`bS4Ii&7x_hPif4nkKf-|&kJ_Xk}J^Lf2#TJX(%Svqe^aBQvboap~loV z2J88^E+>(>IELTQ$x=ZrSw#OV?lnZPLYAO;i-Z?+gr%66K83{zM{j;8#xlRm|x+!+2 zhaIVgW=SeG-IPl3JiL|nKX#JEf}k5|b#A|AFSB|l^Lh4b?@m!=(Q$drhd+d?^VVpR zlu2+xb(x=`-4HHQp$NxROjwRt=GT+3&i)^+ifwKLve<)(F8K~>RyF01>WQAtmpXbv zN~)X~a+@QS$lgfs|C+;V8m7-_7wf2d)Zdls*FvNrLveM zOhL8P!LAw2WD&F4L)(MBZ7WAh%j}t(@=J2a7VJHq1-tziOxGLGQBRBFT@+am#YC`4 zv10HJcun(whYsp~BBLQ-+RHv~tM|rgFi!jO*9~KhIg;FI&TvgE+d>&8#*VCO=QD|g z^O{2@j<4q4kXWbwP~-WrrcJA^q+Tww8w)fGikb%M_5?#DaObu46H|@)2e#pvw!;;y zC1|{v{f|+?PC4qDnlhTIJ#@JjcW@V)S~dh2G!LsfZigSN->A%8)HS=5oWPYrtWZn4 z@7zizPqW{vQ|0C|e__!^Y<7`fUej))7q(P^Sv}kNnN46h{+a49&i-o^<&Tn@X_b_7 z@fdH92=pUogaH=WL3S_q@oL0s_qr8T*XwSh6TQeR^*>$qXg*hTa@*Hq z%hzO7vT+N8nH5zWkXPG7P*XX`-Y>M`i7hBA^v_rV0`w}sj(W3Lb)L8 z%`dPGXY;r@OCjFu4aM^fv24{RyNeC4RZNW1;^HHNK@&Y}mvdVM-=;i)eq)QWeqHQ`L z7u7j4Tj5Qv*JhtPJ)4Gl4z_c_p`#>b|~V^@Ix&Z@qh;k?b43 z)j3z4#j?NR!#4ko)3Gs!dQWN1^^gvJ*mVz)x29hdpm`{`Wk zUdg=}3=KbNi-U53dG54-609|G#*8$<>K3Qj{`KqaWFjn=bb8br{GE7Iv#B0ep$MI8 zcUMM5H%5@#Dtt{I>OYm&tivPr*UkpfF`xMNu!5OH0*`8V;ToM zNrVU~e`Pd2?k!%H-g06>!i7Z>)|sl5!DhEZ%Aiv(z@gZt_ zNOQBo+GBTsM=QsBd^RCLqOefY9JighwSDY8ydbE*q*URia~Vm2(@TwwP{E&8DQ)~c-99;VqN%P(xH zc?MVl2~%ojAW-Y-)aUlE7-W8TS<%sp!eIe9>UC;xj0>JHRbi#2S8}d}9m)`>i~&&6 zsVtPg{<#~QJ5zxnxP8)6M>0LbggQiyx00{ktQ*nN2M_)J@VtjPS+4qns@N9sj4M6W zyXDn(;#NgFYM=GpPbsmg$bREjQSV)NcI)G5fF;U7SYk6>kyIQj0qxZXu>!M>GqVa{p8Z$zW3RUrh#O&*D$yYVHWNDDL4*uzRmUndWR{$ zGB(ZXd-0q#IydM3y{Bh$ow=F#6u*0uq}yYzb9jzjdf9dNu3jUVL#~Rlf!cnlN;uUA zm+bh88ov>4U=k08?~9nVqHv4q7XrMox$;F9u-+`y+fCN@6WcE~L{2!(G-I+ns%M6K zFexz|c~X8aE5`|_R`=aDHs~h1<#Mtg*XVD3rqdPVuuh8m*n8@>LYO#?*El)rn5j*a zOc0(IGtGN`a}BIaL{$`0K6ZVI&!p{Rh@Q*X3%ZK(MU5s?{y}|qSbq_2k%F>ve(Rz| ze-;m?9=Qxcnv&q6?v&f%UKyW=oOSW+QXuPu{F+r4_mj>v<;8d(Hf&Ca>sk?l)FCn_ zadvTiNaLF1wphp@-*7A*|7=)XLQ=`W&M8r{^svUxF1z}l0_sC=Ep=e8D|J~0iY;%F zQqb#6X2l1q+eaNd_RVYD_&TvDa+0Y$+)xsiKUeA%EbeFE^LUnL0d%~M7m={l=6A*+ zGv*CCurA9gJ=gRN9h0@S?S3^mvL!nS+|>cA@+N`?LsBgBb}DTDyV!cVVws!l&~{W{ zwwzFK_2K?fWOY|kcb9P(x+E)^x8zngb zzC95aXQU~AU?Ye#`D-gThvy=uW_{m!B!%K#UamezPO>M1AB+hZP#*dk+I|atAf)-?nhsey{>%H)SXNObuE@G1nE_W&*a^wDm~k1G@=zfz6rM308)Smnj;Ov%X46|}2_ifBF<|}D=Qa}7E8bLpHi(^m@F55wgO}kY;p4)oj1+jH7oHw$V+JJ2 zM5!&c;%j7DpWA1A%y}8>H0-jNQG%-9o7Jqg^%(w?M)0EQ@s+2qt&~+Xc(pJ9WHs7R z-R`Xuqb>PhgC0;Uu86>!E5%A_z?#8`W6P?Y=bT~7ok+>;GxoKLLtL~xy1hm=W$d$# z@vI?uOlZoSgPt9VQQleU_{75|DhW~^?G3l> zewJbi=(b*KT2ztAH~!>LUVj#%TfN~uUqBxI9b>MV7?3_U2rkz7quGu zuYO6h$JhxJGmEouDaEzm6xpt#&miYGd&gRA+L>C6tw-`o^42?R$qrw_%m^H|r^014 zNuIeE<`04(kw6zWCDRsbaE9D#kkGlQrKrexKdXyz8{47iB-{~x9+B*?S+OuplGVf4 zt+NQaC_&Ot40qy;9tov`>k9hymdBpSxx>}GH9q~sw5{9*vxusi?OnFZ1Jd~}{bBwk z_eB^;Nu0ls3YO-G=hbC7tk-U^v25QhZ1hhtWpZ!ZYh8ZM`LFU9Bx6<@JUa zptX*mO?4m=OK0%JQ_{8rNbPY8W^SH3yYj>e?K?Im*>>d~yEMa!f?(KGvnJ$$j&E$EJHmzxUm{g%3ySPwbd-Xxr zN?_B2vfb}r>lyMG5mp?|_J}z|rmK+#OSiId8QFB73aU z;q_TNSzJ=ju_P3AUr&`CQ9A zaQ(#Kq0x_l+$wz2azFps+%UHODxX^H;cLa>z{hJ9?w2wm1xL!&H)*=WHMi6>_i+@u zC8Xu!`RzX%k9{x|-jTECDTK8s>x(&55N|nmIGizH8v1eT@bO}_(;?m3tpMrE=~fx8 z6>CxRB6kGI%c$&cXH#|ZYhrF_4SCOh5mq=ivppKkwzP*_+bX27~kkL1k8t-l4B zd}G~k%lh@NRwKgfl~og!8j}ec3TXA$Ev_+jR&}mfJ1J`s#N&&W1LaL!<}Cv;hZ?z3 zOx77a*3gPz7UXhWnPd_$FE;7J%Ql%8vgfKvclxop@0#5!X+yCZ7=-wkbIY6Fg{hs} zk?En@w=x1!{n1jK);_B)^8mxGblMpyF%16X8kLK&%fP(U3*U*_5No7C~S^hbK8agcu8tcc*Up;u37*mMAu4Pi5opNr02idX{+=Xr8i0h zFouhzQh(8q3Dkfq2uTfnCh^}Tj>otjIU#CgX2jdK(ii|-O-_sJgC=eKkP!3hzEW&) z=O6KB6;U#gF$iRt1O=Rcb_TIJD3ao1&dmaH<$s@5D#W8NvHv9+0S9WwWC zX7Y*xsA2t;%M6l#U6G&yi)>>2eCLm0gRW2k@S2k$?*h;N`{;kC29F0@hU5RsW%A08 zvdVrb;5b@?J5XR8Ov)tGnUR@UAe0{#Mw1w~M#%4H=Q>XG9$5l@ z4J^)e1tun@jjvRSy$MnQ_r(LY*DwR5f`jA*Y9D=dO00@v_vej^-L_YGxI)6h`qz8( ztpO^A3ZD`ATZ=FVL2n;8AJ~9mVgT(h=canIyxex`-tJGqYSV><1*;SM{QLq)tVvJA zmk@=LFah2^1m*k=XmI{&fa+ZvZ%PYA5eWag-!$p2DVv}vWN{@ zRc$}%>8c%gkFHKzKSl7z;^E@2^1?rN7=B|Auz0%3Q@v-q+w4?j2sq@12Z!mYF*l}8 zU*Y8~$wjSaBJf!H)4aL*eo}ldwat3@#}0w>QpHWx`bz@IkJJB_l}u|v?zS;fcK3lx{$MR;axhl?l?)(bi7TiMnp+0};-EY+;=1 zNDa=`E@}WwGV?%*b?p8)UG>MCEh0CqFL!4?$73<`lK9hO;$9O_Z@s@|W1yFD7mO>5 zeyo-&cBeBntGKw^)db?s5P+G1h>!Y>r4eff2M2*ipPx_fxeb0dp1jJkZWQ^iul#Qf zDphL^U4?p}7>Niwi)i*rBN9%HsiJvV`u)1AOu%iQ$6Qs z8#1~&JMYY8GqU6*1BO;89WWfoSm8gO0|-PGun1ac&iJ(I(_n7_cNt6y*F%1XH4bU! ztLDC97VR%@g@H;5VslOgH~5=mGo#> ziPty}q1S^qtk@6#_Sjs!OZ_{5VSN%F8k+s^@oQ|Bu&{7;US1mrNAg36@z7_d$Ra3X z!J?0ZoaS${UFkNW|I3z-FSoY08`!&k|E%Qo3h+CZhP+piuLsz!P~7}egj1i1Go-iL z(D?1o3-Gv9;B;8*xrc{;FR>?NhoRSx%A!o>k76I+%87!#5~|8{xIVzl-6=9EExZ~g471LQ;fyk z_*IhQF`G0Ch@Rhq34WpLv4(hF5X1 zfVr6&v3z!;gsLj1|KAT}Z*6OP1bn*KWaW%ScC%TL|}fQs!hTqn*q^kxwc zd2sTN`FZJaeso&W17Y-;Wq;xKh{(utj`@L-mM|WK879h3A$;b=#{WxNhm(;o3fb2e zEi1Je9v@1Osm99=3uN`V^|l&ghEQyvPI6Yh%uJZy6<4~t{Z;qO_F_Qft_wkse8l$RAf)X(%Al7I4xoh zt#rC&6BCoW=H_2~$x6PzyF36ehHUGRD${HoLK3Pyj^_n%ra!h1dHpj_be_c!x~pip z^;*&FPs+m6XRp-&+i;N$k=dR9pX%t3EcKE?$AVZ8ZH<>ikOUJ^u|JTKlG2{@v3%ss@<`aQlkWG$ zi_;|G2(H^$7(moWmVk1_OBBTD1Q!tyVNT@IZ4iLR?Vncp!-2i@&VO{>w4z&9k_%>& z3BniqnV+9;!W4`LUKhr#GoM{DY?}yZ*6*NTDfg`V>PH3f>e1=}NG)#` zDR<=UTYnhk8ew5kQJDxckbl^h%uKE*y_S@fyQ~;vL5xxgqGR2%sUTAQYq`rkDZ{(F zy2{$Li>-#tkc@>UxB6`+zLjY(1DRfTs$y(h=0?L5m|;Ulv-ltWjD7B@Z9fkLdjuk# zMN-^i?N^BH??>=`ClG6BZ0yh~(&Uv}PSlJ=o%utr$U1^Y>^whMc9VYju5l7=m%L;< z=V&~PR9#M1ANc6blPs@4;sRX73+VL1v(&z{TMbwAX=`i0T z)&B@A`qMs7@kmKYc}C+km_c=9n$BTe*PCaE;0gxnAd2O?a=crg@nPONc^7|QL>xSY zhu{f0!U1^D>fe=gS4Q@-6iDXt?A9OvUWJCP*oko0T&0s1mV;~j5pFs3a2*Vl=g#b>yUGbk;Re59lRY9p`lSqdS` z5MiS-R-%1>WJ!2E!5JnYYRj=aMfuV&`vh#JGeQk#^EDIqanv$U;YJn8=^ z_pl$~c6N4rqj}~%IfdqlZCahlQn+NKs&ZGZf+Y6BK`f2*cRRoZi5PF&el51>8%Gy* z_Vnl&8yjo8GK;yDk~qI&femwVadl{izbMxPtGL1VLAUMM;Vv?zuTMXno19B)nMtp*gD=#ns3FHk#VZ>ru106slw{Ow=@zE#E$FhG6=oJB< zL}P|}rmkGDJ-UdaKK`Na0rg(ok(~L*_T2RJ$ICI!pJ}e z_=JEXIIAhWP}DNq`}f@qipHts2W7U&>E?2+307D2Y7FaNpBL9OEF%Y^>*eT+>$~XP z#d1p^9<5668&bbERkgbqa8j(W$UgFck>-i+-$l$wUr%jE4zF}?X4lO{Qv23>Y8jrs zZ^Jmft|Z?A4#Fs$D;IV(4@lG+3D~K}R4EzPi5iPVwkO|Nm0om+Hnp($;b@ojM~y(v z7Dw{Kx^g|NaVFmWMEX4~tq<~$ql-$vI{OH80O=Z!Q1wSQ_xi6Q_XqQfa$CzSPk0o3GXk;L&F|0ufy?5T%-1S)p8|y5^wZvyA=5tk-mF!a;iKn5Z1gr(R;fKG+S5AY}X+= zfj<@v`bQJ#P&wV=U+4+w$}3eKf@;BPlD24&?1pUs3eyxLI`A!!q0TyAZ~K+YZX6M+ z!tzalE0SW^b*XZ())|thO&c@|H2eG(WxVv*KbrUOgnzXM)QDG*8-QeZrj>Pc7))Gf zIZ2cE%z3%V<(G|K#dtG?K1N{ zp{6cHJ`$NvJpa22pYUoNNhy7bBUJRfeZMi&QAetrslmLzvwqVdu#M%C8y>OS{is{l zd1)?hy)o>D6kAqzbpol}+ob5Goabh&d1BubV9meKA0M(O+zpphGZTa?R}<=RHK{LM zL5#QCJ8P3GDnz=*e~@{>#ah@t1v84_h8S&qzZ`3?(-{DIw=MDT?77;k)vbx(0j}H% z15C72OtgKwEF$dt`ZG7_W?O`2F;Er%dQzq0=TEI@2E(J!b--KN^8;v(cLNG! z*?&iNq`2ZHlesi=3V{wAU+5bsE$FakMUnP#Q(A;x&Y-avPsfkisAmiA{Xzzt-+4cT z+Fh{927TaV-Q5`1#p@34A8r~4MZ0hG5rrF`JT`Fv>{;M}&f7^mI9k~0q);K0-Q(&2 z+j`TlcEsa@U-ZK-w$aeF?l8}l{@kgk*CsNGFjsD6i?+u#U3y7FPRI%>5**5aV-Ew1TFVoMOCM4q$t_}^dGXB$>Py$vTY*0Q z-06+iM67d@Qeb9!5BfNM8>7+|9@qc@#fwnEm0l~uE>$KX9n8F7>vHHMZyRJ9 z<&~0>GF^Wau%Crt_x)ELIc(CsnapO|K=FEC;ENC;_9e*l%#Vyyk6Y6dj=~ob61tCg z5R#@=CJZoz*tQ^$1+qbdO20fkQtcr<@aPggu{fCK`6z>yE`L0^BOMPhzI(z*> zKl<5u)0@hVqpyKfPxwVEA?Ua7%n@`(Kuu;F6dxa}aB_0$=BrFCDYdt+uC5+U?)g11 zz%&KA(Bl+>*DyQO47;v130eYYId_?hNyy&BoH#9UTF$#v>Vbuq7TD{ch_`3)39o1Q&94rEL)&QtI+qKKl zXrX?zq>IQ|e*tbWS{_@F!K6^$w?%&QY75*;yj>jJ#K>s@6XmCpZ21s0cIj%@U2cGE zzT^e_v5G|X_h(fQb3B>>?Dv#6REv9Ubd^}^Gr&+4pp=9%M&%1oR;3d-CnpIQ5FLLO zoYfew+-b=XXgoR+6BF559eq!qGw4WFECJmj`7+_fZxV#|AUd^mJ(D?Zd{1~YL4SY0 z#l*M3oKP7_Ny)f{Oq~F4CScG4&_~a@KCOv`g*`5rgSDxp zC267P)Z%*~hYaYL2-wVYCQP&vI9`*$pl^4t{obG6m4hR+iY*6zdNT_<&J$MTP=;rO z`dZ1|yN8!s=MC^h2*WXB3;VtQRv~5-bL;KS*1;Uu!|p0Ko&oK;9GSE^(S>g}2p
tlk&aUWLSH9J8LGSt9cBaTq6Exg|e;{)Ee@C@YV>4d?VSRaaM!b9ey4 z2f*}~*$%(W|Ay>HT!?#~OdkW4pH`*)aYXk)IjP$d2GXTgteCrmZJ`>K*DxD^!#1PL zgQ$r-?4*C3qXMK=1B@XLh2B(SG;59(N5?@zg(E;x&Hzo*^!&R_2|Q-0A}~pitO{&_1crpdM$)3Bym3 zO@HqP=K#6q`8c2QZyi&;^h0pGBqu-)IM3&%G;`1mLvA-aJNvxx`U!BP6XG4*Ft6y3 z;|g#%Zs2%6s5k&+glOjn%3v$wC>x*aW^yX*PcD#agK!cr>Y(-;zDY@B@95YM_OPy1 za9CK_od7M+!_7k{!T{JmKgM+cb&7q zc}07KF7THRPtQF8{<#8ZHI4(j9~8%+DkZg}>M&arbjY>v7uVBUST!yR-N zSm?}YTS2nkN4LdXVX*r0VDC2r|$ld^crBHEoDFN9t+`w z$)4yU6MVYBNsu$cmd}Dha=x%bp&-w&j$PO193g9~FjyjN^IEg=vHd_!>;=l!esGl2 zVs~>%@xjR3x0F4VMvH@GDHI3}FeZ2w&eG|;3w6v~=GYifR|`OBXn-Q=Vfy4d{vTlx zlcqXs(yyS?Tv{$JE~}yM!NHtgGwB9LSmnZbibxp*E_S7*r`r(8pCY9XvA_M~0??d7 zhb-+5x{?OYvn!*00N`(VFFA7Y2HA(6gN7 z^ZAXhY-&Ymj-Ro1)`=@9=;+SXw+3uw0l*L|MA?4oNlCqkc3b@}VA`Ja{+5SAgB zZtP|ssI(!q4?yIEMmn)es0oIVkAQ_ON%%&g8JC!z{=L$LD=k3~%av1lVV|_%c)2fi z70UcQwNl(4sc0N*J^(L`b{_Xj2?&5FL(uDd3O2m=F96V~0wE3E{rzU&eVXt2WD>4d zxU9{)FS*9bM+9(A)Z88h1jl8MrGB@u0{IyxJ^tkRGsQvXGOAs~} zcWo0+un1VZpBqoenK^nMgwNC!R#ZJFzNZTg6BIuklL(g=V!ChLC%jnt#bB2)#^vXS z9EKkSEn;)%4ge9AGqMf~4IbN{tqcSLR)Z$FRXZ7tD{rQ^x;{(N(K$R)=7t@uJ{qKs zAU^~+{Zjn!6JGdxKE&7FDZ>*i!+W3_K{_Ew-c6r4)~%fM(h&OhT>z}Djre@X0BM^x z6Q1w>wn&Z=Ut2d2Dsa5+Lr})Mxu6s~Ktxzs`)g#HlOAJ+L3!2M-TefL*V%fN{d~Rf zgai_Qz_+4J`RI<$5oQB3TM(Sc@tAHs-xH-6&=PgKV`p=T<*7XFv1T&v7Cv8%r|rT_ zxlID7x1L;%dg|Tu{`xaU7r)o<0T*Km-1Tnc)dCtmZ1y)!%`$Ae+|2<(^Iz6yC6~Y@ zy|~=KmD>x7hPT})0N_EMs0}h*`_y}IP?#_YLK_izr^EEC+{#3ulbTJSX)G7So)&o5 z#%Xph?TtBP$A6u|KHRhFnOU#))LLuoqLZT|`(UNER4iz$e@@k#o0`g&UcCuosWU;y zg$w$U52*I)*Vd!~j~<^XxZB^a4M3&2)rz{DNX=10P$nTRuO6$_y%>q7%NIfZff&90 z^kkWQH|F7-bx9dq2u&@cpx7Rs{r)rnzC``cTMArmI|Jnpvcm@2#AXk&#d)gb`Bb z!9frh4hkBTi7}yBuUaS`umH3L@JrY(IFY!djhZxdA`*;jsn_}(eFJ2u|K*CfY5b6+ z!m6ra!1Rn+y}QIM@3>6t$dOwI2$gr--=Na=NFtLzsb*It2Edj+e=)+U{RQjt=L-J1 zL-rl=0uHmCJdMl86d+y_w6|<0?}k70H}!3<#q)+Ql3c*l!g8fDoTsl z0E*2#UlvsFy0Y3PZ785~%O`?_Ex!1WF0Mct3M%&d*VbLJbd_V7QiQ(@f#W4Xu2d2_ zRPIzB22n8=ll!RMd(@kS_pjvoqP-v@z$=Cl>d8@{_jd$G>h82W^FGV@Z~MP^@W1wd zvFZR6i9ZO0`Si*FMsqg@s`O{OG7}s3mVIHbhBrsrPSt>8{WGy&-T$>0m?m1jhT)Br z*O~n*%M=%7O%sU2LPPri8>Aa2IpPr7A-Rn?a^Lfy?YZShiR{*z0* zqrb>}h@nD|w#=%Fl)dYejD_Yo>&W9#;We+9`|gI6og;^b3|jWy{(AdnA9MnF=`)xs z$qTo-brKsujuOk(yfDy&umQ?+gt+YwtU{m>N0W$lY^1RN@&iC4`M-f4S&cxL8`fzssEjeQ6)AW2C0o)%sk0p{aDy64iSms|g z35PGngrt6H?2Fg2?M~*ZUj_IJ2Z&L`Q}>m~j@hT3TB(G_3R&U2PY4yc?V!1nnbzRIs<1NChjBeTcB zT5L7|3R=avWTNpH>5&>9^lzka$)U*BIsk_E2ix^r?Ve?#~eHwLsxm zzoVNYz}M{8M?pdVU)U67k9%aH?=nRL)vTX|enboRDBK#u_eqoI_<9-|n=_OKBd;BI zPb)g5OpQ^tDH)#}i}kHcc%UbI>L(90Hsgy{ zAaV0l*MX92yxI7|k=dY5`@8$`0tcLx;d`~v5X!Ys)T1X?8VnkezH`P=kCKyPQ7@Oe&>4bja( zLtsJvb&n{RN0mD^k;SNDQ|!t1R}e>rT~Iq<;hJ_XcMrW z@De+h)v=H%f^C&dsRrnx^kyY^*#(q7EFqy-)my;FeFPe!cGTVki%j@$ik}x3AngGt zj%*MJN&+#!&0EsxKJ}v5;s_~_76rcf01SNZ0Z624(3$DPMGqj2(Y^Tzym)7Ks}5$> zqiYw)Ft^ud94miIvATn`H4xAD43jP0?6BKfQ>km0Q$H~9h=S0}#m#7eMTz@R z^+#U+fR%bk543ZL5?a!gEj%m6XQ8072~@D2$ITi)gK z{vaz0Q*g^1Q(^%Hj7GM$5wMxvP~k+E(dXwxJl0I#cPD19L6)wMhI4qcKwUhd+_BHH zWax=%s)ECZJgK-8uZHIfG}iC0U4k8#i?yoa0cpSns3=SK>{n)$R8$P0qr#xvVH>tc zIojrvVD_Ja!3pSn9)q08lpR%N5_0)Xe>@idthr~8fvXQcK3-E#Yv2}1fKYWfugzIE zU}I#I9PPSTCSdUO<$q3e3Ys&3EGM+xMZmEOo(}+y$%bARz=JdXLF906E$4Q8zM-b3 zrt7=1V>n0L0KNbKv%zbXWDpf9s;Y)SKoWEXPzBH$O^bO^E|}H&(-g4npVow`es+bZ zUU%1)l^DLs?4w@uzbrqkfD1f)kQoq>3Z(-8phvgq?qG>(Jp+h>pMj3Y6W)stxa_Oc z+%WIYlv{+HwYIi7GA1-cIqFSm-me<+FRlV0A$C=P>z zniFIX0Ll5yHwULaR8!CNG)>s`^IX4*XnblV6%+(c`e}AD?Sj*AbL|M>VcUQ1QlASK zCxAg6XEivnoi8>H!ffQIOt-(qfAfH|ey9OW?>u;Sh5W5o+1kh-pbG*V|2s<`m!km> zu_`HjA|uric3K=P<%R($%9#GfeQx1S{dRr6b~!`C#NCH3Ew00c;I%v(EkBO~+qbEd zZgkM;yR`{e_ES4W?o=V$WiLZ--uN%rY?4r`ElVn|@Vd{))m`B*-_WdDr4if*XA(*X zNNDY~fS-zP1X{a&fvU3zYR;GcEYv?38EX$7dKC2D{85$_5D+A7*EaneO0+m$pD75X zGG?IC3Q$f4W?d8vTjW!K7nx6qw90R9SJ45dRxZ(NXFcsrC;BhggB1gGhyiq}i0Y&m zgib7>Hp>HRWO|#gu1f(FfR)^T)K~EFds~-_oU^kDDKF?hKLHe(AFZL#e>VZ8J(2U5 z-i%tAga8W&5((fz*PD1n(Q=bOt?v-KlG2bKB~y_gEjr>HJ6A(-c0Q2pJPx>i+~MJP zO-)T3^93#1?(Xh`Q(dbMWaTd_M4fdKqdMxki6vRVz5;~CY$0m(UF`c+kFT^+qf6C? z2h+aoAy-rl=%5!F{FlQ_iROnZ2OR8Vskfka+92FkYVpO|Rg4}1Tg4gUOSG*@vT)K% z-)DiB4B7@joGEZDC^0%@i)iHNo@3H+2Pq=uAneGv{|ij*6iPv5pnsP>q^GB+JYxtF z*aKl8YCQW*fazfv1T{9Kd8yK0d;1a znUfPF7-{-zyo7fMI+JTz<_X&!w7>MmfF4)Cb;IHavjBcIM#3Rz!l8|NxyrQU>>G{R znMq--2dPH>%YXnt2jeaI`x}9LclMj==IS$QF&O_*{5{$-(+XZdPvSu2nAC8YPNxL5 zzf<-Ide7B4;efiM#}n|7M~wP1hKlaKmE*bcR3+mVNTfuFcZ!VJ{?nPMc_jlk(dNQw zg8^`)+#F~+TI{K^jy~A>3@@G-g^i;s|5bf*JoN$}aWd_**;*dUG8cJRgVV$925=aV z=Kg8V#;6I^m%*D}hCf1-*~Qd*cSFOnhN5PA*8VcLQ_HJ`JBl&<*J4IQ&{F#IZh5y#zJmeG5vavGa0bx{1 zn`XsBUv%EZ^0Tzqw3B4ajHLD))5<_SYp))5)WZ_Kf$l{nw)K#2IsT67L|w38U}g+Q z*$k-iCG3gQewD}YEV!#{`NB|niQLCTAfI*2D`Yme4+BNEJ&~hHq2xbR-wMg|3tRDv zJe>4xQysAfg>r7%d#OHb|1AtM*b|R|zDbc=I3!MOlLqH|%CtlufT0pUok6g2@$$Aj zd6#tr?vSfMVIO-wZ56!i#PrCmxjI{= zqN6jn7#{j!e9lV1w+Rha8Xin&&&TqyS#meqe!7>r&jT$?TATW5^7}i>t_*bzO-*A! z4ajYt)%qoGio?P0L>q!?7*a1fMOuE8^dc?#xl9?75qgR(pSs2`g&0)qYIxByXLNRR zO0wKj54yEn_mSoSgN_2smPtz5!c8ajc17n@5vr|qO#Y*FABwfM4vi#w|omD?Yw+myZqA9h& ztB=0f>)f3JTyD-%M$TOmLjI#AgZ7+K7M@uffbh=HE%ySorcp$in%`oj1{1og{_2Yj zrrZZD=j#CZSqSIHflwh_!)dvnb zKnYWzTVfUMz7n8+e-=4lS<(l@Z|`_I5I5Jy4*xx0VA#^~KPYaoUv&4s$^|V2A8`q( z2COCZa{c485P`+>DG2C*9pr}kpfIlV3X)@X=UKZZMf!U)$edtWL?t=+7M1+OzTRy{ z87q{82M4~mu+=s=7}cFe*Fb~Izfb|>H}bd)zb0*gw_{KVL>pkz0NyQceb;HReOys_GLHrDXu;Z0p3e&*0h+*p)d$_QWoQ11# zc6OFXBTvL8Mdp3uDdq{if8NFf^4y#WYYqWR1B-;Zj$ye?n^1e`REmmW`RTj$fPOLc8IoIJRu!0yI& z2Pkvt@xkIuyZ4rhXN`0&c<49`+-@j#=+uss@*xX$J8iMIKwCb=)73CT2oek~9)0;LhmzFxjyKEpRac>tGI)$^mgW~r^O6m*c@OLEn&+Wp zoyyY>l-jpnT)RpAiR}NO>n)(F>bmZ6fs29?_kx7dmy!kn>5^_lI+gD3k`gWMe2YjKLDv3t-W%-MV~g?G6!fK3CU2r6Kx9I z7^X2@NbPFs6y-~8*O;anx7DqeTIWitsqKNN5KT`WXF7aMD*~r>hu1{?!T!kteu!y} zfqU8gcgsjqIv~dT9iLbvEo9XoSe?o?vI#jYd=Td=W%T+^4*WQDL zZ_;hi&k6?EQ&cySiKGdG<6egi3-eRF(ijh0_LMQv{)Ai_VBe)RXB zH5pT9V)!9b?XmV=7j{#it-DY_OF*3~A; zKZD^n5Z;OUnOT;%p~*ped@u-C5;PolNxguGB1oj_4TDU z8NMX&m5-ixGB;^hNIF}iGv^jdXInIv(;x}2b^>7#=@nS~6P-x!1+cjCP3|0(VMJCs zpi6gVzkbOM2}Ku1D5$HO`6f2nG-HqhX0w?0-zO1Wy63VY(-Ea0l{ukOvolwGXlQUR z8nEIoKw1`Hz^2mp>V4v&o~h}N;sx97%+tVuqg%TIA+sS-j=iX4(z;HQlBg6Pl6v~A=QbomT$ zEzq*^Yh%|+Ps_HD2T(t%0ev&&z7L;hZ$}Ym)ES$uyXF3{RAQ4>739MvYbx>$FSSWG zE9TEz&l!|6UBV*~Dp}|O9qhX2nr3KB(f0kYa;2|AL-$HHjoZMq@ED?hPRohxDB&Ex zBe+o@3R-T6|4Cd#U*0Wv9m+~x_#Li2;gVzX$KQs>Vnur=W!l%a42=n8Wte`Rg}JDQt*;Yr3=YEj z$2g1}+Gh06=X$|?Y<0R1z2gRQH>>v|ulyF`zbr1=At~(>pn+Xu?fe5#AoHM%6!ggf zW3bcLfyT25_iUGN*2fGpnQZ}R)fDKIAAJ9WcNH1b54NRKb+)hXIW=UO^GJ`glkz2~ z6M)DL*VpQq8;FuAj*_V`cFlP${rq<4KvHs215AUC;nIR;yTHjwA$)<6s!-?Gb3A17 z!vUISh^{ZnfB$4l)|2NGzTW7Ir9K0OnKRrQlbD)$j{j~zqq>mEDsiHK1bW3Vg&;z7 zzGi>vBEmYc`hH|?0BRI|-S?%ogU80=Bdx0ky1H)t@8|jDZS{L0({D`GX_hPVeB-!v zC3U#&S}2eREbqfK;3&}?)QqKS#-CNk2{9zUuOXc$k%<60k6p_TjrYR7jf1R<04Ptc z82l~v3RqPEp}m)KQOy^ve_8ZL&;T+4w}#-uhd)?lp79W}EoXM+BG*MmJ*GuICY;lWMBi@+&4I!(`eD;da@>j4!rbOb|OrVr3HB2xQrO zD#fHxlA^DnQI{2SB3^*GiR<@0C`Y^!No$>b9Pwp+=PO7o=9i??>{0m@^hu4qxq6?L zltiw`n!Wz=9uaFKYbV*iyHkNEs6~-phJY_ga!MQ+=yP}-d%(OSS8_=4pF1VzvnjQ| z>DEd5-Jvsm(_OMJqX7vOxcWFn3T~bz*J@+TZo8Hx<$nK&qY=$0_yufndSe2%(dDc7 z{SE4>juCn8ZeCD}sS<5B06ES9`+<$db|kk&HFt_)fW#Bhyvsf~e8Lpc>MU!T;yu6d z4}DdJ#J)$?{yS0|E+?@W+5z!{;=8^Y<#ge^vQG~mJAlT)Q*`B!0l+#uk_#A?@y~YM zcLt5%A~UgHMp7gj>rHfU+%y^m$yoU;d)2&|9z$&T#fOXX2PQet;2lxf0a5sLxK#n_ znU{QN47y^aMDLO)&~v4L4&!gVA7Do~DjAQoQeE^|!=3=#*g%8h+8%I`bnBd?(~Uw_ z{Gyd$K{MzHukHijKDPme z{Dvg*{v=~^WcsFn`Y^oCwlfZ_^=LPfjptLI4IKNO^Aj918m;BdnDrX>Iv{_`sbaf+(lN$w47gTt5nt1O@S;0s0N64*x%ki`GX5`4J5`Tl6$UU z3RJCP?%G(%?2ewf;?$VhR#S_QD7m6UfB3MyAzrknxufG;pdNbcT97F1NptHdgC4S0 zA@@CfHWAwTvsCHr;|*TL+A|N}dTZYXTIcqF(NTT%F#q7>NaFvb6VEaL2V8qPQ)Bsj zd}mi+bZ@p=@#9!?@Q$qJ3p((UpiU^*vh!JlrMa_C@3@h#t?azs^e1D>D;w+TnCjk;G%`co zL?<#rjAUeCE4K(d+{}S+wjLA@-2_v~x6y^vkQ}c2gTM3vGGTg6mvP906cB*woxb#4 zE1D%Fc`TcI^%fH8e$i6&NzZS`d9{gydYU`G9w&9uJYzJK9_$Fdu?dGJKcLvSsC|zv|5GYUH~@ z_GwLH7s#q@90z6Va(+oz*G-~C#^+;K$$y)kEi%nGiXvbpvK3Rw6 zCrQK44;n^4)T{T-?1IqSPMD0Ecg+OCPfT zHpi8M!zlQJ(*Q%>ArEpBq6Es)P?A#DjfvPy%ik|-?ORAOJ!r_g85b)LD7YrlEQM$N z=GlACxVuMw5jQQStPq?XfN)r@R9u-EO9=Kc!GKYSh}sGY+U=Cab|im^%x7)|uwy)W z&~)sz_qB<~()9gz=_@Af}tlf?-IK>>besqytTjt$SLhUJEKVxkBfd8w(w(~oo z2G_iOi@Qm~q8|l@1sqfKoy>95;Q7lq=ki6WxsM zoNC=I?JHpZscXLfW4DyWY5qa~W~|*dsc`qynaBj;m!T>G-M~9+d6C=p$r8sNdLBC% zBq|Sz((AVIDt&MuPfxJMpbqDY-u5mML}XWhJKdYowRgora@)S=O-rs{jfGF%j;#BK zDEanM58G{66m{(2IMN@<=}p|#I}Y3_Vbk>0YHXLIYw5( za{pG>$1E8UTvnMyWgtxecJ+FXCxB-5AEdgk*nb?|WpAE*J1=XwXG}D-sm(iE+xzkJh zQ(Pbr(aLU7(BNST8qD*Sk(%89P$9@E?nIYClP70Zl#|yVejGdWo@azk@m*JeOMgZR z0ABfswh-$}ymNtTS~21u`QW8|CAEB;SK)RQInjmk@|OvmHs(3HgTL)cD)ZClZ54-p zVsD}gXUS8z=PQTJT5^Uotmm2!lN*`l#&TFQ04sDrR^j>4%&eu`Jmz0?#*nZVCns+0svu(nipHd&HbJbiIIvy=+S7%^isdWB^^M;0v z%rakrLSna}-H92g77e8MiJH0G`|7&))y$^wIUfmohPB4h)Rack@0%gW2VR2Q-R z&jVbc80r2YZBqnV*!KIswuRc8%E#{i(G{o!fTm`5OediBIZts3*C|Xx#?Nc=Kk-u+ zg_g>8scu|nx(#~iJvLX@OfD-tm$Yw&UuLMv%{v#@8KPVl(z<42z4>Uqxs3~N9RcB) zl13$CdorIUwZRKKH6N`g!8rkHAk|nKS>rAlQc!wem}c|r8&wP^GGjxcJUrnMA$2Wa zThzx_#Vq((C=~!2Uf%2juqji|v712S*|S(q$|pd@@R~Yf>^?Hr{PE*Q8SM->g`z3) z=l$}$n-_>ccx7~psBH80Q6BA_RF&zM1${ByK#rqf=md1@0kX4U4V7$Y=3Jmq_ZClxS)4VJ9R4&BrIS0XH7PT=Jjy7Y( zph394t?Yo8I3{Va_vj3KmYY^d+E-#GkLInzR3s{kmqOt{q)hr=FeH#F!!QDL94kKK z*>uIz%my+F2k8H?@+p!Gj6I@rxF=^)TH5fPj~L%g zh;A@}@*p!aQ;ZGng~-(uB)s%Z%=p~nRD!N6I2>k77vMojsg2#G*zKUtF?BOOcAKB) zvP}r6vOYK}bv+}*qeQ{ONb>L6t=QL^rFVNY);T)|wIN!Zr_XEG(?Zi#!q!pk5k-f! zyXi92)#hayd@SPP`tfPS#ZyEWl{>!2M!ohKd4tjse%!2eViVnd3TThSkYRI_ty#ydtue}jeY9aooQT> z8QCtBeJCVx>m?NmxUMK>gIXTJRT^8(%***R22$+zXhM_T0@Y4r{iRPe%?RFgsdsA! z4t?#OnTZSYRTQe2&1Q1IB2DQ!dc-Evp^#X)p-0eMlb0}Z7wyu)+XJKP!v%+6iiZXj zB~KI*N$jmZz%!PVhyhB6eN|B-8YmXZrCr!nIE0I^QhaeBswXZzz5KP_RZ1tm%{o*5bT z`=9#SY70?9){{fEt}o&<>c}2+TMUv$2PHY}RVBVKHiu0;G0!e`--6)CO6#U3XGumK ze^OH`kc8lB%TE!;bG@JnAcj+}O6zr9fA3WwPfK|^Op*J2XJc@6=&LBXj%h5}hD#9yFOy)7`oc-<>XRP3488J)0yUb>dKOdwvMB$=3Qu*}-jb9zMMT6w(Gmw-jYM^Ea zgA-PO@?*LKvibRK5HNWmj)5i2Sx@|f7d zoO(1rJvO#NC?Ii)(K~IhSQq{tvGN#hZ_+{l6OGcp9ZoH!pK+Drn_+vWF8*?`Q#V60 zid#aGPbygQZMoL#J};X5;rmL*PYek!TR5ihtYyL*9xA+lqWHUxy){tk@#+mMb`Z)N zS=&%0p|b_uWrXozWr}ycvqLiyFwJ&q90*lKUgi zO2TIR`?coO-nf1+UbYHL8;yyNEr8%uu5Q&YtIQj9Jj-L|i^I1JdwBeMJ%B@+lGVsh zCgV9wgthg;NDmlYKFnHxE&-8EOlWw(2Cy^iMJh@Hni|NcQ1OgS0iLI0bN7@^Z81p} zstD+6$QYw9Bhwj>-diqW{|t*m{bjxs6zL52mA)m@$WvJn)52ChjX<)O}-8cVG)A#!=#>5N%|*1?9HII>4pTl+V> zq_YE!L9ZcFoB<7CmsIw5HhLuOjB*g#r%i{g#r14WdykoQYGmx~XRFYG&}J$L^Vltg zmLEdX6PWTr`2)XtTPFTFbw6b7Q1L@bMZ^X|kJ|Ai9+kwR&Pl1QOWE@p@5Eh8EjO73 z9M2TY+MJx629^g>jk=-ic+-NHB9syLVfNTbMi1a3!vTaasDwwHDQLu>8?wpJ9aZUG zigoN0#L^__)R**Y`}OMygkMv6f8A#=nQc0-9zh9|$7$~z4Y8omU9n0C{yp5w);nu= zzu8_KpZ9^~K};5TZ|x1?dHgL*h;+j_RNjjwhkAEyRR><{yHj)YD<3+) z%*<-Uxiq-QI?oj@^1P ztA5c??uKmzl@Epm$D3sMzE+m9sKh)v$R-3-i8CXWEp}SZ?z7 zit$GWu@6fuO9QjHYc3vaGnl2Synhog;~<7(A^exBgbZ-52h(DD;XbC7(15{;W#Wdjq)Ubr$4e`wwiUPw zNkpW!kx$>k#`)6{<7NWbK=)&nZ**`(iFQbo|8-rzeMr|q3@tHJUO@K!XX0gFMJt0k zKHtGxeS#SyN`u+p#XtIMY;EguouPe>=!HPSjrYW0g@EZpVVmOaInO;Ch-P&mRXQp)_UFrIGSgQr#Db^Zoz4g|p%p;2u5dLo34|L6gLp zJ)`ZDS%Ui^FyL7+ZSW;Am>AZe=&OBg3PF_1R*-eR3f-t*Onx0zzbHwWc$k4$+Su5z zI_f=9jTf)R3xg1yiJ?-*`Y%lPyuP6NzM(h{%wjW_uKR-~fvp7Z*0(M2M?0`3 z4+a{uWnJJ}Cev)-C`9jhL1Yc}ydwRKme_oW50iMzk0_QXA>=s6rFWJr8D>gxl#LPV zOq2|?A3qc(_R6eJ6pe)k_Sxrj8N~7&R%)y_QDuctxe)I{29lAn2dBJ-83p$7c6X<$qp3}s8TPcM_)*liQ?(ceXXUlK%+=0vP+*>Z$ zY7F=ru_Q5QO$idFyQ+1Bj8^5z&?RjfSJ#+^Vvuv^Zu^al+y+}nEH(RTS9D+IOJ#8} zQM)`^{W7g!$jLXA#f5Ng8K4oLRLvL=&sjlbAGM{O0}m7E&ASN=02zp<%-HgahG4y-j_X!pk6}e0?zC{~6f=F=T zAAN=jqQ3O9>uP^N}TXg5r?~?6jH*ir@ihT1h>r4^xp>$nI+af)_|rj ze}gffXDZyJdA4@Z=Xv*L%o*jKUZ(bg&OF$FdTCk4g4L4iq7K zlwDSagFNBE%m;aGA&c&!e~_&1(SuHH1v3Q zi$>ReCh(y|v>k5BkW@Ag@pM69C>xny5WS!ny>1J44H%y1h5`6de%(G?E^gLti%4hVPww`O)7f7- zss3&5`Uo%0!V$M#4YOQbO1QM)-8s_xo08K*Co?A}oZxHA5Q%F9AKzm;VBRswPvfIYWYC)nur@h+Yj-RkoC7Cz=(%XhLu zM~PclCG&14R%6{hOzU+RC1V@N!aSVI@nNEK2rYQ8!Uuj=m4`oKyQKVxXw5=FG`8hvf#>0Rou)o7Z$md>%ATU-B55q9% zf%LBz2udlpJ=SXbNN_6mVwM!OO>~1-RU#VN!vrEQ$g75kuruS<9l^>Sqg<}~VbLKT zjH?-K?|zxr<#b*93$?^a;ang=ht@&nr?$4lg^ZlZ0#4iL40z5MQH92qE=qbFh7gRt zp{C|cg7h1U{8Q4h&jcL7@5QMCVA7(2?RH-q2>^-z{k@xRmwRP6%(_NPxs&lo_jWjG z)&9^p?Px@iK{%G2B+1BAl?-Py0Z=nvP%k_2k2d{OcWJvaZ459vI|bAHo4XS~z`5@fP8ipzPlNxZ@@c4u zb&_Z=Gm0JdE>74Q3{!JbbY#d4e2)gojItY&yXCbR?eJv(%(Gx9E+DzY@Kf#jtJSO~o5br&4QL(u8kttCCaqX;G+bp(}% z%&6`k#024rTK6HHVQlT(VY2+3Pc5JCfqW%Yxt2wgr@4tkNsYm>&0F@3oAuEw6%&Z; zt4X4t!ze;=DSDy-kW&%xcwT-6W337;jEI`yVK=Q5IkA-LxA2B$N>VP&E*J{zJ{C-( zAa(y~V^AJ)8aJ%fHFBg9$b3rt2j6WG{Ly)wK#1^oV==nCLWO z+7XlhCr~i;4x5Zt%K7;Krs;WK@#*3VxjZ6mEwD;tD2Y+w7O6mAFjpkL8X%(wbu2$o zlXw*SSruTb<8TB-8k^3w_LRvo1Uo={DlTzg{!;J_RgCX(y-LYf8eU5~yE1Q~`x;CE zk(LBtQ&R=nVtmD z1i0Jq^agSdGmiOy9iBw|G_;}WLV$Hj^kc5Ao~;?}39sNf>h0sh7H_k<`! z!VaRz@092Z=TF9R6TY%-x<%ox^nF&Ts@`XF^`TmF&ZSMS9QYBjvi_@~$-9hnLwJJE z94RaKuxTrl&r@-B^YbghYwP zuW`a=#WFjIxTj)oVRkX0BzJ}^ju<>Wuljv2x$5V4b(;sd9*b@$27FF|R(Wq`g`N;q zia&1noR+X#nI-@xu+*J^l!!~JnNZGoU|)bn?zKh5;zQgzy~@dw_nGIb3;yxR3EGVr zvtcVK=$%aUyW;sy2@@@AyYbrB6?PD*2bVxb>12P1N*X`_>-L~92~;NOMjCqhrcyqv zVjMoUH7&_^Tvya6cnEG(BPw(K--M0H*On0`C!O3F|)83ET=RfYS|vqe%n+HgxUU?e%w< ziD@WNt($QvA>1E{ zmIL&iNB8B<=AMkj`44_eUa>aUq-1&R7~LFbgpL{c7?s2P5>{BRn|XQyyt}Md#P#$f z*4ZU80J*;kPC5DDX6Om{H%HL0_Vd>}!Qahy2XcVh+ymgRZP?L}`L)HH#}6Fp`*gty zSPz7d*fn(68`^+yiHw-upyclXehg9HzL|k@do^@fyhw7dv3WJNpho~d0K*~0iE`D- z&*C-XelyQ=J087Bez?Q_?6b6S#Gwb-J!Qsr9<~ChXJzh=UFZE;@~AZIahP2Rq>Og7 zaIRnyQq~z8Hm4Itv(KtPj>}*Fohn_KZ{xo0&NZ+l6HT%%``D9TUJL1skK|Sb=MXIu z3a?jdc>x;u1yGCd;~@5okx(rdZ9`PJRVrwEx3IDrFSMVZpQqyByB$tuB-!_8vCxc4 zw~y6Gf2_ZTO*PW`#gZ0k`7d!Zf66P>7M@os?(cF}XP*c&e(YOkOp6)QzC{Opz8b1` z7ZZbj(iF5M*>#u#s-gFT!CBaMDtK5!2eAfKFSnY&I-RDTKFv&_#3h9mjce8ly)fws zDKFul5PAz!BFZ0ep8vs3XPw_f8TmHZKVk6BKfD^BD)V~a!0$6$mVbR8xluDOrEXU1 z16cAV{az9KZ;9V_+rk1y$Qvo3ggUSGI;hy&*i7cLu(5?gk_GD2G&CCB$(Sq}n11rm z)b(Q0r}}eSPc4JS<=zwy<5TR)hDO1&8RJm~H#|HCV1rKK{YV5E)c^g6F5o-VQ3U>K zH{+c&5~+|a|AqcIQC1*=-ner6v0zi~41Lh)ea_aKTsb{2F9tjog_C`JLgWs#05MX! zj`7b}6=8vQfMm=iL#47IhASxGe%Lt`T#DUML|y2M*nhkLu*9gD2CHR1q_B}lD!luZ z2#P`2?-P;V*+Ur=z**1-*dOql$pFPM`LuiM2^mZH#ct>iNKiM?lOhRMU0vgB%^`}SWp!Wk1 zEheC>%gI4{yciBx@iid*K|0k6y;)jo6*z33$hRJ-ydyj5$AGst`8@VCND4Mh@Jdoo zN8;K8C_^e-cQtI$jYvU5++{j{2;zuu&f2l9a@J-{82gI;Z|wf_J8=g@4~HZ*)UAEC z0)05hYHYXEf9Q8hr6%(O^cgKbbiXFYyE^A`OoyL(Q)Oqt)|vQ-E#}HK-|yK@ts(M; zyI4>-j1U^lH7WQaGzL-)w8X{H0h@Tm_{y|1HxlT{Hmp)UxF0=M&)|}(%iw@s2=egD z{O6KB=@=AorXhb;B63&!^FgJGd9fe((UXUHsObeCd{8Mx_nfL@ZaJEP2YzU5w9)0xwk~ zv?ZeQ$Yd-;5WyJfQ@LI)W{FvaTlcSX9fezGmgjA0>AwccZHlK14(hCWBHbuTV@ zcv=I170I?6&|?KPto;f0uQ!3e`=ppM_DVOPNoF1>ZZM;E8Wa7`+!RF>7hS%%AjOMx z;S7u#f%Rz(g|QPw2Yoa_WzLO%p&>7CML@1Uj&K(iSg#IEpusQWEhg1ueIzXxXI~6W zC##G@j#d;Ea)EQKk;FvpR9q=^c30dI*}r$55<^kNBD+ z8L9Zh!$|OJ+FvrfEttX|d#4(pc*m>4*O0!>R&{V(iPWO#WL$j7j^p6sxCDOP>uIbK zS|F`F=W-b27qy-)m2bR&vIU}zS{$^f!#SgJ=z9x#-4j@%^A8SI<c^nuo-F9uE2kOIA^qD+~sQoMciv1fLv@vhf-R3L{N=vbG3hk&- z)Im2(4Q&q)l)f%1V)SjsAO*H|STaCMIjytZk9lj16Hg&cDQjF%S+ zAq{9F{&Rd8ONxvkp2z7B&(h=%xQ~pS!Jpjl-|+yn=-^A~kzVC2^UI~>oqdcmho$|* z2O+Y#>HGNnbXL;@<|~jrTji;zy-#5LPabs`zY_8)@%B02$gay`Cs=ANc$lv=bTSsM zj+}eY2gsrX!TtJ~L|27YAc>iQ!FQ1io&^UqfkawKK)gAu#J!JFo>D_oi=cCSfx<-<6)F*ToH8IJLZE-K zhDlCLOl;H_-K+g{d*TSRtKC-?Jc67Ze2oByruY_gU8HcDjkXYKF;=Sjj>&7)UTa~NwxpV>z%L2ex&A*Za4o4a#%)!@ZJ~JOkE&Mxi zk$(x`rC;8&csjj%{;BcC_K;JSoyzeTRcj!>Va<8(*%6;h9YNR#6^d-)J^5&E4{O*V z{eH2mGRJuY>2`rSioE1j0i#m3K!k`x3D_-(RZ%(E{D`@_K0OVL3dexYYOPn_4|~f2 zY_UG5y6sFBN;#MP_f8-~S%%s7tL(>#=ID6Wg| zttc&@r~vxm5xMbIsZu;%8qNp|vfN7swD%70|5OWcq6-KLwCynY&!Fk+JCG7!-=$=H zpb6}GA^m-Eu^n96Xm1_={MWzxj|`gFY1LObM6_py@fn1sFdyrMtYxaWKeYX;=%BSWUu^HJ?vI+`a;AChL znNhX>9TgBnsH~w#s43*&{5NWWt$l8j3!&e<(}PxC(wloC)WV@!PN&^lFMM$97m%1d zr5`5Z0#^SriS}11L^QF15*T;6=jSAz*5CyvtV&<|aqv0)&mKT*3$gKT6-~}57pm#J zntWb;WOr8U;W1BhXzMeH6*fmxk*&2daZozVsdXpWX~c>aGz z_6}-XYu))0TgL8(CA5|g2B8bFoG9nDdI8u?a?N%9a`nr)L~3iP&%9>pxPm-9M4(}tu4^E>prEfMo#ek?J`EHXviD}*RrLAQVy!{wZdK0e zenR>3n6LA4;IZkE>Tvq=NP7dV*Ckx77^i`VCghS@=0$EBSPJ^ypD^i{fvh%ImJc8# zEk7GDJ?5yAsAOdqhtPz`I(`5553)pMZBoxKOT1s_FxTSNCgM$Z!0UmDQ~%A&_PK6a z)O*(A?)4&ta$A6}Aiuc$R#YLJLMeYj2zU)mKTUdpR1b52ROSWFOWpsOoglCv9-+(+ zt8e-Im>tfI?wcJJ5;p)_r&4F;BA^8{WXQQ^~ZugWJ+uhm9 z#0&^a-?cmil{2^~M3j=QJNod|?<*9n4~qZmYLs;tr)A#v^vjp-X>P*G(RO86I_-7K zx_%Dpl-n}ieSVZZ*t42*#0Snt3H+0IQ#$+s426B_{(*J=repyt^vf1}{$-%;XFdtH z1*gZfYI#4{`RHQhbw#Ih#W+m_<}`Crd*Z!;*n+x$9tbY(zbb9QTSfF=Z!O@NO#6mQ z#`%Q#Xzi(3`_x)udq4ta!+GDv%qHWHIkSTVz3XKr<7aWntF_one;?&3wdfHwmz`kv zKa~OUUtc8PS5VAE4r}^h_ugO|$j%H|dW?H>-L|!w7T6%9lB~BuA7L95OLLu8X;Yrm`>ArUyAlF zis3A6yL%V&pScF*LJ6l3qh3S#?-`r@#=pAyfxi7B?k>^qV}93UCu`3U*V{LXUNEm^ zK09eUGkB}>Gw$Iu?e$uUKXoFg($j8rNB41>M*sJPs1c%v?_OE0bbk4=>Z);PXyZb> zSAAxuQleXcgnK38-m~erb<^Va`0gv#O+3fGHR9XM`h+WVwgE9P2zKC)a zm1sb}A|umvuVr(YU>?p3Zl||v0Bga;lch(~#A81c2@HbI7%fzDwHKtHm7w;lsi|qV zs`a8hsV(=#2EITByV_!?or>+V6F8dE_%Es`JK`WI_^}y;+Pk*#mJOkOC;!y>?FH~< zH4QSH#Fp;XSJ^uDu$=$5P8Ivp+UkiMz130hlroB1;LWoqkc5_9Kv ze`-!}EJ`4PDzZ;!pQkIXwHD)%oD9d@DTTQY`J*yo--oVVz6##%`bLlORI9(ly@teZ z)@!%Yf71Eqbs>aMhtI_27G@^togQU&X4ZWuy@N__;r^7LeD@<{a4+9W=$m}?U4YM9 z#l&ac*j4{?zXHhnP4hj|uKymJni>ht@uB48e9o|ywiHD~NQfh#h4aD%sR?o*UF4Hl z2aKp@?DDQS+KLm(nL)#((fwL;EqS23#1nIhiAlna$ADYyFHWVaf54Mf@Ori6_Oyy$ zDwy;?0iKu+H7ec6I9MQ3I=Jm$Z5fTnSGU{kr=sors2FE6 zR>;4Iq`&*i1d4VB(M5ysU&~W zMNtok-cENwLm`%nF6^8@(_Pnt?3}_E;_1<_DOcjC@MFeME)H|r*At8Nz2WWZx1ZXi zWfuOFbe%GktX4F?7|@^7Psx3+t@&ixT8ZPL{%we>t)?AH< z!t)u^h$Q!@mtr1L?s(HDo1 z*_rbBRl*JyJc+`ZiNdkQ-sw!z_fb-u24Auk2{qma7(zxP+h9cU$p8rpP;R`GkbotN zhrzu{yYL8%bCj85pZv-S!+UmGCtPN~l(8nM#t}8WE1a2ocoBfA@G9L2hxXP-wNr1+ zawqp#UYbgs$pzA@ozqqFtbrV3){+2iOt=nKtZP4nh9S)M~t+jAK6Fy!9ccu@ts#*K@{Eb@|8v#Zr-((DODD; zc!DdU?mhrL7Ayqsl-|362>!EYtDwf6eRX^>=pV@>w+B5}H6?h^zsgRV$x(S%fRsRM zNj-c&N^Vj>Uj4dgv?^ujT}luIqjmH#YWQ|+v(jD!^4(?0X4Kac0U_xEI6ne3SBkeY zAAs^Ig}@Ot4MK8TQWSyk8(XslWJ6mgPZ&MhX5-_sR z|pn#Cz{On7Av4?#9nlxulLo^r?Qh57b4DLf_K0pr86YF#Zfpy8|Bd zh?;f5z{@KzqNRQEUFT)(!$#+(gyJCEUFnLlS0q!4hth&lp*e0sj1GwTa6G<*t&Y&u z_HxHSS{UZr*X)I~DXTmC<3Q;cV-bN@FY|b`QNlK&Hf%r`z$u*!n;f1f z6cvYOsihV|qM3#nrLCMSrr;})0lj}iiA0E~o)*$4kBS&2rvQQbDE88;%YY3p{ZpN9 zE}r0GpzE5qsJ;0`kea3X&}-wXU>w^q4vmlHk7B8tc&}S$J7b_g=c87pL{b~w0J}6q z&68@Rd1MWZz?1ka2jy5rq^K(p{fYR*ijWZMeeY@}-Fsw4{V$mn%ldS-64z<78)^2p zRNqC((fZJ6f2x1@Kye9oI1xD$jABI>8d|#Gf>D-R7=4OKzMJO0gybYiR_1t?2~$t0U~jgAXxeQJe~MFMuhN;)R6c?Yf^P| zz287|nuUVQKwM5V-`9S2iGsOVgnAB|<7@Bo?ISvPx=bUN8aR`YGh3eq$H7m@D8|~d zXM-uLTsp22k?aC+WOg!x%mdHLFr_DzMeQ2)nnW}Mdj%?{TJURh#h-eSGpo~gJf?gS z%J3@Ei!%4)j>SzJUlVOZ8lL#%>P$JaBv@HB0>^J{)TAEtGaEmVfF*|Tbi|R)ZzXEQ z$0Xzt6&#)Xma{Qrjf#T66<_dp@>!?i z!#;b@9%!5uNj@&!b4@f#uO@bbB`#e97+=^+ZCnKvd>Es`s~{&30_?$L~HJ{XW8 zl^By6&?Reb)j?h9hJ%-2eeWJ7L|l~oJASd*RyVIw;5I*w-F`h3CJ#&Wd-m2#e4m25 zdkQU?`Q`0+*Auy|YUY4D9bHg$6W?YJBMfFaiDwL9ik^us&z;3LZX>$}e!tqkN=%%G zg83tb<2El3gz;}$hO7K4vmq@QOFvm?`*XfN$4VwJ)=$CrTUwN@0zCxCN&$$B_>Ma7 zqlkF+*duYAtW{9=2Wk4z|;HmtTkRwcv(Q%>9KU8%0zmYef(j7sqZp$CWZ<#y&sTcj3qX?(7tkd?2?w zvnx6y+>2O|^i^Zk_lw7Ibg7v`--Gzm-pEqk1&P$ToKXl?Bo%xWQB2Ry)A9(*a?<6k z%Q&r!z-BLdu4(i;KIbf&z+qeH$o= znjbnS;A-BcFASX+ZhFv9pGtCKym3pGQU{{PxZypLKa#PT-4;sQ%>TLHLa9iFkd$_2 zrZ6z(OxJ}pOod;2>@iLDK*O<1WO<*8Jipc8cx)1$vsB+$quO5b-}yj{sd7y$bwPbO ziN~?1Qx9o@drX$!vnygGiGVOEYKY0>)2Cy{A0L?`$#gGNJPZx;BE22#DuWvK@%lF? z^S*&$bLI0z#u9>Ky=RyVQX&pBSx>?wHg1yK#$I)%J}0bot}sSY_s6e_%{NX4x;Ho) z5{l(Bx=rOIGwg!ZNbydSGAzR%l0)w`qNre2oe5IJ=>$sCdc}E*y~m$KG3bWtn|p!-6QC5~8Glbcu9} zbW3+hNOw0PNDERDA_z!#cXxMpNlW*89?=<{-~0Uq-#cs8EXTE;`@YXU`|Pvh+SkS^ zu`eZjMJMDEv~xelOpRA)laNmMd*d$g%CJey#v3$z_lIv_)Jru1804O<$KE#t5!dFm z!}!imOP1FX`aQDmkn!&+55?y_%f`$wHGcZJ?Ven0ywNMr>0ReLSm>@n+5&w2Vd0qsOb%rT%~+bBs^HCUbm^ zea=nz7(_+Q^jyC!BRBrjvQTi+;j zT~Va7%R_HX-Z7lfMe8p&Ps%$85V*C?Qm`k-5EFUf<-|p06c<}Gfxy204XUx}YGDVs z4Ldu*Mn4)ZFqMFWySh+=XL&9k_bP`uks=jY^=C^&H1_cC3m;3pXSRC5mz|cab@^as zmmAtJYu9$K7ltW{b9Hed(FV;tJ0#myT+0>3tY?7S{^LqI!<1OOs|B0H?VdNlJUq~2 zKz08C(&zi!wzZeyeHI#ScqVj@(!Hao5K$vDTH=*9QpmC(c8-Vv8n`?uo1YyvwF_rM z7&0CbVIY!+bi z?sgiwwc?6#EO13)t==#s%E;JGIF`ft2S!WYL*8{hA!b9}{lShDjUZEr!L%Qmd(b~m z^xy5mcTJ1FR$%^G}o6l zu04vPi69HNO|=-%B2YiWSVNn^aRpEKBT6BgGd%4xqIk#Km_$@qVNlV!fx!qL3UhCu zr5f3G0x3azHWc$CHpTF1ww*_PXUAp5~0hDI}b;b;O2GA`N?vyW~6c-<+22xW}dNuQb zZsX9MNeYtDpEP1KzXf zx=$UinamrFKQrQjR}mD6qMX$+)Fc3Zv(YEmqK}SeNnBf z&_6W6T~b|jR0!JsgB02>&B>7@Bx&^XCxeia+S@UrV4mRMns7zJ=u1X8ZSM{o9q%r% z72l0wBtM2H&@*5-RpHVQ35kReqJZjn8o_b%m_x&UEz6<4R!%?Zhs+kIV)9W#R+#f> z*$3-MyRao?5$$rg@zOlLu|eJj*}C3cG!{~cmWa+8{i>o)6SX2rsI_jo#uBaEmpUkV z_Kx>jZu9V7vZ6$EJLj5!8JyP0|DpqM^S%e<9rLY^E%0m`o-r{ox?eg%fFD)%+kP3QdtRWX|4NYnH+qhs>h5woFcm`FPzCGlmHRq}x$anyJY%$2hnsrANwP}^VMtm_ z>d2&(e>q7djA@nPn-lDdoxZC?qpUYi2zI#$)#7KD3@Ll=8urmN+i4N_yo&sAi#mo_ zpm|3~Bcgu{vEZ~yKw#iZE~)dHr$@7bU^6N@lQ$(GYqC`P(4EMAQ7)kW*`IJQ?-vAV zYpwv)@@mNF*cw{l_eaSgT~9y>V5Er7z6O3^TR^g{anHkM=fbCCd@)H-0o7rX`(ko} z_iZy8Fl|N{`MVlfcVj0bTm0$@hn}NI+a~&Pq@O%gg%>+yGZBKyPw0?-Ycl%~Aw2J1e8PS)42Un6A?PDMrv4;2N72j&e9 z40!p#Dp<-Wfb`c=7Ud>nwIE~`ssET57(~7E)~?le-FX{@O$`OlgJp$h;$nj--^>y9 z4e$y@BYL@Nkx#<-Vqmr|Sd-QCEeg!e`YohgB$*vw+v z&%=KUssJxIhX<(_Vy_ntyR$h%Eno&>h%ucv1o?ARTwStnKY|RfNz`3>5nkq0i@l3k z07@~ZL96F}GUHHSHpSvk5gr|#uIC|=FGnICT*U~1c|*SX1<{*JF4F70dsllC2OYff zs_JW2YafCX@o!G)@tH`sFV(a=k!Z?~tnjU1;gN{6l|}_;cCT?+B3oLt*tK(qqX@F_ zhTK>ZILUBstTeHJi**Y)`?+ah_1X|60JE$W0|~?59h38(by$t!vcw@vb01+gdka9n zzfbKBpvSjZr8`+d8#_BJEZn)AoSZWkpFs}vq|~(i^!0~s=DD+8hj*kSIdb}M(0L$n z4Jkt^$Yr4LA?~Jn|_MEs@m-ki5kHuFw%yOPHWNUD_V(xWD^a9$9F!CYA7$=V=GCBwF42y%Ks8A*RL7GixYl z+PfA)fb-iQ5V#Y9s<`)8^XYQ*P_7^T1e?G-=iJQ@+H8EZgtZ_*Bl8nY>#qmH*ZCQ6H!>4yJofQOI1-7U*4>u6Fs$~sY=dY;<=t!iM&=HiG^Yx0b3_1no#4qu z2q?{k^&f#966x=cGp~dPcLf8Ov0l9LwtszN!wZI*xf4L*AW)}Agm*O`Tk+giHACVW z7b&2=Ak7NOdxEPO^kZh%T?y8RnFlicrTlwl)*)B(4~g1tH5EN%`u)Q4T(Y6lQ>~lHs@_g2g7hbL9kOg zuUcd`hrF(Ci{(Jag`RkM^$mSU)S8|?1I@?C;6xDK_TQ&``2q!80SXF=erTVsDG6Lo z)*`M*{7rGE#NL6jg+)cdlA)Vkxl9;@%i}~j;kjcJU_yp04sUWsJnq5#C&&R>jNoHt zxWYWlb-Sx!ad|(tcZtcDjBQ&)=+WOVRc&kU)jrO=1tqL}uTTZ$|21twwb3r5o@ zhBh_3MU=9_fh1x?xbifj3^>C`*n%I~bUP{aFC(Cm@N$sX&rU@lLB)~Z3`f%HX^)f* zO5(Qa+tg#H{HCg-Fc^lK>v*y5xF`=-@@awJm>o10uW}6D#>AXx6e*sw*VTL8m0q%t zR&C{S%Y~#UfWY*+rimiE%T(7c1y4b>&lU6}iUK&IlT`X(-Wfm$CCa9b0aPSKc*`$% zq4p_!%wXY%{m2{C(g*Z>YguJYPOtOY_&b_%?UtT>)v^riMOrG51?!7PVT(Ja_@Gy% zxdj`^?B|d7OiBT#;E8Sk!hPzlr0|d5OVJ82kJJJ&sbxuvEM@}ob=T>o`G<|tR z!l^bgI9+Exfhv5%V#@8FK55q5d?0q{IOWb^D)M8pHj_=^J%OHOEBqN|O!ceVGA`K6 zx6(wp0Bi6SjXDT8)~5(k&H(xfi*+auo#~H8{KlBfB&c}-?F2I#L9x_vUl+_I(Tw0( z3SzpC?s2vfy$=SvY-~LPo!+6j5f2{&_ARO?I*ouno(m*GkIMqJJ0Kl}tm#&=o3j|t zZWGd_%Jj{~{npQaXblbW7bF+3eP1rFg!>N`mc*3VjB^Z;gb0E-6Kn3hf6({x+?Usf zvTFPrf(xyOL;R1h8G3W00)f&q?Bi|2~MwTg&^QxC}{X-;b3bueP>M7v+8?MZO-|iW{)5*rJCX1n1>Tm?Q#dVj~buul_auZ zwL<~Y&(^5{kygV9{TEQeCd*LM`Sa^V>u4q@EAKEo6&eOKy{E+(_+b#0l*Gee?~id2 zliYlxcw!}p(AJndZHV#enfE5e+flj}>A**w%z}ZB#3Y`*M5!@XvXZ`htqZ?Uod0QX z&v-ZXG~EuWk}NAM`c-&z_N!4v-GS^UF_g}ZUz%rr2?tFDIy@3`T><_`DK+hfL(z%9@ua9ybmmjox3A-4@@ILngqL>A!D2MWYS|bu#O)J>oTYKvcYUJY@lTXJAg| zJZ_Q}wkAsu0hVBn_>E)m6Gk3`?ieqCr-uTi+kXHkx*t8u`3`ekQ6MdU%2m}NgX)YL z_g-GH)3gBW{*t_x`~?mBc{4J;r+h^=Xh#k)dE9}Y7GKy;1Nz0kH+>w3Rg3&?en{_2q zm!wmkPJsxL$}iQPBfL(e6Gu>e5V-Gt@v?ff>dy9_`9OGdtbZP8wYscm5h;{8gnLUr zbsH@~Iv2qziK8$Uy{$orAjr*Z`jxnd_x7_t2-Jf-SHexDFl;m4cNMa4Y>#|i4IQ_Q zGo%Xf9t0^XlrM!oTx|c6Ks6+~E2H-?{GcM7Nk^JqNXeR5RpN|oxx0zS1jXLoKu28d z=_k`!TCN}{gLE~^qecXETUGe*2Q}%X%#${fzt1XYU47!-ELbmB5Q8R;1TR9>Ao3gI zqRj?JDyab@zKv&Fcja>Eeyl71j>w9!yG_-aA{tB1*<8LCMW-}JYvIo&T?;b=cUvk^gS?QHROzxaTGl0~UKU}nKEikUR%W0lEd05AU?2lJk^tT; zDK-OJy&wrMI#DxV_;mom7ujXm-$%s>G`Jru)tt;Wk%V$3&uExeZL>Js8E9Y9{tcc$ zdk8hPGr~cXd5~Y|377hf4F>~G9uXfSRG=bLquhfQrfa zOZ2r5&R9mnxaKF`b)iB?ms1)N18ZvB7F$P4k(4#wUtho%3OWbi zY9(mB>*}Qua28XqRzt9UXDanefcqfQpVoxghj4ck)ri~~%mU2PD9dPh=hnoYZ_A#uml)TB>7mlD@fcb$grRk*}yC$4_^Vz%|` zwEbC+`t(b`@BUHjo^;YuqLDTgr4$egPf>SJzte zMw2NttYY$`I0efxu@b|&orEca^C&Y-j^#GE- zZp4f;pHGc?ky9}~66u=*5EXMI6SG{fxP+Mnf%CGTK6_Wl{jiKJJZ|sxuO{n|>G55= zZJ{CZ8sLLuNvZ*ZB7rZeMXX$SL}VV_@1z0}I+W z^9@WskMCDta6MAXmW`_wb!x_ZO-`8=>@F(+{|GSQ4Bn`eMqo(Fa`JNigLd+lghXZ{ zU~Ixv>6fA+CKgI-`Cr?`2W<{r!vg!=zSN`(rC?{gmfqT2(Lcii(Je;9Yk0)%#`_~eBMUnh_Y>>3t*NOE%<7{$P5+k$ zgfUmCoxzFdSAr?bQy6-Zt~fW>JfKB4LtN~Yg=%INHw#{GLU^S;7@`P;y3otH8`R+h z!v)3iyT$P#N&kRZn=zSncGNh2oLYWe2Kh4jCk5@04h3A*EjA}tl&1>%FOCa-S!ECs zgC^8w$;_$5#Ch}l>eTw+`!A1wB36Lp0m($qPV>8&ZS&0#S{F_EqBp9|;GFX)`6i zyQmV|{?kQuH}Nqj$_51>-vaxD7`$3nHMveSxiz({&{(Rd*JJ-|`cH+Kf&HBtha_c{ zA-!#W9hkKpG(|PdjY@%fzx-)BXju72lZH{FOWS>qbR1yloJs%fyWg0f8(ydk;G+Mn zi-5F+B2*&@qKww`KpDke0ICYo3)D*b^r80(S&ICv&u&YA#U&a32mD{F^h=|Y3J&i4 z$_(=*)fPOyQLe-h=jDYtT6>=KDhrUI@lprE5sJQFsCx^y-uz2?Kb>iTmwb)J&0kl9 zH8lxLSso)%qZzp1CR@Nr-^N{N(J=PRwhM&f$GQ`rnk>&455ux)40|c+X~^tFs3y09 zn#<#vV1M5~VBZ@c6;^@!^q?Q1s%jt19P?DfYfzUeB?A%yiY5tDI^r;&L3VU%sSz>v zCY>zKs$(Oy+|(WS9%@}NXyyNJk-&!3G!?OLREn;`8B>=g{y{boq zd5PNa<3g6$VY{|Hc3t%!KYopobjtWFkJ=D`!6CUKX!z{8fi%HT1vS;nUOF{>)0yehSaBi{abHDgf_I;2KbIYeE5JqY!~*?6`;cVCv?3z!h&!0Nrb2sE_-@B zTi1CD+%Vhj!iq5wZ>e$Q=P%h?y2lrzvHyKs5DPRTt$Q_?VMweIy1&$>w>4fsjNkS9 zLIIhSa{0F*))84lkI1n)gRQWE$M;WLhcoV&D(9j z^cp_Re?=IO$};%|TXL-6cME;*x#M5$ zvp-b^a3@Iw5<>DQax)>T7UXL`eU%hXPl5LYU5bQ_yz-rs+`%*MhUyCO0VdcAi`lsM z7-A`$Pw%xtSXqAkn*`TfeA+<$+oF*~dd@5~hxef@Y@{BI&kZj*d@9l3fM9Jc&ktSw zBT0VW+YoUBk_`%XnBD-|-*PpDR;+qpk;CP0I6LD9&TQS4WOlbqvM`xc!l$(|+JEH= ze(m@a6tS3ac5XG2$jcvNy@Y-5JRNpF8-x4PWVZW7qGCkGj2eKx9Zf|c|5Cr7n(6ug zqLkonGcqv|FgDKI9Vm+WrLns4d=7bc^W3IA*P^!tKnu&Rkzr6*Yf0ze(#sNqUmnTN zGvJTp6@$rZUVFCM-W^i&TYDJL*S8Av_F-Z77JpnB?G|3(w`9L(S3QhTjad250zj4r zK)mPE>PLTx5{(n)<+`INJORB&K4AhkO5NQEWde+LTV4IMxzay(juhN9!UVA3|UflZMB$O1sYm3J#1Gssyzhbx0MlqCBca ze__`tAP5ApNv>ZA%a`!LqqFTF(756k4lc*@`Wl05NJQq;@r_QDv+`qFN+AW7?ta5< z3;69(0v+uX?#;8z3;ws78bVBjt5~|dDUYjop3U}hD=dimiTlfQ?gKmFyKI)_AsgKPpo1S=6Gj` zNi&-02+*)M2c)Q~*!i!2i1_eOT^QhO{xu9}N?<{b*jn6?ZAw3a17+e z4q0w)HJD8-scOuNx32O*o`>_^x)h-r!0pI?<4bW<1l?g&ic_6tg3#feyR3hjsypy- zK&V%9K41vqLI~#$pSbvECp%G5cmUNB6|qzbf$KF9HR8P%#by!#f^~X}aXh60Y-takc}{Vv35`fR`JZ zJ7^oC0NTT7mK++6t254*S2^4=&IN7Zak2h!c?6*o5!T$JHst?`6&!-oC)z(0*j2xO zii?A^vvALB=2rF6y0(YNNwMDp6z(H-kUp$R#uID9Ts5R?;pPy~@U6L9ef~X5hlPoG zQTxRuxC$61;h*DeG)P|;qlDHjXq>Xy850nk@si7)KHS#A?)}N-cB`C?M11=K4Qd6t zzEiY6xeOMbAHNrQV$R?d1-SOH!Spr`gLW_%4|g5RWrS!~ig*2>Lhxy?#i;gGuYKr2 zWKt?`Z8i4IL2yUA0Sk6$&@H_E_URA=g{_NhJeab}oO4?R#1qb4=I%N9%>__cI9Nk3 zGltcHU1U1(4Bv7WlxttlbF=cB17U!IHU5nK<&N*Ra_MTm66OPW^-#4s9)4J+qe=f? zeG$KHHxW<|gpb5VBI+p8WUMFOLMy&z9VwA!b)|&zWafwhk_b%z%G( z%q$**#!{YA!18%`Xsd2FqA>djxx@XvtZvc7yJdAQc*x#5l4v`y?JtL4Ia)dSV7%il zqv{rIl+OLuQ2a@PZHuME(Y&qZ-(w46m^=Lb(!lu~3V+>w&L)85Nb_v8~-Qo>1e zTE(5Y1jB+nSxv(EE)dHS;0OVoBud$10%^$J5I0d+oiw_7;RyI%QJ@VCqGU8cYN% zp{u3LgQmh|FnhB;Qz6T;71vN3urK>ezC0HI?-mJ-yaE(6L3q%#_$<>WTZ*ogU;1+T zhq^&D)!3IC>zO>AlIxo|X+`+Q`#}jn98lamiHV3vVP|>%eEpQL4FNB_qJ^TRrKNS% z$^@PF^#^OvIZ%T)CbA7EkqSkKMvMXp6MOr9w!NWALQ}I4G*~1~+(+l*!UFgM)_BhQ zO0*PyqGrF^O>S?mFur1N3(33%_<`5{RRNd}XuVUGB~B6BEgFSZ0i%t^4y zwEX?3Xf1bK`kF%wF>!G#z^AjG35Km9nLabR_#jN+RbC4Lg8&wl-$)d+J{WKf)@w+c zH6;4XGU;kCls%lQ`FfvHbk#ADs_l1z4Whwleh_1+RSZdOL;PCO$ZasKI!!}|R4xD= zcDZ;+s_F9V09IoW^qm_qrqtFFgs+g%218xkhgyV%z;UGOQ-6%|)R}@IJ$kHLD1H2r zMMJcw9STcU--oNw`+Oe!6M;gmSxImz&-1$eFHPYZB#b=GpO!kQ^K7fi9!3|YrQN@6 z3K!*T^lQ-_(0xDAU6y!X0D^Q{jIRF@r8#|YsF*#m%KdF)OO(C*?E8LuPmomY@o8$? zhkr6JK+Z^HQ21|hAu**hL!O>N7|&sNhW=cwTYF!hu%cqr*e0s?PF2O%{-Ah z475n>C)XkciNN(6g|EJ8(c#*=WU;m76y^he5l34a=ZTksD#EiYF; zHo8bQ$0ZgzS4a=i(&obrrQ5`UzY??txBcb*3cll7k$`5pj9OH#&nLymeyo0K(qH{+ zw_1Tj3GAXkKpZ)sJ_0msWHMz%$PPEa>V-J=8j&$oEoMKPcGSeepmkHpUvwzNWB=pe z)Qddf+t(x2knH_QBa+Nl{{=dh=d4%rQm;dkjl1o+Aiq6^`Prmn{Bs_11QPA&zZQW; z0N%W?zcc{(TZUx#E#B!IxpDK!p#{6d1L^LzP*x@52YMi|_No&H!rA9Pp5Eb)HoA_>f7 zrRnLm&D8!v<3;o`j|&|z+ew4$1|)#;qCgvWN|J+Ro^eX*GTDUyoKS1Q00P5ZiTP!! z&Kivghj6qgQd6C|*K^4J2oyx#(1a9SHBZWko_uyjmwj!ghb62$+1m-G0<#Eye&ccvM)N~r`6>4c&5?^+f$SpElJ zMERgG!aQLB7ZtonQWm<*<+?GmffE_{U)!+yOJDIuYKtD^Jc(Ew&BSNqME2EO`sOjm znk$racn0mBvkRNTrg+Db_Q5CwBQrmr@#}H@XE2z&9Cau(Uf|JQ67V!s5;{(1RM2hK zRwOmG>~8LoZ4){}Ul!Cp&pgj&EY^~|h*kdgL%`&1A-?^T3FGD;BH)4F!t29RyBe3v z)&wm=mRt74`1c&$Cad#{e|#@iSeZ^&3f()q%L?bW&CPMRt|K#)nG`GX@57QZ;3?dn z+Z`#h!Gnhp)2NlwNH`hti%W;Rh)1dB#HLk=qD)2}B|xjS>0Z@I8Sr`5GuHjrT4Et< ziKuOh0&PqEm^M-{jcX3sX_-3N~$qqOF(0+8^aj%s$goybCF9p zy4lCbIJy<=aX6`(QD+Z3_X@xs(fQ-~G{(r1&EnnrqGze|URCm6*@DF>OVy7SD0xG+ z83iqd$3_Q}^=Y*x8v9dv9L}`bs-(N@3T0hM*EPxH!ao+3>FkjZn(m)u(A|~$Q}eX3 zo+h0ZASl(Q(?P+@3(w~`tO`q`aB*nT_~So7&~2+K`Vm_?SZC44#+@pt9^y^j_NnPy zK}q?qhwwpD1N~KofBzVKi$dfb@y=(Gz8h2$YNEPr*-n8>!h=9l^393+DGJzHF}*{c z7WjC6<#xu5(i-?Jn4c%Xs~e8O{q?)SAn=&nm)aeJ`d(zO-w%9fMw#tX3%fF=-proh zHI;v2No1QXr>t6vW72K>&l!SjH{T9W7;E7zDDZ&Zr};+vgN=oSHqkED)w=_X`@t1D zIyx><@1hz4sN%AjC`>DN<{Rgd^7q&Ib+kP^jEj~{HUE6vw@6RgAFA(THK8$#jG+&3 z=-?k9c6A6@gq)=3*G@P%6jL$`wO^)gm*Q$ zIoF!GGyl;Zh>&C+K(&N^vVCW zT|*4MId%Q3!9XY>5Iggb^(5RAhQq7=L0=zSU5$mc<7EYIES_5le@%4Tn*S_bS4+qM z$$8bJ*i+k{mw&FKM*y58r|*s7;7nk_(I^*_UY@-#Yw9_vkyTF6*Yw34F?&>?I>)Br z`-a-l*7TL5hKXa{cYBt|zn0LS3haXI#`mHEr&7mKRNOjos{lCL9AL zh{-}2Y=fE9O@4gdPZsN<6NAo`1$1MGu8KOe|}f;j0^7P_enET9(V`e@K)q3om< zoxo`|A>UNxqU{k&xip(m=!b>*0umo6@u@X^gx?12=Q5I>0rlxrpQeI>fBjN>U@>(e zPg*K)pG(tOPfDS}oXuF9%{FgfA%!ys_xyNAWrXK10RUO_KLQUJZ7juEbW6QOy*ODjUIOJ={4xr zKbog>2EvcmN)LCZm|R_>OB&M=5(zeQ*^|C6#mM}reqhbOo#)LY)8_zprUG;q0u&d4 z+v<$v!Ii~Za`NgiiIE0PC+5=61%$ zXb-Lrf6+haa#P}(F}bD{r!itRs%64Q?78tRf3DZW_u9=xp(OObd$24Ba+&mntJGd} zQrfH2f(xOjby)5v^Km-ha?Rq1OJF6$UUJ%-Oc4&vJZAjxm)wZFs=>ismK?t%>lcLU z{#3;l6!)Bo3!;M=@0pqc^ZIg)jm76?xtnS^^KemUX%wm&+PyTEO=;8qdQLtS0p~Nk;a_;JgHe}xVVPHwARa>mgj}jI%qEJ;Ha*Nu-(*u>eC7!WE0gZU zR6%pJIlo-Q#OC5mt1e1vn!r}fL?*nf|B-6Un?GQT-`j@&ht~kn7u);b$Y6X2tq$-? znojzdJW(_%7~x20g#MnF^)w+G>d`G}eUh^ii=a3A(t z_VdH~CUWC5S|lXR*}kpMvgiqxQ%-cBD@vIpN;ShbS^j?2s5uovAz^`wC6qhMh@PKw zc_sCqvqJ*cyz;O$l(0@=cebW&h`nt{N1?bQg4~bravoL(1pfPn!XbFLEdfeeZp`kQ zdcn+%kX)3bS)UZg#@9VSZPxnw?tLWx7i}QH7*K`pwESTqOaO(_{WuuLm@MCPqVQn~ z$vtv1`?jEoK;s^R)Wc1zsBbfA^AiO#tpZD%z;np?*oyquf$)T>#P9-h2w=E*XY&fc zj2{*RLJ2GWWsxWSdZFCdbX8l0pm(s?5=5ottryQaZ{{pSe7^iTOSm2Woc~Gw>y5(lbA$#_bdJQtQ_UD|iwcKVVnZLs zkcE+vrm)|Jm?3(@Cb#VoyK#mGU;M2gkUKws(D!?TxZZTRRY7tH_wji0Uc);LhZ*Lr z+KZiq1wXi`z#G53sF6(I zU>=?2*vx)#AC;i`m1Hhxq-`++L)U3)^NxDrENXEbi2c9sEi{?T)KCu6s~10)_*}t| z&0MXu?n~`8N(^nYMvvlXgI478hQ`r}Mzb(mz$rBWTr0Sx)7m13TlG~wCs*f(VpA;A913V{iPc%nMqub^4R|@jKGwPMqd9 zHWrc}KYh!75dGHc%Go!E^=nEEmu9oUJyf6B8uf;>5`<684pt9Y;qyXH76WMLYk4>^ z8y$!B?|$>~W`(umt468jItSzR{`Rkdx21|9b4ju12)c-$V0kAf7P?<4W;-%k%vOtu zPR_IrmL071g>Im5?bX{48AfV@_7$Ch913pTZv6=FbrCYCp6OtQ1Kt zo0q~hiOAnIsh@bxwTAzTh0au_ad!1H^TBL}ot26GqZLa0Ds!9eB;p> z8gOeAGC84yBR#S*-5&4BznYm4Zw$1c3-EY=kY~|jl%Px58tqNxj+vjvMrgCzodFk{Y7zk^?2^{x ze76G}l~x(*QRNyg+a*na>se{UU4^VujQw$q5Zq?&0B1FinU#VdiQ<~L9Ck}xRU(qt z9+5wBT+R@^DNCi+_-os_#^Ko9P23;HE`d#m5jbygW|bdmq5FL=S@3awCpouPG{NE4 z^lju~I4#1xxQh3t9r&C6^^TWlK!AUW?78o)MW&BldL^`hF?gygT$0HB<s&NR{YTz0a~dGuOm4W*Vx0#C!=DMdFXkL8xqbo=8SBVSp{fKDVk~)IZ%uY zZ#aY<|Bk^R4i+!C@SvCQ9i|zlZHFX39D{B|be{8p&{86go2!$yJ~i`-6*2K)44GoX zzNjJAy+gK}%Y(aenwC>%p*N?Y5W8O7UPSG&t`gH!*iq@TH1~z@5k4BCrkjZ-Vu9z1 zr3v$o`Q1{#v}O%q94_KgN}7pJ&E6lf)|PM!z7nnGpCb_9!{Nhm6s?W$2J1J;-VIE| z_h)bNTYgZUuGpIfsLTI!P4hE>KSD9g2BjbrN`Maryie=+C zvR`enVzIUC3s9S}CbXC9Pft1uQgk=SrQ>iC-QO4x-}&w$D+6MO7h%1Rwu3hg?s_wDjb8S70{XtrQC<4rCoY${7k)B(aBIw?Ken;+Q>dA<6t{NEIIaQMu9 zZY@%!Sqt-;)x^&rJOTdopUVql_IeRb`xf9aEJi^6Vqc}zue~zKX{{}Q1~7A4i8M&Y zA~cCnQlhj22xQZH^2G7RuBGgvUIf5u>n~UZ`ePbFZ%h-Qltt_1x!AG~$#+@hUI)wg))&Y zSsYo)=~_0r=dLrCFdEmImh~NUxQx4!JsvCR^^-rM2hScyje&6hwVFa&O8*JLo?_Qx%CsIj75+_knH`_)bua8kX+su!6 zC3!~+fA+qU{AZHpXKG{YiRYT%qVQN|kwRLr%_=?05;>!S{+aPJddzB8IEhcu$OQe= zuG&rBV}Npqza+6nua1#O3G>!QF%%H%TY^znr)FO!4{{!YuO{hYY6OxiY13 zt0JBKQ11u%hiBe*!lA!C3F_7wcaD!_es`=vfOq897jiVFNsy8%&Yz$E*Mq>0_Y(te z3%_wyDnZio&Q}_juDJo7nnsMKu13VXBG>I7Z=S)(&sYv8Q8>;|@ZXXDpQ~C?68t{P zpvIx9Xh%WWISNbplyi|}X6KNL_%GiA2;rNLXP442RG-Fwu2P5#V&o1U)1xIE^`F2K zqLW@6AU*Z}e)r#<`+suzfBN!YFXjL5=jbq3Ps<@Ey#7~uLGL?Ya(D}7rYcpWcl>!b z|4Pn&iv}sp8knhZkIs*j3!so1)Yt}il65-?xSijE>O5OK7y+)+?2m4?wXO6&=MSXv z`0L-$tH@-BwSlsCQ^i{1V*uz!L*o$w;rytU#}y1mcjpRmzI}pquILkVbwRfeVeg{e-^kX*uAmZenc1dr)>E{db1p zr;&7pTR;17cg$hVjTPcDt-6BtENov?maWRoTpduS4xQt^+-uSS!P0)s&HqIf?U3{? ziaKbF8(^#P&32g3kB$rpDd;P6q-1UoTMKCUZ$^!Bbf*4L@2RNEuvJ zFGZKJ8%P$cST;-hpUw2jhU-__nJTkToG57*-cPv~V#@RNHdcyR0O}Mt<%fj(%Jp;G zKq$DfhL{x_f<_WN!G7C+>Be^BH( z{hd63Jf!fccKuNHd+Z-mAc_mD;Ms3}&~mcTb9FLDI0sz+s4uoF(U7#r&-eXhJ6?m_ zDhX#)OMFbqbLYeL)3dR@jc@{vt1{`}Q<74QF<;9~I6F8b(q zzbnLBfrHxJ{GNArH}<{ycFGy4(D0`nXvoKJ(OkpMk`cmnsK;uduq7R`4j@;2x7oN=WPYUSlCbD;hF`y>{f8qaS)=MuF#$b=Uh2EkYW zCD!plW`QFS3W5uA)pqo(20XK(X-BYCX!B$5jN`ArwPN1=31a80SS*Z${f(K<=rff;9Oo-WH#<@7V zVsK{9TWiw?(j!`(1;+AbPfhJZK<~LJCeOpIv%%|Ea)UQ^Mho_WpG+fmxC2u;c;uK#AN26ZO0BP>CA~0cY zgoS11&i>inr~>X6BX0C<-`Pr{#X`2xGk09r727gu zyiyI zL*J{_jFnO8kRW&T?bO+b?ZZjoq_Ud3_n=s{99=G$ic4Hc>=FVVg;&q*u zHTvE6Gw@c%34qo|y{$G>{rB80kB18j^Ge3C(C!3qy3&?{n7eh^smzJZUq3t5KvkN$ zZFFMzM=+krml0=G_R`*`c@~(wKb}na9%z{4lt4Ga%Bk$#$rx2Gq)TnPODg?KN2W;T z-DbkWP9*+&+S4(&Zh%vhMHwhm;zaRLKRd%%Rg$JnabWn+7tf?|#8{WTG~j6&J*fQ) zOxZC#{WkVA&J;s>9$+`&K-J~V$!joY0uwT`Yy<~!{p^l-ya+_w<$R4TDu03Nvs!cOnGK9xxPQQhO}h;afG zZ>BYWRMwRR*^va{w*QZmksK#E;uOx~Ko2b2-ragJIxuLt3>xPeq>gaM%G#*kzH6AS zut{-f2T7yGa;v3A0M;WZnsJ!p&NsE5g=ja8FgxV-b#`&mQgWbm?34d{_svpYRBPw* z@=I8hFTC4ZQzSudB90o0PG9Mj&|}U&k1VJdMw_});vzhQfW0dN!zxyHjR@m znd>Z2ePxy&p*Q84d8HUK5^l_yfH$=&w_VoL`=LG?l1*|1spl;2)bUXr&-J^6r+gsS z8qxo3RKTL8)>^$GInlv7+iy0&ZKd2>aCTSudnqg|WwuhS!j~eB#DVkYCltFraUNqt z^%SsOV?5w)T>#E69H=;$E}(S!w*m8<9NydB28Pr?IS5zz1Vo+oxs29C2% zB>v~avP~`L+)f11Ujth&n5THWJCGv$kjv4MXcWP_O_=w5Ab<6fiDI0c!3JPdrC6a5 zhI7PJnh38gn`-1ZoI?^S2=@tBd*W~c>!(FNTIIMy zcR#jZX~@4WdIk&U*tcwh5B}3sf@^)0*pZG6F*s}@B|DXyt2TFRwx2_T!Al_phI^yC zb9El+kttR9<1T4cBStF+{Fg=&+V48s986e=_?qrD-TgXXzu=4W4#KqNy7YJs4vccM z+27-O*vow~{V=SL9*#a5#gV0^4Zid^)!$3_k~FZG^pRhy?c< zw`uW!(1P6Ob9rTu|LuxcJC9{Sp}X|YDz576ase4_KQQlv6D4VX2HboPA@!hlD2WaJ zn}K}Q>Lwp!e>~L-pn~dB`KS$w$_OBrY$wna-6`9Zb<*8gTy3am5Wa@p63Vqo=WBmv8mX;Crz z*cd=WQ%-~Z&mp+$?k{{<168?i*^wPfnS~F-9K`)&!hzpp&(p~kPvZ&)Ojs7HhnUv3EGI)@tyI?ael0Qy(wJz*6 z^4yV@xg6B7cxh)aKledO@*zhSH@ijvR}9=Ri^Rv~3=1dj}$5ozuOAoSldxoLQzL z^rOU#P%YtO#mjZL67yQrl&i~rhhfL@D5@d+p7fln^|v;7!;Y)-QI8UAf~59E$rI$y z0q;ZzXCs7St2}6CK$tM0>ZI-eu=iGRRetN+uL2Sx-Q5k+-QfgD0qF*jE(s^yDIrKo zN=!sr8tD{}Mq0WwCk`3BnGPP_r+XF+uQu zEPp93V|`486m8jvDH$R*0w$?mJA;8ZHO&5$RAAYaht zs97t!sMckrI`-eR-klHR>2e&eE&Bn5epi^rbr3RM$zkp#7`8oz_PqJS zKb!97TJH?FNoX+6w>ADi`xw`s^W6CW4Sv_{Ie4a>yKv;?${#i197+g-@W}$3AJ=E! z&IDb|nVZx=tUlE*=taU?dQVah;#?GryvZhJ^V|-}R1@E`pDqN{5P4D>zr%p_dg1}* zaqOJYt_UW7IFYhL^3T<b!l?}F+{l?z*d4>wQe>G+Fd zLdjIhkN_2fTP2fgjcJey%&iAV!6VpSmdWZt$@`ygtQ8b2uX|)na^k{i>ghe;M{sf6r~}g{g`%XcP}}tTV|ln|13nI$ zfIp1SKgy9h`~tbmX@bte$5XqI;z-U$sCcTfNb4%}3n5_0jM^5L5_SvM@;(Z8HT{CU zzz671wzpYd1^V%yRMG5@$|QiL!OchYeewA)cTDLHiU>Bh&(d!t*hIu#wZ z3uJ9p3v&V?iq!naCXS>~503F%ZiI6>k6)lu=<*`UgMLp#N_~1YQ%RH#&D=8?&+MTE zb~)*CaJ4dG*eY0K1480VPDh>I$UA507s_`}>;^v_m%lD;NNt(J5!%WLUXjg`%4J^Y zI-w4j1W1=LzD-3jtS6jEB1e51CGq$dbPk(IHVhvaRie^~XZpsSOS^8SEmn;f?{!HG`uOP3c~Z_-y;aBfeHlBJ6kUgFDD`U9^WrZS?1`N-$7 z?vu6tST<=e)%A+7MUovWgjj61=(SvA^@YJMlm*C}t-dZ#6J+uBStKs5(EB)_8zOeq znSWt*K+j5nKb)N$b@K@r8KiI8zX{ZMqb~S?u(UhiQ^9b`sf6V>;JYKtmINBC=jpEA zDu^`DvSfW+)5F@&uWa5k9auf-RUIZlPE*9-rgxW1OV*UQW2_whW&}yQ#OSppDAAbX zKJXy1Jjn`_x!Vxel46!%Oa-(=x+Nv=S!dIS;>h+vrLtZ&+T0X+wu)uvE zErd^Oj>1;7R4HQ-7@v;+#2rLIoL<;#T7J3(;q2|w+lhG_CU5!}xO3PuJ^ED-W76c`zcD)Ay5TQj4|5xRL-dXc^9J%t0i1a0UGD{w_ z;!t&J>gAdCe(yeO`6a!KkbHUp*ue?3dl#3k7G>cF6q8pMqQ{M_9&vHv501qcL@O7I zWfT1x&YIn^b1`F(Cdx{!hy+`ChrLJ`qH+H`H&>W-0PgF5ZK-7_D6Gt71_Y3f-tl(g z@dtJFkj#wUzN`nCiXz8I(X2JHn8Nf!{OrFbKVOtcbnqc2Q`-$&Gq21>6h4TFXo`ss z0h734D`u&J(Tc-kO=9+NOmc6mXnQsEa4i15Pt8x`?CFX5i9iSkB*FPPU-L9xnj7QfRFX^ zIBeUT58@K{nh)($#J1`oeTA{rFJATm1w9WQS^hlX>976h#}st!!es#NfH@cl6FAdcWq_x9avs5Ml*&1WVXbg7!)HuQm}Q!T6znQQpUy8pbrF==~<^{cPY5h84k^|r7j9` zJ_H)|fH7ehUB-s$vGBpxRH?d|TwpwQ|FaMmIYfM$T0C9!VdUJZDps^C8t7L(8E|-s z`9Qgnyx2$3AU@I|fn??gl0H!Xk$~rA$?WPW7oG2Tjw?2WqPa%Q3fb#}dHUN0?scC( zilc(jy+;nF(^L`4Zi?Z-eU#OMKmKv{I&YN`L-*n6MgnZL-|lM?EDrg%h0o0{5cD$0B$ zHqN-J_?F@0q!!Fv9pdVIE_o+owOG=#6M7&9i;8))~4Wfm}B8?UyV|-wwB91 zc$7_L0*Szp&=|>5w9z?ph#oq`&Ks30OhK{Zp|n;0Dg8=HS*A_+uno}w$wnlzPbF+kP#FE;rXRdm=d$ZLIu#7tmd+ zZA>stwqzT8S4%fVB7?HOuhAYRz+jno`vg{jfQ=Xs!vikfa}bISSM$lwHh_MRoDb(i zG<-B_$59&Cr;bkkgP#gfD8GB0@Sl8hStnElOZny+6bE8iVr;v}U=X*k9KCJjCM#C; zA}2CqUu9mIF2{l3f|fkD;pQV~zU;8+bN*oVOzA)yQK%O+HBNY#5Nq}L?8h%ZN!IPp zEs`hsvlJG12aPRe60wf*D)AmxvD7v21w1&S031O?b}H`8_O=L-vbWJ)ykDe5D8JU5 z2#k?Ez=twnZ=p@Ku2Fsu4j2-_LZ(Dj_T7$-4~^9(@F_xsksSsMjO!QM?cfk8nsez+ z%}hK{CZS=kflRt@yRbF)ij$~e8zzyccc4e!z-|R|##;28EUIJPEa?xNej4-Kz5R_% z9rsm3)ewdSvWe-UfkyJ0HTgSOK`Pt6m61*^?#g#tJ9m{l{Boh4kHjK_8dFKbue!BY zuDL=FL21V6qbmLs)lp@=4zrGxZo!SnFWx~fle+H6SSI(F3mGeA@1mCiO)M?`gpmEuuJv=M{) z@iqDyL}CoDj+tH}eOnh+@FCi5#I3{@kNr|=Czg?3P^L5-#^Px+%jdv*EfCX0=?NJx z?-=+zfB2i0i_@mQJmRQxf=4Waj612cNo?O&`_?*+)YHo2-zy?B=E4N!>g3BM?M+c|lYEIt`tzzwrn64@O z_gUS4#q%}+G^n|(GazjD-3d0$9v9d|Kf&aiZL~9@+rX~QSp4}KM^fXL#$wUXcx{tO zWK8@@4mB$_wLd8T#rT)jygh!ZS#V1-pT{w9Iz1VUW8P$a>Ae()vO-P8*1t~izNY55 zy5?7-W2UN3k)52MdE-d!r(e%*Bui&IPuZ>?6XRjaJ`83_1 z$#1v>pu4uesUJx@>cWBslE?gHcpN~jPCrf<8q{gP=>^^tPw_PNx`b0z zSryGZ!H!5e@1P)&_Ph*!$)v7+qGYrc7J|@1y5~oCBS|Cel;w+ZLON$J9%Mi4cru%a zu4?!C1)=+gDCF2Z_)Tb{8S!5wtef5%OBzQYJ3Cr z8FS~-zuECtUYxPbs|G#?!UG(=gf?j@C5VxRr&H5W&t=6kVeySri(H+@SSBLR>pveK zS`x%IVF=3k`EVObdVW97lR`<9OdhacARwjpRC6kFE1Ps@V5OLw+;jgmsRu7cJX;Rd zia!VV*{y{LHoAu=#6Q>kWWNv0X&{;cwLt7cGql;h2$YT;P_8kGwjEZK1tmjocFa~K zJ_R)6a8PMz6WkR^FV10`;Ln>-jx*6(&K9 zy@@wN&?bTYog%A#h4ei29$O*yxS^N)u)!7nAFO=0+~$M4$mJ_dq7F$zDVwAS0<^s; zEvE$)U&`=PqKBn%Wo~wczDM-bZ=^iBW#(^kw4oYnHG3aTebpB$cKm(%H62D3!d4;w zI~G4aB^5q(;RI2!Ob^Evjq(sap6EK6hEH3C>k2k<*j99$d6mhetc95#`WXh=UhICT z*rLufovCr}h|sY(lR0k&o&4~qpq>&?n?%N%#MN;<3rAi>mbr4!hur!U^i6_pmR&7* zLn0+hDeItg*M7?*w9#=_8Hc#Sp>_rqIz_lbL%SusnSu96u8W8AhbIY^`h6;A>;>@@ zlQHX$hSz8^v2iZ`Ov}N;x05eaMYwL)S^RXTu*(-4WsLB2u@5?43uKwOtw766`GgF} z%<{(1FIxYIdx&UKD`y=PU`&b))qwjHEC6)Y-js#edRYaPyw;i(U*+u_h-XEJjip&s zwpm@%2)@KbV+&u(TAMoWfr3*ErB}4>iM-Ws3#uF4%kT0M8szo1bl-lo$SCw$_iEGm zyiYFbpEqc|UFXcZa9&smcgHxpE``6*SK@t2m=?~|>d{(sjw6++!;t zu6oHO7D~$c2MMTWB_msnn>vbUKNfBq>t7$nx#pecO)>J&i{4E-w+Wld8hn8|Fk*=0 zTuF;7MkM^jylC`$hciaWtD^CE#|TB6uhYpd75e*6)8FKgi;01pVab&5SQT^{>ND@Q z>DO(Pi4^F<+}08*9?JUFJ7kG5%cE&3=*?n>=S2%*2G#tp$L$@>_=x{qvuJM|LVl(5d6%C3m21z{x(p(3 zX-_vk;>wGPJyry>ZIh2XJ+_=SqME&9EuKjgY!mLb^PAn&?|g7dG`vybK7gJlhew!( zQF04xbr~nUcJ*kMB#ii?Vc7Sns_5&}6c3ahCzfC-R6b~Xa-HA_D_d{3EkW=Y@5XwH zN$kz$)o|lbm?L=m2@@GqG-rQ>B?kZ^FymGQX;M1r!L6m9A?ZoV)pKXs4)>NkuT|+VMcj2RQf@#wY6F`-8oa zmF!pW#2{R^qfn8nrCQAr=osSMpGf|x`qFw@@*_-pu7Rt0DfiveC`t;klS8jpk0L-% zJb#~p?9@T$vmNZFK$-%!uWc}Fa=SuACdVyzPz(CRa7nvK1UD0Q12IO+aX;HTV;2lI zWfV*#liy5`pw{E0@~)-GSF%lvVoryC&zc725>P`fo^vDz1*W`}vFpp5t%maGPyHOT z6*mHxneVe9C3b2-GKo(IcF6z-8s=&6aVA;&D2=jecK@ z3d+XQR?fPZ4X%||y~%B+EvKxgpp{Bn=H~2}&!?(pT8xX!9RoJMRf; zcsF9jpo|!cs1x z*R`eoicw##o>oh+_NDR<}D*vZ?Q+m5y^D%YL{*LFR99)|LiR?LB)9pfwP zK=h{Otrbz*svy(DjS(vM3H2j{!VfM$euez<`$VPGo@v8#k!O65H8&%xXx?klIYLMu z%m?tjOg*W+JD=-CMi-=RFRD;E@$Ws$ z+$2yhLq{L3NoN%5P0frtfOtpws9$tB1aF1<0nWvCHN!)3QQVnzIm9#I(}|!k>mN8p zU0LUHDEX*r?7?7@rU|ZjP!68L`}08{$iY|CQMB!;cA_Z+-e)VnY`LuLWnd<;-biLPQ~a?S0*Ih9A7aOPZIJ7lE^|e z@JnMu4C9&Lqrh%ESn18XgNXM-So%>dYQZ%LE9vfOrMuUW;P`xPNrQx0=I?i)^rxQ4 z`nHheX(Vdd36i25`zxD`rh|R%P|xk#1^V{TH(WnMSj^pWPz|G}`#vLfu*clHwT|eS z{b?Em+o&_R%0u{0^pJ_8R2jqKwID;)kfKlXCa7+nl2-k3gLgsGLw7%1N!#CBS(0j5 z;m29-klBExx7>4pPO->W*utTzjQ+YS*>qG)p3K2bpuXquJdMjtCYmgaNr7jDE=l#L z(%gyHI;ku>HAL$17lXk{>&>NyV-dT!lwC43lcX1X?t>>`Z>o4O5VetV<&jo>G`^)^ zGHbeu@<2+el3JN}S~^WTo!oP5RVWEs^GXJsHBlOobl>$Ol?8cSBg*eX}Lo8JMn_b2yS-$m0M@B~JW5F}1 z;VIdub4hIg%#IR)*%$h9Bti!0Pz%k%1yM6S5NO|YCO^iAVHOvp!16*u_7h?zlB9Qg z4|n+H7Amr$2Tr|UvuP7BFEWT#LWx0U-nHl~7GIU$$MV0FRStWc0UN5WGtpi&n4(~s z`MJ~hk+&Jt&CT~7fVhH3pnpqjOlUoDIJlKhh$2P{qUda}k@b3Q70_&>HNTT;N$iS< zjtv@5&1Bkau65-+Sa`|u`k|($HP_oPy{Rkp!V{O1LBZZ6t+7+0WY)scNoHB8r8*Lq zPcsm2Htr>o;7w{7lwj-U3QQcslq-qU9(4VOw`^48pA)UIZrgH zCby={j-i&;SEqQFfZ@VO9LN=SK5e%AzjDPKsepK{rj1lwkeGb8WozTnR_b^A``Zx; zG-^NM!)npVlFf;AH0;ZFsmS3}eo!qm>(HcVz| z>=*ku^;nu-k6iA)jyWOrk5n&qLBBe^M`3*RCwLR_i%zh2{5&3kmh9zevyyld$wj8h z-S9d=2B`rH)OugosxI8}N33OyFpJ+y_@}p~4g_FvCJE>TNGM?LYPHd9CfG>cji3)| zkwFB}iwu{2=}FnRZdYH}{549v~x0_*p-B(Cy?Vp5NT5S~Za(tdx zf}ljTc5(T(0;OLl;QlqzR5q|%tfm>PADKSBAU4) z0GZ96+HG_Hp*e{K|Ev4B#x>&&!H?JCO96~aT%chS?YiqQnZXCm7`XLJc<36TN zxX>l{-9GpPNA*ZXB0F%+pQ_~DX+1kRrI4qp%h%{gFW+*$5Isgd70XYE6SEPVS&~4H zv1oxP;(5Mwr9K<@rJOJ6qBc6I;Z+Smyz98`zpWaWCg^QJQTr;9Bo0Xa^>=Q)3ns>gMp z+<;sRlbhOj_7B+am7P|1uXoF-q)v(@bbgy z()02OPIzhsGN^s+c|pIlamj&HylI)*=aBT^? z>V(51Up%CSFjeGH-QdkdI*%@ddCdZKI-Bi@KK9%rtSlfc?nOMBf0Ywu69V1AwJ5Cm z!Ly-cJRIk9_3?&P`N32lTopM)$9T97*RD;QeL+Thnt~+VKL zB97!1u@U&dE^Esfu0xUBGTMq6I%c&*aS^{eoE!S zuJWPVdm!~u_qr#!+lcf-oM0O{s0Wu0xz09VNH9i`&rlSt3hVX6R@phJ!Pw1{{HA8I ze!GgO!q zD7d7!f|SrbP7w9z+hyLbWaP4A&e4APifA_8KJ#~3RC+J?u+XYy0t!Z7A ze%|u)FvBeE3yy84e|(qTL9Lr~rWJ*Rz(<9KJw!zA(nUa}_Iw*MZleoM46tRdph-11 zgAkj8hmP?9B)A}<;Pg3CQ6l3zhh%Q_D7NFAU?r)6aaB={XmVkmg>UfcPDTgBbmleJ zK&lD{c3vpX;%Q0~YeJ5$Um{S;_ozgCis2d1rLut$P`||c5I*XUWY@(8nFEt|er=^) zbf9zMW64Mv$PELyB#ooR%v{qT&xQ*0+5HoQ-$04$vQ8B=x!Rq!2hI&;C(*N?Jg?uG z49D4feoRNw$IkiS)`7;9&)3JK|9nc9~v6a}qf4zjD zRt!nCv`1P`?xl9b@`oCq_#Bi>Q*kK$JpKT;C(vb|0<9STMZ6RS(u>2RvbNPceS)+h z-iNqkXPtH=cArPI6L`AO;wW(5560=@4qbN=^% zh$vxfW%aZfZx=k-C?YU-S>s< z`Q)~vpG?=Xim~^u z)~e(ge)#JSnni=_0Aq1SvCDceKK2lvD03q?6%_H-P&HhOvs;{%%qj}ojO9;D!ODt@ z>Ohs5wNaZ2$XN*xnQGljV*RX0L>0@f?+)+lL7ul_tGX9D;tP~mFSDf)HS~LYs8=zL z0l|fS;SW=$n?@%xo2@X}z_bJh3`ZOsK`tXfMUh=E%tD#84W&mQZsyg^yf^a6m3Lkx zi7Q2gSy2N&4uGN#<7B$%Wo&J&UQg}Tt3MRnM!@8?X&d|Rh_HkfiJ0RxDSlpegqcl2 z*lQ8rfobI$GrhlVv>aaG@^wRYk7M}{{+doy5HFtOnbnKG{|_#jOliO*;{EKc{2wf8 zq;RgH+2{YC{=b~JR^$^F|28F!xC3@v1+^HM^Y_VVvBhkFqi$~^NKjKxgdhZ^T%m0@ zkCssI?;lGA0?cj-9Ayb^Mov2t7#K_jwEQoBsslMTNf>a#Y0Wm_ii0C*`SxTuZpH-Y zb3eU3m`um=%r@Y#?0YAleD#K+ja@H}QoJMJAtInS`HsOZeHYwK0f^J8D&ZBmZm)Qx6S&OJ1$2U%?Id7v4E8OOQSQK<3udtX8m)! zO0^!CE5B6zxS6+QQTiGBJ3j<)$kR8YrU0ju+2K+yQGsNNItGPvXXw<0ju|EhT{&iVI|0{BQ} zDl;VUyCfn$hs_!e&UeRmijpS4mbuJka;10)W&&z73SV5r2-VH2!$@_x@uM-IBiNTReh1A3;(upmZa2QzE3zT)*r!W)^^-a0d?I%43k?KAi{HkY#wWww&tfyZjETiIq{gJHc$`Uju)R5*)dpq?Ho- zFLFdkfNnyP;Dfw#Ao`8-6hNF~A1*!hNi`3`emPJbQJ|EfJs2vL?);+z-TasP6vpE1 zh9IJ*QbjM7f79tp{igD>CLyMNJyNUuXJJ}CrjIAlm@=@%wG?6>jD1{p@m=Qbbnj@; zv8b7_;J-@y$AH+#_)uK3@pZGF?cdv^+l?v)T3pLO?C1bpPtq~(O=97On+_qRefgw+M)vF}OmFl=^x&TbKxWn~?6Hs7jBirWX!!|} z@1)&O)6eTNh8$Bm{8M(UcvE}ebqG^&Ln6BI*~+$ueKvT_2KlFy`=ZJxC@5Ls+H@K3 zGu`<hMm*-wi81TznO3=n)i)b_|2q&69yQ0(AuGAl}R8P zjHF@NsKg@zD{}p~B6tgqGE8_Q#R6BP5}f+l_)1&ffzH)G`evNlG34hP$I>Je%o2#~ z?UNM^Afce--X0WOBzJ7whFwy=hKrPcWci-YIB%p`g=wmRS%g)C0!Q9M6bCUW#$GHw z2i}WpQ6xTSyFm5KxiMUw{|Zk0q9*331?B{~%3u~YR6KG-voDmgOC9|>H``mDG{5uh z0(^Fl?xjLHZxCdaS!?vm%5BSZjzQ%4FE8ig zpaD~zV1ZKTdalB=tUdfJkSmYQ1v;=8BT+!2mp;1bg89~JkT_%sUU9fkJu34S?r8IZ ztTlHimu_kovEq370ec%!D3(YQ-1CwO1`bQnw_yq`1vsCM%A+J>^w7b{ND)8oZ(9F| z{!W)e9;3ylf%%-~QDMLd^M-v$fuIECe}9tw&?VT0g*em49`1wa+M~|CI@y$u`v~>{ zNx!L4e`OU<;2v82O=e#q&7JNw^v*S}1p3$|7O)7PZp7W<1*gJD44 z1;wVuhY!U>Qz2Z|BRK(6561<#l8G!yk;jGM*JrCTa74QILu<7yswU!e-NY95v5xIt zryG$NITyl<4GzBvFUNxQb#b~Ys1H{RVjNaT9KG_0aQGbX77(TWjz_{AvH_|C4CT52 z3Ug!|#TXmj2!=ihAcE~J!gD)6f-3YXXueB*6MtOjhizcX->>)ty-DX#!mJTF3j7qwEL96vUMz35V)l+TcrC;Y%I{NIfx?Yode5 ze{(Z2r(bh8%bJ51g1RH#EC;`71rKKBrQ7=>E+F~)OfS2}LL>=q^Pal1u602SxfHHX=R*m$)irfwCa=PKeyR`@GjZ-kxKGnMOb_873Ow@IaiwL;~Cd= z_SqJgBCsO(xB^&=siDT$f{dlw6Z(EHB7!1^79Eg(Tl4y7PjlNi@==}}KM{fh4+G-wR^W9rza{AUtmeC;$&IXF z*K2Pls;TtZu1P7?qZ2<%qj;R=jiA#L~)Q<1Wq$$T!1o)&JV`o z-FLf4fsfe;N~2J|9H|s|m2jhf_=^~mA_6eG(PwN!=KuWjf~k!!Vq!*hT(^JtWBjE) zN2z@ROpEKw@Ro4iPD%7A>cGl;>YMf`{;pN(^FHAF(qZ>n>1n78QevQ+wzCOu*7#Uk zi^H4@rBfb&ANaV1CWngHsxXU06a)vha$6l^n=a29t)WwjsR5mq<0PS&Fo)Hi7`CrJ zj;~il+B=chF0LcU?*2t*`xYiMdgKWlOT4G|-6dy)O%6s6J zm96c-x9xxB`)?EVj^54>MAdKcctY%5&Mc)9O%C_k`W>7k#QWotx1>~4Ihp`rvTI#SR zbc^!iCr4eCEW2D+6E#&M9~vED*8>*>qEewbi<6ihtBaE7t20F=UmzliWc&Tz>SI=d z);&u$BwI(su|OxNoA%VS3$b0q%eTg(Il)R~Q=ZDVx_+K_?n42?+i|OZo<=EzB#kQ^lS75B+Q*5|`th?OZ6id*65 z{WG%wJ`fcd+YmOi9&9xv5k65vg_)RgHLgUltq9Y|nL_>HCv*3_=xoF>{JV33f(WDr0x&|%cugxYCXgx zs1M*^1$f+hiu@^fPGJ{(%M_Pz3Xh0(X?Xpu_dy1GU;o!;d-)|A%I}C+V%WjqW5U_s zU@_$w`& z#wZh@DDr_R!&!#gnqgchq2Pi7O(^ZGp+=i<#nB!xM=HO=r6C!Vfg`d_Smv5s;;Ud| zjoy5WuC0D$@@Ln-KrvN*Swg&I|k;PH2|-e0FWVSzD%4Mibk4v=p)x ze%OBY?(hPHd3n&=my(foNAE5c^AE=Dt}guk;^+f^E9IaR3cz6lkFf%e#MXN7CrhUO zhFSp8V9ovzRz+=(E%lYdB2v9cA|g1zbTo^`aYUWoL}af35{)G$dwp?W5?vhd;;Zv= zM^f?+OQ4jEE&x!(6Wd+i$<0Z9Jzc`2gDvWm6gwD^Imb4I~I<&c2^0=Pcu-I z(3n0&?L!Bs{iw^ReGaGed(%(Xk?oGlJR-p?xn1${B zON*yI{H<#IuT=l2)9zoOivRv!104aoJI_F;91P3&Ki_R7eCmI^`M<8^ua97Ztt&HD z`(HHPzkl0}RsJ7d38|4EM0t)z%~=`$EYkn+S)!;Q@5SEHA>@c;SI zKtrqv`QP{UzdCKWTa(~7jqYb&Z~nszf^lXT|B2jD3XjWoch1)S!wbH+JtX~aumA5K za}3}I_~g&szW+_N1D`$Ta+&|XU*&&YkpFvwz+L~pHwYfN|9`WBbhS|7oQ~z}Ketv; zz^&3r(mWSqs8GbwtuU3o1Xv;hc!sF8(#jYux$5W!4emeAtc>4o|1M-|(%H9sGPToV^U11h#^I&04ghkVQ%=7~ z0Tr&tSdzES_6B>PNwi{I5{+L0O=+iR87RFDd}ac zr$P4937FXbn9CbJm6taiD_#!0ZNZDX>gM$0_5oET$FoaNlfi-)$8PW86)U(I)j%cv zf=DUlDF&P$fp5;J@)7_3;?Mn}C)L<3FQYSPNrVF5)9V7mn~MV~xN-atIGuRxfuEPn zH+vl!DYVRdq1dPZ#%{%nv!!?eHJHEOZM|oM0}{+)_$N`}0?2e)gw|qzP`~R0JpclD zrN8#;;N3THn* zKf_5uxpm!eWCvVg+RQ0}h(w7OO>YKl1TWH`K|xl$dhxv4hSz1t1ki=5wrHfR^me_F zh|`qA>g2!3i@n2s+`u`;pECCI_!@LtLg4=KWy$+hN!yuHn{lsl0f(O@V~_IaVu|lT zwahYU@+uWK(s=9Eu z@DZR)hQNIH@F5cD58ibKqEHgcJ6^&&H|TJR5EPU(5XQ}j3t0F?0znTA!|k-VM*o85 z0F(;E=Ru7s(1yORw07)BLq2gdeaNGrz(?}39jgP-n>)Z=)+wrY+@(c=#E=p+Lz$nR zoVdi_gFB1^uWXvs{qmUGYBAc~UUJMt>eat+=s8^|9@J$^-iD8@ZteWFZEV-0zb-sj zO(CXt83U&eun~UP6w-WE0VW1&E02GV0e9ssyhUM=Znib@DMqZi7qIGhXkYkp(&r(5 zi~+1rezYZzk0Eg&H08O4JJ<0ggvdmA0KrY`^7q?TGF8$2>WX$`vTC8{L0o>I@PGwC zA94EjPP?-5=QOPoQn~6$NU_JOTicu27PQBZkL9sh`TJ$5 z6_L^uApjb-A1Ukh8Rxle4XEgL+3?r=wz`hRGktZ~8Cs;1eYa<2!Dy%!5fZ~Ef*BF< zQ=hK9e5gss;G5d=TC@KCPjeB!@wCVK?~d%3NZR>z``|2^Y=VNLVAcYqx$m6C)V&ma z#dJ4^O-fnhAab8p93F6U+$Mn+hU!?~x1m)=-PG#am1AB~c9o z0S?`-SKG#fTd7ZZz~lnW4p1E#d(l?@ICkzUSpF}dvI=nx&I2c9fHz4X_A0MBtXDbby!3(fG z>_O57p6;)*k5^&C8G^&gT|RPV$S)Kz)K}qTgLyf9p%hZf0?JGAXI$oAXpBk37JKlK zqBPhz-Ws?&!?7b^*J3!#=(eUiETIo04nn!Yc1|{(1jgEK#r;72vY0U*t%26l40s95 zEjl2Jieh!c7+1Fdbwud-0dx%5%=@pFJsJ)b5VQxSqogD+o@T8*W29nq!ctF zNpQkfy$l`oEhx?!+qj{rVT#0o;Sla{1mrt9-RSyp>by@O2AmiTa7g82-XE?oSp!>T z0$QpjM8z1MWp&vFz96H-)@sJK4nd_9&kr;R@gJ4f?lf1G zmqUIPZ)bgW%kEh8uDNA3tdThRHgMO~97NMtc*Cyp%~W&+NfTesHu;9AJSnSmyz=Y-l5q+ux?+92+V;7Ov(zYU4~V|k&n`F=Nd!u7+3c>*xR4r zF{ws-av3Eq%It(%t&FQlW+wDZjlH3h)A|w&s-s+nK-*CHlq6@#Q^f!!o`HiyS>pD{ zJE1JeEl0on@BNdc7|;Q|uT><_Y!(u|NR%2x+ThT|g$?PxI4mynoC@0?`S#W<5i0kG zFpu2^KSahYbZf)yvvtInP`g9&C4C<@zeca5d(+rk!Fgv{=*@l|;sOVpo#1kfsp&^d z^6H-3m=4$q4Q2~26`2xZ_W=Q%Sbp0Sf|9Jx0*SMBF0C5+<8)wevd{cms0CuDWMgB? zTstkpJmJLfj5Opr9EsZTI@@o2vBNit)e8k6(&^}97wp83d zOxDORkYmorLgk3al^#DyJAIoen&vu`74Gd3cX1NzeUagtnt3ua^J~#(Q;?S^|1NFp z{PJkExHtA%D$6;-rDIE^tAr1gIzOf&&{}43d*JCed-wN)3URJr9)mE2yd35beUf>` z4t&JHCxRv?kDc+j^x;3=xl-$d9#WsOj z)ZlJ;>jIYnk%g*@kB!*dtGO0m?yacs1H5bpP>=N4)Ud#=tRX>^P%w>%G0`~$gqX+4@wO9{G9)qxWbTU;Hdf?MX&B`Q zd)c0h1}9TRSIo@MQDLJmeX&;u#xqZ6B)<$eam0C<@3QAy9r<)C-S>KPQIP81@xU~HWSk?- zcfceoyK4B4`pKZleeKL?>B-5@B6B%(dIG8bJ78o?VwK&KLluoDCq~?)@f*=svJBN8 zsc^pDwhi6tHu*u>De=+`qBpm@5Pn><K#_Q&{4A*o~zWghVJ65D{mIBhhwLheyR99!CiV4i!@|iG=Khg<23U+9`Fo zErJM)S713B-gdQ{1bxNMzP|L0RDLu96KiSNTk8H`M@AZnkW1-BYw~%ni~_C<4gHii zo$+$Fs5DHAJyxt$G(KGd89LF)nP8+6lf1+9=X0Ch6_cml69CmFHJ{EB_tw_Y#7xDB zJhTfb#+F2liF?1tGvR%<#7#A8e9@9v{%oPR<_14wq1fWF*1JJe$^lxRLsuv=uKLe& z3A!(cc^#A;^7==A{N>hZ{N6jdMfTO5q*co1np4cJ%&0W3-z52lW0du^tu}Dmk!Ut+ zMrNO0NvPsd zTWoS-V~2xli`l(16pUpbr;R_kIz8g`tXQ8?I4F;`;_?5bPk;;VQ%`7vAlF`GzP zdg?A>HVWg6tO|{s$TRj%nt`dKw5OZ}>dVbXJ!G2iXp&~(oxErH?}^CQT?l2+qn&M) zI=MGy55J)XI$GfHw-KZGejS!1U-L(!m3Z?y+xO7#-hI%j3?o)3p0b*6jFzOdfA70T zu5$!)ZjAt|(#r>2WxK#JoGLNAYO>#WQQ*UIY+U-Q@bJtyvpI3}nJEda#IcD8`xn~9 z@L@U45N485gLMgT!E>!IQi*95I_XaP;`0kyp)Ph#A<4N~hj&KfJVc1cdb*Q) ze3v}W*N}UE&(;*{>FGU3T<1D28JfPBAyU-$qeu6C@8f&Ff583W zo`2w+$2p(RIiK@-K0og_*EqX&dV}4kIV>6~dhyH_vt=O;&>!uk7Bv;s->OXAbH90f zT^16tH>H%Gg&71+vt2G`;C-`dOQ;feqrTve3c2Koug&oKj>))RP8uu;82x6`Sbt^` z{hNJ%+0d}z*U!{Bq@PmO{*U>c4!zaxCVoz+$8EWEE+%(s<%U;%Xh&(XX`dY{%$0I@ zjgSL}po>D4pPnK*p@cX_g9`B}2tT*jQX~B|c(`bEHp6oRa$AV#Vz}}%;V&$-bk0A1 zVr_7f14NS{km_}H_ek#6$mFxakjOn5Q(th0+z!cCE-$@n^zANYl?Z&8ls0H;%08OI ze~#p>4}VKNRNl!f{cunrTs+l`bFNZZQ;9g+%;J9oO8q=3Ht#}H_Tn~vLq#I|ih{wi zosnMc=Z}0;{Ry3xr&w^2qI>sL70k_2RqAwg()28*&1Fu4)|0RL%H_vbPJYqn%2Ecs zKF@!-_I0iFJh`!YeL&kOnUkPFv*1tP48bLG^vu@si+tpyCJ$w(&%euxw^0`@woqjD zW(aPI@f|)hxthF?a{P!P`f>e_nbWQ7rc0#PyqV`~VKyTN`cu51$5y!&VJJ*61R@Xv ziP<(5oUT&PYBpJW6eT?W7TAfk@3D-)0yT0uAyJKX(XrW%h9>2YfLCKp_HsxD_2a7x z;2YTmC@aZ1Mj06uAPC3qr}dBnjU$#@5w&2K-0h7%YS+GNF$V$RjyMK!T{R*QF!99o z^fw2|vS&0j_BslvX;-p@IXoknFE?^wjNaAO%?{aOLI~l%U5tB+0<)8992x%D!t%PNAqy;nM>U*O zql5eCaJnaGM)w+^oXom_r6`fn^JLjp%GtY0sb(ZgZ{$54BeDaqgO)~kS_Dx+AJer`=9&+-b+c|6e`n%8iA>8}af4P9HqZdg9A@3{O|G7+&=a#RXv+@g zF0T^_U-O;zCx%kj_ogvohX0`-hISZx=yU2qfGH9^qUpW&%}2O&41`aZ7*m$e>zNo9 zbO=RotXEahqtEya+?iW!7+D{e`js1)} zZ61+c>E|3NPAc-B2l#+VbKI;TyNhB5tU$}B~_eL3Lh9ujg#CIK} zn)kTDWUk8BeYD>R%c;bSW#XzcM6p0Fh>pCa-7Pq(PSI}`uW`{VZlFT(by59wsa>hx zu}MWAqR2IrawJdDxf#uH;d7Oz1Cq(z8olI?=cwdqD z8AGzikk03I41|n5GGeEteZ42MW8mYh{S5Y)}?GmE3w+P(kQJw4b=`M!Cs%ZL}h(5z?!udh=DQ zkYEAhPuduV3d78MQh6$Rv6I7a#S`%Pm%|=<6|$OJ$FQn#u)BAoVui}2N6jw*KWb1% zjXr(b$G8jxRQt#@mr>{Xpcxu(IrO%aF1>)S_z5yW7W=d;6vNc^KE=f~d;j z1Cq}aO)Ne>>eDhW+FDpu3tN?NDZZdJt!yq6&@HXY@t;4k8b7WvAF9(Q>z~)!8>(Z* zp#4*esmw2EZ%86AQ=v}RR^)r-YoT^?^p)y27X+p1E+ZpvcvCE^bDuA1GRJcI~ra@oaTAU;s%jZw@4B0Z@iFy}fOJ+b-f*4E7t)1lRt z`ktM%(#lx-&X*6(!RyhNA1g0~Q)aUyerXQ15*y*H%@E@1)47+1zf9S~{JYX4C+Ludh3_6eao|iYpREDS<8UFZ zaQ09M>rbGFIk3qk)NYRBGWS!pA;fLbXxDlw+*@eUosGu0!VN0_ z(JcL6_>dq07IrZF&}DEDe+8922!Klef4aZi?Ek5o+5PIcFW6ee0=wB>omsW}uSW>r hFGej2B!Rj=1fD^ Tensor - out = F.dropout(x, p=prob, training=True) - out = residual + out - return out - - -class EncdecMultiheadAttn(nn.Module): - """Multi-headed attention. - - See "Attention Is All You Need" for more details. - """ - - def __init__(self, embed_dim, num_heads, dropout=0.0, bias=False, include_norm_add=False, impl="fast"): - super().__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - self.dropout = dropout - self.head_dim = embed_dim // num_heads - assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" - self.bias = bias - self.include_norm_add = include_norm_add - self.impl = impl - self.scaling = self.head_dim ** -0.5 - - self.in_proj_weight_q = Parameter(torch.empty(embed_dim, embed_dim)) - self.in_proj_weight_kv = Parameter(torch.empty(2 * embed_dim, embed_dim)) - self.out_proj_weight = Parameter(torch.empty(embed_dim, embed_dim)) - if self.bias: - assert impl != "fast", "ERROR! The Fast implementation does not support biases!" - self.in_proj_bias_q = Parameter(torch.empty(embed_dim)) - self.in_proj_bias_kv = Parameter(torch.empty(2 * embed_dim)) - self.out_proj_bias = Parameter(torch.empty(embed_dim)) - else: - self.register_parameter("in_proj_bias_q", None) - self.register_parameter("in_proj_bias_kv", None) - self.in_proj_bias_q = None - self.in_proj_bias_kv = None - self.out_proj_bias = None - if self.include_norm_add: - if impl == "fast": - self.lyr_nrm_gamma_weights = Parameter(torch.empty(embed_dim)) - self.lyr_nrm_beta_weights = Parameter(torch.empty(embed_dim)) - self.lyr_nrm = None - else: - self.register_parameter("lyr_norm_gamma_weights", None) - self.register_parameter("lyr_norm_beta_weights", None) - self.lyr_nrm_gamma_weights = None - self.lyr_nrm_beta_weights = None - self.lyr_nrm = FusedLayerNorm(embed_dim) - self.reset_parameters() - - if self.include_norm_add: - if impl == "fast": - self.attn_func = fast_encdec_attn_norm_add_func - elif impl == "default": - self.attn_func = encdec_attn_func - else: - assert False, "Unsupported impl: {} !".format(impl) - else: - if impl == "fast": - self.attn_func = fast_encdec_attn_func - elif impl == "default": - self.attn_func = encdec_attn_func - else: - assert False, "Unsupported impl: {} !".format(impl) - - def reset_parameters(self): - nn.init.xavier_uniform_(self.in_proj_weight_q) - # in_proj_weight_kv has shape [2 * hidden, hidden] but it should be - # initialized like a [hidden, hidden] matrix. - # sqrt(6 / (hidden + hidden)) / sqrt(6 / (2 * hidden + hidden)) = sqrt(1.5) - # therefore xavier_uniform gain should be set to sqrt(1.5). - nn.init.xavier_uniform_(self.in_proj_weight_kv, gain=math.sqrt(1.5)) - nn.init.xavier_uniform_(self.out_proj_weight) - if self.bias: - nn.init.constant_(self.in_proj_bias_q, 0.0) - nn.init.constant_(self.in_proj_bias_kv, 0.0) - nn.init.constant_(self.out_proj_bias, 0.0) - if self.include_norm_add: - if self.impl == "fast": - nn.init.ones_(self.lyr_nrm_gamma_weights) - nn.init.zeros_(self.lyr_nrm_beta_weights) - else: - self.lyr_nrm.reset_parameters() - - def forward(self, query, key, value, key_padding_mask=None, need_weights=False, attn_mask=None, is_training=True): - """Input shape: Time x Batch x Channel - - Self-attention can be implemented by passing in the same arguments for - query, key and value. Future timesteps can be masked with the - `mask_future_timesteps` argument. Padding elements can be excluded from - the key by passing a binary ByteTensor (`key_padding_mask`) with shape: - batch x src_len, where padding elements are indicated by 1s. - """ - - if key_padding_mask is not None: - assert attn_mask is None, "ERROR attn_mask and key_padding_mask should not be both defined!" - mask = key_padding_mask - elif attn_mask is not None: - mask = attn_mask - else: - mask = None - - if self.include_norm_add: - if self.impl == "fast": - outputs = self.attn_func( - attn_mask is not None, - is_training, - self.num_heads, - query, - key, - self.lyr_nrm_gamma_weights, - self.lyr_nrm_beta_weights, - self.in_proj_weight_q, - self.in_proj_weight_kv, - self.out_proj_weight, - mask, - self.dropout, - ) - else: - lyr_nrm_results = self.lyr_nrm(query) - outputs = self.attn_func( - attn_mask is not None, - is_training, - self.num_heads, - self.scaling, - lyr_nrm_results, - key, - self.in_proj_weight_q, - self.in_proj_weight_kv, - self.out_proj_weight, - self.in_proj_bias_q, - self.in_proj_bias_kv, - self.out_proj_bias, - mask, - self.dropout, - ) - if is_training: - print('default:', outputs) - outputs = jit_dropout_add(outputs, query, self.dropout, is_training) - else: - outputs = outputs + query - else: - if self.impl == "fast": - outputs = self.attn_func( - attn_mask is not None, - is_training, - self.num_heads, - query, - key, - self.in_proj_weight_q, - self.in_proj_weight_kv, - self.out_proj_weight, - mask, - self.dropout, - ) - else: - outputs = self.attn_func( - attn_mask is not None, - is_training, - self.num_heads, - self.scaling, - query, - key, - self.in_proj_weight_q, - self.in_proj_weight_kv, - self.out_proj_weight, - self.in_proj_bias_q, - self.in_proj_bias_kv, - self.out_proj_bias, - mask, - self.dropout, - ) - - return outputs, None diff --git a/apex/contrib/multihead_attn/encdec_multihead_attn_func.py b/apex/contrib/multihead_attn/encdec_multihead_attn_func.py deleted file mode 100644 index 5710e87..0000000 --- a/apex/contrib/multihead_attn/encdec_multihead_attn_func.py +++ /dev/null @@ -1,357 +0,0 @@ -import torch -import torch.nn.functional as F - - -class EncdecAttnFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - use_time_mask, - is_training, - heads, - scale, - inputs_q, - inputs_kv, - input_weights_q, - input_weights_kv, - output_weights, - input_biases_q, - input_biases_kv, - output_biases, - mask, - dropout_prob, - ): - use_biases_t = torch.tensor([input_biases_q is not None]) - heads_t = torch.tensor([heads]) - scale_t = torch.tensor([scale]) - dropout_prob_t = torch.tensor([dropout_prob]) - null_tensor = torch.tensor([]) - head_dim = inputs_q.size(2) // heads - - # Input Linear GEMM Q - # input1: (activations) [seql_q, seqs, embed_dim(1024)] - # input2: (weights) [embed_dim (1024), embed_dim (1024)] (transpose [0,1]) - # output: [seql_q, seqs, embed_dim] - # GEMM: ( (seql_q*seqs) x embed_dim ) x ( embed_dim x embed_dim ) = (seql_q*seqs x embed_dim) - if use_biases_t[0]: - input_lin_q_results = torch.addmm( - input_biases_q, - inputs_q.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)), - input_weights_q.transpose(0, 1), - beta=1.0, - alpha=1.0, - ) - else: - input_lin_q_results = torch.mm( - inputs_q.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)), input_weights_q.transpose(0, 1) - ) - input_lin_q_results = input_lin_q_results.view(inputs_q.size(0), inputs_q.size(1), input_weights_q.size(0)) - # Input Linear GEMM KV - # input1: (activations) [seql_k, seqs, embed_dim(1024)] - # input2: (weights) [embed_dim*2 (2048), embed_dim (1024)] (transpose [0,1]) - # output: [seql_k, seqs, embed_dim*2] - # GEMM: ( (seql_k*seqs) x embed_dim ) x ( embed_dim x embed_dim*2 ) = (seql_k*seqs x embed_dim*2) - if use_biases_t[0]: - input_lin_kv_results = torch.addmm( - input_biases_kv, - inputs_kv.view(inputs_kv.size(0) * inputs_kv.size(1), inputs_kv.size(2)), - input_weights_kv.transpose(0, 1), - beta=1.0, - alpha=1.0, - ) - else: - input_lin_kv_results = torch.mm( - inputs_kv.view(inputs_kv.size(0) * inputs_kv.size(1), inputs_kv.size(2)), - input_weights_kv.transpose(0, 1), - ) - input_lin_kv_results = input_lin_kv_results.view(inputs_kv.size(0), inputs_kv.size(1), input_weights_kv.size(0)) - - # Slice out k,v from one big Input Linear outuput (should only impact meta data, no copies!) - # Sequences and heads are combined to make the batch of the Batched GEMM - # input_lin_kv_results: [seql_k, seqs, heads(16), 2, head_dim(64)] - # input_lin_kv_results: [seql_k, batches=seqs*heads, 2, head_dim] - queries = input_lin_q_results.view(inputs_q.size(0), inputs_q.size(1) * heads, head_dim) - input_lin_kv_results = input_lin_kv_results.view(inputs_kv.size(0), inputs_kv.size(1) * heads, 2, head_dim) - keys = input_lin_kv_results[:, :, 0, :] - values = input_lin_kv_results[:, :, 1, :] - - # Matmul1 Batched GEMMs - # The output tensor is specified prior to the Batch GEMM because baddbmm requires its specification - # baddbmm is used to apply the scale parameter via the Batched GEMM's alpha parameter instead of - # a separate elementwise operation. - # Input1: (Queries) [seql_q, seqs*heads, head_dim] tranpose(0,1) - # Input2: (Keys) [seql_k, seqs*heads, head_dim] transpose(0,1) - # output: [seqs*heads, seql_q, seql_k] - # GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k ) - matmul1_results = torch.empty( - (queries.size(1), queries.size(0), keys.size(0)), dtype=queries.dtype, device=torch.device("cuda") - ) - matmul1_results = torch.baddbmm( - matmul1_results, - queries.transpose(0, 1), - keys.transpose(0, 1).transpose(1, 2), - out=matmul1_results, - beta=0.0, - alpha=scale_t[0], - ) - - if mask is not None: - # Self Attention Time Mask - if use_time_mask: - assert len(mask.size()) == 2, "Timing mask is not 2D!" - assert mask.size(0) == mask.size(1), "Sequence length should match!" - mask = mask.to(torch.bool) - matmul1_results = matmul1_results.masked_fill_(mask, float("-inf")) - # Key Padding Mask - else: - batches, seql_q, seql_k = matmul1_results.size() - seqs = int(batches / heads) - matmul1_results = matmul1_results.view(seqs, heads, seql_q, seql_k) - mask = mask.to(torch.bool) - matmul1_results = matmul1_results.masked_fill_(mask.unsqueeze(1).unsqueeze(2), float("-inf")) - matmul1_results = matmul1_results.view(seqs * heads, seql_q, seql_k) - - softmax_results = F.softmax(matmul1_results, dim=-1) - - # Dropout - is not executed for inference - if is_training: - dropout_results, dropout_mask = torch._fused_dropout(softmax_results, p=(1.0 - dropout_prob_t[0])) - else: - dropout_results = softmax_results - dropout_mask = null_tensor - - # Matmul2 Batched GEMMs - # The output tensor specification is needed here to specify the non-standard output. - # Given that pytorch cannot currently perform autograd with an output tensor specified, - # this requires a backward pass specified. - # Input1: from_softmax [seqs*heads, seql_q, seql_k] - # Input2: (values) [seql_v, seqs*heads, head_dim] transpose(0,1) - # Output: [seql_q, seqs*heads, head_dim] transpose(0,1) - # GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = (seql_q x head_dim) - matmul2_results = torch.empty( - (dropout_results.size(1), dropout_results.size(0), values.size(2)), - dtype=dropout_results.dtype, - device=torch.device("cuda"), - ).transpose(1, 0) - matmul2_results = torch.bmm(dropout_results, values.transpose(0, 1), out=matmul2_results) - matmul2_results = ( - matmul2_results.transpose(0, 1).contiguous().view(inputs_q.size(0), inputs_q.size(1), inputs_q.size(2)) - ) - - # Output Linear GEMM - # Input1: (activations) [seql_q, seqs, embed_dim=heads*head_dim] - # Input2: (weights) [ embed_dim, embed_dim ] transpose(0,1) - # Output: [ seql_q, seqs, embed_dim ] - # GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim ) - if use_biases_t[0]: - outputs = torch.addmm( - output_biases, - matmul2_results.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)), - output_weights.transpose(0, 1), - beta=1.0, - alpha=1.0, - ) - else: - outputs = torch.mm( - matmul2_results.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)), - output_weights.transpose(0, 1), - ) - outputs = outputs.view(inputs_q.size(0), inputs_q.size(1), output_weights.size(0)) - - ctx.save_for_backward( - use_biases_t, - heads_t, - scale_t, - matmul2_results, - dropout_results, - softmax_results, - input_lin_q_results, - input_lin_kv_results, - inputs_q, - inputs_kv, - input_weights_q, - input_weights_kv, - output_weights, - dropout_mask, - dropout_prob_t, - ) - - return outputs.detach() - - @staticmethod - def backward(ctx, output_grads): - ( - use_biases_t, - heads_t, - scale_t, - matmul2_results, - dropout_results, - softmax_results, - input_lin_q_results, - input_lin_kv_results, - inputs_q, - inputs_kv, - input_weights_q, - input_weights_kv, - output_weights, - dropout_mask, - dropout_prob_t, - ) = ctx.saved_tensors - - head_dim = inputs_q.size(2) // heads_t[0] - - # Slice out k,v from one big Input Linear outuput (should only impact meta data, no copies!) - # Sequences and heads are combined to make the batch of the Batched GEMM - # input_lin_kv_results: [seql_k, seqs, heads(16), 2, head_dim(64)] - # input_lin_kv_results: [seql_k, batches=seqs*heads, 2, head_dim] - queries = input_lin_q_results.view(inputs_q.size(0), inputs_q.size(1) * heads_t[0], head_dim) - input_lin_kv_results = input_lin_kv_results.view(inputs_kv.size(0), inputs_kv.size(1) * heads_t[0], 2, head_dim) - keys = input_lin_kv_results[:, :, 0, :] - values = input_lin_kv_results[:, :, 1, :] - - # Slice out k,v from one big set of gradients entering the input linear's bprop (should only impact meta data, no copies!) - # The gradients are identical in size to the Input Linear outputs. - # The tensor is declared before hand to properly slice out query, key, and value grads. - input_lin_kv_results_grads = torch.empty_like(input_lin_kv_results) - queries_grads = torch.empty_like(queries) - keys_grads = input_lin_kv_results_grads[:, :, 0, :] - values_grads = input_lin_kv_results_grads[:, :, 1, :] - - # Output Linear GEMM - DGRAD - # Input1: (data grads) [seql_q, seqs, embed_dim=heads*head_dim] - # Input2: (weights) [ embed_dim, embed_dim ] - # Output: [ seql_q, seqs, embed_dim ] - # GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim ) - output_lin_grads = torch.mm( - output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), output_weights - ) - output_lin_grads = output_lin_grads.view(output_grads.size(0), output_grads.size(1), output_weights.size(1)) - # Output Linear GEMM - WGRAD - # Input1: (data grads) [seql_q*seqs, embed_dim=heads*head_dim] transpose(0,1) - # Input2: (activations) [seql_q*seqs, embed_dim ] - # Output: [ seql_q, seqs, embed_dim ] - # GEMM: ( embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = ( embed_dim x embed_dim ) - output_weight_grads = torch.mm( - output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)).transpose(0, 1), - matmul2_results.view(matmul2_results.size(0) * matmul2_results.size(1), matmul2_results.size(2)), - ) - output_lin_grads = output_lin_grads.view( - output_grads.size(0), output_grads.size(1) * heads_t[0], head_dim - ).transpose(0, 1) - - if use_biases_t[0]: - output_bias_grads = torch.sum( - output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), 0 - ) - else: - output_bias_grads = None - - # Matmul2 - DGRAD1 - # Input1: (data grads) [seql_q, seqs*heads, head_dim] transpose(0,1) - # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2) - # Output: [seqs*heads, seql_q, seql_k] - # GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k ) - matmul2_dgrad1 = torch.bmm(output_lin_grads, values.transpose(0, 1).transpose(1, 2)) - # Matmul2 - DGRAD2 - # Input1: (data grads) [seql_q, seqs*heads, head_dim] transpose(0,1) - # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2) - # Output: [seqs*heads, seql_q, seql_k] - # GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k ) - values_grads = torch.bmm(dropout_results.transpose(1, 2), output_lin_grads, out=values_grads.transpose(0, 1)) - - # Mask and Scaling for Dropout (not a publically documented op) - dropout_grads = torch._masked_scale(matmul2_dgrad1, dropout_mask, 1.0 / (1.0 - dropout_prob_t[0])) - - # Softmax Grad (not a publically documented op) - ### softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, softmax_results) # og - softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, torch.float32, grad_input=softmax_results) - - # Matmul1 - DGRAD1 - # Input1: (data grads) [seqs*heads, seql_q, seql_k] - # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1) - # Output: [seqs*heads, seql_q, head_dim] transpose(0,1) - # GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = ( seql_q x head_dim ) - queries_grads = torch.baddbmm( - queries_grads.transpose(0, 1), - softmax_grads, - keys.transpose(0, 1), - out=queries_grads.transpose(0, 1), - beta=0.0, - alpha=scale_t[0], - ) - # Matmul1 - DGRAD2 - # Input1: (data grads) [seqs*heads, seql_q, seql_k] transpose(1,2) - # Input2: (activations) [seql_q, seqs*heads, head_dim] transpose(0,1) - # Output: [seqs*heads, seql_k, head_dim] transpose(0,1) - # GEMM: Per batch: ( seql_k x seql_q ) x ( seql_q x head_dim ) = ( seql_k x head_dim ) - keys_grads = torch.baddbmm( - keys_grads.transpose(0, 1), - softmax_grads.transpose(1, 2), - queries.transpose(0, 1), - out=keys_grads.transpose(0, 1), - beta=0.0, - alpha=scale_t[0], - ) - - # Input Q Linear GEMM - DGRAD - # input1: (data grads) [seql_q, seqs, embed_dim(1024)] - # input2: (weights) [embed_dim (1024), embed_dim (1024)] - # output: [seql_q, seqs, embed_dim] - # GEMM: ( (seql_q*seqs) x embed_dim ) x ( embed_dim x embed_dim ) = (seql_q*seqs x embed_dim) - queries_grads = queries_grads.transpose(0, 1).view(inputs_q.size(0) * inputs_q.size(1), heads_t[0] * head_dim) - input_q_grads = torch.mm(queries_grads, input_weights_q) - input_q_grads = input_q_grads.view(inputs_q.size(0), inputs_q.size(1), inputs_q.size(2)) - # Input KV Linear GEMM - DGRAD - # input1: (data grads) [seql_k, seqs, 2*embed_dim(2048)] - # input2: (weights) [embed_dim*2 (2048), embed_dim (1024)] - # output: [seql_k, seqs, embed_dim] - # GEMM: ( (seql_k*seqs) x 2*embed_dim ) x ( 2*embed_dim x embed_dim ) = (seql_k*seqs x embed_dim) - input_lin_kv_results_grads = input_lin_kv_results_grads.view( - inputs_kv.size(0) * inputs_kv.size(1), heads_t[0] * 2 * head_dim - ) - input_kv_grads = torch.mm(input_lin_kv_results_grads, input_weights_kv) - input_kv_grads = input_kv_grads.view(inputs_kv.size(0), inputs_kv.size(1), inputs_kv.size(2)) - # Input Q Linear GEMM - WGRAD - # input1: (data grads) [seql_q*seqs, embed_dim(1024)] - # input2: (activations) [seql_q*seqs, embed_dim(1024)] - # output: [embed_dim, embed_dim] - # GEMM: ( embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = (embed_dim x embed_dim) - input_weight_q_grads = torch.mm( - queries_grads.transpose(0, 1), inputs_q.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)) - ) - # Input KV Linear GEMM - WGRAD - # input1: (data grads) [seql_k*seqs, 2*embed_dim(2048)] - # input2: (activations) [seql_k*seqs, embed_dim(1024)] - # output: [2*embed_dim, embed_dim] - # GEMM: ( 2*embed_dim x seql_k*seqs ) x ( seql_k*seqs x embed_dim ) = (2*embed_dim x embed_dim) - input_weight_kv_grads = torch.mm( - input_lin_kv_results_grads.transpose(0, 1), - inputs_kv.view(inputs_kv.size(0) * inputs_kv.size(1), inputs_kv.size(2)), - ) - - if use_biases_t[0]: - input_bias_grads_q = torch.sum(queries_grads, 0) - input_bias_grads_kv = torch.sum(input_lin_kv_results_grads, 0) - else: - input_bias_grads_q = None - input_bias_grads_kv = None - - return ( - None, - None, - None, - None, - input_q_grads, - input_kv_grads, - input_weight_q_grads, - input_weight_kv_grads, - output_weight_grads, - input_bias_grads_q, - input_bias_grads_kv, - output_bias_grads, - None, - None, - ) - - -encdec_attn_func = EncdecAttnFunc.apply diff --git a/apex/contrib/multihead_attn/fast_encdec_multihead_attn_func.py b/apex/contrib/multihead_attn/fast_encdec_multihead_attn_func.py deleted file mode 100644 index 9431a49..0000000 --- a/apex/contrib/multihead_attn/fast_encdec_multihead_attn_func.py +++ /dev/null @@ -1,121 +0,0 @@ -import torch - -import fast_multihead_attn - - -class FastEncdecAttnFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - use_time_mask, - is_training, - heads, - inputs_q, - inputs_kv, - input_weights_q, - input_weights_kv, - output_weights, - pad_mask, - dropout_prob, - ): - heads_t = torch.tensor([heads]) - dropout_prob_t = torch.tensor([dropout_prob]) - null_tensor = torch.tensor([]) - use_mask = pad_mask is not None - - ( - input_lin_q_results, - input_lin_kv_results, - softmax_results, - dropout_results, - dropout_mask, - matmul2_results, - outputs, - ) = fast_multihead_attn.encdec_multihead_attn_forward( - use_mask, - use_time_mask, - is_training, - heads, - inputs_q, - inputs_kv, - input_weights_q, - input_weights_kv, - output_weights, - pad_mask if use_mask else null_tensor, - dropout_prob, - ) - - ctx.save_for_backward( - heads_t, - matmul2_results, - dropout_results, - softmax_results, - input_lin_q_results, - input_lin_kv_results, - inputs_q, - inputs_kv, - input_weights_q, - input_weights_kv, - output_weights, - dropout_mask, - dropout_prob_t, - ) - - return outputs.detach() - - @staticmethod - def backward(ctx, output_grads): - ( - heads_t, - matmul2_results, - dropout_results, - softmax_results, - input_lin_q_results, - input_lin_kv_results, - inputs_q, - inputs_kv, - input_weights_q, - input_weights_kv, - output_weights, - dropout_mask, - dropout_prob_t, - ) = ctx.saved_tensors - - ( - input_q_grads, - input_kv_grads, - input_weight_q_grads, - input_weight_kv_grads, - output_weight_grads, - ) = fast_multihead_attn.encdec_multihead_attn_backward( - heads_t[0], - output_grads, - matmul2_results, - dropout_results, - softmax_results, - input_lin_q_results, - input_lin_kv_results, - inputs_q, - inputs_kv, - input_weights_q, - input_weights_kv, - output_weights, - dropout_mask, - dropout_prob_t[0], - ) - - return ( - None, - None, - None, - input_q_grads, - input_kv_grads, - input_weight_q_grads, - input_weight_kv_grads, - output_weight_grads, - None, - None, - ) - - -fast_encdec_attn_func = FastEncdecAttnFunc.apply diff --git a/apex/contrib/multihead_attn/fast_encdec_multihead_attn_norm_add_func.py b/apex/contrib/multihead_attn/fast_encdec_multihead_attn_norm_add_func.py deleted file mode 100644 index 320bebd..0000000 --- a/apex/contrib/multihead_attn/fast_encdec_multihead_attn_norm_add_func.py +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright (c) 2017-present, Facebook, Inc. -# All rights reserved. -# -# This source code is licensed under the license found in the LICENSE file in -# the root directory of this source tree. An additional grant of patent rights -# can be found in the PATENTS file in the same directory. - -import torch - -import fast_multihead_attn - - -class FastEncdecAttnNormAddFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - use_time_mask, - is_training, - heads, - inputs_q, - inputs_kv, - lyr_nrm_gamma_weights, - lyr_nrm_beta_weights, - input_weights_q, - input_weights_kv, - output_weights, - pad_mask, - dropout_prob, - ): - heads_t = torch.tensor([heads]) - dropout_prob_t = torch.tensor([dropout_prob]) - null_tensor = torch.tensor([]) - use_mask = pad_mask is not None - - ( - lyr_nrm_results, - lyr_nrm_mean, - lyr_nrm_invvar, - input_lin_q_results, - input_lin_kv_results, - softmax_results, - dropout_results, - dropout_mask, - matmul2_results, - dropout_add_mask, - outputs, - ) = fast_multihead_attn.encdec_multihead_attn_norm_add_forward( - use_mask, - use_time_mask, - is_training, - heads, - inputs_q, - inputs_kv, - lyr_nrm_gamma_weights, - lyr_nrm_beta_weights, - input_weights_q, - input_weights_kv, - output_weights, - pad_mask if use_mask else null_tensor, - dropout_prob, - ) - - ctx.save_for_backward( - heads_t, - matmul2_results, - dropout_results, - softmax_results, - input_lin_q_results, - input_lin_kv_results, - lyr_nrm_results, - lyr_nrm_mean, - lyr_nrm_invvar, - inputs_q, - inputs_kv, - lyr_nrm_gamma_weights, - lyr_nrm_beta_weights, - input_weights_q, - input_weights_kv, - output_weights, - dropout_mask, - dropout_add_mask, - dropout_prob_t, - ) - - return outputs.detach() - - @staticmethod - def backward(ctx, output_grads): - ( - heads_t, - matmul2_results, - dropout_results, - softmax_results, - input_lin_q_results, - input_lin_kv_results, - lyr_nrm_results, - lyr_nrm_mean, - lyr_nrm_invvar, - inputs_q, - inputs_kv, - lyr_nrm_gamma_weights, - lyr_nrm_beta_weights, - input_weights_q, - input_weights_kv, - output_weights, - dropout_mask, - dropout_add_mask, - dropout_prob_t, - ) = ctx.saved_tensors - - ( - input_q_grads, - input_kv_grads, - lyr_nrm_gamma_grads, - lyr_nrm_beta_grads, - input_weight_q_grads, - input_weight_kv_grads, - output_weight_grads, - ) = fast_multihead_attn.encdec_multihead_attn_norm_add_backward( - heads_t[0], - output_grads, - matmul2_results, - dropout_results, - softmax_results, - input_lin_q_results, - input_lin_kv_results, - lyr_nrm_results, - lyr_nrm_mean, - lyr_nrm_invvar, - inputs_q, - inputs_kv, - lyr_nrm_gamma_weights, - lyr_nrm_beta_weights, - input_weights_q, - input_weights_kv, - output_weights, - dropout_mask, - dropout_add_mask, - dropout_prob_t[0], - ) - - # import pdb; pdb.set_trace() - return ( - None, - None, - None, - input_q_grads, - input_kv_grads, - lyr_nrm_gamma_grads, - lyr_nrm_beta_grads, - input_weight_q_grads, - input_weight_kv_grads, - output_weight_grads, - None, - None, - ) - - -fast_encdec_attn_norm_add_func = FastEncdecAttnNormAddFunc.apply diff --git a/apex/contrib/multihead_attn/fast_self_multihead_attn_func.py b/apex/contrib/multihead_attn/fast_self_multihead_attn_func.py deleted file mode 100644 index 6b50fe2..0000000 --- a/apex/contrib/multihead_attn/fast_self_multihead_attn_func.py +++ /dev/null @@ -1,243 +0,0 @@ -import torch - -import fast_multihead_attn - - -class FastSelfAttnFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - use_time_mask, - is_training, - heads, - inputs, - input_weights, - output_weights, - input_biases, - output_biases, - pad_mask, - mask_additive, - dropout_prob, - ): - use_biases_t = torch.tensor([input_biases is not None]) - heads_t = torch.tensor([heads]) - dropout_prob_t = torch.tensor([dropout_prob]) - null_tensor = torch.tensor([]) - use_mask = pad_mask is not None - mask_additive_t = torch.tensor([mask_additive]) - - if use_biases_t[0]: - if not mask_additive: - ( - input_lin_results, - softmax_results, - dropout_results, - dropout_mask, - matmul2_results, - outputs, - ) = fast_multihead_attn.self_attn_bias_forward( - use_mask, - use_time_mask, - is_training, - heads, - inputs, - input_weights, - output_weights, - input_biases, - output_biases, - pad_mask if use_mask else null_tensor, - dropout_prob, - ) - # fast_self_multihead_attn_bias.forward() \ - ctx.save_for_backward( - use_biases_t, - heads_t, - matmul2_results, - dropout_results, - softmax_results, - null_tensor, - null_tensor, - mask_additive_t, - input_lin_results, - inputs, - input_weights, - output_weights, - dropout_mask, - dropout_prob_t, - ) - - else: - ( - input_lin_results, - bmm1_results, - dropout_results, - dropout_mask, - matmul2_results, - outputs, - ) = fast_multihead_attn.self_attn_bias_additive_mask_forward( - use_mask, - use_time_mask, - is_training, - heads, - inputs, - input_weights, - output_weights, - input_biases, - output_biases, - pad_mask if use_mask else null_tensor, - dropout_prob, - ) - # fast_self_multihead_attn_bias_additive_mask.forward( \ - ctx.save_for_backward( - use_biases_t, - heads_t, - matmul2_results, - dropout_results, - null_tensor, - bmm1_results, - pad_mask, - mask_additive_t, - input_lin_results, - inputs, - input_weights, - output_weights, - dropout_mask, - dropout_prob_t, - ) - - else: - ( - input_lin_results, - softmax_results, - dropout_results, - dropout_mask, - matmul2_results, - outputs, - ) = fast_multihead_attn.self_attn_forward( - use_mask, - use_time_mask, - is_training, - heads, - inputs, - input_weights, - output_weights, - pad_mask if use_mask else null_tensor, - dropout_prob, - ) - # fast_self_multihead_attn.forward( \ - ctx.save_for_backward( - use_biases_t, - heads_t, - matmul2_results, - dropout_results, - softmax_results, - null_tensor, - null_tensor, - mask_additive_t, - input_lin_results, - inputs, - input_weights, - output_weights, - dropout_mask, - dropout_prob_t, - ) - return outputs.detach() - - @staticmethod - def backward(ctx, output_grads): - ( - use_biases_t, - heads_t, - matmul2_results, - dropout_results, - softmax_results, - bmm1_results, - pad_mask, - mask_additive_t, - input_lin_results, - inputs, - input_weights, - output_weights, - dropout_mask, - dropout_prob_t, - ) = ctx.saved_tensors - - if use_biases_t[0]: - if not mask_additive_t[0]: - ( - input_grads, - input_weight_grads, - output_weight_grads, - input_bias_grads, - output_bias_grads, - ) = fast_multihead_attn.self_attn_bias_backward( - heads_t[0], - output_grads, - matmul2_results, - dropout_results, - softmax_results, - input_lin_results, - inputs, - input_weights, - output_weights, - dropout_mask, - dropout_prob_t[0], - ) - # fast_self_multihead_attn_bias.backward( \ - - else: - ( - input_grads, - input_weight_grads, - output_weight_grads, - input_bias_grads, - output_bias_grads, - ) = fast_multihead_attn.self_attn_bias_additive_mask_backward( - heads_t[0], - output_grads, - matmul2_results, - dropout_results, - bmm1_results, - pad_mask, - input_lin_results, - inputs, - input_weights, - output_weights, - dropout_mask, - dropout_prob_t[0], - ) - # fast_self_multihead_attn_bias_additive_mask.backward( \ - - else: - input_bias_grads = None - output_bias_grads = None - input_grads, input_weight_grads, output_weight_grads = fast_multihead_attn.self_attn_backward( - heads_t[0], - output_grads, - matmul2_results, - dropout_results, - softmax_results, - input_lin_results, - inputs, - input_weights, - output_weights, - dropout_mask, - dropout_prob_t[0], - ) - # fast_self_multihead_attn.backward( \ - return ( - None, - None, - None, - input_grads, - input_weight_grads, - output_weight_grads, - input_bias_grads, - output_bias_grads, - None, - None, - None, - ) - - -fast_self_attn_func = FastSelfAttnFunc.apply diff --git a/apex/contrib/multihead_attn/fast_self_multihead_attn_norm_add_func.py b/apex/contrib/multihead_attn/fast_self_multihead_attn_norm_add_func.py deleted file mode 100644 index 7f110cb..0000000 --- a/apex/contrib/multihead_attn/fast_self_multihead_attn_norm_add_func.py +++ /dev/null @@ -1,135 +0,0 @@ -import torch - -import fast_multihead_attn - - -class FastSelfAttnNormAddFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - use_time_mask, - is_training, - heads, - inputs, - lyr_nrm_gamma_weights, - lyr_nrm_beta_weights, - input_weights, - output_weights, - pad_mask, - dropout_prob, - ): - heads_t = torch.tensor([heads]) - dropout_prob_t = torch.tensor([dropout_prob]) - null_tensor = torch.tensor([]) - use_mask = pad_mask is not None - - ( - lyr_nrm_results, - lyr_nrm_mean, - lyr_nrm_invvar, - input_lin_results, - softmax_results, - dropout_results, - dropout_mask, - matmul2_results, - dropout_add_mask, - outputs, - ) = fast_multihead_attn.self_attn_norm_add_forward( - use_mask, - use_time_mask, - is_training, - heads, - inputs, - lyr_nrm_gamma_weights, - lyr_nrm_beta_weights, - input_weights, - output_weights, - pad_mask if use_mask else null_tensor, - dropout_prob, - ) - # fast_self_multihead_attn_norm_add.forward( \ - - ctx.save_for_backward( - heads_t, - matmul2_results, - dropout_results, - softmax_results, - input_lin_results, - lyr_nrm_results, - lyr_nrm_mean, - lyr_nrm_invvar, - inputs, - lyr_nrm_gamma_weights, - lyr_nrm_beta_weights, - input_weights, - output_weights, - dropout_mask, - dropout_add_mask, - dropout_prob_t, - ) - - return outputs.detach() - - @staticmethod - def backward(ctx, output_grads): - ( - heads_t, - matmul2_results, - dropout_results, - softmax_results, - input_lin_results, - lyr_nrm_results, - lyr_nrm_mean, - lyr_nrm_invvar, - inputs, - lyr_nrm_gamma_weights, - lyr_nrm_beta_weights, - input_weights, - output_weights, - dropout_mask, - dropout_add_mask, - dropout_prob_t, - ) = ctx.saved_tensors - - ( - input_grads, - lyr_nrm_gamma_grads, - lyr_nrm_beta_grads, - input_weight_grads, - output_weight_grads, - ) = fast_multihead_attn.self_attn_norm_add_backward( - heads_t[0], - output_grads, - matmul2_results, - dropout_results, - softmax_results, - input_lin_results, - lyr_nrm_results, - lyr_nrm_mean, - lyr_nrm_invvar, - inputs, - lyr_nrm_gamma_weights, - lyr_nrm_beta_weights, - input_weights, - output_weights, - dropout_mask, - dropout_add_mask, - dropout_prob_t[0], - ) - # fast_self_multihead_attn_norm_add.backward( \ - - return ( - None, - None, - None, - input_grads, - lyr_nrm_gamma_grads, - lyr_nrm_beta_grads, - input_weight_grads, - output_weight_grads, - None, - None, - ) - - -fast_self_attn_norm_add_func = FastSelfAttnNormAddFunc.apply diff --git a/apex/contrib/multihead_attn/mask_softmax_dropout_func.py b/apex/contrib/multihead_attn/mask_softmax_dropout_func.py deleted file mode 100644 index b34eec4..0000000 --- a/apex/contrib/multihead_attn/mask_softmax_dropout_func.py +++ /dev/null @@ -1,64 +0,0 @@ -import torch - -import fast_multihead_attn - - -class MaskSoftmaxDropout(torch.autograd.Function): - @staticmethod - def forward(ctx, is_training, heads, inputs, pad_mask, mask_additive, dropout_prob): - heads_t = torch.tensor([heads]) - dropout_prob_t = torch.tensor([dropout_prob]) - null_tensor = torch.tensor([]) - use_mask = pad_mask is not None - use_mask_t = torch.tensor([use_mask]) - mask_additive_t = torch.tensor([mask_additive]) - - if mask_additive: - dropout_results, dropout_mask, softmax_results = fast_multihead_attn.additive_mask_softmax_dropout_forward( - use_mask, is_training, heads, inputs, pad_mask if use_mask else null_tensor, dropout_prob - ) - # fast_additive_mask_softmax_dropout.forward( \ - else: - dropout_results, dropout_mask, softmax_results = fast_multihead_attn.mask_softmax_dropout_forward( - use_mask, is_training, heads, inputs, pad_mask if use_mask else null_tensor, dropout_prob - ) - # fast_mask_softmax_dropout.forward( \ - - ctx.save_for_backward( - use_mask_t, - heads_t, - softmax_results, - dropout_mask, - pad_mask if use_mask else null_tensor, - mask_additive_t, - dropout_prob_t, - ) - - return dropout_results.detach() - - @staticmethod - def backward(ctx, output_grads): - ( - use_mask_t, - heads_t, - softmax_results, - dropout_mask, - pad_mask, - mask_additive_t, - dropout_prob_t, - ) = ctx.saved_tensors - - if mask_additive_t[0]: - input_grads = fast_multihead_attn.additive_mask_softmax_dropout_backward( - use_mask_t[0], heads_t[0], output_grads, softmax_results, dropout_mask, dropout_prob_t[0] - ) - # fast_additive_mask_softmax_dropout.backward( \ - else: - input_grads = fast_multihead_attn.mask_softmax_dropout_backward( - use_mask_t[0], heads_t[0], output_grads, softmax_results, dropout_mask, pad_mask, dropout_prob_t[0] - ) - # fast_mask_softmax_dropout.backward( \ - return None, None, input_grads, None, None, None - - -fast_mask_softmax_dropout_func = MaskSoftmaxDropout.apply diff --git a/apex/contrib/multihead_attn/self_multihead_attn.py b/apex/contrib/multihead_attn/self_multihead_attn.py deleted file mode 100644 index 2806c4d..0000000 --- a/apex/contrib/multihead_attn/self_multihead_attn.py +++ /dev/null @@ -1,255 +0,0 @@ -import math - -import torch -from torch import nn -from torch.nn import Parameter -import torch.nn.functional as F - -from .self_multihead_attn_func import self_attn_func -from .fast_self_multihead_attn_func import fast_self_attn_func -from .fast_self_multihead_attn_norm_add_func import fast_self_attn_norm_add_func -from apex.normalization.fused_layer_norm import FusedLayerNorm - - -@torch.jit.script -def jit_dropout_add(x, residual, prob, is_training): - # type: (Tensor, Tensor, float, bool) -> Tensor - out = F.dropout(x, p=prob, training=True) - out = residual + out - return out - - -class SelfMultiheadAttn(nn.Module): - """Multi-headed attention. - - See "Attention Is All You Need" for more details. - """ - - def __init__( - self, - embed_dim, - num_heads, - dropout=0.0, - bias=False, - include_norm_add=False, - impl="fast", - separate_qkv_params=False, - mask_additive=False, - ): - super().__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - self.dropout = dropout - self.head_dim = embed_dim // num_heads - assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" - self.bias = bias - self.include_norm_add = include_norm_add - self.impl = impl - self.scaling = self.head_dim ** -0.5 - self.separate_qkv_params = separate_qkv_params - self.mask_additive = mask_additive - if mask_additive: - assert self.include_norm_add == False, "additive mask not supported with layer norm" - assert impl == "default" or ( - impl == "fast" and bias - ), "additive mask not supported for fast mode without bias" - if separate_qkv_params: - self.q_weight = Parameter(torch.empty(embed_dim, embed_dim)) - self.k_weight = Parameter(torch.empty(embed_dim, embed_dim)) - self.v_weight = Parameter(torch.empty(embed_dim, embed_dim)) - else: - self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim)) - self.out_proj_weight = Parameter(torch.empty(embed_dim, embed_dim)) - if self.bias: - if separate_qkv_params: - self.q_bias = Parameter(torch.empty(embed_dim)) - self.k_bias = Parameter(torch.empty(embed_dim)) - self.v_bias = Parameter(torch.empty(embed_dim)) - else: - self.in_proj_bias = Parameter(torch.empty(3 * embed_dim)) - self.out_proj_bias = Parameter(torch.empty(embed_dim)) - else: - if separate_qkv_params: - self.register_parameter("q_bias", None) - self.register_parameter("k_bias", None) - self.register_parameter("v_bias", None) - self.q_bias = None - self.k_bias = None - self.v_bias = None - else: - self.register_parameter("in_proj_bias", None) - self.in_proj_bias = None - self.register_parameter("out_proj_bias", None) - self.out_proj_bias = None - if self.include_norm_add: - if impl == "fast": - self.lyr_nrm_gamma_weights = Parameter(torch.empty(embed_dim)) - self.lyr_nrm_beta_weights = Parameter(torch.empty(embed_dim)) - self.lyr_nrm = None - else: - self.register_parameter("lyr_norm_gamma_weights", None) - self.register_parameter("lyr_norm_beta_weights", None) - self.lyr_nrm_gamma_weights = None - self.lyr_nrm_beta_weights = None - self.lyr_nrm = FusedLayerNorm(embed_dim) - self.reset_parameters() - - if self.include_norm_add: - if impl == "fast": - self.attn_func = fast_self_attn_norm_add_func - elif impl == "default": - self.attn_func = self_attn_func - else: - assert False, "Unsupported impl: {} !".format(impl) - else: - if impl == "fast": - self.attn_func = fast_self_attn_func - elif impl == "default": - self.attn_func = self_attn_func - else: - assert False, "Unsupported impl: {} !".format(impl) - - def reset_parameters(self): - if self.separate_qkv_params: - nn.init.xavier_uniform_(self.q_weight) - nn.init.xavier_uniform_(self.k_weight) - nn.init.xavier_uniform_(self.v_weight) - else: - # in_proj_weight has shape [3 * hidden, hidden] but it should be - # initialized like a [hidden, hidden] matrix. - # sqrt(6 / (hidden + hidden)) / sqrt(6 / (3 * hidden + hidden)) = sqrt(2) - # therefore xavier_uniform gain should be set to sqrt(2). - nn.init.xavier_uniform_(self.in_proj_weight, gain=math.sqrt(2)) - nn.init.xavier_uniform_(self.out_proj_weight) - if self.bias: - if self.separate_qkv_params: - nn.init.constant_(self.q_bias, 0.0) - nn.init.constant_(self.k_bias, 0.0) - nn.init.constant_(self.v_bias, 0.0) - else: - nn.init.constant_(self.in_proj_bias, 0.0) - nn.init.constant_(self.out_proj_bias, 0.0) - if self.include_norm_add: - if self.impl == "fast": - nn.init.ones_(self.lyr_nrm_gamma_weights) - nn.init.zeros_(self.lyr_nrm_beta_weights) - else: - self.lyr_nrm.reset_parameters() - - def forward(self, query, key, value, key_padding_mask=None, need_weights=False, attn_mask=None, is_training=True): - """Input shape: Time x Batch x Channel - - Self-attention can be implemented by passing in the same arguments for - query, key and value. Future timesteps can be masked with the - `mask_future_timesteps` argument. Padding elements can be excluded from - the key by passing a binary ByteTensor (`key_padding_mask`) with shape: - batch x src_len, where padding elements are indicated by 1s. - """ - if self.separate_qkv_params: - input_weights = ( - torch.cat( - [ - self.q_weight.view(self.num_heads, 1, self.head_dim, self.embed_dim), - self.k_weight.view(self.num_heads, 1, self.head_dim, self.embed_dim), - self.v_weight.view(self.num_heads, 1, self.head_dim, self.embed_dim), - ], - dim=1, - ) - .reshape(3 * self.embed_dim, self.embed_dim) - .contiguous() - ) - else: - input_weights = self.in_proj_weight - if self.bias: - if self.separate_qkv_params: - input_bias = ( - torch.cat( - [ - self.q_bias.view(self.num_heads, 1, self.head_dim), - self.k_bias.view(self.num_heads, 1, self.head_dim), - self.v_bias.view(self.num_heads, 1, self.head_dim), - ], - dim=1, - ) - .reshape(3 * self.embed_dim) - .contiguous() - ) - else: - input_bias = self.in_proj_bias - else: - input_bias = None - if key_padding_mask is not None: - assert attn_mask is None, "ERROR attn_mask and key_padding_mask should not be both defined!" - mask = key_padding_mask - elif attn_mask is not None: - assert self.mask_additive == False, "additive mask not supported for time mask" - mask = attn_mask - else: - mask = None - - if self.include_norm_add: - if self.impl == "fast": - outputs = self.attn_func( - attn_mask is not None, - is_training, - self.num_heads, - query, - self.lyr_nrm_gamma_weights, - self.lyr_nrm_beta_weights, - input_weights, - self.out_proj_weight, - mask, - self.dropout, - ) - else: - lyr_nrm_results = self.lyr_nrm(query) - outputs = self.attn_func( - attn_mask is not None, - is_training, - self.num_heads, - self.scaling, - lyr_nrm_results, - input_weights, - self.out_proj_weight, - input_bias, - self.out_proj_bias, - mask, - self.mask_additive, - self.dropout, - ) - if is_training: - outputs = jit_dropout_add(outputs, query, self.dropout, is_training) - else: - outputs = outputs + query - else: - if self.impl == "fast": - outputs = self.attn_func( - attn_mask is not None, - is_training, - self.num_heads, - query, - input_weights, - self.out_proj_weight, - input_bias, - self.out_proj_bias, - mask, - self.mask_additive, - self.dropout, - ) - else: - outputs = self.attn_func( - attn_mask is not None, - is_training, - self.num_heads, - self.scaling, - query, - input_weights, - self.out_proj_weight, - input_bias, - self.out_proj_bias, - mask, - self.mask_additive, - self.dropout, - ) - - return outputs, None diff --git a/apex/contrib/multihead_attn/self_multihead_attn_func.py b/apex/contrib/multihead_attn/self_multihead_attn_func.py deleted file mode 100644 index f26e704..0000000 --- a/apex/contrib/multihead_attn/self_multihead_attn_func.py +++ /dev/null @@ -1,308 +0,0 @@ -import torch -import torch.nn.functional as F - - -class SelfAttnFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, - use_time_mask, - is_training, - heads, - scale, - inputs, - input_weights, - output_weights, - input_biases, - output_biases, - mask, - is_additive_mask, - dropout_prob, - ): - use_biases_t = torch.tensor([input_biases is not None]) - heads_t = torch.tensor([heads]) - scale_t = torch.tensor([scale]) - dropout_prob_t = torch.tensor([dropout_prob]) - null_tensor = torch.tensor([]) - head_dim = inputs.size(2) // heads - - # Input Linear GEMM - # input1: (activations) [seql_q, seqs, embed_dim(1024)] - # input2: (weights) [embed_dim*3 (3072), embed_dim (1024)] (transpose [0,1]) - # output: [seql_q, seqs, embed_dim*3] - # GEMM: ( (seql_q*seqs) x embed_dim ) x ( embed_dim x embed_dim*3 ) = (seql_q*seqs x embed_dim*3) - if use_biases_t[0]: - input_lin_results = torch.addmm( - input_biases, - inputs.view(inputs.size(0) * inputs.size(1), inputs.size(2)), - input_weights.transpose(0, 1), - beta=1.0, - alpha=1.0, - ) - else: - input_lin_results = torch.mm( - inputs.view(inputs.size(0) * inputs.size(1), inputs.size(2)), input_weights.transpose(0, 1) - ) - input_lin_results = input_lin_results.view(inputs.size(0), inputs.size(1), input_weights.size(0)) - - # Slice out q,k,v from one big Input Linear outuput (should only impact meta data, no copies!) - # Sequences and heads are combined to make the batch of the Batched GEMM - # input_lin_results: [seql_q, seqs, heads(16), 3, head_dim(64)] - # input_lin_results: [seql_q, batches=seqs*heads, 3, head_dim] - input_lin_results = input_lin_results.view(inputs.size(0), inputs.size(1) * heads, 3, head_dim) - queries = input_lin_results[:, :, 0, :] - keys = input_lin_results[:, :, 1, :] - values = input_lin_results[:, :, 2, :] - - # Matmul1 Batched GEMMs - # The output tensor is specified prior to the Batch GEMM because baddbmm requires its specification - # baddbmm is used to apply the scale parameter via the Batched GEMM's alpha parameter instead of - # a separate elementwise operation. - # Input1: (Queries) [seql_q, seqs*heads, head_dim] tranpose(0,1) - # Input2: (Keys) [seql_k, seqs*heads, head_dim] transpose(0,1) - # output: [seqs*heads, seql_q, seql_k] - # GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k ) - matmul1_results = torch.empty( - (queries.size(1), queries.size(0), keys.size(0)), dtype=queries.dtype, device=torch.device("cuda") - ) - matmul1_results = torch.baddbmm( - matmul1_results, - queries.transpose(0, 1), - keys.transpose(0, 1).transpose(1, 2), - out=matmul1_results, - beta=0.0, - alpha=scale_t[0], - ) - - if mask is not None: - # Self Attention Time Mask - if use_time_mask: - assert len(mask.size()) == 2, "Timing mask is not 2D!" - assert mask.size(0) == mask.size(1), "Sequence length should match!" - mask = mask.to(torch.bool) - matmul1_results = matmul1_results.masked_fill_(mask, float("-inf")) - # Key Padding Mask - else: - batches, seql_q, seql_k = matmul1_results.size() - seqs = int(batches / heads) - matmul1_results = matmul1_results.view(seqs, heads, seql_q, seql_k) - if is_additive_mask: - matmul1_results = matmul1_results + mask.unsqueeze(1).unsqueeze(2) - else: - mask = mask.to(torch.bool) - matmul1_results = matmul1_results.masked_fill_(mask.unsqueeze(1).unsqueeze(2), float("-inf")) - matmul1_results = matmul1_results.view(seqs * heads, seql_q, seql_k) - - softmax_results = F.softmax(matmul1_results, dim=-1) - - # Dropout - is not executed for inference - if is_training: - dropout_results, dropout_mask = torch._fused_dropout(softmax_results, p=(1.0 - dropout_prob_t[0])) - else: - dropout_results = softmax_results - dropout_mask = null_tensor - - # Matmul2 Batched GEMMs - # The output tensor specification is needed here to specify the non-standard output. - # Given that pytorch cannot currently perform autograd with an output tensor specified, - # this requires a backward pass specified. - # Input1: from_softmax [seqs*heads, seql_q, seql_k] - # Input2: (values) [seql_v, seqs*heads, head_dim] transpose(0,1) - # Output: [seql_q, seqs*heads, head_dim] transpose(0,1) - # GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = (seql_q x head_dim) - matmul2_results = torch.empty( - (dropout_results.size(1), dropout_results.size(0), values.size(2)), - dtype=dropout_results.dtype, - device=torch.device("cuda"), - ).transpose(1, 0) - matmul2_results = torch.bmm(dropout_results, values.transpose(0, 1), out=matmul2_results) - matmul2_results = ( - matmul2_results.transpose(0, 1).contiguous().view(inputs.size(0), inputs.size(1), inputs.size(2)) - ) - - # Output Linear GEMM - # Input1: (activations) [seql_q, seqs, embed_dim=heads*head_dim] - # Input2: (weights) [ embed_dim, embed_dim ] transpose(0,1) - # Output: [ seql_q, seqs, embed_dim ] - # GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim ) - if use_biases_t[0]: - outputs = torch.addmm( - output_biases, - matmul2_results.view(inputs.size(0) * inputs.size(1), inputs.size(2)), - output_weights.transpose(0, 1), - beta=1.0, - alpha=1.0, - ) - else: - outputs = torch.mm( - matmul2_results.view(inputs.size(0) * inputs.size(1), inputs.size(2)), output_weights.transpose(0, 1) - ) - outputs = outputs.view(inputs.size(0), inputs.size(1), output_weights.size(0)) - - ctx.save_for_backward( - use_biases_t, - heads_t, - scale_t, - matmul2_results, - dropout_results, - softmax_results, - input_lin_results, - inputs, - input_weights, - output_weights, - dropout_mask, - dropout_prob_t, - ) - - return outputs.detach() - - @staticmethod - def backward(ctx, output_grads): - ( - use_biases_t, - heads_t, - scale_t, - matmul2_results, - dropout_results, - softmax_results, - input_lin_results, - inputs, - input_weights, - output_weights, - dropout_mask, - dropout_prob_t, - ) = ctx.saved_tensors - - head_dim = inputs.size(2) // heads_t[0] - - # Slice out q,k,v from one big Input Linear outuput (should only impact meta data, no copies!) - # Sequences and heads are combined to make the batch of the Batched GEMM - # input_lin_results: [seql_q, seqs, heads(16), 3, head_dim(64)] - # input_lin_results: [seql_q, batches=seqs*heads, 3, head_dim] - input_lin_results = input_lin_results.view(inputs.size(0), inputs.size(1) * heads_t[0], 3, head_dim) - queries = input_lin_results[:, :, 0, :] - keys = input_lin_results[:, :, 1, :] - values = input_lin_results[:, :, 2, :] - - # Slice out q,k,v from one big set of gradients entering the input linear's bprop (should only impact meta data, no copies!) - # The gradients are identical in size to the Input Linear outputs. - # The tensor is declared before hand to properly slice out query, key, and value grads. - input_lin_results_grads = torch.empty_like(input_lin_results) - queries_grads = input_lin_results_grads[:, :, 0, :] - keys_grads = input_lin_results_grads[:, :, 1, :] - values_grads = input_lin_results_grads[:, :, 2, :] - - # Output Linear GEMM - DGRAD - # Input1: (data grads) [seql_q, seqs, embed_dim=heads*head_dim] - # Input2: (weights) [ embed_dim, embed_dim ] - # Output: [ seql_q, seqs, embed_dim ] - # GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim ) - output_lin_grads = torch.mm( - output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), output_weights - ) - output_lin_grads = output_lin_grads.view(output_grads.size(0), output_grads.size(1), output_weights.size(1)) - # Output Linear GEMM - WGRAD - # Input1: (data grads) [seql_q*seqs, embed_dim=heads*head_dim] transpose(0,1) - # Input2: (activations) [seql_q*seqs, embed_dim ] - # Output: [ seql_q, seqs, embed_dim ] - # GEMM: ( embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = ( embed_dim x embed_dim ) - output_weight_grads = torch.mm( - output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)).transpose(0, 1), - matmul2_results.view(matmul2_results.size(0) * matmul2_results.size(1), matmul2_results.size(2)), - ) - output_lin_grads = output_lin_grads.view(inputs.size(0), inputs.size(1) * heads_t[0], head_dim).transpose(0, 1) - - if use_biases_t[0]: - output_bias_grads = torch.sum( - output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), 0 - ) - else: - output_bias_grads = None - - # Matmul2 - DGRAD1 - # Input1: (data grads) [seql_q, seqs*heads, head_dim] transpose(0,1) - # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2) - # Output: [seqs*heads, seql_q, seql_k] - # GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k ) - matmul2_dgrad1 = torch.bmm(output_lin_grads, values.transpose(0, 1).transpose(1, 2)) - # Matmul2 - DGRAD2 - # Input1: (data grads) [seql_q, seqs*heads, head_dim] transpose(0,1) - # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2) - # Output: [seqs*heads, seql_q, seql_k] - # GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k ) - values_grads = torch.bmm(dropout_results.transpose(1, 2), output_lin_grads, out=values_grads.transpose(0, 1)) - - # Mask and Scaling for Dropout (not a publically documented op) - dropout_grads = torch._masked_scale(matmul2_dgrad1, dropout_mask, 1.0 / (1.0 - dropout_prob_t[0])) - - # Softmax Grad (not a publically documented op) - ### softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, softmax_results) # og - softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, torch.float32, grad_input=softmax_results) - - # Matmul1 - DGRAD1 - # Input1: (data grads) [seqs*heads, seql_q, seql_k] - # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1) - # Output: [seqs*heads, seql_q, head_dim] transpose(0,1) - # GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = ( seql_q x head_dim ) - queries_grads = torch.baddbmm( - queries_grads.transpose(0, 1), - softmax_grads, - keys.transpose(0, 1), - out=queries_grads.transpose(0, 1), - beta=0.0, - alpha=scale_t[0], - ) - # Matmul1 - DGRAD2 - # Input1: (data grads) [seqs*heads, seql_q, seql_k] transpose(1,2) - # Input2: (activations) [seql_q, seqs*heads, head_dim] transpose(0,1) - # Output: [seqs*heads, seql_k, head_dim] transpose(0,1) - # GEMM: Per batch: ( seql_k x seql_q ) x ( seql_q x head_dim ) = ( seql_k x head_dim ) - keys_grads = torch.baddbmm( - keys_grads.transpose(0, 1), - softmax_grads.transpose(1, 2), - queries.transpose(0, 1), - out=keys_grads.transpose(0, 1), - beta=0.0, - alpha=scale_t[0], - ) - - # Input Linear GEMM - DGRAD - # input1: (data grads) [seql_q, seqs, 3*embed_dim(3072)] - # input2: (weights) [embed_dim*3 (3072), embed_dim (1024)] - # output: [seql_q, seqs, embed_dim] - # GEMM: ( (seql_q*seqs) x 3*embed_dim ) x ( 3*embed_dim x embed_dim ) = (seql_q*seqs x embed_dim) - input_lin_results_grads = input_lin_results_grads.view( - inputs.size(0) * inputs.size(1), heads_t[0] * 3 * head_dim - ) - input_grads = torch.mm(input_lin_results_grads, input_weights) - input_grads = input_grads.view(inputs.size(0), inputs.size(1), inputs.size(2)) - # Input Linear GEMM - WGRAD - # input1: (data grads) [seql_q*seqs, 3*embed_dim(3072)] - # input2: (activations) [seql_q*seqs, embed_dim(1024)] - # output: [3*embed_dim, embed_dim] - # GEMM: ( 3*embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = (3*embed_dim x embed_dim) - input_weight_grads = torch.mm( - input_lin_results_grads.transpose(0, 1), inputs.view(inputs.size(0) * inputs.size(1), inputs.size(2)) - ) - - if use_biases_t[0]: - input_bias_grads = torch.sum(input_lin_results_grads, 0) - else: - input_bias_grads = None - - return ( - None, - None, - None, - None, - input_grads, - input_weight_grads, - output_weight_grads, - input_bias_grads, - output_bias_grads, - None, - None, - ) - - -self_attn_func = SelfAttnFunc.apply diff --git a/apex/contrib/optimizers/__init__.py b/apex/contrib/optimizers/__init__.py deleted file mode 100644 index 1933b2f..0000000 --- a/apex/contrib/optimizers/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .fp16_optimizer import FP16_Optimizer -from .fused_adam import FusedAdam -from .fused_lamb import FusedLAMB diff --git a/apex/contrib/optimizers/distributed_fused_adam.py b/apex/contrib/optimizers/distributed_fused_adam.py deleted file mode 100644 index 5500680..0000000 --- a/apex/contrib/optimizers/distributed_fused_adam.py +++ /dev/null @@ -1,1280 +0,0 @@ -import collections -import contextlib -import enum -import importlib -import inspect -import io -import math -import threading - -import torch -import amp_C -from apex.multi_tensor_apply import multi_tensor_applier -from torch.distributed.distributed_c10d import _get_default_group, _get_global_rank - -def _round_to_multiple(number, multiple, round_up=True): - """Assumes arguments are positive integers""" - return (number+multiple-1 if round_up else number) // multiple * multiple - -class DistributedFusedAdam(torch.optim.Optimizer): - """AdamW optimizer with ZeRO algorithm. - - Currently GPU-only. Requires Apex to be installed via - ``python setup.py install --cuda_ext --cpp_ext``. - - This implements the ZeRO-2 algorithm, which distributes the - optimizer state and gradients between parallel processes. In - particular, the parameters are flattened, grouped into fixed-size - buckets, and the optimizer state for each bucket is sharded over - the parallel processes. Options are provided to overlap the - gradient synchronization with the backward pass compute. - - Adam was proposed in `Adam: A Method for Stochastic - Optimization`_, AdamW in `Decoupled Weight Decay Regularization`_, - and ZeRO in `ZeRO: Memory Optimizations Toward Training Trillion - Parameter Models`_. - - Arguments: - params (iterable): iterable of parameters to optimize or dicts - defining parameter groups. - lr (float, optional): learning rate. (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for - computing running averages of gradient and its square. - (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability. (default: 1e-8) - weight_decay (float, optional): weight decay (L2 penalty) - (default: 0) - amsgrad (boolean, optional): whether to use the AMSGrad - variant of this algorithm from the paper - `On the Convergence of Adam and Beyond`_ (default: False). - This is not yet supported. - dtype (torch.dtype, optional): datatype for optimizer state - (default: torch.float32) - grad_sync_dtype (torch.dtype, optional): datatype for gradient - synchronization (default: same as dtype) - param_sync_dtype (torch.dtype, optional): datatype for - parameter synchronization (default: same as dtype) - device (torch.device, optional): device for optimizer state - (default: cuda). Currently only supports GPU with one GPU - per process. - process_group (torch.distributed.ProcessGroup, optional): - parallel processes participating in optimizer (default: - default group in torch.distributed). This group is - interpreted as a 2D grid with dimensions - distributed_size x redundant_size. - distributed_process_group (torch.distributed.ProcessGroup, - optional): parallel processes to distribute optimizer - state over (default: same as process_group) - redundant_process_group (torch.distributed.ProcessGroup, - optional): parallel processes to replicate optimizer state - over (default: group only containing calling process) - average_grad_sync (bool, optional): whether to use average - reduction for gradient synchronization rather than sum - (default: True) - overlap_grad_sync(boolean, optional): whether to overlap - gradient synchronization with backward pass compute - (default: True) - bucket_cap_mb (float, optional): bucket size in megabytes - (default: 100) - pipeline_size (int, optional): number of buckets to - synchronize simultaneously (default: 2) - contiguous_grad_buffer (bool, optional): allocate gradient - buckets out of a large persistent buffer (default: False). - This allows individual parameter gradients to be accessed - externally (see grad_buffer_view function). It also - maximizes memory usage and may prevent overlapping - communication and compute. - - .. _Adam\: A Method for Stochastic Optimization: - https://arxiv.org/abs/1412.6980 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ - .. _Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101 - .. _ZeRO\: Memory Optimizations Toward Training Trillion Parameter Models: - https://arxiv.org/abs/1910.02054 - - """ - - class ParameterFragment: - """Buffer ranges for a parameter fragment - - Describes corresponding regions in parameter buffer and - parameter bucket. - - """ - def __init__( - self, - param_group_id, - param_id, - bucket_id, - param_range, - bucket_range, - in_local_shard, - shard_range, - shard_bucket_range, - shard_param_range, - ): - # Parameter group index - self.param_group_id = param_group_id - # Parameter index within parameter group - self.param_id = param_id - # Bucket index - self.bucket_id = bucket_id - # Range within flattened parameter buffer - self.param_range = param_range - # Range within bucket - self.bucket_range = bucket_range - # Whether fragment is in local shard of bucket - self.in_local_shard = in_local_shard - # Range within local shard - self.shard_range = shard_range - # Range of local fragment shard within bucket - self.shard_bucket_range = shard_bucket_range - # Range of local fragment shard within parameter - self.shard_param_range = shard_param_range - - class StateBucket: - def __init__(self, shard_size, dtype, device): - """Optimizer state for a bucket""" - # Buffer ranges corresponding to parameter fragments - self.fragments = [] - # Local shard of parameters - self.params_shard = torch.zeros([shard_size], dtype=dtype, device=device) - # Local shard of first moment estimate - self.exp_avg_shard = torch.zeros([shard_size], dtype=dtype, device=device) - # Local shard of second moment estimate - self.exp_avg_sq_shard = torch.zeros([shard_size], dtype=dtype, device=device) - - class GradientStatus(enum.Enum): - """Status of gradients within a bucket""" - # Gradients are ready to use - READY = enum.auto() - # Bucket is partially filled with unreduced gradients - PARTIALLY_FILLED = enum.auto() - # Bucket is fully filled with unreduced gradients - FULLY_FILLED = enum.auto() - # Asynchronous reduction is in progress - SYNCING = enum.auto() - - class GradientBucket: - """Gradient buffers and state for a bucket""" - def __init__(self): - # Local shard of gradients - self.grads_shard = None - # Local contribution to gradients - self.grads_bucket = None - # Buffer for gradient reduce-scatter - self.sync_grads_shard = None - # Status of gradients - self.status = DistributedFusedAdam.GradientStatus.READY - # Request object for asynchronous communication - self.sync_request = None - - def sync_wait(self): - """Wait for asynchronous communication to finish""" - if self.sync_request is not None: - self.sync_request.wait() - self.sync_request = None - - _step_supports_amp_scaling = True - - def __init__(self, - params, - lr=1e-3, - bias_correction=True, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=0., - amsgrad=False, - dtype=torch.float32, - grad_sync_dtype=None, - param_sync_dtype=None, - device='cuda', - process_group=None, - distributed_process_group=None, - redundant_process_group=None, - average_grad_sync=True, - overlap_grad_sync=True, - bucket_cap_mb=100, - pipeline_size=2, - contiguous_grad_buffer=False, - ): - defaults = dict(lr=lr, bias_correction=bias_correction, - betas=betas, eps=eps, weight_decay=weight_decay) - super(DistributedFusedAdam, self).__init__(params, defaults) - - # Adam options - if amsgrad: - raise RuntimeError('DistributedFusedAdam does not support the AMSGrad variant.') - - # Datatype options - if grad_sync_dtype is None: - grad_sync_dtype = dtype - if param_sync_dtype is None: - param_sync_dtype = dtype - supported_dtypes = [ - (torch.float32, torch.float16), - (torch.float32, torch.float32), - ] - if (dtype, grad_sync_dtype) not in supported_dtypes: - raise RuntimeError( - 'Invalid dtypes for DistributedFusedAdam ' - f'(dtype={dtype}, ' - f'grad_sync_dtype={grad_sync_dtype}, ' - f'param_sync_dtype={param_sync_dtype}))') - if device != 'cuda': - raise RuntimeError('DistributedFusedAdam only supports GPU') - self.dtype = dtype - self.grad_sync_dtype = grad_sync_dtype - self.param_sync_dtype = param_sync_dtype - self.device = device - - # Process groups - self.process_group = ( - _get_default_group() - if process_group is None - else process_group - ) - self.distributed_process_group = ( - self.process_group - if distributed_process_group is None - else distributed_process_group - ) - self.redundant_process_group = redundant_process_group - self.process_group_size = torch.distributed.get_world_size(self.process_group) - self.distributed_rank = torch.distributed.get_rank(self.distributed_process_group) - self.distributed_size = torch.distributed.get_world_size(self.distributed_process_group) - self.redundant_size = ( - 1 - if self.redundant_process_group is None - else torch.distributed.get_world_size(self.redundant_process_group) - ) - if self.process_group_size != self.distributed_size * self.redundant_size: - raise RuntimeError( - 'Invalid process group configuration ' - f'(process group size = {self.process_group_size}, ' - f'distributed process group size = {self.distributed_size}, ' - f'redundant process group size = {self.redundant_size})' - ) - try: - self._process_group_ranks = [ - _get_global_rank(self.process_group, local_rank) - for local_rank in range(self.distributed_size) - ] - except: - self._process_group_ranks = list(range(self.distributed_size)) - - # Use average reduction for grad sync - self.average_grad_sync = average_grad_sync - # Copy param grads to bucket as soon as available - self.greedy_grad_copy = True - # Synchronize grad buckets as soon as all grads are available - self.overlap_grad_sync = overlap_grad_sync - # Number of buckets to synchronize at a time - self.pipeline_size = pipeline_size - # Allocate contiguous buffer for gradients - self.contiguous_grad_buffer = contiguous_grad_buffer - - # Determine bucket sizes - dtype_size = torch.finfo(self.grad_sync_dtype).bits // 8 - self.alignment = 128 // dtype_size - bucket_size = 1024*1024*bucket_cap_mb / dtype_size - shard_size = int(bucket_size / self.distributed_size) - shard_size = _round_to_multiple(shard_size, self.alignment, round_up=False) - shard_size = max(shard_size, self.alignment) - bucket_size = shard_size * self.distributed_size - self.bucket_size = bucket_size - self.shard_size = shard_size - - # Load CUDA kernels - global fused_adam_cuda, distributed_adam_cuda - fused_adam_cuda = importlib.import_module("fused_adam_cuda") - distributed_adam_cuda = importlib.import_module("distributed_adam_cuda") - - # Optimizer state - self.state['buckets'] = [] - self.state['step'] = 0 - - # Objects for gradient synchronization - self._grads_buckets = collections.defaultdict(self.GradientBucket) - self._grads_generated = set() - self._pipeline_streams = [torch.cuda.Stream() for _ in range(self.pipeline_size)] - - # Divide gradients by factor before optimizer step. Used for - # grad clipping and gradient scaler. - self._inv_grad_scale = torch.full([1], 1.0, dtype=self.dtype, device=self.device) - # Norm of parameter gradients. Used for gradient clipping and - # gradient scaler. - self._grad_norm = None - - # Check if collectives have no_copy option - self._reduce_scatter_no_copy = ( - 'no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args - ) - self._all_gather_no_copy = ( - 'no_copy' in inspect.getfullargspec(torch.distributed.all_gather).args - ) - self._gather_no_copy = ( - 'no_copy' in inspect.getfullargspec(torch.distributed.gather).args - ) - - # Attach hooks for gradient synchronization - self._register_post_backward_hooks() - - def _register_post_backward_hooks(self): - """Attach hooks for gradient synchronization - - Optimizer state for parameters are initialized lazily as they - are encountered in the backward pass. - - """ - self._num_grads = 0 - grad_buffer_size = 0 - self._lock = threading.Lock() - self._grad_accs = [] - for param_group_id, group in enumerate(self.param_groups): - for param_id, param in enumerate(group['params']): - torch.distributed.broadcast( - param, - src=self._process_group_ranks[0], - group=self.process_group, - ) - if param.requires_grad: - self._num_grads += 1 - - # Callback after gradient is generated - def wrapper(p, p_group_id, p_id): - p_tmp = p.expand_as(p) - grad_acc = p_tmp.grad_fn.next_functions[0][0] - def reduction_hook(*unused): - with self._lock: - if 'fragments' not in self.state[p]: - self._init_param_state(p, p_group_id, p_id) - if self.greedy_grad_copy: - self._grad_copy(p) - if self.overlap_grad_sync: - self._try_start_bucket_grad_sync( - params=[p], - ignore_last_bucket=True, - ) - grad_acc.register_hook(reduction_hook) - self._grad_accs.append(grad_acc) - wrapper(param, param_group_id, param_id) - - # Gradient size, with padding for alignment - grad_size = _round_to_multiple(param.numel(), self.alignment) - grad_buffer_size += grad_size - - # Allocate contiguous gradient buffer if needed - if self.contiguous_grad_buffer: - grad_buffer_size = _round_to_multiple( - grad_buffer_size, - self.bucket_size, - ) - self._grad_buffer = torch.zeros( - [grad_buffer_size], - dtype=self.dtype, - device=self.device, - ) - - def init_params(self, params=None): - """Initialize optimizer state for parameters - - Arguments: - params (iterable, optional): parameters to initialize - (default: all parameters) - - """ - - # Default cases - if isinstance(params, torch.Tensor): - params = [params] - elif params is None: - params = [] - for group in self.param_groups: - params.extend(group['params']) - - # Get indices corresponding to parameters - id_map = dict() - for param_group_id, group in enumerate(self.param_groups): - for param_id, param in enumerate(group['params']): - id_map[param] = (param_group_id, param_id) - - # Initialize parameters - for param in params: - if param in id_map and 'fragments' not in self.state[param]: - param_group_id, param_id = id_map[param] - self._init_param_state(param, param_group_id, param_id) - - def _init_param_state( - self, - param, - param_group_id, - param_id, - ): - """Initialize optimizer state for a parameter""" - - # Make sure there is at least one bucket - if not self.state['buckets']: - self.state['buckets'].append( - self.StateBucket(self.shard_size, self.dtype, self.device) - ) - - # Split parameter values into fragments - # Note: Each fragment resides within a bucket - param_start = 0 - param_size = param.numel() - self.state[param]['fragments'] = [] - while param_start < param_size: - - # Get current bucket - bucket_id = len(self.state['buckets']) - 1 - bucket = self.state['buckets'][bucket_id] - fragment_id = len(bucket.fragments) - - # Determine fragment position within bucket - if fragment_id == 0: - bucket_start = 0 - else: - _, bucket_start = bucket.fragments[-1].bucket_range - bucket_start = _round_to_multiple(bucket_start, self.alignment) - fragment_size = min(param_size-param_start, self.bucket_size-bucket_start) - param_end = param_start + fragment_size - bucket_end = bucket_start + fragment_size - - # Create new bucket if current one is full - if fragment_size <= 0: - self.state['buckets'].append( - self.StateBucket(self.shard_size, self.dtype, self.device) - ) - continue - - # Fragment position within local shard - shard_id = self.distributed_rank - shard_start = bucket_start - self.shard_size*shard_id - shard_end = bucket_end - self.shard_size*shard_id - shard_start = min(max(shard_start, 0), self.shard_size) - shard_end = min(max(shard_end, 0), self.shard_size) - in_local_shard = shard_start < shard_end - if in_local_shard: - shard_bucket_start = shard_start + self.shard_size*shard_id - shard_bucket_end = shard_bucket_start + shard_end - shard_start - shard_param_start = shard_bucket_start - bucket_start + param_start - shard_param_end = shard_param_start + shard_end - shard_start - else: - shard_bucket_start, shard_bucket_end = None, None - shard_param_start, shard_param_end = None, None - - # Record fragment info - fragment = self.ParameterFragment( - param_group_id=param_group_id, - param_id=param_id, - bucket_id=bucket_id, - param_range=(param_start,param_end), - bucket_range=(bucket_start,bucket_end), - in_local_shard=in_local_shard, - shard_range=(shard_start,shard_end), - shard_bucket_range=(shard_bucket_start,shard_bucket_end), - shard_param_range=(shard_param_start,shard_param_end), - ) - self.state[param]['fragments'].append(fragment) - bucket.fragments.append(fragment) - param_start = param_end - - # Initialize master param buffer - for fragment in self.state[param]['fragments']: - if fragment.in_local_shard: - bucket = self.state['buckets'][fragment.bucket_id] - param_start, param_end = fragment.shard_param_range - shard_start, shard_end = fragment.shard_range - model_param_fragment = param.view(-1)[param_start:param_end] - master_param_fragment = bucket.params_shard[shard_start:shard_end] - master_param_fragment.copy_(model_param_fragment) - - def zero_grad(self, set_to_none=True): - """Clear parameter gradients""" - - # Reset bucket buffers - self._grads_buckets.clear() - - # Construct views into contiguous grad buffer, if needed - if self.contiguous_grad_buffer: - self._grad_buffer.zero_() - for bucket_id in range(len(self.state['buckets'])): - bucket_start = bucket_id * self.bucket_size - bucket_end = bucket_start + self.bucket_size - bucket = self._grads_buckets[bucket_id] - bucket.grads_bucket = self._grad_buffer[bucket_start:bucket_end] - - # Reset param grads - for group in self.param_groups: - for param in group['params']: - if param.grad is None or set_to_none: - param.grad = None - else: - param.grad.zero_() - - # Reset other state - self._grads_generated = set() - self._inv_grad_scale = torch.full([1], 1.0, dtype=self.dtype, device=self.device) - self._grad_norm = None - - def _grad_copy(self, param): - """Copy parameter gradients to buckets""" - - # Copy param grad to buckets - for fragment in self.state[param]['fragments']: - - # Get fragment position - bucket_id = fragment.bucket_id - bucket = self._grads_buckets[bucket_id] - grad_start, grad_end = fragment.param_range - bucket_start, bucket_end = fragment.bucket_range - - # Set reduction status - if bucket.status == self.GradientStatus.SYNCING: - self._finish_bucket_grad_sync() - bucket.status = self.GradientStatus.PARTIALLY_FILLED - - # Allocate gradient buffer if needed - if bucket.grads_bucket is None: - if self.contiguous_grad_buffer: - grad_buffer_start = bucket_id * self.bucket_size - grad_buffer_end = grad_buffer_start + self.bucket_size - bucket.grads_bucket = self._grad_buffer[grad_buffer_start:grad_buffer_end] - else: - bucket.grads_bucket = torch.empty( - [self.bucket_size], - dtype=self.grad_sync_dtype, - device=self.device, - ) - bucket.grads_bucket.zero_() - - # Copy param grad to bucket - if param.grad is not None: - grad_in = param.grad.detach().view(-1)[grad_start:grad_end] - grad_out = bucket.grads_bucket[bucket_start:bucket_end] - if grad_in.data_ptr() != grad_out.data_ptr(): - grad_out.add_(grad_in) - - # Free param grad buffer - param.grad = None - - def grad_buffer_view(self, param): - """Construct view into grad buffer corresponding to param - - Assumes optimizer is using a contiguous grad buffer. - - """ - assert self.contiguous_grad_buffer - - # Figure out corresponding position in grad buffer - param_fragments = self.state[param]['fragments'] - start_bucket_id = param_fragments[0].bucket_id - start_bucket_offset, _ = param_fragments[0].bucket_range - end_bucket_id = param_fragments[-1].bucket_id - _, end_bucket_offset = param_fragments[-1].bucket_range - buffer_start = start_bucket_id * self.bucket_size + start_bucket_offset - buffer_end = end_bucket_id * self.bucket_size + end_bucket_offset - - # Construct view into grad buffer - flat_buffer = self._grad_buffer[buffer_start:buffer_end] - return flat_buffer.detach().view(param.size()) - - def _force_bucket_grad_sync(self): - """Ensure that all gradient buckets are synchronized""" - - # Synchronize all unsynchronized buckets - self._finish_bucket_grad_sync() - buckets = [ - bucket - for bucket_id, bucket in sorted(self._grads_buckets.items()) - if bucket.status != self.GradientStatus.READY - ] - if buckets: - self._start_bucket_grad_sync(buckets) - self._finish_bucket_grad_sync() - - # Fill any unsynchronized gradients with zeros - for bucket_id in range(len(self.state['buckets'])): - bucket = self._grads_buckets[bucket_id] - if bucket.grads_shard is None: - bucket.grads_shard = torch.zeros( - [self.shard_size], - dtype=self.grad_sync_dtype, - device=self.device, - ) - - # Reset set of generated gradients - self._grads_generated = set() - - def _try_start_bucket_grad_sync( - self, - params=[], - ignore_last_bucket=True, - ): - """Launches gradient synchronization if enough buckets are ready - - Gradient synchronization is asynchronous. Launches gradient - synchronization if all gradients have been generated or if - there are enough buckets ready to fill pipeline. - - Arguments: - params (iterable): parameters that have had their - gradients copied to buckets - ignore_last_bucket (bool): avoid synchronizing last bucket - until all gradients have been generated. This avoids - excessive synchronization when initializing buckets in - the first backward pass. - - """ - - # Register params that have generated grads - for param in params: - self._grads_generated.add(param) - for fragment in self.state[param]['fragments']: - bucket_id = fragment.bucket_id - bucket_fragments = self.state['buckets'][bucket_id].fragments - is_filled = True - for other_fragment in reversed(bucket_fragments): - param_group_id = other_fragment.param_group_id - param_id = other_fragment.param_id - other_param = self.param_groups[param_group_id]['params'][param_id] - if other_param not in self._grads_generated: - is_filled = False - break - if is_filled: - bucket = self._grads_buckets[bucket_id] - bucket.status = self.GradientStatus.FULLY_FILLED - - # Launch reductions if enough buckets are ready - if len(self._grads_generated) == self._num_grads: - self._force_bucket_grad_sync() - else: - filled_buckets = [] - for bucket_id, bucket in sorted(self._grads_buckets.items()): - if ignore_last_bucket and bucket_id == len(self.state['buckets'])-1: - continue - if bucket.status == self.GradientStatus.FULLY_FILLED: - filled_buckets.append(bucket) - pipeline_size = _round_to_multiple( - len(filled_buckets), - self.pipeline_size, - ) - if pipeline_size > 0: - self._start_bucket_grad_sync(filled_buckets[:pipeline_size]) - - def _start_bucket_grad_sync(self, buckets): - """Synchronize gradient buckets - - Gradient synchronization is asynchronous. Involves - reduce-scatter over distributed process group and allreduce - over redundant process group. - - """ - - # Call recursively if more buckets than streams - while len(buckets) > self.pipeline_size: - self._start_bucket_grad_sync(buckets[:self.pipeline_size]) - buckets = buckets[self.pipeline_size:] - self._finish_bucket_grad_sync() - - # Reduction operation - if self.average_grad_sync: - reduce_op = torch.distributed.ReduceOp.AVG - else: - reduce_op = torch.distributed.ReduceOp.SUM - - # Reduce gradients - main_stream = torch.cuda.current_stream() - for stream in self._pipeline_streams: - stream.wait_stream(main_stream) - for i, bucket in enumerate(buckets): - bucket.status = self.GradientStatus.SYNCING - stream = self._pipeline_streams[i % self.pipeline_size] - with torch.cuda.stream(stream): - - # Reduce-scatter over distributed process group - bucket.sync_wait() - if self.distributed_size == 1: - bucket.sync_grads_shard = bucket.grads_bucket - else: - with torch.cuda.stream(main_stream): - bucket.sync_grads_shard = torch.zeros( - [self.shard_size], - dtype=self.grad_sync_dtype, - device=self.device, - ) - grads_bucket_shards = [ - bucket.grads_bucket[i*self.shard_size:(i+1)*self.shard_size] - for i in range(self.distributed_size) - ] - if self._reduce_scatter_no_copy: - no_copy_kwarg = { 'no_copy': True } - else: - no_copy_kwarg = {} - bucket.sync_request = ( - torch.distributed.reduce_scatter( - bucket.sync_grads_shard, - grads_bucket_shards, - op=reduce_op, - group=self.distributed_process_group, - async_op=True, - **no_copy_kwarg, - ) - ) - - # All-reduce over redundant process group - # Note: Assuming reduce-scatters are finished in the - # order they are submitted, all-reduces should be - # submitted in a consistent order. There could be race - # conditions if wait doesn't finish in order. - if self.redundant_size > 1: - bucket.sync_wait() - bucket.sync_request = ( - torch.distributed.all_reduce( - bucket.sync_grads_shard, - op=reduce_op, - group=self.redundant_process_group, - async_op=True, - ) - ) - - def _finish_bucket_grad_sync(self): - """Wait for any gradient synchronizations that are in progress""" - for bucket_id, bucket in sorted(self._grads_buckets.items()): - if bucket.status == self.GradientStatus.SYNCING: - - # Finish asynchronous communication - bucket.sync_wait() - - # Accumulate gradient in local shard - if bucket.grads_shard is None: - bucket.grads_shard = bucket.sync_grads_shard - else: - bucket.grads_shard.add_(bucket.sync_grads_shard) - bucket.grads_bucket = None - bucket.sync_grads_shard = None - - # Reset status - bucket.status = self.GradientStatus.READY - - # Cached gradient norm has been invalidated - self._grad_norm = None - - @contextlib.contextmanager - def no_sync(self, greedy_grad_copy=False): - """Disable overlapped gradient synchronization - - Context manager that is similar to - torch.nn.parallel.DistributedDataParallel.no_sync. The - gradients can be synchronized by calling grad_sync or step. If - overlapped gradient synchronization is enabled, gradients can - also be synchronized by leaving the context and performing a - backward pass. - - Arguments: - greedy_grad_copy (bool, optional): copy parameter - gradients to buckets as soon as they are generated - (default: False) - - """ - old_greedy_grad_copy = self.greedy_grad_copy - old_overlap_grad_sync = self.overlap_grad_sync - self.greedy_grad_copy = greedy_grad_copy - self.overlap_grad_sync = False - try: - yield - finally: - self.greedy_grad_copy = old_greedy_grad_copy - self.overlap_grad_sync = old_overlap_grad_sync - - def grad_sync(self): - """Ensure that all gradients are synchronized""" - for bucket in self.state['buckets']: - for fragment in bucket.fragments: - param_group_id = fragment.param_group_id - param_id = fragment.param_id - param = self.param_groups[param_group_id]['params'][param_id] - if param.grad is not None: - self._grad_copy(param) - self._try_start_bucket_grad_sync( - params=[param], - ignore_last_bucket=False, - ) - self._force_bucket_grad_sync() - - def _local_grad_norm(self, parameters=[], norm_type=2.0): - """Local contribution to parameter gradient norm - - Returns square of 2-norm. Other norms are not yet supported. - - If no parameters are provided, the norm is computed for all - parameters in optimizer. Provided parameters are assumed to be - in optimizer. - - """ - norm_type = float(norm_type) - assert norm_type == 2.0 - - # Make sure that gradients have been reduced - self.grad_sync() - - if not parameters or len(parameters) == self._num_grads: - # Compute norm of all local gradients - dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device='cuda') - grad_norm_sq = multi_tensor_applier( - amp_C.multi_tensor_l2norm, - dummy_overflow_buf, - [[bucket.grads_shard for bucket in self._grads_buckets.values()]], - False, - )[0] ** 2 - else: - # Compute norm of selected local gradients - grads = [] - for param in parameters: - for fragment in self.state[param]['fragments']: - if fragment.in_local_shard: - bucket = self._grads_buckets[fragment.bucket_id] - shard_start, shard_end = fragment.shard_range - grads.append(bucket.grads_shard[shard_start:shard_end]) - if grads: - dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device='cuda') - grad_norm_sq = multi_tensor_applier( - amp_C.multi_tensor_l2norm, - dummy_overflow_buf, - [grads], - False, - )[0] ** 2 - else: - grad_norm_sq = torch.zeros([1], dtype=torch.float32, device=self.device) - - return grad_norm_sq.detach().view([]) - - def grad_norm(self, parameters=[], norm_type=2.0, force=False): - """Gradient norm of parameters in optimizer - - The norm is computed over all gradients together, as if they - were concatenated into a single vector. All provided - parameters must be managed by optimizer. - - The computed value is cached to avoid redundant communication. - - Arguments: - parameters (iterable, optional): an iterable of parameters - in optimizer (default: all parameters in optimizer). - norm_type (float or int, optional): type of the used - p-norm (default: 2). Only 2-norm is currently - supported. - force (bool, optional): ignore cached value and force norm - computation (default: False). - - """ - if force or self._grad_norm is None: - norm_type = float(norm_type) - assert norm_type == 2.0 - grad_norm_sq = self._local_grad_norm( - parameters=parameters, - norm_type=norm_type, - ) - torch.distributed.all_reduce( - grad_norm_sq, - op=torch.distributed.ReduceOp.SUM, - group=self.distributed_process_group, - ) - self._grad_norm = grad_norm_sq.sqrt() - return self._grad_norm.detach() - - def clip_grad_norm(self, max_norm, parameters=[], norm_type=2.0): - """Clips gradient norm of parameters in optimizer - - The norm is computed over all gradients together, as if they - were concatenated into a single vector. The scaling is - deferred until the optimizer step, which should be called - immediately after this function. - - The computed grad norm is cached to avoid redundant - communication. - - Arguments: - max_norm (float or int): max norm of the gradients - parameters (iterable, optional): an iterable of parameters - in optimizer (default: all parameters in optimizer). - norm_type (float or int, optional): type of the used - p-norm (default: 2) - - """ - assert max_norm > 0 - total_norm = self.grad_norm(parameters=parameters, norm_type=norm_type) - inv_clip_coef = (total_norm + 1e-6) / max_norm - self._inv_grad_scale = torch.clamp(inv_clip_coef, min=1.0).view(1) - return total_norm - - def step(self, closure=None, *, grad_scaler=None): - """Apply Adam optimizer step - - Arguments: - closure (callable, optional): closure to recompute loss - (default: None) - grad_scaler (torch.cuda.amp.GradScaler, optional): - gradient scaler (default: None) - - """ - - # Apply closure - loss = None - if closure is not None: - loss = closure() - - # Make sure that gradients have been reduced - self.grad_sync() - - # Apply gradient scaler if provided - # Note: We compute gradient norm to check for non-finite - # values. This is more conservative and compute intensive than - # directly checking, but it avoids extra communication if we - # have already computed gradient norm e.g. for gradient - # clipping. - if grad_scaler is not None: - grad_norm = self.grad_norm() - found_inf = torch.logical_not(torch.isfinite(grad_norm)) - scaler_state = grad_scaler._per_optimizer_states[id(self)] - scaler_state['found_inf_per_device'] = {found_inf.device: found_inf.float()} - if found_inf.item(): - return - else: - assert grad_scaler._scale is not None - self._inv_grad_scale *= grad_scaler._scale - inv_grad_scale = self._inv_grad_scale.item() - - # Construct workspace buffers - params_bucket_buffers = [ - torch.empty( - [self.bucket_size], - dtype=self.param_sync_dtype, - device=self.device, - ) - for _ in range(self.pipeline_size) - ] - if self.grad_sync_dtype == self.param_sync_dtype: - shard_start = self.distributed_rank * self.shard_size - shard_end = shard_start + self.shard_size - params_copy_buffers = [ - params_bucket[shard_start:shard_end] - for params_bucket in params_bucket_buffers - ] - else: - params_copy_buffers = [ - torch.empty( - [self.shard_size], - dtype=self.grad_sync_dtype, - device=self.device, - ) - for _ in range(self.pipeline_size) - ] - - # Apply optimizer step to each bucket and synchronize params - self.state['step'] += 1 - main_stream = torch.cuda.current_stream() - for stream in self._pipeline_streams: - stream.wait_stream(main_stream) - for bucket_id in range(len(self.state['buckets'])): - stream_id = bucket_id % self.pipeline_size - - # Bucket buffers - fragments = self.state['buckets'][bucket_id].fragments - shard_start = self.distributed_rank * self.shard_size - shard_end = shard_start + self.shard_size - params_bucket = params_bucket_buffers[stream_id] - params_bucket_shard = params_bucket[shard_start:shard_end] - params_shard = self.state['buckets'][bucket_id].params_shard - params_copy = params_copy_buffers[stream_id] - exp_avg = self.state['buckets'][bucket_id].exp_avg_shard - exp_avg_sq = self.state['buckets'][bucket_id].exp_avg_sq_shard - grads = self._grads_buckets[bucket_id].grads_shard - - # Perform compute on parallel stream - stream = self._pipeline_streams[stream_id] - with torch.cuda.stream(stream): - - # Find param fragments in local shard - buffers = collections.defaultdict(list) # p, m, v, g, p_copy - for fragment in fragments: - if fragment.in_local_shard: - param_group_id = fragment.param_group_id - shard_start, shard_end = fragment.shard_range - buffers[param_group_id].append([ - params_shard[shard_start:shard_end], - exp_avg[shard_start:shard_end], - exp_avg_sq[shard_start:shard_end], - grads[shard_start:shard_end], - params_copy[shard_start:shard_end], - ]) - - # Fuse param fragments if possible - if len(buffers) == 1: - group_id = list(buffers.keys())[0] - buffers[group_id] = [( - params_shard, - exp_avg, - exp_avg_sq, - grads, - params_copy, - )] - - # Apply optimizer step to each param group - for group_id, group_buffers in buffers.items(): - - # Get param group configs - group = self.param_groups[group_id] - beta1, beta2 = group['betas'] - bias_correction = 1 if group['bias_correction'] else 0 - eps = group['eps'] - weight_decay = group['weight_decay'] - - # Copy param group configs to GPU - num_fragments = len(group_buffers) - beta1 = torch.full([num_fragments], beta1, dtype=self.dtype, device='cuda') - beta2 = torch.full([num_fragments], beta2, dtype=self.dtype, device='cuda') - bias_correction = torch.full([num_fragments], bias_correction, dtype=torch.int32, device='cuda') - eps = torch.full([num_fragments], eps, dtype=self.dtype, device='cuda') - weight_decay = torch.full([num_fragments], weight_decay, dtype=self.dtype, device='cuda') - - # Apply Adam step - dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device='cuda') - multi_tensor_applier( - distributed_adam_cuda.multi_tensor_fused_adam, - dummy_overflow_buf, - list(zip(*group_buffers)), - beta1, - beta2, - bias_correction, - eps, - weight_decay, - group['lr'], - inv_grad_scale, - self.state['step'], - 1, # Set to 0 to apply eps inside sqrt - ) - - # Cast parameter dtype if needed - if params_copy.data_ptr() != params_bucket_shard.data_ptr(): - params_bucket_shard.copy_(params_copy) - - # Allgather updated parameters - if self.distributed_size > 1: - all_params_bucket_shards = [ - params_bucket[i*self.shard_size:(i+1)*self.shard_size] - for i in range(self.distributed_size) - ] - if self._all_gather_no_copy: - no_copy_kwarg = { 'no_copy': True } - else: - no_copy_kwarg = {} - torch.distributed.all_gather( - all_params_bucket_shards, - params_bucket_shard, - group=self.distributed_process_group, - **no_copy_kwarg, - ) - - # Copy values to param buffers - buffers = collections.defaultdict(list) # param_in, param_out - for fragment in fragments: - param_group_id = fragment.param_group_id - param_id = fragment.param_id - param = self.param_groups[param_group_id]['params'][param_id] - bucket_start, bucket_end = fragment.bucket_range - param_start, param_end = fragment.param_range - param_in = params_bucket[bucket_start:bucket_end] - param_out = param.detach().view(-1)[param_start:param_end] - if param_in.dtype == param_out.dtype: - # Just copy bytes if buffers have same type - param_in = param_in.view(torch.uint8) - param_out = param_out.view(torch.uint8) - buffers[(param.is_cuda, param.dtype)].append( - (param_in, param_out) - ) - for (is_cuda, dtype), dtype_buffers in buffers.items(): - fused_kernel_dtypes = ( - self.param_sync_dtype, - torch.float32, - torch.float16, - torch.uint8, - ) - if is_cuda and dtype in fused_kernel_dtypes: - dummy_overflow_buf = torch.zeros([1], dtype=torch.int32, device='cuda') - multi_tensor_applier( - fused_adam_cuda.maybe_cast_mt, - dummy_overflow_buf, - list(zip(*dtype_buffers)), - ) - else: - for param_in, param_out in dtype_buffers: - param_out.copy_(param_in) - - # Synchronize pipeline streams - for stream in self._pipeline_streams: - main_stream.wait_stream(stream) - - return loss - - def state_dict(self, gather_on_root=True): - """Get dictionary containing optimizer state - - Default behavior is to perform communication so that the - entire optimizer state is returned on the root rank in the - process group. In this case, all ranks in the process group - must enter this function and no value is returned on non-root - ranks. - - Arguments: - gather_on_root (bool, optional): Gather state from all - ranks on the root rank (default: True) - - """ - state_dict = super().state_dict() - if not gather_on_root: - return state_dict - - # Export local state to byte string - state_bytes = io.BytesIO() - torch.save(state_dict, state_bytes) - state_bytes.seek(0) - state_bytes_view = state_bytes.getbuffer() - - # Get data sizes on all ranks - local_state_size = len(state_bytes_view) - state_sizes = [None] * self.distributed_size - torch.distributed.all_gather_object( - state_sizes, - local_state_size, - group=self.process_group, - ) - max_state_size = max(state_sizes) - - # Construct workspace buffers - chunk_size = self.shard_size * torch.finfo(self.grad_sync_dtype).bits // 8 - if self.distributed_rank == 0: - gathered_state_bytes = [state_bytes.getvalue()] - gathered_state_bytes.extend(bytearray(size) for size in state_sizes[1:]) - gathered_chunks_buffers = [ - torch.empty( - [chunk_size * self.distributed_size], - dtype=torch.uint8, - device=self.device, - ) - for _ in range(self.pipeline_size) - ] - else: - chunk_buffers = [ - torch.empty( - [chunk_size], - dtype=torch.uint8, - device=self.device, - ) - for _ in range(self.pipeline_size) - ] - - # Split data into chunks and gather on root rank - # Note: Assuming we are using the NCCL backend, communication - # must happen on the GPU. We split the data into fixed-size - # chunks so that the GPU memory usage is limited to - # (chunk_size * distributed_size) bytes. - # TODO: Avoid chunking with direct communication between CPUs - main_stream = torch.cuda.current_stream() - for stream in self._pipeline_streams: - stream.wait_stream(main_stream) - for stream_id, offset in enumerate(range(0, max_state_size, chunk_size)): - stream_id %= self.pipeline_size - - # Buffers for chunk - if self.distributed_rank == 0: - gathered_chunks = [ - gathered_chunks_buffers[stream_id][i*chunk_size:(i+1)*chunk_size] - for i in range(self.distributed_size) - ] - else: - chunk = chunk_buffers[stream_id] - - # Perform communication on parallel stream - stream = self._pipeline_streams[stream_id] - with torch.cuda.stream(stream): - - # Copy to GPU - if self.distributed_rank != 0 and offset < local_state_size: - local_chunk_size = min(chunk_size, local_state_size-offset) - chunk[:local_chunk_size].copy_( - torch.frombuffer( - state_bytes_view, - dtype=torch.uint8, - count=local_chunk_size, - offset=offset, - ), - non_blocking=True, - ) - - # Gather on root - if self.distributed_rank == 0: - if self._gather_no_copy: - no_copy_kwarg = { 'no_copy': True } - else: - no_copy_kwarg = {} - torch.distributed.gather( - gathered_chunks[0], - gathered_chunks, - dst=self._process_group_ranks[0], - group=self.process_group, - **no_copy_kwarg, - ) - else: - torch.distributed.gather( - chunk, - dst=self._process_group_ranks[0], - group=self.process_group, - ) - - # Copy back to CPU - if self.distributed_rank == 0: - for rank in range(1, self.distributed_size): - if offset < state_sizes[rank]: - rank_chunk_size = min(chunk_size, state_sizes[rank]-offset) - torch.frombuffer( - gathered_state_bytes[rank], - dtype=torch.uint8, - count=rank_chunk_size, - offset=offset, - ).copy_( - gathered_chunks[rank][:rank_chunk_size], - non_blocking=True, - ) - - # Synchronize GPU - for stream in self._pipeline_streams: - main_stream.wait_stream(stream) - main_stream.synchronize() - - # Return gathered state data on root rank - if self.distributed_rank == 0: - return {'gathered_states': gathered_state_bytes} - else: - return None - - def load_state_dict(self, state_dict): - """Load optimizer state""" - - # State dict contains state for all ranks - if 'gathered_states' in state_dict: - - # Deallocate distributed optimizer state to reduce GPU - # memory usage - if 'buckets' in self.state: - del self.state['buckets'] - - # Get state for current rank and parse byte string - state_bytes = state_dict['gathered_states'][self.distributed_rank] - state_bytes = io.BytesIO(state_bytes) - state_dict = torch.load(state_bytes) - - return super().load_state_dict(state_dict) diff --git a/apex/contrib/optimizers/distributed_fused_lamb.py b/apex/contrib/optimizers/distributed_fused_lamb.py deleted file mode 100644 index c06e5b7..0000000 --- a/apex/contrib/optimizers/distributed_fused_lamb.py +++ /dev/null @@ -1,722 +0,0 @@ -import math -import torch -import importlib -import amp_C -from apex.multi_tensor_apply import multi_tensor_applier - -import torch.distributed.distributed_c10d as c10d - -class DistributedFusedLAMB(torch.optim.Optimizer): - - """Implements LAMB algorithm. - - Currently GPU-only. Requires Apex to be installed via - ``pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./``. - - This version of fused LAMB implements 2 fusions. - - * Fusion of the LAMB update's elementwise operations - * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches. - - :class:`apex.optimizers.FusedLAMB`'s usage is identical to any ordinary Pytorch optimizer:: - - opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....) - ... - opt.step() - - :class:`apex.optimizers.FusedLAMB` may be used with or without Amp. If you wish to use :class:`FusedLAMB` with Amp, - you may choose any ``opt_level``:: - - opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....) - model, opt = amp.initialize(model, opt, opt_level="O0" or "O1 or "O2") - ... - opt.step() - - In general, ``opt_level="O1"`` is recommended. - - LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. - - Arguments: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups. - lr (float, optional): learning rate. (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its norm. (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability. (default: 1e-8) - weight_decay (float, optional): weight decay (L2 penalty) (default: 0) - amsgrad (boolean, optional): whether to use the AMSGrad variant of this - algorithm from the paper `On the Convergence of Adam and Beyond`_ - NOT SUPPORTED now! (default: False) - adam_w_mode (boolean, optional): Apply L2 regularization or weight decay - True for decoupled weight decay(also known as AdamW) (default: True) - grad_averaging (bool, optional): whether apply (1-beta2) to grad when - calculating running averages of gradient. (default: True) - set_grad_none (bool, optional): whether set grad to None when zero_grad() - method is called. (default: True) - max_grad_norm (float, optional): value used to clip global grad norm - (default: 1.0) - use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0 - weight decay parameter (default: False) - step_supports_amp_scaling(boolean, optional): whether to use customized - gradient unscaling logic (default: True) - - .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes: - https://arxiv.org/abs/1904.00962 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ - """ - - class AtomicCounter(object): - def __init__(self): - self.value = 0 - self.order = [] - import threading - self._lock = threading.Lock() - - def add(self, idx): - with self._lock: - self.value += 1 - self.order.append(idx) - - def __init__(self, params, - lr=1e-3, bias_correction = True, grad_averaging=True, - betas=(0.9, 0.999), eps=1e-8, - weight_decay=0., max_grad_norm=0., - adam_w_mode=True, use_nvlamb=False, - step_supports_amp_scaling=True, overlap_reductions=True, - dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4, - dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0, - e5m2_allgather=False, verbose=False): - defaults = dict(lr=lr, bias_correction=bias_correction, - betas=betas, eps=eps, weight_decay=weight_decay, - grad_averaging=grad_averaging, - max_grad_norm=max_grad_norm) - - super(DistributedFusedLAMB, self).__init__(params, defaults) - - global fused_adam_cuda, distributed_lamb_cuda - fused_adam_cuda = importlib.import_module("fused_adam_cuda") - distributed_lamb_cuda = importlib.import_module("distributed_lamb_cuda") - - self._overflow_buf = torch.cuda.IntTensor([0]) - self._has_overflow = False - self.multi_tensor_lamb_compute_update_term = distributed_lamb_cuda.multi_tensor_lamb_compute_update_term - self.multi_tensor_lamb_update_weights = distributed_lamb_cuda.multi_tensor_lamb_update_weights - import amp_C - self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm - - self._grad_averaging = grad_averaging - self._adam_w_mode = 1 if adam_w_mode else 0 - self._use_nvlamb = use_nvlamb - self._step_supports_amp_scaling = step_supports_amp_scaling - self._is_accumulation_step = False - self._last_step = False - self._overlap_reductions = overlap_reductions - self._global_scale = None - self._num_blocks = dwu_num_blocks - self._num_chunks = dwu_num_chunks - self._e5m2_allgather = e5m2_allgather - self._verbose = verbose - self._L2_grad_norm = None - - self._current_process_group = c10d._get_default_group() - self._available_ranks = list(c10d._pg_group_ranks[self._current_process_group].keys()) - self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size - self._world_size = torch.distributed.get_world_size() - self._num_groups = self._world_size // self._group_size - self._rank_in_group = torch.distributed.get_rank() % self._group_size - - self._lr = torch.tensor(0.0, dtype=torch.float32, device='cuda') - - self._resume_from_checkpoint = False - self._step = torch.cuda.IntTensor([0]) - - # Master weight, moment, gradient buffers - self._fp32_p, self._fp32_m, self._fp32_v, self._fp16_p, self._fp16_g = None, None, None, None, None - - import inspect - assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), "This version of c10d does not support no_copy option" - - self._num_rs_pg = dwu_num_rs_pg - self._num_ar_pg = dwu_num_ar_pg - self._num_ag_pg = dwu_num_ag_pg - if self._num_groups > 1: - self._ar_pg = [] - for dev_i in range(self._group_size): - ranks = [dev_i+j*self._group_size for j in range(self._num_groups)] - for i in range(self._num_ar_pg): - if self._verbose: - print(f"creating new group {i}: {ranks}") - grp = torch.distributed.new_group(ranks=ranks) - if grp != torch.distributed.GroupMember.NON_GROUP_MEMBER: - if self._verbose: - print(f"group {i}: init barrier (device: {torch.cuda.current_device()})") - torch.distributed.barrier(group=grp, device_ids=[torch.cuda.current_device()]) - if self._verbose: - print(f"created new group {i}") - - if torch.distributed.get_rank() in ranks: - self._ar_pg.append(grp) - self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)] - #for ar_pg in self._ar_pg: - # torch.distributed.all_reduce(self._overflow_buf,group=ar_pg) - rs_ranks = [] - for group_i in range(self._num_groups): - rs_ranks.append([group_i*self._group_size+j for j in range(self._group_size)]) - self._rs_pg = [] - for group_i in range(self._num_groups): - ranks = rs_ranks[group_i] - for i in range(self._num_rs_pg): - grp = torch.distributed.new_group(ranks=ranks) - if torch.distributed.get_rank() in ranks: - self._rs_pg.append(grp) - l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks) - if torch.distributed.get_rank() in ranks: - self._l2_grad_norm_pg = l2_grad_norm_pg - #torch.distributed.all_reduce(self._overflow_buf,group=self._l2_grad_norm_pg) - self._rs_st = [torch.cuda.Stream() for _ in range(self._num_rs_pg)] - #for rs_pg in self._rs_pg: - # torch.distributed.all_reduce(self._overflow_buf,group=rs_pg) - if self._num_ag_pg == 0: - self._ag_pg = self._rs_pg - self._ag_st = self._rs_st - self._num_ag_pg = self._num_rs_pg - else: - self._ag_pg = [] - for group_i in range(self._num_groups): - ranks = rs_ranks[group_i] - for i in range(self._num_ag_pg): - grp = torch.distributed.new_group(ranks=ranks) - if torch.distributed.get_rank() in ranks: - self._ag_pg.append(grp) - self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)] - #for ag_pg in self._ag_pg: - # torch.distributed.all_reduce(self._overflow_buf,group=ag_pg) - self._l2_grad_norm_st = torch.cuda.Stream() - self._completion_st = torch.cuda.Stream() - self._step.record_stream(self._completion_st) - - self._reductions_works = [None]*self._num_blocks - self._allgather_works = [None]*self._num_blocks - - self._one = torch.cuda.IntTensor([1]) - - self._first_step = True - self._lazy_init_stage1_done, self._lazy_init_stage2_done = False, False - self._param_order = self.AtomicCounter() - - def _lazy_init_stage1(self): - if self._lazy_init_stage1_done: return - - p_offset = 0 - p_i = 0 - self._model_params = [] - self._grad_accs = [] - self._group_properties = [] - for group in self.param_groups: - prev = None - beta1, beta2 = group['betas'] - beta3 = 1.0 - beta1 if self._grad_averaging else 1.0 - bias_correction = 1 if group['bias_correction'] else 0 - eps = group['eps'] - weight_decay = group['weight_decay'] - for p in group['params']: - torch.distributed.broadcast(p, 0) - if not p.requires_grad: - continue - self._model_params.append(p) - self._group_properties.append(( - weight_decay, - bias_correction, - beta1, - beta2, - beta3, - eps - )) - p_grads_size = p.numel() - def wrapper(param, param_i): - param_tmp = param.expand_as(param) - grad_acc = param_tmp.grad_fn.next_functions[0][0] - def allreduce_hook(*unused): - if self._first_step: - # first time - self._param_order.add(param_i) - else: - idx = self._param_order.order.index(param_i) - self._do_overlapped_reduction(idx, param) - grad_acc.register_hook(allreduce_hook) - self._grad_accs.append(grad_acc) - wrapper(p, p_i) - p_offset += p_grads_size - # Only enforce 128b alignment (64 * fp16) for non-consecutive parameters - # RNN is one example of consecutive parameters: - # (weight_ih, weight_hh, bias_ih, bias_hh) - if prev is not None and (prev.data_ptr() + prev.numel() * prev.element_size() != p.data_ptr()): - p_offset = ((p_offset + 63) // 64) * 64 - prev = p - p_i += 1 - self._grads_generated = [False]*len(self._model_params) - self._grads_fp16, self._grads_fp32 = [], [] - if self._overlap_reductions: - self._current_block = self._num_blocks - - self._net_total_param_size = p_offset - self._total_param_size = p_offset - dwu_min_page_size = 256 * self._num_blocks * self._num_chunks * self._group_size - self._total_param_size = ((self._total_param_size + dwu_min_page_size - 1) // dwu_min_page_size) * dwu_min_page_size - self._block_size = self._total_param_size // self._num_blocks - self._chunk_size = self._block_size // self._num_chunks - self._shard_size = self._chunk_size // self._group_size - #print("self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._chunk_size=%d, self._shard_size=%d" % (self._net_total_param_size, self._total_param_size,dwu_min_page_size,self._block_size,self._chunk_size,self._shard_size)) - - self._flat_grads = torch.zeros([self._total_param_size], dtype=torch.float16, device='cuda') - self._new_params = torch.zeros([self._total_param_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda') - self._mega_shard_size = self._num_blocks * self._num_chunks * self._shard_size - # initialize master weights, moments buffers if not loaded from checkpoint - if self._fp32_p is None: - self._fp32_p = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda') - self._fp32_m = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda') - self._fp32_v = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda') - self._fp32_u = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda') - # FIXME: Rethink fp16 label since it's either uint8 or fp16 - self._fp16_p = torch.zeros([self._mega_shard_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda') - self._fp16_g = torch.zeros([self._mega_shard_size], dtype=torch.float16, device='cuda') - - def _flat_split(p): - def __blockify(p): - return [p[block_id*self._block_size:(block_id+1)*self._block_size] for block_id in range(self._num_blocks)] - def __chunkify(p): - return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._num_chunks)] - def __shardify(p): - return [p[shard_id*self._shard_size:(shard_id+1)*self._shard_size] for shard_id in range(self._group_size)] - list_of_blocks = __blockify(self._flat_grads) - list_of_list_of_chunks = [__chunkify(block) for block in list_of_blocks] - list_of_list_of_list_of_shards = [[__shardify(chunk) for chunk in chunks] for chunks in list_of_list_of_chunks] - return list_of_blocks, list_of_list_of_chunks, list_of_list_of_list_of_shards - self._flat_grads_blocks, self._flat_grads_chunks, self._flat_grads_shards = _flat_split(self._flat_grads) - def _full_packed_split(p): - def __shardify(p): - return [p[mega_shard*self._mega_shard_size:(mega_shard+1)*self._mega_shard_size] for mega_shard in range(self._group_size)] - def __blockify(p): - return [p[block_id*self._num_chunks*self._shard_size:(block_id+1)*self._num_chunks*self._shard_size] for block_id in range(self._num_blocks)] - def __chunkify(p): - return [p[chunk_id*self._shard_size:(chunk_id+1)*self._shard_size] for chunk_id in range(self._num_chunks)] - list_of_mega_shards = __shardify(p) - list_of_list_of_mega_blocks = [__blockify(mega_shard) for mega_shard in list_of_mega_shards] - list_of_list_of_list_of_mega_chunks = [[__chunkify(mega_block) for mega_block in mega_blocks] for mega_blocks in list_of_list_of_mega_blocks] - return list_of_mega_shards, list_of_list_of_mega_blocks, list_of_list_of_list_of_mega_chunks - self._new_params_mega_shards, self._new_params_mega_blocks, self._new_params_mega_chunks = _full_packed_split(self._new_params) - def _packed_split(p): - def __packed_blockify(p): - packed_block_size = self._num_chunks*self._shard_size - return [p[block_id*packed_block_size:(block_id+1)*packed_block_size] for block_id in range(self._num_blocks)] - def __packed_chunkify(p): - # in the packed format, each chunk contains one shard, so packed_chunk_size == self._shard_size - return [p[chunk_id*self._shard_size:(chunk_id+1)*self._shard_size] for chunk_id in range(self._num_chunks)] - list_of_blocks = __packed_blockify(p) - list_of_list_of_chunks = [__packed_chunkify(block) for block in list_of_blocks] - return list_of_blocks, list_of_list_of_chunks - self._fp32_p_blocks, self._fp32_p_chunks = _packed_split(self._fp32_p) - self._fp32_m_blocks, self._fp32_m_chunks = _packed_split(self._fp32_m) - self._fp32_v_blocks, self._fp32_v_chunks = _packed_split(self._fp32_v) - self._fp32_u_blocks, self._fp32_u_chunks = _packed_split(self._fp32_u) - self._fp16_p_blocks, self._fp16_p_chunks = _packed_split(self._fp16_p) - self._fp16_g_blocks, self._fp16_g_chunks = _packed_split(self._fp16_g) - - self._lazy_init_stage1_done = True - - def _lazy_init_stage2(self): - if self._lazy_init_stage2_done: return - - self._param_order.order.reverse() - - # re-order model_params, grad_accs, group_properties lists - self._model_params = [self._model_params[i] for i in self._param_order.order] - self._grad_accs = [self._grad_accs[i] for i in self._param_order.order] - self._group_properties = [self._group_properties[i] for i in self._param_order.order] - - # re-collect grads info (size, offset) after ordering - prev = None - p_offset = 0 - self._grads_info = [] - self._individual_flat_grads = [] - for i, p in enumerate(self._model_params): - p_grads_size = p.numel() - self._grads_info.append({"param_grads_size":p_grads_size, "param_offset":p_offset}) - self._individual_flat_grads.append(self._flat_grads[p_offset:p_offset+p_grads_size].view_as(p)) - # for the first iteration - self._do_overlapped_reduction(i, p) - p_offset += p_grads_size - # Only enforce 128b alignment (64 * fp16) for non-consecutive parameters - # RNN is one example of consecutive parameters: - # (weight_ih, weight_hh, bias_ih, bias_hh) - if prev is not None and (prev.data_ptr() + prev.numel() * prev.element_size() != p.data_ptr()): - p_offset = ((p_offset + 63) // 64) * 64 - prev = p - - self._low_param_i = [0]*self._num_blocks - for block_id in range(self._num_blocks-1,-1,-1): - p_i = len(self._grads_info)-1 - while p_i > 0 and self._grads_info[p_i]["param_offset"] > block_id*self._block_size: - p_i -= 1 - self._low_param_i[block_id] = p_i - #print("self._low_param_i", self._low_param_i) - - # This paragraph does two things: - # 1) Copy model parameters into master buffer - # 2) Create tensor lists for unpacking new parameter tensor after all-gather - self._packed_flat_to_model_params_fp16 = [] - self._packed_flat_to_model_params_fp32 = [] - self._model_params_num = len(self._model_params) - self._contrib_tensor_list = [] - self._contrib_min_param_i, self._contrib_max_param_i = -1, -1 - self._contrib_update_frag_for_norm = [] - self._contrib_model_param_for_norm_fp16 = [] - self._contrib_model_param_for_norm_fp32 = [] - self._contrib_model_param_for_norm_is_fp16 = [] - self._model_param_is_contrib = [] - self._contrib_group_properties = [] - for shard_id in range(self._group_size): - for block_id in range(self._num_blocks): - for chunk_id in range(self._num_chunks): - flat_shard_start = (((block_id * self._num_chunks + chunk_id) * self._group_size) + shard_id) * self._shard_size - flat_shard_end = flat_shard_start + self._shard_size - for param_i, (p, grads_info, group_props) in enumerate(zip(self._model_params, self._grads_info, self._group_properties)): - flat_grad_start = grads_info["param_offset"] - flat_grad_end = flat_grad_start + grads_info["param_grads_size"] - clipped_start = (lambda a,b: a if a > b else b)(flat_grad_start, flat_shard_start) - clipped_end = (lambda a,b: a if a < b else b)(flat_grad_end, flat_shard_end) - if clipped_start < clipped_end: - grad_offset = clipped_start - flat_grad_start - grad_length = clipped_end - clipped_start - shard_offset = clipped_start - flat_shard_start - model_param_fragment = p.view(-1)[grad_offset:grad_offset+grad_length] - new_param_packed_fragment = self._new_params_mega_chunks[shard_id][block_id][chunk_id][shard_offset:shard_offset+grad_length] - if model_param_fragment.dtype == torch.float16: - self._packed_flat_to_model_params_fp16.append( (new_param_packed_fragment, model_param_fragment) ) - else: - self._packed_flat_to_model_params_fp32.append( (new_param_packed_fragment, model_param_fragment) ) - if shard_id == self._rank_in_group: - self._model_param_is_contrib.append(param_i) - # copy model parameters into master buffer - master_param_fragment = self._fp32_p_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length] - opti_state_m_fragment = self._fp32_m_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length] - opti_state_v_fragment = self._fp32_v_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length] - opti_state_u_fragment = self._fp32_u_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length] - opti_state_g_fragment = self._fp16_g_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length] - opti_state_p_fragment = self._fp16_p_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length] - #print("model_param_fragment.size()=%s, new_param_packed_fragment.size()=%s, master_param_fragment.size()=%s" % (str(model_param_fragment.size()), str(new_param_packed_fragment.size()), str(master_param_fragment.size()))) - if not self._resume_from_checkpoint: - master_param_fragment.copy_(model_param_fragment) - self._contrib_group_properties.append(group_props) - self._contrib_tensor_list.append((master_param_fragment, opti_state_m_fragment, opti_state_v_fragment, opti_state_u_fragment, opti_state_g_fragment, opti_state_p_fragment)) # p, m, v, u, g, p_copy - self._contrib_update_frag_for_norm.append(opti_state_u_fragment) - if p.dtype == torch.float16: - self._contrib_model_param_for_norm_fp16.append(p) - else: - self._contrib_model_param_for_norm_fp32.append(p) - self._contrib_model_param_for_norm_is_fp16.append(True if p.dtype == torch.float16 else False) - if self._contrib_min_param_i < 0: self._contrib_min_param_i = param_i - self._contrib_max_param_i = param_i - self._contrib_model_param_for_norm_num = len(self._contrib_model_param_for_norm_is_fp16) - if len(self._contrib_model_param_for_norm_fp16) == 0: self._contrib_model_param_for_norm_fp16 = None - if len(self._contrib_model_param_for_norm_fp32) == 0: self._contrib_model_param_for_norm_fp32 = None - self._contrib_model_param_for_norm_is_fp32 = torch.tensor([not is_fp16 for is_fp16 in self._contrib_model_param_for_norm_is_fp16], dtype=torch.bool, device='cuda') - self._contrib_model_param_for_norm_is_fp16 = torch.tensor([is_fp16 for is_fp16 in self._contrib_model_param_for_norm_is_fp16], dtype=torch.bool, device='cuda') - self._offsets = torch.tensor(self._model_param_is_contrib, dtype=torch.int64, device='cuda') - - p, m, v, u, g, p_copy = list(zip(*self._contrib_tensor_list)) - self._contrib_compute_update_term_tensor_list = [g, p, m, v, u] - self._contrib_update_weights_tensor_list = [u, p, p_copy] - - math_type = self._fp32_u.dtype - decay, bias_correction, beta1, beta2, beta3, epsilon = list(zip(*self._contrib_group_properties)) - self._contrib_beta1 = torch.tensor(beta1, dtype=math_type, device='cuda') - self._contrib_beta2 = torch.tensor(beta2, dtype=math_type, device='cuda') - self._contrib_beta3 = torch.tensor(beta3, dtype=math_type, device='cuda') - self._contrib_bias_correction = torch.tensor(bias_correction, dtype=torch.int, device='cuda') - self._contrib_epsilon = torch.tensor(epsilon, dtype=math_type, device='cuda') - self._contrib_weight_decay = torch.tensor(decay, dtype=math_type, device='cuda') - - self._packed_flat_to_model_params_fp16 = list(zip(*self._packed_flat_to_model_params_fp16)) if len(self._packed_flat_to_model_params_fp16) > 0 else None - self._packed_flat_to_model_params_fp32 = list(zip(*self._packed_flat_to_model_params_fp32)) if len(self._packed_flat_to_model_params_fp32) > 0 else None - - self._lazy_init_stage2_done = True - - self.complete_reductions() - self._first_step = False - - def set_is_accumulation_step(self, is_accumulation_step): - self._is_accumulation_step = is_accumulation_step - - def set_last_step(self, last_step): - self._last_step = last_step - - def _get_flush_block(self): - flush_block = [] - if self._current_block > 0 and self._grads_generated[self._low_param_i[self._current_block-1]]: - num_grads = len(self._grads_generated) - contiguous_idx = num_grads - while contiguous_idx > 0 and self._grads_generated[contiguous_idx-1]: - contiguous_idx -= 1 - - if contiguous_idx < num_grads and self._grads_info[contiguous_idx]["param_offset"] <= (self._current_block-1)*self._block_size: - self._current_block -= 1 - start = self._current_block * self._block_size - end = (self._current_block+1) * self._block_size - flush_block = [start, end] - - return flush_block - - def _pipeline_block_reductions(self, block_id): - self._flatten_grad_mt(1.0/self._world_size) - - # Reduction within each node - # Changes gradient format from [block * chunk * shard] to [shard * block * chunk] - # The output format is the same as the fp32 master parameters - works = [None]*self._num_chunks - for chunk_id in range(self._num_chunks): - glob_chunk_id = block_id * self._num_chunks + chunk_id - rs_stream = self._rs_st[glob_chunk_id%self._num_rs_pg] - rs_stream.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(rs_stream): - works[chunk_id] = torch.distributed.reduce_scatter(self._fp16_g_chunks[block_id][chunk_id],self._flat_grads_shards[block_id][chunk_id],group=self._rs_pg[glob_chunk_id%self._num_rs_pg],async_op=True,no_copy=True) - - # Reduction across nodes for each rank - if self._num_groups > 1: - for chunk_id in range(self._num_chunks): - glob_chunk_id = block_id * self._num_chunks + chunk_id - ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg] - with torch.cuda.stream(ar_stream): - works[chunk_id].wait() - works[chunk_id] = torch.distributed.all_reduce(self._fp16_g_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True) - self._reductions_works[block_id] = works - - # Compute L2 grad norm - if block_id == 0: - with torch.cuda.stream(self._l2_grad_norm_st): - for block_id in range(self._num_blocks): - for chunk_id in range(self._num_chunks): - self._reductions_works[block_id][chunk_id].wait() - # Since the packed format is contiguous after reductions, only one norm is needed - l2_grad_norm_sq = torch.empty([1], device='cuda') - l2_grad_norm_sq = self._fp16_g.norm(dtype=torch.float32, p=2)**2 - torch.distributed.all_reduce(l2_grad_norm_sq, group=self._l2_grad_norm_pg) - self._L2_grad_norm = l2_grad_norm_sq.sqrt() - - def __compute_contrib_param_norm(self): - if self._contrib_model_param_for_norm_fp16 is not None and self._contrib_model_param_for_norm_fp32 is not None: - gnorm_fp16 = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_model_param_for_norm_fp16], True)[1] - gnorm_fp32 = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_model_param_for_norm_fp32], True)[1] - gnorm = torch.empty(size=[self._contrib_model_param_for_norm_num], dtype=torch.bool, device='cuda') - gnorm.masked_scatter_(self._contrib_model_param_for_norm_is_fp16, gnorm_fp16) - gnorm.masked_scatter_(self._contrib_model_param_for_norm_is_fp32, gnorm_fp32) - elif self._contrib_model_param_for_norm_fp16 is not None: - gnorm = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_model_param_for_norm_fp16], True)[1] - elif self._contrib_model_param_for_norm_fp32 is not None: - gnorm = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_model_param_for_norm_fp32], True)[1] - return gnorm - - def __compute_contrib_update_norm(self): - l2_norm = torch.zeros(size=[self._model_params_num], dtype=torch.float32, device='cuda') - local_contrib_l2_norm = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_update_frag_for_norm], True)[1] ** 2 - l2_norm.scatter_(dim=0, index=self._offsets, src=local_contrib_l2_norm) - torch.distributed.all_reduce(l2_norm, group=self._ag_pg[0]) - l2_norm = torch.sqrt(l2_norm) - return l2_norm - - def _pipeline_step(self): - global_scale = self.global_scale - max_grad_norm = self.defaults['max_grad_norm'] - global_grad_norm = self.L2_grad_norm - - # check global_grad_norm and fill overflow_buf - is_finite = (global_grad_norm + 1 > global_grad_norm).int() - self._overflow_buf = self._one * (is_finite ^ self._one) # toggle between 0 and 1 - - # increment step counter if no overflow - self._step += is_finite - self._completion_st.wait_stream(torch.cuda.current_stream()) - self._completion_st.wait_stream(self._l2_grad_norm_st) - - # Call step kernel once per step - # Call all-gather once per step - with torch.cuda.stream(self._completion_st): - for block_id in range(self._num_blocks): - for chunk_id in range(self._num_chunks): - self._reductions_works[block_id][chunk_id].wait() - param_norm = self.__compute_contrib_param_norm() - multi_tensor_applier(self.multi_tensor_lamb_compute_update_term, - self._overflow_buf, - self._contrib_compute_update_term_tensor_list, # g, p, m, v, u - self._contrib_beta1, - self._contrib_beta2, - self._contrib_beta3, - self._contrib_bias_correction, - self._step, - self._contrib_epsilon, - self._adam_w_mode, - self._contrib_weight_decay, - global_scale, - global_grad_norm, - max_grad_norm) - upd_norm = self.__compute_contrib_update_norm() - multi_tensor_applier(self.multi_tensor_lamb_update_weights, - self._overflow_buf, - self._contrib_update_weights_tensor_list, # u, p, p_copy - param_norm, - upd_norm, - self._offsets, - self._lr, - self._contrib_weight_decay, - global_grad_norm, - self._use_nvlamb) - torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True) - - def _flatten_grad_mt(self, scale): - if len(self._grads_fp16) > 0: - self._overflow_buf.zero_() - multi_tensor_applier( - amp_C.multi_tensor_scale, - self._overflow_buf, - list(zip(*self._grads_fp16)), - scale) - self._grads_fp16 = [] - if len(self._grads_fp32) > 0: - self._overflow_buf.zero_() - multi_tensor_applier( - amp_C.multi_tensor_scale, - self._overflow_buf, - list(zip(*self._grads_fp32)), - scale) - self._grads_fp32 = [] - - def _do_overlapped_reduction(self, param_i, param): - if not self._is_accumulation_step: - # handle overlapped reductions - if param.dtype == torch.float16: - self._grads_fp16.append( (param.grad, self._individual_flat_grads[param_i]) ) - else: - self._grads_fp32.append( (param.grad, self._individual_flat_grads[param_i]) ) - self._grads_generated[param_i]=True - if not self._first_step and not self._last_step: - if self._overlap_reductions: - flush_block = self._get_flush_block() - while flush_block: - block_id = flush_block[0] // self._block_size - self._pipeline_block_reductions(block_id) - flush_block = self._get_flush_block() - - def set_global_scale(self, global_scale): - """Set global scale. - """ - self._global_scale = global_scale - - @property - def global_scale(self): - return self._global_scale - - @property - def L2_grad_norm(self): - torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st) - return self._L2_grad_norm - - def complete_reductions(self): - """Complete reductions if full pipeline is not selected or overlap is not allowed. - """ - if self._last_step: - # zero out gradients that have not been completed yet - for param_i, grad_generated in enumerate(self._grads_generated): - if not grad_generated: - grad_info = self._grads_info[param_i] - param_offset = grad_info["param_offset"] - param_size = grad_info["param_grads_size"] - self._flat_grads[param_offset:param_offset+param_size].zero_() - self._grads_generated[param_i] = True - - if self._first_step or self._last_step or not self._overlap_reductions: - # nothing done so far, run full pipeline after reductions - for block_id in range(self._num_blocks-1,-1,-1): - self._pipeline_block_reductions(block_id) - - torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st) - - self._current_block = self._num_blocks - self._grads_generated = [False]*len(self._grads_info) - - def step(self, closure=None, grad_scaler=None): - loss = None - if closure is not None: - loss = closure() - - self._pipeline_step() - - if grad_scaler is not None: - found_inf = self._overflow_buf.float() - optimizer_state = grad_scaler._per_optimizer_states[id(self)] - current_device = torch.device('cuda', torch.cuda.current_device()) - optimizer_state["found_inf_per_device"][current_device] = found_inf - - self._completion_st.wait_stream(torch.cuda.current_stream()) - - with torch.cuda.stream(self._completion_st): - # Copy self._new_params to model params - with torch.no_grad(): - if self._packed_flat_to_model_params_fp16 is not None: - multi_tensor_applier( - fused_adam_cuda.maybe_cast_mt, - self._overflow_buf, - self._packed_flat_to_model_params_fp16) - if self._packed_flat_to_model_params_fp32 is not None: - multi_tensor_applier( - fused_adam_cuda.maybe_cast_mt, - self._overflow_buf, - self._packed_flat_to_model_params_fp32) - - torch.cuda.current_stream().wait_stream(self._completion_st) - - self._reductions_works = [None]*self._num_blocks - self._allgather_works = [None]*self._num_blocks - - return loss - - def state_dict(self): - """ - Returns a dict containing the current state of this :class:`DistributedFusedAdam` instance. - Example:: - checkpoint = {} - checkpoint['model'] = model.state_dict() - checkpoint['optimizer'] = optimizer.state_dict() - torch.save(checkpoint, "saved.pth") - """ - # save step, master weights and first/second moments - state_dict = {} - state_dict['step'] = self._step - state_dict['fp32_p'] = self._fp32_p - state_dict['fp32_m'] = self._fp32_m - state_dict['fp32_v'] = self._fp32_v - return state_dict - - def load_state_dict(self, state_dict): - """ - Loads a state_dict created by an earlier call to state_dict(). - If an DistributedFusedAdam instance was constructed from some ``init_optimizer``, - whose parameters in turn came from ``model``, it is expected that the user - will call ``model.load_state_dict()`` before - ``optimizer.load_state_dict()`` is called. - Example:: - model = torch.nn.Linear(D_in, D_out).cuda().half() - optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) - optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) - ... - checkpoint = torch.load("saved.pth") - model.load_state_dict(checkpoint['model']) - optimizer.load_state_dict(checkpoint['optimizer']) - """ - # restore step, master weights and first/second moments - self._step = state_dict['step'] - self._fp32_p = state_dict['fp32_p'].to(device="cuda") - self._fp32_m = state_dict['fp32_m'].to(device="cuda") - self._fp32_v = state_dict['fp32_v'].to(device="cuda") - self._resume_from_checkpoint = True diff --git a/apex/contrib/optimizers/fp16_optimizer.py b/apex/contrib/optimizers/fp16_optimizer.py deleted file mode 100755 index 0cbb63b..0000000 --- a/apex/contrib/optimizers/fp16_optimizer.py +++ /dev/null @@ -1,243 +0,0 @@ -import torch -from apex.multi_tensor_apply import multi_tensor_applier - -class FP16_Optimizer(object): - """ - :class:`FP16_Optimizer` A cutdown version of apex.fp16_utils.FP16_Optimizer. - Designed only to wrap apex.contrib.optimizers.FusedAdam, FusedSGD. - Refer to apex.fp16_utils documents for more information. - Example:: - model = torch.nn.Linear(D_in, D_out).cuda().half() - optimizer = apex.contrib.optimizers.FusedSGD(model.parameters()) - optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) - ... - # loss.backward() becomes: - optimizer.backward(loss) - ... - Example with dynamic loss scaling:: - ... - optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) - # optional arg to control dynamic loss scaling behavior - # dynamic_loss_args={'scale_window' : 500}) - # Usually, dynamic_loss_args is not necessary. - """ - - def __init__(self, - init_optimizer, - static_loss_scale=1.0, - dynamic_loss_scale=False, - dynamic_loss_args=None, - verbose=True): - - print("\nThis fp16_optimizer is designed to only work with apex.contrib.optimizers.*") - print("To update, use updated optimizers with AMP.") - # The fused optimizer does all the work. We need this layer for two reason: - # 1. maintain same user API from apex.fp16_utils - # 2. keep common stuff here in case we need to add new fused optimizer later - - if not torch.cuda.is_available: - raise SystemError("Cannot use fp16 without CUDA.") - self.optimizer = init_optimizer - - self.fp16_groups = [] # model params - self.fp32_groups = [] # master weights - - # iterate over param_groups - for param_group in self.optimizer.param_groups: - fp16_group = [] - fp32_group = [] - for p in param_group['params']: - fp16_group.append(p) - fp32_group.append(p.clone().float().detach()) - self.fp16_groups.append(fp16_group) - self.fp32_groups.append(fp32_group) - param_group['params'] = fp32_group - - if multi_tensor_applier.available: - import amp_C - self.overflow_buf = torch.cuda.IntTensor([0]) - self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm - else: - raise RuntimeError('FP16_Optimizer requires cuda extensions') - - # we may have a way of fusing dynamic scale. Do not support for now - if dynamic_loss_scale: - if dynamic_loss_args is not None: - raise SystemError("Do not support dynamic loss scale args for now.") - self.dynamic_loss_scale = True - self.cur_scale = 2**16 - self.cur_iter = 0 - self.last_overflow_iter = -1 - self.scale_factor = 2 - self.scale_window = 1000 - else: - self.dynamic_loss_scale = False - self.cur_iter = 0 - self.cur_scale = static_loss_scale - self.verbose = verbose - - def zero_grad(self, set_grads_to_None=True): - """ - Zero FP16 parameter grads. - """ - # FP32 grad should never exist. - # For speed, set model fp16 grad to None by default - for group in self.fp16_groups: - for p in group: - if set_grads_to_None: - p.grad = None - else: - if p.grad is not None: - p.grad.detach_() - p.grad.zero_() - - def step(self, closure=None): - """ - Not supporting closure. - """ - fp16_grads = [] - norm_groups = [] - skip = False - - for group in self.fp16_groups: - fp16_grad = [] - for i, p in enumerate(group): - fp16_grad.append(p.grad) - fp16_grads.append(fp16_grad) - - # nan check - self.overflow_buf.zero_() - for fp16_grad in fp16_grads: - if len(fp16_grad) > 0: - norm, norm_per_tensor = multi_tensor_applier(self.multi_tensor_l2norm, - self.overflow_buf, - [fp16_grad], True) - norm_groups.append(norm) - if self.overflow_buf.item() != 0: - skip = True - - if skip: - self._update_scale(skip) - return - - # norm is in fact norm*cur_scale - self.optimizer.step(grads=fp16_grads, - output_params=self.fp16_groups, - scale=self.cur_scale, - grad_norms=norm_groups) - - self._update_scale(False) - return - - def backward(self, loss): - """ - :attr:`backward` performs the following steps: - 1. fp32_loss = loss.float() - 2. scaled_loss = fp32_loss*loss_scale - 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves - """ - scaled_loss = (loss.float()) * self.cur_scale - scaled_loss.backward() - - def _update_scale(self, skip): - if self.dynamic_loss_scale: - if skip: - if self.verbose: - print("\nGrad overflow on iteration", self.cur_iter) - print("Using dynamic loss scale of", self.cur_scale) - self.cur_scale = max(self.cur_scale/self.scale_factor, 1) - self.last_overflow_iter = self.cur_iter - else: - if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0: - self.cur_scale *= self.scale_factor - else: - if skip: - print("\nGrad overflow on iteration", self.cur_iter) - print("Using static loss scale of", self.cur_scale) - self.cur_iter +=1 - return - - # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state" - def _get_state(self): - return self.optimizer.state - - def _set_state(self, value): - self.optimizer.state = value - - state = property(_get_state, _set_state) - - # Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups" - # (for example, to adjust the learning rate) - def _get_param_groups(self): - return self.optimizer.param_groups - - def _set_param_groups(self, value): - self.optimizer.param_groups = value - - param_groups = property(_get_param_groups, _set_param_groups) - - def state_dict(self): - """ - Returns a dict containing the current state of this :class:`FP16_Optimizer` instance. - This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict - of the contained Pytorch optimizer. - Example:: - checkpoint = {} - checkpoint['model'] = model.state_dict() - checkpoint['optimizer'] = optimizer.state_dict() - torch.save(checkpoint, "saved.pth") - """ - state_dict = {} - state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale - state_dict['cur_scale'] = self.cur_scale - state_dict['cur_iter'] = self.cur_iter - if state_dict['dynamic_loss_scale']: - state_dict['last_overflow_iter'] = self.last_overflow_iter - state_dict['scale_factor'] = self.scale_factor - state_dict['scale_window'] = self.scale_window - state_dict['optimizer_state_dict'] = self.optimizer.state_dict() - state_dict['fp32_groups'] = self.fp32_groups - return state_dict - - def load_state_dict(self, state_dict): - """ - Loads a state_dict created by an earlier call to state_dict(). - If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``, - whose parameters in turn came from ``model``, it is expected that the user - will call ``model.load_state_dict()`` before - ``fp16_optimizer_instance.load_state_dict()`` is called. - Example:: - model = torch.nn.Linear(D_in, D_out).cuda().half() - optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) - optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) - ... - checkpoint = torch.load("saved.pth") - model.load_state_dict(checkpoint['model']) - optimizer.load_state_dict(checkpoint['optimizer']) - """ - # I think it should actually be ok to reload the optimizer before the model. - self.dynamic_loss_scale = state_dict['dynamic_loss_scale'] - self.cur_scale = state_dict['cur_scale'] - self.cur_iter = state_dict['cur_iter'] - if state_dict['dynamic_loss_scale']: - self.last_overflow_iter = state_dict['last_overflow_iter'] - self.scale_factor = state_dict['scale_factor'] - self.scale_window = state_dict['scale_window'] - self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) - # At this point, the optimizer's references to the model's fp32 parameters are up to date. - # The optimizer's hyperparameters and internal buffers are also up to date. - # However, the fp32 master copies of the model's fp16 params stored by the optimizer are still - # out of date. There are two options. - # 1: Refresh the master params from the model's fp16 params. - # This requires less storage but incurs precision loss. - # 2: Save and restore the fp32 master copies separately. - # We choose option 2. - # - # Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device - # of their associated parameters, because it's possible those buffers might not exist yet in - # the current optimizer instance. In our case, as long as the current FP16_Optimizer has been - # constructed in the same way as the one whose state_dict we are loading, the same master params - # are guaranteed to exist, so we can just copy_() from the saved master params. - for current, saved in zip(self.fp32_groups, state_dict['fp32_groups']): - for _current, _saved in zip(current, saved): - _current.data.copy_(_saved.data) diff --git a/apex/contrib/optimizers/fused_adam.py b/apex/contrib/optimizers/fused_adam.py deleted file mode 100644 index a823e7b..0000000 --- a/apex/contrib/optimizers/fused_adam.py +++ /dev/null @@ -1,206 +0,0 @@ -import types -import torch -import importlib -from apex.multi_tensor_apply import multi_tensor_applier - -class FusedAdam(torch.optim.Optimizer): - - """Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via - ``python setup.py install --cuda_ext --cpp_ext``. - - It has been proposed in `Adam: A Method for Stochastic Optimization`_. - - Arguments: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups. - lr (float, optional): learning rate. (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square. (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability. (default: 1e-8) - weight_decay (float, optional): weight decay (L2 penalty) (default: 0) - amsgrad (boolean, optional): whether to use the AMSGrad variant of this - algorithm from the paper `On the Convergence of Adam and Beyond`_ - (default: False) NOT SUPPORTED in FusedAdam! - eps_inside_sqrt (boolean, optional): in the 'update parameters' step, - adds eps to the bias-corrected second moment estimate before - evaluating square root instead of adding it to the square root of - second moment estimate as in the original paper. (default: False) - use_mt (boolean, optional): use multi tensor apply for lower launch - latency. (default: False) - - .. _Adam - A Method for Stochastic Optimization: - https://arxiv.org/abs/1412.6980 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ - """ - - def __init__(self, params, - lr=1e-3, bias_correction = True, - betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False, - weight_decay=0., max_grad_norm=0., amsgrad=False, use_mt=False, - amp_scale_adjustment=1.0): - global fused_adam_cuda - fused_adam_cuda = importlib.import_module("fused_adam_cuda") - - self._use_multi_tensor = False - if use_mt: - if not multi_tensor_applier.available: - print("Warning: multi_tensor_applier is unavailable") - else: - self._use_multi_tensor = True - self._overflow_buf = torch.cuda.IntTensor([0]) - - self._amp_scale_adjustment = amp_scale_adjustment - - if amsgrad: - raise RuntimeError('FusedAdam does not support the AMSGrad variant.') - defaults = dict(lr=lr, bias_correction=bias_correction, - betas=betas, eps=eps, weight_decay=weight_decay, - max_grad_norm=max_grad_norm) - super(FusedAdam, self).__init__(params, defaults) - self.eps_mode = 0 if eps_inside_sqrt else 1 - - def step(self, closure=None, grads=None, output_params=None, scale=1., grad_norms=None): - """Performs a single optimization step. - - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - grads (list of tensors, optional): weight gradient to use for the - optimizer update. If gradients have type torch.half, parameters - are expected to be in type torch.float. (default: None) - output params (list of tensors, optional): A reduced precision copy - of the updated weights written out in addition to the regular - updated weights. Have to be of same type as gradients. (default: None) - scale (float, optional): factor to divide gradient tensor values - by before applying to weights. (default: 1) - """ - loss = None - if closure is not None: - loss = closure() - - if hasattr(self, "_amp_stash"): - grads = self._amp_stash.grads - output_params = self._amp_stash.output_params - scale = self._amp_stash.scale*self._amp_scale_adjustment - grad_norms = self._amp_stash.grad_norms - - if grads is None: - grads_group = [None]*len(self.param_groups) - # backward compatibility - # assuming a list/generator of parameter means single group - elif isinstance(grads, types.GeneratorType): - grads_group = [grads] - elif type(grads[0])!=list: - grads_group = [grads] - else: - grads_group = grads - - if output_params is None: - output_params_group = [None]*len(self.param_groups) - elif isinstance(output_params, types.GeneratorType): - output_params_group = [output_params] - elif type(output_params[0])!=list: - output_params_group = [output_params] - else: - output_params_group = output_params - - if grad_norms is None: - grad_norms = [None]*len(self.param_groups) - - for group, grads_this_group, output_params_this_group, grad_norm in zip(self.param_groups, grads_group, output_params_group, grad_norms): - if grads_this_group is None: - grads_this_group = [None]*len(group['params']) - if output_params_this_group is None: - output_params_this_group = [None]*len(group['params']) - - # compute combined scale factor for this group - combined_scale = scale - if group['max_grad_norm'] > 0: - # norm is in fact norm*scale - clip = ((grad_norm / scale) + 1e-6) / group['max_grad_norm'] - if clip > 1: - combined_scale = clip * scale - - bias_correction = 1 if group['bias_correction'] else 0 - - if self._use_multi_tensor: - if output_params: - tensorlists = [[],[],[],[],[]] - else: - tensorlists = [[],[],[],[]] - tensordevice = None - - for p, grad, output_param in zip(group['params'], grads_this_group, output_params_this_group): - #note: p.grad should not ever be set for correct operation of mixed precision optimizer that sometimes sends None gradients - if p.grad is None and grad is None: - continue - if grad is None: - grad = p.grad.data - if grad.is_sparse: - raise RuntimeError('FusedAdam does not support sparse gradients, please consider SparseAdam instead') - - state = self.state[p] - - # State initialization - if len(state) == 0: - state['step'] = 0 - # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p.data) - # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like(p.data) - - exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] - beta1, beta2 = group['betas'] - - state['step'] += 1 - - out_p = torch.tensor([], dtype = torch.float) if output_param is None else output_param - if self._use_multi_tensor: - pl = [p.data, exp_avg, exp_avg_sq, grad] - if output_param is not None: - pl.append(out_p) - - for tl, t in zip(tensorlists, pl): - tl.append(t) - - if tensordevice is None: - tensordevice = p.device - elif tensordevice != p.device: - raise RuntimeError('FusedAdam does not support use_mt with tensors on multiple device') - - else: - with torch.cuda.device(p.device): - fused_adam_cuda.adam(p.data, - out_p, - exp_avg, - exp_avg_sq, - grad, - group['lr'], - beta1, - beta2, - group['eps'], - combined_scale, - state['step'], - self.eps_mode, - bias_correction, - group['weight_decay']) - - if self._use_multi_tensor: - with torch.cuda.device(tensordevice): - multi_tensor_applier( - fused_adam_cuda.adam_mt, - self._overflow_buf, - tensorlists, - group['lr'], - beta1, - beta2, - group['eps'], - combined_scale, - state['step'], - self.eps_mode, - bias_correction, - group['weight_decay']) - - return loss diff --git a/apex/contrib/optimizers/fused_lamb.py b/apex/contrib/optimizers/fused_lamb.py deleted file mode 100644 index 81d8682..0000000 --- a/apex/contrib/optimizers/fused_lamb.py +++ /dev/null @@ -1,208 +0,0 @@ -import torch -import importlib -import math -from apex.multi_tensor_apply import multi_tensor_applier - -class FusedLAMB(torch.optim.Optimizer): - - """Implements LAMB algorithm. - - Currently GPU-only. Requires Apex to be installed via - ``pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--deprecated_fused_lamb" ./``. - - This version of fused LAMB implements 2 fusions. - - * Fusion of the LAMB update's elementwise operations - * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches. - - :class:`apex.contrib.optimizers.FusedLAMB`'s usage is identical to any ordinary Pytorch optimizer:: - - opt = apex.contrib.optimizers.FusedLAMB(model.parameters(), lr = ....) - ... - opt.step() - - :class:`apex.optimizers.FusedLAMB` may be used with or without Amp. If you wish to use :class:`FusedLAMB` with Amp, - you may choose any ``opt_level``:: - - opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....) - model, opt = amp.initialize(model, opt, opt_level="O0" or "O1 or "O2") - ... - opt.step() - - In general, ``opt_level="O1"`` is recommended. - - LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. - - Arguments: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups. - lr (float, optional): learning rate. (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its norm. (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability. (default: 1e-8) - weight_decay (float, optional): weight decay (L2 penalty) (default: 0) - amsgrad (boolean, optional): whether to use the AMSGrad variant of this - algorithm from the paper `On the Convergence of Adam and Beyond`_ - NOT SUPPORTED now! (default: False) - adam_w_mode (boolean, optional): Apply L2 regularization or weight decay - True for decoupled weight decay(also known as AdamW) (default: True) - grad_averaging (bool, optional): whether apply (1-beta2) to grad when - calculating running averages of gradient. (default: True) - set_grad_none (bool, optional): whether set grad to None when zero_grad() - method is called. (default: True) - max_grad_norm (float, optional): value used to clip global grad norm - (default: 1.0) - - .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes: - https://arxiv.org/abs/1904.00962 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ - """ - - def __init__(self, params, lr=1e-3, bias_correction=True, - betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01, - amsgrad=False, adam_w_mode=True, - grad_averaging=True, set_grad_none=True, - max_grad_norm=1.0): - if amsgrad: - raise RuntimeError('FusedLAMB does not support the AMSGrad variant.') - defaults = dict(lr=lr, bias_correction=bias_correction, - betas=betas, eps=eps, weight_decay=weight_decay, - grad_averaging=grad_averaging, - max_grad_norm=max_grad_norm) - super(FusedLAMB, self).__init__(params, defaults) - if multi_tensor_applier.available: - import amp_C - self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm - self._dummy_overflow_buf = torch.cuda.IntTensor([0]) - fused_lamb_cuda = importlib.import_module("fused_lamb_cuda") - self.multi_tensor_lamb = fused_lamb_cuda.lamb - else: - raise RuntimeError('apex.contrib.optimizers.FusedLAMB requires cuda extensions') - - self.adam_w_mode = 1 if adam_w_mode else 0 - self.set_grad_none = set_grad_none - - def zero_grad(self): - if self.set_grad_none: - for group in self.param_groups: - for p in group['params']: - p.grad = None - else: - super(FusedLAMB, self).zero_grad() - - def step(self, closure=None): - """Performs a single optimization step. - - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - loss = closure() - - # create separate grad lists for fp32 and fp16 params - g_all_32, g_all_16 = [], [] - for group in self.param_groups: - for p in group['params']: - if p.grad is None: - continue - if p.dtype == torch.float32: - g_all_32.append(p.grad.data) - elif p.dtype == torch.float16: - g_all_16.append(p.grad.data) - else: - raise RuntimeError('FusedLAMB only support fp16 and fp32.') - - g_norm_32, g_norm_16 = 0.0, 0.0 - # compute grad norm for two lists - if len(g_all_32) > 0: - g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm, - self._dummy_overflow_buf, - [g_all_32], False)[0].item() - if len(g_all_16) > 0: - g_norm_16 = multi_tensor_applier(self.multi_tensor_l2norm, - self._dummy_overflow_buf, - [g_all_16], False)[0].item() - - # blend two grad norms to get global grad norm - global_grad_norm = math.sqrt(g_norm_32 * g_norm_32 + g_norm_16 * g_norm_16) - max_grad_norm = self.defaults['max_grad_norm'] - - for group in self.param_groups: - bias_correction = 1 if group['bias_correction'] else 0 - beta1, beta2 = group['betas'] - grad_averaging = 1 if group['grad_averaging'] else 0 - - # assume same step across group now to simplify things - # per parameter step can be easily support by making it tensor, or pass list into kernel - if 'step' in group: - group['step'] += 1 - else: - group['step'] = 1 - - # create lists for multi-tensor apply - g_16, p_16, m_16, v_16 = [], [], [], [] - g_32, p_32, m_32, v_32 = [], [], [], [] - - for p in group['params']: - if p.grad is None: - continue - if p.grad.data.is_sparse: - raise RuntimeError('FusedLAMB does not support sparse gradients, please consider SparseAdam instead') - - state = self.state[p] - # State initialization - if len(state) == 0: - # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p.data) - # Exponential moving average of gradient values - state['exp_avg_sq'] = torch.zeros_like(p.data) - - if p.dtype == torch.float16: - g_16.append(p.grad.data) - p_16.append(p.data) - m_16.append(state['exp_avg']) - v_16.append(state['exp_avg_sq']) - elif p.dtype == torch.float32: - g_32.append(p.grad.data) - p_32.append(p.data) - m_32.append(state['exp_avg']) - v_32.append(state['exp_avg_sq']) - else: - raise RuntimeError('FusedLAMB only support fp16 and fp32.') - - if(len(g_16) > 0): - multi_tensor_applier(self.multi_tensor_lamb, - self._dummy_overflow_buf, - [g_16, p_16, m_16, v_16], - group['lr'], - beta1, - beta2, - group['eps'], - group['step'], - bias_correction, - group['weight_decay'], - grad_averaging, - self.adam_w_mode, - global_grad_norm, - max_grad_norm) - if(len(g_32) > 0): - multi_tensor_applier(self.multi_tensor_lamb, - self._dummy_overflow_buf, - [g_32, p_32, m_32, v_32], - group['lr'], - beta1, - beta2, - group['eps'], - group['step'], - bias_correction, - group['weight_decay'], - grad_averaging, - self.adam_w_mode, - global_grad_norm, - max_grad_norm) - - return loss diff --git a/apex/contrib/optimizers/fused_sgd.py b/apex/contrib/optimizers/fused_sgd.py deleted file mode 100644 index 83587c6..0000000 --- a/apex/contrib/optimizers/fused_sgd.py +++ /dev/null @@ -1,211 +0,0 @@ -import types -import torch -from torch.optim.optimizer import Optimizer, required - -from apex.multi_tensor_apply import multi_tensor_applier - -class FusedSGD(Optimizer): - r"""Implements stochastic gradient descent (optionally with momentum). - - This version of fused SGD implements 2 fusions. - * Fusion of the SGD update's elementwise operations - * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches. - - :class:`apex.contrib.optimizers.FusedSGD` should be used without AMP. - - :class:`apex.contrib.optimizers.FusedSGD` only works in the case where all parameters require grad. - - Nesterov momentum is based on the formula from - `On the importance of initialization and momentum in deep learning`__. - - Args: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float): learning rate - momentum (float, optional): momentum factor (default: 0) - weight_decay (float, optional): weight decay (L2 penalty) (default: 0) - dampening (float, optional): dampening for momentum (default: 0) - nesterov (bool, optional): enables Nesterov momentum (default: False) - - Example: - model = ... - model.half() - optimizer = apex.contrib.optimizers.FusedSGD(model.parameters()) - # wrap with FP16_Optimizer - optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) - optimizer.zero_grad() - ... - optimizer.backward(loss) - optmizer.step() - - __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf - - .. note:: - The implementation of SGD with Momentum/Nesterov subtly differs from - Sutskever et. al. and implementations in some other frameworks. - - Considering the specific case of Momentum, the update can be written as - - .. math:: - v = \rho * v + g \\ - p = p - lr * v - - where p, g, v and :math:`\rho` denote the parameters, gradient, - velocity, and momentum respectively. - - This is in contrast to Sutskever et. al. and - other frameworks which employ an update of the form - - .. math:: - v = \rho * v + lr * g \\ - p = p - v - - The Nesterov version is analogously modified. - """ - - def __init__(self, params, lr=required, momentum=0, dampening=0, - weight_decay=0, nesterov=False, - wd_after_momentum=False, - materialize_master_grads=True): - if lr is not required and lr < 0.0: - raise ValueError("Invalid learning rate: {}".format(lr)) - if momentum < 0.0: - raise ValueError("Invalid momentum value: {}".format(momentum)) - if weight_decay < 0.0: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) - - defaults = dict(lr=lr, momentum=momentum, dampening=dampening, - weight_decay=weight_decay, nesterov=nesterov) - if nesterov and (momentum <= 0 or dampening != 0): - raise ValueError("Nesterov momentum requires a momentum and zero dampening") - super(FusedSGD, self).__init__(params, defaults) - - self.wd_after_momentum = wd_after_momentum - - if multi_tensor_applier.available: - import amp_C - # Skip buffer - self._dummy_overflow_buf = torch.cuda.IntTensor([0]) - self.multi_tensor_sgd = amp_C.multi_tensor_sgd - else: - raise RuntimeError('apex.contrib.optimizers.FusedSGD requires cuda extensions') - - def __setstate__(self, state): - super(FusedSGD, self).__setstate__(state) - for group in self.param_groups: - group.setdefault('nesterov', False) - - def get_momentums(self, params): - momentums = [] - first_run = True - for p in params: - param_state = self.state[p] - # torch.optim.SGD initializes momentum in the main loop, we have - # to do it here, and track whether or not we've done so, so that - # momentum application can be skipped in the main kernel. - if 'momentum_buffer' not in param_state: - first_run = True - buf = param_state['momentum_buffer'] = torch.zeros_like(p.data) - momentums.append(buf) - else: - first_run = False - momentums.append(param_state['momentum_buffer']) - return momentums, first_run - - def step(self, closure=None, grads=None, output_params=None, scale=1., grad_norms=None): - """Performs a single optimization step. - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - grads (list of tensors, optional): weight gradient to use for the - optimizer update. If gradients have type torch.half, parameters - are expected to be in type torch.float. (default: None) - output_params (list of tensors, optional): A reduced precision copy - of the updated weights written out in addition to the regular - updated weights. Have to be of same type as gradients. (default: None) - scale (float, optional): factor to divide gradient tensor values - by before applying to weights. (default: 1) - """ - if hasattr(self, "_amp_stash"): - raise RuntimeError('apex.contrib.optimizers.FusedSGD should not be used with AMP.') - - loss = None - if closure is not None: - loss = closure() - - if grads is None: - raise RuntimeError('apex.contrib.optimizers.FusedSGD must be wrapped \ - with apex.contrib.optimizers.FP16_Optimizer \ - which provides grads.') - # backward compatibility - # assuming a list/generator of parameter means single group - elif isinstance(grads, types.GeneratorType): - grads_group = [grads] - elif type(grads[0])!=list: - grads_group = [grads] - else: - grads_group = grads - - if output_params is None: - raise RuntimeError('apex.contrib.optimizers.FusedSGD must be wrapped \ - with apex.contrib.optimizers.FP16_Optimizer \ - which provides output_params.') - elif isinstance(output_params, types.GeneratorType): - output_params_group = [output_params] - elif type(output_params[0])!=list: - output_params_group = [output_params] - else: - output_params_group = output_params - - for group, grads_this_group, output_params_this_group in zip(self.param_groups, - grads_group, - output_params_group): - if grads_this_group is None or output_params_this_group is None: - raise RuntimeError('apex.contrib.optimizers.FusedSGD only works \ - when all parameters require grad.') - - weight_decay = group['weight_decay'] - momentum = group['momentum'] - dampening = group['dampening'] - nesterov = group['nesterov'] - lr = group['lr'] - - first_runs = [True, True] - - # output_params_this_group: original weights (either fp16 or fp32) - # group['params']: master weights (fp32) - - # grad_type, param_to_update_type, momentum_type, requires_fp16_model_copy - # fp32, fp32, fp32, No - fp32_grads = [g for (p, g) in zip(output_params_this_group, grads_this_group) if p.dtype == torch.float32] - fp32_params = [p2 for (p1, p2) in zip(output_params_this_group, group['params']) if p1.dtype == torch.float32] - fp32_momentums, first_runs[1] = self.get_momentums(fp32_params) - fp32_set = [fp32_grads, fp32_params, fp32_momentums] - - # fp16, fp32, fp32, Yes - fp16_grads = [g for (p, g) in zip(output_params_this_group, grads_this_group) if p.dtype == torch.float16] - fp32_from_fp16_params = [p2 for (p1, p2) in zip(output_params_this_group, group['params']) if p1.dtype == torch.float16] - fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(fp32_from_fp16_params) - fp16_params = [p1 for (p1, p2) in zip(output_params_this_group, group['params']) if p1.dtype == torch.float16] - fp16_set = [fp16_grads, fp32_from_fp16_params, fp32_from_fp16_momentums, fp16_params] - - launch_sets = [fp16_set, fp32_set] - - for launch_set, first_run in zip(launch_sets, first_runs): - assert len(launch_set[0]) == len(launch_set[1]) - assert len(launch_set[0]) == len(launch_set[2]) - if len(launch_set[0]) > 0: - multi_tensor_applier( - self.multi_tensor_sgd, - self._dummy_overflow_buf, - launch_set, - weight_decay, - momentum, - dampening, - lr, - nesterov, - first_run, - self.wd_after_momentum, - 1.0/scale) - - return loss diff --git a/apex/contrib/peer_memory/__init__.py b/apex/contrib/peer_memory/__init__.py deleted file mode 100644 index 8d6fa54..0000000 --- a/apex/contrib/peer_memory/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .peer_memory import PeerMemoryPool -from .peer_halo_exchanger_1d import PeerHaloExchanger1d - diff --git a/apex/contrib/peer_memory/peer_halo_exchange_module_tests.py b/apex/contrib/peer_memory/peer_halo_exchange_module_tests.py deleted file mode 100644 index bd85354..0000000 --- a/apex/contrib/peer_memory/peer_halo_exchange_module_tests.py +++ /dev/null @@ -1,164 +0,0 @@ -import torch -from apex.contrib.peer_memory import PeerMemoryPool, PeerHaloExchanger1d -import peer_memory_cuda as pm - -# How to run: -# torchrun --nproc_per_node -# must be a power of 2 greater than 1. - - -# Output of this function is used as ground truth in module tests. -def nccl_halo_ex(peer_rank, peer_group_size, y, half_halo, explicit_nhwc, H_split): - if explicit_nhwc: - if H_split: - _, Hp, _, _ = list(y.shape) - H = Hp - 2*half_halo - top_out_halo = y[:,half_halo:2*half_halo,:,:] - top_inp_halo = y[:,:half_halo,:,:] - btm_out_halo = y[:,H:H+half_halo,:,:] - btm_inp_halo = y[:,H+half_halo:H+2*half_halo,:,:] - else: - _, _, Wp, _ = list(y.shape) - W = Wp - 2*half_halo - top_out_halo = y[:,:,half_halo:2*half_halo,:] - top_inp_halo = y[:,:,:half_halo,:] - btm_out_halo = y[:,:,W:W+half_halo,:] - btm_inp_halo = y[:,:,W+half_halo:W+2*half_halo,:] - else: - if H_split: - _, _, Hp, _ = list(y.shape) - H = Hp - 2*half_halo - top_out_halo = y[:,:,half_halo:2*half_halo,:] - top_inp_halo = y[:,:,:half_halo,:] - btm_out_halo = y[:,:,H:H+half_halo,:] - btm_inp_halo = y[:,:,H+half_halo:H+2*half_halo,:] - else: - _, _, _, Wp = list(y.shape) - W = Wp - 2*half_halo - top_out_halo = y[:,:,:,half_halo:2*half_halo] - top_inp_halo = y[:,:,:,:half_halo] - btm_out_halo = y[:,:,:,W:W+half_halo] - btm_inp_halo = y[:,:,:,W+half_halo:W+2*half_halo] - - mf = torch.channels_last if y.is_contiguous(memory_format=torch.channels_last) else torch.contiguous_format - top_out_halo = top_out_halo.contiguous() - btm_out_halo = btm_out_halo.contiguous() - - top_inp_halos = [torch.empty_like(top_out_halo) for _ in range(peer_group_size)] - torch.distributed.all_gather(top_inp_halos, top_out_halo) - btm_inp_halos = [torch.empty_like(btm_out_halo) for _ in range(peer_group_size)] - torch.distributed.all_gather(btm_inp_halos, btm_out_halo) - top_rank = (peer_rank + peer_group_size - 1) % peer_group_size - btm_rank = (peer_rank + 1) % peer_group_size - if peer_rank == 0: - top_inp_halo.zero_() - else: - top_inp_halo.copy_(btm_inp_halos[top_rank].to(memory_format=mf)) - if peer_rank == peer_group_size-1: - btm_inp_halo.zero_() - else: - btm_inp_halo.copy_(top_inp_halos[btm_rank].to(memory_format=mf)) - - -def single_test(peer_rank, peer_group_size, halo_ex, C, H, W, half_halo, dtype, memory_format, H_split, num_steps, numSM=1): - if memory_format == 1: - # 1 -> explicit nhwc - explicit_nhwc = True - if H_split: - y = torch.randn([1,H+2*half_halo,W,C], dtype=dtype, device='cuda') - ym = y[:,half_halo:H+half_halo,:,:] - else: - y = torch.randn([1,H,W+2*half_halo,C], dtype=dtype, device='cuda') - ym = y[:,:,half_halo:W+half_halo,:] - else: - # 2 -> native nhwc - # 3 -> nchw - explicit_nhwc = False - if H_split: - y = torch.randn([1,C,H+2*half_halo,W], dtype=dtype, device='cuda') - if memory_format == 2: - y = y.to(memory_format=torch.channels_last) - ym = y[:,:,half_halo:H+half_halo,:] - else: - y = torch.randn([1,C,H,W+2*half_halo], dtype=dtype, device='cuda') - if memory_format == 2: - y = y.to(memory_format=torch.channels_last) - ym = y[:,:,:,half_halo:W+half_halo] - y3 = y.clone() - list_y = [] - for step in range(num_steps): - halo_ex(y, H_split, explicit_nhwc, numSM) - list_y.append(y.clone()) - y.copy_(y3) - halo_ex.peer_pool.reset() - torch.distributed.barrier() - y2 = y3.clone() - list_y2 = [] - for step in range(num_steps): - nccl_halo_ex(peer_rank, peer_group_size, y2, half_halo, explicit_nhwc, H_split) - list_y2.append(y2.clone()) - y2.copy_(y3) - is_equal = [torch.all(torch.eq(yy,yy2)) for yy,yy2 in zip(list_y,list_y2)] - is_equal = torch.tensor(is_equal, dtype=torch.bool) - is_equal = torch.all(is_equal) - if peer_rank == 0: - if memory_format == 1: - memory_format_str = "explicit_nhwc" - elif memory_format == 2: - memory_format_str = "native nhwc" - elif memory_format == 3: - memory_format_str = "nchw" - else: - memory_format_str = "???" - if is_equal: - print("SUCCESS : N,C,H,W = 1,%d,%d,%d, half_halo=%d, %s, %s, %s" % (C,H,W,half_halo,str(dtype),memory_format_str,"H-split" if H_split else "W-split")) - else: - print("FAILURE : N,C,H,W = 1,%d,%d,%d, half_halo=%d, %s, %s, %s" % (C,H,W,half_halo,str(dtype),memory_format_str,"H-split" if H_split else "W-split")) - - # peer memory flag sync relies on there being at least one barrier per step - torch.distributed.barrier() - - -def H_split_tests(N, C, H, W, half_halo, rank, world_size, halo_ex, num_steps): - Hr = 8*world_size - Hp = ((H + Hr - 1) // Hr) * 8 - - for i in range(4): - div = int(pow(2,i)) - single_test(rank, world_size, halo_ex, C*div, Hp//div, W//div, half_halo, torch.float16, 1, True, num_steps) - single_test(rank, world_size, halo_ex, C*div, Hp//div, W//div, half_halo, torch.float16, 2, True, num_steps) - single_test(rank, world_size, halo_ex, C*div, Hp//div, W//div, half_halo, torch.float16, 3, True, num_steps) - - -def W_split_tests(N, C, H, W, half_halo, rank, world_size, halo_ex, num_steps): - Wr = 8*world_size - Wp = ((W + Wr - 1) // Wr) * 8 - - for i in range(4): - div = int(pow(2,i)) - single_test(rank, world_size, halo_ex, C*div, H//div, Wp//div, half_halo, torch.float16, 1, False, num_steps) - single_test(rank, world_size, halo_ex, C*div, H//div, Wp//div, half_halo, torch.float16, 2, False, num_steps) - single_test(rank, world_size, halo_ex, C*div, H//div, Wp//div, half_halo, torch.float16, 3, False, num_steps) - - -def main(): - # for this trivial example peer_rank == rank and peer_group_size == world_size - - torch.distributed.init_process_group("nccl") - rank = torch.distributed.get_rank() - world_size = torch.distributed.get_world_size() - torch.cuda.set_device(rank) - peer_ranks = [i for i in range(world_size)] - pool = PeerMemoryPool(64*1024, 2*1024*1024, peer_ranks) - - num_steps = 100 - - half_halo = 1 - halo_ex = PeerHaloExchanger1d(peer_ranks, rank, pool, half_halo) - - H_split_tests(1,64,336,200, half_halo,rank,world_size,halo_ex,num_steps) - W_split_tests(1,64,200,336, half_halo,rank,world_size,halo_ex,num_steps) - - -if __name__ == "__main__": - main() diff --git a/apex/contrib/peer_memory/peer_halo_exchanger_1d.py b/apex/contrib/peer_memory/peer_halo_exchanger_1d.py deleted file mode 100644 index cc25693..0000000 --- a/apex/contrib/peer_memory/peer_halo_exchanger_1d.py +++ /dev/null @@ -1,65 +0,0 @@ -import torch -from apex.contrib.peer_memory import PeerMemoryPool -import peer_memory_cuda as pm - -class PeerHaloExchanger1d: - def __init__(self, ranks, rank_in_group, peer_pool, half_halo): - self.peer_group_size = len(ranks) - self.ranks = ranks - self.peer_rank = rank_in_group - self.low_neighbor = (self.peer_rank + self.peer_group_size - 1) % self.peer_group_size - self.high_neighbor = (self.peer_rank + 1) % self.peer_group_size - self.low_zero = True if self.peer_rank == 0 else False - self.high_zero = True if self.peer_rank == self.peer_group_size - 1 else False - - self.peer_pool = peer_pool - self.signals = peer_pool.allocate_peer_tensors([2,4], torch.int32, False, False) - self.signals[self.peer_rank].zero_() - self.half_halo = half_halo - - def __call__(self, y, H_split=True, explicit_nhwc=False, numSM=1, diagnostics=False): - channels_last = y.is_contiguous(memory_format=torch.channels_last) and not explicit_nhwc - if H_split: - if explicit_nhwc: - _, Hs, _, _ = list(y.shape) - H = Hs - 2*self.half_halo - low_out_halo = y[:,self.half_halo:2*self.half_halo,:,:] - low_tx = self.peer_pool.allocate_peer_tensors(list(low_out_halo.shape), low_out_halo.dtype, False, True) - low_inp_halo = y[:,:self.half_halo,:,:] - high_out_halo = y[:,H:H+self.half_halo,:,:] - high_tx = self.peer_pool.allocate_peer_tensors(list(high_out_halo.shape), high_out_halo.dtype, False, True) - high_inp_halo = y[:,H+self.half_halo:H+2*self.half_halo,:,:] - else: - _, _, Hs, _ = list(y.shape) - H = Hs - 2*self.half_halo - low_out_halo = y[:,:,self.half_halo:2*self.half_halo,:] - low_tx = self.peer_pool.allocate_peer_tensors(list(low_out_halo.shape), low_out_halo.dtype, channels_last, True) - low_inp_halo = y[:,:,:self.half_halo,:] - high_out_halo = y[:,:,H:H+self.half_halo,:] - high_tx = self.peer_pool.allocate_peer_tensors(list(high_out_halo.shape), high_out_halo.dtype, channels_last, True) - high_inp_halo = y[:,:,H+self.half_halo:H+2*self.half_halo,:] - else: - if explicit_nhwc: - _, _, Ws, _ = list(y.shape) - W = Ws - 2*self.half_halo - low_out_halo = y[:,:,self.half_halo:2*self.half_halo,:] - low_tx = self.peer_pool.allocate_peer_tensors(list(low_out_halo.shape), low_out_halo.dtype, False, True) - low_inp_halo = y[:,:,:self.half_halo,:] - high_out_halo = y[:,:,W:W+self.half_halo,:] - high_tx = self.peer_pool.allocate_peer_tensors(list(high_out_halo.shape), high_out_halo.dtype, False, True) - high_inp_halo = y[:,:,W+self.half_halo:W+2*self.half_halo,:] - else: - _, _, _, Ws = list(y.shape) - W = Ws - 2*self.half_halo - low_out_halo = y[:,:,:,self.half_halo:2*self.half_halo] - low_tx = self.peer_pool.allocate_peer_tensors(list(low_out_halo.shape), low_out_halo.dtype, channels_last, True) - low_inp_halo = y[:,:,:,:self.half_halo] - high_out_halo = y[:,:,:,W:W+self.half_halo] - high_tx = self.peer_pool.allocate_peer_tensors(list(high_out_halo.shape), high_out_halo.dtype, channels_last, True) - high_inp_halo = y[:,:,:,W+self.half_halo:W+2*self.half_halo] - pm.push_pull_halos_1d( - diagnostics, explicit_nhwc, numSM, - self.low_zero, low_out_halo, low_tx[self.peer_rank], high_tx[self.low_neighbor], low_inp_halo, - self.high_zero, high_out_halo, high_tx[self.peer_rank], low_tx[self.high_neighbor], high_inp_halo, - self.signals[self.low_neighbor], self.signals[self.high_neighbor], self.signals[self.peer_rank] - ) diff --git a/apex/contrib/peer_memory/peer_memory.py b/apex/contrib/peer_memory/peer_memory.py deleted file mode 100644 index adb2182..0000000 --- a/apex/contrib/peer_memory/peer_memory.py +++ /dev/null @@ -1,87 +0,0 @@ -import torch -import numpy as np -import peer_memory_cuda as pm - -class PeerMemoryPool(object): - - def __init__(self, static_size, dynamic_size, peer_ranks=None): - rank = torch.distributed.get_rank() - world_size = torch.distributed.get_world_size() - ngpus = min(torch.cuda.device_count(), world_size) - peer_group_size = ngpus - peer_group = rank // ngpus - peer_rank_base = peer_group * ngpus - peer_rank = rank - peer_rank_base - if peer_ranks is None: - peer_ranks = [i+peer_rank_base for i in range(peer_group_size)] - peer_rank_start = peer_rank_base - peer_rank_end = peer_rank_start + peer_group_size - 1 - for pr in peer_ranks: - assert(pr >= peer_rank_start and pr <= peer_rank_end), "%d :: peer_rank %d not on same node (ranks=[%d,%d])" % (rank, pr, peer_rank_start, peer_rank_end) - - self.alignment = 256 - self.static_size = ((static_size + self.alignment - 1) // self.alignment) * self.alignment - self.dynamic_size = ((dynamic_size + self.alignment - 1) // self.alignment) * self.alignment - - # allocate giant pool of device memory - self.raw = pm.allocate_raw(self.static_size+self.dynamic_size) - - # exchange peer pointers with nccl - raw_ipc = pm.get_raw_ipc_address(self.raw).cuda() - peer_raw_ipcs = [torch.empty_like(raw_ipc) for _ in range(world_size)] - torch.distributed.all_gather(peer_raw_ipcs, raw_ipc) - peer_raw_ipcs = torch.stack(peer_raw_ipcs).cpu() - - # extract IPC pointers for ranks on same node - peer_raw = pm.get_raw_peers(peer_raw_ipcs[peer_rank_base:peer_rank_base+ngpus], peer_rank, self.raw) - self.peer_raw = [peer_raw[peer_rank-peer_rank_base] for peer_rank in peer_ranks] - self.static_offset = 0 - self.dynamic_offset = 0 - self.peer_ranks = peer_ranks - - def __del__(self): - pm.free_raw(self.raw) - - def reset(self): - self.dynamic_offset = 0 - - def allocate_peer_tensors(self, shape, dtype, channels_last, dynamic): - nels = np.prod(shape) - if dtype == torch.float16: - elem_size = 2 - if dynamic: - start = ((self.dynamic_offset + self.alignment - 1) // self.alignment) * self.alignment - self.dynamic_offset = start + nels * elem_size - assert(self.dynamic_offset < self.dynamic_size), "Dynamic peer memory pool exhausted" - return [pm.blob_view_half(pr + self.static_size + start, shape, channels_last) for pr in self.peer_raw] - else: - start = ((self.static_offset + self.alignment - 1) // self.alignment) * self.alignment - self.static_offset = start + nels * elem_size - assert(self.static_offset < self.static_size), "Static peer memory pool exhausted" - return [pm.blob_view_half(pr + start, shape, channels_last) for pr in self.peer_raw] - if dtype == torch.float32: - elem_size = 4 - if dynamic: - start = ((self.dynamic_offset + self.alignment - 1) // self.alignment) * self.alignment - self.dynamic_offset = start + nels * elem_size - assert(self.dynamic_offset < self.dynamic_size), "Dynamic peer memory pool exhausted" - return [pm.blob_view_float(pr + self.static_size + start, shape, channels_last) for pr in self.peer_raw] - else: - start = ((self.static_offset + self.alignment - 1) // self.alignment) * self.alignment - self.static_offset = start + nels * elem_size - assert(self.static_offset < self.static_size), "Static peer memory pool exhausted" - return [pm.blob_view_float(pr + start, shape, channels_last) for pr in self.peer_raw] - if dtype == torch.int32: - elem_size = 4 - if dynamic: - start = ((self.dynamic_offset + self.alignment - 1) // self.alignment) * self.alignment - self.dynamic_offset = start + nels * elem_size - assert(self.dynamic_offset < self.dynamic_size), "Dynamic peer memory pool exhausted" - return [pm.blob_view_int(pr + self.static_size + start, shape, channels_last) for pr in self.peer_raw] - else: - start = ((self.static_offset + self.alignment - 1) // self.alignment) * self.alignment - self.static_offset = start + nels * elem_size - assert(self.static_offset < self.static_size), "Static peer memory pool exhausted" - return [pm.blob_view_int(pr + start, shape, channels_last) for pr in self.peer_raw] - else: - assert(False), "dtype %s not supported" % (str(dtype)) diff --git a/apex/contrib/sparsity/README.md b/apex/contrib/sparsity/README.md deleted file mode 100644 index 3468118..0000000 --- a/apex/contrib/sparsity/README.md +++ /dev/null @@ -1,134 +0,0 @@ -# Introduction to ASP - -This serves as a quick-start for ASP (Automatic SParsity), a tool that enables sparse training and inference for PyTorch models by adding 2 lines of Python. - -## Importing ASP - -``` -from apex.contrib.sparsity import ASP -``` - -## Initializing ASP - -Apart from the import statement, it is sufficient to add just the following line of code before the training phase to augment the model and the optimizer for sparse training/inference: - -``` -ASP.prune_trained_model(model, optimizer) -``` - -In the context of a typical PyTorch training loop, it might look like this: - -``` -ASP.prune_trained_model(model, optimizer) - -x, y = DataLoader(args) -for epoch in range(epochs): - y_pred = model(x) - loss = loss_function(y_pred, y) - loss.backward() - optimizer.step() - -torch.save(...) -``` - -The `prune_trained_model` step calculates the sparse mask and applies it to the weights. This is done once, i.e., sparse locations in the weights matrix remain fixed after this step. - -## Generate a Sparse Network - -The following approach serves as a guiding example on how to generate a pruned model that can use Sparse Tensor Cores in the NVIDIA Ampere Architecture. This approach generates a model for deployment, i.e. inference mode. - -``` -(1) Given a fully trained (dense) network, prune parameter values in a 2:4 sparse pattern. -(2) Fine-tune the pruned model with optimization method and hyper-parameters (learning-rate, schedule, number of epochs, etc.) exactly as those used to obtain the trained model. -(3) (If required) Quantize the model. -``` - -In code, below is a sketch on how to use ASP for this approach (steps 1 and 2 above). - -``` -model = define_model(..., pretrained=True) # define model architecture and load parameter tensors with trained values (by reading a trained checkpoint) -criterion = ... # compare ground truth with model predition; use the same criterion as used to generate the dense trained model -optimizer = ... # optimize model parameters; use the same optimizer as used to generate the dense trained model -lr_scheduler = ... # learning rate scheduler; use the same schedule as used to generate the dense trained model - -from apex.contrib.sparsity import ASP -ASP.prune_trained_model(model, optimizer) #pruned a trained model - -x, y = DataLoader(args) -for epoch in range(epochs): # train the pruned model for the same number of epochs as used to generate the dense trained model - y_pred = model(x) - loss = criterion(y_pred, y) - lr_scheduler.step() - loss.backward() - optimizer.step() - -torch.save(...) # saves the pruned checkpoint with sparsity masks -``` - -## Non-Standard Usage - -If your goal is to easily perpare a network for accelerated inference, please follow the recipe above. However, ASP can also be used to perform experiments in advanced techniques like training with sparsity from initialization. For example, in order to recompute the sparse mask in between training steps, use the following method: - -``` -ASP.compute_sparse_masks() -``` - -A more thorough example can be found in `./test/toy_problem.py`. - -## Advanced Usage: Channel Permutation - -We introduce channel permutations as an advanced method to maximize the accuracy of structured sparse networks. By permuting weight matrices along their channel dimension and adjusting the surrounding layers appropriately, we demonstrate accuracy recovery for even small, parameter-efficient networks, without affecting inference run-time. - -The final accuracy has a strong relationship with the quality of permutations. We provide the default algorithms to search for high-quality permutations. The permutation search process can be accelerated by the Apex CUDA extension: `apex.contrib.sparsity.permutation_search_kernels` - -If you want to use the GPU to accelerate the permutation search process, we recommend installing Apex with permutation search CUDA extension via - -``` -pip install -v --disable-pip-version-check --no-cache-dir --global-option="--permutation_search" ./ -``` - -If you want to disable the permutation search process, please pass the `allow_permutation=False` to `init_model_for_pruning` function. For example: - -``` -ASP.init_model_for_pruning(model, mask_calculator="m4n2_1d", verbosity=2, whitelist=[torch.nn.Linear, torch.nn.Conv2d], allow_recompute_mask=False, allow_permutation=False) -``` - -Please notice, when using multi-GPUs we should set the identical random seed for all GPUs to make sure the same results generated in permutation search. The library has implemented the `set_identical_seed` function in `permutation_lib.py`, and be called in ASP library. We still suggest the users to set the identical random seed when using multi-GPUs in their code, the example code is as follows: - -``` -import torch -import numpy -import random - -torch.manual_seed(identical_seed) -torch.cuda.manual_seed_all(identical_seed) -numpy.random.seed(identical_seed) -random.seed(identical_seed) -torch.backends.cudnn.deterministic = True -torch.backends.cudnn.benchmark = False -``` - -## Reference Papers - -More details about sparsity support on the NVIDIA Ampere GPU with Sparse Tensor Cores can refer to our [white paper](https://arxiv.org/abs/2104.08378). - -``` -@article{mishra2021accelerating, - title={Accelerating sparse deep neural networks}, - author={Mishra, Asit and Latorre, Jorge Albericio and Pool, Jeff and Stosic, Darko and Stosic, Dusan and Venkatesh, Ganesh and Yu, Chong and Micikevicius, Paulius}, - journal={arXiv preprint arXiv:2104.08378}, - year={2021} -} -``` - -The details about sparsity with permutation can refer to our [paper](https://proceedings.neurips.cc/paper/2021/hash/6e8404c3b93a9527c8db241a1846599a-Abstract.html) published in *Thirty-fifth Conference on Neural Information Processing Systems* (**NeurIPS 2021**): - -``` -@article{pool2021channel, - title={Channel Permutations for N: M Sparsity}, - author={Pool, Jeff and Yu, Chong}, - journal={Advances in Neural Information Processing Systems}, - volume={34}, - year={2021} -} -``` diff --git a/apex/contrib/sparsity/__init__.py b/apex/contrib/sparsity/__init__.py deleted file mode 100644 index 661fd4a..0000000 --- a/apex/contrib/sparsity/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .sparse_masklib import create_mask -from .asp import ASP diff --git a/apex/contrib/sparsity/asp.py b/apex/contrib/sparsity/asp.py deleted file mode 100644 index 924024f..0000000 --- a/apex/contrib/sparsity/asp.py +++ /dev/null @@ -1,312 +0,0 @@ -import types -import torch -from .sparse_masklib import create_mask -from .permutation_lib import Permutation - -torchvision_imported=True -try: - import torchvision -except ImportError: - print("[ASP][Warning] torchvision cannot be imported.") - torchvision_imported=False - -import json -import os -import string -import time - -def eligible_modules(model, whitelist_layer_types, allowed_layer_names, disallowed_layer_names): - eligible_modules_list = [] - for name, mod in model.named_modules(): - if isinstance(mod, whitelist_layer_types) and name not in disallowed_layer_names: - if allowed_layer_names is not None and name not in allowed_layer_names: - continue - eligible_modules_list.append((name, mod)) - return eligible_modules_list - - -class ASP: - __model = None - __verbosity = 0 - __optimizer = None - __sparse_parameters = [] - __calculate_mask = None - __allow_permutation = True - __all_parameters = [] - __save_permutation_graph = False - __permutation_output_dir = '' - - @classmethod - def init_model_for_pruning(cls, model, mask_calculator="m4n2_1d", - verbosity=3, - whitelist=[torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d], - allowed_layer_names=None, disallowed_layer_names=[], - allow_recompute_mask=False, custom_layer_dict={}, - allow_permutation=True): - """Call this method to modify your model to take advantage of sparse matrix multiplication. - Note that this call alone only augments the model with additional buffers needed for sparse MMA, - it does not enable use of sparse MMA. - - If you are starting with a fresh model: - - model = ... - ASP.init_model_for_pruning(model, mask_calculator, ...) - if (training) ASP.init_optimizer_for_pruning(optimizer) - ASP.compute_sparse_masks() // sparsity is off by default, call when youy want to enable it. - - If you are starting from a checkpoint: - - model = ... - ASP.init_model_for_pruning(model, mask_calculator, ...) - torch.load(...) - if (training) ASP.init_optimizer_for_pruning(optimizer) - - Arguments: - model The model - mask_calculator Either callable that computes mask given a tensor OR pattern string for sparse mask lib. - verbosity Integer controling verbosity level. - 0 -> Only errors. - 1 -> Errors and warnings. - 2 -> Errors, warnings and info. - 3 -> Errors, warnings, info and debug. - whitelist Module types approved for sparsity. - allowed_layer_names If not None, only layer names that appear in this list are considered for sparsity. - disallowed_layer_names If not [], only layer names that do not appear in this list are considered for sparsity. - allow_recompute_mask If True, stores pruned values so that dense weights can be restored. - Pruned weights are stored in CPU memory, hence this option does not increase GPU memory usage. - custom_layer_dict Dictionary of additional layer paremeters to sparsify. e.g. {CustomLinear: ['weight']} - allow_permutation If True, allow the input channel permutation to ease the influence of weight pruning. - - [Future] Support for allow_recompute_mask can be removed, it is not part of sparse inference recipe. - """ - assert (cls.__model is None), "ASP has been initialized already." - cls.__model = model - cls.__verbosity = verbosity - cls.__allow_permutation = allow_permutation - - if isinstance(mask_calculator, str): - def create_mask_from_pattern(param): - return create_mask(param, mask_calculator).bool() - cls.__calculate_mask = create_mask_from_pattern - else: - cls.__calculate_mask = mask_calculator #user defined function - - # function to extract variables that will be sparsified. - # idea is that you will add one of these functions for each module type that can be sparsified. - if torchvision_imported: - print("[ASP] torchvision is imported, can work with the MaskRCNN/KeypointRCNN from torchvision.") - sparse_parameter_list = {torch.nn.Linear: ['weight'], torch.nn.Conv1d: ['weight'], torch.nn.Conv2d: ['weight'], torch.nn.Conv3d: ['weight'], torchvision.ops.misc.Conv2d: ['weight']} - else: - sparse_parameter_list = {torch.nn.Linear: ['weight'], torch.nn.Conv1d: ['weight'], torch.nn.Conv2d: ['weight'], torch.nn.Conv3d: ['weight']} - if custom_layer_dict: # Update default list to include user supplied custom (layer type : parameter tensor), make sure this tensor type is something ASP knows how to prune - sparse_parameter_list.update(custom_layer_dict) - whitelist += list(custom_layer_dict.keys()) - - for module_type in whitelist: - assert (module_type in sparse_parameter_list), "Module %s :: Don't know how to sparsify module." % module.dtype() - - if allow_permutation: # find all named modules, extract parameters and decorate, used for offline permutation in K dim - for module_name, module in model.named_modules(): - module_type_str = str(type(module)).split("\'")[1] - if module_type_str == 'torch.nn.modules.container.Sequential' or module_type_str.startswith('torchvision.models'): - # filter out the 'torch.nn.modules.container.Sequential' type and the whole model, like 'torchvision.models.vgg.VGG' - continue - for p_name, p in module.named_parameters(): - cls.__all_parameters.append((module_name, module, p_name, p)) - if module_type_str == 'torch.nn.modules.batchnorm.BatchNorm2d': - # need to get the running_mean and running_var from model.state_dict(), as they are not the learnable parameters - module_mean_name = module_name + '.running_mean' - module_var_name = module_name + '.running_var' - for param_key in model.state_dict(): - if module_mean_name == param_key or module_var_name == param_key: - cls.__all_parameters.append((module_name, module, param_key.split(".")[-1], model.state_dict()[param_key])) - # add the __permutation_output_dir field to save the intermediate results for permutation - cls.__permutation_output_dir = '.' - # Set the corresponding params from ASP class to the Permutation class - Permutation.set_permutation_params_from_asp(cls.__model, cls.__sparse_parameters, cls.__all_parameters) - # Set the identical random seed for all GPUs to make sure the same results generated in permutation search - Permutation.set_identical_seed() - - # find all sparse modules, extract sparse parameters and decorate - def add_sparse_attributes(module_name, module): - sparse_parameters = sparse_parameter_list[type(module)] - for p_name, p in module.named_parameters(): - if p_name in sparse_parameters and p.requires_grad: - # check for NVIDIA's TC compatibility: we check along the horizontal direction - if p.dtype == torch.float32 and ((p.size()[0] % 8) != 0 or (p.size()[1] % 16) != 0): #User defines FP32 and APEX internally uses FP16 math - print("[ASP] Auto skipping pruning %s::%s of size=%s and type=%s for sparsity" % (module_name, p_name, str(p.size()), str(p.dtype))) - continue - if p.dtype == torch.float16 and ((p.size()[0] % 8) != 0 or (p.size()[1] % 16) != 0): #For Conv2d dim= K x CRS; we prune along C - print("[ASP] Auto skipping pruning %s::%s of size=%s and type=%s for sparsity" % (module_name, p_name, str(p.size()), str(p.dtype))) - continue - - if cls.__verbosity >= 3: - print("[ASP] Sparsifying %s::%s of size=%s and type=%s for sparsity" % (module_name, p_name, str(p.size()), str(p.dtype))) - - mask = torch.ones_like(p).bool() - buffname = p_name.split(".")[-1] # buffer names cannot contain "." - module.register_buffer('__%s_mma_mask' % buffname, mask) - if allow_recompute_mask: - pruned = torch.zeros_like(p).cpu() - module.register_buffer('__%s_mma_pruned_p' % buffname, pruned) - else: - pruned = None - cls.__sparse_parameters.append((module_name, module, p_name, p, mask, pruned)) - else: - if cls.__verbosity >= 3: - print("[ASP] Not sparsifying %s::%s of size=%s and type=%s" % (module_name, p_name, str(p.size()), str(p.dtype))) - - for name, sparse_module in eligible_modules(model, tuple(whitelist), allowed_layer_names, disallowed_layer_names): - add_sparse_attributes(name, sparse_module) - - @classmethod - def already_init_asp_model(cls): - """Call this method to check whether ASP has been initialized already. - """ - if cls.__model is None: - if cls.__verbosity >= 3: - print("[ASP] ASP has not been initialized.") - return False - else: - if cls.__verbosity >= 3: - print("[ASP] ASP has been initialized already.") - return True - - @classmethod - def init_optimizer_for_pruning(cls, optimizer): - """Call this method to monkey patch optimizer step function so that masks can be applied to - gradients and weights during training. - You must call init_model_for_pruning(...) before calling init_optimizer_for_pruning(...) - """ - assert (cls.__optimizer is None), "ASP has initialized optimizer already." - assert (cls.__calculate_mask is not None), "Called ASP.init_optimizer_for_pruning before ASP.init_model_for_pruning." - - # store pointer to original optimizer step method - cls.__optimizer = optimizer - cls.__optimizer.__step = optimizer.step - - def __step(opt_self, *args, **kwargs): - # prune gradients before step method - with torch.no_grad(): - for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters: - if p.grad is not None: #thx pjudd - p.grad.mul_(mask) - # call original optimizer step method - rval = opt_self.__step(*args, **kwargs) - # prune parameters after step method - with torch.no_grad(): - for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters: - p.mul_(mask) - return rval - cls.__optimizer.step = types.MethodType(__step, cls.__optimizer) - - @classmethod - def compute_sparse_masks(cls): - """Call this method to enable sparsity. - If init(...) was called with allow_recompute_mask=False AND sparsity is disabled, pruned field can be None. - """ - with torch.no_grad(): - if cls.__allow_permutation: - # Step 1: use the Torch.FX library to build the graph - # Step 2: permutation search with the customized kernel - # Notice: need to use the single GPU to build the Torch.FX graph - # The simplest without user intervention: - # A. try to import with the distributed mode of the original model - # B. if meet the error, import with the none-distributed mode of the original model - start_time_build_offline_permutation_graph = time.perf_counter() - try: - offline_permutation_fx_graph, success_in_build_offline_permutation_graph = Permutation.build_offline_permutation_graph(cls.__model.module, dump_fx_graph=cls.__save_permutation_graph, save_dumped_fx_graph=os.path.join(cls.__permutation_output_dir, 'model_offline_permutation_graph.json')) - print("\n[compute_sparse_masks] build offline permutation graph on distributed model.") - except AttributeError: - offline_permutation_fx_graph, success_in_build_offline_permutation_graph = Permutation.build_offline_permutation_graph(cls.__model, dump_fx_graph=cls.__save_permutation_graph, save_dumped_fx_graph=os.path.join(cls.__permutation_output_dir, 'model_offline_permutation_graph.json')) - print("\n[compute_sparse_masks] build offline permutation graph on none-distributed model.") - duration_build_offline_permutation_graph = time.perf_counter() - start_time_build_offline_permutation_graph - print("[compute_sparse_masks] Take {:.4f} seconds to finish build_offline_permutation_graph function.".format(duration_build_offline_permutation_graph)) - - # Step 3: off-line permutation to avoid the runtime overhead in deployment - if success_in_build_offline_permutation_graph: - start_time_apply_offline_permutation = time.perf_counter() - try: - Permutation.apply_offline_permutation(cls.__model.module, fx_graph=offline_permutation_fx_graph) - print("\n[compute_sparse_masks] apply offline permutation on distributed model.") - except AttributeError: - Permutation.apply_offline_permutation(cls.__model, fx_graph=offline_permutation_fx_graph) - print("\n[compute_sparse_masks] apply offline permutation on none-distributed model.") - duration_apply_offline_permutation = time.perf_counter() - start_time_apply_offline_permutation - print("[compute_sparse_masks] Take {:.4f} seconds to finish apply_offline_permutation function.\n".format(duration_apply_offline_permutation)) - else: - print("[compute_sparse_masks] skip applying offline permutation because there is no valid offline_permutation_fx_graph.") - # Finally, permutation search and off-line permutation is done, give the model back to ASP to generate the normal structured sparse mask - - for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters: - if mask.sum() < mask.numel(): # when recalculating masks - # restore dense parameter if allow_recompute_mask is enabled - assert (pruned is not None), "Unable to restore dense parameter because allow_recompute_mask == False" - p.add_(pruned.cuda()) - - mask.set_(cls.__calculate_mask(p)) - - if pruned is not None: # stow away pruned weights to cpu - pruned.set_((p * (~mask)).cpu()) - - p.mul_(mask) # in-place multiplication, so pruned weights are 0-values, hence checkpoint will have 0s for pruned weights - if cls.__verbosity >= 2: - print("[ASP] Enabled %.2f%% sparsity for %s::%s of size=%s and type=%s" % (100.0-100.0*mask.sum()/mask.numel(), module_name, p_name, str(p.size()), str(p.dtype))) - - @classmethod - def restore_pruned_weights(cls): - """Call this method to disable sparsity and restore all weights. - This will only work if init(...) was called with allow_recompute=True. - """ - with torch.no_grad(): - for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters: - if mask.sum() < mask.numel(): - assert (pruned is not None), "Unable to restore dense parameter because allow_recompute_mask == False" - p.add_(pruned.cuda()) - mask.fill_(1) - pruned.zero_() - if cls.__verbosity >= 2: - print("[ASP] Disabled sparsity for %s::%s (dense weights restored)" % (module_name, p_name)) - - @classmethod - def is_sparsity_enabled(cls): - """Call this method to determine if sparsity is enabled in the model. - The typical use case is right after checkpoint has been loaded. - """ - total,sp100,sp50 = 0,0,0 - for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters: - total += 1 - mask_sum = mask.sum() - mask_numel = mask.numel() - if mask_sum == mask_numel: - sp100 += 1 - elif mask_sum*2 == mask_numel: - sp50 += 1 - - assert (total == sp100 or total == sp50), "Inconsistent model sparsity" - if total == sp100: - return False - elif total == sp50: - return True - - @classmethod - def prune_trained_model(cls, model, optimizer): - # add mask buffers to model (init_model_for_pruning), augment optimizer (init_optimizer_for_pruning) and compute masks (compute_sparse_masks) - cls.init_model_for_pruning(model, mask_calculator="m4n2_1d", verbosity=2, whitelist=[torch.nn.Linear, torch.nn.Conv2d], allow_recompute_mask=False) - cls.init_optimizer_for_pruning(optimizer) - cls.compute_sparse_masks() - - @classmethod - def set_permutation_saving_params(cls, allow_permutation=True, save_permutation_graph=False, permutation_output_dir='.'): - """This function is used to set the permutation saving related parameters in ASP class and inside of the Permutation class.""" - print("\n[ASP][set_permutation_saving_param] Set permutation saving related parameters") - print("\n[set_permutation_saving_param] Set permutation saving related parameters") - cls.__allow_permutation = allow_permutation - print("[set_permutation_saving_param]\t Allow permutation: {}".format(cls.__allow_permutation)) - cls.__save_permutation_graph = save_permutation_graph - print("[set_permutation_saving_param]\t Save permutation graphs: {}".format(cls.__save_permutation_graph)) - cls.__permutation_output_dir = permutation_output_dir - print("[set_permutation_saving_param]\t Permutation graphs saving dir: {}".format(cls.__permutation_output_dir)) - - Permutation.set_permutation_saving_params(allow_permutation, save_permutation_graph, permutation_output_dir) - diff --git a/apex/contrib/sparsity/permutation_lib.py b/apex/contrib/sparsity/permutation_lib.py deleted file mode 100644 index b7cf102..0000000 --- a/apex/contrib/sparsity/permutation_lib.py +++ /dev/null @@ -1,925 +0,0 @@ -import os -import torch -import json -import string -import time -try: - from .permutation_search_kernels import accelerated_search_for_good_permutation, sum_after_2_to_4 - print("[ASP][Info] permutation_search_kernels can be imported.") -except ImportError: - print("[ASP][Warning] permutation_search_kernels cannot be imported.") - print("[ASP][Warning] If you want to accelerate the permutation search process by GPU, please build APEX by following the instructions at https://github.com/NVIDIA/apex/blob/master/apex/contrib/sparsity/README.md") - -def convert_fx_node_name(fx_node_name): - converted_fx_node_name = fx_node_name - converted_fx_node_name = converted_fx_node_name.replace('_', '.') - return converted_fx_node_name - -def get_node_parent_children(fx_node): - # get node parent list, and convert node name to module name - node_parent_name_converted = [] - if len(fx_node.all_input_nodes) > 0: - node_parent = fx_node.all_input_nodes - for item in node_parent: - converted_item = convert_fx_node_name(item.name) - node_parent_name_converted.append(converted_item) - else: - node_parent = list('None') - node_parent_name_converted.append('None') - # get node children list, and convert node name to module name - node_children_name_converted = [] - if len(list(fx_node.users.keys())) > 0: - node_children = list(fx_node.users.keys()) - for item in node_children: - converted_item = convert_fx_node_name(item.name) - node_children_name_converted.append(converted_item) - else: - node_children = list('None') - node_children_name_converted.append('None') - return node_parent_name_converted, node_children_name_converted - - -class Permutation: - __model = None - __sparse_parameters = [] - __allow_permutation = False - __all_parameters = [] - __save_permutation_graph = False - __permutation_output_dir = '' - - @classmethod - def set_permutation_params_from_asp(cls, model, sparse_parameters, all_parameters): - """This function is used to set the permutation needed parameters from ASP class.""" - print("\n[set_permutation_params_from_asp] Set permutation needed parameters") - cls.__model = model - cls.__sparse_parameters = sparse_parameters - cls.__all_parameters = all_parameters - - @classmethod - def set_identical_seed(cls, identical_seed=1): - print("\n[set_identical_seed] Set the identical seed: {:} for all GPUs to make sure the same results generated in permutation search".format(identical_seed)) - torch.manual_seed(identical_seed) - torch.cuda.manual_seed_all(identical_seed) - import numpy as np - import random - np.random.seed(identical_seed) - random.seed(identical_seed) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - - @classmethod - def set_permutation_saving_params(cls, allow_permutation=False, save_permutation_graph=False, permutation_output_dir='.'): - """This function is used to set the permutation saving related parameters.""" - print("\n[permutation_lib][set_permutation_saving_param] Set permutation saving related parameters") - cls.__allow_permutation = allow_permutation - print("[set_permutation_saving_param]\t Allow permutation: {}".format(cls.__allow_permutation)) - cls.__save_permutation_graph = save_permutation_graph - print("[set_permutation_saving_param]\t Save permutation graphs: {}".format(cls.__save_permutation_graph)) - cls.__permutation_output_dir = permutation_output_dir - print("[set_permutation_saving_param]\t Permutation graphs saving dir: {}".format(cls.__permutation_output_dir)) - - @classmethod - def apply_offline_permutation(cls, model, fx_graph): - """This function is used to offline permutation for each node according to the the whole network graph built with Torch.FX.""" - print("\n[apply_offline_permutation] Offline permutation for each node according to the the whole network graph built with Torch.FX") - - # Firstly, we should transfer the sparse mask to all-one dense mask - cls.transfer_to_dense_mask() - - for node_name in fx_graph.keys(): - node_module_type = fx_graph.get(node_name).get('module_type') - - # check wheter the current layer can permute as plan, e.g., the flatten layer in VGG will change the shape and broke the permutation chain - # only need to check the 'is_node_real_parents_K_permuted', because the 'is_node_real_parents_C_permuted' has no influence to the children - node_real_parents = fx_graph.get(node_name).get('real_parents') - is_node_real_parents_K_permuted = True - if node_real_parents is not None: # filter out the 'unique_siblings' item - for real_parent_item in node_real_parents: - if fx_graph.get(real_parent_item).get('permutation_type') in ['K', 'KC']: - if fx_graph.get(real_parent_item).get('k_permuted') == 'False': - is_node_real_parents_K_permuted = False - - if fx_graph[node_name]['permutation_type'] == 'KC': # intermediate Conv, FC - C_permutation_sequence = cls.fetch_C_permutation_sequence_value(node_name, fx_graph) - K_permutation_sequence = cls.fetch_K_permutation_sequence_value(node_name, fx_graph) - print("\n[apply_offline_permutation] node_name: \'{:}\', node module type: \'{:}\', need to do offline permutation in K and C dims.".format(node_name, node_module_type)) - if is_node_real_parents_K_permuted == True: - fx_graph[node_name]['c_permuted'] = str(cls.apply_permutation_in_C_dim(node_name, C_permutation_sequence)) - fx_graph[node_name]['k_permuted'] = str(cls.apply_permutation_in_K_dim(node_name, K_permutation_sequence)) - else: - print("[apply_offline_permutation][warning] node_name: \'{:}\', its real parents have trouble in permutation in K dim, so skip the offline permutation in C dim.".format(node_name, node_module_type)) - fx_graph[node_name]['k_permuted'] = str(cls.apply_permutation_in_K_dim(node_name, K_permutation_sequence)) - elif fx_graph[node_name]['permutation_type'] == 'K': # BN, first layer Conv/FC - K_permutation_sequence = cls.fetch_K_permutation_sequence_value(node_name, fx_graph) - print("\n[apply_offline_permutation] node_name: \'{:}\', node module type: \'{:}\', need to do offline permutation in K dim.".format(node_name, node_module_type)) - if is_node_real_parents_K_permuted == True: - fx_graph[node_name]['k_permuted'] = str(cls.apply_permutation_in_K_dim(node_name, K_permutation_sequence)) - else: # for BN, if the previous Conv cannot do permutation in K dim, then no need to do permutation in K dim for this BN - print("[apply_offline_permutation][warning] node_name: \'{:}\', its real parents have trouble in permutation in K dim, so skip the offline permutation in K dim.".format(node_name, node_module_type)) - elif fx_graph[node_name]['permutation_type'] == 'C': # last layer FC/Conv - C_permutation_sequence = cls.fetch_C_permutation_sequence_value(node_name, fx_graph) - print("\n[apply_offline_permutation] node_name: \'{:}\', node module type: \'{:}\', need to do offline permutation in C dim.".format(node_name, node_module_type)) - if is_node_real_parents_K_permuted == True: - fx_graph[node_name]['c_permuted'] = str(cls.apply_permutation_in_C_dim(node_name, C_permutation_sequence)) - else: - print("[apply_offline_permutation][warning] node_name: \'{:}\', its real parents have trouble in permutation in K dim, so skip the offline permutation in C dim.".format(node_name, node_module_type)) - - if cls.__save_permutation_graph: - cls.save_graph_to_json(fx_graph, save_dumped_graph_path_with_name=os.path.join(cls.__permutation_output_dir, './model_graph_apply_offline_permutation.json')) # save the intermediate graph as JSON file for debugging - return fx_graph - - @classmethod - def transfer_to_dense_mask(cls): - """Call this method to transfer the sparse mask to all-one dense mask.""" - with torch.no_grad(): - for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters: - mask.fill_(1) - - @classmethod - def fetch_C_permutation_sequence_value(cls, node_name, fx_graph): - """This function is used to fetch the permutation sequence value in C dim from the unique_siblings record.""" - # C_permutation_sequence is the corresponding 'permutation_sequence' value stored in the fx_graph.get('unique_siblings') item which contains node_name - unique_siblings_groups = fx_graph.get('unique_siblings').get('name') - unique_siblings_groups_permutation_sequence = fx_graph.get('unique_siblings').get('permutation_sequence') - item_index = 0 - fetched_C_permutation_sequence = [] - for item in unique_siblings_groups: - if node_name in item: - fetched_C_permutation_sequence = unique_siblings_groups_permutation_sequence[item_index] - item_index = item_index + 1 - return fetched_C_permutation_sequence - - @classmethod - def fetch_K_permutation_sequence_value(cls, node_name, fx_graph): - """This function is used to fetch the permutation sequence value in K dim from the unique_siblings record.""" - # K_permutation_sequence is its real_children's corresponding 'permutation_sequence' value stored in the fx_graph.get('unique_siblings') item which contains real_children name - # we have the assumption that all the real children are in one unique_sibling group, so should share the same permutation_sequence value - unique_siblings_groups = fx_graph.get('unique_siblings').get('name') - unique_siblings_groups_permutation_sequence = fx_graph.get('unique_siblings').get('permutation_sequence') - node_real_children = fx_graph.get(node_name).get('real_children') - fetched_K_permutation_sequence = [] - if len(node_real_children) > 0: - node_representative_child = node_real_children[0] - fetched_K_permutation_sequence = cls.fetch_C_permutation_sequence_value(node_representative_child, fx_graph) - return fetched_K_permutation_sequence - - @classmethod - def apply_permutation_in_C_dim(cls, node_name, permutation_sequence): - """This function is used to permutation for a node in C dim. (Only need to handle the weight of the node) """ - print("[apply_permutation_in_C_dim] Permutation for node: \'{:}\' in C dim".format(node_name)) - if len(permutation_sequence) == 0: - print("[apply_permutation_in_C_dim] the permutation sequence is empty, fail to apply permutation in C dim.") - return False - is_node_in_sparse_parameters = False - success_permutation = False - for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters: - processed_module_name = ''.join(c for c in module_name if c not in string.punctuation).lower() - processed_node_name = ''.join(c for c in node_name if c not in string.punctuation).lower() - distributed_node_name = 'module.' + node_name - processed_distributed_node_name = 'module.' + processed_node_name - if (module_name == node_name) or (module_name == distributed_node_name) or (processed_module_name == processed_node_name) or (processed_module_name == processed_distributed_node_name): # Inception-V3, module_name: Conv2d_2a_3x3.conv, node_name: conv2d.1a.3x3.conv - print("[apply_permutation_in_C_dim] find the node: \'{:}\' in cls.__sparse_parameters, succeed to apply permutation in C dim.".format(node_name)) - is_node_in_sparse_parameters = True - temp_weight = torch.zeros_like(p) - temp_weight.copy_(p[:, permutation_sequence, ...]) - p.data.copy_(temp_weight) - success_permutation = True - if is_node_in_sparse_parameters == False: - # A special case: if the node itself not in sparse_module_names but one of its real_siblings in sparse_module_names, then the node will not do the permutation search, but it may need to apply the offline permutation in C dim according to the searched permutation sequence from its real_siblings in sparse_module_names - try: - for module_name_from_all_parameters, module_from_all_parameters, p_name_from_all_parameters, p_from_all_parameters in cls.__all_parameters: - if ((node_name == module_name_from_all_parameters) or ('module.' + node_name == module_name_from_all_parameters)) and p_name_from_all_parameters == "weight": - print("[apply_permutation_in_C_dim] cannot find the node: \'{:}\' in cls.__sparse_parameters, but can find in cls.__all_parameters.".format(node_name)) - temp_weight = torch.zeros_like(p_from_all_parameters) - temp_weight.copy_(p_from_all_parameters[:, permutation_sequence, ...]) - p_from_all_parameters.data.copy_(temp_weight) - success_permutation = True - print("[apply_permutation_in_C_dim] cannot find the node: \'{:}\' in cls.__sparse_parameters, after trying with cls.__all_parameters, succeed to apply permutation in C dim.".format(node_name)) - except: - success_permutation = False - print("[apply_permutation_in_C_dim] cannot find the node: \'{:}\' in cls.__sparse_parameters, after trying with cls.__all_parameters, still fail to apply permutation in C dim.".format(node_name)) - return success_permutation - - @classmethod - def apply_permutation_in_K_dim(cls, node_name, permutation_sequence): - """This function is used to permutation for a node in K dim. (Need to handle the weight/bias/running_mean/running_var of the node)""" - print("[apply_permutation_in_K_dim] Permutation for node: \'{:}\' in K dim".format(node_name)) - if len(permutation_sequence) == 0: - print("[apply_permutation_in_K_dim] the permutation sequence is empty, fail to apply permutation in K dim.") - return False - is_node_in_all_parameters = False - success_permutation = False - for module_name, module, p_name, p in cls.__all_parameters: - processed_module_name = ''.join(c for c in module_name if c not in string.punctuation).lower() - processed_node_name = ''.join(c for c in node_name if c not in string.punctuation).lower() - distributed_node_name = 'module.' + node_name - processed_distributed_node_name = 'module.' + processed_node_name - if (module_name == node_name) or (module_name == distributed_node_name) or (processed_module_name == processed_node_name) or (processed_module_name == processed_distributed_node_name): # Inception-V3, module_name: Conv2d_2a_3x3.conv, node_name: conv2d.1a.3x3.conv - print("[apply_permutation_in_K_dim] find the node: \'{:}\' with \'{:}\' in cls.__all_parameters, may succeed to apply permutation in K dim.".format(node_name, p_name)) - is_node_in_all_parameters = True - temp_weight = torch.zeros_like(p) - if p.shape[0] != len(permutation_sequence): - print("[apply_permutation_in_K_dim][warning] the node: \'{:}\' with shape: \'{:}\', cannot match the size of permutation sequence with len: \'{:}\', fail to apply permutation in K dim.".format(node_name, p.shape, len(permutation_sequence))) - success_permutation = False - else: - print("[apply_permutation_in_K_dim] the node: \'{:}\' with shape: \'{:}\', can match the size of permutation sequence with len: \'{:}\', succeed to apply permutation in K dim.".format(node_name, p.shape, len(permutation_sequence))) - temp_weight.copy_(p[permutation_sequence, ...]) - p.data.copy_(temp_weight) - success_permutation = True - if is_node_in_all_parameters == False: - print("[apply_permutation_in_K_dim] cannot find the node: \'{:}\' in cls.__all_parameters, fail to apply permutation in K dim.".format(node_name)) - success_permutation = False - return success_permutation - - @classmethod - def build_offline_permutation_graph(cls, model, dump_fx_graph=False, save_dumped_fx_graph='./model_offline_permutation_graph.json'): - """This function is used to refine the whole network graph built with Torch.FX with some extra infomation needed for offline permutation.""" - print("\n[build_offline_permutation_graph] Further refine the model graph built by Torch.FX for offline permutation") - # extract the output_dir, so all the intermediate fx_graph can be saved under that path - extract_output_dir=os.path.split(save_dumped_fx_graph)[0] - cls.__permutation_output_dir = extract_output_dir - fx_graph, success_in_build_fx_graph = cls.build_fx_graph(model, dump_fx_graph=dump_fx_graph, save_dumped_fx_graph=save_dumped_fx_graph) - if success_in_build_fx_graph: - fx_graph_after_find_real_parents = cls.find_real_parents(fx_graph) - fx_graph_after_find_real_children = cls.find_real_children(fx_graph_after_find_real_parents) - fx_graph_after_find_real_siblings = cls.find_real_siblings(fx_graph_after_find_real_children) - fx_graph_after_extract_all_unique_siblings = cls.extract_all_unique_siblings(fx_graph_after_find_real_siblings) - fx_graph_after_init_permutation_flag = cls.init_permutation_flag(fx_graph_after_extract_all_unique_siblings) - start_time_search_for_good_permutation = time.perf_counter() - fx_graph_after_search_for_good_permutation = cls.search_for_good_permutation(fx_graph_after_init_permutation_flag) - duration_search_for_good_permutation = time.perf_counter() - start_time_search_for_good_permutation - print("\n[build_offline_permutation_graph] Take {:.4f} seconds to finish search_for_good_permutation function.".format(duration_search_for_good_permutation)) - else: - fx_graph_after_search_for_good_permutation = {} - return fx_graph_after_search_for_good_permutation, success_in_build_fx_graph - - # Please notice the apply_offline_permutation step cannot fold into the above search_for_good_permutation step. - # Because the real_parent node needs to offline permutation in K direction according to the searched permutation sequence from its real_children. - # However, when we search_for_good_permutation for the node, its real_children have not been handled by search_for_good_permutation. - - if cls.__save_permutation_graph: - cls.save_graph_to_json(fx_graph_after_search_for_good_permutation, save_dumped_graph_path_with_name=os.path.join(cls.__permutation_output_dir, './model_graph_build_offline_permutation_graph.json')) # save the intermediate graph as JSON file for debugging - return fx_graph_after_search_for_good_permutation, success_in_build_fx_graph - - @classmethod - def search_for_good_permutation(cls, fx_graph): - """This function is used to: - 1. search for the good permutation sequence for each node weight, or each siblings_group weights by calling the permutation search kernels as ASP extension. - 2. add the searched permutation sequence for each node according to the whole network graph built with Torch.FX.""" - print("\n[search_for_good_permutation] Search for the good permutation sequence for each node according to the whole network graph built with Torch.FX") - - unique_siblings_groups = fx_graph.get('unique_siblings').get('name') - unique_siblings_groups_module_type = fx_graph.get('unique_siblings').get('module_type') - unique_siblings_groups_permutation_sequence = [] - item_index = 0 - for unique_siblings_group in unique_siblings_groups: # loop through all unique siblings groups that must share a permutation sequence - print("\n[search_for_good_permutation] this unique_siblings_group has {:} real siblings: \'{:}\', with module type: \'{:}\'.".format(len(unique_siblings_group), unique_siblings_group, unique_siblings_groups_module_type[item_index])) - item_index = item_index + 1 - - # concat the weight for layers in the same unique_siblings_group - matrix_group = None - for node_name in unique_siblings_group: - node_module_type = fx_graph.get(node_name).get('module_type') - print("[search_for_good_permutation] try to merge the weight for node: \'{:}\', with module type: \'{:}\'.".format(node_name, node_module_type)) - is_node_in_sparse_parameters = False - node_weight = torch.zeros(0) - for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters: - processed_module_name = ''.join(c for c in module_name if c not in string.punctuation).lower() - processed_node_name = ''.join(c for c in node_name if c not in string.punctuation).lower() - distributed_node_name = 'module.' + node_name - processed_distributed_node_name = 'module.' + processed_node_name - if (module_name == node_name) or (module_name == distributed_node_name) or (processed_module_name == processed_node_name) or (processed_module_name == processed_distributed_node_name): # Inception-V3, module_name: Conv2d_2a_3x3.conv, node_name: conv2d.1a.3x3.conv - module_type_from_sparse_parameters = str(type(module)) # e.g. - module_type_from_sparse_parameters = module_type_from_sparse_parameters[8:-2] - print("[search_for_good_permutation] find the node: \'{:}\' in cls.__sparse_parameters, module type match: \'{:}\'.".format(node_name, node_module_type==module_type_from_sparse_parameters)) - is_node_in_sparse_parameters = True - node_weight = torch.zeros_like(p) - node_weight.copy_(p) - # Need to handle the concat for layers with different R & S - shape = node_weight.shape - # 1d-tensor - if len(shape) == 1: - node_weight = node_weight.view(1, shape[0]) - # 2d-tensor (in, out) - elif len(shape) == 2: - node_weight = node_weight.view(shape[0], shape[1]) - # 3d-tensor (batch, in, out) - elif len(shape) == 3: - node_weight = node_weight.view(shape[0]*shape[1], shape[2]) - # 4d-tensor (in, out, h, w) - elif len(shape) == 4: - # convs - node_weight = node_weight.permute(2,3,0,1).contiguous().view(shape[2]*shape[3]*shape[0], shape[1]) - - if is_node_in_sparse_parameters == False: - print("[search_for_good_permutation] cannot find the node: \'{:}\' in cls.__sparse_parameters, no need to merge its weight for permutation.".format(node_name)) - else: - if matrix_group == None: - matrix_group = node_weight - else: - try: - if matrix_group.dim() == node_weight.dim(): - matrix_group = torch.cat((matrix_group, node_weight), dim=0) # concat the weights in K dimension, and keep the same C dimension - else: # e.g. when try to merge the Conv and FC layers - print("[search_for_good_permutation] matrix_group dim: {:} is not matched with node_weight dim: {:}.".format(matrix_group.dim(), node_weight.dim())) - print("[search_for_good_permutation] matrix_group shape: \'{:}\' is not matched with node_weight shape: \'{:}\'.".format(matrix_group.size(), node_weight.size())) - if matrix_group.dim() < node_weight.dim(): - while node_weight.dim() - matrix_group.dim() > 0: - matrix_group = matrix_group.unsqueeze(matrix_group.dim()) - else: - while matrix_group.dim() - node_weight.dim() > 0: - node_weight = node_weight.unsqueeze(node_weight.dim()) - print("[search_for_good_permutation] matrix_group shape: \'{:}\' is now matched with node_weight shape: \'{:}\'.".format(matrix_group.size(), node_weight.size())) - matrix_group = torch.cat((matrix_group, node_weight), dim=0) # concat the weights in K dimension, and keep the same C dimension - except: - print("[search_for_good_permutation][warning] cannot merge the weight for node: \'{:}\', with its weight shape: \'{:}\', the matrix_group shape: \'{:}\'.".format(node_name, node_weight.size(), matrix_group.size())) - continue - print("[search_for_good_permutation] have merged the weight for node: \'{:}\', with its weight shape: \'{:}\', the matrix_group shape: \'{:}\'.".format(node_name, node_weight.size(), matrix_group.size())) - - if matrix_group == None: # cannot find the node: \'{:}\' in cls.__sparse_parameters - input_channel_num = 0 - print("\n[search_for_good_permutation] init the all-zero list with length \'{:}\' for permutation search sequence of this unique_siblings_group.".format(input_channel_num)) - print("[search_for_good_permutation] no need to search the permutation_sequence for empty matrix_group.") - permutation_sequence = [0 for n in range(input_channel_num)] - unique_siblings_groups_permutation_sequence.append(permutation_sequence) - continue - else: - input_channel_num = matrix_group.size()[1] - print("\n[search_for_good_permutation] init the all-zero list with length \'{:}\' for permutation search sequence of this unique_siblings_group.".format(input_channel_num)) - permutation_sequence = [0 for n in range(input_channel_num)] - - # automatic check for skipping the permutation search process - original_magnitude = (torch.abs(matrix_group)).sum(dtype=torch.float64) - pruned_magnitude = sum_after_2_to_4(matrix_group.cpu().detach().numpy()) - diff_ratio = abs(original_magnitude - pruned_magnitude)/original_magnitude - epsilon = 1e-3 - print("\n[search_for_good_permutation] Original element abs sum: {:}, Pruned element abs sum: {:}, Diff ratio: {:}".format(original_magnitude, pruned_magnitude, diff_ratio)) - if diff_ratio < epsilon: - print("[search_for_good_permutation] Original element abs sum is almost same as the pruned element abs sum, further permutation search will not help, skipping!") - print("[search_for_good_permutation] Change the all-zero permutation search sequence to a sequential permutation search sequence.") - permutation_sequence = [n for n in range(input_channel_num)] - unique_siblings_groups_permutation_sequence.append(permutation_sequence) - continue - else: - print("[search_for_good_permutation] Original element abs sum is different from the pruned element abs sum, further permutation search will help, continue with the permutation search!") - - # call the permutation search CUDA kernels as ASP extension. - # users can provide prefer search strategy by providing a valid 'search_options' as a dictionary, - # or users can implement their customized 'accelerated_search_for_good_permutation' function. - search_options = {} - # No.1 Strategy: Exhaustive Search - # search_options['strategy'] = 'exhaustive' - # search_options['stripe_group_size'] = 8 - # search_options['escape_attempts'] = 100 - # No.2 Strategy: Progressive Channel Swap Search - # search_options['strategy'] = 'progressive channel swap' - # search_options['progressive_search_time_limit'] = 10 - # search_options['improvement_threshold'] = 1e-9 - # No.3 Strategy: User Defined Search - # search_options['strategy'] = 'user defined' - - # permutation search time is too long for matrix_group with large channel num - # change from Exhaustive Search to Progressive Channel Swap Search based on input matrix_group size - if input_channel_num > 2048: - search_options['strategy'] = 'progressive channel swap' - search_options['progressive_search_time_limit'] = 120 - search_options['improvement_threshold'] = 1e-9 - print("[search_for_good_permutation] Change to Progressive Channel Swap Search with {} seconds limitation, because the {} is too large and will leading too long permutation search time with Exhaustive Search.".format(search_options['progressive_search_time_limit'], input_channel_num)) - - start_time_accelerated_search_for_good_permutation = time.perf_counter() - permutation_sequence = accelerated_search_for_good_permutation(matrix_group, options=search_options) - duration_accelerated_search_for_good_permutation = time.perf_counter() - start_time_accelerated_search_for_good_permutation - print("[search_for_good_permutation] Take {:.4f} seconds to finish accelerated_search_for_good_permutation function.".format(duration_accelerated_search_for_good_permutation)) - unique_siblings_groups_permutation_sequence.append(permutation_sequence) - fx_graph['unique_siblings']['permutation_sequence'] = unique_siblings_groups_permutation_sequence - - if cls.__save_permutation_graph: - cls.save_graph_to_json(fx_graph, save_dumped_graph_path_with_name=os.path.join(cls.__permutation_output_dir, './model_graph_search_for_good_permutation.json')) # save the intermediate graph as JSON file for debugging - return fx_graph - - @classmethod - def init_permutation_flag(cls, fx_graph): - """This function is used to init the permutation flag for each node according to the whole network graph built with Torch.FX.""" - print("\n[init_permutation_flag] Init the permutation flag for each node according to the whole network graph built with Torch.FX") - sparse_module_names = [] - processed_sparse_module_names = [] # Inception-V3, module_name: Conv2d_2a_3x3.conv, node_name: conv2d.1a.3x3.conv - for module_name, module, p_name, p, mask, pruned in cls.__sparse_parameters: - sparse_module_names.append(module_name) - processed_module_name = ''.join(c for c in module_name if c not in string.punctuation).lower() - processed_sparse_module_names.append(processed_module_name) - for node_name in fx_graph.keys(): - processed_node_name = ''.join(c for c in node_name if c not in string.punctuation).lower() - distributed_node_name = 'module.' + node_name - processed_distributed_node_name = 'module.' + processed_node_name - node_module_type = fx_graph.get(node_name).get('module_type') - if node_module_type in ['torch.nn.modules.conv.Conv2d', 'torch.nn.modules.linear.Linear']: - node_parents = fx_graph.get(node_name).get('parents') - node_children = fx_graph.get(node_name).get('children') - node_real_parents = fx_graph.get(node_name).get('real_parents') - node_real_children = fx_graph.get(node_name).get('real_children') - node_groups_param = fx_graph.get(node_name).get('groups_param') - is_node_real_children_in_sparse_parameters = False - is_node_real_children_has_group_conv = False - for real_child_item in node_real_children: - processed_real_child_item = ''.join(c for c in real_child_item if c not in string.punctuation).lower() - distributed_real_child_item = 'module.' + real_child_item - processed_distributed_real_child_item = 'module.' + processed_real_child_item - if (real_child_item in sparse_module_names) or (processed_real_child_item in processed_sparse_module_names) or (distributed_real_child_item in sparse_module_names) or (processed_distributed_real_child_item in processed_sparse_module_names): - is_node_real_children_in_sparse_parameters = True - if (fx_graph.get(real_child_item).get('groups_param') not in ['None', '1']): - is_node_real_children_has_group_conv = True - is_node_real_parents_has_group_conv = False - for real_parent_item in node_real_parents: - # notice: we assume the if one item of real_parents need to permute in C or K dim, then the corresponding flag should be set - # so for all items of real_parents, they may not share the same 'permutation_type' (e.g., one item is Group Conv, etc.) - # that's why we also need to judge the 'is_node_real_parents_has_group_conv' - if (fx_graph.get(real_parent_item).get('groups_param') not in ['None', '1']): - is_node_real_parents_has_group_conv = True - # If the node itself is in sparse_module_names or one of its real_children in sparse_module_names, then it may need the offline permutation - if ((node_name in sparse_module_names) or (processed_node_name in processed_sparse_module_names) or (distributed_node_name in sparse_module_names) or (processed_distributed_node_name in processed_sparse_module_names)) or (is_node_real_children_in_sparse_parameters == True): - if node_groups_param not in ['None', '1']: - # for Group Conv, disable the permutation in 'C' and 'K' dim - fx_graph[node_name]['permutation_type'] = 'None' - elif ('x' in node_parents) or ((node_name not in sparse_module_names) and (processed_node_name not in processed_sparse_module_names) and (distributed_node_name not in sparse_module_names) and (processed_distributed_node_name not in processed_sparse_module_names)): - # for the first (due to it is connected to 'x' node or itself is not in sparse_module_names) or not NVIDIA's TC compatiable Conv/FC, only permutate the K direction - if is_node_real_children_has_group_conv == False: - fx_graph[node_name]['permutation_type'] = 'K' - fx_graph[node_name]['k_permuted'] = 'False' - else: # if node real_children contains Group Conv, disable the permutation for node in 'K' dim - fx_graph[node_name]['permutation_type'] = 'None' - elif ('output' in node_children) or (is_node_real_children_in_sparse_parameters == False): - # for the last (due to it is connected to 'output' node or to a node which is not in sparse_module_names) FC/Conv, only permutate the C direction - if is_node_real_parents_has_group_conv == False: - fx_graph[node_name]['permutation_type'] = 'C' - fx_graph[node_name]['c_permuted'] = 'False' - else: # if node real_parents contains Group Conv, disable the permutation for node in 'C' dim - fx_graph[node_name]['permutation_type'] = 'None' - else: - if (is_node_real_parents_has_group_conv == False) and (is_node_real_children_has_group_conv == False): - fx_graph[node_name]['permutation_type'] = 'KC' - fx_graph[node_name]['k_permuted'] = 'False' - fx_graph[node_name]['c_permuted'] = 'False' - elif is_node_real_parents_has_group_conv == True: # TODO: if node real_parents contains Group Conv, disable the permutation for node in 'C' dim - fx_graph[node_name]['permutation_type'] = 'K' - fx_graph[node_name]['k_permuted'] = 'False' - else: # if node real_children contains Group Conv, disable the permutation for node in 'K' dim - fx_graph[node_name]['permutation_type'] = 'C' - fx_graph[node_name]['c_permuted'] = 'False' - else: - fx_graph[node_name]['permutation_type'] = 'None' - elif node_module_type in ['torch.nn.modules.batchnorm.BatchNorm2d']: - node_real_parents = fx_graph.get(node_name).get('real_parents') - is_node_real_parents_need_K_permutation = False - is_node_real_parents_has_group_conv = False - for real_parent_item in node_real_parents: - # notice: we assume the if one item of real_parents need to permute in K dim, then the corresponding flag should be set - # as in most of the cases, BN only follows one Conv, so it should be OK for now - if fx_graph.get(real_parent_item).get('permutation_type') in ['K', 'KC']: - is_node_real_parents_need_K_permutation = True - if (fx_graph.get(real_parent_item).get('groups_param') not in ['None', '1']): - is_node_real_parents_has_group_conv = True - node_real_children = fx_graph.get(node_name).get('real_children') - is_node_real_children_in_sparse_parameters = False - for real_child_item in node_real_children: - processed_real_child_item = ''.join(c for c in real_child_item if c not in string.punctuation).lower() - distributed_real_child_item = 'module.' + real_child_item - processed_distributed_real_child_item = 'module.' + processed_real_child_item - if (real_child_item in sparse_module_names) or (processed_real_child_item in processed_sparse_module_names) or (distributed_real_child_item in sparse_module_names) or (processed_distributed_real_child_item in processed_sparse_module_names): - is_node_real_children_in_sparse_parameters = True - # Firstly, we should make sure the BN is not in the last (due to it is connected to a FC/Conv node which is not in sparse_module_names), then: - # If the real_parents of BN node are in sparse_module_names, then it may need the offline permutation - # Or if the real_parents of BN node just needs to permute in K dim - if (is_node_real_children_in_sparse_parameters == True) and (is_node_real_parents_need_K_permutation == True): - if (is_node_real_parents_has_group_conv == False) and (is_node_real_parents_need_K_permutation == True): - fx_graph[node_name]['permutation_type'] = 'K' - fx_graph[node_name]['k_permuted'] = 'False' - else: # if node real_parents contains Group Conv or does not need permutation in 'K' dim, disable the permutation for node in 'K' dim - fx_graph[node_name]['permutation_type'] = 'None' - else: - fx_graph[node_name]['permutation_type'] = 'None' - else: - fx_graph[node_name]['permutation_type'] = 'None' - - # A special case: if the node itself not in sparse_module_names but one of its real_siblings in sparse_module_names, then the node will not do the permutation search, but it may need to apply the offline permutation in C dim according to the searched permutation sequence from its real_siblings in sparse_module_names - # We make it as the post-processing, because if we add this to the previous logic, will make it too complex - # Post-processing Step No.1: - print("\n[init_permutation_flag] Post-processing Step No.1.") - node_change_permutation_due_to_siblings = [] - for node_name in fx_graph.keys(): - node_real_siblings = fx_graph.get(node_name).get('real_siblings') - if node_real_siblings is not None: - is_node_real_siblings_needs_C_permutation = False - for real_sibling_item in node_real_siblings: - if fx_graph.get(real_sibling_item).get('permutation_type') in ['C', 'KC']: - is_node_real_siblings_needs_C_permutation = True - if is_node_real_siblings_needs_C_permutation == True: - print("[init_permutation_flag] node_name: \'{:}\', one of its real siblings need do offline permutation in C dim.".format(node_name)) - node_original_permutation_type = fx_graph.get(node_name).get('permutation_type') - if node_original_permutation_type in ['C', 'KC']: - print("[init_permutation_flag] node_name: \'{:}\', its original permutation: \'{:}\' already includes C dim, no need to do No.1 post-processing change.".format(node_name, node_original_permutation_type)) - elif node_original_permutation_type == 'None': - fx_graph[node_name]['permutation_type'] = 'C' - print("[init_permutation_flag] node_name: \'{:}\', change its original permutation: \'{:}\' to new permutation: 'C'.".format(node_name, node_original_permutation_type)) - node_change_permutation_due_to_siblings.append(node_name) - elif node_original_permutation_type == 'K': - fx_graph[node_name]['permutation_type'] = 'KC' - print("[init_permutation_flag] node_name: \'{:}\', change its original permutation: \'{:}\' to new permutation: 'KC'.".format(node_name, node_original_permutation_type)) - node_change_permutation_due_to_siblings.append(node_name) - # Post-processing Step No.2: - print("\n[init_permutation_flag] Post-processing Step No.2.") - for node_name in fx_graph.keys(): - node_real_children = fx_graph.get(node_name).get('real_children') - node_module_type = fx_graph.get(node_name).get('module_type') - if (node_real_children is not None) and (node_module_type in ['torch.nn.modules.conv.Conv2d', 'torch.nn.modules.linear.Linear', 'torch.nn.modules.batchnorm.BatchNorm2d']): - is_node_real_children_has_node_change_permutation = False - for real_child_item in node_real_children: - if real_child_item in node_change_permutation_due_to_siblings: - is_node_real_children_has_node_change_permutation = True - if is_node_real_children_has_node_change_permutation == True: - print("[init_permutation_flag] node_name: \'{:}\', one of its real children has changed permutation due to its siblings.".format(node_name)) - node_original_permutation_type = fx_graph.get(node_name).get('permutation_type') - if node_original_permutation_type in ['K', 'KC']: - print("[init_permutation_flag] node_name: \'{:}\', its original permutation: \'{:}\' already includes K dim, no need to do No.2 post-processing change.".format(node_name, node_original_permutation_type)) - elif node_original_permutation_type == 'None': - fx_graph[node_name]['permutation_type'] = 'K' - print("[init_permutation_flag] node_name: \'{:}\', change its original permutation: \'{:}\' to new permutation: 'K'.".format(node_name, node_original_permutation_type)) - elif node_original_permutation_type == 'C': - fx_graph[node_name]['permutation_type'] = 'KC' - print("[init_permutation_flag] node_name: \'{:}\', change its original permutation: \'{:}\' to new permutation: 'KC'.".format(node_name, node_original_permutation_type)) - - if cls.__save_permutation_graph: - cls.save_graph_to_json(fx_graph, save_dumped_graph_path_with_name=os.path.join(cls.__permutation_output_dir, './model_graph_init_permutation_flag.json')) # save the intermediate graph as JSON file for debugging - return fx_graph - - @classmethod - def extract_all_unique_siblings(cls, fx_graph): - """This function is used to extrat all unique siblings for the whole network graph built with Torch.FX.""" - print("\n[extract_all_unique_siblings] Extract all unique siblings for the whole network graph built with Torch.FX") - all_unique_siblings_name = [] - all_unique_siblings_module_type = [] - for node_name in fx_graph.keys(): - fx_graph[node_name]['node_type'] = 'network_node' # use the 'node_type' to divide the real nodes apart from the auxiliary info node, like 'unique_siblings' node - node_module_type = fx_graph.get(node_name).get('module_type') - node_real_siblings = fx_graph.get(node_name).get('real_siblings') - node_real_siblings_module_type = fx_graph.get(node_name).get('real_siblings_module_type') - if node_real_siblings == []: - print("[extract_all_unique_siblings] node_name: \'{:}\', node module type: \'{:}\', has no real siblings.".format(node_name, node_module_type)) - # for the Conv/FC layers without real_siblings, then we should insert itself as an unique_siblings - if node_module_type in ['torch.nn.modules.conv.Conv2d', 'torch.nn.modules.linear.Linear']: - # direct insert will change the real_siblings info for the node in the fx_graph - node_real_siblings_with_node_itself = node_real_siblings.copy() - node_real_siblings_with_node_itself.insert(0, node_name) - node_real_siblings_module_type_with_node_itself = node_real_siblings_module_type.copy() - node_real_siblings_module_type_with_node_itself.insert(0, node_module_type) - all_unique_siblings_name.append(node_real_siblings_with_node_itself) - all_unique_siblings_module_type.append(node_real_siblings_module_type_with_node_itself) - else: - print("[extract_all_unique_siblings] node_name: \'{:}\', node module type: \'{:}\', has {:} real siblings: \'{:}\'.".format(node_name, node_module_type, len(node_real_siblings), node_real_siblings)) - # for the two duplicated siblings lists, the node names included should be the same. - # If the node name is already included in one of the unique_siblings_name list, which means the real_siblings of this node is duplicated with the unique_siblings_name list. - # Otherwise, we should insert the [real_siblings + node_name] as a new unique_siblings_name list. - has_include_siblings = False - for unique_siblings_item in all_unique_siblings_name: - if node_name in unique_siblings_item: - has_include_siblings = True - if has_include_siblings == False: - # direct insert will change the real_siblings info for the node in the fx_graph - node_real_siblings_with_node_itself = node_real_siblings.copy() - node_real_siblings_with_node_itself.insert(0, node_name) - node_real_siblings_module_type_with_node_itself = node_real_siblings_module_type.copy() - node_real_siblings_module_type_with_node_itself.insert(0, node_module_type) - all_unique_siblings_name.append(node_real_siblings_with_node_itself) - all_unique_siblings_module_type.append(node_real_siblings_module_type_with_node_itself) - - fx_graph['unique_siblings'] = {} - fx_graph['unique_siblings']['name'] = all_unique_siblings_name - fx_graph['unique_siblings']['module_type'] = all_unique_siblings_module_type - fx_graph['unique_siblings']['node_type'] = 'auxiliary_info_node' - - if cls.__save_permutation_graph: - cls.save_graph_to_json(fx_graph, save_dumped_graph_path_with_name=os.path.join(cls.__permutation_output_dir, './model_graph_extract_all_unique_siblings.json')) # save the intermediate graph as JSON file for debugging - return fx_graph - - @classmethod - def find_real_siblings(cls, fx_graph): - """This function is used to find all siblings for each node according to the whole network graph built with Torch.FX. - we need to find siblings recursively, because siblings may have siblings via other parents we don't know about. - """ - print("\n[find_real_siblings] Find all siblings for each node according to the whole network graph built with Torch.FX") - for node_name in fx_graph.keys(): - node_real_siblings_name = [] - node_real_siblings_module_type = [] - node_real_parents = fx_graph.get(node_name).get('real_parents') - node_module_type = fx_graph.get(node_name).get('module_type') - if node_module_type not in ['torch.nn.modules.conv.Conv2d', 'torch.nn.modules.linear.Linear']: - print("[find_real_siblings] node_name: \'{:}\', node module type: \'{:}\', has no real siblings.".format(node_name, node_module_type)) - else: - print("[find_real_siblings] node_name: \'{:}\', node module type: \'{:}\', may have real siblings.".format(node_name, node_module_type)) - # sibling means the nodes share the same real parent - for real_parent_item in node_real_parents: - for real_child_item in fx_graph.get(real_parent_item).get('real_children'): - if real_child_item != node_name: - sibling_module_type = fx_graph.get(real_child_item).get('module_type') - print("[find_real_siblings] node_name: \'{:}\', has one real sibling: \'{:}\', its real sibling module type: \'{:}\'.".format(node_name, real_child_item, sibling_module_type)) - node_real_siblings_name.append(real_child_item) - node_real_siblings_module_type.append(sibling_module_type) - - # remove the duplicated real siblings - exclusive_node_real_siblings_name = [] - exclusive_node_real_siblings_module_type = [] - item_index = 0 - duplicated_real_siblings = 0 - for item in node_real_siblings_name: - if item not in exclusive_node_real_siblings_name: - exclusive_node_real_siblings_name.append(item) - exclusive_node_real_siblings_module_type.append(node_real_siblings_module_type[item_index]) - else: - duplicated_real_siblings = duplicated_real_siblings + 1 - item_index = item_index + 1 - if duplicated_real_siblings > 0: - print("[find_real_siblings] node_name: \'{:}\', remove {:} duplicated real siblings.".format(node_name, duplicated_real_siblings)) - fx_graph[node_name]['real_siblings'] = exclusive_node_real_siblings_name - fx_graph[node_name]['real_siblings_module_type'] = exclusive_node_real_siblings_module_type - - if cls.__save_permutation_graph: - cls.save_graph_to_json(fx_graph, save_dumped_graph_path_with_name=os.path.join(cls.__permutation_output_dir, './model_graph_find_real_siblings.json')) # save the intermediate graph as JSON file for debugging - return fx_graph - - @classmethod - def recursive_find_real_children(cls, node_name, fx_graph): - """This function is used to recursively find the real children for each node according to the whole network graph built with Torch.FX. - Used as the sub-function of find_real_children. - """ - node_real_children_name = [] - node_real_children_module_type = [] - if node_name in fx_graph.keys(): # can be deleted, because node_name is already in the 'children' item in one node of the fx_graph - node_children = fx_graph.get(node_name).get('children') - node_module_type = fx_graph.get(node_name).get('module_type') - has_visit_children_num = 0 - has_real_children_num = 0 - sub_node_need_recursive_search = [] - while has_visit_children_num < len(node_children): - for child_name in node_children: - if child_name != 'output': # 'output' node has no 'module_type' - child_module_type = fx_graph.get(child_name).get('module_type') - if child_module_type in ['torch.nn.modules.conv.Conv2d', 'torch.nn.modules.linear.Linear']: - print("[recursive_find_real_children] node_name: \'{:}\', has one real child: \'{:}\', its real child module type: \'{:}\'.".format(node_name, child_name, child_module_type)) - node_real_children_name.append(child_name) - node_real_children_module_type.append(child_module_type) - has_real_children_num = has_real_children_num + 1 - else: - print("[recursive_find_real_children] node_name: \'{:}\', its child: \'{:}\' with module type: \'{:}\', needs recursive search.".format(node_name, child_name, child_module_type)) - sub_node_need_recursive_search.append(child_name) - else: - print("[recursive_find_real_children] node_name: \'{:}\', its child: \'{:}\' with no module type, is not its real child.".format(node_name, child_name)) - has_visit_children_num = has_visit_children_num + 1 - if len(sub_node_need_recursive_search) > 0: - for sub_node in sub_node_need_recursive_search: - if fx_graph.get(sub_node).get('real_children') == []: - sub_node_real_children_name, sub_node_real_children_module_type = cls.recursive_find_real_children(sub_node, fx_graph) - else: - # if the sub_node already find the 'real_children', no need to do recursive search - sub_node_real_children_name = fx_graph.get(sub_node).get('real_children') - sub_node_real_children_module_type = fx_graph.get(sub_node).get('real_children_module_type') - node_real_children_name.extend(sub_node_real_children_name) - node_real_children_module_type.extend(sub_node_real_children_module_type) - return node_real_children_name, node_real_children_module_type - - @classmethod - def find_real_children(cls, fx_graph): - """This function is used to find the real children for each node according to the whole network graph built with Torch.FX. - For example: - The real children of Conv is the subsequent Conv/FC. - The real children of BN or other no-need-permutataion layers is the subsequent Conv/FC. - """ - print("\n[find_real_children] Find the real children for each node according to the whole network graph built with Torch.FX") - from sys import version_info - if version_info.major == 3 and version_info.minor >= 8: - reversible_fx_graph_keys = fx_graph.keys() - else: # 'dict_keys' object is not reversible in previous of Python 3.8 - reversible_fx_graph_keys = list(fx_graph.keys()) - for node_name in reversed(reversible_fx_graph_keys): # as the optimization, we need to find the real children from back to front, to use the already saved 'real_children' - node_real_children_name = [] - node_real_children_module_type = [] - node_children = fx_graph.get(node_name).get('children') - node_module_type = fx_graph.get(node_name).get('module_type') - if node_module_type not in ['torch.nn.modules.conv.Conv2d', 'torch.nn.modules.linear.Linear']: - print("\n[find_real_children] node_name: \'{:}\', node module type: \'{:}\', children num: {:}, recursive to find real children.".format(node_name, node_module_type, len(node_children))) - node_real_children_name, node_real_children_module_type = cls.recursive_find_real_children(node_name, fx_graph) - else: # Quick method, but cannot get the real children for no-need-permutataion layers like BN - print("\n[find_real_children] node_name: \'{:}\', node module type: \'{:}\', children num: {:}, can directly find real children.".format(node_name, node_module_type, len(node_children))) - # if the node is in the 'real_parents' list of the other node, then the other node is the real children for this node - for other_node_name in fx_graph.keys(): - if (other_node_name != node_name) and (node_name in fx_graph.get(other_node_name).get('real_parents')): - child_module_type = fx_graph.get(other_node_name).get('module_type') - if child_module_type in ['torch.nn.modules.conv.Conv2d', 'torch.nn.modules.linear.Linear']: - print("[find_real_children] node_name: \'{:}\', has one real child: \'{:}\', its real child module type: \'{:}\'.".format(node_name, other_node_name, child_module_type)) - node_real_children_name.append(other_node_name) - node_real_children_module_type.append(child_module_type) - - # remove the duplicated real children - exclusive_node_real_children_name = [] - exclusive_node_real_children_module_type = [] - item_index = 0 - duplicated_real_children = 0 - for item in node_real_children_name: - if item not in exclusive_node_real_children_name: - exclusive_node_real_children_name.append(item) - exclusive_node_real_children_module_type.append(node_real_children_module_type[item_index]) - else: - duplicated_real_children = duplicated_real_children + 1 - item_index = item_index + 1 - if duplicated_real_children > 0: - print("[find_real_children] node_name: \'{:}\', remove {:} duplicated real children.".format(node_name, duplicated_real_children)) - fx_graph[node_name]['real_children'] = exclusive_node_real_children_name - fx_graph[node_name]['real_children_module_type'] = exclusive_node_real_children_module_type - - if cls.__save_permutation_graph: - cls.save_graph_to_json(fx_graph, save_dumped_graph_path_with_name=os.path.join(cls.__permutation_output_dir, './model_graph_find_real_children.json')) # save the intermediate graph as JSON file for debugging - return fx_graph - - @classmethod - def find_real_parents(cls, fx_graph): - """This function is used to find the real parents for each node according to the whole network graph built with Torch.FX. - For example: - The real parent of BN is the previous Conv/FC. - The real parent of Conv is the previous Conv/FC. - """ - print("\n[find_real_parents] Find the real parents for each node according to the whole network graph built with Torch.FX") - for node_name in fx_graph.keys(): - node_real_parents_name = [] - node_real_parents_module_type = [] - node_parents = fx_graph.get(node_name).get('parents') - print("[find_real_parents] node_name: \'{:}\', parents num: {:}".format(node_name, len(node_parents))) - - has_visit_parent_num = 0 - while has_visit_parent_num < len(node_parents): - for parent_name in node_parents: - if fx_graph.__contains__(parent_name): - parent_module_type = fx_graph.get(parent_name).get('module_type') - if parent_module_type in ['torch.nn.modules.conv.Conv2d', 'torch.nn.modules.linear.Linear']: - print("[find_real_parents] node_name: \'{:}\', has one real parent: \'{:}\', its real parent module type: \'{:}\'.".format(node_name, parent_name, parent_module_type)) - node_real_parents_name.append(parent_name) - node_real_parents_module_type.append(parent_module_type) - else: - print("[find_real_parents] node_name: \'{:}\', has one/several real parent(s): \'{:}\', its real parent module type: \'{:}\'.".format(node_name, fx_graph[parent_name]['real_parents'], fx_graph[parent_name]['real_parents_module_type'])) - for real_parent_item in fx_graph[parent_name]['real_parents']: - node_real_parents_name.append(real_parent_item) - for real_parent_module_type_item in fx_graph[parent_name]['real_parents_module_type']: - node_real_parents_module_type.append(real_parent_module_type_item) - else: - print("[find_real_parents] node_name: \'{:}\', has no real parent because this is the first node.".format(node_name)) - has_visit_parent_num = has_visit_parent_num + 1 - - # remove the duplicated real parents - exclusive_node_real_parents_name = [] - exclusive_node_real_parents_module_type = [] - exclusive_node_real_parents_groups_param = [] - item_index = 0 - duplicated_real_parents = 0 - for item in node_real_parents_name: - if item not in exclusive_node_real_parents_name: - exclusive_node_real_parents_name.append(item) - exclusive_node_real_parents_module_type.append(node_real_parents_module_type[item_index]) - exclusive_node_real_parents_groups_param.append(fx_graph.get(item).get('groups_param')) - else: - duplicated_real_parents = duplicated_real_parents + 1 - item_index = item_index + 1 - if duplicated_real_parents > 0: - print("[find_real_parents] node_name: \'{:}\', remove {:} duplicated real parents.".format(node_name, duplicated_real_parents)) - fx_graph[node_name]['real_parents'] = exclusive_node_real_parents_name - fx_graph[node_name]['real_parents_module_type'] = exclusive_node_real_parents_module_type - fx_graph[node_name]['real_parents_groups_param'] = exclusive_node_real_parents_groups_param - - if cls.__save_permutation_graph: - cls.save_graph_to_json(fx_graph, save_dumped_graph_path_with_name=os.path.join(cls.__permutation_output_dir, './model_graph_find_real_parent.json')) # save the intermediate graph as JSON file for debugging - return fx_graph - - @classmethod - def build_fx_graph(cls, model, dump_fx_graph=False, save_dumped_fx_graph='./model_fx_graph.json'): - """This function is used to build the whole network graph with Torch.FX features.""" - success = True - torch_version = str(torch.__version__) - torch_version_major = int(torch_version.split('.')[0]) - torch_version_minor = int(torch_version.split('.')[1]) - try: - torch_version_minimum = int(torch_version.split('.')[2]) - except ValueError: # support the none standard version - torch_version_minimum = torch_version.split('.')[2] - print("[build_fx_graph] The torch version is: {}, version major is: {}, version minor is: {}, version minimum is: {}".format(torch_version, torch_version_major, torch_version_minor, torch_version_minimum)) - if torch_version_major >= 1 and torch_version_minor >= 8: - print("[build_fx_graph] The Torch.FX is supported.") - else: # Torch.FX is introduced in torch 1.8.0 - print("[build_fx_graph] The Torch.FX is not supported. So cannot build the Torch.FX graph.") - success = False - network_fx_graph = {} - return network_fx_graph, success - - print("\n[build_fx_graph] Print the model structure with pure PyTorch function") - print(model) - - print("\n[build_fx_graph] Build the module name and type dictionary") - module_name_type_dict = {} - module_name_group_conv_dict = {} - for name, mod in model.named_modules(): - print("[build_fx_graph] module_name: {}, module type: {}".format(name, type(mod))) - module_name_type_dict[name] = str(type(mod)).split("\'")[1] - try: - print("[build_fx_graph] this module has \'group\' param with value: {}".format(mod.groups)) - module_name_group_conv_dict[name] = str(mod.groups) - except: - module_name_group_conv_dict[name] = 'None' - continue - - graph_module = cls.print_raw_fx_graph(model, print_tabular=True) - - # keep track of children and parents for each layer (could be call_module or call_function) - print("\n[build_fx_graph] Print the children and parents relationship for each layer") - network_fx_graph = {} - for node in graph_module.graph.nodes: - if node.op == 'placeholder': - print("[build_fx_graph] This is the \'input\' node: {:}".format(node.target)) - continue - elif node.op == 'get_attr': - print("[build_fx_graph] This is the \'get_attr\' node: {:}".format(node.target)) - continue - elif node.op == 'call_function': # e.g. 'adaptive.avg.pool2d', 'add', 'cat', 'flatten', 'floordiv', 'getattr', 'getitem', 'hardsigmoid', 'mean', 'mul', 'relu', 'transpose' - node_parent, node_children = get_node_parent_children(node) - converted_node_name=convert_fx_node_name(node.name) - print("[build_fx_graph] This is the \'call_function\' node: {:}, its parent list: {:}, its children list: {:}".format(converted_node_name, node_parent, node_children)) - network_fx_graph[converted_node_name] = {} - network_fx_graph[converted_node_name]['parents'] = node_parent - network_fx_graph[converted_node_name]['children'] = node_children - network_fx_graph[converted_node_name]['fx_op'] = 'call_function' - elif node.op == 'call_method': # e.g. 'chunk', 'contiguous', 'mean', 'size', 'unsqueeze', 'view' - node_parent, node_children = get_node_parent_children(node) - converted_node_name=convert_fx_node_name(node.name) - print("[build_fx_graph] This is the \'call_method\' node: {:}, its parent list: {:}, its children list: {:}".format(converted_node_name, node_parent, node_children)) - network_fx_graph[converted_node_name] = {} - network_fx_graph[converted_node_name]['parents'] = node_parent - network_fx_graph[converted_node_name]['children'] = node_children - network_fx_graph[converted_node_name]['fx_op'] = 'call_method' - continue - elif node.op == 'call_module': - node_parent, node_children = get_node_parent_children(node) - converted_node_name=convert_fx_node_name(node.name) - # check whether the converted_node_name is same as node.target, especially for ReLU case - if converted_node_name != node.target: - print("[build_fx_graph][warning] The target name from Torch.FX is \'{:}\', the manually converted node name is \'{:}\', not the same one, choose the converted node name".format(node.target, converted_node_name)) - # assume the modules share the same target name have the same type, because converted_node_name may not be obtained by model.named_modules(), like some ReLU (defined in forward function) - node_type = module_name_type_dict[node.target] - print("[build_fx_graph] This is the \'call_module\' node: {:}, its parent list: {:}, its children list: {:}, its type: {:}".format(converted_node_name, node_parent, node_children, node_type)) - network_fx_graph[converted_node_name] = {} - network_fx_graph[converted_node_name]['parents'] = node_parent - network_fx_graph[converted_node_name]['children'] = node_children - network_fx_graph[converted_node_name]['fx_op'] = 'call_module' - network_fx_graph[converted_node_name]['module_type'] = node_type - network_fx_graph[converted_node_name]['groups_param'] = module_name_group_conv_dict[node.target] - elif node.op == 'output': - print("[build_fx_graph] This is the \'output\' node: {:}".format(node.target)) - continue - - if dump_fx_graph: - print("\n[build_fx_graph] Dump the overall dict for children and parents relationship into JSON file") - cls.save_graph_to_json(network_fx_graph, save_dumped_graph_path_with_name=save_dumped_fx_graph) - - return network_fx_graph, success - - @classmethod - def print_raw_fx_graph(cls, model, print_tabular=False, generate_python_code=False): - """This function is used to print the intermediate representation (IR) - Graph representation with Torch.FX features.""" - from torch.fx import symbolic_trace - # Symbolic tracing frontend - captures the semantics of the module - try: - symbolic_traced : torch.fx.GraphModule = symbolic_trace(model) - except: - if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: - print("\n[print_raw_fx_graph] Meet the fatal fault when trying to symbolic trace the model with Torch.FX") - raise - exit(0) - - # High-level intermediate representation (IR) - Graph representation - print("\n[print_raw_fx_graph] Print the intermediate representation (IR) with Torch.FX") - print(symbolic_traced.graph) - - if print_tabular: - print("\n[print_raw_fx_graph] Print the intermediate representation (IR) with Torch.FX in a table format") - try: - symbolic_traced.graph.print_tabular() - except AttributeError: # to avoid the AttributeError: 'Graph' object has no attribute 'print_tabular' - print("[print_raw_fx_graph][Warning] \'print_tabular\' function is not supported in current Torch version. Skip!") - - # Code generation - valid Python code - if generate_python_code: - print("\n[print_raw_fx_graph] Create valid Python code matching the IR/Graph's semantics with Torch.FX") - print(symbolic_traced.code) - - return symbolic_traced - - @classmethod - def save_graph_to_json(cls, graph, save_dumped_graph_path_with_name='./model_fx_graph.json'): - """This function is used to same the graph into JSON file.""" - # use dumps to transfer the dict to JSON string - json_graph_str = json.dumps(graph) - with open(save_dumped_graph_path_with_name, 'w', encoding='utf-8') as dumped_graph_file: - dumped_graph_file.write(json_graph_str) # write the transferred JSON string into JSON file diff --git a/apex/contrib/sparsity/permutation_search_kernels/CUDA_kernels/permutation_search_kernels.cu b/apex/contrib/sparsity/permutation_search_kernels/CUDA_kernels/permutation_search_kernels.cu deleted file mode 100644 index c7b053c..0000000 --- a/apex/contrib/sparsity/permutation_search_kernels/CUDA_kernels/permutation_search_kernels.cu +++ /dev/null @@ -1,371 +0,0 @@ -#include -#include -#include -namespace py = pybind11; - -#define gpuErrchk(ans) { gpuAssert((ans), __FILE__, __LINE__); } -inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort=true) -{ - if (code != cudaSuccess) - { - fprintf(stderr,"GPUassert %d: %s %s %d\n", (int)code, cudaGetErrorString(code), file, line); - if (abort) exit(code); - } -} - -__device__ float group_2_to_4(float4 vals) -{ - vals.x = fabs(vals.x); - vals.y = fabs(vals.y); - vals.z = fabs(vals.z); - vals.w = fabs(vals.w); - - float sum0 = vals.x + vals.y; - float sum1 = vals.x + vals.z; - float sum2 = vals.x + vals.w; - float sum3 = vals.y + vals.z; - float sum4 = vals.y + vals.w; - float sum5 = vals.z + vals.w; - - float best_sum0 = fmax(sum0, sum1); - float best_sum1 = fmax(sum2, sum3); - float best_sum2 = fmax(sum4, sum5); - float best_sum = fmax(fmax(best_sum0, best_sum1), best_sum2); - - return best_sum; -} - -inline float* float_ptr_from_numpy(py::array_t& py_float) -{ - return (float*)py_float.data(); -} - -inline unsigned int* uint_ptr_from_numpy(py::array_t& py_uint) -{ - return (unsigned int*)py_uint.data(); -} - -__global__ void subset_sum_after_2_to_4(float* matrix, - unsigned int rows, - unsigned int cols, - unsigned int start_col, - unsigned int end_col, - float* output) -{ - // vectorize - float4* mat4 = (float4*) matrix; - cols /= 4; - start_col /= 4; - end_col /= 4; - - // each thread in a block takes some number of rows - size_t num_rows = max((int)ceilf((float)rows / (float)blockDim.x), 1); - size_t row_offset = num_rows * threadIdx.x; - // each block takes some number of columns - size_t num_cols = (end_col - start_col) / gridDim.x; - size_t col_offset = num_cols * blockIdx.x; - start_col += col_offset; - end_col = start_col + num_cols; - - float sum = 0.0f; - for ( unsigned int r = row_offset; r < row_offset + num_rows; ++r ) { - if (r < rows) { - for ( unsigned int c = start_col; c < end_col; c++ ) { - sum += group_2_to_4(mat4[r * cols + c]); - } - } - } - - atomicAdd(output, sum); -} - -// build the entire permute map at once -// each block handles one group of stripes -// each threads in the block handle all handle the same permutation at the same time on different rows before moving to the next permutation -__global__ void build_permute_map(float* matrix, - unsigned int rows, - unsigned int cols, - unsigned int* stripes, - unsigned int group_width, - unsigned int* permutations, - unsigned int num_permutations, - unsigned int perm_length, - float* output, - unsigned int* best_indices) -{ - // vectorize - float4* mat4 = (float4*) matrix; - cols /= 4; - - // each block handles a group of stripes - unsigned int* stripe_group = (unsigned int*)&stripes[blockIdx.x*group_width]; - - // shared memory: 32 threads each need 16*2 - extern __shared__ float pm_shared[32][32]; - float4* local_stripes = (float4*)&pm_shared[threadIdx.x]; - float* local_columns = (float*) &pm_shared[threadIdx.x]; - float4* permuted_stripes = (float4*) &local_stripes[4]; - float* permuted_columns = (float*) &local_columns[16]; - - // each thread handles all permutations in the row before moving on to the next row - size_t num_rows = max((int)ceilf((float)rows / (float)blockDim.x), 1); - size_t row_offset = num_rows * threadIdx.x; - - for ( unsigned int r = row_offset; r < row_offset + num_rows; ++r) { - if (r >= rows) - break; - - // load a row into smem - for ( unsigned int s = 0; s < group_width; ++s) { - unsigned int const stripe = stripe_group[s]; - local_stripes[s] = mat4[r*cols+stripe]; - } - - for ( unsigned int p = 0; p < num_permutations; ++p) { - unsigned int* permutation = &permutations[p*perm_length]; - float sum = 0.0f; - - // permute - #pragma unroll 4 - for ( unsigned int c = 0; c < group_width*4; ++c) { - permuted_columns[c] = local_columns[permutation[c]]; - } - - // sum 2:4 - for ( unsigned int s = 0; s < group_width; ++s) { - sum += group_2_to_4(permuted_stripes[s]); - } - - // update the running sum for this stripe group's permutation - atomicAdd(&output[blockIdx.x*num_permutations + p], sum); - } - } - - // at this point, each permutation's sum in this stripe group has been calculated - // now, find the best option - __syncthreads(); - - if (threadIdx.x == 0) { - unsigned int best_permutation = 0; - float best_magnitude = output[blockIdx.x*num_permutations]; - float base_magnitude = best_magnitude; - - //#pragma unroll 32 - for (unsigned int p = 1; p < num_permutations; ++p) { - float magnitude = output[blockIdx.x*num_permutations+p]; - if (magnitude > best_magnitude) { - best_permutation = p; - best_magnitude = magnitude; - } - } - - output[blockIdx.x*num_permutations] = best_magnitude - base_magnitude; - best_indices[blockIdx.x] = best_permutation; - } -} - - -void free_sum_after_2_to_4_memory(float** dmatrix, - float** dresult) -{ - cudaFree(*dmatrix); - cudaFree(*dresult); -} - -int set_up_sum_after_2_to_4_memory(float** dmatrix, - unsigned int rows, - unsigned int cols, - float** dresult) -{ - static unsigned int setupRows = 0; - static unsigned int setupCols = 0; - static bool allocated = false; - - int fresh_allocation = 0; - if (!allocated || - setupRows != rows || - setupCols != cols) - { - if (allocated) - free_sum_after_2_to_4_memory(dmatrix, dresult); - - gpuErrchk(cudaMalloc( (void**) dmatrix, rows*cols*sizeof(float))); - gpuErrchk(cudaMalloc( (void**) dresult, sizeof(float))); - - setupRows = rows; - setupCols = cols; - - fresh_allocation = 1; - } - - allocated = true; - - return fresh_allocation; -} - -int run_subset_sum_after_2_to_4(py::array_t& py_matrix, - unsigned int rows, - unsigned int cols, - unsigned int start_col, - unsigned int end_col, - unsigned int blocks, - unsigned int threads, - py::array_t& py_output) -{ - - static float* d_matrix; - static float* d_result; - - int fresh_allocation = set_up_sum_after_2_to_4_memory(&d_matrix, rows, cols, &d_result); - - float* matrix = float_ptr_from_numpy(py_matrix); - float* output = float_ptr_from_numpy(py_output); - - gpuErrchk(cudaMemcpy( d_matrix, matrix, rows*cols*sizeof(float), cudaMemcpyHostToDevice )); - gpuErrchk(cudaMemset( d_result, 0, sizeof(float))); - - subset_sum_after_2_to_4<<>>(d_matrix, rows, cols, start_col, end_col, d_result); - gpuErrchk(cudaDeviceSynchronize()); - - gpuErrchk(cudaMemcpy( output, d_result, sizeof(float), cudaMemcpyDeviceToHost )); - - return 0; -} - -void set_up_permute_map_memory(float** dmatrix, - unsigned int rows, - unsigned int cols, - unsigned int** dstripes, - unsigned int num_groups, - unsigned int group_width, - unsigned int** dpermutations, - unsigned int num_permutations, - unsigned int perm_length, - float** doutput, - unsigned int** dindices, - float** hresult, - unsigned int** hindices) -{ - static unsigned int setUpRows = 0; - static unsigned int setUpCols = 0; - static unsigned int setUpGroupWidth = 0; - static unsigned int setUpNumGroups = 0; - static unsigned int setUpNumPerms = 0; - static unsigned int setUpPermLength = 0; - - if (setUpRows != rows || - setUpCols != cols) { - if (*dmatrix != NULL) { gpuErrchk(cudaFree(*dmatrix)); *dmatrix = NULL; } - gpuErrchk(cudaMalloc( (void**) dmatrix, rows*cols*sizeof(float))); - } - - if (setUpGroupWidth < group_width || - setUpNumGroups < num_groups) { - if (*dstripes != NULL) { gpuErrchk(cudaFree(*dstripes)); *dstripes = NULL; } - gpuErrchk(cudaMalloc( (void**) dstripes, num_groups*group_width*sizeof(unsigned int))); - - if (setUpNumGroups < num_groups) { - if (*dindices != NULL) { gpuErrchk(cudaFree(*dindices)); *dindices = NULL; } - gpuErrchk(cudaMalloc( (void**) dindices, num_groups*sizeof(unsigned int))); - if (*hindices != NULL) { free(*hindices); *hindices = NULL; } - *hindices = (unsigned int*) malloc (num_groups*sizeof(unsigned int)); - } - } - - if (setUpNumPerms < num_permutations || - setUpPermLength < perm_length) { - if (*dpermutations != NULL) { gpuErrchk(cudaFree(*dpermutations)); *dpermutations = NULL; } - gpuErrchk(cudaMalloc( (void**) dpermutations, perm_length*num_permutations*sizeof(unsigned int))); - } - - if (setUpNumPerms < num_permutations || - setUpNumGroups < num_groups) { - if (*doutput != NULL) { gpuErrchk(cudaFree(*doutput)); *doutput = NULL; } - gpuErrchk(cudaMalloc( (void**) doutput, num_permutations*num_groups*sizeof(float))); - if (*hresult != NULL) { free(*hresult); *hresult = NULL; } - *hresult = (float*) malloc(num_permutations*num_groups*sizeof(float)); - } - - setUpRows = rows; - setUpCols = cols; - setUpGroupWidth = group_width; - setUpNumGroups = num_groups; - setUpNumPerms = num_permutations; - setUpPermLength = perm_length; -} - -int run_build_permute_map(py::array_t& py_matrix, - unsigned int rows, - unsigned int cols, - py::array_t& py_stripes, - unsigned int num_groups, - unsigned int group_width, - py::array_t& py_permutations, - //unsigned int num_permutations, - unsigned int perm_length, - py::array_t& py_improvements, - py::array_t& py_best_indices) -{ - static float* d_matrix = NULL; - static unsigned int* d_stripes = NULL; - static unsigned int* d_permutations = NULL; - static float* d_output = NULL; - static unsigned int* d_indices = NULL; - static float* hresult = NULL; - static unsigned int* hindices = NULL; - - //const unsigned int cols = py_matrix.size() / rows; - //const unsigned int num_groups = py_stripes.size() / group_width; - //const unsigned int perm_length = group_width * 4; // 2:4 sparsity - each stripe in the group is 4 elements wide - const unsigned int num_permutations = py_permutations.size() / perm_length; - - const unsigned int MAX_GROUPS_PER_LAUNCH = num_permutations <= 5775 ? 1820 : 40; - const unsigned int full_launches = num_groups / MAX_GROUPS_PER_LAUNCH; - const unsigned int final_launch = num_groups % MAX_GROUPS_PER_LAUNCH; - const unsigned int launches = full_launches + (final_launch != 0 ? 1 : 0); - - set_up_permute_map_memory(&d_matrix, rows, cols, &d_stripes, min(num_groups,MAX_GROUPS_PER_LAUNCH), group_width, &d_permutations, num_permutations, perm_length, &d_output, &d_indices, &hresult, &hindices); - - float* matrix = float_ptr_from_numpy(py_matrix); - unsigned int* stripes = uint_ptr_from_numpy(py_stripes); - unsigned int* permutations = uint_ptr_from_numpy(py_permutations); - float* improvements = float_ptr_from_numpy(py_improvements); - unsigned int* best_indices = uint_ptr_from_numpy(py_best_indices); - - gpuErrchk(cudaMemcpy( d_matrix, matrix, rows*cols*sizeof(float), cudaMemcpyHostToDevice )); - gpuErrchk(cudaMemcpy( d_permutations, permutations, num_permutations*perm_length*sizeof(unsigned int), cudaMemcpyHostToDevice )); - - unsigned int group_offset = 0; - for (unsigned int l = 0; l < launches; ++l) - { - unsigned int groups_this_launch = (l < full_launches) ? MAX_GROUPS_PER_LAUNCH : final_launch; - - gpuErrchk(cudaMemcpy( d_stripes, &stripes[group_offset*group_width], groups_this_launch*group_width*sizeof(unsigned int), cudaMemcpyHostToDevice )); - gpuErrchk(cudaMemset( d_output, 0, groups_this_launch*num_permutations*sizeof(float))); - gpuErrchk(cudaMemset( d_indices, 0, groups_this_launch*sizeof(unsigned int))); - - unsigned int shmem = 32*(32)*sizeof(float); - build_permute_map<<>>(d_matrix, rows, cols, d_stripes, group_width, d_permutations, num_permutations, perm_length, d_output, d_indices); - gpuErrchk(cudaDeviceSynchronize()); - - gpuErrchk(cudaMemcpy( hresult, d_output, num_permutations*groups_this_launch*sizeof(float), cudaMemcpyDeviceToHost )); - gpuErrchk(cudaMemcpy( hindices, d_indices, groups_this_launch*sizeof(unsigned int), cudaMemcpyDeviceToHost )); - - // thread0 stuck the minimum in the first slot of each group - for (unsigned int g = 0; g < groups_this_launch; ++g) { - improvements[group_offset+g] = hresult[g*num_permutations]; - best_indices[group_offset+g] = hindices[g]; - } - - group_offset += groups_this_launch; - } - - return 0; - -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ - m.def("sum_after_2_to_4", &run_subset_sum_after_2_to_4, "matrix sum after applying 2:4 (CUDA)"); - m.def("build_permute_map", &run_build_permute_map, "optimize stripe groups (CUDA)"); -} \ No newline at end of file diff --git a/apex/contrib/sparsity/permutation_search_kernels/__init__.py b/apex/contrib/sparsity/permutation_search_kernels/__init__.py deleted file mode 100644 index bd29c6d..0000000 --- a/apex/contrib/sparsity/permutation_search_kernels/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .call_permutation_search_kernels import accelerated_search_for_good_permutation -from .permutation_utilities import sum_after_2_to_4 \ No newline at end of file diff --git a/apex/contrib/sparsity/permutation_search_kernels/call_permutation_search_kernels.py b/apex/contrib/sparsity/permutation_search_kernels/call_permutation_search_kernels.py deleted file mode 100644 index 56547d8..0000000 --- a/apex/contrib/sparsity/permutation_search_kernels/call_permutation_search_kernels.py +++ /dev/null @@ -1,74 +0,0 @@ -import numpy as np -from .permutation_utilities import * -from .exhaustive_search import Exhaustive_Search - -def accelerated_search_for_good_permutation(matrix_group, options=None): - """This function is used to call the permutation search CUDA kernels. - users can provide prefer search strategy by providing a valid 'options' as a dictionary, - or users can implement their customized 'accelerated_search_for_good_permutation' function. - """ - input_matrix = matrix_group.cpu().detach().numpy() - print("\n[accelerated_search_for_good_permutation] input matrix shape: \'{:}\'.".format(input_matrix.shape)) - - result = np.copy(input_matrix) - # init a sequential permutation search sequence - input_channel_num = matrix_group.size()[1] - permutation_sequence = [n for n in range(input_channel_num)] - duration = 0.0 - - if options == None: - options = {} - if 'strategy' not in options: # right now, the default permutation search strategy is: 'exhaustive' search - options['strategy'] = 'exhaustive' - print("[accelerated_search_for_good_permutation] the permutation strategy is: \'{:} search\'.".format(options['strategy'])) - - # define sub options for each search strategy - if options['strategy'] == 'exhaustive': - # right now, the default options for 'exhaustive' search is: 'exhaustive,8,100' - if 'stripe_group_size' not in options: - options['stripe_group_size'] = 8 - if 'escape_attempts' not in options: - options['escape_attempts'] = 100 - elif options['strategy'] == 'progressive channel swap': - # just swaps meaningful channels, keeping the good swaps, until the search time limit expires. - if 'progressive_search_time_limit' not in options: - options['progressive_search_time_limit'] = 60 - if 'improvement_threshold' not in options: - options['improvement_threshold'] = 1e-9 - - # execute the requested strategy - if options['strategy'] == 'exhaustive': - result, duration, permutation_sequence = Exhaustive_Search(result, stripe_group_size=options['stripe_group_size'], escape_attempts=options['escape_attempts']) - elif options['strategy'] == 'progressive channel swap': - real_swap_num = 0 - start_time = time.perf_counter() - while time.perf_counter() - start_time < options['progressive_search_time_limit']: - src = np.random.randint(result.shape[1]) - dst = np.random.randint(result.shape[1]) - src_group = int(src/4) - dst_group = int(dst/4) - if src_group == dst_group: # channel swapping within a stripe does nothing - continue - new_sum, improvement = try_swap(result, dst, src) - if improvement > options['improvement_threshold']: - result[...,[src,dst]] = result[...,[dst,src]] - permutation_sequence[src], permutation_sequence[dst] = permutation_sequence[dst], permutation_sequence[src] - real_swap_num += 1 - duration = time.perf_counter() - start_time - print("\tFinally swap {} channel pairs until the search time limit expires.".format(real_swap_num)) - elif options['strategy'] == 'user defined': # need to get the permutated matrix (result) by applying customized permutation search function - print("[accelerated_search_for_good_permutation] Use the user customized permutation search function!") - else: - print("[accelerated_search_for_good_permutation] Cannot find the implementation of the required strategy!") - print("[accelerated_search_for_good_permutation] Take {:.4f} seconds to search the permutation sequence.".format(duration)) - - # In the new version of Exhaustive_Search function, there’s no need to use the find_permutation(result, input_matrix) function - # to recover the permutation sequence applied to the input_matrix to get the result separately any more. - #start_time_find_permutation = time.perf_counter() - #permutation_sequence = find_permutation(result, input_matrix) - #duration_find_permutation = time.perf_counter() - start_time_find_permutation - #print("[accelerated_search_for_good_permutation] Take {:.4f} seconds to finish find_permutation function.".format(duration_find_permutation)) - #print("[accelerated_search_for_good_permutation] The permutation sequence is: {:}".format(permutation_sequence)) - #print("[accelerated_search_for_good_permutation] The length of permutation sequence is: {:}".format(len(permutation_sequence))) - - return permutation_sequence diff --git a/apex/contrib/sparsity/permutation_search_kernels/exhaustive_search.py b/apex/contrib/sparsity/permutation_search_kernels/exhaustive_search.py deleted file mode 100644 index 945b1eb..0000000 --- a/apex/contrib/sparsity/permutation_search_kernels/exhaustive_search.py +++ /dev/null @@ -1,371 +0,0 @@ -from .permutation_utilities import * - -################################################################################################################ -# Exhaustive -# Try them all -# - order of columns within a group doesn't matter -# - order of groups doesn't matter -# - we can eliminate effective duplicates by defining aunique combination to be a sorted list of sorted groups -################################################################################################################ - -#################################################################### -# generate unique permutations -#################################################################### - -# check if adding a column index to a current permutation would keep it in canonical form -# assumes that perm is in canonical form already! -def is_canonical(perm, col): - # if it's a new group - if len(perm) % 4 == 0: - # every column ID < col needs to be in the permutation already - for val in range(col): - if val not in perm: - return False - # this new group needs to be sorted w.r.t. the previous group - return col > perm[-4] - - # not a new group, just check to see if it will still be sorted - return col > perm[-1] - - -# recursive: build a unique permutation one column index at a time -def generate_unique_combinations(built_permutation, remaining_columns, full_permutation_list, group_width): - - # base case: nothing else to add - if len(remaining_columns) == 0: - full_permutation_list.append(np.copy(built_permutation)) - if len(full_permutation_list) % 1000000 == 0: - print(f"{len(full_permutation_list)} unique permutations found so far") - - # still more choices to make, so add each remaining column in turn column if it keeps everything sorted - else: - for c in range(len(remaining_columns)): - # to satisfy our immutables (values within groups are sorted, groups are globally sorted), - # only add this column if either: - # it's starting a new group and is larger than the previous group's first entry - # OR - # it's larger than the last value in the built_permutation - col_to_add = remaining_columns[c] - - if is_canonical(built_permutation, col_to_add): - # add the column to the running permutation, remove it from remaining columns - built_permutation.append(col_to_add) - remaining_columns.pop(c) - # recurse - generate_unique_combinations(built_permutation, remaining_columns, full_permutation_list, group_width) - # remove the most recent column and put it back on the remaining column list where we found it (sorted) - remaining_columns.insert(c, built_permutation.pop(-1)) - -import pickle -import os.path -from os import path -master_unique_permutation_list = {} -def generate_all_unique_combinations(C, M, must_use_all_groups = False): - global master_unique_permutation_list - if len(master_unique_permutation_list) == 0 and path.exists("master_list.pkl"): - with open("master_list.pkl","rb") as cache: - master_unique_permutation_list = pickle.load(cache) - - if (C,M) not in master_unique_permutation_list: - full_permutation_list = [] - generate_unique_combinations([0], [c for c in range(1,C)], full_permutation_list, M) - master_unique_permutation_list[(C,M)] = full_permutation_list - - with open("master_list.pkl", "wb") as cache: - pickle.dump(master_unique_permutation_list, cache) - - unique_permutations = master_unique_permutation_list[(C,M)] - - return unique_permutations - -# analytical solution -import math -def predict_unique_combinations(C, M): - assert(C%M==0) - G = int(C/M) - return int(int(math.factorial(C)) / (int(math.pow(math.factorial(M),G)) * math.factorial(G))) - -################################################################# -# exhaustively try all unique permutations -################################################################# - -# exhaustively search the entire matrix -def search_matrix(matrix, group_width): - # give up quickly if we'd go on forever - prediction = predict_unique_combinations(matrix.shape[1], group_width) - best_permutation = [c for c in range(matrix.shape[1])] - if prediction > 1e10: - print(f"There are {prediction} unique combinations with {matrix.shape[1]} columns and a group width of {group_width}, not searching.") - return matrix, prediction, best_permutation - - start_time = time.perf_counter() - full_permutation_list = generate_all_unique_combinations(matrix.shape[1], group_width) - - # found them, now try them - best_improvement = 0.0 - base_sum = sum_after_2_to_4(matrix) - for i in range(1,len(full_permutation_list)): - permutation = full_permutation_list[i] - permuted = matrix[:, permutation] - cur_improvement = sum_after_2_to_4(permuted) - base_sum - - if (cur_improvement > best_improvement): - best_improvement = cur_improvement - best_permutation = permutation - seconds = time.perf_counter() - start_time - return matrix[:, best_permutation], seconds, best_permutation, best_improvement - - -############# -# Stripe group handling -############# - -# gather stripes from a larger matrix into a single matrix -def collect_stripes(matrix, stripes, group_width): - subset = np.zeros((matrix.shape[0], len(stripes)*group_width)) - #print("[Debug][collect_stripes] matrix shape info: {}".format(matrix.shape)) - #print("[Debug][collect_stripes] subset info: {}, {}, {}".format(matrix.shape[0], len(stripes), group_width)) - for s,stripe in enumerate(stripes): - #print("[Debug][collect_stripes] s: {}, stripe: {}".format(s, stripe)) - subset[...,s*group_width:s*group_width+group_width] = matrix[...,stripe*group_width:stripe*group_width+group_width] - return subset - -# apply the stripe group permutation to the entire permutation -def apply_stripe_group_permutation(sgp, stripes, group_width, permutation): - new_permutation = permutation.copy() - for subset_idx in range(len(sgp)): - dst_stripe_idx = stripes[int(subset_idx / group_width)] - dst_col_idx = subset_idx % group_width - - subset_val = sgp[subset_idx] - src_stripe_idx = stripes[int(subset_val / group_width)] - src_col_idx = subset_val % group_width - - new_permutation[dst_stripe_idx*group_width + dst_col_idx] = permutation[src_stripe_idx*group_width + src_col_idx] - - return new_permutation - -# generate all possible stripe groups -def generate_stripe_groups(num_stripes, window_size): - stripe_array = [[c] for c in range(num_stripes)] - - next_stripe_array = [] - for w in range(1, window_size): - for g in range(len(stripe_array)): - start_c = stripe_array[g][w-1]+1 - group = stripe_array[g] - for c in range(start_c, num_stripes): - new_group = group.copy() - new_group.append(c) - next_stripe_array.append(new_group) - stripe_array = next_stripe_array - next_stripe_array = [] - - return set(tuple(stripe_array[g]) for g in range(len(stripe_array))) - -# It is not safe to just reset the stripe_set as None here. -# When calling the Exhaustive_Search in E2E search, the stripe_set will not be reset as None. -stripe_set = None -stripe_set_config = None -# build the stripe map -def build_stripe_map(matrix, group_width, window_size, stripe_map, stripe_ids, perm_map, used_stripes): - global stripe_set, stripe_set_config - #print("[Debug][build_stripe_map] Now the stripe_set value is: {}".format(stripe_set)) - - window_size = int(window_size / group_width) - - if stripe_set is None or stripe_set_config is None or stripe_set_config != (group_width, window_size): - num_stripes = int(matrix.shape[1] / group_width) - assert(group_width * num_stripes == matrix.shape[1]) - stripe_set = generate_stripe_groups(num_stripes, window_size) - #print("[Debug][build_stripe_map] Update stripe_set value as: {}".format(stripe_set)) - stripe_set_config = (group_width, window_size) - - # step through each, update the stripe_map/stripe_ids if necessary - updates = 0 - use_cuda = use_gpu() - gpu_list = [] - gpu_groups = [] - for i,s in enumerate(stripe_set): - sg = [] # build the group of stripes, check if any members changed - need_update = i >= len(stripe_map) - for stripe in s: - sg.append(stripe) - if stripe in used_stripes: - need_update = True - - # pre-populate if we're building fresh - if i >= len(stripe_map): - stripe_ids.append(sg) - stripe_map.append(0.) - perm_map.append([c for c in range(group_width * window_size)]) - - # update entries if needed (only stripe_map and perm_map) - if need_update: - updates += 1 - - if not use_cuda: # do the work here if using the CPU - subset = collect_stripes(matrix, sg, group_width) - sub_result, sub_duration, permutation, improvement = search_matrix(subset, group_width) - stripe_map[i] = improvement - perm_map[i] = permutation - else: # otherwise, just track the work needed to farm off to the GPU - gpu_groups.append(sg) - gpu_list.append(i) - - if use_cuda: # if using the GPU, perform the work - matrix_view = np.copy(matrix).astype(np.float32).flatten() - all_permutations = generate_all_unique_combinations(window_size*group_width, group_width) - num_permutations = len(all_permutations) - permutation_view = np.copy(np.asarray(all_permutations)).astype(np.uint32).flatten() - stripe_groups_view = np.asarray(gpu_groups).astype(np.uint32).flatten() - num_gpu_groups = len(gpu_list) - gpu_improvement = np.zeros((num_gpu_groups), dtype=np.float32).flatten() - gpu_permutation = np.zeros((num_gpu_groups), dtype=np.uint32).flatten() - result = permutation_search_cuda_kernels.build_permute_map(matrix_view, - matrix.shape[0], - matrix.shape[1], - stripe_groups_view, - num_gpu_groups, - window_size, - permutation_view, - window_size * group_width, - gpu_improvement, - gpu_permutation) - - # put the data where python expects it - for i in range(len(gpu_list)): - stripe_map[gpu_list[i]] = gpu_improvement[i] - perm_map[gpu_list[i]] = all_permutations[gpu_permutation[i]] - - return stripe_map, stripe_ids, perm_map - - -# start performing stripe checks -sm_perturbations = 0 -sm_perturbation_limit = 0 -def use_stripe_map(matrix, group_width, stripe_map, stripe_ids, perm_map, permutation): - global sm_perturbations, sm_perturbation_limit - used_stripes = [] - stripe_groups_optimized = 0 - improvement = 0.0 - - # set the traversal order - ix = np.flip(np.argsort(stripe_map)) # small to large --> large to small - - for i in range(len(ix)): - stripe_group_id = ix[i] - perm = perm_map[stripe_group_id].copy() - - if stripe_map[stripe_group_id] <= 0.0001: - # perturbations - if len(used_stripes) == 0 and sm_perturbations < sm_perturbation_limit: - sm_perturbations += 1 - # use this permutation, but swap two channels from left/right halves to include two stripes, no matter the group size - stripe_group_id = ix[np.random.randint(len(ix))] - perm = perm_map[stripe_group_id].copy() - # a little easier to escape from - src = np.random.randint(int(len(perm)/2)) - dst = int(len(perm)/2) + np.random.randint(int(len(perm)/2)) - perm[src],perm[dst] = perm[dst],perm[src] - else: - break - - stripe_group = stripe_ids[stripe_group_id] - - # don't work on stripes we've already touched - touched_stripe = False - for stripe in stripe_group: - if stripe in used_stripes: - touched_stripe = True - if touched_stripe: - continue - - # apply the permutation we've already found to this stripe group - subset = collect_stripes(matrix, stripe_group, group_width) - sub_result = subset[...,perm] - permutation = apply_stripe_group_permutation(perm, stripe_group, group_width, permutation) - - # scatter the results, track what changed - for s,stripe in enumerate(stripe_group): - # see if this group is in canonical form (entry 0 a multiple of 4, contiguous values)) - group = perm[s*group_width:s*group_width+group_width] # columns in this group of the used permutation - changed = False - if group[0] % 4 != 0: - changed = True - for c in range(1,group_width): - if group[c] != group[c-1]+1: - changed = True - break - # if it's not, then it changed - if changed: - used_stripes.append(stripe_group[s]) - - matrix[...,stripe*group_width:stripe*group_width+group_width] = sub_result[...,s*group_width:s*group_width+group_width] - - improvement += stripe_map[stripe_group_id] - stripe_groups_optimized += 1 - - return matrix, stripe_groups_optimized, stripe_map, stripe_ids, used_stripes, improvement, permutation - -# entry point for exhaustive searches - both the entire matrix, as well as stripe groups -def Exhaustive_Search(matrix, stripe_group_size=-1, escape_attempts=0, permutation=None): - global sm_perturbation_limit, sm_perturbations - sm_perturbations = 0 - sm_perturbation_limit = escape_attempts - if permutation is None: - permutation = [c for c in range(matrix.shape[1])] - - # It is much safer to reset the stripe_set as None in the entry point of Exhaustive_Search - global stripe_set, stripe_set_config - stripe_set = None - stripe_set_config = None - - # only support N:4 for now - group_width = 4 - - result = np.copy(matrix) - - # if the matrix is too large for a window size of 12, subdivide, then fix up with a global optimization with a window size of 8 - if group_width==4 and stripe_group_size==12 and matrix.shape[1] > 512: - stripe_split = int(matrix.shape[1]/2/group_width) - col_split = stripe_split * group_width - result[:,:col_split], durationL, permutation[:col_split] = Exhaustive_Search(result[:,:col_split], stripe_group_size=stripe_group_size, escape_attempts=escape_attempts, permutation=permutation[:col_split]) - result[:,col_split:], durationR, permutation[col_split:] = Exhaustive_Search(result[:,col_split:], stripe_group_size=stripe_group_size, escape_attempts=escape_attempts, permutation=permutation[col_split:]) - escape_attempts = max(escape_attempts, 100)*10 - result,duration,permutation = Exhaustive_Search(result, stripe_group_size=8, escape_attempts=escape_attempts, permutation=permutation) - return result, durationL+durationR+duration, permutation - - # small enough to optimize the entire matrix at once - if stripe_group_size != -1 and stripe_group_size < matrix.shape[1]: - stripe_map = [] - stripe_ids = [] - perm_map = [] - used_stripes = [] - optimized_groups_count = 0 - agg_improvement = 0. - cur_total_sum = sum_after_2_to_4(result) - - # in practice, this work will be cached ahead of time; doing it now. - # (Reading the cached list from disk can take several seconds, which shouldn't be counted against the search, but amortized over every layer in a network) - generate_all_unique_combinations(stripe_group_size, group_width) - - start_time = time.perf_counter() - - while True: - #print("[Debug][Exhaustive_Search] Before entering the build_stripe_map function.") - #print("[Debug][Exhaustive_Search] Now the stripe_set value is: {}".format(stripe_set)) - stripe_map, stripe_ids, perm_map = build_stripe_map(result, group_width, stripe_group_size, stripe_map, stripe_ids, perm_map, used_stripes) - result, stripe_groups_optimized, stripe_map, stripe_ids, used_stripes, improvement, permutation = use_stripe_map(result, group_width, stripe_map, stripe_ids, perm_map, permutation) - - # converged? - if len(used_stripes) == 0: - break - - duration = time.perf_counter() - start_time - - else: # no sliding window, single iteration - print(f"Matrix has {matrix.shape[1]} columns and the search window is only {stripe_group_size}: searching exhaustively") - result, duration, permutation, improvement = search_matrix(matrix, group_width) - - return result, duration, permutation diff --git a/apex/contrib/sparsity/permutation_search_kernels/permutation_utilities.py b/apex/contrib/sparsity/permutation_search_kernels/permutation_utilities.py deleted file mode 100644 index 55a18e1..0000000 --- a/apex/contrib/sparsity/permutation_search_kernels/permutation_utilities.py +++ /dev/null @@ -1,113 +0,0 @@ -import numpy as np -import time -import ctypes -import subprocess -import os -import math - -gpus_tested = False -gpus_found = 0 -kernels_found = True -try: - import permutation_search_cuda as permutation_search_cuda_kernels - print(f"Found permutation search CUDA kernels") -except ImportError: - print(f"Could not find permutation search CUDA kernels, falling back to CPU path") - kernels_found = False - -def use_gpu(initial_override = True): - global gpus_tested, gpus_found, kernels_found - if not gpus_tested: - if not initial_override: - gpus_tested = True - return False - - try: - gpus_found = str(subprocess.check_output(["nvidia-smi", "-L"])).count('UUID') - print(f"Found {gpus_found} gpus") - except: - gpus_found = 0 - print(f"Could not find nvidia-smi, please check your cuda installation") - - gpus_tested = True - - return gpus_found > 0 and kernels_found - -############################################################################################## -# pruning utilities -############################################################################################## -## apply 2:4 to some matrix -def apply_2_to_4(matrix): - for row in range(matrix.shape[0]): - for col in range(0,matrix.shape[1],4): - ix = np.argsort(np.abs(matrix[row,col:col+4])) - matrix[row,col+ix[0]] = 0.0 - matrix[row,col+ix[1]] = 0.0 - return matrix - -## find the sum of magnitudes if 2:4 were applied to a matrix -def sum_after_2_to_4(matrix): - #matrix = np.copy(matrix) - cur_sum = 0.0 - use_cuda = use_gpu() - if not use_cuda: - start_time = time.perf_counter() - for row in range(matrix.shape[0]): - for col in range(0,matrix.shape[1],4): - ix = np.argsort(np.abs(matrix[row,col:col+4])) - cur_sum += abs(matrix[row,col+ix[2]]) - cur_sum += abs(matrix[row,col+ix[3]]) - np_elapsed = time.perf_counter() - start_time - else: - matrix = matrix.astype(np.float32) - cuda_sum = np.zeros((1), dtype=np.float32) - start_time = time.perf_counter() - matrix_view = np.copy(matrix).flatten() - sum_view = cuda_sum.flatten() - blocks = max(int(matrix.shape[1]/4/2), 1) - threads = min(max(math.ceil(matrix.shape[0]/4), 1), 1024) - result = permutation_search_cuda_kernels.sum_after_2_to_4(matrix_view, - matrix.shape[0], - matrix.shape[1], - 0, - matrix.shape[1], - blocks, - threads, - sum_view) - cuda_elapsed = time.perf_counter() - start_time - #print(cuda_sum, cuda_elapsed, cur_sum, np_elapsed, np_elapsed/cuda_elapsed) - cur_sum = sum_view[0] - return cur_sum - -## try swapping columns and tracking magnitude after pruning -def try_swap(matrix, dst, src): - src_base = sum_after_2_to_4(matrix[...,int(src/4)*4:int(src/4)*4+4]) - dst_base = sum_after_2_to_4(matrix[...,int(dst/4)*4:int(dst/4)*4+4]) - - # swap - matrix[...,[src,dst]] = matrix[...,[dst,src]] - - # check the Nx4 slices of the swapped columns - src_sum = sum_after_2_to_4(matrix[...,int(src/4)*4:int(src/4)*4+4]) - dst_sum = sum_after_2_to_4(matrix[...,int(dst/4)*4:int(dst/4)*4+4]) - - # swap back - matrix[...,[src,dst]] = matrix[...,[dst,src]] - - return src_sum + dst_sum, (src_sum + dst_sum) - (src_base + dst_base) - -############################################################################################## -# permutation utilities -############################################################################################## - -## find the permutation needed to make matrix A look like matrix B -def find_permutation(A, B): - permutation = [] - for col in range(A.shape[1]): - Avals = A[...,col] - for bcol in range(B.shape[1]): - if np.all(Avals - B[...,bcol] == np.zeros(Avals.shape)): - permutation.append(bcol) - break - return permutation - diff --git a/apex/contrib/sparsity/sparse_masklib.py b/apex/contrib/sparsity/sparse_masklib.py deleted file mode 100644 index 48deb63..0000000 --- a/apex/contrib/sparsity/sparse_masklib.py +++ /dev/null @@ -1,184 +0,0 @@ -import sys -import torch -import numpy as np -import collections -from itertools import permutations - - -""" compute density (helper fn to compute % NNZs in a tensor) """ -def fill(x): - return float(x.nonzero().size(0))/torch.numel(x) - -""" reshape matrix into m-dimensional vectors: (h,w) -> (hw/m, m) """ -def reshape_1d(matrix, m): - # If not a nice multiple of m, fill with zeroes. - if matrix.shape[1] % m > 0: - mat = torch.cuda.FloatTensor(matrix.shape[0], matrix.shape[1] + (m-matrix.shape[1]%m)).fill_(0) - mat[:, :matrix.shape[1]] = matrix - shape = mat.shape - return mat.view(-1,m),shape - else: - return matrix.view(-1,m), matrix.shape - -""" return all possible m:n patterns in a 1d vector """ -valid_m4n2_1d_patterns = None -def compute_valid_1d_patterns(m,n): - # Early exit if patterns was already created. - global valid_m4n2_1d_patterns - - if m==4 and n==2 and valid_m4n2_1d_patterns is not None: return valid_m4n2_1d_patterns - patterns = torch.zeros(m) - patterns[:n] = 1 - valid_patterns = torch.tensor(list(set(permutations(patterns.tolist())))) - if m == 4 and n == 2: valid_m4n2_1d_patterns = valid_patterns - return valid_patterns - -""" m:n 1d structured best """ -def mn_1d_best(matrix, m, n): - # Find all possible patterns. - patterns = compute_valid_1d_patterns(m,n).cuda() - - # Find the best m:n pattern (sum of non-masked weights). - mask = torch.cuda.IntTensor(matrix.shape).fill_(1).view(-1,m) - mat,shape = reshape_1d(matrix,m) - pmax = torch.argmax(torch.matmul(mat.abs(),patterns.t()), dim=1) - mask[:] = patterns[pmax[:]] - mask = mask.view(matrix.shape) - return mask - -def m4n2_1d(mat, density): - return mn_1d_best(mat, 4, 2) - -""" - Below 2d-masking related code is targeted more for training (from scratch). - 2d-pruning of a weight tensor is done to accelerate DGRAD step during backprop - phase of training algorithm. Acceleration comes from using SpMMA instructions in - Tensor Cores of NVIDIA Ampere GPU Architecture - (note: this code does not do the acceleration, GPU kernels are required for this). - 1d pruning of weight tensor helps speed up FPROP step by pruning in 2:4 pattern - along the horizontal (logical) direction. - During DGRAD step, weight tensor is transposed. 2d pruning functions below, mask - weight tensor such that their transposed versions are also 2:4 sparse along the - horizontal (logical) direction. Thus, with 2d pruning, weight tensors are - 2:4 sparse along row and column directions. - """ - -""" m:n 2d structured pruning: greedy method to select mask """ -def mn_2d_greedy(matrix, m, n): - # Convert to numpy - mat = matrix.cpu().detach().numpy() - mask = np.ones(mat.shape, dtype=int) - - rowCount = int(mat.shape[0]/m) * m - colCount = int(mat.shape[1]/m) * m - for rowStartIdx in range(0, rowCount, m): - rowEndIdx = rowStartIdx + m - for colStartIdx in range(0, colCount, m): - colEndIdx = colStartIdx + m - matrixSub = np.absolute(np.squeeze(mat[rowStartIdx:rowEndIdx, colStartIdx:colEndIdx])) - maskSub = np.squeeze(mask[rowStartIdx:rowEndIdx, colStartIdx:colEndIdx]) - maskSub.fill(0.0) - matrixVecView = matrixSub.reshape(-1) - maskVecView = maskSub.reshape(-1) - linearIdx = np.argsort(matrixVecView) - matrixIdx = [(int(x/m), x % m) for x in linearIdx] - rowCounter = collections.Counter() - colCounter = collections.Counter() - for currIdx in range(len(linearIdx) - 1, -1, -1): - currMatrixEntry = matrixIdx[currIdx] - if (rowCounter[currMatrixEntry[0]] == n) or (colCounter[currMatrixEntry[1]] == n): - continue - #end if - maskSub[currMatrixEntry[0], currMatrixEntry[1]] = 1.0 - rowCounter[currMatrixEntry[0]] += 1 - colCounter[currMatrixEntry[1]] += 1 - - return torch.tensor(mask.cuda()) - -def m4n2_2d_greedy(mat, density): - return mn_2d_greedy(mat, 4, 2) - -""" return all possible m:n patterns in a mxn block. """ -valid_m4n2_2d_patterns = None -def compute_valid_2d_patterns(m,n): - # Early exit if patterns was already created. - global valid_m4n2_2d_patterns - if valid_m4n2_2d_patterns is not None: return valid_m4n2_2d_patterns - - patterns = torch.zeros(m) - patterns[:n] = 1 - patterns = list(set(permutations(patterns.tolist()))) - patterns = patterns + patterns - patterns = torch.empty(list(set(permutations(patterns,m)))) - - valid = ((patterns.sum(dim=1) <= n).sum(dim=1) == m).nonzero().view(-1) - valid_patterns = torch.empty(valid.shape[0],m,m) - valid_patterns[:] = patterns[valid[:]] - - if m == 4 and n == 2: valid_m4n2_2d_patterns = valid_patterns - return valid_patterns - -""" m:n 2d structured pruning: exhaustive method to select best mask """ -def mn_2d_best(matrix, m, n): - # Find all possible patterns. - patterns = compute_valid_2d_patterns(m,n).cuda() - - # Find the best m:n pattern (sum of non-masked weights). - mask = torch.cuda.IntTensor(matrix.shape).fill_(1) - mat = reshape_2d(matrix,m,m).abs() - pmax = torch.argmax(torch.matmul(mat,patterns.view(patterns.shape[0],m*m).t()), dim=2) - - # Copy best m:n patterns into mask. - mat = mat.view(mat.shape[0]*mat.shape[1],-1) - pmax = pmax.view(pmax.shape[0]*pmax.shape[1]).unsqueeze(1).expand(-1,mat.shape[1]) - patterns = patterns.view(patterns.shape[0],patterns.shape[1]*patterns.shape[2]) - mat = torch.gather(patterns,0,pmax) - mat = reshape_2d_inv(mat.view(matrix.shape[0]//m,matrix.shape[1]//m,m,m)) - mask.copy_(mat.type(mask.type())) - return mask - -def m4n2_2d_best(mat, density): - return mn_2d_best(mat, 4, 2) - - -""" returns a sparse mask """ -def create_mask(tensor, pattern="m4n2_1d", density=0.5): - # Reshape tensor and mask. - shape = tensor.shape - ttype = tensor.type() - t = tensor.float().contiguous() - - # 1d-tensor - if len(shape) == 1: - t = t.view(1, shape[0]) - func = getattr(sys.modules[__name__], pattern, None) - mask = func(t, density) - return mask.view(shape).type(ttype) - # 2d-tensor (in, out) - elif len(shape) == 2: - t = t.view(shape[0], shape[1]) - func = getattr(sys.modules[__name__], pattern, None) - mask = func(t, density) - return mask.view(shape).type(ttype) - # 3d-tensor (batch, in, out) - elif len(shape) == 3: - t = t.view(shape[0]*shape[1], shape[2]) - func = getattr(sys.modules[__name__], pattern, None) - mask = func(t, density) - return mask.view(shape).type(ttype) - # 4d-tensor (in, out, h, w) - elif len(shape) == 4: - """ - # transformers (bmm) - t = t.view(shape[0]*shape[1]*shape[2], shape[3]) - func = getattr(sys.modules[__name__], pattern, None) - mask = func(t, density) - return mask.view(shape).type(ttype) - """ - # convs - t = t.permute(2,3,0,1).contiguous().view(shape[2]*shape[3]*shape[0], shape[1]) - func = getattr(sys.modules[__name__], pattern, None) - mask = func(t, density) - mask = mask.view(shape[2], shape[3], shape[0], shape[1]).permute(2,3,0,1).contiguous() - return mask.view(shape).type(ttype) - diff --git a/apex/contrib/sparsity/test/checkpointing_test_part1.py b/apex/contrib/sparsity/test/checkpointing_test_part1.py deleted file mode 100644 index 34232b8..0000000 --- a/apex/contrib/sparsity/test/checkpointing_test_part1.py +++ /dev/null @@ -1,94 +0,0 @@ -from collections import OrderedDict - -import torch -from apex.optimizers import FusedAdam -from apex.contrib.sparsity import ASP - -def build_model(args): - od = OrderedDict() - for i in range(args.num_layers): - if i == 0: - od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.input_features, args.hidden_features) - od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.hidden_features]) - elif i == args.num_layers-1: - od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.hidden_features, args.output_features) - od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.output_features]) - else: - od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.hidden_features, args.hidden_features) - od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.hidden_features]) - return torch.nn.Sequential(od) - -def train_step(args, model, optimizer, input_batch, target_batch, step): - predicted_target = model(input_batch) - loss = ((predicted_target-target_batch)**2).sum() - loss.backward() - optimizer.step() - optimizer.zero_grad() - step = step + 1 - #print("Step %d :: loss=%e" % (step, loss.item())) - return step - -def train_loop(args, model, optimizer, step, num_steps): - for i in range(num_steps): - input_batch = torch.randn([args.batch_size, args.input_features]).cuda() - target_batch = torch.randn([args.batch_size, args.output_features]).cuda() - step = train_step(args, model, optimizer, input_batch, target_batch, step) - return step - -def main(args): - # - # PART1 - # - - torch.manual_seed(args.seed) - - model = build_model(args).cuda() - one_ll = next(model.children()).weight - optimizer = FusedAdam(model.parameters()) - ASP.init_model_for_pruning(model, args.pattern, verbosity=args.verbosity, whitelist=args.whitelist, allow_recompute_mask=args.allow_recompute_mask) - ASP.init_optimizer_for_pruning(optimizer) - - step = 0 - - # train for a few steps with dense weights - print("DENSE :: ",one_ll) - step = train_loop(args, model, optimizer, step, args.num_dense_steps) - - # simulate sparsity by inserting zeros into existing dense weights - ASP.compute_sparse_masks() - - # train for a few steps with sparse weights - print("SPARSE :: ",one_ll) - step = train_loop(args, model, optimizer, step, args.num_sparse_steps) - - torch.save({ - 'step': step, - 'verbosity': args.verbosity, - 'seed2': args.seed2, - 'pattern': args.pattern, - 'whitelist': args.whitelist, - 'allow_recompute_mask': args.allow_recompute_mask, - 'model_state_dict': model.state_dict(), - 'optimizer_state_dict': optimizer.state_dict(), - }, args.checkpoint_path) - -if __name__ == '__main__': - class Args: - verbosity=3 - seed = 4873 - seed2 = 99875 - pattern = "m4n2_2d_best" - whitelist = [torch.nn.Linear] - allow_recompute_mask = True - batch_size = 32 - input_features = 8 - output_features = 8 - hidden_features = 32 - num_layers = 4 - num_dense_steps = 2000 - num_sparse_steps = 3000 - num_sparse_steps_2 = 1000 - checkpoint_path = "part1.chkp" - args = Args() - - main(args) diff --git a/apex/contrib/sparsity/test/checkpointing_test_part2.py b/apex/contrib/sparsity/test/checkpointing_test_part2.py deleted file mode 100644 index d2b161c..0000000 --- a/apex/contrib/sparsity/test/checkpointing_test_part2.py +++ /dev/null @@ -1,79 +0,0 @@ -from collections import OrderedDict - -import torch -from apex.optimizers import FusedAdam -from apex.contrib.sparsity import ASP - -def build_model(args): - od = OrderedDict() - for i in range(args.num_layers): - if i == 0: - od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.input_features, args.hidden_features) - od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.hidden_features]) - elif i == args.num_layers-1: - od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.hidden_features, args.output_features) - od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.output_features]) - else: - od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.hidden_features, args.hidden_features) - od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.hidden_features]) - return torch.nn.Sequential(od) - -def train_step(args, model, optimizer, input_batch, target_batch, step): - predicted_target = model(input_batch) - loss = ((predicted_target-target_batch)**2).sum() - loss.backward() - optimizer.step() - optimizer.zero_grad() - step = step + 1 - #print("Step %d :: loss=%e" % (step, loss.item())) - return step - -def train_loop(args, model, optimizer, step, num_steps): - for i in range(num_steps): - input_batch = torch.randn([args.batch_size, args.input_features]).cuda() - target_batch = torch.randn([args.batch_size, args.output_features]).cuda() - step = train_step(args, model, optimizer, input_batch, target_batch, step) - return step - -def main(step, args, model_state_dict, optimizer_state_dict): - # - # PART2 - # - - model = build_model(args).cuda() - one_ll = next(model.children()).weight - optimizer = FusedAdam(model.parameters()) - ASP.init_model_for_pruning(model, args.pattern, verbosity=args.verbosity, whitelist=args.whitelist, allow_recompute_mask=args.allow_recompute_mask) - ASP.init_optimizer_for_pruning(optimizer) - - torch.manual_seed(args.seed2) - model.load_state_dict(model_state_dict) - optimizer.load_state_dict(optimizer_state_dict) - - print("Model sparsity is %s" % ("enabled" if ASP.is_sparsity_enabled() else "disabled")) - - # train for a few steps with sparse weights - print("SPARSE :: ",one_ll) - step = train_loop(args, model, optimizer, step, args.num_sparse_steps_2) - -if __name__ == '__main__': - checkpoint = torch.load("part1.chkp") - class Args: - verbosity = checkpoint['verbosity'] - seed = 4873 - seed2 = checkpoint['seed2'] - pattern = checkpoint['pattern'] - whitelist = checkpoint['whitelist'] - allow_recompute_mask = checkpoint['allow_recompute_mask'] - batch_size = 32 - input_features = 8 - output_features = 8 - hidden_features = 32 - num_layers = 4 - num_dense_steps = 2000 - num_sparse_steps = 3000 - num_sparse_steps_2 = 1000 - checkpoint_path = "part1.chkp" - args = Args() - - main(checkpoint['step'], args, checkpoint['model_state_dict'], checkpoint['optimizer_state_dict']) diff --git a/apex/contrib/sparsity/test/checkpointing_test_reference.py b/apex/contrib/sparsity/test/checkpointing_test_reference.py deleted file mode 100644 index 57ea51a..0000000 --- a/apex/contrib/sparsity/test/checkpointing_test_reference.py +++ /dev/null @@ -1,96 +0,0 @@ -from collections import OrderedDict - -import torch -from apex.optimizers import FusedAdam -from apex.contrib.sparsity import ASP - -# -# Reference run for checkpointing test (part1 + part2) -# - -def build_model(args): - od = OrderedDict() - for i in range(args.num_layers): - if i == 0: - od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.input_features, args.hidden_features) - od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.hidden_features]) - elif i == args.num_layers-1: - od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.hidden_features, args.output_features) - od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.output_features]) - else: - od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.hidden_features, args.hidden_features) - od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.hidden_features]) - return torch.nn.Sequential(od) - -def train_step(args, model, optimizer, input_batch, target_batch, step): - predicted_target = model(input_batch) - loss = ((predicted_target-target_batch)**2).sum() - loss.backward() - optimizer.step() - optimizer.zero_grad() - step = step + 1 - #print("Step %d :: loss=%e" % (step, loss.item())) - return step - -def train_loop(args, model, optimizer, step, num_steps): - for i in range(num_steps): - input_batch = torch.randn([args.batch_size, args.input_features]).cuda() - target_batch = torch.randn([args.batch_size, args.output_features]).cuda() - step = train_step(args, model, optimizer, input_batch, target_batch, step) - return step - -def main(args): - # - # PART1 - # - - torch.manual_seed(args.seed) - - model = build_model(args).cuda() - one_ll = next(model.children()).weight - optimizer = FusedAdam(model.parameters()) - ASP.init_model_for_pruning(model, args.pattern, whitelist=args.whitelist, allow_recompute_mask=args.allow_recompute_mask) - ASP.init_optimizer_for_pruning(optimizer) - - step = 0 - - # train for a few steps with dense weights - print("DENSE :: ",one_ll) - step = train_loop(args, model, optimizer, step, args.num_dense_steps) - - # simulate sparsity by inserting zeros into existing dense weights - ASP.compute_sparse_masks() - - # train for a few steps with sparse weights - print("SPARSE :: ",one_ll) - step = train_loop(args, model, optimizer, step, args.num_sparse_steps) - - # - # PART 2 - # - - torch.manual_seed(args.seed2) - - # train for a few steps with sparse weights - print("SPARSE :: ",one_ll) - step = train_loop(args, model, optimizer, step, args.num_sparse_steps_2) - -if __name__ == '__main__': - class Args: - seed = 4873 - seed2 = 99875 - pattern = "m4n2_2d_best" - whitelist = [torch.nn.Linear] - allow_recompute_mask = True - batch_size = 32 - input_features = 8 - output_features = 8 - hidden_features = 32 - num_layers = 4 - num_dense_steps = 2000 - num_sparse_steps = 3000 - num_sparse_steps_2 = 1000 - checkpoint_path = "part1.chkp" - args = Args() - - main(args) diff --git a/apex/contrib/sparsity/test/toy_problem.py b/apex/contrib/sparsity/test/toy_problem.py deleted file mode 100644 index 2145323..0000000 --- a/apex/contrib/sparsity/test/toy_problem.py +++ /dev/null @@ -1,87 +0,0 @@ -from collections import OrderedDict - -import torch -from apex.optimizers import FusedAdam -from apex.contrib.sparsity import ASP - -def build_model(args): - od = OrderedDict() - for i in range(args.num_layers): - if i == 0: - od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.input_features, args.hidden_features) - od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.hidden_features]) - elif i == args.num_layers-1: - od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.hidden_features, args.output_features) - od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.output_features]) - else: - od['linear_layer_%d' % (i+1)] = torch.nn.Linear(args.hidden_features, args.hidden_features) - od['layer_norm_%d' % (i+1)] = torch.nn.LayerNorm([args.batch_size, args.hidden_features]) - return torch.nn.Sequential(od) - -def train_step(args, model, optimizer, input_batch, target_batch, step): - predicted_target = model(input_batch) - loss = ((predicted_target-target_batch)**2).sum() - loss.backward() - optimizer.step() - optimizer.zero_grad() - step = step + 1 - #print("Step %d :: loss=%e" % (step, loss.item())) - return step - -def train_loop(args, model, optimizer, step, num_steps): - for i in range(num_steps): - input_batch = torch.randn([args.batch_size, args.input_features]).cuda() - target_batch = torch.randn([args.batch_size, args.output_features]).cuda() - step = train_step(args, model, optimizer, input_batch, target_batch, step) - return step - -def main(args): - model = build_model(args).cuda() - one_ll = next(model.children()).weight - optimizer = FusedAdam(model.parameters()) - # only prune linear layers, even though we also support conv1d, conv2d and conv3d - ASP.init_model_for_pruning(model, "m4n2_1d", whitelist=[torch.nn.Linear], allow_recompute_mask=True) - ASP.init_optimizer_for_pruning(optimizer) - - step = 0 - - # train for a few steps with dense weights - print("DENSE :: ",one_ll) - step = train_loop(args, model, optimizer, step, args.num_dense_steps) - - # simulate sparsity by inserting zeros into existing dense weights - ASP.compute_sparse_masks() - - # train for a few steps with sparse weights - print("SPARSE :: ",one_ll) - step = train_loop(args, model, optimizer, step, args.num_sparse_steps) - - # recompute sparse masks - ASP.compute_sparse_masks() - - # train for a few steps with sparse weights - print("SPARSE :: ",one_ll) - step = train_loop(args, model, optimizer, step, args.num_sparse_steps_2) - - # turn off sparsity - print("SPARSE :: ",one_ll) - ASP.restore_pruned_weights() - - # train for a few steps with dense weights - print("DENSE :: ",one_ll) - step = train_loop(args, model, optimizer, step, args.num_dense_steps_2) - -if __name__ == '__main__': - class Args: - batch_size = 32 - input_features = 16 - output_features = 8 - hidden_features = 40 - num_layers = 4 - num_dense_steps = 2000 - num_sparse_steps = 3000 - num_sparse_steps_2 = 1000 - num_dense_steps_2 = 1500 - args = Args() - - main(args) diff --git a/apex/contrib/test/clip_grad/test_clip_grad.py b/apex/contrib/test/clip_grad/test_clip_grad.py deleted file mode 100644 index 3c1cc90..0000000 --- a/apex/contrib/test/clip_grad/test_clip_grad.py +++ /dev/null @@ -1,162 +0,0 @@ -import random -import unittest - -import torch -from apex.contrib.clip_grad import clip_grad_norm_ - -def make_params( - num_params, - sizes=[1,2,3,4,5], - num_dims=[1,2,3], - dtypes=[torch.float32], - devices=['cuda'], - make_copy=False, -): - """Construct parameters with random configurations""" - - # Construct parameters - params = [] - for _ in range(num_params): - dims = [random.choice(sizes) for _ in range(random.choice(num_dims))] - dtype = random.choice(dtypes) - device = random.choice(devices) - p = torch.nn.Parameter(torch.randn(dims, dtype=dtype, device=device)) - p.grad = torch.randn_like(p) - params.append(p) - - # Copy parameters if needed - if make_copy: - params_copy = [] - for p in params: - p_copy = p.clone().detach() - p_copy.grad = p.grad.clone().detach() - params_copy.append(p_copy) - return params, params_copy - else: - return params - -class ClipGradNormTest(unittest.TestCase): - - def setUp(self, seed=1234): - random.seed(seed) - torch.manual_seed(seed) - - def test_matches_pytorch( - self, - num_params=41, - dtypes=[torch.float32, torch.float16, torch.float64], - devices=['cuda', 'cpu'], - max_norm=0.54321, - norm_type=2.0, - rtol=1e-3, - atol=1e-20, - ): - """Make sure PyTorch and Apex gradient clipping produce same results""" - - # Construct identical sets of parameters - torch_params, apex_params = make_params( - num_params, - dtypes=dtypes, - devices=devices, - make_copy=True, - ) - - # Apply gradient clipping - torch_norm = torch.nn.utils.clip_grad_norm_( - torch_params, - max_norm, - norm_type=norm_type, - ) - apex_norm = clip_grad_norm_( - apex_params, - max_norm, - norm_type=norm_type, - ) - - # Make sure PyTorch and Apex get same results - torch.testing.assert_close( - apex_norm, torch_norm, - rtol=rtol, - atol=atol, - check_dtype=False, - ) - for torch_p, apex_p in zip(torch_params, apex_params): - torch.testing.assert_close( - apex_p, torch_p, - rtol=0, - atol=0, - ) # Params should be unaffected - torch.testing.assert_close( - apex_p.grad, torch_p.grad, - rtol=rtol, - atol=atol, - ) - - def test_matches_pytorch_fp16(self): - self.test_matches_pytorch(num_params=11, dtypes=[torch.float16]) - - def test_matches_pytorch_fp32(self): - self.test_matches_pytorch(dtypes=[torch.float32], rtol=1e-6) - - def test_matches_pytorch_fp64(self): - self.test_matches_pytorch(dtypes=[torch.float64], rtol=1e-15) - - def test_matches_pytorch_cpu(self): - self.test_matches_pytorch(devices=['cpu']) - - def test_matches_pytorch_infnorm(self): - self.test_matches_pytorch(norm_type=float('inf')) - - def test_matches_pytorch_1norm(self): - self.test_matches_pytorch(norm_type=1.0) - - def test_raises_on_mismatch(self): - - # Construct different sets of parameters - torch_params, apex_params = make_params(7, make_copy=True) - with torch.no_grad(): - torch_params[0].grad.view(-1)[0] = 1.23 - apex_params[0].grad.view(-1)[0] = 3.21 - - # Apply gradient clipping - torch_norm = torch.nn.utils.clip_grad_norm_( - torch_params, - 0.54321, - ) - apex_norm = clip_grad_norm_( - apex_params, - 0.54321, - ) - - # Make sure PyTorch and Apex get different results - self.assertRaises( - AssertionError, - torch.testing.assert_close, - apex_norm, torch_norm, - rtol=1e-3, - atol=1e-20, - check_dtype=False, - ) - for torch_p, apex_p in zip(torch_params, apex_params): - self.assertRaises( - AssertionError, - torch.testing.assert_close, - apex_p.grad, torch_p.grad, - rtol=1e-3, - atol=1e-20, - ) - - def test_raises_on_nan(self): - params = make_params(5, num_dims=[1]) - params[2].grad[-1] = float('NaN') - self.assertRaises( - RuntimeError, clip_grad_norm_, params, 1.0, error_if_nonfinite=True) - - def test_raises_on_inf(self): - params = make_params(5, num_dims=[1]) - params[2].grad[-1] = float('inf') - self.assertRaises( - RuntimeError, clip_grad_norm_, params, 1.0, error_if_nonfinite=True) - -if __name__ == "__main__": - unittest.main() diff --git a/apex/contrib/test/conv_bias_relu/test_conv_bias_relu.py b/apex/contrib/test/conv_bias_relu/test_conv_bias_relu.py deleted file mode 100644 index 350257c..0000000 --- a/apex/contrib/test/conv_bias_relu/test_conv_bias_relu.py +++ /dev/null @@ -1,105 +0,0 @@ -import copy -import math -import random -import unittest - -import torch -import torch.nn.functional as F - -HAS_CONV_BIAS_RELU = None -try: - from apex.contrib.conv_bias_relu import ConvBiasReLU, ConvBias, ConvBiasMaskReLU -except ImportError as e: - HAS_CONV_BIAS_RELU = False -else: - HAS_CONV_BIAS_RELU = True - - -@unittest.skipIf(not HAS_CONV_BIAS_RELU, "`apex.contrib.conv_bias_relu` is not found.") -class FusedDenseTest(unittest.TestCase): - def setUp(self, seed=0): - torch.manual_seed(seed) - - self.batch_size = random.randint(1, 64) - self.in_channels = random.randint(1, 64) * 8 - self.out_channels = random.randint(1, 64) * 8 - self.in_height = self.in_width = random.randint(5, 100) - self.conv_kernel_size = random.randint(1, 5) - self.conv_pad = random.randint(0, int(self.conv_kernel_size / 2)) - self.conv_stride = random.randint(1, 5) - self.conv_dilation = 1 - self.out_height = self.out_width = \ - math.floor((self.in_height + 2 * self.conv_pad - \ - self.conv_dilation * (self.conv_kernel_size - 1) - 1) / self.conv_stride + 1) - - self.x = torch.randint(low=-16, high=16, - size=[self.batch_size, self.in_channels, self.in_height, self.in_width]) \ - .cuda().to(memory_format=torch.channels_last).float() - self.x_ = self.x.clone() - self.x.requires_grad_() - self.x_.requires_grad_() - - self.mask = torch.randn([self.batch_size, self.out_channels, self.out_height, self.out_width]).cuda().to(memory_format=torch.channels_last) - self.mask = (self.mask > 0).to(torch.int8) - self.mask_ = self.mask.clone() - - self.conv1 = torch.nn.Conv2d(self.in_channels, self.out_channels, self.conv_kernel_size, - stride=self.conv_stride, padding=self.conv_pad).cuda().to(memory_format=torch.channels_last) - self.conv1_ = copy.deepcopy(self.conv1) - - print() - print('> input=[{}, {}, {}, {}]'.format(self.batch_size, self.in_channels, self.in_height, self.in_width)) - print('> kernel=[{}, {}, {}, {}], stride={}, pad={}'.format(self.out_channels, self.in_channels, - self.conv_kernel_size, self.conv_kernel_size, - self.conv_stride, self.conv_pad)) - - def test_conv_bias_relu(self): - with torch.cuda.amp.autocast(dtype=torch.half): - out = ConvBiasReLU(self.x, self.conv1.weight, self.conv1.bias.reshape(1, -1, 1, 1), self.conv_pad, self.conv_stride) - loss = (out.float()**2).sum() / out.numel() - loss.backward() - with torch.cuda.amp.autocast(dtype=torch.half): - out_ = F.relu(self.conv1_(self.x_)) - loss_ = (out_**2).sum() / out_.numel() - loss_.backward() - - self.assertTrue(torch.allclose(self.x_, self.x, atol=1e-3, rtol=1e-3, equal_nan=True)) - self.assertTrue(torch.allclose(self.conv1_.bias.grad, self.conv1.bias.grad, atol=1e-3, rtol=1e-3, equal_nan=True)) - self.assertTrue(torch.allclose(self.conv1_.weight.grad, self.conv1.weight.grad, atol=1e-3, rtol=1e-3, equal_nan=True)) - self.assertTrue(torch.allclose(self.x_.grad, self.x.grad, atol=1e-3, rtol=1e-3, equal_nan=True)) - - def test_conv_bias(self): - with torch.cuda.amp.autocast(dtype=torch.half): - out = ConvBias(self.x, self.conv1.weight, self.conv1.bias.reshape(1, -1, 1, 1), self.conv_pad, self.conv_stride) - loss = (out.float()**2).sum() / out.numel() - loss.backward() - - with torch.cuda.amp.autocast(dtype=torch.half): - out_ = self.conv1_(self.x_) - loss_ = (out_**2).sum() / out_.numel() - loss_.backward() - - self.assertTrue(torch.allclose(self.x_, self.x, atol=1e-3, rtol=1e-3, equal_nan=True)) - self.assertTrue(torch.allclose(self.conv1_.bias.grad, self.conv1.bias.grad, atol=1e-3, rtol=1e-3, equal_nan=True)) - self.assertTrue(torch.allclose(self.conv1_.weight.grad, self.conv1.weight.grad, atol=1e-3, rtol=1e-3, equal_nan=True)) - self.assertTrue(torch.allclose(self.x_.grad, self.x.grad, atol=1e-3, rtol=1e-3, equal_nan=True)) - - def test_conv_bias_mask_relu(self): - with torch.cuda.amp.autocast(dtype=torch.half): - out = ConvBiasMaskReLU(self.x, self.conv1.weight, self.conv1.bias.reshape(1, -1, 1, 1), self.mask, self.conv_pad, self.conv_stride) - loss = (out.float()**2).sum() / out.numel() - loss.backward() - with torch.cuda.amp.autocast(dtype=torch.half): - out_ = F.relu(self.conv1_(self.x_) * self.mask_) - loss_ = (out_**2).sum() / out_.numel() - loss_.backward() - - self.assertTrue(torch.allclose(self.x_, self.x, atol=1e-3, rtol=1e-3, equal_nan=True)) - self.assertTrue(torch.allclose(self.conv1_.bias.grad, self.conv1.bias.grad, atol=1e-3, rtol=1e-3, equal_nan=True)) - self.assertTrue(torch.allclose(self.conv1_.weight.grad, self.conv1.weight.grad, atol=1e-3, rtol=1e-3, equal_nan=True)) - self.assertTrue(torch.allclose(self.x_.grad, self.x.grad, atol=1e-3, rtol=1e-3, equal_nan=True)) - - -if __name__ == '__main__': - unittest.main() - diff --git a/apex/contrib/test/fmha/test_fmha.py b/apex/contrib/test/fmha/test_fmha.py deleted file mode 100644 index 00970ee..0000000 --- a/apex/contrib/test/fmha/test_fmha.py +++ /dev/null @@ -1,136 +0,0 @@ -############################################################################### -# Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * Neither the name of the NVIDIA CORPORATION nor the -# names of its contributors may be used to endorse or promote products -# derived from this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND -# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY -# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; -# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND -# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -############################################################################### - -import math -import sys -import unittest - -import torch -import numpy as np - -import fmhalib as mha - - -def _get_device_properties(device = torch.device("cuda")): - # type: (str or torch.device) -> Tuple[int, int] - properties = torch.cuda.get_device_properties(device) - return properties.major, properties.minor - - -def py_mha(qkv, amask, b, s, h, d): - qkv = qkv.view(b, s, h, 3, d) - q = qkv[:, :, :, 0, :].permute(0,2,1,3) - k = qkv[:, :, :, 1, :].permute(0,2,1,3) - v = qkv[:, :, :, 2, :].permute(0,2,1,3) - p = torch.matmul(q.float(), k.permute(0,1,3,2).float()) - p_masked = p / math.sqrt(d) + (1.0 - amask) * -10000.0 - s = torch.softmax(p_masked, -1).to(qkv.dtype) - ctx = torch.matmul(s, v) - ctx = ctx.permute(0,2,1,3).contiguous() - - ctx.retain_grad() - - return ctx - -@unittest.skipIf(not _get_device_properties() == (8, 0), "FMHA only supports sm80") -class TestFMHA(unittest.TestCase): - - def run_test(self, s: int, b: int, zero_tensors: bool): - print(f'Test s={s} b={b}, zero_tensors={zero_tensors}') - - torch.manual_seed(1234) - torch.cuda.manual_seed(1234) - - dtype = torch.float16 - device = torch.device('cuda') - - h = 16 - d = 64 - - slens = [s] * b - a = torch.tensor(np.array([0] + slens), dtype=torch.int32) - amask = torch.ones(b,h,s,s, dtype=dtype, device=device) - seqlens = torch.tensor(slens, dtype=torch.int32, device=device) - cu_seqlens = torch.cumsum(a, 0).to(dtype=torch.int32, device=device) - total = cu_seqlens[-1].item() - - qkv = torch.randn((b,s,h,3,d), device=device, dtype=dtype) - - qkv_vs = qkv.permute(0,1,3,2,4).contiguous().view(b*s, 3, h,d) - - qkv.requires_grad = True - - if b < 4: - ctx, S_ = mha.fwd(qkv_vs, cu_seqlens, 0.0, s, True, True, zero_tensors, None) - else: - ctx, S_ = mha.fwd(qkv_vs, cu_seqlens, 0.0, s, True, False, zero_tensors, None) - ctx = ctx.view(b,s,h,d) - - ctx_ref = py_mha(qkv, amask, b,s,h,d) - self.assertTrue(torch.allclose(ctx_ref.float(), ctx.float(), atol=1e-3)) - - labels = torch.randn_like(ctx_ref) - diff = ctx_ref - labels - l = (diff * diff).sum() / b - l.backward() - - dw = ctx_ref.grad.permute(0,2,1,3) - - dw2 = dw.permute(0,2,1,3).clone().detach().contiguous() - - if b < 4: - dqkv2, _, _ = mha.bwd_nl(dw2, qkv_vs, S_, cu_seqlens, 0.0, s, zero_tensors) - else: - dqkv2, _ = mha.bwd(dw2, qkv_vs, S_, cu_seqlens, 0.0, s, zero_tensors) - - dqkv2 = dqkv2.permute(0,2,1,3).view(b,s, h,3,d) - - self.assertTrue(torch.allclose(qkv.grad.float(), dqkv2.float(), atol=1e-3)) - - def test_128(self): - self.run_test(128, 32, False) - self.run_test(128, 32, True) - - def test_256(self): - self.run_test(256, 32, False) - self.run_test(256, 32, True) - - def test_384(self): - self.run_test(384, 32, False) - self.run_test(384, 32, True) - - def test_512(self): - self.run_test(512, 32, False) - self.run_test(512, 32, True) - self.run_test(512, 2, False) - self.run_test(512, 2, True) - self.run_test(512, 3, False) - self.run_test(512, 3, True) - - -if __name__ == '__main__': - unittest.main() diff --git a/apex/contrib/test/focal_loss/test_focal_loss.py b/apex/contrib/test/focal_loss/test_focal_loss.py deleted file mode 100644 index 546e967..0000000 --- a/apex/contrib/test/focal_loss/test_focal_loss.py +++ /dev/null @@ -1,69 +0,0 @@ -import unittest - -import torch -import torch.nn.functional as F - -reference_available = True -try: - from torchvision.ops.focal_loss import sigmoid_focal_loss -except ImportError: - reference_available = False - -from apex.contrib.focal_loss import focal_loss - - -@unittest.skipIf(not reference_available, "Reference implementation `torchvision.ops.focal_loss.sigmoid_focal_loss` is not available.") -class FocalLossTest(unittest.TestCase): - - N_SAMPLES = 12 - N_CLASSES = 8 - ALPHA = 0.24 - GAMMA = 2.0 - REDUCTION = "sum" - - def test_focal_loss(self) -> None: - if not reference_available: - self.skipTest("This test needs `torchvision` for `torchvision.ops.focal_loss.sigmoid_focal_loss`.") - else: - x = torch.randn(FocalLossTest.N_SAMPLES, FocalLossTest.N_CLASSES).cuda() - with torch.no_grad(): - x_expected = x.clone() - x_actual = x.clone() - x_expected.requires_grad_() - x_actual.requires_grad_() - - classes = torch.randint(0, FocalLossTest.N_CLASSES, (FocalLossTest.N_SAMPLES,)).cuda() - with torch.no_grad(): - y = F.one_hot(classes, FocalLossTest.N_CLASSES).float() - - expected = sigmoid_focal_loss( - x_expected, - y, - alpha=FocalLossTest.ALPHA, - gamma=FocalLossTest.GAMMA, - reduction=FocalLossTest.REDUCTION, - ) - - actual = sum([focal_loss.FocalLoss.apply( - x_actual[i:i+1], - classes[i:i+1].long(), - torch.ones([], device="cuda"), - FocalLossTest.N_CLASSES, - FocalLossTest.ALPHA, - FocalLossTest.GAMMA, - 0.0, - ) for i in range(FocalLossTest.N_SAMPLES)]) - - # forward parity - torch.testing.assert_close(expected, actual) - - expected.backward() - actual.backward() - - # grad parity - torch.testing.assert_close(x_expected.grad, x_actual.grad) - - -if __name__ == "__main__": - torch.manual_seed(42) - unittest.main() diff --git a/apex/contrib/test/fused_dense/test_fused_dense.py b/apex/contrib/test/fused_dense/test_fused_dense.py deleted file mode 100644 index 301ebf6..0000000 --- a/apex/contrib/test/fused_dense/test_fused_dense.py +++ /dev/null @@ -1,44 +0,0 @@ -import torch -import unittest -import torch.nn.functional as F -from apex import fused_dense -from torch import nn -from apex import amp - -class FusedDenseTest(unittest.TestCase): - def setUp(self, seed=0): - torch.manual_seed(seed) - #torch.cuda.manual_seed_all(seed) - - self.seq_length = 512 - self.sequences = 3 - self.hidden_dim = 1024 - - self.ref_inputs = torch.randn(self.sequences*self.seq_length, self.hidden_dim, - dtype=torch.float16, device=torch.device("cuda")).int().half().requires_grad_(True) - - self.tst_inputs = self.ref_inputs.clone().detach().requires_grad_(True) - self.dense = fused_dense.FusedDense(1024, 3072) - self.dense.half() - self.dense.cuda() - - - def test_fused_dense(self) : - y_tst = self.dense(self.tst_inputs) - y_ref = torch.matmul(self.ref_inputs,self.dense.weight.t())+self.dense.bias - dy = torch.randn_like(y_tst).half() - y_tst.backward(dy) - dw_ref = torch.matmul(dy.t(), self.ref_inputs) - dx_ref = torch.matmul(dy, self.dense.weight.clone()) - db_ref = dy.sum(0, False) - - - self.assertTrue(torch.allclose(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5)) - self.assertTrue(torch.allclose(y_ref, y_tst, atol=1e-3, rtol=1e-3, equal_nan=True)) - self.assertTrue(torch.allclose(dw_ref, self.dense.weight.grad, atol=1e-3, rtol=1e-3, equal_nan=True)) - self.assertTrue(torch.allclose(dx_ref, self.tst_inputs.grad, atol=1e-3, rtol=1e-3, equal_nan=True)) - self.assertTrue(torch.allclose(db_ref, self.dense.bias.grad, atol=1e-3, rtol=1e-3, equal_nan=True)) - - -if __name__ == '__main__': - unittest.main() diff --git a/apex/contrib/test/groupbn/test_groupbn.py b/apex/contrib/test/groupbn/test_groupbn.py deleted file mode 100644 index 3df7917..0000000 --- a/apex/contrib/test/groupbn/test_groupbn.py +++ /dev/null @@ -1,185 +0,0 @@ -import torch -import unittest -import numpy as np -import random -from apex.contrib.groupbn.batch_norm import BatchNorm2d_NHWC - -def generate_uniform_tensor(size, np_dtype, pyt_dtype, device): - array = None - while array is None or np.isnan(array).any(): - array = np.random.uniform(low=-1.0, high=1.0, size=size).astype(np_dtype) - return torch.from_numpy(array).to(device).to(pyt_dtype) - -def to_channels_last(tensor): - return tensor.permute(0, 2, 3, 1).contiguous() - -def to_channels_first(tensor): - return tensor.permute(0, 3, 1, 2).contiguous() - -class Bn(torch.nn.BatchNorm2d): - def __init__(self, planes, mode): - super(Bn, self).__init__(planes, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) - self.mode = mode - - def forward(self, x, z=None): - out = super().forward(x) - if self.mode == 'bn_add_relu': - out = out.add_(z) - if self.mode != 'bn': - out = out.relu_() - return out - -def bn_nhwc_bwd_ref(grad_y, x, mu, ivar, gamma): - sum_dim_c = (0, 1, 2) - grad_y_f32 = grad_y.float() - x_f32 = x.float() - N = x.shape[0] * x.shape[1] * x.shape[2] # nhw - ones = torch.ones(x.shape, dtype=torch.float32, device='cuda') - - xmu = x_f32 - mu - xhat = xmu * ivar - - dbias = torch.sum(grad_y_f32, dim=sum_dim_c) - - dscale = torch.sum(grad_y_f32 * xhat, dim=sum_dim_c) - - dx1 = (gamma * ivar) / N - dx2 = (N * grad_y_f32) - (dbias * ones) - dx3 = -xhat * dscale - dx = dx1 * (dx2 + dx3) - dx = dx.half() - return dx, dscale, dbias - -class TestGroupBN(unittest.TestCase): - - def setUp(self, seed=5, verbose=False): - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - np.random.seed(seed) - self.verbose = verbose - - def test_bn(self): - self.run_group_bn('bn') - - def test_bn_relu(self): - self.run_group_bn('bn_relu') - - def test_bn_add_relu(self): - self.run_group_bn('bn_add_relu') - - def run_group_bn(self, mode): - if self.verbose: - print('Running {}'.format(mode)) - - tensor_sizes = [ - (120, 64, 75, 75), - (120, 128, 38, 38)] - - for i in range(len(tensor_sizes)): - tensor_size = tensor_sizes[i] - num_channels = tensor_size[1] - - # Create input data - input_data = generate_uniform_tensor(tensor_size, np.float16, torch.half, 'cuda') - np.save('input.npy', input_data.detach().cpu().numpy()) - input_data.requires_grad = True - - gbn_input = torch.from_numpy(np.load('input.npy')).cuda().half() - gbn_input.requires_grad = True - - residual_data = None - gbn_residual_data = None - if mode == 'bn': - fuse_relu = False - else: - fuse_relu = True - if mode == 'bn_add_relu': - residual_data = generate_uniform_tensor(tensor_size, np.float16, torch.half, 'cuda') - gbn_residual_data = to_channels_last(residual_data) - - bn_grad = generate_uniform_tensor(input_data.shape, np.float16, torch.half, 'cuda') - - # Create models - batchnorm_model = Bn(num_channels, mode).cuda() - group_batchnorm = BatchNorm2d_NHWC(num_channels, fuse_relu=fuse_relu, bn_group=1,torch_channels_last=False).cuda() - - # Run reference forward - bn_output = batchnorm_model(input_data, residual_data) - - # Run GBN forward - gbn_input_data = to_channels_last(gbn_input) - gbn_output = group_batchnorm(gbn_input_data, gbn_residual_data) - - torch.cuda.synchronize() - - # Run reference backward - # (Use the same input and parameters as GBN) - gbn_grad = to_channels_last(bn_grad) - grad = gbn_grad.clone().detach() - input_data = torch.from_numpy(np.load('input.npy')).cuda().half() - input_data = to_channels_last(input_data) - if mode != 'bn': - grad[gbn_output <= 0] = 0 - bn_output_grad, _, _ = bn_nhwc_bwd_ref( \ - grad, - input_data, - group_batchnorm.minibatch_mean, - group_batchnorm.minibatch_riv, - group_batchnorm.weight) - bn_output_grad = to_channels_first(bn_output_grad) - - # Run GBN backward - gbn_output.backward(gbn_grad) - torch.cuda.synchronize() - - gbn_output = to_channels_first(gbn_output) - gbn_output_grad = gbn_input.grad.detach().clone().cpu() - - ########################## Validate results ########################## - if self.verbose: - print('Validate activation') - self.validate(bn_output.shape, bn_output, gbn_output) - if self.verbose: - print('Validate grad') - self.validate(bn_output_grad.shape, bn_output_grad, gbn_output_grad, is_grad=True) - - def validate(self, tensors, output_ref, output_test, is_grad=False): - output_ref = output_ref.detach().cpu().numpy() - output_test = output_test.detach().cpu().numpy() - - if self.verbose: - print('>>> tensor_size\t{}'.format(tensors)) - print("sum_output_ref {}, isnan {}, max {}, min {}".format( - np.sum(output_ref, dtype=float), np.isnan(output_ref).any(), np.max(output_ref), np.min(output_ref))) - print("sum_output_test {}, isnan {}, max {}, min {}".format( - np.sum(output_test, dtype=float), np.isnan(output_test).any(), np.max(output_test), np.min(output_test))) - - ret = np.array_equal(output_ref, output_test) - if not ret: - ret_allclose = np.allclose( - output_ref, output_test, rtol=1e-3, atol=1e-3, equal_nan=True) - if self.verbose: - print('{}\tshape {}\tidentical {}\tclose {}'.format('cpu/gpu', tensors, ret, ret_allclose)) - output_ref = output_ref.flatten() - output_test = output_test.flatten() - if not ret: - sub = np.absolute(output_ref - output_test) - norm_diff = np.average(sub) - rel = np.divide(sub, np.absolute(output_ref)) - rel[rel == np.inf] = 0 - max_abs_idx = np.argmax(sub) - max_rel_idx = np.argmax(rel) - if self.verbose: - print('max_diff {}, max_rel_diff {}, norm_diff {}'.format(np.max(sub), np.max(rel), np.average(sub))) - print('max_abs pair [{}] {} {}'.format(max_abs_idx, output_ref[max_abs_idx], output_test[max_abs_idx])) - print('max_rel pair [{}] {} {}'.format(max_rel_idx, output_ref[max_rel_idx], output_test[max_rel_idx])) - - result = ret or ret_allclose or (is_grad and norm_diff < 1e-4) - - if self.verbose: - print("Result {}".format("PASS" if result else "FAIL")) - - self.assertTrue(result) - -if __name__ == '__main__': - unittest.main() diff --git a/apex/contrib/test/groupbn/test_groupbn_channel_last.py b/apex/contrib/test/groupbn/test_groupbn_channel_last.py deleted file mode 100644 index 5ae36e3..0000000 --- a/apex/contrib/test/groupbn/test_groupbn_channel_last.py +++ /dev/null @@ -1,194 +0,0 @@ -import torch -import unittest -import numpy as np -import random -from apex.contrib.groupbn.batch_norm import BatchNorm2d_NHWC - -def generate_uniform_tensor(size, np_dtype, pyt_dtype, device): - array = None - while array is None or np.isnan(array).any(): - array = np.random.uniform(low=-1.0, high=1.0, size=size).astype(np_dtype) - return torch.from_numpy(array).to(device).to(pyt_dtype) - -def to_channels_last(tensor): - #return tensor.permute(0, 2, 3, 1).contiguous() - return tensor.to(memory_format = torch.channels_last) - -def to_channels_first(tensor): - #return tensor.permute(0, 3, 1, 2).contiguous() - return tensor.to(memory_format = torch.contiguous_format) - -class Bn(torch.nn.BatchNorm2d): - def __init__(self, planes, mode): - super(Bn, self).__init__(planes, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) - self.mode = mode - - def forward(self, x, z=None): - out = super().forward(x) - if self.mode == 'bn_add_relu': - out = out.add_(z) - if self.mode != 'bn': - out = out.relu_() - return out - -def bn_nhwc_bwd_ref(grad_y, x, mu, ivar, gamma): - grad_y = grad_y.permute(0, 2, 3, 1).contiguous() - x = x.permute(0, 2, 3, 1).contiguous() - sum_dim_c = (0, 1, 2) - grad_y_f32 = grad_y.float() - x_f32 = x.float() - N = x.shape[0] * x.shape[1] * x.shape[2] # nhw - ones = torch.ones(x.shape, dtype=torch.float32, device='cuda') - - xmu = x_f32 - mu - - xhat = xmu * ivar - dbias = torch.sum(grad_y_f32, dim=sum_dim_c) - - dscale = torch.sum(grad_y_f32 * xhat, dim=sum_dim_c) - - dx1 = (gamma * ivar) / N - dx2 = (N * grad_y_f32) - (dbias * ones) - dx3 = -xhat * dscale - dx23 = dx2 + dx3 - dx = dx1 * (dx23) - dx = dx.half() - dx = dx.permute(0, 3, 1, 2).contiguous() - return dx, dscale, dbias - -class TestGroupBNChannelLast(unittest.TestCase): - - def setUp(self, seed=5, verbose=False): - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - np.random.seed(seed) - self.verbose = verbose - - def test_bn_channel_last(self): - self.run_group_bn_channel_last('bn') - - def test_bn_relu_channel_last(self): - self.run_group_bn_channel_last('bn_relu') - - def test_bn_add_relu_channel_last(self): - self.run_group_bn_channel_last('bn_add_relu') - - def run_group_bn_channel_last(self, mode): - if self.verbose: - print('Running {}'.format(mode)) - - tensor_sizes = [ - (120, 64, 75, 75), - (120, 128, 38, 38)] - - for i in range(len(tensor_sizes)): - tensor_size = tensor_sizes[i] - num_channels = tensor_size[1] - - # Create input data - input_data = generate_uniform_tensor(tensor_size, np.float16, torch.half, 'cuda') - np.save('input.npy', input_data.detach().cpu().numpy()) - input_data.requires_grad = True - - gbn_input = torch.from_numpy(np.load('input.npy')).cuda().half() - gbn_input.requires_grad = True - - residual_data = None - gbn_residual_data = None - if mode == 'bn': - fuse_relu = False - else: - fuse_relu = True - if mode == 'bn_add_relu': - residual_data = generate_uniform_tensor(tensor_size, np.float16, torch.half, 'cuda') - gbn_residual_data = to_channels_last(residual_data) - - bn_grad = generate_uniform_tensor(input_data.shape, np.float16, torch.half, 'cuda') - - # Create models - batchnorm_model = Bn(num_channels, mode).cuda() - group_batchnorm = BatchNorm2d_NHWC(num_channels, fuse_relu=fuse_relu, bn_group=1, torch_channels_last=True).cuda() - - # Run reference forward - bn_output = batchnorm_model(input_data, residual_data) - - # Run GBN forward - gbn_input_data = to_channels_last(gbn_input) - #gbn_input_data = gbn_input - gbn_output = group_batchnorm(gbn_input_data, gbn_residual_data) - - torch.cuda.synchronize() - - # Run reference backward - # (Use the same input and parameters as GBN) - gbn_grad = to_channels_last(bn_grad) - #gbn_grad = bn_grad - grad = gbn_grad.clone().detach() - input_data = torch.from_numpy(np.load('input.npy')).cuda().half() - input_data = to_channels_last(input_data) - if mode != 'bn': - grad[gbn_output <= 0] = 0 - bn_output_grad, _, _ = bn_nhwc_bwd_ref( \ - grad, - input_data, - group_batchnorm.minibatch_mean, - group_batchnorm.minibatch_riv, - group_batchnorm.weight) - bn_output_grad = to_channels_first(bn_output_grad) - - # Run GBN backward - gbn_output.backward(gbn_grad) - torch.cuda.synchronize() - - gbn_output = to_channels_first(gbn_output) - gbn_output_grad = gbn_input.grad.detach().clone().cpu() - - ########################## Validate results ########################## - if self.verbose: - print('Validate activation') - self.validate(bn_output.shape, bn_output, gbn_output) - if self.verbose: - print('Validate grad') - self.validate(bn_output_grad.shape, bn_output_grad, gbn_output_grad, is_grad=True) - - def validate(self, tensors, output_ref, output_test, is_grad=False): - output_ref = output_ref.detach().cpu().numpy() - output_test = output_test.detach().cpu().numpy() - - if self.verbose: - print('>>> tensor_size\t{}'.format(tensors)) - print("sum_output_ref {}, isnan {}, max {}, min {}".format( - np.sum(output_ref, dtype=float), np.isnan(output_ref).any(), np.max(output_ref), np.min(output_ref))) - print("sum_output_test {}, isnan {}, max {}, min {}".format( - np.sum(output_test, dtype=float), np.isnan(output_test).any(), np.max(output_test), np.min(output_test))) - - ret = np.array_equal(output_ref, output_test) - if not ret: - ret_allclose = np.allclose( - output_ref, output_test, rtol=1e-3, atol=1e-3, equal_nan=True) - if self.verbose: - print('{}\tshape {}\tidentical {}\tclose {}'.format('cpu/gpu', tensors, ret, ret_allclose)) - output_ref = output_ref.flatten() - output_test = output_test.flatten() - if not ret: - sub = np.absolute(output_ref - output_test) - norm_diff = np.average(sub) - rel = np.divide(sub, np.absolute(output_ref)) - rel[rel == np.inf] = 0 - max_abs_idx = np.argmax(sub) - max_rel_idx = np.argmax(rel) - if self.verbose: - print('max_diff {}, max_rel_diff {}, norm_diff {}'.format(np.max(sub), np.max(rel), np.average(sub))) - print('max_abs pair [{}] {} {}'.format(max_abs_idx, output_ref[max_abs_idx], output_test[max_abs_idx])) - print('max_rel pair [{}] {} {}'.format(max_rel_idx, output_ref[max_rel_idx], output_test[max_rel_idx])) - - result = ret or ret_allclose or (is_grad and norm_diff < 1e-4) - - if self.verbose: - print("Result {}".format("PASS" if result else "FAIL")) - - self.assertTrue(result) - -if __name__ == '__main__': - unittest.main() - diff --git a/apex/contrib/test/index_mul_2d/test_index_mul_2d.py b/apex/contrib/test/index_mul_2d/test_index_mul_2d.py deleted file mode 100644 index d8f37ea..0000000 --- a/apex/contrib/test/index_mul_2d/test_index_mul_2d.py +++ /dev/null @@ -1,106 +0,0 @@ -import random -import unittest - -import torch -import torch.nn.functional as F - -HAS_INDEX_MUL_2D_RELU = None -try: - from apex.contrib.index_mul_2d import index_mul_2d -except ImportError as e: - HAS_INDEX_MUL_2D_RELU = False -else: - HAS_INDEX_MUL_2D_RELU = True - - -@unittest.skipIf(not HAS_INDEX_MUL_2D_RELU, "`apex.contrib.index_mul_2d` is not found.") -class IndexMul2dTest(unittest.TestCase): - def setUp(self, seed=0): - torch.manual_seed(seed) - - self.input1_size = random.randint(1, 1000) - self.input2_size = random.randint(1, 100000) - self.feature_size = random.randint(1, 256) - - self.input1_float = torch.randn(size=(self.input1_size, self.feature_size),).cuda() - self.input2_float = torch.randn(size=(self.input2_size, self.feature_size),).cuda() - self.index1 = torch.randint(low=0, high=self.input1_size, size=(self.input2_size,)).cuda() - - self.input1_float_ = self.input1_float.clone() - self.input2_float_ = self.input2_float.clone() - - self.input1_float.requires_grad_() - self.input1_float_.requires_grad_() - self.input2_float.requires_grad_() - self.input2_float_.requires_grad_() - - self.input1_half = torch.randn(size=(self.input1_size, self.feature_size),).cuda().half() - self.input2_half = torch.randn(size=(self.input2_size, self.feature_size),).cuda().half() - - self.input1_half_ = self.input1_half.clone() - self.input2_half_ = self.input2_half.clone() - - self.input1_half.requires_grad_() - self.input2_half.requires_grad_() - self.input1_half_.requires_grad_() - self.input2_half_.requires_grad_() - - def test_index_mul_float(self): - out = index_mul_2d(self.input1_float, self.input2_float, self.index1) - energy = (out.float()**2).sum() / out.numel() - force = torch.autograd.grad( - energy, - self.input1_float, - grad_outputs=torch.ones_like(energy), - create_graph=True, - )[0] - loss = (out.float()**2).sum() / out.numel() + (force.float()**2).sum() - loss.backward() - - out_ = self.input1_float_[self.index1] * self.input2_float_ - energy_ = (out_.float()**2).sum() / out.numel() - force_ = torch.autograd.grad( - energy_, - self.input1_float_, - grad_outputs=torch.ones_like(energy), - create_graph=True, - )[0] - loss = (out_.float()**2).sum() / out_.numel() + (force_.float()**2).sum() - loss.backward() - - self.assertTrue(torch.allclose(self.input1_float, self.input1_float_, atol=1e-3, rtol=1e-3, equal_nan=True)) - self.assertTrue(torch.allclose(self.input2_float, self.input2_float_, atol=1e-3, rtol=1e-3, equal_nan=True)) - self.assertTrue(torch.allclose(self.input1_float.grad, self.input1_float_.grad, atol=1e-3, rtol=1e-3, equal_nan=True)) - self.assertTrue(torch.allclose(self.input2_float.grad, self.input2_float_.grad, atol=1e-3, rtol=1e-3, equal_nan=True)) - - def test_index_mul_half(self): - out = index_mul_2d(self.input1_half, self.input2_half, self.index1) - energy = (out.float()**2).sum() / out.numel() - force = torch.autograd.grad( - energy, - self.input1_half, - grad_outputs=torch.ones_like(energy), - create_graph=True, - )[0] - loss = (out.float()**2).sum() / out.numel() + (force.float()**2).sum() - loss.backward() - - out_ = self.input1_half_[self.index1] * self.input2_half_ - energy_ = (out_.float()**2).sum() / out.numel() - force_ = torch.autograd.grad( - energy_, - self.input1_half_, - grad_outputs=torch.ones_like(energy), - create_graph=True, - )[0] - loss = (out_.float()**2).sum() / out_.numel() + (force_.float()**2).sum() - loss.backward() - - self.assertTrue(torch.allclose(self.input1_half, self.input1_half_, atol=1e-3, rtol=1e-3, equal_nan=True)) - self.assertTrue(torch.allclose(self.input2_half, self.input2_half_, atol=1e-3, rtol=1e-3, equal_nan=True)) - self.assertTrue(torch.allclose(self.input1_half.grad, self.input1_half_.grad, atol=1e-3, rtol=1e-3, equal_nan=True)) - self.assertTrue(torch.allclose(self.input2_half.grad, self.input2_half_.grad, atol=1e-3, rtol=1e-3, equal_nan=True)) - -if __name__ == '__main__': - unittest.main() - diff --git a/apex/contrib/test/layer_norm/test_fast_layer_norm.py b/apex/contrib/test/layer_norm/test_fast_layer_norm.py deleted file mode 100644 index 66590e9..0000000 --- a/apex/contrib/test/layer_norm/test_fast_layer_norm.py +++ /dev/null @@ -1,277 +0,0 @@ -import unittest -import sys -import os - -import numpy as np -import torch - -import fast_layer_norm as fln -from apex.contrib.layer_norm.layer_norm import FastLayerNorm - - -class GPUTimer: - def __init__(self, stream): - self.start_ = torch.cuda.Event(enable_timing=True) - self.stop_ = torch.cuda.Event(enable_timing=True) - self.stream_ = stream - - def start(self): - self.stream_.record_event(self.start_) - - def stop(self): - self.stream_.record_event(self.stop_) - - def sync(self): - self.stream_.synchronize() - - def millis(self): - return self.start_.elapsed_time(self.stop_) - - -def size_in_bytes(t): - return torch.numel(t) * t.element_size() - - -def metrics(y_ref, y, epsilon=1e-6): - y_ref = y_ref.float() - y = y.float() - relerr, mse = ( - (y_ref - y).abs().sum() / (y_ref.abs().sum() + epsilon), - (y_ref - y).square().mean(), - ) - return relerr.item(), mse.item() - - -device = torch.device("cuda") -fp32 = torch.float32 -fp16 = torch.float16 -bf16 = torch.bfloat16 - - -def backward_(dz, x, mu, rs, gamma): - - wtype = gamma.dtype - itype = x.dtype - otype = dz.dtype - ctype = mu.dtype - mu = mu.unsqueeze(1) - rs = rs.unsqueeze(1) - - hidden_size = gamma.numel() - y = rs * (x.to(ctype) - mu) - dbeta = dz.view(-1, hidden_size).sum(0, dtype=ctype) - dgamma = (dz * y).view(-1, hidden_size).sum(0, dtype=ctype) - dy = dz.view(-1, hidden_size).to(ctype) * gamma.unsqueeze(0).to(ctype) - mdy = dy.mean(1, keepdim=True, dtype=ctype) - - mdyy = (dy * y).mean(1, keepdim=True, dtype=ctype) - dx = rs * (dy - mdyy * y - mdy) - - return dx.to(itype), dgamma.to(wtype), dbeta.to(wtype) - - -def benchmark_(S, B, hidden_size, itype, wtype, runs=100): - epsilon = 1e-5 - - x = torch.randn((S * B, hidden_size), dtype=itype, device=device) - beta = torch.randn(hidden_size, dtype=wtype, device=device) - gamma = torch.randn(hidden_size, dtype=wtype, device=device) - dz = torch.randn(x.shape, dtype=wtype, device=device) - - stream = torch.cuda.Stream() - with torch.cuda.stream(stream): - - timer = GPUTimer(stream) - - # warmup - for r in range(runs): - z, mu, rsigma = fln.ln_fwd(x, gamma, beta, epsilon) - - timer.start() - for r in range(runs): - z, mu, rsigma = fln.ln_fwd(x, gamma, beta, epsilon) - timer.stop() - timer.sync() - - total_bytes_fwd = sum([size_in_bytes(t) for t in [x, z, gamma, beta, mu, rsigma]]) - - ms_fwd = timer.millis() / runs - - print( - "[FWD] Time: {:.4f}ms Throughput: {:.4f} GB/sec".format( - ms_fwd, total_bytes_fwd * 1e-6 / ms_fwd - ) - ) - - timer.start() - for r in range(runs): - dx, dgamma, dbeta, dbp, dgp = fln.ln_bwd(dz, x, mu, rsigma, gamma) - timer.stop() - timer.sync() - - total_bytes_bwd = sum( - [ - size_in_bytes(t) - for t in [dz, x, mu, rsigma, gamma, dx, dgamma, dbeta, dbp, dbp, dgp, dgp] - ] - ) - - ms_bwd = timer.millis() / runs - - print( - "[BWD] Time: {:.4f}ms Throughput: {:.4f} GB/sec".format( - ms_bwd, total_bytes_bwd * 1e-6 / ms_bwd - ) - ) - - -def test_(S, B, hidden_size, itype, wtype, ctype=fp32): - - seed = 1243 - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - - otype = wtype - print("========================================================") - print(f"S={S} B={B} Hidden={hidden_size} {itype} {wtype}") - print("--------------------------------------------------------") - - x = torch.randn(S * B, hidden_size, dtype=itype, device=device) - gamma = torch.randn(hidden_size, dtype=wtype, device=device) * 0.2 - beta = torch.randn(hidden_size, dtype=wtype, device=device) * 0.2 - epsilon = 1e-5 - - x.requires_grad = True - gamma.requires_grad = True - beta.requires_grad = True - - mu_ref = x.mean(1, dtype=ctype, keepdim=True) - v = torch.square(x - mu_ref).mean(1, dtype=ctype, keepdim=True) - rs_ref = torch.rsqrt(v + epsilon) - y_ref = rs_ref * (x.to(ctype) - mu_ref) - z_ref = (gamma.unsqueeze(0) * (y_ref).to(otype) + beta.unsqueeze(0)).to(otype) - - mu_ref = mu_ref.flatten() - rs_ref = rs_ref.flatten() - - dz = torch.randn_like(z_ref) - - # z_ref.backward(dz) - # dx_ref = x.grad - # dgamma_ref = gamma.grad - # dbeta_ref = beta.grad - - dx_ref, dg_ref, db_ref = backward_(dz, x, mu_ref, rs_ref, gamma) - - z, mu, rs = fln.ln_fwd(x, gamma, beta, epsilon) - dx, dg, db, dg_part, db_part = fln.ln_bwd(dz, x, mu, rs, gamma) - - re_z, mse_z = metrics(z_ref, z) - re_mu, mse_mu = metrics(mu_ref, mu) - re_rs, mse_rs = metrics(rs_ref, rs) - - re_dx, mse_dx = metrics(dx_ref, dx) - re_dg, mse_dg = metrics(dg_ref, dg) - re_db, mse_db = metrics(db_ref, db) - - print(f" z: relerr={re_z :.4e} mse={mse_z :.4e}") - print(f"mu: relerr={re_mu:.4e} mse={mse_mu:.4e}") - print(f"rs: relerr={re_mu:.4e} mse={mse_mu:.4e}") - - print(f"dx: relerr={re_dx:.4e} mse={mse_dx:.4e}") - print(f"dg: relerr={re_dg:.4e} mse={mse_dg:.4e}") - print(f"db: relerr={re_db:.4e} mse={mse_db:.4e}") - - def check_err(x, relerr): - tol = 1e-3 if x.dtype == torch.float16 else 5e-6 - return relerr < tol - - return [ - check_err(x, re) - for x, re in zip([z, mu, rs, dx, dg, db], [re_z, re_mu, re_rs, re_dx, re_dg, re_db]) - ] - - -class TestFastLayerNorm(unittest.TestCase): - def assertAll(self, l): - if not all(l): - print(l) - for x in l: - self.assertTrue(x) - - def test_all_configs(self): - - hidden_sizes = [ - 768, - 1024, - 1536, - 2048, - 2304, - 3072, - 3840, - 4096, - 5120, - 6144, - 8192, - 10240, - 12288, - 12800, - 14336, - 15360, - 16384, - 18432, - 20480, - 24576, - 25600, - 30720, - 32768, - 40960, - 49152, - 65536, - ] - - for h in hidden_sizes: - with self.subTest(f"hidden_size={h}"): - self.assertAll(test_(256, 2, h, fp32, fp32)) - self.assertAll(test_(256, 2, h, fp16, fp16)) - self.assertAll(test_(256, 2, h, fp32, fp16)) - # self.assertAll(test_(256, 2, h, bf16, bf16)) - # self.assertAll(test_(256, 2, h, fp32, bf16)) - - def test_run_benchmark(self): - for (S, B, hidden_size, runs) in ( - (512, 32, 768, 1000), - (512, 32, 1024, 1000), - (512, 8, 4096, 1000), - (512, 8, 5120, 1000), - (512, 8, 6144, 1000), - (256, 2, 20480, 500), - (256, 2, 25600, 500), - (256, 2, 40960, 250), - (256, 2, 65536, 250), - ): - with self.subTest(f"(S, B, hidden_size)=({S}, {B}, {hidden_size})"): - benchmark_(S, B, hidden_size, fp16, fp16, runs) - - def test_compat_with_autocast(self): - autocast_dtypes = ( - # (torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,) - (torch.half,) - ) - input_shape = (512, 32, 768) - layer_norm = FastLayerNorm(input_shape[-1]).cuda() - input = torch.randn(input_shape).cuda() - - for dtype in autocast_dtypes: - layer_norm.zero_grad(set_to_none=True) - with self.subTest(f"autocast_dtype={dtype}"): - with torch.cuda.amp.autocast(enabled=True, dtype=dtype): - out = layer_norm(input) - self.assertEqual(dtype, out.dtype) - grad = torch.randn_like(out) - out.backward(grad) - self.assertEqual(torch.float32, layer_norm.weight.grad.dtype) - - -if __name__ == "__main__": - unittest.main() diff --git a/apex/contrib/test/multihead_attn/test_encdec_multihead_attn.py b/apex/contrib/test/multihead_attn/test_encdec_multihead_attn.py deleted file mode 100644 index 836fe84..0000000 --- a/apex/contrib/test/multihead_attn/test_encdec_multihead_attn.py +++ /dev/null @@ -1,136 +0,0 @@ -import torch - -import unittest - -from apex.contrib.multihead_attn import EncdecMultiheadAttn - -class EncdecMultiheadAttnTest(unittest.TestCase): - def setUp(self, seed=1234): - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - self.seq_length = 80 - self.sequences = 10 - self.hidden_dim = 1024 - self.heads = 16 - self.dropout_prob = 0.0 - - self.ref_layer = EncdecMultiheadAttn(self.hidden_dim, - self.heads, - dropout=self.dropout_prob, - bias=False, - include_norm_add=False, - impl='default') - self.ref_layer.cuda().half() - self.ref_layer.reset_parameters() - self.ref_inputs_q = torch.randn(self.seq_length, self.sequences, self.hidden_dim, - dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True) - self.ref_inputs_k = torch.randn(self.seq_length, self.sequences, self.hidden_dim, - dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True) - - # Reset seed so parameters are identical - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - self.tst_layer = EncdecMultiheadAttn(self.hidden_dim, - self.heads, - dropout=self.dropout_prob, - bias=False, - include_norm_add=False, - impl='fast') - self.tst_layer.cuda().half() - self.tst_layer.reset_parameters() - - self.tst_inputs_q = torch.randn(self.seq_length, self.sequences, self.hidden_dim, - dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True) - self.tst_inputs_k = torch.randn(self.seq_length, self.sequences, self.hidden_dim, - dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True) - - def test_encdec_multihead_attn(self) : - grads = torch.randn_like(self.tst_inputs_q) - - ref_outputs,_ = self.ref_layer.forward(self.ref_inputs_q, - self.ref_inputs_k, - self.ref_inputs_k, - key_padding_mask=None, - need_weights=False, - attn_mask=None, - is_training=True) - - tst_outputs,_ = self.tst_layer.forward(self.tst_inputs_q, - self.tst_inputs_k, - self.tst_inputs_k, - key_padding_mask=None, - need_weights=False, - attn_mask=None, - is_training=True) - - self.ref_inputs_q.backward(grads) - self.tst_inputs_q.backward(grads) - - self.assertTrue(torch.allclose(self.ref_inputs_q, self.tst_inputs_q, atol=1e-5, rtol=1e-5)) - self.assertTrue(torch.allclose(self.ref_inputs_k, self.tst_inputs_k, atol=1e-5, rtol=1e-5)) - self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3)) - self.assertTrue(torch.allclose(self.ref_inputs_q.grad, self.tst_inputs_q.grad, atol=1e-3, rtol=1e-3)) - - def test_encdec_multihead_attn_time_mask(self) : - grads = torch.randn_like(self.tst_inputs_q) - time_mask_byte = torch.triu(torch.ones(self.tst_inputs_q.size(0), self.tst_inputs_k.size(0), device=torch.device("cuda"), dtype=torch.uint8), 1) - time_mask_bool = time_mask_byte.to(torch.bool) - - ref_outputs,_ = self.ref_layer.forward(self.ref_inputs_q, - self.ref_inputs_k, - self.ref_inputs_k, - key_padding_mask=None, - need_weights=False, - attn_mask=time_mask_bool, - is_training=True) - - tst_outputs,_ = self.tst_layer.forward(self.tst_inputs_q, - self.tst_inputs_k, - self.tst_inputs_k, - key_padding_mask=None, - need_weights=False, - attn_mask=time_mask_byte, - is_training=True) - - self.ref_inputs_q.backward(grads) - self.tst_inputs_q.backward(grads) - - self.assertTrue(torch.allclose(self.ref_inputs_q, self.tst_inputs_q, atol=1e-5, rtol=1e-5)) - self.assertTrue(torch.allclose(self.ref_inputs_k, self.tst_inputs_k, atol=1e-5, rtol=1e-5)) - self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3)) - self.assertTrue(torch.allclose(self.ref_inputs_q.grad, self.tst_inputs_q.grad, atol=1e-3, rtol=1e-3)) - - def test_encdec_multihead_attn_pad_mask(self) : - grads = torch.randn_like(self.tst_inputs_q) - pad_mask_byte = torch.tril(torch.ones(self.tst_inputs_k.size(1), self.tst_inputs_k.size(0), device=torch.device("cuda"), dtype=torch.uint8), 1) - pad_mask_bool = pad_mask_byte.to(torch.bool) - - ref_outputs,_ = self.ref_layer.forward(self.ref_inputs_q, - self.ref_inputs_k, - self.ref_inputs_k, - key_padding_mask=pad_mask_bool, - need_weights=False, - attn_mask=None, - is_training=True) - - tst_outputs,_ = self.tst_layer.forward(self.tst_inputs_q, - self.tst_inputs_k, - self.tst_inputs_k, - key_padding_mask=pad_mask_byte, - need_weights=False, - attn_mask=None, - is_training=True) - - self.ref_inputs_q.backward(grads) - self.tst_inputs_q.backward(grads) - - self.assertTrue(torch.allclose(self.ref_inputs_q, self.tst_inputs_q, atol=1e-5, rtol=1e-5)) - self.assertTrue(torch.allclose(self.ref_inputs_k, self.tst_inputs_k, atol=1e-5, rtol=1e-5)) - self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3)) - self.assertTrue(torch.allclose(self.ref_inputs_q.grad, self.tst_inputs_q.grad, atol=1e-3, rtol=1e-3)) - - -if __name__ == '__main__': - unittest.main() diff --git a/apex/contrib/test/multihead_attn/test_encdec_multihead_attn_norm_add.py b/apex/contrib/test/multihead_attn/test_encdec_multihead_attn_norm_add.py deleted file mode 100644 index 2ab3009..0000000 --- a/apex/contrib/test/multihead_attn/test_encdec_multihead_attn_norm_add.py +++ /dev/null @@ -1,78 +0,0 @@ -import torch - -import unittest - -from apex.contrib.multihead_attn import EncdecMultiheadAttn - -class EncdecMultiheadAttnNormAddTest(unittest.TestCase): - def setUp(self, seed=1234): - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - self.seq_length = 80 - self.sequences = 10 - self.hidden_dim = 1024 - self.heads = 16 - self.dropout_prob = 0.0 - - self.ref_layer = EncdecMultiheadAttn(self.hidden_dim, - self.heads, - dropout=self.dropout_prob, - bias=False, - include_norm_add=True, - impl='default') - self.ref_layer.cuda().half() - self.ref_layer.reset_parameters() - self.ref_inputs_q = torch.randn(self.seq_length, self.sequences, self.hidden_dim, - dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True) - self.ref_inputs_k = torch.randn(self.seq_length, self.sequences, self.hidden_dim, - dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True) - - # Reset seed so parameters are identical - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - self.tst_layer = EncdecMultiheadAttn(self.hidden_dim, - self.heads, - dropout=self.dropout_prob, - bias=False, - include_norm_add=True, - impl='fast') - self.tst_layer.cuda().half() - self.tst_layer.reset_parameters() - - self.tst_inputs_q = torch.randn(self.seq_length, self.sequences, self.hidden_dim, - dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True) - self.tst_inputs_k = torch.randn(self.seq_length, self.sequences, self.hidden_dim, - dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True) - - def test_encdec_multihead_attn_norm_add(self) : - grads = torch.randn_like(self.tst_inputs_q) - - for _ in range(5) : - ref_outputs,_ = self.ref_layer.forward(self.ref_inputs_q, - self.ref_inputs_k, - self.ref_inputs_k, - key_padding_mask=None, - need_weights=False, - attn_mask=None, - is_training=True) - - tst_outputs,_ = self.tst_layer.forward(self.tst_inputs_q, - self.tst_inputs_k, - self.tst_inputs_k, - key_padding_mask=None, - need_weights=False, - attn_mask=None, - is_training=True) - - self.ref_inputs_q.backward(grads) - self.tst_inputs_q.backward(grads) - - self.assertTrue(torch.allclose(self.ref_inputs_q, self.tst_inputs_q, atol=1e-5, rtol=1e-5)) - self.assertTrue(torch.allclose(self.ref_inputs_k, self.tst_inputs_k, atol=1e-5, rtol=1e-5)) - self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3)) - self.assertTrue(torch.allclose(self.ref_inputs_q.grad, self.tst_inputs_q.grad, atol=1e-3, rtol=1e-3)) - -if __name__ == '__main__': - unittest.main() diff --git a/apex/contrib/test/multihead_attn/test_fast_self_multihead_attn_bias.py b/apex/contrib/test/multihead_attn/test_fast_self_multihead_attn_bias.py deleted file mode 100644 index b4bbf34..0000000 --- a/apex/contrib/test/multihead_attn/test_fast_self_multihead_attn_bias.py +++ /dev/null @@ -1,77 +0,0 @@ -import torch - -import unittest - -from apex.contrib.multihead_attn import SelfMultiheadAttn - -class SelfMultiheadAttnTest(unittest.TestCase): - def setUp(self, seed=1234): - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - self.seq_length = 80 - self.sequences = 10 - self.hidden_dim = 1024 - self.heads = 16 - self.dropout_prob = 0.0 - - self.ref_layer = SelfMultiheadAttn(self.hidden_dim, - self.heads, - dropout=self.dropout_prob, - bias=True, - include_norm_add=False, - separate_qkv_params=True, - mask_additive=True, - impl='default') - self.ref_layer.cuda().half() - self.ref_layer.reset_parameters() - self.ref_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim, - dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True) - # Reset seed so parameters are identical - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - self.tst_layer = SelfMultiheadAttn(self.hidden_dim, - self.heads, - dropout=self.dropout_prob, - bias=True, - include_norm_add=False, - separate_qkv_params=True, - mask_additive=True, - impl='fast') - self.tst_layer.cuda().half() - self.tst_layer.reset_parameters() - - self.tst_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim, - dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True) - - def test_self_multihead_attn_additive_mask(self) : - grads = torch.randn_like(self.tst_inputs) - mask = ((torch.randn(self.sequences, self.seq_length) > 0) * -10000.0).half().cuda() - - ref_outputs,_ = self.ref_layer.forward(self.ref_inputs, - self.ref_inputs, - self.ref_inputs, - key_padding_mask=mask, - need_weights=False, - attn_mask=None, - is_training=True) - - tst_outputs,_ = self.tst_layer.forward(self.tst_inputs, - self.tst_inputs, - self.tst_inputs, - key_padding_mask=mask, - need_weights=False, - attn_mask=None, - is_training=True) - - - self.ref_inputs.backward(grads) - self.tst_inputs.backward(grads) - - self.assertTrue(torch.allclose(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5)) - self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3)) - self.assertTrue(torch.allclose(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3)) - -if __name__ == '__main__': - unittest.main() diff --git a/apex/contrib/test/multihead_attn/test_mha_fused_softmax.py b/apex/contrib/test/multihead_attn/test_mha_fused_softmax.py deleted file mode 100644 index 60d9541..0000000 --- a/apex/contrib/test/multihead_attn/test_mha_fused_softmax.py +++ /dev/null @@ -1,42 +0,0 @@ -import torch -import unittest -import torch.nn.functional as F -from apex.contrib.multihead_attn import fast_mask_softmax_dropout_func - -class FusedSoftmaxTest(unittest.TestCase): - def setUp(self, seed=1234): - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - self.seq_length = 80 - self.sequences = 10 - self.hidden_dim = 1024 - self.heads = 16 - self.dropout_prob = 0.0 - - self.mask = (torch.randn(self.sequences,self.seq_length)>0).cuda() - self.mask = self.mask.half()*-10000 - self.ref_inputs = torch.randn(self.heads * self.sequences, self.seq_length, self.seq_length, - dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True) - - self.tst_inputs = self.ref_inputs.clone().detach().requires_grad_(True) - - def test_fused_softmax(self) : - grads = torch.randn_like(self.tst_inputs) - y_ref = self.ref_inputs.view(self.sequences, self.heads, self.seq_length, self.seq_length) - y_ref = y_ref + self.mask.unsqueeze(1).unsqueeze(2) - y_ref = y_ref.view(self.sequences*self.heads, self.seq_length, self.seq_length) - y_ref = F.softmax(y_ref, dim=-1) - y_ref = torch._fused_dropout(y_ref, 1.0) - - y_tst = fast_mask_softmax_dropout_func(True, self.heads, self.tst_inputs, self.mask, True, 0.0) - y_ref[0].backward(grads) - y_tst.backward(grads) - - self.assertTrue(torch.allclose(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5)) - self.assertTrue(torch.allclose(y_ref[0], y_tst, atol=1e-3, rtol=1e-3)) - self.assertTrue(torch.allclose(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3)) - - -if __name__ == '__main__': - unittest.main() diff --git a/apex/contrib/test/multihead_attn/test_self_multihead_attn.py b/apex/contrib/test/multihead_attn/test_self_multihead_attn.py deleted file mode 100644 index 10d779f..0000000 --- a/apex/contrib/test/multihead_attn/test_self_multihead_attn.py +++ /dev/null @@ -1,130 +0,0 @@ -import torch - -import unittest - -from apex.contrib.multihead_attn import SelfMultiheadAttn - -class SelfMultiheadAttnTest(unittest.TestCase): - def setUp(self, seed=1234): - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - self.seq_length = 80 - self.sequences = 10 - self.hidden_dim = 1024 - self.heads = 16 - self.dropout_prob = 0.0 - - self.ref_layer = SelfMultiheadAttn(self.hidden_dim, - self.heads, - dropout=self.dropout_prob, - bias=False, - include_norm_add=False, - impl='default') - self.ref_layer.cuda().half() - self.ref_layer.reset_parameters() - self.ref_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim, - dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True) - - # Reset seed so parameters are identical - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - self.tst_layer = SelfMultiheadAttn(self.hidden_dim, - self.heads, - dropout=self.dropout_prob, - bias=False, - include_norm_add=False, - impl='fast') - self.tst_layer.cuda().half() - self.tst_layer.reset_parameters() - - self.tst_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim, - dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True) - - def test_self_multihead_attn(self): - grads = torch.randn_like(self.tst_inputs) - - ref_outputs,_ = self.ref_layer.forward(self.ref_inputs, - self.ref_inputs, - self.ref_inputs, - key_padding_mask=None, - need_weights=False, - attn_mask=None, - is_training=True) - - tst_outputs,_ = self.tst_layer.forward(self.tst_inputs, - self.tst_inputs, - self.tst_inputs, - key_padding_mask=None, - need_weights=False, - attn_mask=None, - is_training=True) - - self.ref_inputs.backward(grads) - self.tst_inputs.backward(grads) - - self.assertTrue(torch.allclose(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5)) - self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3)) - self.assertTrue(torch.allclose(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3)) - - def test_self_multihead_attn_time_mask(self) : - grads = torch.randn_like(self.tst_inputs) - time_mask_byte= torch.triu(torch.ones(self.tst_inputs.size(0), self.tst_inputs.size(0), device=torch.device("cuda"), dtype=torch.uint8), 1) - time_mask_bool= time_mask_byte.to(torch.bool) - - ref_outputs,_ = self.ref_layer.forward(self.ref_inputs, - self.ref_inputs, - self.ref_inputs, - key_padding_mask=None, - need_weights=False, - attn_mask=time_mask_bool, - is_training=True) - - tst_outputs,_ = self.tst_layer.forward(self.tst_inputs, - self.tst_inputs, - self.tst_inputs, - key_padding_mask=None, - need_weights=False, - attn_mask=time_mask_byte, - is_training=True) - - - self.ref_inputs.backward(grads) - self.tst_inputs.backward(grads) - - self.assertTrue(torch.allclose(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5)) - self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3)) - self.assertTrue(torch.allclose(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3)) - - def test_self_multihead_attn_pad_mask(self) : - grads = torch.randn_like(self.tst_inputs) - pad_mask_byte = torch.tril(torch.ones(self.tst_inputs.size(1), self.tst_inputs.size(0), device=torch.device("cuda"), dtype=torch.uint8), 1) - pad_mask_bool = pad_mask_byte.to(torch.bool) - - ref_outputs,_ = self.ref_layer.forward(self.ref_inputs, - self.ref_inputs, - self.ref_inputs, - key_padding_mask=pad_mask_bool, - need_weights=False, - attn_mask=None, - is_training=True) - - tst_outputs,_ = self.tst_layer.forward(self.tst_inputs, - self.tst_inputs, - self.tst_inputs, - key_padding_mask=pad_mask_byte, - need_weights=False, - attn_mask=None, - is_training=True) - - - self.ref_inputs.backward(grads) - self.tst_inputs.backward(grads) - - self.assertTrue(torch.allclose(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5)) - self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3)) - self.assertTrue(torch.allclose(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3)) - -if __name__ == '__main__': - unittest.main() diff --git a/apex/contrib/test/multihead_attn/test_self_multihead_attn_norm_add.py b/apex/contrib/test/multihead_attn/test_self_multihead_attn_norm_add.py deleted file mode 100644 index 125656f..0000000 --- a/apex/contrib/test/multihead_attn/test_self_multihead_attn_norm_add.py +++ /dev/null @@ -1,73 +0,0 @@ -import torch - -import unittest - -from apex.contrib.multihead_attn import SelfMultiheadAttn - -class SelfMultiheadAttnNormAddTest(unittest.TestCase): - def setUp(self, seed=1234): - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - self.seq_length = 80 - self.sequences = 10 - self.hidden_dim = 1024 - self.heads = 16 - self.dropout_prob = 0.0 - - self.ref_layer = SelfMultiheadAttn(self.hidden_dim, - self.heads, - dropout=self.dropout_prob, - bias=False, - include_norm_add=True, - impl='default') - self.ref_layer.cuda().half() - self.ref_layer.reset_parameters() - self.ref_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim, - dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True) - - # Reset seed so parameters are identical - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - self.tst_layer = SelfMultiheadAttn(self.hidden_dim, - self.heads, - dropout=self.dropout_prob, - bias=False, - include_norm_add=True, - impl='fast') - self.tst_layer.cuda().half() - self.tst_layer.reset_parameters() - - self.tst_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim, - dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True) - - def test_self_multihead_attn_norm_add(self) : - grads = torch.randn_like(self.tst_inputs) - - for _ in range(0, 5) : - ref_outputs,_ = self.ref_layer.forward(self.ref_inputs, - self.ref_inputs, - self.ref_inputs, - key_padding_mask=None, - need_weights=False, - attn_mask=None, - is_training=True) - - tst_outputs,_ = self.tst_layer.forward(self.tst_inputs, - self.tst_inputs, - self.tst_inputs, - key_padding_mask=None, - need_weights=False, - attn_mask=None, - is_training=True) - - self.ref_inputs.backward(grads) - self.tst_inputs.backward(grads) - - self.assertTrue(torch.allclose(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5)) - self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3)) - self.assertTrue(torch.allclose(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3)) - -if __name__ == '__main__': - unittest.main() diff --git a/apex/contrib/test/optimizers/test_dist_adam.py b/apex/contrib/test/optimizers/test_dist_adam.py deleted file mode 100644 index bd23ce2..0000000 --- a/apex/contrib/test/optimizers/test_dist_adam.py +++ /dev/null @@ -1,391 +0,0 @@ -from contextlib import contextmanager -import io -import os - -import torch -from torch.testing._internal import common_utils -from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam -from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase - -class SimpleModel(torch.nn.Module): - - def __init__(self, num_layers, size): - super().__init__() - self.layers = torch.nn.ModuleList([ - torch.nn.Linear(size, size, bias=(i%3==0)) - for i in range(num_layers) - ]) - - def forward(self, x): - y = 0 - for i, l in enumerate(self.layers): - y += (i+1) * l(x) - return y - -def make_models( - num_layers, - size, - dtype=torch.float32, - param_sync_dtype=None, - device='cuda', - overlap_communication=True, -): - - # Construct models with same parameters - ref_model = SimpleModel(num_layers, size).to(dtype=dtype, device=device) - dist_model = SimpleModel(num_layers, size).to(dtype=dtype, device=device) - with torch.no_grad(): - for ref_param, dist_param in zip(dist_model.parameters(), - ref_model.parameters()): - dist_param.copy_(ref_param) - - # Initialize reference model with data-parallelism - rank = torch.distributed.get_rank() - ref_model = torch.nn.parallel.DistributedDataParallel( - ref_model, - device_ids=[rank] if device=='cuda' else None, - output_device=rank if device=='cuda' else None, - ) - - # Construct optimizers with same hyperparameters - optim_args = dict(lr=0.1, betas=(0.1,0.2), eps=0.25, weight_decay=0.1) - ref_optim = torch.optim.AdamW( - [ - {'params': list(ref_model.parameters())[1::2], 'lr': 0.2}, - {'params': list(ref_model.parameters())[0::2]}, - ], - **optim_args, - ) - dist_optim = DistributedFusedAdam( - [ - {'params': list(dist_model.parameters())[1::2], 'lr': 0.2}, - {'params': list(dist_model.parameters())[0::2]}, - ], - overlap_grad_sync=overlap_communication, - bucket_cap_mb=71/(4*1024*1024), - dtype=torch.float32, - param_sync_dtype=param_sync_dtype, - **optim_args, - ) - - return ref_model, ref_optim, dist_model, dist_optim - -@contextmanager -def dummy_context(): - try: - yield - finally: - pass - -class TestDistributedFusedAdam(NcclDistributedTestBase): - - seed = 1234 - - def test_matches_pytorch( - self, - num_layers=11, - layer_size=7, - batch_size=3, - num_steps=3, - micro_batch_steps=3, - overlap_communication=True, - use_nosync=True, - dtype=torch.float32, - param_sync_dtype=None, - device='cuda', - rtol=None, - atol=None, - ): - - torch.manual_seed(self.seed + self.rank) - - # Identical models with data-parallel and ZeRO - ref_model, ref_optim, dist_model, dist_optim = make_models( - num_layers, - layer_size, - dtype=dtype, - param_sync_dtype=param_sync_dtype, - device=device, - overlap_communication=overlap_communication, - ) - - # Training loop - for step in range(num_steps): - - # Reset gradients - ref_optim.zero_grad() - dist_optim.zero_grad() - - # Forward and backward passes - for micro_step in range(micro_batch_steps): - - # Synthetic data - x = torch.rand(batch_size, layer_size) - 0.5 - dy = torch.rand_like(x) - 0.5 - x = x.to(dtype=dtype, device=device) - dy = dy.to(dtype=dtype, device=device) - - # Reference implementation - x_ref = x.detach().clone().requires_grad_(True) - y_ref = ref_model(x_ref) - y_ref.backward(dy) - - # Distributed implementation - x_dist = x.detach().clone().requires_grad_(True) - y_dist = dist_model(x_dist) - backward_context = dummy_context - if use_nosync and micro_step < micro_batch_steps-1: - backward_context = dist_optim.no_sync - with backward_context(): - y_dist.backward(dy) - - # Check that data tensors match - torch.testing.assert_close( - y_dist, y_ref, rtol=rtol, atol=atol) - torch.testing.assert_close( - x_dist.grad, x_ref.grad, rtol=rtol, atol=atol) - - # Optimization step - ref_optim.step() - dist_optim.step() - - # Check that parameters match - for ref_param, dist_param in zip(ref_model.parameters(), - dist_model.parameters()): - torch.testing.assert_close( - dist_param, ref_param, rtol=rtol, atol=atol) - - def test_matches_pytorch_no_overlap(self): - self.test_matches_pytorch( - overlap_communication=False, - use_nosync=False, - ) - - def test_matches_pytorch_sync_every_step(self): - self.test_matches_pytorch(use_nosync=False) - - def test_matches_pytorch_fp64(self): - self.test_matches_pytorch( - dtype=torch.float64, - rtol=1.3e-6, - atol=1e-5, - ) - - def test_matches_pytorch_fp16(self): - self.test_matches_pytorch( - dtype=torch.float16, - rtol=1e-2, - atol=1e-2, - ) - - def test_matches_pytorch_allgather_fp16(self): - self.test_matches_pytorch( - dtype=torch.float32, - param_sync_dtype=torch.float16, - rtol=1e-2, - atol=1e-2, - ) - - def test_raises_on_mismatch(self): - - torch.manual_seed(self.seed + self.rank) - - # Identical models with data-parallel and ZeRO - num_layers = 11 - layer_size = 7 - ref_model, ref_optim, dist_model, dist_optim = make_models( - num_layers, - layer_size, - ) - - # Only perform training step with distributed model - dist_optim.zero_grad() - x = torch.rand(3, layer_size) + 0.5 - x = x.to(dtype=torch.float32, device='cuda') - dy = torch.rand_like(x) + 0.5 - y = dist_model(x) - y.backward(dy) - dist_optim.step() - - # Check that parameters do not match - for ref_param, dist_param in zip(ref_model.parameters(), - dist_model.parameters()): - self.assertRaises( - AssertionError, - torch.testing.assert_close, - dist_param, ref_param, - ) - - def test_clip_grad_norm(self): - - torch.manual_seed(self.seed + self.rank) - - # Identical models with data-parallel and ZeRO - ref_model, ref_optim, dist_model, dist_optim = make_models(1, 1) - - # Training steps with pre-determined gradients - xs = [3, 1, 4, 1, 5, 9] - dys = [1, -1, 1, -1, 1, -1] - for x, dy in zip(xs, dys): - x = torch.tensor([x], dtype=torch.float32, device='cuda') - dy = torch.tensor([dy], dtype=torch.float32, device='cuda') - - # Reference implementation - ref_optim.zero_grad() - y_ref = ref_model(x.detach()) - y_ref.backward(dy.detach()) - ref_grad_norm = torch.nn.utils.clip_grad_norm_(ref_model.parameters(), 3.5) - ref_optim.step() - - # Distributed implementation - dist_optim.zero_grad() - y_dist = dist_model(x.detach()) - y_dist.backward(dy.detach()) - dist_grad_norm = dist_optim.clip_grad_norm(3.5) - dist_optim.step() - - # Check that parameters match - torch.testing.assert_close(dist_grad_norm, ref_grad_norm) - for ref_param, dist_param in zip(ref_model.parameters(), - dist_model.parameters()): - torch.testing.assert_close(dist_param, ref_param) - - def test_grad_scaler(self): - - torch.manual_seed(self.seed + self.rank) - - # Identical models with data-parallel and ZeRO - ref_model, ref_optim, dist_model, dist_optim = make_models(1, 1) - grad_scaler_args = dict( - init_scale=3.21, - growth_factor=1.23, - backoff_factor=0.876, - growth_interval=1, - ) - ref_scaler = torch.cuda.amp.GradScaler(**grad_scaler_args) - dist_scaler = torch.cuda.amp.GradScaler(**grad_scaler_args) - - # Training steps with pre-determined gradients - xs = [3, 1, 4, 1, 5, 9] - dys = [1, float('inf'), 1, 1, float('nan'), -1] - for x, dy in zip(xs, dys): - x = torch.tensor([x], dtype=torch.float32, device='cuda') - dy = torch.tensor([dy], dtype=torch.float32, device='cuda') - - # Reference implementation - ref_optim.zero_grad() - y_ref = ref_model(x.detach()) - ref_scaler.scale(y_ref).backward(dy.detach()) - ref_scaler.step(ref_optim) - ref_scaler.update() - - # Distributed implementation - dist_optim.zero_grad() - y_dist = dist_model(x.detach()) - dist_scaler.scale(y_dist).backward(dy.detach()) - dist_scaler.step(dist_optim) - dist_scaler.update() - - # Check that parameters match - for ref_param, dist_param in zip(ref_model.parameters(), - dist_model.parameters()): - torch.testing.assert_close(dist_param, ref_param) - - def test_checkpoint(self): - - # Construct two models with same config and different params - num_layers = 5 - layer_size = 2 - torch.manual_seed(self.seed + self.rank) - _, _, model_save, optim_save = make_models(num_layers, layer_size) - _, _, model_load, optim_load = make_models(num_layers, layer_size) - - # Train one of the models - num_steps = 3 - micro_batch_steps = 2 - batch_size = 4 - for step in range(num_steps): - optim_save.zero_grad() - for micro_step in range(micro_batch_steps): - x = torch.rand(batch_size, layer_size) - 0.5 - dy = torch.rand_like(x) - 0.5 - x = x.cuda() - dy = dy.cuda() - y = model_save(x) - y.backward(dy) - optim_save.step() - - # Make sure models are different - for param_save, param_load in zip(model_save.parameters(), - model_load.parameters()): - self.assertRaises( - AssertionError, - torch.testing.assert_close, - param_load, param_save, - ) - - # Save state on root rank and load on all ranks - state_dict = { - 'model': model_save.state_dict(), - 'optim': optim_save.state_dict(), - } - if self.rank == 0: - state_bytes = io.BytesIO() - torch.save(state_dict, state_bytes) - state_bytes = [state_bytes.getvalue()] - else: - state_bytes = [None] - torch.distributed.broadcast_object_list(state_bytes, src=0) - state_bytes = io.BytesIO(state_bytes[0]) - state_dict = torch.load(state_bytes, map_location='cuda') - model_load.load_state_dict(state_dict['model']) - optim_load.load_state_dict(state_dict['optim']) - - # Make sure models are identical - for param_save, param_load in zip(model_save.parameters(), - model_load.parameters()): - torch.testing.assert_close(param_load, param_save) - - # Train both models - num_steps = 3 - micro_batch_steps = 3 - batch_size = 5 - for step in range(num_steps): - - # Reset gradients - optim_save.zero_grad() - optim_load.zero_grad() - - # Forward and backward passes - for micro_step in range(micro_batch_steps): - - # Synthetic data - x = torch.rand(batch_size, layer_size) - 0.5 - dy = torch.rand_like(x) - 0.5 - x = x.cuda() - dy = dy.cuda() - - # Forward and backward pass - x_save = x.detach().clone().requires_grad_(True) - y_save = model_save(x_save) - y_save.backward(dy) - x_load = x.detach().clone().requires_grad_(True) - y_load = model_load(x_load) - y_load.backward(dy) - - # Check that data tensors match - torch.testing.assert_close(y_load, y_save) - torch.testing.assert_close(x_load.grad, x_save.grad) - - # Optimizer step - optim_save.step() - optim_load.step() - - # Check that parameters match - for param_save, param_load in zip(model_save.parameters(), - model_load.parameters()): - torch.testing.assert_close(param_load, param_save) - -if __name__ == "__main__": - # Assume script has been run with torchrun - common_utils.run_tests() diff --git a/apex/contrib/test/run_rocm_extensions.py b/apex/contrib/test/run_rocm_extensions.py deleted file mode 100644 index c780198..0000000 --- a/apex/contrib/test/run_rocm_extensions.py +++ /dev/null @@ -1,26 +0,0 @@ -import unittest -import sys - - -test_dirs = ["groupbn", "fused_dense", "layer_norm", "multihead_attn", "transducer", "focal_loss", "index_mul_2d", "."] # "." for test_label_smoothing.py -ROCM_BLACKLIST = [ - "layer_norm" -] - -runner = unittest.TextTestRunner(verbosity=2) - -errcode = 0 - -for test_dir in test_dirs: - if test_dir in ROCM_BLACKLIST: - continue - suite = unittest.TestLoader().discover(test_dir) - - print("\nExecuting tests from " + test_dir) - - result = runner.run(suite) - - if not result.wasSuccessful(): - errcode = 1 - -sys.exit(errcode) diff --git a/apex/contrib/test/test_label_smoothing.py b/apex/contrib/test/test_label_smoothing.py deleted file mode 100644 index 70e9f3d..0000000 --- a/apex/contrib/test/test_label_smoothing.py +++ /dev/null @@ -1,128 +0,0 @@ -import torch -from apex.contrib import xentropy as label_smoothing -import unittest - -import warnings -import random -import numpy as np -import time - -def label_smoothing_raw(x, target, padding_idx, smoothing): - logprobs = torch.nn.functional.log_softmax(x, dim=-1, dtype=torch.float32) - - non_pad_mask = (target != padding_idx) - nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) - nll_loss = nll_loss.squeeze(1)[non_pad_mask] - smooth_loss = -logprobs.mean(dim=-1)[non_pad_mask] - loss = (1.0 - smoothing) * nll_loss + smoothing * smooth_loss - return loss - -def label_smoothing_opt_1(x, target, padding_idx, smoothing): - logprobs = torch.nn.functional.log_softmax(x, dim=-1, dtype=torch.float32) - - pad_mask = (target == padding_idx) - ll_loss = logprobs.gather(dim=-1, index=target.unsqueeze(1)).squeeze(1) - smooth_loss = logprobs.mean(dim=-1) - loss = (smoothing - 1.0) * ll_loss - smoothing * smooth_loss - loss.masked_fill_(pad_mask, 0) - return loss - -class LabelSmoothingTest(unittest.TestCase): - def setUp(self, seed=1234): - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - # Set pytorch print precision - torch.set_printoptions(precision=10) - - def gen_test_inputs(self, N, T, H, smoothing, padding_idx): - logits = torch.randn((N*T, H), dtype=torch.half, device='cuda', - requires_grad=True) - labels = torch.randint(0, H, [N*T], device='cuda') - for i in random.sample(range(N*T), N*T//6): - labels[i] = padding_idx - half_to_float = (logits.dtype == torch.half) - - return logits, labels, half_to_float - - def print_max_diff_elem(self, ref, tst): - ref, tst = ref.flatten(), tst.flatten() - diff = (ref - tst).abs().max() - idx = (ref - tst).abs().argmax() - print("Max atol idx: {}, diff: {:.6f}, ref: {:.6f}, tst: {:.6f}".format( - idx, diff, ref[idx], tst[idx])) - - def test_label_smoothing_function(self): - # Set label smoothing configuration - smoothing, padding_idx = 0.1, 0 - N, T, H = 128, 74, 32320 - iters = 10 - loss_func = label_smoothing.SoftmaxCrossEntropyLoss.apply - - for i in range(iters): - logits, labels, half_to_float = self.gen_test_inputs( - N, T, H, smoothing, padding_idx) - - # Run original softmax cross entropy with label smoothing - logits.grad = None - losses = label_smoothing_raw(logits, labels, padding_idx, smoothing) - loss = losses.sum() - loss.backward() - - ref_loss = loss.clone().detach() - ref_grad = logits.grad.clone().detach() - - # Run optimized softmax cross entropy with label smoothing - logits.grad = None - losses = loss_func(logits, labels, smoothing, padding_idx, half_to_float) - loss = losses.sum() - loss.backward() - - val_loss = loss.clone().detach() - val_grad = logits.grad.clone().detach() - - # Validate - self.print_max_diff_elem(ref_grad, val_grad) - self.assertTrue(torch.allclose(ref_loss, val_loss, atol=1e-5, rtol=1e-5)) - self.assertTrue(torch.allclose(ref_grad, val_grad, atol=1e-5, rtol=1e-5)) - - def test_label_smoothing_perf(self): - # Set label smoothing configuration - smoothing, padding_idx = 0.1, 0 - N, T, H = 128, 74, 32320 - iters = 1000 - loss_func = label_smoothing.SoftmaxCrossEntropyLoss.apply - print() - - logits, labels, half_to_float = self.gen_test_inputs( - N, T, H, smoothing, padding_idx) - - # Run original softmax cross entropy with label smoothing - torch.cuda.synchronize() - ts = time.time() - for i in range(iters): - logits.grad = None - losses = label_smoothing_raw(logits, labels, padding_idx, smoothing) - loss = losses.sum() / N - loss.backward() - torch.cuda.synchronize() - print("Raw time {:.2f} s elapsed for {} iterations, norm {:.4f}".format( - time.time() - ts, iters, logits.grad.norm())) - - # Run optimized softmax cross entropy with label smoothing - torch.cuda.synchronize() - ts = time.time() - for i in range(iters): - logits.grad = None - losses = loss_func(logits, labels, smoothing, padding_idx, half_to_float) - loss = losses.sum() / N - loss.backward() - torch.cuda.synchronize() - print("Opt time {:.2f} s elapsed for {} iterations, norm {:.4f}".format( - time.time() - ts, iters, logits.grad.norm())) - -if __name__ == '__main__': - unittest.main() - diff --git a/apex/contrib/test/transducer/test_transducer_joint.py b/apex/contrib/test/transducer/test_transducer_joint.py deleted file mode 100755 index 120865e..0000000 --- a/apex/contrib/test/transducer/test_transducer_joint.py +++ /dev/null @@ -1,163 +0,0 @@ -import torch -import unittest -from apex.contrib.transducer import TransducerJoint -import transducer_ref - -class TransducerJointTest(unittest.TestCase): - def setUp(self, seed=1234): - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - def gen_input(self, for_vector_kernel): - self.B = 4 - T_min = 51 - T_max = 101 - U_min = 12 - U_max = 25 - if for_vector_kernel: - H = 512 - else: - H = 509 - dtype = torch.float16 - device = "cuda" - - self.f_tst = torch.randn((self.B, T_max, H), dtype=dtype, requires_grad=True, device=device) - self.g_tst = torch.randn((self.B, U_max, H), dtype=dtype, requires_grad=True, device=device) - self.h_grad = torch.randn(self.B, T_max, U_max, H, dtype=dtype, device=device) - self.f_len = torch.randint(T_min, T_max+1, (self.B,), dtype=torch.int, device=device) - self.g_len = torch.randint(U_min, U_max+1, (self.B,), dtype=torch.int, device=device) - self.f_len[torch.randint(0, self.B, (1,)).item()] = T_max - self.g_len[torch.randint(0, self.B, (1,)).item()] = U_max - self.dropout_prob = 0.5 - - # Make sure gradients from out-of-bound locations are zero. This should be guaranteed by - # the loss function - for b in range(self.B): - self.h_grad[b, self.f_len[b]:, :, :] = 0 - self.h_grad[b, :, self.g_len[b]:, :] = 0 - self.h_grad_packed = self._pack(self.h_grad, self.f_len, self.g_len) - - - def _pack(self, x, f_len, g_len): - B = x.size(0) - list_x = [] - for b in range(B): - list_x_row = [x[b, t, :g_len[b]] for t in range(f_len[b])] - x_row = torch.cat(list_x_row) - list_x.append(x_row) - x_packed = torch.cat(list_x).data.clone() - x_packed.requires_grad = True - batch_offset = torch.cumsum(f_len * g_len, dim=0) - return x_packed - - def _unpack(self, x, f_len, g_len): - batch_offset = torch.cumsum(f_len * g_len, dim=0) - x_unpacked = torch.zeros_like(self.h_grad, dtype=torch.uint8) - B = self.h_grad.size(0) - H = self.h_grad.size(-1) - for b in range(B): - my_batch_offset = 0 if b == 0 else batch_offset[b-1] - my_f_len = f_len[b] - my_g_len = g_len[b] - for t in range(my_f_len): - x_unpacked[b, t, :my_g_len] = x[my_batch_offset + t*my_g_len : - my_batch_offset + t*my_g_len + my_g_len] - return x_unpacked - - def run_transducer_joint(self, for_vector_kernel, pack_output, relu, dropout): - self.gen_input(for_vector_kernel=for_vector_kernel) - # Generate reference - f_ref = self.f_tst.data.clone() - g_ref = self.g_tst.data.clone() - f_ref.requires_grad = True - g_ref.requires_grad = True - - my_joint = TransducerJoint(pack_output=pack_output, relu=relu, dropout=dropout, - dropout_prob=self.dropout_prob, probe_mask=True) - if not pack_output: - h_tst = my_joint( f=self.f_tst, - g=self.g_tst, - f_len=self.f_len, - g_len=self.g_len) - h_tst.backward(self.h_grad) - if dropout: - mask = my_joint.mask_probe[0] - else: - batch_offset = torch.cumsum(self.f_len * self.g_len, dim=0) - h_tst = my_joint( f=self.f_tst, - g=self.g_tst, - f_len=self.f_len, - g_len=self.g_len, - batch_offset=batch_offset, - packed_batch=batch_offset[-1]) - h_tst.backward(self.h_grad_packed) - if dropout: - mask_packed = my_joint.mask_probe[0] - mask = self._unpack(mask_packed, self.f_len, self.g_len) - - # reference - h_ref, f_grad_ref, g_grad_ref \ - = transducer_ref.transducer_joint_reference(f=f_ref, - g=g_ref, - h_grad=self.h_grad, - f_len=self.f_len, - g_len=self.g_len, - pack_output=pack_output, - relu=relu, - dropout=dropout, - dropout_prob=self.dropout_prob, - mask=mask if dropout else None) - - f_grad_tst = self.f_tst.grad - g_grad_tst = self.g_tst.grad - - self.assertTrue(torch.allclose(h_ref, h_tst, atol=1e-5, rtol=1e-5)) - self.assertTrue(torch.allclose(f_grad_ref, f_grad_tst, atol=1e-5, rtol=1e-5)) - self.assertTrue(torch.allclose(g_grad_ref, g_grad_tst, atol=1e-4, rtol=1e-4)) - - def test_transducer_joint(self): - self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=False, dropout=False) - - def test_transducer_joint_vec(self): - self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=False, dropout=False) - - @unittest.skip("Skipped the test on ROCm. Please also refer to https://github.com/ROCmSoftwarePlatform/apex/issues/89") - def test_transducer_joint_pack(self): - self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=False, dropout=False) - - def test_transducer_joint_vec_pack(self): - self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=False, dropout=False) - - def test_transducer_joint_relu(self): - self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=False) - - def test_transducer_joint_vec_relu(self): - self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=True, dropout=False) - - @unittest.skip("Skipped the test on ROCm. Please also refer to https://github.com/ROCmSoftwarePlatform/apex/issues/89") - def test_transducer_joint_pack_relu(self): - self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=True, dropout=False) - - def test_transducer_joint_vec_pack_relu(self): - self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=False) - - @unittest.skip("Skipped the test on ROCm. Please also refer to https://github.com/ROCmSoftwarePlatform/apex/issues/89") - def test_transducer_joint_relu_dropout(self): - self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=True) - - @unittest.skip("Skipped the test on ROCm. Please also refer to https://github.com/ROCmSoftwarePlatform/apex/issues/89") - def test_transducer_joint_vec_relu_dropout(self): - self.run_transducer_joint(for_vector_kernel=True, pack_output=False, relu=True, dropout=True) - - @unittest.skip("Skipped the test on ROCm. Please also refer to https://github.com/ROCmSoftwarePlatform/apex/issues/89") - def test_transducer_joint_pack_relu_dropout(self): - self.run_transducer_joint(for_vector_kernel=False, pack_output=True, relu=True, dropout=True) - - @unittest.skip("Skipped the test on ROCm. Please also refer to https://github.com/ROCmSoftwarePlatform/apex/issues/89") - def test_transducer_joint_vec_pack_relu_dropout(self): - self.run_transducer_joint(for_vector_kernel=True, pack_output=True, relu=True, dropout=True) - - - -if __name__ == '__main__': - unittest.main() diff --git a/apex/contrib/test/transducer/test_transducer_loss.py b/apex/contrib/test/transducer/test_transducer_loss.py deleted file mode 100755 index 82f5bd3..0000000 --- a/apex/contrib/test/transducer/test_transducer_loss.py +++ /dev/null @@ -1,133 +0,0 @@ -import torch -import unittest -from apex.contrib.transducer import TransducerLoss -import transducer_ref - -class TransducerLossTest(unittest.TestCase): - def setUp(self, seed=1234): - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - def gen_input(self, scalar_t, for_vector_kernel): - self.B = 5 - T_min = 23 - T_max = 51 - U_min = 12 - U_max = 25 - V = 16 if for_vector_kernel else 14 - self.blank_idx = V - 1 - device = "cuda" - - self.x_tst = torch.randn((self.B, T_max, U_max, V), dtype=scalar_t, requires_grad=True, - device=device) - self.y = torch.randint(0, self.blank_idx, (self.B, U_max-1), dtype=torch.int, device=device) - self.f_len = torch.randint(T_min, T_max+1, (self.B,), dtype=torch.int, device=device) - self.y_len = torch.randint(U_min-1, U_max, (self.B,), dtype=torch.int, device=device) - self.f_len[torch.randint(0, self.B, (1,)).item()] = T_max - self.y_len[torch.randint(0, self.B, (1,)).item()] = U_max-1 - self.x_tst_packed, self.batch_offset = self._pack(self.x_tst) - # Generate reference - x_ref = self.x_tst.data.clone() - x_ref.requires_grad = True - loss_grad = torch.ones(x_ref.size(0), dtype=x_ref.dtype, device=x_ref.device)/x_ref.size(0) - _, _, self.grad_ref, self.loss_ref \ - = transducer_ref.transducer_loss_reference( x=x_ref, - label=self.y, - f_len=self.f_len, - y_len=self.y_len, - blank_idx=self.blank_idx, - loss_grad=loss_grad) - - def _pack(self, x): - list_x = [] - for b in range(self.B): - list_x_row = [x[b, t, : self.y_len[b]+1] for t in range(self.f_len[b])] - x_row = torch.cat(list_x_row) - list_x.append(x_row) - x_packed = torch.cat(list_x).data.clone() - x_packed.requires_grad = True - batch_offset = torch.cumsum(self.f_len * (self.y_len+1), dim=0) - return x_packed, batch_offset - - def _unpack(self, x): - x_unpacked = torch.zeros(self.B, self.f_len.max(), self.y_len.max()+1, x.size(-1), - dtype=x.dtype, device=x.device) - for b in range(self.B): - my_batch_offset = 0 if b == 0 else self.batch_offset[b-1] - my_f_len = self.f_len[b] - my_g_len = self.y_len[b] + 1 - for t in range(my_f_len): - for u in range(my_g_len): - x_unpacked[b, t, u] = x[my_batch_offset + t*my_g_len + u] - return x_unpacked - - def run_transducer_loss(self, scalar_t, fuse_softmax_backward, packed_input, for_vector_kernel): - self.gen_input(scalar_t, for_vector_kernel) - my_loss = TransducerLoss( fuse_softmax_backward=fuse_softmax_backward, - packed_input=packed_input) - if not packed_input: - loss_tst = my_loss( x=self.x_tst, - label=self.y, - f_len=self.f_len, - y_len=self.y_len, - blank_idx=self.blank_idx) - loss_tst.mean().backward() - grad_tst = self.x_tst.grad - else: - loss_tst = my_loss( x=self.x_tst_packed, - label=self.y, - f_len=self.f_len, - y_len=self.y_len, - blank_idx=self.blank_idx, - batch_offset=self.batch_offset, - max_f_len=max(self.f_len)) - loss_tst.mean().backward() - grad_tst_packed = self.x_tst_packed.grad - grad_tst = self._unpack(grad_tst_packed) - - return loss_tst, grad_tst - - def test_transducer_loss_fp32(self): - loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float32, - fuse_softmax_backward=False, - packed_input=False, - for_vector_kernel=False) - self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5)) - self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-5, rtol=1e-5)) - - def test_transducer_loss_fp16(self): - loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float16, - fuse_softmax_backward=False, - packed_input=False, - for_vector_kernel=False) - self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5)) - self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3)) - - def test_transducer_loss_fp16_backward_fusion(self): - loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float16, - fuse_softmax_backward=True, - packed_input=False, - for_vector_kernel=False) - self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5)) - self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3)) - - def test_transducer_loss_fp16_backward_fusion_packed(self): - loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float16, - fuse_softmax_backward=True, - packed_input=True, - for_vector_kernel=False) - self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5)) - self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3)) - - def test_transducer_loss_fp16_backward_fusion_packed_vec(self): - loss_tst, grad_tst = self.run_transducer_loss( scalar_t=torch.float16, - fuse_softmax_backward=True, - packed_input=True, - for_vector_kernel=True) - self.assertTrue(torch.allclose(self.loss_ref, loss_tst, atol=1e-5, rtol=1e-5)) - self.assertTrue(torch.allclose(self.grad_ref, grad_tst, atol=1e-4, rtol=1e-3)) - - - -if __name__ == '__main__': - unittest.main() \ No newline at end of file diff --git a/apex/contrib/test/transducer/transducer_ref.py b/apex/contrib/test/transducer/transducer_ref.py deleted file mode 100755 index de34279..0000000 --- a/apex/contrib/test/transducer/transducer_ref.py +++ /dev/null @@ -1,112 +0,0 @@ -import torch -import numpy as np -import pdb - -def transducer_loss_reference(x, label, f_len, y_len, blank_idx, loss_grad): - def log_sum_exp(a, b): - if (a >= b): - return a + torch.log(1 + torch.exp(b-a)) - else: - return b + torch.log(1 + torch.exp(a-b)) - - def forward_alpha(x, label, f_len, y_len, blank_idx): - B, T, U, V = x.size() - acc_t = torch.float32 if x.dtype in [torch.float16, torch.float32] else x.dtype - alpha = torch.zeros((B, T, U), dtype=acc_t, device=x.device) - for b in range(B): - alpha[b, 0, 0] = 0 - for t in range(1, f_len[b]): - alpha[b, t, 0] = alpha[b, t-1, 0] + x[b, t-1, 0, blank_idx] - for u in range(1, y_len[b]+1): - alpha[b, 0, u] = alpha[b, 0, u-1] + x[b, 0, u-1, label[b, u-1]] - for t in range(1, f_len[b]): - for u in range(1, y_len[b]+1): - curr_ = alpha[b, t-1, u] + x[b, t-1, u, blank_idx] - next_ = alpha[b, t, u-1] + x[b, t, u-1, label[b, u-1]] - alpha[b, t, u] = log_sum_exp(curr_, next_) - return alpha - - def forward_beta(x, label, f_len, y_len, blank_idx): - B, T, U, V = x.shape - acc_t = torch.float32 if x.dtype in [torch.float16, torch.float32] else x.dtype - beta = torch.zeros((B, T, U), dtype=acc_t, device=x.device) - for b in range(B): - beta[b, f_len[b]-1, y_len[b]] = x[b, f_len[b]-1, y_len[b], blank_idx] - for t in range(f_len[b]-2, -1, -1): - beta[b, t, y_len[b]] = beta[b, t+1, y_len[b]] + x[b, t, y_len[b], blank_idx] - for u in range(y_len[b]-1, -1, -1): - beta[b, f_len[b]-1, u] = beta[b, f_len[b]-1, u+1] + x[b, f_len[b]-1, u, label[b, u]] - for t in range(f_len[b]-2, -1, -1): - for u in range(y_len[b]-1, -1, -1): - curr_ = beta[b, t+1, u] + x[b, t, u, blank_idx] - next_ = beta[b, t, u+1] + x[b, t, u, label[b, u]] - beta[b, t, u] = log_sum_exp(curr_, next_) - return beta - - def backward(x, label, f_len, y_len, alpha, beta, loss_grad, blank_idx): - grad = torch.zeros_like(x) - B, T, U, V = x.size() - for b in range(B): - common_factor = torch.log(loss_grad[b]) + alpha - beta[b, 0, 0] - # next - for u in range(y_len[b]): - grad[b, :f_len[b], u, label[b, u]] = -torch.exp(common_factor[b, :f_len[b], u] - + beta[b, :f_len[b], u+1] - + x[b, :f_len[b], u, label[b, u]]) - - # current - grad[b, :f_len[b]-1, :y_len[b]+1, blank_idx] \ - = -torch.exp(common_factor[b, :f_len[b]-1, :y_len[b]+1] - + beta[b, 1:f_len[b], :y_len[b]+1] - + x[b, :f_len[b]-1, :y_len[b]+1, blank_idx]) - - grad[b, f_len[b]-1, y_len[b], blank_idx] = -torch.exp(common_factor[b, f_len[b]-1, y_len[b]] - + x[b, f_len[b]-1, y_len[b], blank_idx]) - - return grad - - x_log = torch.nn.functional.log_softmax(x, dim=-1) - alpha = forward_alpha(x_log, label, f_len, y_len, blank_idx) - beta = forward_beta(x_log, label, f_len, y_len, blank_idx) - grad = backward(x_log, label, f_len, y_len, alpha, beta, - loss_grad, blank_idx) - x_log.backward(grad) - loss = -beta[:, 0, 0] - loss = loss.to(x.dtype) - return alpha, beta, x.grad, loss - - -def transducer_joint_reference(f, g, h_grad, f_len, g_len, pack_output, relu, dropout, - dropout_prob=0, mask=None): - if dropout and mask == None: - raise NotImplementedError("mask needs to supplied to test dropout.") - B, T, H = f.size() - U = g.size(1) - f_expand = f.unsqueeze(dim=2) - g_expand = g.unsqueeze(dim=1) - h = f_expand + g_expand - if relu: - h = torch.nn.functional.relu(h) - if dropout: - h *= mask - scale = 1/(1-dropout_prob) - h *= scale - h.backward(h_grad) - - if pack_output == False: - # intentionally set don't-care region to -1 to test if transducer joint - # write these regions to avoid NaN and inf - for b in range(B): - h[b, f_len[b]:] = -1 - h[b, :, g_len[b]:] = -1 - - return h, f.grad, g.grad - - # packing - list_to_pack = [] - for b in range(B): - list_to_pack.append(h[b, :f_len[b], :g_len[b], :].reshape(-1, H)) - h_packed = torch.cat(list_to_pack) - return h_packed, f.grad, g.grad - - diff --git a/apex/contrib/transducer/__init__.py b/apex/contrib/transducer/__init__.py deleted file mode 100755 index bd5dbf6..0000000 --- a/apex/contrib/transducer/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .transducer import TransducerJoint -from .transducer import TransducerLoss \ No newline at end of file diff --git a/apex/contrib/transducer/transducer.py b/apex/contrib/transducer/transducer.py deleted file mode 100755 index 7843962..0000000 --- a/apex/contrib/transducer/transducer.py +++ /dev/null @@ -1,195 +0,0 @@ -import torch -import transducer_loss_cuda -import transducer_joint_cuda - -class TransducerJoint(torch.nn.Module): - """Transducer joint - Detail of this loss function can be found in: Sequence Transduction with Recurrent Neural - Networks - - Arguments: - pack_output (bool, optional): whether to pack the output in a compact form with don't-care - data being removed. (default: False) - relu (bool, optional): apply ReLU to the output of the joint operation. Requires opt=1 - (default: False) - dropout (bool, optional): apply dropout to the output of the joint operation. Requires opt=1 - (default: False) - opt (int, optional): pick the optimization level in [0, 1]. opt=1 picks a tiled algorithm. - (default: 1) - fwd_tile_size (int, optional): tile size used in forward operation. This argument will be - ignored if opt != 1. (default: 4) - dropout_prob (float, optional): dropout probability. (default: 0.0) - probe_mask (bool, optional): a flag used to probe the mask generated by ReLU and/or dropout - operation. When this argument is set to True, the mask can be accessed through - self.mask_probe. (default: false) - """ - - def __init__(self, pack_output=False, relu=False, dropout=False, opt=1, fwd_tile_size=4, - dropout_prob=0, probe_mask=False): - super(TransducerJoint, self).__init__() - self.pack_output = pack_output - self.relu = relu - self.dropout = dropout - self.dropout_prob = dropout_prob - self.opt = opt - self.fwd_tile_size = fwd_tile_size - self.dummy_batch_offset = torch.empty(0) - masked = self.relu or self.dropout - self.mask_probe = [] if masked and probe_mask else None - if masked and opt != 1: - raise NotImplementedError("ReLU and dropout fusion is only supported with opt=1") - - - def forward(self, f, g, f_len, g_len, batch_offset=None, packed_batch=0): - """Forward operation of transducer joint - - Arguments: - f (tensor): transcription vector from encode block of shape (B, T, H). - g (tensor): prediction vector form predict block of shape (B, U, H). - f_len (tensor): length of transcription vector for each batch. - g_len (tensor): length of prediction vector minus 1 for each batch. - batch_offset (tensor, optional): tensor containing the offset of each batch - in the results. For example, batch offset can be obtained from: - batch_offset = torch.cumsum(f_len*g_len, dim=0) - This argument is required if pack_output == True, and is ignored if - pack_output == False. (default: None) - packed_batch (int, optional): the batch size after packing. This argument is - ignored if pack_output == False. (default: 0) - """ - my_batch_offset = batch_offset if self.pack_output else self.dummy_batch_offset - if self.pack_output and (batch_offset is None or packed_batch == 0): - raise Exception("Please specify batch_offset and packed_batch when packing is enabled") - dropout = self.dropout and self.training # only dropout for training - return TransducerJointFunc.apply(f, g, f_len, g_len, self.pack_output, self.relu, dropout, - my_batch_offset, packed_batch, self.opt, - self.fwd_tile_size, self.dropout_prob, self.mask_probe) - - -class TransducerLoss(torch.nn.Module): - """Transducer loss - Detail of this loss function can be found in: Sequence Transduction with Recurrent Neural - Networks - - Arguments: - fuse_softmax_backward (bool, optional) whether to fuse the backward of transducer loss with - softmax. (default: True) - opt (int, optional): pick the optimization level in [0, 1]. opt=1 picks a more optimized - algorithm. In some cases, opt=1 might fall back to opt=0. (default: 1) - packed_input (bool, optional): whether to pack the output in a compact form with don't-care - data being removed. (default: False) - """ - def __init__(self, fuse_softmax_backward=True, opt=1, packed_input=False): - super(TransducerLoss, self).__init__() - self.fuse_softmax_backward = fuse_softmax_backward - self.opt = opt - self.packed_input = packed_input - self.dummy_batch_offset = torch.empty(0) - - - def forward(self, x, label, f_len, y_len, blank_idx, batch_offset=None, max_f_len=None, - debug_list=None): - """Forward operation of transducer joint - - Arguments: - x (tensor): input tensor to the loss function with a shape of (B, T, U, H). - label (tensor): labels for the input data. - f_len (tensor): lengths of the inputs in the time dimension for each batch. - y_len (tensor): lengths of the labels for each batch. - blank_idx (int): index for the null symbol. - batch_offset (tensor, optional): tensor containing the offset of each batch - in the input. For example, batch offset can be obtained from: - batch_offset = torch.cumsum(f_len*(y_len+1), dim=0) - This argument is required if packed_input == True, and is ignored if - packed_input == False. (default: None) - max_f_len (int, optional): maximum length of the input in the time dimension. - For example, it can be obtained as - max_f_len = max(f_len) - This argument is required if packed_input == True, and is ignored if - packed_input == False. (default: None) - (default: None) - debug_list (list, optional): when an empty list is supplied, Alpha and Beta generated - in the forward operation will be attached to this list for debug purpose. - (default: None) - """ - if self.packed_input: - if batch_offset is None or max_f_len is None: - raise Exception("Please specify batch_offset and max_f_len when packing is \ - enabled") - my_batch_offset = batch_offset - my_max_f_len = max_f_len - else: - my_batch_offset = self.dummy_batch_offset - my_max_f_len = x.size(1) - return TransducerLossFunc.apply(x, label, f_len, y_len, my_batch_offset, my_max_f_len, - blank_idx, self.fuse_softmax_backward, debug_list, - self.opt, self.packed_input) - -class TransducerLossFunc(torch.autograd.Function): - @staticmethod - def forward(ctx, x, label, f_len, y_len, batch_offset, max_f_len, blank_idx, - fuse_softmax_backward, debug_list, opt, packed_input): - if fuse_softmax_backward == False: - with torch.enable_grad(): - x = torch.nn.functional.log_softmax(x, dim=-1) - else: - x = torch.nn.functional.log_softmax(x, dim=-1) - alpha, beta, loss = transducer_loss_cuda.forward( x, label, f_len, y_len, batch_offset, - max_f_len, blank_idx, opt, packed_input) - if debug_list == []: - debug_list += [alpha, beta] - ctx.save_for_backward(x, alpha, beta, f_len, y_len, label, batch_offset) - ctx.blank_idx = blank_idx - ctx.fuse_softmax_backward = fuse_softmax_backward - ctx.opt = opt - ctx.packed_input = packed_input - ctx.max_f_len = max_f_len - return loss - - @staticmethod - def backward(ctx, loss_grad): - x, alpha, beta, f_len, y_len, label, batch_offset = ctx.saved_tensors - x_grad = transducer_loss_cuda.backward( x, loss_grad, alpha, beta, f_len, y_len, label, - batch_offset, ctx.max_f_len, ctx.blank_idx, ctx.opt, - ctx.fuse_softmax_backward, ctx.packed_input) - if ctx.fuse_softmax_backward == False: - x_grad = x.backward(x_grad) - return x_grad, None, None, None, None, None, None, None, None, None, None - -class TransducerJointFunc(torch.autograd.Function): - @staticmethod - def forward(ctx, f, g, f_len, g_len, pack_output, relu, dropout, batch_offset, packed_batch, - opt, fwd_tile_size, dropout_prob, mask_probe): - h = transducer_joint_cuda.forward(f, g, f_len, g_len, batch_offset, packed_batch, opt, - pack_output, relu, dropout, dropout_prob, fwd_tile_size) - masked = relu or dropout - if masked: - ctx.save_for_backward(h[1], f_len, g_len, batch_offset) - if mask_probe is not None: - mask_probe.append(h[1]) - else: - ctx.save_for_backward(f_len, g_len, batch_offset) - - ctx.pack_output = pack_output - ctx.masked = relu or dropout - ctx.max_f_len = f.size(1) - ctx.max_g_len = g.size(1) - ctx.scale = 1 / (1-dropout_prob) if dropout and dropout_prob != 1 else 1 - return h[0] - - @staticmethod - def backward(ctx, loss_grad): - if ctx.masked: - mask, f_len, g_len, batch_offset = ctx.saved_tensors - inp = [loss_grad, mask] - else: - f_len, g_len, batch_offset = ctx.saved_tensors - inp = [loss_grad] - - f_grad, g_grad = transducer_joint_cuda.backward( inp, f_len, g_len, batch_offset, - ctx.max_f_len, ctx.max_g_len, - ctx.pack_output, ctx.scale) - - return f_grad, g_grad, None, None, None, None, None, None, None, None, None, None, None, \ - None, None, None - - diff --git a/apex/contrib/xentropy/__init__.py b/apex/contrib/xentropy/__init__.py deleted file mode 100644 index 7dff6a2..0000000 --- a/apex/contrib/xentropy/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -try: - import torch - import xentropy_cuda - from .softmax_xentropy import SoftmaxCrossEntropyLoss - del torch - del xentropy_cuda - del softmax_xentropy -except ImportError as err: - print("apex was installed without --xentropy flag, contrib.xentropy is not available") diff --git a/apex/contrib/xentropy/softmax_xentropy.py b/apex/contrib/xentropy/softmax_xentropy.py deleted file mode 100644 index 33fbf8b..0000000 --- a/apex/contrib/xentropy/softmax_xentropy.py +++ /dev/null @@ -1,28 +0,0 @@ -import torch -import xentropy_cuda - -class SoftmaxCrossEntropyLoss(torch.autograd.Function): - @staticmethod - def forward(ctx, logits, labels, smoothing=0.0, padding_idx=0, half_to_float=False): - losses, max_log_sum_exp = xentropy_cuda.forward( - logits, labels, smoothing, half_to_float) - losses.masked_fill_(labels==padding_idx, 0) - - ctx.save_for_backward(logits, max_log_sum_exp, labels, - torch.FloatTensor([smoothing]), - torch.LongTensor([padding_idx])) - - return losses - - @staticmethod - def backward(ctx, grad_loss): - logits, max_log_sum_exp, labels, smoothing, padding_idx = ctx.saved_tensors - - if not grad_loss.is_contiguous(): - grad_loss = grad_loss.contiguous() - grad_loss.masked_fill_(labels==padding_idx.item(), 0) - grad_logits = xentropy_cuda.backward( - grad_loss.contiguous(), logits, max_log_sum_exp, - labels, smoothing.item()) - - return grad_logits, None, None, None, None diff --git a/apex/fp16_utils/README.md b/apex/fp16_utils/README.md deleted file mode 100644 index 941de17..0000000 --- a/apex/fp16_utils/README.md +++ /dev/null @@ -1,16 +0,0 @@ -fp16_optimizer.py contains `FP16_Optimizer`, a Python class designed to wrap an existing Pytorch optimizer and automatically enable master parameters and loss scaling in a manner transparent to the user. To use `FP16_Optimizer`, only two lines of one's Python model need to change. - -#### [FP16_Optimizer API documentation](https://nvidia.github.io/apex/fp16_utils.html#automatic-management-of-master-params-loss-scaling) - -#### [Simple examples with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/FP16_Optimizer_simple) - -#### [Imagenet with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/imagenet) - -#### [word_language_model with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/word_language_model) - - -fp16_util.py contains a number of utilities to manually manage master parameters and loss scaling, if the user chooses. - -#### [Manual management documentation](https://nvidia.github.io/apex/fp16_utils.html#manual-master-parameter-management) - -The [Imagenet with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/imagenet) and [word_language_model with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/word_language_model) directories also contain `main.py` files that demonstrate manual management of master parameters and static loss scaling. These examples illustrate what sort of operations `FP16_Optimizer` is performing automatically. diff --git a/apex/fp16_utils/__init__.py b/apex/fp16_utils/__init__.py deleted file mode 100644 index c7bb1f5..0000000 --- a/apex/fp16_utils/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -from .fp16util import ( - BN_convert_float, - network_to_half, - prep_param_lists, - model_grads_to_master_grads, - master_params_to_model_params, - tofp16, - to_python_float, - clip_grad_norm, - convert_module, - convert_network, - FP16Model, -) - -from .fp16_optimizer import FP16_Optimizer -from .loss_scaler import LossScaler, DynamicLossScaler diff --git a/apex/fp16_utils/fp16_optimizer.py b/apex/fp16_utils/fp16_optimizer.py deleted file mode 100755 index 7c0dd39..0000000 --- a/apex/fp16_utils/fp16_optimizer.py +++ /dev/null @@ -1,554 +0,0 @@ -import torch -from torch import nn -from torch.autograd import Variable -from torch.nn.parameter import Parameter -from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors - -from ..amp._amp_state import _amp_state, maybe_print -from ..amp.scaler import LossScaler -from ..multi_tensor_apply import multi_tensor_applier -from .fp16util import model_grads_to_master_grads, master_params_to_model_params, clip_grad_norm - -# TODO: Update overflow check + downscale to use Carl's fused kernel. -class FP16_Optimizer(object): - def __init__(self, - init_optimizer, - static_loss_scale=1.0, - dynamic_loss_scale=False, - dynamic_loss_args=None, - verbose=True): - print("Warning: FP16_Optimizer is deprecated and dangerous, and will be deleted soon. " - "If it still works, you're probably getting lucky. " - "For mixed precision, use the documented API https://nvidia.github.io/apex/amp.html, with opt_level=O1.") - - if not torch.cuda.is_available: - raise SystemError("Cannot use fp16 without CUDA.") - - self.verbose = verbose - - self.optimizer = init_optimizer - # init_state_dict sets up an alternative way to cast per-param state tensors. - # Stashing here in case https://github.com/pytorch/pytorch/issues/7733 makes it necessary. - # init_state_dict = init_optimizer.state_dict() - - self.fp16_groups = [] - self.fp32_from_fp16_groups = [] - self.fp32_from_fp32_groups = [] - for i, param_group in enumerate(self.optimizer.param_groups): - self.maybe_print("FP16_Optimizer processing param group {}:".format(i)) - fp16_params_this_group = [] - fp32_params_this_group = [] - fp32_from_fp16_params_this_group = [] - for i, param in enumerate(param_group['params']): - if param.requires_grad: - if param.type() == 'torch.cuda.HalfTensor': - self.maybe_print("FP16_Optimizer received torch.cuda.HalfTensor with {}" - .format(param.size())) - fp16_params_this_group.append(param) - master_param = param.detach().clone().float() - master_param.requires_grad = True - param_group['params'][i] = master_param - fp32_from_fp16_params_this_group.append(master_param) - # Reset existing state dict key to the new master param. - # We still need to recast per-param state tensors, if any, to FP32. - if param in self.optimizer.state: - self.optimizer.state[master_param] = self.optimizer.state.pop(param) - elif param.type() == 'torch.cuda.FloatTensor': - self.maybe_print("FP16_Optimizer received torch.cuda.FloatTensor with {}" - .format(param.size())) - fp32_params_this_group.append(param) - param_group['params'][i] = param - else: - raise TypeError("Wrapped parameters must be either " - "torch.cuda.FloatTensor or torch.cuda.HalfTensor. " - "Received {}".format(param.type())) - - self.fp16_groups.append(fp16_params_this_group) - self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group) - self.fp32_from_fp32_groups.append(fp32_params_this_group) - - self.all_fp16_params = [] - for group in self.fp16_groups: - self.all_fp16_params += group - - self.all_fp32_from_fp16_params = [] - for group in self.fp32_from_fp16_groups: - self.all_fp32_from_fp16_params += group - - self.all_fp32_from_fp32_params = [] - for group in self.fp32_from_fp32_groups: - self.all_fp32_from_fp32_params += group - - # Leverage state_dict() and load_state_dict() to recast preexisting per-param state tensors - self.optimizer.load_state_dict(self.optimizer.state_dict()) - # alternative way to cast per-param state tensors: - # self.optimizer.load_state_dict(init_state_dict) - - if dynamic_loss_scale: - self.dynamic_loss_scale = True - if dynamic_loss_args is not None: - self.loss_scaler = LossScaler("dynamic", **dynamic_loss_args) - else: - self.loss_scaler = LossScaler("dynamic") - else: - self.dynamic_loss_scale = False - self.loss_scaler = LossScaler(static_loss_scale) - - self.overflow = False - self.first_closure_call_this_step = True - - self.clip_grad_norm = clip_grad_norm - - # TODO: Centralize exposure and import error checking for the C backend. - if multi_tensor_applier.available: - import amp_C - self.multi_tensor_scale = amp_C.multi_tensor_scale - self._dummy_overflow_buf = torch.cuda.IntTensor([0]); - - # Having self.maybe_print distinct from _amp_state.maybe_print is another artifact - # of having to support FP16_Optimizer separately, for the time being. - def maybe_print(self, msg): - if self.verbose: - print(msg) - - def __getstate__(self): - raise RuntimeError("FP16_Optimizer should be serialized using state_dict().") - - def __setstate__(self, state): - raise RuntimeError("FP16_Optimizer should be deserialized using load_state_dict().") - - def zero_grad(self, set_grads_to_None=False): - """ - Zero fp32 and fp16 parameter grads. - """ - # In principle, only the .grad attributes of the model params need to be zeroed, - # because gradients are copied into the FP32 master params. However, we zero - # all gradients owned by the optimizer, just to be safe: - for group in self.optimizer.param_groups: - for p in group['params']: - if set_grads_to_None: - p.grad = None - else: - if p.grad is not None: - p.grad.detach_() - p.grad.zero_() - - # Zero fp16 gradients owned by the model: - for fp16_group in self.fp16_groups: - for param in fp16_group: - if set_grads_to_None: - param.grad = None - else: - if param.grad is not None: - param.grad.detach_() # as in torch.optim.optimizer.zero_grad() - param.grad.zero_() - - # Should not be used anymore. - # def _check_overflow(self): - # params = [] - # for group in self.fp16_groups: - # for param in group: - # params.append(param) - # for group in self.fp32_from_fp32_groups: - # for param in group: - # params.append(param) - # self.overflow = self.loss_scaler.has_overflow(params) - - # def _update_scale(self, has_overflow=False): - # self.loss_scaler.update_scale(has_overflow) - - def _master_params_to_model_params(self): - if multi_tensor_applier.available: - if len(self.all_fp16_params) > 0: - multi_tensor_applier( - self.multi_tensor_scale, - self._dummy_overflow_buf, - [self.all_fp32_from_fp16_params, self.all_fp16_params], - 1.0) - else: - for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups): - master_params_to_model_params(fp16_group, fp32_from_fp16_group) - - # To consider: Integrate distributed with this wrapper by registering a hook on each variable - # that does the overflow check, gradient copy + downscale, and fp32 allreduce in a different stream. - # def _model_grads_to_master_grads(self): - # for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups): - # model_grads_to_master_grads(fp16_group, fp32_from_fp16_group) - - # def _downscale_master(self): - # if self.loss_scale != 1.0: - # for group in self.optimizer.param_groups: - # for param in group['params']: - # if param.grad is not None: - # param.grad.data.mul_(1./self.loss_scale) - - def clip_master_grads(self, max_norm, norm_type=2): - """ - Clips fp32 master gradients via ``torch.nn.utils.clip_grad_norm``. - - Args: - max_norm (float or int): max norm of the gradients - norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for - infinity norm. - - Returns: - Total norm of the current fp32 gradients (viewed as a single vector). - - .. warning:: - Returns -1 if the most recently computed fp16 gradients overflowed (that is, if ``self.overflow`` is ``True``). - """ - if not self.overflow: - fp32_params = [] - for param_group in self.optimizer.param_groups: - for param in param_group['params']: - fp32_params.append(param) - return self.clip_grad_norm(fp32_params, max_norm, norm_type) - else: - return -1 - - def state_dict(self): - """ - Returns a dict containing the current state of this :class:`FP16_Optimizer` instance. - This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict - of the contained Pytorch optimizer. - Example:: - - checkpoint = {} - checkpoint['model'] = model.state_dict() - checkpoint['optimizer'] = optimizer.state_dict() - torch.save(checkpoint, "saved.pth") - """ - state_dict = {} - state_dict['loss_scaler'] = self.loss_scaler - state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale - state_dict['overflow'] = self.overflow - state_dict['first_closure_call_this_step'] = self.first_closure_call_this_step - state_dict['optimizer_state_dict'] = self.optimizer.state_dict() - state_dict['fp32_from_fp16'] = self.fp32_from_fp16_groups - return state_dict - - def load_state_dict(self, state_dict): - """ - Loads a state_dict created by an earlier call to state_dict(). - If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``, - whose parameters in turn came from ``model``, it is expected that the user - will call ``model.load_state_dict()`` before - ``fp16_optimizer_instance.load_state_dict()`` is called. - - Example:: - - model = torch.nn.Linear(D_in, D_out).cuda().half() - optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) - optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) - ... - checkpoint = torch.load("saved.pth") - model.load_state_dict(checkpoint['model']) - optimizer.load_state_dict(checkpoint['optimizer']) - """ - # I think it should actually be ok to reload the optimizer before the model. - self.loss_scaler = state_dict['loss_scaler'] - self.dynamic_loss_scale = state_dict['dynamic_loss_scale'] - self.overflow = state_dict['overflow'] - self.first_closure_call_this_step = state_dict['first_closure_call_this_step'] - self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) - # At this point, the optimizer's references to the model's fp32 parameters are up to date. - # The optimizer's hyperparameters and internal buffers are also up to date. - # However, the fp32 master copies of the model's fp16 params stored by the optimizer are still - # out of date. There are two options. - # 1: Refresh the master params from the model's fp16 params. - # This requires less storage but incurs precision loss. - # 2: Save and restore the fp32 master copies separately. - # We choose option 2. - # - # Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device - # of their associated parameters, because it's possible those buffers might not exist yet in - # the current optimizer instance. In our case, as long as the current FP16_Optimizer has been - # constructed in the same way as the one whose state_dict we are loading, the same master params - # are guaranteed to exist, so we can just copy_() from the saved master params. - for current_group, saved_group in zip(self.fp32_from_fp16_groups, state_dict['fp32_from_fp16']): - for current, saved in zip(current_group, saved_group): - current.data.copy_(saved.data) - - def step(self, closure=None): # could add clip option. - """ - If no closure is supplied, :attr:`step` should be called after - ``fp16_optimizer_obj.backward(loss)``. - :attr:`step` updates the fp32 master copy of parameters using the optimizer supplied to - :class:`FP16_Optimizer`'s constructor, then copies the updated fp32 params into the fp16 params - originally referenced by :class:`FP16_Optimizer`'s constructor, so the user may immediately run - another forward pass using their model. - - If a closure is supplied, :attr:`step` may be called without a prior call to - :attr:`backward(loss)`. - This control flow is identical to `ordinary Pytorch optimizer use`_ with closures. - However, the user should take care that any ``loss.backward()`` call within the closure - has been replaced by ``fp16_optimizer_obj.backward(loss)``. - - Args: - closure (optional): Closure that will be supplied to the underlying optimizer originally passed to :class:`FP16_Optimizer`'s constructor. closure should call :attr:`zero_grad()` on the :class:`FP16_Optimizer` object, compute the loss, call :attr:`backward(loss)`, and return the loss. - - Example with closure:: - - # optimizer is assumed to be an FP16_Optimizer object, previously constructed from an - # existing pytorch optimizer. - for input, target in dataset: - def closure(): - optimizer.zero_grad() - output = model(input) - loss = loss_fn(output, target) - # loss.backward() becomes: - optimizer.backward(loss) - return loss - optimizer.step(closure) - - .. warning:: - Currently, calling :attr:`step` with a closure is not compatible with dynamic loss scaling. - - .. _`ordinary Pytorch optimizer use`: - http://pytorch.org/docs/master/optim.html#optimizer-step-closure - """ - - scale = self.loss_scaler.loss_scale() - # To consider: Should this be in step(), or update_master_grads? It works either way, - # but I should make it consistent with the Amp control flow, which updates the scale - # during backward context manager exit. - # self._update_scale(self.overflow) - - if self.overflow: - # Using _amp_state.maybe_print instead of self.print here is intentional. - maybe_print("Gradient overflow. Skipping step, reducing " + - "loss scale to {}".format(self.loss_scaler.loss_scale())) - return - - if closure is not None: - retval = self._step_with_closure(closure) - else: - # torch.cuda.nvtx.range_push("pytorch optimizer step") - retval = self.optimizer.step() - # torch.cuda.nvtx.range_pop() - - self._master_params_to_model_params() - - return retval - - def _step_with_closure(self, closure): - def wrapped_closure(): - # helpful for debugging - # print("Calling wrapped_closure, first_closure_call_this_step = {}" - # .format(self.first_closure_call_this_step)) - if self.first_closure_call_this_step: - # We expect that the fp16 params are initially fresh on entering self.step(), - # so _master_params_to_model_params() is unnecessary the first time wrapped_closure() - # is called within self.optimizer.step(). - self.first_closure_call_this_step = False - else: - # If self.optimizer.step() internally calls wrapped_closure more than once, - # it may update the fp32 params after each call. However, self.optimizer - # doesn't know about the fp16 params at all. If the fp32 params get updated, - # we can't rely on self.optimizer to refresh the fp16 params. We need - # to handle that manually: - self._master_params_to_model_params() - # Our API expects the user to give us ownership of the backward() call by - # replacing all calls to loss.backward() with optimizer.backward(loss). - # This requirement holds whether or not the call to backward() is made within a closure. - # If the user is properly calling optimizer.backward(loss) within "closure," - # calling closure() here will give the fp32 master params fresh gradients - # for the optimizer to play with, so all wrapped_closure needs to do is call - # closure() and return the loss. - temp_loss = closure() - while(self.overflow): - scale = self.loss_scaler.loss_scale() - # self._update_scale(self.overflow) # now done at the end of backward - print("OVERFLOW within closure! Skipping step, reducing loss scale to {}".format( - self.loss_scaler.loss_scale())) - temp_loss = closure() - return temp_loss - - retval = self.optimizer.step(wrapped_closure) - - self.first_closure_call_this_step = True - - return retval - - def backward(self, loss, update_master_grads=True, retain_graph=False): - """ - :attr:`backward` performs the following conceptual steps: - - 1. fp32_loss = loss.float() (see first Note below) - 2. scaled_loss = fp32_loss*loss_scale - 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's leaves (which may be fp16, fp32, or a mixture, depending how your model was defined). - 4. fp16 grads are then copied to the master params' ``.grad`` attributes (see second Note), which are guaranteed to be fp32. - 5. Finally, master grads are divided by loss_scale. - - In this way, after :attr:`backward`, the master params have fresh gradients, - and :attr:`step` may be called. - - .. note:: - :attr:`backward` internally converts the loss to fp32 before applying the loss scale. - This provides some additional safety against overflow if the user has supplied an - fp16 loss value. - However, for maximum overflow safety, the user should - compute the loss criterion (MSE, cross entropy, etc) in fp32 before supplying it to - :attr:`backward`. - - .. warning:: - The gradients found in a model's leaves after the call to - :attr:`backward` should not be regarded as valid in general, - because it's possible - they have been scaled (and in the case of dynamic loss scaling, - the scale factor may change over time). - If the user wants to inspect gradients after a call to :attr:`backward`, - only the master gradients should be regarded as valid. These can be retrieved via - :attr:`inspect_master_grad_data()`. - - Args: - loss: The loss output by the user's model. loss may be either float or half (but see first Note above). - update_master_grads (bool, optional, default=True): Option to copy fp16 grads to fp32 grads on this call. By setting this to False, the user can delay the copy, which is useful to eliminate redundant fp16->fp32 grad copies if :attr:`backward` is being called on multiple losses in one iteration. If set to False, the user becomes responsible for calling :attr:`update_master_grads` before calling :attr:`step`. - retain_graph (bool, optional, default=False): Forwards the usual ``retain_graph=True`` option to the internal call to ``loss.backward``. If ``retain_graph`` is being used to accumulate gradient values from multiple backward passes before calling ``optimizer.step``, passing ``update_master_grads=False`` is also recommended (see Example below). - - Example:: - - # Ordinary operation: - optimizer.backward(loss) - - # Naive operation with multiple losses (technically valid, but less efficient): - # fp32 grads will be correct after the second call, but - # the first call incurs an unnecessary fp16->fp32 grad copy. - optimizer.backward(loss1) - optimizer.backward(loss2) - - # More efficient way to handle multiple losses: - # The fp16->fp32 grad copy is delayed until fp16 grads from all - # losses have been accumulated. - optimizer.backward(loss1, update_master_grads=False) - optimizer.backward(loss2, update_master_grads=False) - optimizer.update_master_grads() - """ - # To consider: try multiple backward passes using retain_grad=True to find - # a loss scale that works. After you find a loss scale that works, do a final dummy - # backward pass with retain_graph=False to tear down the graph. Doing this would avoid - # discarding the iteration, but probably wouldn't improve overall efficiency. - scaled_loss = loss.float()*self.loss_scaler.loss_scale() - scaled_loss.backward(retain_graph=retain_graph) - if update_master_grads: - self.update_master_grads() - - def update_master_grads(self): - # torch.cuda.nvtx.range_push("update_master_grads") - """ - Copy the ``.grad`` attribute from stored references to fp16 parameters to - the ``.grad`` attribute of the fp32 master parameters that are directly - updated by the optimizer. :attr:`update_master_grads` only needs to be called if - ``fp16_optimizer_obj.backward`` was called with ``update_master_grads=False``. - """ - # if self.dynamic_loss_scale: - # self._check_overflow() - # if self.overflow: return - # self._model_grads_to_master_grads() - # self._downscale_master() - # Use the one-shot multi-tensor apply kernel - self.loss_scaler.clear_overflow_state() - if len(self.all_fp16_params) > 0: - # print("Model grads before") - # print([param.grad.data for param in self.all_fp16_params]) - # I'm ONLY writing this as an incremental way to make some tests pass until - # I can refactor the tests as well. - # FP16_Optimizer should not be used by anyone. - model_grads = [] - master_grads = [] - for model_param, master_param in zip(self.all_fp16_params, - self.all_fp32_from_fp16_params): - if model_param.grad is not None: - model_grads.append(model_param.grad) - if master_param.grad is None: - master_param.grad = torch.empty_like(master_param) - master_grads.append(master_param.grad) - self.loss_scaler.unscale( - model_grads, - master_grads, - self.loss_scaler.loss_scale()) - # print("Master grads after") - # print([param.grad.data for param in self.all_fp32_from_fp16_params]) - if len(self.all_fp32_from_fp32_params) > 0: - model_grads = [] - master_grads = [] - for model_param, master_param in zip(self.all_fp32_from_fp32_params, - self.all_fp32_from_fp32_params): - if model_param.grad is not None: - model_grads.append(model_param.grad) - master_grads.append(master_param.grad) - # print("Model grads before") - # print([param.grad.data for param in self.all_fp32_from_fp32_params]) - self.loss_scaler.unscale( - model_grads, - master_grads, - self.loss_scaler.loss_scale()) - # print("Master grads after") - # print([param.grad.data for param in self.all_fp32_from_fp32_params]) - # quit() - self.overflow = self.loss_scaler.update_scale() - # torch.cuda.nvtx.range_pop() - - - def inspect_master_grad_data(self): - """ - When running with :class:`FP16_Optimizer`, - ``.grad`` attributes of a model's fp16 leaves should not be - regarded as truthful, because they might be scaled. - After a call to :attr:`fp16_optimizer_obj.backward(loss)`, if no overflow was encountered, - the fp32 master params' ``.grad`` - attributes will contain valid gradients properly divided by the loss scale. However, - because :class:`FP16_Optimizer` flattens some parameters, accessing them may be - nonintuitive. :attr:`inspect_master_grad_data` - allows those gradients to be viewed with shapes corresponding to their associated model leaves. - - Returns: - List of lists (one list for each parameter group). The list for each parameter group - is a list of the ``.grad.data`` attributes of the fp32 master params belonging to that group. - """ - if self.overflow: - print("Warning: calling FP16_Optimizer.inspect_master_grad_data while in an overflow state. " - "Gradients are currently invalid (may be inf, nan, or stale). Returning None.") - return None - else: - # The optimizer owns only references to master params. - master_grads_data = [] - for param_group in self.optimizer.param_groups: - master_grads_this_group = [] - for param in param_group['params']: - if param.grad is not None: - master_grads_this_group.append(param.grad.data) - else: - master_grads_this_group.append(None) - master_grads_data.append(master_grads_this_group) - return master_grads_data - - - # Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale" - def _get_loss_scale(self): - return self.loss_scaler.loss_scale() - - def _set_loss_scale(self, value): - self.loss_scaler._loss_scale = value - - loss_scale = property(_get_loss_scale, _set_loss_scale) - - # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state" - def _get_state(self): - return self.optimizer.state - - def _set_state(self, value): - self.optimizer.state = value - - state = property(_get_state, _set_state) - - # Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups" - # (for example, to adjust the learning rate) - def _get_param_groups(self): - return self.optimizer.param_groups - - def _set_param_groups(self, value): - self.optimizer.param_groups = value - - param_groups = property(_get_param_groups, _set_param_groups) - diff --git a/apex/fp16_utils/fp16util.py b/apex/fp16_utils/fp16util.py deleted file mode 100644 index dcdc344..0000000 --- a/apex/fp16_utils/fp16util.py +++ /dev/null @@ -1,187 +0,0 @@ -import torch -import torch.nn as nn -from torch.autograd import Variable -from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors - - -class tofp16(nn.Module): - """ - Utility module that implements:: - - def forward(self, input): - return input.half() - """ - - def __init__(self): - super(tofp16, self).__init__() - - def forward(self, input): - return input.half() - - -def BN_convert_float(module): - """ - Utility function for network_to_half(). - - Retained for legacy purposes. - """ - if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True: - module.float() - for child in module.children(): - BN_convert_float(child) - return module - - -def network_to_half(network): - """ - Convert model to half precision in a batchnorm-safe way. - - Retained for legacy purposes. It is recommended to use FP16Model. - """ - return nn.Sequential(tofp16(), BN_convert_float(network.half())) - - -def convert_module(module, dtype): - """ - Converts a module's immediate parameters and buffers to dtype. - """ - for param in module.parameters(recurse=False): - if param is not None: - if param.data.dtype.is_floating_point: - param.data = param.data.to(dtype=dtype) - if param._grad is not None and param._grad.data.dtype.is_floating_point: - param._grad.data = param._grad.data.to(dtype=dtype) - - for buf in module.buffers(recurse=False): - if buf is not None and buf.data.dtype.is_floating_point: - buf.data = buf.data.to(dtype=dtype) - - -def convert_network(network, dtype): - """ - Converts a network's parameters and buffers to dtype. - """ - for module in network.modules(): - if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) and module.affine is True: - continue - convert_module(module, dtype) - if isinstance(module, torch.nn.RNNBase) or isinstance(module, torch.nn.modules.rnn.RNNBase): - module.flatten_parameters() - return network - - -class FP16Model(nn.Module): - """ - Convert model to half precision in a batchnorm-safe way. - """ - - def __init__(self, network): - super(FP16Model, self).__init__() - self.network = convert_network(network, dtype=torch.half) - - def forward(self, *inputs): - inputs = tuple(t.half() for t in inputs) - return self.network(*inputs) - - -def backwards_debug_hook(grad): - raise RuntimeError("master_params recieved a gradient in the backward pass!") - -def prep_param_lists(model, flat_master=False): - """ - Creates a list of FP32 master parameters for a given model, as in - `Training Neural Networks with Mixed Precision: Real Examples`_. - - Args: - model (torch.nn.Module): Existing Pytorch model - flat_master (bool, optional, default=False): Flatten the master parameters into a single tensor, as a performance optimization. - Returns: - A tuple (``model_params``, ``master_params``). ``model_params`` is a list of the model's parameters for later use with :func:`model_grads_to_master_grads` and :func:`master_params_to_model_params`. ``master_params`` is a list of FP32 master gradients. If ``flat_master=True``, ``master_params`` will be a list with one element. - - Example:: - - model_params, master_params = prep_param_lists(model) - - .. warning:: - Currently, if ``flat_master=True``, all the model's parameters must be the same type. If the model has parameters of different types, use ``flat_master=False``, or use :class:`FP16_Optimizer`. - - .. _`Training Neural Networks with Mixed Precision: Real Examples`: - http://on-demand.gputechconf.com/gtc/2018/video/S81012/ - """ - model_params = [param for param in model.parameters() if param.requires_grad] - - if flat_master: - # Give the user some more useful error messages - try: - # flatten_dense_tensors returns a contiguous flat array. - # http://pytorch.org/docs/master/_modules/torch/_utils.html - master_params = _flatten_dense_tensors([param.data for param in model_params]).float() - except: - print("Error in prep_param_lists: model may contain a mixture of parameters " - "of different types. Use flat_master=False, or use F16_Optimizer.") - raise - master_params = torch.nn.Parameter(master_params) - master_params.requires_grad = True - # master_params.register_hook(backwards_debug_hook) - if master_params.grad is None: - master_params.grad = master_params.new(*master_params.size()) - return model_params, [master_params] - else: - master_params = [param.clone().float().detach() for param in model_params] - for param in master_params: - param.requires_grad = True - return model_params, master_params - - -def model_grads_to_master_grads(model_params, master_params, flat_master=False): - """ - Copy model gradients to master gradients. - - Args: - model_params: List of model parameters created by :func:`prep_param_lists`. - master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`model_grads_to_master_grads`. - """ - if flat_master: - # The flattening may incur one more deep copy than is necessary. - master_params[0].grad.data.copy_( - _flatten_dense_tensors([p.grad.data for p in model_params])) - else: - for model, master in zip(model_params, master_params): - if model.grad is not None: - if master.grad is None: - master.grad = Variable(master.data.new(*master.data.size())) - master.grad.data.copy_(model.grad.data) - else: - master.grad = None - - -def master_params_to_model_params(model_params, master_params, flat_master=False): - """ - Copy master parameters to model parameters. - - Args: - model_params: List of model parameters created by :func:`prep_param_lists`. - master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`master_params_to_model_params`. - """ - if flat_master: - for model, master in zip(model_params, - _unflatten_dense_tensors(master_params[0].data, model_params)): - model.data.copy_(master) - else: - for model, master in zip(model_params, master_params): - model.data.copy_(master.data) - -# Backward compatibility fixes - -def to_python_float(t): - if hasattr(t, 'item'): - return t.item() - else: - return t[0] - -TORCH_MAJOR = int(torch.__version__.split('.')[0]) -TORCH_MINOR = int(torch.__version__.split('.')[1]) -if TORCH_MAJOR == 0 and TORCH_MINOR <= 4: - clip_grad_norm = torch.nn.utils.clip_grad_norm -else: - clip_grad_norm = torch.nn.utils.clip_grad_norm_ diff --git a/apex/fp16_utils/loss_scaler.py b/apex/fp16_utils/loss_scaler.py deleted file mode 100644 index b9f32fe..0000000 --- a/apex/fp16_utils/loss_scaler.py +++ /dev/null @@ -1,186 +0,0 @@ -import torch - -# item() is a recent addition, so this helps with backward compatibility. -def to_python_float(t): - if hasattr(t, 'item'): - return t.item() - else: - return t[0] - -class LossScaler: - """ - Class that manages a static loss scale. This class is intended to interact with - :class:`FP16_Optimizer`, and should not be directly manipulated by the user. - - Use of :class:`LossScaler` is enabled via the ``static_loss_scale`` argument to - :class:`FP16_Optimizer`'s constructor. - - Args: - scale (float, optional, default=1.0): The loss scale. - """ - - def __init__(self, scale=1): - self.cur_scale = scale - - # `params` is a list / generator of torch.Variable - def has_overflow(self, params): - return False - - # `x` is a torch.Tensor - def _has_inf_or_nan(x): - return False - - def update_scale(self, overflow): - pass - - @property - def loss_scale(self): - return self.cur_scale - - def scale_gradient(self, module, grad_in, grad_out): - return tuple(self.loss_scale * g for g in grad_in) - - def backward(self, loss, retain_graph=False): - scaled_loss = loss*self.loss_scale - scaled_loss.backward(retain_graph=retain_graph) - -class DynamicLossScaler: - """ - Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler` - indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of - :class:`FP16_Optimizer`. However, it's important to understand how :class:`DynamicLossScaler` - operates, because the default options can be changed using the - the ``dynamic_loss_args`` argument to :class:`FP16_Optimizer`'s constructor. - - Loss scaling is designed to combat the problem of underflowing gradients encountered at long - times when training fp16 networks. Dynamic loss scaling begins by attempting a very high loss - scale. Ironically, this may result in OVERflowing gradients. If overflowing gradients are - encountered, :class:`DynamicLossScaler` informs :class:`FP16_Optimizer` that an overflow has - occurred. - :class:`FP16_Optimizer` then skips the update step for this particular iteration/minibatch, - and :class:`DynamicLossScaler` adjusts the loss scale to a lower value. - If a certain number of iterations occur without overflowing gradients detected, - :class:`DynamicLossScaler` increases the loss scale once more. - In this way :class:`DynamicLossScaler` attempts to "ride the edge" of - always using the highest loss scale possible without incurring overflow. - - Args: - init_scale (float, optional, default=2**32): Initial loss scale attempted by :class:`DynamicLossScaler.` - scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``. - scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before increasing the loss scale. - """ - - def __init__(self, - init_scale=2**32, - scale_factor=2., - scale_window=1000): - self.cur_scale = init_scale - self.cur_iter = 0 - self.last_overflow_iter = -1 - self.scale_factor = scale_factor - self.scale_window = scale_window - - # `params` is a list / generator of torch.Variable - def has_overflow(self, params): - for p in params: - if p.grad is not None and DynamicLossScaler._has_inf_or_nan(p.grad.data): - return True - - return False - - # `x` is a torch.Tensor - def _has_inf_or_nan(x): - try: - # if x is half, the .float() incurs an additional deep copy, but it's necessary if - # Pytorch's .sum() creates a one-element tensor of the same type as x - # (which is true for some recent version of pytorch). - cpu_sum = float(x.float().sum()) - # More efficient version that can be used if .sum() returns a Python scalar - # cpu_sum = float(x.sum()) - except RuntimeError as instance: - # We want to check if inst is actually an overflow exception. - # RuntimeError could come from a different error. - # If so, we still want the exception to propagate. - if "value cannot be converted" not in instance.args[0]: - raise - return True - else: - if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: - return True - return False - - # `overflow` is boolean indicating whether the gradient overflowed - def update_scale(self, overflow): - if overflow: - # self.cur_scale /= self.scale_factor - self.cur_scale = max(self.cur_scale/self.scale_factor, 1) - self.last_overflow_iter = self.cur_iter - else: - if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0: - self.cur_scale *= self.scale_factor - self.cur_iter += 1 - - @property - def loss_scale(self): - return self.cur_scale - - def scale_gradient(self, module, grad_in, grad_out): - return tuple(self.loss_scale * g for g in grad_in) - - def backward(self, loss, retain_graph=False): - scaled_loss = loss*self.loss_scale - scaled_loss.backward(retain_graph=retain_graph) - -############################################################## -# Example usage below here -- assuming it's in a separate file -############################################################## -""" -TO-DO separate out into an example. -if __name__ == "__main__": - import torch - from torch.autograd import Variable - from dynamic_loss_scaler import DynamicLossScaler - - # N is batch size; D_in is input dimension; - # H is hidden dimension; D_out is output dimension. - N, D_in, H, D_out = 64, 1000, 100, 10 - - # Create random Tensors to hold inputs and outputs, and wrap them in Variables. - x = Variable(torch.randn(N, D_in), requires_grad=False) - y = Variable(torch.randn(N, D_out), requires_grad=False) - - w1 = Variable(torch.randn(D_in, H), requires_grad=True) - w2 = Variable(torch.randn(H, D_out), requires_grad=True) - parameters = [w1, w2] - - learning_rate = 1e-6 - optimizer = torch.optim.SGD(parameters, lr=learning_rate) - loss_scaler = DynamicLossScaler() - - for t in range(500): - y_pred = x.mm(w1).clamp(min=0).mm(w2) - loss = (y_pred - y).pow(2).sum() * loss_scaler.loss_scale - print('Iter {} loss scale: {}'.format(t, loss_scaler.loss_scale)) - print('Iter {} scaled loss: {}'.format(t, loss.data[0])) - print('Iter {} unscaled loss: {}'.format(t, loss.data[0] / loss_scaler.loss_scale)) - - # Run backprop - optimizer.zero_grad() - loss.backward() - - # Check for overflow - has_overflow = DynamicLossScaler.has_overflow(parameters) - - # If no overflow, unscale grad and update as usual - if not has_overflow: - for param in parameters: - param.grad.data.mul_(1. / loss_scaler.loss_scale) - optimizer.step() - # Otherwise, don't do anything -- ie, skip iteration - else: - print('OVERFLOW!') - - # Update loss scale for next iteration - loss_scaler.update_scale(has_overflow) - -""" diff --git a/apex/fused_dense/__init__.py b/apex/fused_dense/__init__.py deleted file mode 100644 index 83d12ca..0000000 --- a/apex/fused_dense/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .fused_dense import * diff --git a/apex/fused_dense/fused_dense.py b/apex/fused_dense/fused_dense.py deleted file mode 100644 index def9236..0000000 --- a/apex/fused_dense/fused_dense.py +++ /dev/null @@ -1,85 +0,0 @@ -import torch -from torch import nn -import fused_dense_cuda -from .. import amp -#implements fused GEMM+bias in forward pass using mlp_cuda from apex -class FusedDenseFunc(torch.autograd.Function): - @staticmethod - def forward(ctx, input, weight, bias): - ctx.save_for_backward(input, weight) - output = fused_dense_cuda.linear_bias_forward(input, weight, bias) - return output - - @staticmethod - def backward(ctx, grad_output): - input, weight = ctx.saved_tensors - grad_input, grad_weight, grad_bias = fused_dense_cuda.linear_bias_backward(input, weight, grad_output) - return grad_input, grad_weight, grad_bias - -class DenseNoBiasFunc(torch.autograd.Function): - @staticmethod - def forward(ctx, input, weight): - ctx.save_for_backward(input, weight) - output = torch.matmul(input, weight.t()) - return output - - @staticmethod - def backward(ctx, grad_output): - input, weight = ctx.saved_tensors - grad_input = grad_output.mm(weight) - grad_weight = grad_output.t().mm(input) - return grad_input, grad_weight - - -class FusedDenseGeluDenseFunc(torch.autograd.Function): - @staticmethod - def forward(ctx, input, weight1, bias1, weight2, bias2): - ctx.save_for_backward(input, weight1, weight2) - output1, output2, gelu_in = fused_dense_cuda.linear_gelu_linear_forward(input, weight1, bias1, weight2, bias2) - ctx.save_for_backward(input, weight1, weight2, gelu_in, output1) - return output2 - - @staticmethod - def backward(ctx, grad_output): - input, weight1, weight2, gelu_in, output1 = ctx.saved_tensors - grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2 = fused_dense_cuda.linear_gelu_linear_backward(input, gelu_in, output1, weight1, weight2, grad_output) - return grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2 - - -fused_dense_function = amp.half_function(FusedDenseFunc.apply) -dense_no_bias_function = amp.half_function(DenseNoBiasFunc.apply) -fused_dense_gelu_dense_function = amp.half_function(FusedDenseGeluDenseFunc.apply) - -class FusedDense(nn.Module): - def __init__(self, in_features, out_features, bias=True): - super(FusedDense, self).__init__() - self.in_features = in_features - self.out_features = out_features - self.weight = nn.Parameter(torch.empty(out_features, in_features)) - if bias: - self.bias = nn.Parameter(torch.empty(out_features)) - else: - #assert False, "no-bias option not added yet" - self.register_parameter('bias', None) - - def forward(self, input): - if self.bias is not None: - return fused_dense_function(input, self.weight, self.bias) - else: - return dense_no_bias_function(input, self.weight) - -class FusedDenseGeluDense(nn.Module): - def __init__(self, in_features, intermediate_features, out_features, bias=True): - super(FusedDenseGeluDense, self).__init__() - assert bias == True, "DenseGeluDense module without bias is currently not supported" - self.in_features = in_features - self.intermediate_features = intermediate_features - self.out_features = out_features - self.weight1 = nn.Parameter(torch.empty(intermediate_features, in_features)) - self.bias1 = nn.Parameter(torch.empty(intermediate_features)) - self.weight2 = nn.Parameter(torch.empty(out_features, intermediate_features)) - self.bias2 = nn.Parameter(torch.empty(out_features)) - - def forward(self, input): - return fused_dense_gelu_dense_function(input, self.weight1, self.bias1, self.weight2, self.bias2) - diff --git a/apex/mlp/__init__.py b/apex/mlp/__init__.py deleted file mode 100644 index f2f30f7..0000000 --- a/apex/mlp/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .mlp import * diff --git a/apex/mlp/mlp.py b/apex/mlp/mlp.py deleted file mode 100644 index bae38f3..0000000 --- a/apex/mlp/mlp.py +++ /dev/null @@ -1,79 +0,0 @@ -from copy import copy -import math -import torch -from torch import nn -import mlp_cuda -from .. import amp - -class MlpFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, bias, activation, *args): - output = mlp_cuda.forward(bias, activation, args) - ctx.save_for_backward(*args) - ctx.outputs = output - ctx.bias = bias - ctx.activation = activation - return output[0] - - @staticmethod - def backward(ctx, grad_o): - grads = mlp_cuda.backward(ctx.bias, ctx.activation, grad_o, ctx.outputs, ctx.saved_tensors) - del ctx.outputs - return (None, None, *grads) - -mlp_function = amp.half_function(MlpFunction.apply) - -class MLP(torch.nn.Module): - """Launch MLP in C++ - - Args: - mlp_sizes (list of int): MLP sizes. Example: [1024,1024,1024] will create 2 MLP layers with shape 1024x1024 - bias (bool): Default True: - relu (bool): Default True - """ - def __init__(self, mlp_sizes, bias=True, activation='relu'): - super(MLP, self).__init__() - self.num_layers = len(mlp_sizes) - 1 - self.mlp_sizes = copy(mlp_sizes) - self.bias = 1 if bias else 0 - - if activation is 'none': - self.activation = 0 - elif activation is 'relu': - self.activation = 1 - elif activation is 'sigmoid': - self.activation = 2 - else: - raise TypeError("activation must be relu or none.") - - self.weights = [] - self.biases = [] - for i in range(self.num_layers): - w = torch.nn.Parameter(torch.empty(mlp_sizes[i+1], mlp_sizes[i])) - self.weights.append(w) - name = 'weight_{}'.format(i) - setattr(self, name, w) - if self.bias: - b = torch.nn.Parameter(torch.empty(mlp_sizes[i+1])) - self.biases.append(b) - name = 'bias_{}'.format(i) - setattr(self, name, b) - - self.reset_parameters() - - def reset_parameters(self): - for weight in self.weights: - dimsum = weight.size(0) + weight.size(1) - std = math.sqrt(2. / float(dimsum)) - nn.init.normal_(weight, 0., std) - if self.bias: - for bias in self.biases: - std = math.sqrt(1. / float(bias.size(0))) - nn.init.normal_(bias, 0., std) - - def forward(self, input): - return mlp_function(self.bias, self.activation, input, *self.weights, *self.biases) - - def extra_repr(self): - s = F"MLP sizes: {self.mlp_sizes}, Bias={self.bias}, activation={self.activation}" - return s diff --git a/apex/multi_tensor_apply/__init__.py b/apex/multi_tensor_apply/__init__.py deleted file mode 100644 index 31e2a53..0000000 --- a/apex/multi_tensor_apply/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .multi_tensor_apply import MultiTensorApply - -multi_tensor_applier = MultiTensorApply(256*32) -multi_tensor_applier_l2norm = MultiTensorApply(2048*32) - diff --git a/apex/multi_tensor_apply/multi_tensor_apply.py b/apex/multi_tensor_apply/multi_tensor_apply.py deleted file mode 100644 index 346c6e5..0000000 --- a/apex/multi_tensor_apply/multi_tensor_apply.py +++ /dev/null @@ -1,30 +0,0 @@ -import torch - -class MultiTensorApply(object): - available = False - warned = False - - def __init__(self, chunk_size): - try: - import amp_C - MultiTensorApply.available = True - self.chunk_size = chunk_size - except ImportError as err: - MultiTensorApply.available = False - MultiTensorApply.import_err = err - - def check_avail(self): - if MultiTensorApply.available == False: - raise RuntimeError( - "Attempted to call MultiTensorApply method, but MultiTensorApply " - "is not available, possibly because Apex was installed without " - "--cpp_ext --cuda_ext. Original import error message:", - MultiTensorApply.import_err) - - def __call__(self, op, noop_flag_buffer, tensor_lists, *args): - self.check_avail() - - return op(self.chunk_size, - noop_flag_buffer, - tensor_lists, - *args) diff --git a/apex/normalization/__init__.py b/apex/normalization/__init__.py deleted file mode 100644 index c649913..0000000 --- a/apex/normalization/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .fused_layer_norm import FusedLayerNorm, MixedFusedLayerNorm, FusedRMSNorm, MixedFusedRMSNorm diff --git a/apex/normalization/fused_layer_norm.py b/apex/normalization/fused_layer_norm.py deleted file mode 100644 index aaf00d1..0000000 --- a/apex/normalization/fused_layer_norm.py +++ /dev/null @@ -1,437 +0,0 @@ -import importlib -import numbers - -import torch -from torch.nn.parameter import Parameter -from torch.nn import init -from torch.nn import functional as F - -from apex._autocast_utils import _cast_if_autocast_enabled - -global fused_layer_norm_cuda -fused_layer_norm_cuda = None - - -# Reference implementation from Huggingface -def manual_rms_norm(input, normalized_shape, weight, eps): - # layer norm should always be calculated in float32 - dims = tuple(i for i in range(-1, -len(normalized_shape)-1, -1)) - variance = input.to(torch.float32).pow(2).mean(dims, keepdim=True) - input = input * torch.rsqrt(variance + eps) - - if weight is None: - return input - - # convert into half-precision if necessary - if weight.dtype in [torch.float16, torch.bfloat16]: - input = input.to(self.weight.dtype) - - return weight * input - - -class FusedLayerNormAffineFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, input, weight, bias, normalized_shape, eps): - global fused_layer_norm_cuda - if fused_layer_norm_cuda is None: - fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") - ctx.normalized_shape = normalized_shape - ctx.eps = eps - input_ = input.contiguous() - weight_ = weight.contiguous() - bias_ = bias.contiguous() - output, mean, invvar = fused_layer_norm_cuda.forward_affine( - input_, ctx.normalized_shape, weight_, bias_, ctx.eps - ) - ctx.save_for_backward(input_, weight_, bias_, mean, invvar) - return output - - @staticmethod - def backward(ctx, grad_output): - input_, weight_, bias_, mean, invvar = ctx.saved_tensors - grad_input = grad_weight = grad_bias = None - grad_input, grad_weight, grad_bias = fused_layer_norm_cuda.backward_affine( - grad_output.contiguous(), mean, invvar, input_, ctx.normalized_shape, weight_, bias_, ctx.eps - ) - return grad_input, grad_weight, grad_bias, None, None - - -class FusedRMSNormAffineFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, input, weight, normalized_shape, eps): - global fused_layer_norm_cuda - if fused_layer_norm_cuda is None: - fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") - ctx.normalized_shape = normalized_shape - ctx.eps = eps - input_ = input.contiguous() - weight_ = weight.contiguous() - output, invvar = fused_layer_norm_cuda.rms_forward_affine( - input_, ctx.normalized_shape, weight_, ctx.eps) - ctx.save_for_backward(input_, weight_, invvar) - return output - - @staticmethod - def backward(ctx, grad_output): - input_, weight_, invvar = ctx.saved_tensors - grad_input = grad_weight = None - grad_input, grad_weight = fused_layer_norm_cuda.rms_backward_affine( - grad_output.contiguous(), invvar, input_, ctx.normalized_shape, weight_, ctx.eps - ) - return grad_input, grad_weight, None, None - - -class FusedLayerNormAffineMixedDtypesFunction(FusedLayerNormAffineFunction): - - @staticmethod - def forward(ctx, input, weight, bias, normalized_shape, eps): - global fused_layer_norm_cuda - if fused_layer_norm_cuda is None: - fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") - ctx.normalized_shape = normalized_shape - ctx.eps = eps - input_ = input.contiguous() - weight_ = weight.contiguous() - bias_ = bias.contiguous() - output, mean, invvar = fused_layer_norm_cuda.forward_affine_mixed_dtypes( - input_, ctx.normalized_shape, weight_, bias_, ctx.eps - ) - ctx.save_for_backward(input_, weight_, bias_, mean, invvar) - return output - - -class FusedRMSNormAffineMixedDtypesFunction(FusedRMSNormAffineFunction): - - @staticmethod - def forward(ctx, input, weight, normalized_shape, eps): - global fused_layer_norm_cuda - if fused_layer_norm_cuda is None: - fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") - ctx.normalized_shape = normalized_shape - ctx.eps = eps - input_ = input.contiguous() - weight_ = weight.contiguous() - output, invvar = fused_layer_norm_cuda.rms_forward_affine_mixed_dtypes( - input_, ctx.normalized_shape, weight_, ctx.eps - ) - - ctx.save_for_backward(input_, weight_, invvar) - return output - - -class FusedLayerNormFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, input, normalized_shape, eps): - global fused_layer_norm_cuda - if fused_layer_norm_cuda is None: - fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") - ctx.normalized_shape = normalized_shape - ctx.eps = eps - input_ = input.contiguous() - output, mean, invvar = fused_layer_norm_cuda.forward(input_, ctx.normalized_shape, ctx.eps) - ctx.save_for_backward(input_, mean, invvar) - return output - - @staticmethod - def backward(ctx, grad_output): - input_, mean, invvar = ctx.saved_tensors - grad_input = None - grad_input = fused_layer_norm_cuda.backward( - grad_output.contiguous(), mean, invvar, input_, ctx.normalized_shape, ctx.eps - ) - return grad_input, None, None - - -class FusedRMSNormFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, input, normalized_shape, eps): - global fused_layer_norm_cuda - if fused_layer_norm_cuda is None: - fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") - ctx.normalized_shape = normalized_shape - ctx.eps = eps - input_ = input.contiguous() - output, invvar = fused_layer_norm_cuda.rms_forward(input_, ctx.normalized_shape, ctx.eps) - ctx.save_for_backward(input_, invvar) - return output - - @staticmethod - def backward(ctx, grad_output): - input_, invvar = ctx.saved_tensors - grad_input = None - grad_input = fused_layer_norm_cuda.rms_backward( - grad_output.contiguous(), invvar, input_, ctx.normalized_shape, ctx.eps - ) - return grad_input, None, None - - -def fused_layer_norm_affine(input, weight, bias, normalized_shape, eps=1e-6): - args = _cast_if_autocast_enabled(input, weight, bias, normalized_shape, eps) - with torch.cuda.amp.autocast(enabled=False): - return FusedLayerNormAffineFunction.apply(*args) - - -def fused_layer_norm(input, normalized_shape, eps=1e-6): - args = _cast_if_autocast_enabled(input, normalized_shape, eps) - with torch.cuda.amp.autocast(enabled=False): - return FusedLayerNormFunction.apply(*args) - - -def mixed_dtype_fused_layer_norm_affine(input, weight, bias, normalized_shape, eps=1e-6): - args = _cast_if_autocast_enabled(input, weight, bias, normalized_shape, eps) - with torch.cuda.amp.autocast(enabled=False): - return FusedLayerNormAffineMixedDtypesFunction.apply(*args) - - -def fused_rms_norm_affine(input, weight, normalized_shape, eps=1e-6): - args = _cast_if_autocast_enabled(input, weight, normalized_shape, eps) - with torch.cuda.amp.autocast(enabled=False): - return FusedRMSNormAffineFunction.apply(*args) - - -def fused_rms_norm(input, normalized_shape, eps=1e-6): - args = _cast_if_autocast_enabled(input, normalized_shape, eps) - with torch.cuda.amp.autocast(enabled=False): - return FusedRMSNormFunction.apply(*args) - - -def mixed_dtype_fused_rms_norm_affine(input, weight, normalized_shape, eps=1e-6): - args = _cast_if_autocast_enabled(input, weight, normalized_shape, eps) - with torch.cuda.amp.autocast(enabled=False): - return FusedRMSNormAffineMixedDtypesFunction.apply(*args) - - -class FusedLayerNorm(torch.nn.Module): - r"""Applies Layer Normalization over a mini-batch of inputs as described in - the paper `Layer Normalization`_ . - - Currently only runs on cuda() tensors. - - .. math:: - y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta - - The mean and standard-deviation are calculated separately over the last - certain number dimensions which have to be of the shape specified by - :attr:`normalized_shape`. - :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of - :attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``. - - .. note:: - Unlike Batch Normalization and Instance Normalization, which applies - scalar scale and bias for each entire channel/plane with the - :attr:`affine` option, Layer Normalization applies per-element scale and - bias with :attr:`elementwise_affine`. - - This layer uses statistics computed from input data in both training and - evaluation modes. - - Args: - normalized_shape (int or list or torch.Size): input shape from an expected input - of size - - .. math:: - [* \times \text{normalized}\_\text{shape}[0] \times \text{normalized}\_\text{shape}[1] - \times \ldots \times \text{normalized}\_\text{shape}[-1]] - - If a single integer is used, it is treated as a singleton list, and this module will - normalize over the last dimension which is expected to be of that specific size. - eps: a value added to the denominator for numerical stability. Default: 1e-5 - elementwise_affine: a boolean value that when set to ``True``, this module - has learnable per-element affine parameters initialized to ones (for weights) - and zeros (for biases). Default: ``True``. - - Shape: - - Input: :math:`(N, *)` - - Output: :math:`(N, *)` (same shape as input) - - Examples:: - - >>> input = torch.randn(20, 5, 10, 10) - >>> # With Learnable Parameters - >>> m = apex.normalization.FusedLayerNorm(input.size()[1:]) - >>> # Without Learnable Parameters - >>> m = apex.normalization.FusedLayerNorm(input.size()[1:], elementwise_affine=False) - >>> # Normalize over last two dimensions - >>> m = apex.normalization.FusedLayerNorm([10, 10]) - >>> # Normalize over last dimension of size 10 - >>> m = apex.normalization.FusedLayerNorm(10) - >>> # Activating the module - >>> output = m(input) - - .. _`Layer Normalization`: https://arxiv.org/abs/1607.06450 - """ - - def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): - super().__init__() - - global fused_layer_norm_cuda - fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") - - if isinstance(normalized_shape, numbers.Integral): - normalized_shape = (normalized_shape,) - self.normalized_shape = torch.Size(normalized_shape) - self.eps = eps - self.elementwise_affine = elementwise_affine - if self.elementwise_affine: - self.weight = Parameter(torch.empty(*normalized_shape)) - self.bias = Parameter(torch.empty(*normalized_shape)) - else: - self.register_parameter("weight", None) - self.register_parameter("bias", None) - self.reset_parameters() - - def reset_parameters(self): - if self.elementwise_affine: - init.ones_(self.weight) - init.zeros_(self.bias) - - def forward(self, input): - if not input.is_cuda: - return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) - if self.elementwise_affine: - return fused_layer_norm_affine(input, self.weight, self.bias, self.normalized_shape, self.eps) - else: - return fused_layer_norm(input, self.normalized_shape, self.eps) - - def extra_repr(self): - return "{normalized_shape}, eps={eps}, " "elementwise_affine={elementwise_affine}".format(**self.__dict__) - - -class FusedRMSNorm(torch.nn.Module): - r"""Applies RMS Normalization over a mini-batch of inputs - - Currently only runs on cuda() tensors. - - .. math:: - y = \frac{x}{\mathrm{RMS}[x]} * \gamma - - The root-mean-square is calculated separately over the last - certain number dimensions which have to be of the shape specified by - :attr:`normalized_shape`. - :math:`\gamma` is a learnable affine transform parameter of - :attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``. - `epsilon` is added to the mean-square, then the root of the sum is taken. - - .. note:: - Unlike Batch Normalization and Instance Normalization, which applies - scalar scale and bias for each entire channel/plane with the - :attr:`affine` option, RMS Normalization applies per-element scale - with :attr:`elementwise_affine`. - - This layer uses statistics computed from input data in both training and - evaluation modes. - - Args: - normalized_shape (int or list or torch.Size): input shape from an expected input - of size - - .. math:: - [* \times \text{normalized}\_\text{shape}[0] \times \text{normalized}\_\text{shape}[1] - \times \ldots \times \text{normalized}\_\text{shape}[-1]] - - If a single integer is used, it is treated as a singleton list, and this module will - normalize over the last dimension which is expected to be of that specific size. - eps: a value added to the denominator for numerical stability. Default: 1e-5 - elementwise_affine: a boolean value that when set to ``True``, this module - has learnable per-element affine parameters initialized to ones (for weights) - and zeros (for biases). Default: ``True``. - - Shape: - - Input: :math:`(N, *)` - - Output: :math:`(N, *)` (same shape as input) - - Examples:: - - >>> input = torch.randn(20, 5, 10, 10) - >>> # With Learnable Parameters - >>> m = apex.normalization.FusedRMSNorm(input.size()[1:]) - >>> # Without Learnable Parameters - >>> m = apex.normalization.FusedRMSNorm(input.size()[1:], elementwise_affine=False) - >>> # Normalize over last two dimensions - >>> m = apex.normalization.FusedRMSNorm([10, 10]) - >>> # Normalize over last dimension of size 10 - >>> m = apex.normalization.FusedRMSNorm(10) - >>> # Activating the module - >>> output = m(input) - - .. _`Root Mean Square Layer Normalization`: https://arxiv.org/pdf/1910.07467.pdf - """ - - def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): - super().__init__() - - global fused_layer_norm_cuda - fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") - - if isinstance(normalized_shape, numbers.Integral): - normalized_shape = (normalized_shape,) - self.normalized_shape = torch.Size(normalized_shape) - self.eps = eps - self.elementwise_affine = elementwise_affine - if self.elementwise_affine: - self.weight = Parameter(torch.empty(*normalized_shape)) - else: - self.register_parameter("weight", None) - self.reset_parameters() - - def reset_parameters(self): - if self.elementwise_affine: - init.ones_(self.weight) - - def forward(self, input): - if not input.is_cuda: - return manual_rms_norm(input, self.normalized_shape, self.weight, self.eps) - - if self.elementwise_affine: - return fused_rms_norm_affine(input, self.weight, self.normalized_shape, self.eps) - else: - return fused_rms_norm(input, self.normalized_shape, self.eps) - - def extra_repr(self): - return "{normalized_shape}, eps={eps}, " "elementwise_affine={elementwise_affine}".format(**self.__dict__) - - -# NOTE (mkozuki): Why "mixed"? -# MixedFusedLayerNorm differs from FusedLayerNorm in that this layer norm uses parameter's dtype -# as output tensor's dtype while FusedLayerNorm uses input tensor's dtype for output tensor's dtype. -# See: `layer_norm_affine` and `layer_norm_affine_mixed_dtypes` in "csrc/layer_norm_cuda.cpp" -class MixedFusedLayerNorm(FusedLayerNorm): - - def __init__(self, normalized_shape, eps=1e-5, **kwargs): - if "elementwise_affine" in kwargs: - import warnings - warnings.warn("MixedFusedLayerNorm does not support `elementwise_affine` argument") - elementwise_affine = kwargs.pop("elementwise_affine") - if not elementwise_affine: - raise RuntimeError("MixedFusedLayerNorm does not support `elementwise_affine = False`") - - super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=True) - - def forward(self, input: torch.Tensor): - # NOTE (mkozuki): CPU path is here mainly for unittest sake. - if not input.is_cuda: - return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) - return mixed_dtype_fused_layer_norm_affine(input, self.weight, self.bias, self.normalized_shape, self.eps) - - -# MixedFusedLayerNorm differs from FusedLayerNorm in that this layer norm uses parameter's dtype -# as output tensor's dtype while FusedLayerNorm uses input tensor's dtype for output tensor's dtype. -# See: `layer_norm_affine` and `layer_norm_affine_mixed_dtypes` in "csrc/layer_norm_cuda.cpp" -class MixedFusedRMSNorm(FusedRMSNorm): - - def __init__(self, normalized_shape, eps=1e-5, **kwargs): - if "elementwise_affine" in kwargs: - import warnings - warnings.warn("MixedFusedRMSNorm does not support `elementwise_affine` argument") - elementwise_affine = kwargs.pop("elementwise_affine") - if not elementwise_affine: - raise RuntimeError("MixedFusedRMSNorm does not support `elementwise_affine = False`") - - super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=True) - - def forward(self, input: torch.Tensor): - # NOTE (mkozuki): CPU path is here mainly for unittest sake. - # TODO Manual RMS Norm Implementation Here - if not input.is_cuda: - return manual_rms_norm(input, self.normalized_shape, self.weight, self.eps) - return mixed_dtype_fused_rms_norm_affine(input, self.weight, self.normalized_shape, self.eps) diff --git a/apex/optimizers/__init__.py b/apex/optimizers/__init__.py deleted file mode 100644 index 888a4af..0000000 --- a/apex/optimizers/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .fused_sgd import FusedSGD -from .fused_adam import FusedAdam -from .fused_novograd import FusedNovoGrad -from .fused_lamb import FusedLAMB -from .fused_adagrad import FusedAdagrad -from .fused_mixed_precision_lamb import FusedMixedPrecisionLamb -from .fused_lars import FusedLARS diff --git a/apex/optimizers/fused_adagrad.py b/apex/optimizers/fused_adagrad.py deleted file mode 100644 index 8d1ef6f..0000000 --- a/apex/optimizers/fused_adagrad.py +++ /dev/null @@ -1,122 +0,0 @@ -import torch -from apex.multi_tensor_apply import multi_tensor_applier - - -class FusedAdagrad(torch.optim.Optimizer): - """Implements Adagrad algorithm. - - Currently GPU-only. Requires Apex to be installed via - ``pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./``. - - This version of fused Adagrad implements 2 fusions. - * Fusion of the Adagrad update's elementwise operations - * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches. - - :class:`apex.optimizers.FusedAdagrad`'s usage is identical to any ordinary Pytorch optimizer:: - opt = apex.optimizers.FusedAdagrad(model.parameters(), lr = ....) - ... - opt.step() - - :class:`apex.optimizers.FusedAdagrad` may be used with or without Amp. If you wish to use :class:`FusedAdagrad` with Amp, - you may choose any ``opt_level``:: - opt = apex.optimizers.FusedAdagrad(model.parameters(), lr = ....) - model, opt = amp.initialize(model, opt, opt_level="O0" or "O1 or "O2") - ... - opt.step() - In general, ``opt_level="O1"`` is recommended. - - It has been proposed in `Adaptive Subgradient Methods for Online Learning - and Stochastic Optimization`_. - Arguments: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-2) - weight_decay (float, optional): weight decay (L2 penalty) (default: 0) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-10) - adagrad_w_mode (boolean, optional): Apply L2 regularization or weight decay - True for decoupled weight decay (also known as AdamW) (default: False) - - .. _Adaptive Subgradient Methods for Online Learning and Stochastic - Optimization: http://jmlr.org/papers/v12/duchi11a.html - """ - def __init__(self, params, lr=1e-2, eps=1e-10, - weight_decay=0., set_grad_none=True, adagrad_w_mode=False): - - defaults = dict(lr=lr, eps=eps, weight_decay=weight_decay) - super(FusedAdagrad, self).__init__(params, defaults) - self.adagrad_w_mode = 1 if adagrad_w_mode else 0 - self.set_grad_none = set_grad_none - - if multi_tensor_applier.available: - import amp_C - # Skip buffer - self._dummy_overflow_buf = torch.cuda.IntTensor([0]) - self.multi_tensor_adagrad = amp_C.multi_tensor_adagrad - else: - raise RuntimeError('apex.optimizers.FusedAdagrad requires cuda extensions') - - def zero_grad(self): - if self.set_grad_none: - for group in self.param_groups: - for p in group['params']: - p.grad = None - else: - super(FusedAdagrad, self).zero_grad() - - def step(self, closure=None): - """Performs a single optimization step. - - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - loss = closure() - - for group in self.param_groups: - # create lists for multi-tensor apply - g_16, p_16, h_16 = [], [], [] - g_32, p_32, h_32 = [], [], [] - - for p in group['params']: - if p.grad is None: - continue - if p.grad.data.is_sparse: - raise RuntimeError('FusedAdagrad does not support sparse gradients') - - state = self.state[p] - # State initialization - if len(state) == 0: - # Exponential moving average of gradient values - state['sum'] = torch.zeros_like(p.data) - if p.dtype in {torch.float16, torch.bfloat16}: - g_16.append(p.grad.data) - p_16.append(p.data) - h_16.append(state['sum']) - elif p.dtype == torch.float32: - g_32.append(p.grad.data) - p_32.append(p.data) - h_32.append(state['sum']) - else: - raise RuntimeError('FusedAdagrad only support fp16, bfloat16 and fp32.') - - if(len(g_16) > 0): - multi_tensor_applier(self.multi_tensor_adagrad, - self._dummy_overflow_buf, - [g_16, p_16, h_16], - group['lr'], - group['eps'], - self.adagrad_w_mode, - group['weight_decay']) - if(len(g_32) > 0): - multi_tensor_applier(self.multi_tensor_adagrad, - self._dummy_overflow_buf, - [g_32, p_32, h_32], - group['lr'], - group['eps'], - self.adagrad_w_mode, - group['weight_decay']) - - return loss \ No newline at end of file diff --git a/apex/optimizers/fused_adam.py b/apex/optimizers/fused_adam.py deleted file mode 100644 index bc8bb15..0000000 --- a/apex/optimizers/fused_adam.py +++ /dev/null @@ -1,193 +0,0 @@ -import torch -from apex.multi_tensor_apply import multi_tensor_applier - -class FusedAdam(torch.optim.Optimizer): - - """Implements Adam algorithm. - - Currently GPU-only. Requires Apex to be installed via - ``pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./``. - - This version of fused Adam implements 2 fusions. - - * Fusion of the Adam update's elementwise operations - * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches. - - :class:`apex.optimizers.FusedAdam` may be used as a drop-in replacement for ``torch.optim.AdamW``, - or ``torch.optim.Adam`` with ``adam_w_mode=False``:: - - opt = apex.optimizers.FusedAdam(model.parameters(), lr = ....) - ... - opt.step() - - :class:`apex.optimizers.FusedAdam` may be used with or without Amp. If you wish to use :class:`FusedAdam` with Amp, - you may choose any ``opt_level``:: - - opt = apex.optimizers.FusedAdam(model.parameters(), lr = ....) - model, opt = amp.initialize(model, opt, opt_level="O0" or "O1 or "O2") - ... - opt.step() - - In general, ``opt_level="O1"`` is recommended. - - - .. warning:: - A previous version of :class:`FusedAdam` allowed a number of additional arguments to ``step``. These additional arguments - are now deprecated and unnecessary. - - Adam was been proposed in `Adam: A Method for Stochastic Optimization`_. - - Arguments: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups. - lr (float, optional): learning rate. (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square. (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability. (default: 1e-8) - weight_decay (float, optional): weight decay (L2 penalty) (default: 0) - amsgrad (boolean, optional): whether to use the AMSGrad variant of this - algorithm from the paper `On the Convergence of Adam and Beyond`_ - (default: False) NOT SUPPORTED in FusedAdam! - adam_w_mode (boolean, optional): Apply L2 regularization or weight decay - True for decoupled weight decay(also known as AdamW) (default: True) - set_grad_none (bool, optional): whether set grad to None when zero_grad() - method is called. (default: True) - - .. _Adam - A Method for Stochastic Optimization: - https://arxiv.org/abs/1412.6980 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ - """ - - def __init__(self, params, lr=1e-3, bias_correction=True, - betas=(0.9, 0.999), eps=1e-8, adam_w_mode=True, - weight_decay=0., amsgrad=False, set_grad_none=True): - - if amsgrad: - raise RuntimeError('FusedAdam does not support the AMSGrad variant.') - defaults = dict(lr=lr, bias_correction=bias_correction, - betas=betas, eps=eps, weight_decay=weight_decay) - super(FusedAdam, self).__init__(params, defaults) - self.adam_w_mode = 1 if adam_w_mode else 0 - self.set_grad_none = set_grad_none - if multi_tensor_applier.available: - import amp_C - # Skip buffer - self._dummy_overflow_buf = torch.cuda.IntTensor([0]) - self.multi_tensor_adam = amp_C.multi_tensor_adam - else: - raise RuntimeError('apex.optimizers.FusedAdam requires cuda extensions') - - def zero_grad(self): - if self.set_grad_none: - for group in self.param_groups: - for p in group['params']: - p.grad = None - else: - super(FusedAdam, self).zero_grad() - - def step(self, closure=None, grads=None, output_params=None, scale=None, grad_norms=None): - """Performs a single optimization step. - - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - - The remaining arguments are deprecated, and are only retained (for the moment) for error-checking purposes. - """ - if any(p is not None for p in [grads, output_params, scale, grad_norms]): - raise RuntimeError('FusedAdam has been updated. Simply initialize it identically to torch.optim.Adam, and call step() with no arguments.') - loss = None - if closure is not None: - loss = closure() - - for group in self.param_groups: - bias_correction = 1 if group['bias_correction'] else 0 - beta1, beta2 = group['betas'] - - # assume same step across group now to simplify things - # per parameter step can be easily support by making it tensor, or pass list into kernel - if 'step' in group: - group['step'] += 1 - else: - group['step'] = 1 - - # create lists for multi-tensor apply - g_16, p_16, m_16, v_16 = [], [], [], [] - g_bf, p_bf, m_bf, v_bf = [], [], [], [] - g_32, p_32, m_32, v_32 = [], [], [], [] - - for p in group['params']: - if p.grad is None: - continue - if p.grad.data.is_sparse: - raise RuntimeError('FusedAdam does not support sparse gradients, please consider SparseAdam instead') - - state = self.state[p] - # State initialization - if len(state) == 0: - # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p.data) - # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like(p.data) - - if p.dtype in {torch.float16, torch.bfloat16}: - g_16.append(p.grad.data) - p_16.append(p.data) - m_16.append(state['exp_avg']) - v_16.append(state['exp_avg_sq']) - elif p.dtype == torch.bfloat16: - g_bf.append(p.grad) - p_bf.append(p) - m_bf.append(state['exp_avg']) - v_bf.append(state['exp_avg_sq']) - elif p.dtype == torch.float32: - g_32.append(p.grad.data) - p_32.append(p.data) - m_32.append(state['exp_avg']) - v_32.append(state['exp_avg_sq']) - else: - raise RuntimeError('FusedAdam only support fp16, bfloat16 and fp32.') - - if(len(g_16) > 0): - multi_tensor_applier(self.multi_tensor_adam, - self._dummy_overflow_buf, - [g_16, p_16, m_16, v_16], - group['lr'], - beta1, - beta2, - group['eps'], - group['step'], - self.adam_w_mode, - bias_correction, - group['weight_decay']) - if g_bf: - multi_tensor_applier( - self.multi_tensor_adam, - self._dummy_overflow_buf, - [g_bf, p_bf, m_bf, v_bf], - group['lr'], - beta1, - beta2, - group['eps'], - group['step'], - self.adam_w_mode, - bias_correction, - group['weight_decay'], - ) - if(len(g_32) > 0): - multi_tensor_applier(self.multi_tensor_adam, - self._dummy_overflow_buf, - [g_32, p_32, m_32, v_32], - group['lr'], - beta1, - beta2, - group['eps'], - group['step'], - self.adam_w_mode, - bias_correction, - group['weight_decay']) - - - return loss diff --git a/apex/optimizers/fused_lamb.py b/apex/optimizers/fused_lamb.py deleted file mode 100644 index a77e0cd..0000000 --- a/apex/optimizers/fused_lamb.py +++ /dev/null @@ -1,215 +0,0 @@ -import torch -from apex.multi_tensor_apply import multi_tensor_applier, multi_tensor_applier_l2norm - -class FusedLAMB(torch.optim.Optimizer): - - """Implements LAMB algorithm. - - Currently GPU-only. Requires Apex to be installed via - ``pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./``. - - This version of fused LAMB implements 2 fusions. - - * Fusion of the LAMB update's elementwise operations - * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches. - - :class:`apex.optimizers.FusedLAMB`'s usage is identical to any ordinary Pytorch optimizer:: - - opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....) - ... - opt.step() - - :class:`apex.optimizers.FusedLAMB` may be used with or without Amp. If you wish to use :class:`FusedLAMB` with Amp, - you may choose any ``opt_level``:: - - opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....) - model, opt = amp.initialize(model, opt, opt_level="O0" or "O1 or "O2") - ... - opt.step() - - In general, ``opt_level="O1"`` is recommended. - - LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. - - Arguments: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups. - lr (float, optional): learning rate. (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its norm. (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability. (default: 1e-8) - weight_decay (float, optional): weight decay (L2 penalty) (default: 0) - amsgrad (boolean, optional): whether to use the AMSGrad variant of this - algorithm from the paper `On the Convergence of Adam and Beyond`_ - NOT SUPPORTED now! (default: False) - adam_w_mode (boolean, optional): Apply L2 regularization or weight decay - True for decoupled weight decay(also known as AdamW) (default: True) - grad_averaging (bool, optional): whether apply (1-beta2) to grad when - calculating running averages of gradient. (default: True) - set_grad_none (bool, optional): whether set grad to None when zero_grad() - method is called. (default: True) - max_grad_norm (float, optional): value used to clip global grad norm - (default: 1.0) - use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0 - weight decay parameter (default: False) - - .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes: - https://arxiv.org/abs/1904.00962 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ - """ - - def __init__(self, params, lr=1e-3, bias_correction=True, - betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01, - amsgrad=False, adam_w_mode=True, - grad_averaging=True, set_grad_none=True, - max_grad_norm=1.0, use_nvlamb=False): - if amsgrad: - raise RuntimeError('FusedLAMB does not support the AMSGrad variant.') - defaults = dict(lr=lr, bias_correction=bias_correction, - betas=betas, eps=eps, weight_decay=weight_decay, - grad_averaging=grad_averaging, - max_grad_norm=max_grad_norm) - super(FusedLAMB, self).__init__(params, defaults) - if multi_tensor_applier.available and multi_tensor_applier_l2norm.available: - import amp_C - self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm - # Skip buffer - self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=self.param_groups[0]["params"][0].device) - self.multi_tensor_lamb = amp_C.multi_tensor_lamb - else: - raise RuntimeError('apex.optimizers.FusedLAMB requires cuda extensions') - - self.adam_w_mode = 1 if adam_w_mode else 0 - self.set_grad_none = set_grad_none - self.use_nvlamb = use_nvlamb - - def zero_grad(self): - if self.set_grad_none: - for group in self.param_groups: - for p in group['params']: - p.grad = None - else: - super(FusedLAMB, self).zero_grad() - - def step(self, closure=None): - """Performs a single optimization step. - - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - loss = closure() - - # create separate grad lists for fp32 and fp16 params - g_all_32, g_all_16 = [], [] - for group in self.param_groups: - for p in group['params']: - if p.grad is None: - continue - if p.dtype == torch.float32: - g_all_32.append(p.grad.data) - elif p.dtype == torch.float16: - g_all_16.append(p.grad.data) - else: - raise RuntimeError('FusedLAMB only support fp16 and fp32.') - - device = self.param_groups[0]["params"][0].device - g_norm_32, g_norm_16 = torch.zeros(1, device=device), torch.zeros(1, device=device) - # compute grad norm for two lists - if len(g_all_32) > 0: - g_norm_32 = multi_tensor_applier_l2norm(self.multi_tensor_l2norm, - self._dummy_overflow_buf, - [g_all_32], False)[0] - if len(g_all_16) > 0: - g_norm_16 = multi_tensor_applier_l2norm(self.multi_tensor_l2norm, - self._dummy_overflow_buf, - [g_all_16], False)[0] - - # blend two grad norms to get global grad norm - global_grad_norm = multi_tensor_applier_l2norm(self.multi_tensor_l2norm, - self._dummy_overflow_buf, - [[g_norm_32, g_norm_16]], - False)[0] - max_grad_norm = self.defaults['max_grad_norm'] - - for group in self.param_groups: - bias_correction = 1 if group['bias_correction'] else 0 - beta1, beta2 = group['betas'] - grad_averaging = 1 if group['grad_averaging'] else 0 - - # assume same step across group now to simplify things - # per parameter step can be easily support by making it tensor, or pass list into kernel - if 'step' in group: - group['step'] += 1 - else: - group['step'] = 1 - - # create lists for multi-tensor apply - g_16, p_16, m_16, v_16 = [], [], [], [] - g_32, p_32, m_32, v_32 = [], [], [], [] - - for p in group['params']: - if p.grad is None: - continue - if p.grad.data.is_sparse: - raise RuntimeError('FusedLAMB does not support sparse gradients, please consider SparseAdam instead') - - state = self.state[p] - # State initialization - if len(state) == 0: - # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p.data) - # Exponential moving average of gradient values - state['exp_avg_sq'] = torch.zeros_like(p.data) - - if p.dtype in {torch.float16, torch.bfloat16}: - g_16.append(p.grad.data) - p_16.append(p.data) - m_16.append(state['exp_avg']) - v_16.append(state['exp_avg_sq']) - elif p.dtype == torch.float32: - g_32.append(p.grad.data) - p_32.append(p.data) - m_32.append(state['exp_avg']) - v_32.append(state['exp_avg_sq']) - else: - raise RuntimeError('FusedLAMB only support fp16, bfloat16 and fp32.') - - if(len(g_16) > 0): - multi_tensor_applier(self.multi_tensor_lamb, - self._dummy_overflow_buf, - [g_16, p_16, m_16, v_16], - group['lr'], - beta1, - beta2, - group['eps'], - group['step'], - bias_correction, - group['weight_decay'], - grad_averaging, - self.adam_w_mode, - global_grad_norm, - max_grad_norm, - self.use_nvlamb) - if(len(g_32) > 0): - multi_tensor_applier(self.multi_tensor_lamb, - self._dummy_overflow_buf, - [g_32, p_32, m_32, v_32], - group['lr'], - beta1, - beta2, - group['eps'], - group['step'], - bias_correction, - group['weight_decay'], - grad_averaging, - self.adam_w_mode, - global_grad_norm, - max_grad_norm, - self.use_nvlamb) - - return loss diff --git a/apex/optimizers/fused_lars.py b/apex/optimizers/fused_lars.py deleted file mode 100644 index 3e60b2c..0000000 --- a/apex/optimizers/fused_lars.py +++ /dev/null @@ -1,224 +0,0 @@ -import torch -from torch.optim.optimizer import Optimizer, required -from torch import nn -from torch.nn.parameter import Parameter -from apex.multi_tensor_apply import multi_tensor_applier - -class FusedLARS(Optimizer): - def __init__(self, params, lr=required, momentum=0, dampening=0, - weight_decay=0, trust_coefficient=0.001, eps=0.0, - nesterov=False, wd_after_momentum=False, - materialize_master_grads=True, set_grad_none=False): - - if lr is not required and lr < 0.0: - raise ValueError("Invalid learning rate: {}".format(lr)) - if momentum < 0.0: - raise ValueError("Invalid momentum value: {}".format(momentum)) - if weight_decay < 0.0: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) - - defaults = dict(lr=lr, momentum=momentum, dampening=dampening, - weight_decay=weight_decay, nesterov=nesterov, trust_coefficient=trust_coefficient, eps=eps, is_skipped=False) - if nesterov and (momentum <= 0 or dampening != 0): - raise ValueError("Nesterov momentum requires a momentum and zero dampening") - super(FusedLARS, self).__init__(params, defaults) - - self.wd_after_momentum = wd_after_momentum - self.materialize_master_grads = materialize_master_grads - self.most_recent_scale = 1.0 - self.scale_set_by_backward = False - self.set_grad_none = set_grad_none - self.trust_coefficient = trust_coefficient - self.eps = eps - - if multi_tensor_applier.available: - import amp_C - # Skip buffer - self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=self.param_groups[0]["params"][0].device) - self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm - self.multi_tensor_lars = amp_C.multi_tensor_lars - self._dummy_overflow_buf = torch.cuda.IntTensor(1).zero_() - else: - raise RuntimeError('apex.optimizers.FusedLARS requires cuda extensions') - - def __setstate__(self, state): - super(FusedLARS, self).__setstate__(state) - for group in self.param_groups: - group.setdefault('nesterov', False) - - def zero_grad(self): - if self.set_grad_none: - for group in self.param_groups: - for p in group['params']: - p.grad = None - else: - super(FusedLARS, self).zero_grad() - - def get_momentums(self, params): - momentums = [] - first_run = True - for p in params: - if p.grad is None: - continue - - param_state = self.state[p] - d_p = p.grad.data - # torch.optim.SGD initializes momentum in the main loop, we have - # to do it here, and track whether or not we've done so, so that - # momentum application can be skipped in the main kernel. - if 'momentum_buffer' not in param_state: - first_run = True - buf = param_state['momentum_buffer'] = torch.zeros_like(p.data) - momentums.append(buf) - else: - first_run = False - momentums.append(param_state['momentum_buffer']) - return momentums, first_run - - def step(self, closure=None): - """Performs a single optimization step. - - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - loss = closure() - - explicit_master_params = (hasattr(self, "_amp_stash") and - hasattr(self._amp_stash, "fp32_from_fp16_groups")) - explicit_master_params = False - - for gid, group in enumerate(self.param_groups): - weight_decay = group['weight_decay'] - momentum = group['momentum'] - dampening = group['dampening'] - nesterov = group['nesterov'] - lr = group['lr'] - is_skipped = group['is_skipped'] - - # For each group, there are 3 possible combinations we need to consider: - # grad_type, param_to_update_type, momentum_type, requires_fp16_model_copy - # 1. fp16, fp16, fp16, No - # 2. fp32, fp32, fp32, No - # 3. fp16, fp32, fp32, Yes - - first_runs = [True, True] - g_norms_grp = [] - w_norms_grp = [] - - - # I think a bit of code divergence in exchange for naming clarity is worthwhile - if explicit_master_params: - print('explicit_master_params') - stash = self._amp_stash - - fp32_params = [p for p in stash.fp32_from_fp32_groups[gid] if p.grad is not None] - fp32_grads = [p.grad for p in stash.fp32_from_fp32_groups[gid] if p.grad is not None] - fp32_momentums, first_runs[1] = self.get_momentums(fp32_params) - - if self.materialize_master_grads: - fp16_model_params = [p for i, p in enumerate( - stash.fp16_groups[gid]) if stash.fp32_from_fp16_groups[gid][i].grad is not None] - fp32_from_fp16_grads = [p.grad for p in stash.fp32_from_fp16_groups[gid] if p.grad is not None] - fp32_from_fp16_params = [p for p in stash.fp32_from_fp16_groups[gid] if p.grad is not None] - fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(fp32_from_fp16_params) - - fp16_set = [fp32_from_fp16_grads, fp32_from_fp16_params, - fp32_from_fp16_momentums, fp16_model_params] - else: - fp16_model_params = [p for p in stash.fp16_groups[gid] if p.grad is not None] - fp16_model_grads = [p.grad for p in stash.fp16_groups[gid] if p.grad is not None] - fp32_from_fp16_params = [p for i, p in enumerate( - stash.fp32_from_fp16_groups[gid]) if stash.fp16_groups[gid][i].grad is not None] - fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(fp32_from_fp16_params) - - fp16_set = [fp16_model_grads, fp32_from_fp16_params, - fp32_from_fp16_momentums, fp16_model_params] - - launch_sets= [fp16_set, [fp32_grads, fp32_params, fp32_momentums]] - - else: - fp16_params = [p for p in group['params'] if (p.dtype == torch.float16 and p.grad is not None)] - #fp16_grads = [p.grad for p in group['params'] if (p.dtype == torch.float16 and p.grad is not None)] - fp16_grads = [] - for p in fp16_params: - if p.is_contiguous(): - fp16_grads.append(p.grad) - elif p.is_contiguous(memory_format=torch.channels_last): - fp16_grads.append(p.grad.to(memory_format=torch.channels_last)) - fp16_momentums, first_runs[0] = self.get_momentums(fp16_params) - # Compute L2 norms - if len(fp16_params) > 0: - w_norms = multi_tensor_applier( - self.multi_tensor_l2norm, - self._dummy_overflow_buf, - [[p.data for p in fp16_params]], - True)[1] - g_norms = multi_tensor_applier( - self.multi_tensor_l2norm, - self._dummy_overflow_buf, - [[p.data for p in fp16_grads]], - True)[1] - else: - w_norms = [] - g_norms = [] - w_norms_grp.append(w_norms) - g_norms_grp.append(g_norms) - - fp32_params = [p for p in group['params'] if (p.dtype == torch.float32 and p.grad is not None)] - fp32_grads = [] - for p in fp32_params: - if p.is_contiguous(): - fp32_grads.append(p.grad) - elif p.is_contiguous(memory_format=torch.channels_last): - fp32_grads.append(p.grad.to(memory_format=torch.channels_last)) - fp32_momentums, first_runs[1] = self.get_momentums(fp32_params) - # Compute L2 norms - if len(fp32_params) > 0: - w_norms = multi_tensor_applier( - self.multi_tensor_l2norm, - self._dummy_overflow_buf, - [[p.data for p in fp32_params]], - True)[1] - g_norms = multi_tensor_applier( - self.multi_tensor_l2norm, - self._dummy_overflow_buf, - [[p.data for p in fp32_grads]], - True)[1] - else: - w_norms = [] - g_norms = [] - w_norms_grp.append(w_norms) - g_norms_grp.append(g_norms) - - launch_sets = [[fp16_grads, fp16_params, fp16_momentums], - [fp32_grads, fp32_params, fp32_momentums]] - - for s, (launch_set, first_run, g_norms, w_norms) in enumerate(zip(launch_sets, first_runs, g_norms_grp, w_norms_grp)): - assert len(launch_set[0]) == len(launch_set[1]) - assert len(launch_set[0]) == len(launch_set[2]) - if len(launch_set[0]) > 0: - multi_tensor_applier( - self.multi_tensor_lars, - self._dummy_overflow_buf, - launch_set, - g_norms, - w_norms, - group['lr'], - group['trust_coefficient'], - self.eps, - weight_decay, - momentum, - dampening, - nesterov, - first_run, - self.wd_after_momentum, - 1.0/self.most_recent_scale, - group['is_skipped']) - - self.most_recent_scale = 1.0 - self.scale_set_by_backward = False - - return loss diff --git a/apex/optimizers/fused_mixed_precision_lamb.py b/apex/optimizers/fused_mixed_precision_lamb.py deleted file mode 100644 index 7ecda4f..0000000 --- a/apex/optimizers/fused_mixed_precision_lamb.py +++ /dev/null @@ -1,256 +0,0 @@ -import torch -from copy import deepcopy -from itertools import chain -from collections import defaultdict, abc as container_abcs - -from apex.multi_tensor_apply import multi_tensor_applier, multi_tensor_applier_l2norm - -class FusedMixedPrecisionLamb(torch.optim.Optimizer): - - def __init__(self, params, lr=1e-3, step=0, bias_correction=True, - betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01, - amsgrad=False, adam_w_mode=True, - grad_averaging=True, max_grad_norm=1.0, use_nvlamb=False, - reduced_precision_dtype=None): - if amsgrad: - raise RuntimeError('FusedLAMB does not support the AMSGrad variant.') - - # The learning rate (lr) and optimizer step (step) should be located on device - # in order to faciliated device sync free execution - defaults = dict(lr=torch.tensor(lr, dtype=torch.float32), - step=torch.tensor([step], dtype=torch.int), - bias_correction=bias_correction, - betas=betas, eps=eps, weight_decay=weight_decay, - grad_averaging=grad_averaging, - max_grad_norm=max_grad_norm) - tensor_state = ['lr', 'step'] - super(FusedMixedPrecisionLamb, self).__init__(params, defaults) - - device = self.param_groups[0]['params'][0].device - - for idx,group in enumerate(self.param_groups): - for item in tensor_state: - self.param_groups[idx][item] = group[item].to(device=device) - - if multi_tensor_applier.available and multi_tensor_applier_l2norm.available: - import amp_C - self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm_mp - # Skip buffer - self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=device) - self.multi_tensor_lamb = amp_C.multi_tensor_lamb_mp - else: - raise RuntimeError('apex.optimizers.FusedLAMB requires cuda extensions') - - # Mixed Precision support - self.reduced_precision_dtype = reduced_precision_dtype - self.param_groups_full_precision = [] - - self._step_supports_amp_scaling = True - self.adam_w_mode = 1 if adam_w_mode else 0 - self.use_nvlamb = use_nvlamb - - # This method is overridden from the parent class because there is not a way to override - # the nested function cast() that copies a saved piece of state to the device without - # redundantly doing the copy. - def load_state_dict(self, state_dict): - r"""Loads the optimizer state. - - Args: - state_dict (dict): optimizer state. Should be an object returned - from a call to :meth:`state_dict`. - """ - # deepcopy, to be consistent with module API - state_dict = deepcopy(state_dict) - # Validate the state_dict - groups = self.param_groups - saved_groups = state_dict['param_groups'] - - if len(groups) != len(saved_groups): - raise ValueError("loaded state dict has a different number of " - "parameter groups") - param_lens = (len(g['params']) for g in groups) - saved_lens = (len(g['params']) for g in saved_groups) - if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): - raise ValueError("loaded state dict contains a parameter group " - "that doesn't match the size of optimizer's group") - - # Update the state - id_map = {old_id: p for old_id, p in - zip(chain.from_iterable((g['params'] for g in saved_groups)), - chain.from_iterable((g['params'] for g in groups)))} - - def cast(param, value): - r"""Make a deep copy of value, casting all tensors to device of param.""" - if isinstance(value, torch.Tensor): - # The original version casted the saved value to the params dtype - # This doesn't work for mixed precision Lamb where the momentum and - # velocity are expected to be in full precision while the params are - # in reduced precision - value = value.to(value.device) - return value - elif isinstance(value, dict): - return {k: cast(param, v) for k, v in value.items()} - elif isinstance(value, container_abcs.Iterable): - return type(value)(cast(param, v) for v in value) - else: - return value - - # Copy state assigned to params (and cast tensors to appropriate types). - # State that is not assigned to params is copied as is (needed for - # backward compatibility). - state = defaultdict(dict) - for k, v in state_dict['state'].items(): - if k in id_map: - param = id_map[k] - state[param] = cast(param, v) - else: - state[k] = v - - # Update parameter groups, setting their 'params' value - def update_group(group, new_group): - new_group['params'] = group['params'] - return new_group - param_groups = [ - update_group(g, ng) for g, ng in zip(groups, saved_groups)] - self.__setstate__({'state': state, 'param_groups': param_groups}) - - def _setup_full_precision_params(self): - for i, pg in enumerate(self.param_groups): - param_list = pg['params'] - self.param_groups_full_precision.append({ - 'params': [ - p.clone().detach().to(dtype=torch.float32) - if (self.reduced_precision_dtype is not None) and (p.dtype == self.reduced_precision_dtype) - else None - for p in param_list - ], - }) - - # add_param_groups() is overridden because default items can be tensors. The - # parent version does not clone the default item, so two param groups can - # accidentally point to the same default item value where they can differ - # given they are in separate groups. - def add_param_group(self, param_group): - super().add_param_group(param_group) - for name, default in self.defaults.items(): - if isinstance(default, torch.Tensor): - self.param_groups[len(self.param_groups) - 1][name] = default.clone() - - @torch.no_grad() - def step(self, closure=None, grad_scaler=None): - loss = None - if closure is not None: - loss = closure() - - # The full precision params are set up in the first step of the optimizer - # instead of in the constructor because the full precision params will get out - # out of sync with the model params if DDP syncs the model params across devices - # after the optimizer is constructed. - if len(self.param_groups_full_precision) == 0 : - self._setup_full_precision_params() - - # create separate grad lists for params - grad_list = [] - for gid,group in enumerate(self.param_groups): - for pid,p in enumerate(group['params']): - assert group['params'][0].dtype == p.dtype, \ - "Error: Parameters are not of the identical type: {} != {}".format( - group['params'][0].dtype, p.dtype) - if p.grad is None: - continue - grad_list.append(p.grad) - - # Overflow check of gradients - device = self.param_groups[0]["params"][0].device - found_inf = ( - grad_scaler._check_inf_per_device(self)[device] - if grad_scaler is not None else torch.zeros((1,), device=device) - ) - self._dummy_overflow_buf.copy_(found_inf) - - # Get unscale scale factor - scale, inv_scale = None, None - if grad_scaler: - scale = grad_scaler._get_scale_async() - inv_scale = scale.double().reciprocal().float() - else: - scale = torch.ones((1,), device=device) - inv_scale = torch.ones((1,), device=device) - - # grad_norm is of scaled gradients. - # So, multiply `max_grad_norm` by scale. - max_grad_norm = self.defaults['max_grad_norm'] * scale - grad_norm = multi_tensor_applier_l2norm( - self.multi_tensor_l2norm, - self._dummy_overflow_buf, - [grad_list], - False, - )[0] - - # Run LAMB optimization math - for gid, (group, group_full) in enumerate(zip(self.param_groups, self.param_groups_full_precision)): - bias_correction = 1 if group['bias_correction'] else 0 - beta1, beta2 = group['betas'] - grad_averaging = 1 if group['grad_averaging'] else 0 - - # assume same step across group now to simplify things - # per parameter step can be easily support by making it tensor, or pass list into kernel - group['step'] += (self._dummy_overflow_buf != 1).to(torch.int) - - state_lists = [ [], # (0) grads - [], # (1) params - [], # (2) momentum state - [], # (3) velocity state - ] - if self.reduced_precision_dtype is not None: - state_lists.append([]) # (4) params reduced_dtype - - - for p, p_full in zip(group['params'], group_full['params']): - if p.grad is None: - continue - assert not p.grad.is_sparse - - state = self.state[p] - # State initialization - if len(state) == 0: - dtype = p.dtype - if self.reduced_precision_dtype is not None and p.dtype == self.reduced_precision_dtype : - dtype = torch.float32 - # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p.data, dtype=dtype) - # Exponential moving average of gradient values - state['exp_avg_sq'] = torch.zeros_like(p.data, dtype=dtype) - - if self.reduced_precision_dtype is not None : - state_lists[0].append(p.grad.data) - state_lists[1].append(p_full.data) - state_lists[2].append(state['exp_avg']) - state_lists[3].append(state['exp_avg_sq']) - state_lists[4].append(p.data) - else : - state_lists[0].append(p.grad.data) - state_lists[1].append(p.data) - state_lists[2].append(state['exp_avg']) - state_lists[3].append(state['exp_avg_sq']) - - multi_tensor_applier( - self.multi_tensor_lamb, - self._dummy_overflow_buf, - state_lists, - group['lr'], - beta1, - beta2, - group['eps'], - group['step'], - bias_correction, - group['weight_decay'], - grad_averaging, - self.adam_w_mode, - grad_norm, - max_grad_norm, - self.use_nvlamb, - found_inf, - inv_scale) - - return loss diff --git a/apex/optimizers/fused_novograd.py b/apex/optimizers/fused_novograd.py deleted file mode 100644 index b3ec5ac..0000000 --- a/apex/optimizers/fused_novograd.py +++ /dev/null @@ -1,214 +0,0 @@ -import torch -from apex.multi_tensor_apply import multi_tensor_applier - -class FusedNovoGrad(torch.optim.Optimizer): - - """Implements NovoGrad algorithm. - - Currently GPU-only. Requires Apex to be installed via - ``pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./``. - - This version of fused NovoGrad implements 2 fusions. - - * Fusion of the NovoGrad update's elementwise operations - * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches. - - :class:`apex.optimizers.FusedNovoGrad`'s usage is identical to any Pytorch optimizer:: - - opt = apex.optimizers.FusedNovoGrad(model.parameters(), lr = ....) - ... - opt.step() - - :class:`apex.optimizers.FusedNovoGrad` may be used with or without Amp. If you wish to use :class:`FusedNovoGrad` with Amp, - you may choose any ``opt_level``:: - - opt = apex.optimizers.FusedNovoGrad(model.parameters(), lr = ....) - model, opt = amp.initialize(model, opt, opt_level="O0" or "O1 or "O2") - ... - opt.step() - - In general, ``opt_level="O1"`` is recommended. - - It has been proposed in `Jasper: An End-to-End Convolutional Neural Acoustic Model`_. - More info: https://nvidia.github.io/OpenSeq2Seq/html/optimizers.html#novograd - - Arguments: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups. - lr (float, optional): learning rate. (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its norm. (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability. (default: 1e-8) - weight_decay (float, optional): weight decay (L2 penalty) (default: 0) - amsgrad (boolean, optional): whether to use the AMSGrad variant of this - algorithm from the paper `On the Convergence of Adam and Beyond`_ - NOT SUPPORTED now! (default: False) - reg_inside_moment (bool, optional): whether do regularization (norm and L2) - in momentum calculation. True for include, False for not include and - only do it on update term. (default: False) - grad_averaging (bool, optional): whether apply (1-beta1) to grad when - calculating running averages of gradient. (default: True) - norm_type (int, optional): which norm to calculate for each layer. - 2 for L2 norm, and 0 for infinite norm. These 2 are only supported - type now. (default: 2) - init_zero (bool, optional): whether init norm with 0 (start averaging on - 1st step) or first step norm (start averaging on 2nd step). True for - init with 0. (default: False) - set_grad_none (bool, optional): whether set grad to None when zero_grad() - method is called. (default: True) - - .. _Jasper - An End-to-End Convolutional Neural Acoustic Model: - https://arxiv.org/abs/1904.03288 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ - """ - - def __init__(self, params, lr=1e-3, bias_correction=True, - betas=(0.9, 0.999), eps=1e-8, weight_decay=0., - amsgrad=False, reg_inside_moment=False, - grad_averaging=True, norm_type=2, init_zero=False, - set_grad_none=True): - if amsgrad: - raise RuntimeError('FusedNovoGrad does not support the AMSGrad variant.') - defaults = dict(lr=lr, bias_correction=bias_correction, - betas=betas, eps=eps, weight_decay=weight_decay, - grad_averaging=grad_averaging, norm_type=norm_type, - init_zero=init_zero) - super(FusedNovoGrad, self).__init__(params, defaults) - if multi_tensor_applier.available: - import amp_C - # Skip buffer - - # Creating the overflow buffer on the same device as the params tensors. - self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=self.param_groups[0]["params"][0].device) - self.multi_tensor_novograd = amp_C.multi_tensor_novograd - else: - raise RuntimeError('apex.optimizers.FusedNovoGrad requires cuda extensions') - - self.moment_mode = 0 if reg_inside_moment else 1 - self.set_grad_none = set_grad_none - - def zero_grad(self): - if self.set_grad_none: - for group in self.param_groups: - for p in group['params']: - p.grad = None - else: - super(FusedNovoGrad, self).zero_grad() - - def load_state_dict(self, state_dict): - super(FusedNovoGrad, self).load_state_dict(state_dict) - # in case exp_avg_sq is not on the same device as params, move it there - for group in self.param_groups: - if len(group['params']) > 0: - group['exp_avg_sq'][0] = group['exp_avg_sq'][0].to(group['params'][0].device) - group['exp_avg_sq'][1] = group['exp_avg_sq'][1].to(group['params'][0].device) - - def step(self, closure=None): - """Performs a single optimization step. - - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - loss = closure() - - for group in self.param_groups: - bias_correction = 1 if group['bias_correction'] else 0 - beta1, beta2 = group['betas'] - grad_averaging = 1 if group['grad_averaging'] else 0 - - # assume same step across group now to simplify things - # per parameter step can be easily support by making it tensor, or pass list into kernel - if 'step' in group: - group['step'] += 1 - else: - group['step'] = 1 - - # create lists for multi-tensor apply - g_16, p_16, m_16 = [], [], [] - g_32, p_32, m_32 = [], [], [] - - for p in group['params']: - if p.grad is None: - continue - if p.grad.data.is_sparse: - raise RuntimeError('FusedNovoGrad does not support sparse gradients, please consider SparseAdam instead') - - state = self.state[p] - # State initialization - if len(state) == 0: - # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p.data) - - if p.dtype in {torch.float16, torch.bfloat16}: - g_16.append(p.grad.data) - p_16.append(p.data) - m_16.append(state['exp_avg']) - elif p.dtype == torch.float32: - g_32.append(p.grad.data) - p_32.append(p.data) - m_32.append(state['exp_avg']) - else: - raise RuntimeError('FusedNovoGrad only support fp16, bfloat16 and fp32.') - - # we store per weight norm as one tensor for one group/precision combination - # different from optim.Adam, we store norm here(not ^2) so we can unify calculation for norm types - if 'exp_avg_sq' not in group: - group['exp_avg_sq'] = [None, None] - if group['init_zero']: - # Creating the following parameters on the same device as the params tensors. - group['exp_avg_sq'][0] = torch.cuda.FloatTensor(len(g_16), device=self.param_groups[0]["params"][0].device).contiguous().fill_(0) - group['exp_avg_sq'][1] = torch.cuda.FloatTensor(len(g_32), device=self.param_groups[0]["params"][0].device).contiguous().fill_(0) - else: # init with first step norm, so first blend have no effect - if group['norm_type'] == 0: - v_16 = [torch.max(torch.abs(g.to(torch.float32))).item() for g in g_16] - v_32 = [torch.max(torch.abs(g)).item() for g in g_32] - elif group['norm_type'] == 2: - v_16 = [torch.sum(torch.pow(g.to(torch.float32), 2)).sqrt().item() for g in g_16] - v_32 = [torch.sum(torch.pow(g, 2)).sqrt().item() for g in g_32] - else: - raise RuntimeError('FusedNovoGrad only support l2/inf norm now.') - # Creating the following parameters on the same device as the params tensors. - group['exp_avg_sq'][0] = torch.cuda.FloatTensor(v_16, device=self.param_groups[0]["params"][0].device) - group['exp_avg_sq'][1] = torch.cuda.FloatTensor(v_32, device=self.param_groups[0]["params"][0].device) - else: - assert(len(g_16) == group['exp_avg_sq'][0].numel()) - assert(len(g_32) == group['exp_avg_sq'][1].numel()) - - if(len(g_16) > 0): - multi_tensor_applier(self.multi_tensor_novograd, - self._dummy_overflow_buf, - [g_16, p_16, m_16], - group['exp_avg_sq'][0], - group['lr'], - beta1, - beta2, - group['eps'], - group['step'], - bias_correction, - group['weight_decay'], - grad_averaging, - self.moment_mode, - group['norm_type']) - if(len(g_32) > 0): - multi_tensor_applier(self.multi_tensor_novograd, - self._dummy_overflow_buf, - [g_32, p_32, m_32], - group['exp_avg_sq'][1], - group['lr'], - beta1, - beta2, - group['eps'], - group['step'], - bias_correction, - group['weight_decay'], - grad_averaging, - self.moment_mode, - group['norm_type']) - - - return loss diff --git a/apex/optimizers/fused_sgd.py b/apex/optimizers/fused_sgd.py deleted file mode 100644 index 88f26f2..0000000 --- a/apex/optimizers/fused_sgd.py +++ /dev/null @@ -1,264 +0,0 @@ -import torch -from torch.optim.optimizer import Optimizer, required - -from apex.multi_tensor_apply import multi_tensor_applier - -class FusedSGD(Optimizer): - r"""Implements stochastic gradient descent (optionally with momentum). - - Currently GPU-only. Requires Apex to be installed via - ``pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./``. - - This version of fused SGD implements 2 fusions. - - * Fusion of the SGD update's elementwise operations - * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches. - - :class:`apex.optimizers.FusedSGD` may be used as a drop-in replacement for ``torch.optim.SGD``:: - - opt = apex.optimizers.FusedSGD(model.parameters(), lr = ....) - ... - opt.step() - - :class:`apex.optimizers.FusedSGD` may be used with or without Amp. If you wish to use :class:`FusedSGD` with Amp, - you may choose any ``opt_level``:: - - opt = apex.optimizers.FusedSGD(model.parameters(), lr = ....) - model, opt = amp.initialize(model, opt, opt_level="O0" or "O1 or "O2") - ... - opt.step() - - In general, ``opt_level="O1"`` is recommended. - - Nesterov momentum is based on the formula from - `On the importance of initialization and momentum in deep learning`__. - - Args: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float): learning rate - momentum (float, optional): momentum factor (default: 0) - weight_decay (float, optional): weight decay (L2 penalty) (default: 0) - dampening (float, optional): dampening for momentum (default: 0) - nesterov (bool, optional): enables Nesterov momentum (default: False) - - Example: - >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) - >>> optimizer.zero_grad() - >>> loss_fn(model(input), target).backward() - >>> optimizer.step() - - __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf - - .. note:: - The implementation of SGD with Momentum/Nesterov subtly differs from - Sutskever et. al. and implementations in some other frameworks. - - Considering the specific case of Momentum, the update can be written as - - .. math:: - v = \rho * v + g \\ - p = p - lr * v - - where p, g, v and :math:`\rho` denote the parameters, gradient, - velocity, and momentum respectively. - - This is in contrast to Sutskever et. al. and - other frameworks which employ an update of the form - - .. math:: - v = \rho * v + lr * g \\ - p = p - v - - The Nesterov version is analogously modified. - """ - - def __init__(self, params, lr=required, momentum=0, dampening=0, - weight_decay=0, nesterov=False, - wd_after_momentum=False, - materialize_master_grads=True, - set_grad_none=False): - if lr is not required and lr < 0.0: - raise ValueError("Invalid learning rate: {}".format(lr)) - if momentum < 0.0: - raise ValueError("Invalid momentum value: {}".format(momentum)) - if weight_decay < 0.0: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) - - defaults = dict(lr=lr, momentum=momentum, dampening=dampening, - weight_decay=weight_decay, nesterov=nesterov) - if nesterov and (momentum <= 0 or dampening != 0): - raise ValueError("Nesterov momentum requires a momentum and zero dampening") - super(FusedSGD, self).__init__(params, defaults) - - self.wd_after_momentum = wd_after_momentum - self.materialize_master_grads = materialize_master_grads - self.most_recent_scale = 1.0 - self.scale_set_by_backward = False - self.set_grad_none = set_grad_none - - if multi_tensor_applier.available: - import amp_C - # Skip buffer - self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=self.param_groups[0]["params"][0].device) - self.multi_tensor_sgd = amp_C.multi_tensor_sgd - else: - raise RuntimeError('apex.optimizers.FusedSGD requires cuda extensions') - - def __setstate__(self, state): - super(FusedSGD, self).__setstate__(state) - for group in self.param_groups: - group.setdefault('nesterov', False) - - def zero_grad(self): - if self.set_grad_none: - for group in self.param_groups: - for p in group['params']: - p.grad = None - else: - super(FusedSGD, self).zero_grad() - - def get_momentums(self, params): - momentums = [] - first_run = True - for p in params: - param_state = self.state[p] - # torch.optim.SGD initializes momentum in the main loop, we have - # to do it here, and track whether or not we've done so, so that - # momentum application can be skipped in the main kernel. - if 'momentum_buffer' not in param_state: - first_run = True - buf = param_state['momentum_buffer'] = torch.zeros_like(p.data) - momentums.append(buf) - else: - first_run = False - momentums.append(param_state['momentum_buffer']) - return momentums, first_run - - def step(self, closure=None): - """Performs a single optimization step. - - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - loss = closure() - - explicit_master_params = (hasattr(self, "_amp_stash") and - hasattr(self._amp_stash, "fp32_from_fp16_groups")) - - for gid, group in enumerate(self.param_groups): - weight_decay = group['weight_decay'] - momentum = group['momentum'] - dampening = group['dampening'] - nesterov = group['nesterov'] - - - # For each group, there are 3 possible combinations we need to consider: - # grad_type, param_to_update_type, momentum_type, requires_fp16_model_copy - # 1. fp16, fp16, fp16, No - # 2. fp32, fp32, fp32, No - # 3. fp16, fp32, fp32, Yes - - first_runs = [True, True] - - # I think a bit of code divergence in exchange for naming clarity is worthwhile - if explicit_master_params: - stash = self._amp_stash - - fp32_params = [p for p in stash.fp32_from_fp32_groups[gid] if p.grad is not None] - fp32_grads = [p.grad for p in stash.fp32_from_fp32_groups[gid] if p.grad is not None] - fp32_momentums, first_runs[1] = self.get_momentums(fp32_params) - - if self.materialize_master_grads: - fp16_model_params = [p for i, p in enumerate( - stash.fp16_groups[gid]) if stash.fp32_from_fp16_groups[gid][i].grad is not None] - fp32_from_fp16_params = [p for p in stash.fp32_from_fp16_groups[gid] if p.grad is not None] - fp32_from_fp16_grads = [] - for p in fp32_from_fp16_params: - if p.is_contiguous(memory_format=torch.contiguous_format): - fp32_from_fp16_grads.append(p.grad) - elif p.is_contiguous(memory_format=torch.channels_last): - fp32_from_fp16_grads.append(p.grad.to(memory_format=torch.channels_last)) - elif p.is_contiguous(memory_format=torch.channel_last_3d): - fp32_from_fp16_grads.append(p.grad.to(memory_format=torch.channel_last_3d)) - else: - assert(False), "Unsupported memory format. Supports only contiguous_format, channels_last, or channel_last_3d." - fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(fp32_from_fp16_params) - - fp16_set = [fp32_from_fp16_grads, fp32_from_fp16_params, - fp32_from_fp16_momentums, fp16_model_params] - else: - fp16_model_params = [p for p in stash.fp16_groups[gid] if p.grad is not None] - fp16_model_grads = [] - for p in fp16_model_params: - if p.is_contiguous(memory_format=torch.contiguous_format): - fp16_model_grads.append(p.grad) - elif p.is_contiguous(memory_format=torch.channels_last): - fp16_model_grads.append(p.grad.to(memory_format=torch.channels_last)) - elif p.is_contiguous(memory_format=torch.channel_last_3d): - fp16_model_grads.append(p.grad.to(memory_format=torch.channel_last_3d)) - else: - assert(False), "Unsupported memory format. Supports only contiguous_format, channels_last, or channel_last_3d." - fp32_from_fp16_params = [p for i, p in enumerate( - stash.fp32_from_fp16_groups[gid]) if stash.fp16_groups[gid][i].grad is not None] - fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(fp32_from_fp16_params) - - fp16_set = [fp16_model_grads, fp32_from_fp16_params, - fp32_from_fp16_momentums, fp16_model_params] - - launch_sets= [fp16_set, [fp32_grads, fp32_params, fp32_momentums]] - else: - fp16_params = [p for p in group['params'] if (p.dtype == torch.float16 and p.grad is not None)] - fp16_grads = [] - for p in fp16_params: - if p.is_contiguous(memory_format=torch.contiguous_format): - fp16_grads.append(p.grad) - elif p.is_contiguous(memory_format=torch.channels_last): - fp16_grads.append(p.grad.to(memory_format=torch.channels_last)) - elif p.is_contiguous(memory_format=torch.channel_last_3d): - fp16_grads.append(p.grad.to(memory_format=torch.channel_last_3d)) - else: - assert(False), "Unsupported memory format. Supports only contiguous_format, channels_last, or channel_last_3d." - fp16_momentums, first_runs[0] = self.get_momentums(fp16_params) - - fp32_params = [p for p in group['params'] if (p.dtype == torch.float32 and p.grad is not None)] - fp32_grads = [] - for p in fp32_params: - if p.is_contiguous(memory_format=torch.contiguous_format): - fp32_grads.append(p.grad) - elif p.is_contiguous(memory_format=torch.channels_last): - fp32_grads.append(p.grad.to(memory_format=torch.channels_last)) - elif p.is_contiguous(memory_format=torch.channel_last_3d): - fp32_grads.append(p.grad.to(memory_format=torch.channel_last_3d)) - else: - assert(False), "Unsupported memory format. Supports only contiguous_format, channels_last, or channel_last_3d." - fp32_momentums, first_runs[1] = self.get_momentums(fp32_params) - - launch_sets = [[fp16_grads, fp16_params, fp16_momentums], - [fp32_grads, fp32_params, fp32_momentums]] - - for s, (launch_set, first_run) in enumerate(zip(launch_sets, first_runs)): - assert len(launch_set[0]) == len(launch_set[1]) - assert len(launch_set[0]) == len(launch_set[2]) - if len(launch_set[0]) > 0: - # multi_tensor_applier has nhwc support: https://github.com/NVIDIA/apex/pull/732 - multi_tensor_applier( - self.multi_tensor_sgd, - self._dummy_overflow_buf, - launch_set, - weight_decay, - momentum, - dampening, - group['lr'], - nesterov, - first_run, - self.wd_after_momentum, - 1.0/self.most_recent_scale) - - self.most_recent_scale = 1.0 - self.scale_set_by_backward = False - - return loss diff --git a/apex/parallel/LARC.py b/apex/parallel/LARC.py deleted file mode 100644 index 4a93fcd..0000000 --- a/apex/parallel/LARC.py +++ /dev/null @@ -1,107 +0,0 @@ -import torch -from torch import nn -from torch.nn.parameter import Parameter - -class LARC(object): - """ - :class:`LARC` is a pytorch implementation of both the scaling and clipping variants of LARC, - in which the ratio between gradient and parameter magnitudes is used to calculate an adaptive - local learning rate for each individual parameter. The algorithm is designed to improve - convergence of large batch training. - - See https://arxiv.org/abs/1708.03888 for calculation of the local learning rate. - - In practice it modifies the gradients of parameters as a proxy for modifying the learning rate - of the parameters. This design allows it to be used as a wrapper around any torch.optim Optimizer. - - ``` - model = ... - optim = torch.optim.Adam(model.parameters(), lr=...) - optim = LARC(optim) - ``` - - It can even be used in conjunction with apex.fp16_utils.FP16_optimizer. - - ``` - model = ... - optim = torch.optim.Adam(model.parameters(), lr=...) - optim = LARC(optim) - optim = apex.fp16_utils.FP16_Optimizer(optim) - ``` - - Args: - optimizer: Pytorch optimizer to wrap and modify learning rate for. - trust_coefficient: Trust coefficient for calculating the lr. See https://arxiv.org/abs/1708.03888 - clip: Decides between clipping or scaling mode of LARC. If `clip=True` the learning rate is set to `min(optimizer_lr, local_lr)` for each parameter. If `clip=False` the learning rate is set to `local_lr*optimizer_lr`. - eps: epsilon kludge to help with numerical stability while calculating adaptive_lr - """ - - def __init__(self, optimizer, trust_coefficient=0.02, clip=True, eps=1e-8): - self.optim = optimizer - self.trust_coefficient = trust_coefficient - self.eps = eps - self.clip = clip - - def __getstate__(self): - return self.optim.__getstate__() - - def __setstate__(self, state): - self.optim.__setstate__(state) - - @property - def state(self): - return self.optim.state - - def __repr__(self): - return self.optim.__repr__() - - @property - def param_groups(self): - return self.optim.param_groups - - @param_groups.setter - def param_groups(self, value): - self.optim.param_groups = value - - def state_dict(self): - return self.optim.state_dict() - - def load_state_dict(self, state_dict): - self.optim.load_state_dict(state_dict) - - def zero_grad(self): - self.optim.zero_grad() - - def add_param_group(self, param_group): - self.optim.add_param_group( param_group) - - def step(self): - with torch.no_grad(): - weight_decays = [] - for group in self.optim.param_groups: - # absorb weight decay control from optimizer - weight_decay = group['weight_decay'] if 'weight_decay' in group else 0 - weight_decays.append(weight_decay) - group['weight_decay'] = 0 - for p in group['params']: - if p.grad is None: - continue - param_norm = torch.norm(p.data) - grad_norm = torch.norm(p.grad.data) - - if param_norm != 0 and grad_norm != 0: - # calculate adaptive lr + weight decay - adaptive_lr = self.trust_coefficient * (param_norm) / (grad_norm + param_norm * weight_decay + self.eps) - - # clip learning rate for LARC - if self.clip: - # calculation of adaptive_lr so that when multiplied by lr it equals `min(adaptive_lr, lr)` - adaptive_lr = min(adaptive_lr/group['lr'], 1) - - p.grad.data += weight_decay * p.data - p.grad.data *= adaptive_lr - - self.optim.step() - # return weight decay control to optimizer - for i, group in enumerate(self.optim.param_groups): - group['weight_decay'] = weight_decays[i] diff --git a/apex/parallel/README.md b/apex/parallel/README.md deleted file mode 100644 index e7910d8..0000000 --- a/apex/parallel/README.md +++ /dev/null @@ -1,66 +0,0 @@ -## Distributed Data Parallel - -distributed.py contains the source code for `apex.parallel.DistributedDataParallel`, a module wrapper that enables multi-process multi-GPU data parallel training optimized for NVIDIA's NCCL communication library. - -`apex.parallel.DistributedDataParallel` achieves high performance by overlapping communication with -computation in the backward pass and bucketing smaller transfers to reduce the total number of -transfers required. - -multiproc.py contains the source code for `apex.parallel.multiproc`, a launch utility that places one process on each of the node's available GPUs. - -#### [API Documentation](https://nvidia.github.io/apex/parallel.html) - -#### [Example/Walkthrough](https://github.com/NVIDIA/apex/tree/master/examples/distributed) - -#### [Imagenet example with Mixed Precision](https://github.com/NVIDIA/apex/tree/master/examples/imagenet) - -#### [Simple example with FP16_Optimizer](https://github.com/NVIDIA/apex/tree/master/examples/FP16_Optimizer_simple/distributed_apex) - -### Synchronized Batch Normalization - -`apex.parallel.SyncBatchNorm` has similar APIs as with `torch.nn.BatchNorm*N*d`. -It reduces stats on the first (channel) dimension of the Tensor and accepts -arbitrary spatial dimensions. - -#### Installation - -Apex provides two sync BN implementation: - -1. There is the Python-only implementation, which is the default implementation -when install with `python setup.py install`. -It uses PyTorch primitive operations and distributed communication package from -`torch.distributed`. - - - _Python-only implementation requires input tensor to be of same data type as -layer_ - -2. We also provide implementation with kernels through CUDA/C++ extension with -improved performance. We are experimenting with Welford and Kahan for reduction -hoping to get better accuracy. - To use the kernel implementation, user need to install Apex with CUDA extension -enabled `python setup.py install --cuda_ext`. - - - _Custom kernel implementation supports fp16 input with fp32 layer as cudnn. -This is required to run imagenet example in fp16._ - - - _Currently kernel implementation only supports GPU._ - -#### HowTo - -1. User could use `apex.parallel.SyncBatchNorm` by building their module with -the layer explicitly. - -``` -import apex -input_t = torch.randn(3, 5, 20).cuda() -sbn = apex.parallel.SyncBatchNorm(5).cuda() -output_t = sbn(input) -``` - -2. User could also take a constructed `torch.nn.Model` and replace all its `torch.nn.BatchNorm*N*d` modules with `apex.parallel.SyncBatchNorm` through utility function `apex.parallel.convert_syncbn_model`. - -``` -# model is an instance of torch.nn.Module -import apex -sync_bn_model = apex.parallel.convert_syncbn_model(model) -``` diff --git a/apex/parallel/__init__.py b/apex/parallel/__init__.py deleted file mode 100644 index 3cd7ae5..0000000 --- a/apex/parallel/__init__.py +++ /dev/null @@ -1,95 +0,0 @@ -import torch - -if hasattr(torch.distributed, 'ReduceOp'): - ReduceOp = torch.distributed.ReduceOp -elif hasattr(torch.distributed, 'reduce_op'): - ReduceOp = torch.distributed.reduce_op -else: - ReduceOp = torch.distributed.deprecated.reduce_op - -from .distributed import DistributedDataParallel, Reducer -# This is tricky because I'd like SyncBatchNorm to be exposed the same way -# for both the cuda-enabled and python-fallback versions, and I don't want -# to suppress the error information. -try: - import syncbn - from .optimized_sync_batchnorm import SyncBatchNorm -except ImportError as err: - from .sync_batchnorm import SyncBatchNorm - SyncBatchNorm.syncbn_import_error = err - -def convert_syncbn_model(module, process_group=None, channel_last=False): - ''' - Recursively traverse module and its children to replace all instances of - ``torch.nn.modules.batchnorm._BatchNorm`` with :class:`apex.parallel.SyncBatchNorm`. - - All ``torch.nn.BatchNorm*N*d`` wrap around - ``torch.nn.modules.batchnorm._BatchNorm``, so this function lets you easily switch - to use sync BN. - - Args: - module (torch.nn.Module): input module - - Example:: - - >>> # model is an instance of torch.nn.Module - >>> import apex - >>> sync_bn_model = apex.parallel.convert_syncbn_model(model) - ''' - mod = module - if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm): - return module - if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): - mod = SyncBatchNorm(module.num_features, module.eps, module.momentum, module.affine, module.track_running_stats, process_group, channel_last=channel_last) - mod.running_mean = module.running_mean - mod.running_var = module.running_var - mod.num_batches_tracked = module.num_batches_tracked - if module.affine: - mod.weight.data = module.weight.data.clone().detach() - mod.bias.data = module.bias.data.clone().detach() - for name, child in module.named_children(): - mod.add_module(name, convert_syncbn_model(child, - process_group=process_group, - channel_last=channel_last)) - # TODO(jie) should I delete model explicitly? - del module - return mod - -def create_syncbn_process_group(group_size): - ''' - Creates process groups to be used for syncbn of a give ``group_size`` and returns - process group that current GPU participates in. - - ``group_size`` must divide the total number of GPUs (world_size). - - ``group_size`` of 0 would be considered as =world_size. In this case ``None`` will be returned. - - ``group_size`` of 1 would be equivalent to using non-sync bn, but will still carry the overhead. - - Args: - group_size (int): number of GPU's to collaborate for sync bn - - Example:: - - >>> # model is an instance of torch.nn.Module - >>> import apex - >>> group = apex.parallel.create_syncbn_process_group(group_size) - ''' - - if group_size==0: - return None - - world_size = torch.distributed.get_world_size() - assert(world_size >= group_size) - assert(world_size % group_size == 0) - - group=None - for group_num in (range(world_size//group_size)): - group_ids = range(group_num*group_size, (group_num+1)*group_size) - cur_group = torch.distributed.new_group(ranks=group_ids) - if (torch.distributed.get_rank()//group_size == group_num): - group = cur_group - #can not drop out and return here, every process must go through creation of all subgroups - - assert(group is not None) - return group diff --git a/apex/parallel/distributed.py b/apex/parallel/distributed.py deleted file mode 100644 index 6aa6a6e..0000000 --- a/apex/parallel/distributed.py +++ /dev/null @@ -1,640 +0,0 @@ -import torch -import torch.distributed as dist -from torch.nn.modules import Module -from torch.autograd import Variable -from collections import OrderedDict -from itertools import chain -import copy -import importlib -from ..multi_tensor_apply import multi_tensor_applier - -imported_flatten_impl = False - -def import_flatten_impl(): - global flatten_impl, unflatten_impl, imported_flatten_impl - try: - import apex_C - flatten_impl = apex_C.flatten - unflatten_impl = apex_C.unflatten - except ImportError: - print("Warning: apex was installed without --cpp_ext. Falling back to Python flatten and unflatten.") - flatten_impl = torch._utils._flatten_dense_tensors - unflatten_impl = torch._utils._unflatten_dense_tensors - imported_flatten_impl = True - -def flatten(bucket): - if not imported_flatten_impl: - import_flatten_impl() - return flatten_impl(bucket) - -def unflatten(coalesced, bucket): - if not imported_flatten_impl: - import_flatten_impl() - return unflatten_impl(coalesced, bucket) - -# apply_dist_call requires that tensors in 'bucket' are all the same type. -def apply_flat_dist_call(bucket, call, extra_args=None): - - coalesced = flatten(bucket) - - if extra_args is not None: - call(coalesced, *extra_args) - else: - call(coalesced) - - if call is dist.all_reduce: - coalesced /= dist.get_world_size() - - for buf, synced in zip(bucket, unflatten(coalesced, bucket)): - buf.copy_(synced) - -def split_half_float_double_bfloat16(tensors): - dtypes = ["torch.cuda.HalfTensor", "torch.cuda.FloatTensor", "torch.cuda.DoubleTensor", "torch.cuda.BFloat16Tensor"] - buckets = [] - for i, dtype in enumerate(dtypes): - bucket = [t for t in tensors if t.type() == dtype] - if bucket: - buckets.append(bucket) - return buckets - -def split_by_type(tensors): - buckets = OrderedDict() - for tensor in tensors: - tp = tensor.type() - if tp not in buckets: - buckets[tp] = [] - buckets[tp].append(tensor) - return buckets - -# flat_dist_call organizes 'tensors' by type. -def flat_dist_call(tensors, call, extra_args=None): - buckets = split_by_type(tensors) - - for tp in buckets: - bucket = buckets[tp] - apply_flat_dist_call(bucket, call, extra_args) - - -def extract_tensors(maybe_tensor, tensor_list): - if torch.is_tensor(maybe_tensor): - tensor_list.append(maybe_tensor) - else: - try: - for item in maybe_tensor: - extract_tensors(item, tensor_list) - except TypeError: - return - - -class Reducer(object): - """ - :class:`apex.parallel.Reducer` is a simple class that helps allreduce a module's parameters - across processes. :class:`Reducer` is intended to give the user additional control: - Unlike :class:`DistributedDataParallel`, :class:`Reducer` will not automatically allreduce - parameters during ``backward()``. - Instead, :class:`Reducer` waits for the user to call ``.reduce()`` manually. - This enables, for example, delaying the allreduce to be carried out every - several iterations instead of every single iteration. - - Like :class:`DistributedDataParallel`, :class:`Reducer` averages any tensors it allreduces - over the number of participating processes. - - :class:`Reducer` is designed to work with the upstream launch utility script - ``torch.distributed.launch`` with ``--nproc_per_node <= number of gpus per node``. - When used with this launcher, :class:`Reducer` assumes 1:1 mapping of processes to GPUs. - It also assumes that your script calls ``torch.cuda.set_device(args.rank)`` before creating the model. - - Args: - module_or_grads_list: Either a network definition (module) being run in multi-gpu/distributed mode, or an iterable of gradients to be reduced. If a module is passed in, the Reducer constructor will sync the parameters across processes (broadcasting from rank 0) to make sure they're all initialized with the same values. If a list of gradients (that came from some module) is passed in, the user is responsible for manually syncing that module's parameters at the beginning of training. - """ - - def __init__(self, module_or_grads_list): - if isinstance(module_or_grads_list, Module): - self.module = module_or_grads_list - flat_dist_call([param.data for param in self.module.parameters()], dist.broadcast, (0,) ) - - else: - self.module = None - self.grads = [] - extract_tensors(module_or_grads_list, self.grads) - - def reduce(self): - if self.module: - grads = [param.grad.data for param in self.module.parameters() if param.grad is not None] - flat_dist_call(grads, dist.all_reduce) - else: - flat_dist_call(self.grads, dist.all_reduce) - - -class DistributedDataParallel(Module): - """ - :class:`apex.parallel.DistributedDataParallel` is a module wrapper that enables - easy multiprocess distributed data parallel training, similar to ``torch.nn.parallel.DistributedDataParallel``. Parameters are broadcast across participating processes on initialization, and gradients are - allreduced and averaged over processes during ``backward()``. - - :class:`DistributedDataParallel` is optimized for use with NCCL. It achieves high performance by - overlapping communication with computation during ``backward()`` and bucketing smaller gradient - transfers to reduce the total number of transfers required. - - :class:`DistributedDataParallel` is designed to work with the upstream launch utility script - ``torch.distributed.launch`` with ``--nproc_per_node <= number of gpus per node``. - When used with this launcher, :class:`DistributedDataParallel` assumes 1:1 mapping of processes to GPUs. - It also assumes that your script calls ``torch.cuda.set_device(args.rank)`` before creating the model. - - https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed shows detailed usage. - https://github.com/NVIDIA/apex/tree/master/examples/imagenet shows another example - that combines :class:`DistributedDataParallel` with mixed precision training. - - Args: - module: Network definition to be run in multi-gpu/distributed mode. - message_size (int, default=1e7): Minimum number of elements in a communication bucket. - delay_allreduce (bool, default=False): Delay all communication to the end of the backward pass. This disables overlapping communication with computation. - allreduce_trigger_params (list, optional, default=None): If supplied, should contain a list of parameters drawn from the model. Allreduces will be kicked off whenever one of these parameters receives its gradient (as opposed to when a bucket of size message_size is full). At the end of backward(), a cleanup allreduce to catch any remaining gradients will also be performed automatically. If allreduce_trigger_params is supplied, the message_size argument will be ignored. - allreduce_always_fp32 (bool, default=False): Convert any FP16 gradients to FP32 before allreducing. This can improve stability for widely scaled-out runs. - gradient_average (bool, default=True): Option to toggle whether or not DDP averages the allreduced gradients over processes. For proper scaling, the default value of True is recommended. - gradient_predivide_factor (float, default=1.0): Allows perfoming the average of gradients over processes partially before and partially after the allreduce. Before allreduce: ``grads.mul_(1.0/gradient_predivide_factor)``. After allreduce: ``grads.mul_(gradient_predivide_factor/world size)``. This can reduce the stress on the dynamic range of FP16 allreduces for widely scaled-out runs. - - .. warning:: - If ``gradient_average=False``, the pre-allreduce division (``grads.mul_(1.0/gradient_predivide_factor)``) will still be applied, but the post-allreduce gradient averaging (``grads.mul_(gradient_predivide_factor/world size)``) will be omitted. - - """ - - def __init__(self, - module, - message_size=10000000, - delay_allreduce=False, - shared_param=None, - allreduce_trigger_params=None, - retain_allreduce_buffers=False, - allreduce_always_fp32=False, - num_allreduce_streams=1, - allreduce_communicators=None, - gradient_average=True, - gradient_predivide_factor=1.0, - gradient_average_split_factor=None, - prof=False): - super(DistributedDataParallel, self).__init__() - - # Backward/forward compatibility around - # https://github.com/pytorch/pytorch/commit/540ef9b1fc5506369a48491af8a285a686689b36 and - # https://github.com/pytorch/pytorch/commit/044d00516ccd6572c0d6ab6d54587155b02a3b86 - if hasattr(dist, "get_backend"): - self._backend = dist.get_backend() - if hasattr(dist, "DistBackend"): - self.backend_enum_holder = dist.DistBackend - else: - self.backend_enum_holder = dist.Backend - else: - self._backend = dist._backend - self.backend_enum_holder = dist.dist_backend - - self.warn_on_half = True if self._backend == self.backend_enum_holder.GLOO else False - - self.prof = prof - - self.allreduce_different_streams = (num_allreduce_streams > 1) - self.num_allreduce_streams = num_allreduce_streams - self.allreduce_communicators = allreduce_communicators - if self.allreduce_communicators: - assert len(allreduce_communicators[0]) == num_allreduce_streams - assert len(allreduce_communicators[0]) == len(allreduce_communicators[1]) - assert self.allreduce_different_streams - - if self.allreduce_different_streams and delay_allreduce: - raise ValueError("self.allreduce_different_streams may only be used if delay_allreduce=False.") - - if shared_param is not None: - raise ValueError("shared_param is no longer supported as an option. It was misleadingly named from the start. It turns out overlapping communication with computation should work fine with shared parameters. If you still wish to delay communication to the end of the backward pass, use delay_allreduce=True|False instead.") - - self.world_size = float(dist.get_world_size()) - - self.retain_allreduce_buffers = retain_allreduce_buffers - self.allreduce_always_fp32 = allreduce_always_fp32 - self.gradient_average = gradient_average - self.gradient_predivide_factor = gradient_predivide_factor - - self.custom_allreduce_triggers = False - if allreduce_trigger_params is not None: - if delay_allreduce: - raise ValueError("Setting allreduce_trigger_params is only valid if delay_allreduce=False.") - self.custom_allreduce_triggers = True - self.allreduce_trigger_params = set([id(param) for param in allreduce_trigger_params]) - - self.delay_allreduce = delay_allreduce - self.message_size = message_size - - self.main_stream = torch.cuda.current_stream() - - self.bucket_streams = [] - self.bucket_events = [] - - self.module = module - - self._disable_allreduce = False - - if self._backend == self.backend_enum_holder.NCCL: - for param in self.module.parameters(): - assert param.is_cuda, "NCCL backend only supports model parameters to be on GPU." - - self.active_params = [] - - self.param_type_to_tmp_i = {"torch.cuda.HalfTensor" : 0, - "torch.cuda.FloatTensor" : 1, - "torch.cuda.DoubleTensor" : 2, - "torch.cuda.BFloat16Tensor" : 3} - - if multi_tensor_applier.available: - # TODO: I really need to centralize the C++ backed imports - import amp_C - self.multi_tensor_scale = amp_C.multi_tensor_scale - self._overflow_buf = torch.cuda.IntTensor([0]) - - self.create_hooks() - - flat_dist_call([param.data for param in self.module.parameters()], dist.broadcast, (0,) ) - - - def __setstate__(self, state): - super(DistributedDataParallel, self).__setstate__(state) - if self.allreduce_different_streams and delay_allreduce: - raise ValueError("self.allreduce_different_streams may only be used if delay_allreduce=False.") - - if self.delay_allreduce: - self.needs_refresh = True - - self.bucket_streams = [] - self.bucket_events = [] - - - def __getstate__(self): - attrs = copy.copy(self.__dict__) - if self._backend != self.backend_enum_holder.NCCL: - del attrs['self.bucket_streams'] - del attrs['self.bucket_events'] - return attrs - - def enable_allreduce(self): - self._disable_allreduce = False - - def disable_allreduce(self): - self._disable_allreduce = True - - # Broadcast rank 0's bucket structure across all processes, and have all processes - # regenerate their bucket structures to match. - def sync_bucket_structure(self): - # Append leftover buckets - for tmp_bucket in self.tmp_buckets: - if len(tmp_bucket) > 0: - self.active_i_buckets.append(tmp_bucket) - - self.num_buckets = len(self.active_i_buckets) - self.bucket_sizes = [len(bucket) for bucket in self.active_i_buckets] - - info_tensor = torch.cuda.IntTensor([self.num_buckets] + - self.bucket_sizes + - list(chain(*self.active_i_buckets))) - - dist.broadcast(info_tensor, 0) - - info = [int(entry) for entry in info_tensor] - - self.num_buckets = info[0] - self.bucket_sizes = info[1:self.num_buckets + 1] - self.buckets = [[None for _ in range(self.bucket_sizes[i])] - for i in range(self.num_buckets)] - # Technically, active_i_buckets' work is done. But the information is still useful to - # keep around. Therefore, refresh active_i_buckets based on rank 0 as well. - self.active_i_buckets = [[None for _ in range(self.bucket_sizes[i])] - for i in range(self.num_buckets)] - - flattened_buckets = info[self.num_buckets + 1:] - flat_i = 0 - for bucket_idx in range(self.num_buckets): - for bucket_loc in range(self.bucket_sizes[bucket_idx]): - param_i = flattened_buckets[flat_i] - self.active_i_buckets[bucket_idx][bucket_loc] = param_i - self.param_id_to_bucket[id(self.active_params[param_i])] = (bucket_idx, bucket_loc) - flat_i += 1 - - - def create_hooks(self): - # Fallback hook that's only called at the end of backward. - # Used if you deliberately want to delay allreduces to the end, or to refresh the - # bucket structure that will be used to overlap communication with computation in later - # iterations. - def allreduce_params(): - # Bucket record refresh - if not self.delay_allreduce: - if self.needs_refresh: - self.sync_bucket_structure() - - self.needs_refresh = False - - self.allreduce_fallback() - - - def overlapping_backward_epilogue(): - for stream, event in zip(self.bucket_streams, self.bucket_events): - stream.record_event(event) - torch.cuda.current_stream().wait_event(event) - - # Sanity checks that all the buckets were kicked off - if self.next_bucket != self.num_buckets: - raise RuntimeError("In epilogue, next_bucket ({}) != num_buckets ({}). ".format( - self.next_bucket, self.num_buckets), - "This probably indicates some buckets were not allreduced.") - - for actual, expected in zip(self.buckets_ready_size, self.bucket_sizes): - if actual != expected: - raise RuntimeError("Some param buckets were not allreduced.") - - - self.grad_accs = [] - for param in self.module.parameters(): - if param.requires_grad: - def wrapper(param): - param_tmp = param.expand_as(param) - grad_acc = param_tmp.grad_fn.next_functions[0][0] - - def allreduce_hook(*unused): - if self.prof: - torch.cuda.nvtx.range_push("allreduce_hook") - - if not self._disable_allreduce: - if self.delay_allreduce or self.needs_refresh: - # TODO: How do we want to handle multiple backward passes between - # each forward, e.g., backward passes with retain_graph=True? - # needs_refresh and callback_queued are both vulnerable states. - if not self.delay_allreduce and self.needs_refresh: - # Use the backward pass to build the bucket structure on the fly. - active_i = self.param_id_to_active_i[id(param)] - - # Float, half, and double tensors are grouped into buckets separately. - current_type = self.param_type_to_tmp_i[param.type()] - - self.tmp_buckets[current_type].append(active_i) - - ship_tmp_bucket = False - if self.custom_allreduce_triggers: - if id(param) in self.allreduce_trigger_params: - ship_tmp_bucket = True - else: - self.tmp_numels[current_type] += param.numel() - if self.tmp_numels[current_type] >= self.message_size: - ship_tmp_bucket = True - - # To consider: If custom_allreduce_triggers are in use, ship all - # tmp_buckets, not just tmp_buckets[current_type]. - if ship_tmp_bucket: - self.active_i_buckets.append(self.tmp_buckets[current_type]) - self.tmp_buckets[current_type] = [] - self.tmp_numels[current_type] = 0 - - if not self.callback_queued: - Variable._execution_engine.queue_callback(allreduce_params) - self.callback_queued = True - else: - if not self.callback_queued: - Variable._execution_engine.queue_callback(overlapping_backward_epilogue) - self.callback_queued = True - - self.comm_ready_buckets(param) - - if self.prof: - torch.cuda.nvtx.range_pop() - - grad_acc.register_hook(allreduce_hook) - self.grad_accs.append(grad_acc) - - wrapper(param) - - - def _stream_this_bucket(self, bucket_idx): - if self.allreduce_different_streams: - return self.bucket_streams[bucket_idx%self.num_allreduce_streams] - else: - return self.bucket_streams[0] - - - def _event_this_bucket(self, bucket_idx): - if self.allreduce_different_streams: - return self.bucket_events[bucket_idx%self.num_allreduce_streams] - else: - return self.bucket_events[0] - - - def allreduce_bucket(self, bucket, bucket_idx, force_default_stream): - tensor = flatten(bucket) - - if force_default_stream: - bucket_stream = self.main_stream - else: - bucket_stream = self._stream_this_bucket(bucket_idx) - bucket_event = self._event_this_bucket(bucket_idx) - torch.cuda.current_stream().record_event(bucket_event) - bucket_stream.wait_event(bucket_event) - - with torch.cuda.stream(bucket_stream): - # self.main_stream.wait_stream(torch.cuda.current_stream()) - # torch.cuda.synchronize() - - tensor_to_allreduce = tensor - - if self.allreduce_always_fp32: - tensor_to_allreduce = tensor.float() - - if self.gradient_predivide_factor != 1.0: - tensor_to_allreduce.mul_(1./self.gradient_predivide_factor) - - if self.allreduce_different_streams and not force_default_stream: - dist.all_reduce(tensor_to_allreduce, group=self.bucket_pgs[bucket_idx%self.num_allreduce_streams]) - else: - dist.all_reduce(tensor_to_allreduce) - - if self.gradient_average: - tensor_to_allreduce.mul_(self.gradient_predivide_factor/self.world_size) - - if self.allreduce_always_fp32 and tensor is not tensor_to_allreduce: - tensor.copy_(tensor_to_allreduce) - - if not self.retain_allreduce_buffers: - if multi_tensor_applier.available: - multi_tensor_applier( - self.multi_tensor_scale, - self._overflow_buf, - [unflatten(tensor, bucket), bucket], - 1.0) - else: - for buf, synced in zip(bucket, unflatten(tensor, bucket)): - buf.copy_(synced) - - # I think we actually do need this here. After allreduce_bucket returns, tensor will - # eventually go out of scope and die, at which point it could otherwise be freed for - # further reuse by the main stream while the allreduce/div/unflatten are underway in bucket_stream. - tensor.record_stream(bucket_stream) - - return tensor - - - def allreduce_maybe_retain(self, bucket, bucket_idx, force_default_stream=False): - allreduced = self.allreduce_bucket(bucket, bucket_idx, force_default_stream) - if self.retain_allreduce_buffers: - if self.allreduce_buffers[bucket_idx] is not None: - raise RuntimeError("The backward pass is attempting to replace an already-filled " - "allreduce buffer. This is almost certainly an error.") - self.allreduce_buffers[bucket_idx] = allreduced - for view, grad in zip(unflatten(allreduced, bucket), bucket): - grad.data = view - # for buf, synced in zip(bucket, unflatten(allreduced, bucket)): - # buf.copy_(synced) - - - def allreduce_fallback(self): - for stream, event in zip(self.bucket_streams, self.bucket_events): - stream.record_event(event) - torch.cuda.current_stream().wait_event(event) - - if self.retain_allreduce_buffers: - grads = [param.grad for param in self.module.parameters() if param.grad is not None] - else: - grads = [param.grad.data for param in self.module.parameters() if param.grad is not None] - - split_buckets = split_half_float_double_bfloat16(grads) - - # If retain_allreduce_buffers is True and delay_allreduce is False, - # this will only be done during the first backward pass, ignored by the - # training script, and overwritten in the next forward pass. So it's harmless. - if self.retain_allreduce_buffers: - self.allreduce_buffers = [None for _ in range(len(split_buckets))] - - for i, bucket in enumerate(split_buckets): - allreduced = self.allreduce_maybe_retain(bucket, i, force_default_stream=True) - - - def comm_ready_buckets(self, param): - # Need to do this in every hook for compatibility with Ruberry's streaming backward PR. - # self.reduction_stream.wait_stream(torch.cuda.current_stream()) - if self.prof: - torch.cuda.nvtx.range_push("comm_ready_buckets") - - bucket_idx, bucket_loc = self.param_id_to_bucket[id(param)] - - if self.buckets[bucket_idx][bucket_loc] is not None: - raise RuntimeError("The backward pass is attempting to replace an already-filled " - "bucket slot. This is almost certainly an error.") - - if self.retain_allreduce_buffers: - self.buckets[bucket_idx][bucket_loc] = param.grad - else: - self.buckets[bucket_idx][bucket_loc] = param.grad.data - - self.buckets_ready_size[bucket_idx] += 1 - - if self.buckets_ready_size[bucket_idx] == self.bucket_sizes[bucket_idx]: - if bucket_idx == self.next_bucket: - self.allreduce_maybe_retain(self.buckets[bucket_idx], bucket_idx) - - self.next_bucket += 1 - - # Reversing upstream's logic here, because we constructed our buckets based on - # the order things were received during backward. - if len(self.ready_buckets_not_reduced) > 0: - sorted_todo = sorted(self.ready_buckets_not_reduced) - for i in sorted_todo: - # Nothing can be reduced now - if i > self.next_bucket: - break - elif i == self.next_bucket: - self.allreduce_maybe_retain(self.buckets[i], i) - self.ready_buckets_not_reduced.remove(i) - self.next_bucket += 1 - else: - raise ValueError("i should always be >= next_bucket") - else: - self.ready_buckets_not_reduced.add(bucket_idx) - - if self.prof: - torch.cuda.nvtx.range_pop() - - - def forward(self, *inputs, **kwargs): - result = self.module(*inputs, **kwargs) - - if self.prof: - torch.cuda.nvtx.range_push("forward pass DDP logic") - - if not self._disable_allreduce: - if not self.delay_allreduce: - param_list = [param for param in self.module.parameters() if param.requires_grad] - - # Conditions under which to refresh self.record - # Forward has the authority to set needs_refresh to True, but only allreduce_params - # in backward has the authority to set needs_refresh to False. - # Parentheses are not necessary for correct order of operations, but make the intent clearer. - if ((not self.active_params) or - (len(param_list) != len(self.active_params)) or - any([param1 is not param2 for param1, param2 in zip(param_list, self.active_params)])): - self.needs_refresh = True - - if self.needs_refresh: - self.active_i_buckets = [] - self.buckets = [] - self.tmp_buckets = [[], [], [], []] # [running half, float, double, bfloat16 buckets] - self.tmp_numels = [0, 0, 0, 0] - self.bucket_sizes = [] - self.param_id_to_active_i = {id(param) : i for i, param in enumerate(param_list)} - self.param_id_to_bucket = {} - self.bucket_pgs = [] - self.bucket_streams = [] - self.bucket_events = [] - else: - # self.buckets = [[None for _ in range(self.bucket_sizes[i])] - # for i in range(self.num_buckets)] - if not self.buckets: - self.buckets = [[None for _ in range(self.bucket_sizes[i])] - for i in range(self.num_buckets)] - else: - assert len(self.buckets) == self.num_buckets, "len(buckets) = {}, expected {}".format( - len(self.buckets), self.num_buckets) - for b, bucket in enumerate(self.buckets): - assert len(bucket) == self.bucket_sizes[b], "len(buckets[{}]) = {}, expected {})".format( - b, len(buckets[b]), self.bucket_sizes[b]) - for i in range(len(bucket)): - bucket[i] = None - - if self.allreduce_communicators: - self.bucket_pgs = self.allreduce_communicators[0] - self.bucket_streams = self.allreduce_communicators[1] - self.bucket_events = [torch.cuda.Event(enable_timing=False, - blocking=False) for _ in range(self.num_allreduce_streams)] - else: - if self.allreduce_different_streams: - if not self.bucket_pgs: - self.bucket_pgs = [dist.new_group() for _ in range(self.num_allreduce_streams)] - for i, bg in enumerate(self.bucket_pgs): - print("rank {} created group {} with backend {}".format( - dist.get_rank(), i, dist.get_backend(bg))) - if self.allreduce_different_streams: - if not self.bucket_streams: - self.bucket_streams = [torch.cuda.Stream() for _ in range(self.num_allreduce_streams)] - self.bucket_events = [torch.cuda.Event(enable_timing=False, - blocking=False) for _ in range(self.num_allreduce_streams)] - else: - if not self.bucket_streams: - self.bucket_streams = [torch.cuda.Stream()] - self.bucket_events = [torch.cuda.Event(enable_timing=False, blocking=False)] - - self.buckets_ready_size = [0 for i in range(self.num_buckets)] - if(self.retain_allreduce_buffers): - self.allreduce_buffers = [None for _ in range(self.num_buckets)] - self.next_bucket = 0 - self.ready_buckets_not_reduced = set() - - self.active_params = param_list - - self.callback_queued = False - - if self.prof: - torch.cuda.nvtx.range_pop() - - return result diff --git a/apex/parallel/multiproc.py b/apex/parallel/multiproc.py deleted file mode 100644 index ff743df..0000000 --- a/apex/parallel/multiproc.py +++ /dev/null @@ -1,35 +0,0 @@ -import torch -import sys -import subprocess - -def docstring_hack(): - """ - Multiproc file which will launch a set of processes locally for multi-gpu - usage: python -m apex.parallel.multiproc main.py ... - """ - pass - -argslist = list(sys.argv)[1:] -world_size = torch.cuda.device_count() - -if '--world-size' in argslist: - world_size = int(argslist[argslist.index('--world-size')+1]) -else: - argslist.append('--world-size') - argslist.append(str(world_size)) - -workers = [] - -for i in range(world_size): - if '--rank' in argslist: - argslist[argslist.index('--rank')+1] = str(i) - else: - argslist.append('--rank') - argslist.append(str(i)) - stdout = None if i == 0 else open("GPU_"+str(i)+".log", "w") - print(argslist) - p = subprocess.Popen([str(sys.executable)]+argslist, stdout=stdout) - workers.append(p) - -for p in workers: - p.wait() diff --git a/apex/parallel/optimized_sync_batchnorm.py b/apex/parallel/optimized_sync_batchnorm.py deleted file mode 100644 index 65cf5ea..0000000 --- a/apex/parallel/optimized_sync_batchnorm.py +++ /dev/null @@ -1,85 +0,0 @@ -import torch -from torch.nn.modules.batchnorm import _BatchNorm -from torch.nn import functional as F - -import syncbn -from .optimized_sync_batchnorm_kernel import SyncBatchnormFunction - - -class SyncBatchNorm(_BatchNorm): - """ - synchronized batch normalization module extented from `torch.nn.BatchNormNd` - with the added stats reduction across multiple processes. - :class:`apex.parallel.SyncBatchNorm` is designed to work with - `DistributedDataParallel`. - - When running in training mode, the layer reduces stats across all processes - to increase the effective batchsize for normalization layer. This is useful - in applications where batch size is small on a given process that would - diminish converged accuracy of the model. The model uses collective - communication package from `torch.distributed`. - - When running in evaluation mode, the layer falls back to - `torch.nn.functional.batch_norm` - - Args: - num_features: :math:`C` from an expected input of size - :math:`(N, C, L)` or :math:`L` from input of size :math:`(N, L)` - eps: a value added to the denominator for numerical stability. - Default: 1e-5 - momentum: the value used for the running_mean and running_var - computation. Can be set to ``None`` for cumulative moving average - (i.e. simple average). Default: 0.1 - affine: a boolean value that when set to ``True``, this module has - learnable affine parameters. Default: ``True`` - track_running_stats: a boolean value that when set to ``True``, this - module tracks the running mean and variance, and when set to ``False``, - this module does not track such statistics and always uses batch - statistics in both training and eval modes. Default: ``True`` - process_group: pass in a process group within which the stats of the - mini-batch is being synchronized. ``None`` for using default process - group - channel_last: a boolean value that when set to ``True``, this module - take the last dimension of the input tensor to be the channel - dimension. Default: False - - Examples:: - >>> # channel first tensor - >>> sbn = apex.parallel.SyncBatchNorm(100).cuda() - >>> inp = torch.randn(10, 100, 14, 14).cuda() - >>> out = sbn(inp) - >>> inp = torch.randn(3, 100, 20).cuda() - >>> out = sbn(inp) - >>> # channel last tensor - >>> sbn = apex.parallel.SyncBatchNorm(100, channel_last=True).cuda() - >>> inp = torch.randn(10, 14, 14, 100).cuda() - """ - - def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None, channel_last=False, fuse_relu=False): - super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) - self.process_group = process_group - self.channel_last = channel_last - self.fuse_relu = fuse_relu - - def _specify_process_group(self, process_group): - self.process_group = process_group - - def _specify_channel_last(self, channel_last): - self.channel_last = channel_last - - def forward(self, input, z = None): - # if input.dim() == 2, we switch to channel_last for efficient memory accessing - channel_last = self.channel_last if input.dim() != 2 else True - - if not self.training and self.track_running_stats and not channel_last and not self.fuse_relu and z == None: - # fall back to pytorch implementation for inference - return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False, 0.0, self.eps) - else: - exponential_average_factor = 0.0 - if self.training and self.track_running_stats: - self.num_batches_tracked += 1 - if self.momentum is None: - exponential_average_factor = 1.0 / float(self.num_batches_tracked) - else: - exponential_average_factor = self.momentum - return SyncBatchnormFunction.apply(input, z, self.weight, self.bias, self.running_mean, self.running_var, self.eps, self.training or not self.track_running_stats, exponential_average_factor, self.process_group, channel_last, self.fuse_relu) diff --git a/apex/parallel/optimized_sync_batchnorm_kernel.py b/apex/parallel/optimized_sync_batchnorm_kernel.py deleted file mode 100644 index 6168471..0000000 --- a/apex/parallel/optimized_sync_batchnorm_kernel.py +++ /dev/null @@ -1,119 +0,0 @@ -import torch -from torch.autograd.function import Function - -import syncbn -from apex.parallel import ReduceOp - -class SyncBatchnormFunction(Function): - - @staticmethod - def forward(ctx, input, z, weight, bias, running_mean, running_variance, eps, track_running_stats = True, momentum = 1.0, process_group = None, channel_last = False, fuse_relu = False): - input = input.contiguous() - world_size = 0 - - mean = None - var_biased = None - inv_std = None - var = None - out = None - count = None - if track_running_stats: - if channel_last: - count = int(input.numel()/input.size(-1)) - mean, var_biased = syncbn.welford_mean_var_c_last(input) - num_channels = input.size(-1) - else: - count = int(input.numel()/input.size(1)) - mean, var_biased = syncbn.welford_mean_var(input) - num_channels = input.size(1) - - if torch.distributed.is_initialized(): - if not process_group: - process_group = torch.distributed.group.WORLD - device = mean.device - world_size = torch.distributed.get_world_size(process_group) - - count_t = torch.empty(1, dtype=mean.dtype, device=mean.device).fill_(count) - combined = torch.cat([mean.view(-1), var_biased.view(-1), count_t], dim=0) - combined_list = [torch.empty_like(combined) for k in range(world_size)] - torch.distributed.all_gather(combined_list, combined, process_group) - combined = torch.stack(combined_list, dim=0) - mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1) - count_all = count_all.view(-1) - mean, var, inv_std = syncbn.welford_parallel(mean_all, invstd_all, count_all.to(torch.int32), eps) - else: - device = mean.device - count_all = torch.cuda.IntTensor([count], device=device) - inv_std = 1.0 / torch.sqrt(var_biased + eps) - var = var_biased * (count) / (count-1) - - if count == 1 and world_size < 2: - raise ValueError('Expected more than 1 value per channel when training, got input size{}'.format(input.size())) - - r_m_inc = mean if running_mean.dtype != torch.float16 else mean.half() - r_v_inc = var if running_variance.dtype != torch.float16 else var.half() - running_mean.data = running_mean.data * (1-momentum) + momentum*r_m_inc - running_variance.data = running_variance.data * (1-momentum) + momentum*r_v_inc - else: - mean = running_mean.data - inv_std = 1.0 / torch.sqrt(running_variance.data + eps) - - ctx.save_for_backward(input, weight, mean, inv_std, z, bias, count_all.to(torch.int32)) - ctx.process_group = process_group - ctx.channel_last = channel_last - ctx.world_size = world_size - ctx.fuse_relu = fuse_relu - - if channel_last: - out = syncbn.batchnorm_forward_c_last(input, z, mean, inv_std, weight, bias, fuse_relu) - else: - out = syncbn.batchnorm_forward(input, mean, inv_std, weight, bias) - - return out - - @staticmethod - def backward(ctx, grad_output): - grad_output = grad_output.contiguous() - # mini batch mean & var are calculated by forward path. - # mu = 1./N*np.sum(h, axis = 0) - # var = 1./N*np.sum((h-mu)**2, axis = 0) - saved_input, weight, mean, inv_std, z, bias, count = ctx.saved_tensors - process_group = ctx.process_group - channel_last = ctx.channel_last - world_size = ctx.world_size - fuse_relu = ctx.fuse_relu - grad_input = grad_z = grad_weight = grad_bias = None - - if fuse_relu: - grad_output = syncbn.relu_bw_c_last(grad_output, saved_input, z, mean, inv_std, weight, bias) - if isinstance(z, torch.Tensor) and ctx.needs_input_grad[1]: - grad_z = grad_output.clone() - - # TODO: update kernel to not pre_divide by item_num - if channel_last: - sum_dy, sum_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn_c_last(grad_output, saved_input, mean, inv_std, weight) - else: - sum_dy, sum_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output, saved_input, mean, inv_std, weight) - - # calculate grad_input - if ctx.needs_input_grad[0]: - - if torch.distributed.is_initialized(): - num_channels = sum_dy.shape[0] - combined = torch.cat([sum_dy, sum_dy_xmu], dim=0) - torch.distributed.all_reduce( - combined, torch.distributed.ReduceOp.SUM, process_group, async_op=False) - sum_dy, sum_dy_xmu = torch.split(combined, num_channels) - - if channel_last: - grad_input = syncbn.batchnorm_backward_c_last(grad_output, saved_input, mean, inv_std, weight, sum_dy, sum_dy_xmu, count) - else: - grad_input = syncbn.batchnorm_backward(grad_output, saved_input, mean, inv_std, weight, sum_dy, sum_dy_xmu, count) - - if weight is None or not ctx.needs_input_grad[2]: - grad_weight = None - - if weight is None or not ctx.needs_input_grad[3]: - grad_bias = None - - return grad_input, grad_z, grad_weight, grad_bias, None, None, None, None, None, None, None, None diff --git a/apex/parallel/sync_batchnorm.py b/apex/parallel/sync_batchnorm.py deleted file mode 100644 index 1fcfe43..0000000 --- a/apex/parallel/sync_batchnorm.py +++ /dev/null @@ -1,134 +0,0 @@ -import torch -from torch.nn.modules.batchnorm import _BatchNorm -from torch.nn import functional as F - -from .sync_batchnorm_kernel import SyncBatchnormFunction -from apex.parallel import ReduceOp - - -class SyncBatchNorm(_BatchNorm): - """ - synchronized batch normalization module extented from ``torch.nn.BatchNormNd`` - with the added stats reduction across multiple processes. - :class:`apex.parallel.SyncBatchNorm` is designed to work with - ``DistributedDataParallel``. - - When running in training mode, the layer reduces stats across all processes - to increase the effective batchsize for normalization layer. This is useful - in applications where batch size is small on a given process that would - diminish converged accuracy of the model. The model uses collective - communication package from ``torch.distributed``. - - When running in evaluation mode, the layer falls back to - ``torch.nn.functional.batch_norm``. - - Args: - num_features: :math:`C` from an expected input of size - :math:`(N, C, L)` or :math:`L` from input of size :math:`(N, L)` - eps: a value added to the denominator for numerical stability. - Default: 1e-5 - momentum: the value used for the running_mean and running_var - computation. Can be set to ``None`` for cumulative moving average - (i.e. simple average). Default: 0.1 - affine: a boolean value that when set to ``True``, this module has - learnable affine parameters. Default: ``True`` - track_running_stats: a boolean value that when set to ``True``, this - module tracks the running mean and variance, and when set to ``False``, - this module does not track such statistics and always uses batch - statistics in both training and eval modes. Default: ``True`` - - Example:: - - >>> sbn = apex.parallel.SyncBatchNorm(100).cuda() - >>> inp = torch.randn(10, 100, 14, 14).cuda() - >>> out = sbn(inp) - >>> inp = torch.randn(3, 100, 20).cuda() - >>> out = sbn(inp) - """ - - warned = False - - def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None, channel_last=False): - if channel_last == True: - raise AttributeError("channel_last is not supported by primitive SyncBatchNorm implementation. Try install apex with `--cuda_ext` if channel_last is desired.") - - if not SyncBatchNorm.warned: - if hasattr(self, "syncbn_import_error"): - print("Warning: using Python fallback for SyncBatchNorm, possibly because apex was installed without --cuda_ext. The exception raised when attempting to import the cuda backend was: ", self.syncbn_import_error) - else: - print("Warning: using Python fallback for SyncBatchNorm") - SyncBatchNorm.warned = True - - super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) - self.process_group = process_group - - def _specify_process_group(self, process_group): - self.process_group = process_group - - def forward(self, input): - torch.cuda.nvtx.range_push("sync_bn_fw_with_mean_var") - mean = None - var = None - cast = None - out = None - - # casting to handle mismatch input type to layer type - if self.running_mean is not None: - if self.running_mean.dtype != input.dtype: - input = input.to(self.running_mean.dtype) - cast = input.dtype - elif self.weight is not None: - if self.weight.dtype != input.dtype: - input = input.to(self.weight.dtype) - cast = input.dtype - - if not self.training and self.track_running_stats: - # fall back to pytorch implementation for inference - torch.cuda.nvtx.range_pop() - out = F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False, 0.0, self.eps) - else: - process_group = self.process_group - world_size = 1 - if not self.process_group: - process_group = torch.distributed.group.WORLD - self.num_batches_tracked += 1 - with torch.no_grad(): - channel_first_input = input.transpose(0, 1).contiguous() - squashed_input_tensor_view = channel_first_input.view( - channel_first_input.size(0), -1) - # total number of data points for each variance entry. Used to calculate unbiased variance estimate - m = None - local_m = float(squashed_input_tensor_view.size()[1]) - local_mean = torch.mean(squashed_input_tensor_view, 1) - local_sqr_mean = torch.pow( - squashed_input_tensor_view, 2).mean(1) - if torch.distributed.is_initialized(): - world_size = torch.distributed.get_world_size(process_group) - torch.distributed.all_reduce( - local_mean, ReduceOp.SUM, process_group) - mean = local_mean / world_size - torch.distributed.all_reduce( - local_sqr_mean, ReduceOp.SUM, process_group) - sqr_mean = local_sqr_mean / world_size - m = local_m * world_size - else: - m = local_m - mean = local_mean - sqr_mean = local_sqr_mean - # var(x) = E (( x - mean_x ) ** 2) - # = 1 / N * sum ( x - mean_x ) ** 2 - # = 1 / N * sum (x**2) - mean_x**2 - var = sqr_mean - mean.pow(2) - - if self.running_mean is not None: - self.running_mean = self.momentum * mean + \ - (1 - self.momentum) * self.running_mean - if self.running_var is not None: - # as noted by the paper, we used unbiased variance estimate of the mini-batch - # Var[x] = m / (m-1) * Eb (sample_variance) - self.running_var = m / \ - (m-1) * self.momentum * var + \ - (1 - self.momentum) * self.running_var - torch.cuda.nvtx.range_pop() - out = SyncBatchnormFunction.apply(input, self.weight, self.bias, mean, var, self.eps, process_group, world_size) - return out.to(cast) diff --git a/apex/parallel/sync_batchnorm_kernel.py b/apex/parallel/sync_batchnorm_kernel.py deleted file mode 100644 index e407a63..0000000 --- a/apex/parallel/sync_batchnorm_kernel.py +++ /dev/null @@ -1,87 +0,0 @@ -import torch -from torch.autograd.function import Function - -from apex.parallel import ReduceOp - - -class SyncBatchnormFunction(Function): - - @staticmethod - def forward(ctx, input, weight, bias, running_mean, running_variance, eps, process_group, world_size): - torch.cuda.nvtx.range_push("sync_BN_fw") - # transpose it to channel last to support broadcasting for input with different rank - c_last_input = input.transpose(1, -1).contiguous().clone() - - ctx.save_for_backward(c_last_input, weight, bias, - running_mean, running_variance) - ctx.eps = eps - ctx.process_group = process_group - ctx.world_size = world_size - - c_last_input = (c_last_input - running_mean) / \ - torch.sqrt(running_variance + eps) - - if weight is not None: - c_last_input = c_last_input * weight - if bias is not None: - c_last_input = c_last_input + bias - - torch.cuda.nvtx.range_pop() - return c_last_input.transpose(1, -1).contiguous().clone() - - @staticmethod - def backward(ctx, grad_output): - torch.cuda.nvtx.range_push("sync_BN_bw") - # mini batch mean & var are calculated by forward path. - # mu = 1./N*np.sum(h, axis = 0) - # var = 1./N*np.sum((h-mu)**2, axis = 0) - c_last_input, weight, bias, running_mean, running_variance = ctx.saved_tensors - - eps = ctx.eps - process_group = ctx.process_group - world_size = ctx.world_size - grad_input = grad_weight = grad_bias = None - num_features = running_mean.size()[0] - - # transpose it to channel last to support broadcasting for input with different rank - torch.cuda.nvtx.range_push("carilli field") - c_last_grad = grad_output.transpose(1, -1).contiguous() - # squash non-channel dimension so we can easily calculate mean - c_grad = c_last_grad.view(-1, num_features).contiguous() - torch.cuda.nvtx.range_pop() - - # calculate grad_input - if ctx.needs_input_grad[0]: - # dh = gamma * (var + eps)**(-1. / 2.) * (dy - np.mean(dy, axis=0) - # - (h - mu) * (var + eps)**(-1.0) * np.mean(dy * (h - mu), axis=0)) - mean_dy = c_grad.mean(0) - mean_dy_xmu = (c_last_grad * (c_last_input - - running_mean)).view(-1, num_features).mean(0) - if torch.distributed.is_initialized(): - torch.distributed.all_reduce( - mean_dy, ReduceOp.SUM, process_group) - mean_dy = mean_dy / world_size - torch.distributed.all_reduce( - mean_dy_xmu, ReduceOp.SUM, process_group) - mean_dy_xmu = mean_dy_xmu / world_size - c_last_grad_input = (c_last_grad - mean_dy - (c_last_input - running_mean) / ( - running_variance + eps) * mean_dy_xmu) / torch.sqrt(running_variance + eps) - if weight is not None: - c_last_grad_input.mul_(weight) - grad_input = c_last_grad_input.transpose(1, -1).contiguous() - - # calculate grad_weight - grad_weight = None - if weight is not None and ctx.needs_input_grad[1]: - # dgamma = np.sum((h - mu) * (var + eps)**(-1. / 2.) * dy, axis=0) - grad_weight = ((c_last_input - running_mean) / torch.sqrt( - running_variance + eps) * c_last_grad).view(-1, num_features).sum(0) - - # calculate grad_bias - grad_bias = None - if bias is not None and ctx.needs_input_grad[2]: - # dbeta = np.sum(dy, axis=0) - grad_bias = c_grad.sum(0) - - torch.cuda.nvtx.range_pop() - return grad_input, grad_weight, grad_bias, None, None, None, None, None diff --git a/apex/testing/__init__.py b/apex/testing/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/apex/testing/common_utils.py b/apex/testing/common_utils.py deleted file mode 100644 index 82b660f..0000000 --- a/apex/testing/common_utils.py +++ /dev/null @@ -1,33 +0,0 @@ -''' -This file contains common utility functions for running the unit tests on ROCM. -''' - -import torch -import os -import sys -from functools import wraps -import unittest - - -TEST_WITH_ROCM = os.getenv('APEX_TEST_WITH_ROCM', '0') == '1' -SKIP_FLAKY_TEST = os.getenv('APEX_SKIP_FLAKY_TEST', '0') == '1' - -## Wrapper to skip the unit tests. -def skipIfRocm(fn): - @wraps(fn) - def wrapper(*args, **kwargs): - if TEST_WITH_ROCM: - raise unittest.SkipTest("test doesn't currently work on ROCm stack.") - else: - fn(*args, **kwargs) - return wrapper - -## Wrapper to skip the flaky unit tests. -def skipFlakyTest(fn): - @wraps(fn) - def wrapper(*args, **kwargs): - if SKIP_FLAKY_TEST: - raise unittest.SkipTest("Test is flaky.") - else: - fn(*args, **kwargs) - return wrapper diff --git a/apex/transformer/README.md b/apex/transformer/README.md deleted file mode 100644 index 7383f65..0000000 --- a/apex/transformer/README.md +++ /dev/null @@ -1,81 +0,0 @@ -# apex.transformer - -`apex.transformer` is a module which enables efficient large Transformer models at scale. - -`apex.transformer.tensor_parallel` and `apex.transformer.pipeline_parallel` are both based on [NVIDIA/Megatron-LM](https://github.com/NVIDIA/Megatron-LM)'s module. -The former is based on `megatron.mpu` and the latter is on `megatron.schedules` and `megatron.p2p_communication`. - -## Tensor Model Parallel (TP) - -APEX's tensor model parallel utilities provides some `torch.nn.Module`'s, custom fused kernels, and PRNG state handling. -See Appendix B.2 of [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) for the details of -PRNG state handling. - -## Pipeline Model Parallel (PP) -APEX's pipeline model parallel functions require models to have `.set_input_tensor` because -the input tensor for `.forward` method can be `None`. - -The following is a really casual sketch of training script with apex pp. - -```python -import torch -import torch.nn as nn -import torch.nn.functional as F - -from apex.transformer import parallel_state -from apex.transformer.pipeline_parallel import get_forward_backward_func - - -class Model(nn.Module): - - ... - - def __init__(self, *args, **kwargs): - super().__init__() - pre_process = kwargs.pop("pre_process") - post_process = kwargs.pop("post_process") - - def set_input_tensor(self, tensor): - self.input_tensor = tensor - - def forward(self, x, ...): - if parallel_state.is_pipeline_first_stage(): - input = x - else: - input = self.input_tensor - ... - - -def model_provider_func(*args, **kwargs): - return Model(*args, **kwargs) - - -def loss_func(pred, label): - loss = ... - averaged_loss = average_losses_across_data_parallel_group([loss]) - return loss, {'nice_loss': averaged_loss} - - -def forward_step_func(batch, model): - input, label = process_batch(batch) - out = model(input) - return out, partial(loss_func, label) - - -forward_backward_func = get_forward_backward_func(virtual_pipeline_model_parallel_size, pipeline_model_parallel_size) - - -parallel_state.initialize_model_parallel( - tensor_model_parallel_size, - pipeline_model_parallel_size, - virtual_pipeline_model_parallel_size, -) -# The following line basically is equivalent to `build_model(Model, wrap_with_ddp, virtual_pipeline_model_parallel_size, *model_args, **model_kwargs)` -model = build_model(model_provider_func, wrap_with_ddp, virtual_pipeline_model_parallel_size, *model_args, **model_kwargs) -optimizer = ... -data_loader = ... -for epoch in range(num_epochs): - for batch in data_loader: - forward_backward_func(forward_step_func, batch, model, forward_only=False, tensor_shape) - optimizer.step() -``` diff --git a/apex/transformer/__init__.py b/apex/transformer/__init__.py deleted file mode 100644 index ff9c7b9..0000000 --- a/apex/transformer/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -from apex.transformer import amp -from apex.transformer import functional -from apex.transformer import parallel_state -from apex.transformer import pipeline_parallel -from apex.transformer import tensor_parallel -from apex.transformer import utils -from apex.transformer.enums import LayerType -from apex.transformer.enums import AttnType -from apex.transformer.enums import AttnMaskType - - -__all__ = [ - "amp", - "functional", - "parallel_state", - "pipeline_parallel", - "tensor_parallel", - "utils", - # enums.py - "LayerType", - "AttnType", - "AttnMaskType", -] diff --git a/apex/transformer/_data/__init__.py b/apex/transformer/_data/__init__.py deleted file mode 100644 index 2831dfb..0000000 --- a/apex/transformer/_data/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -from apex.transformer._data._batchsampler import MegatronPretrainingRandomSampler -from apex.transformer._data._batchsampler import MegatronPretrainingSampler - - -__all__ = [ - "MegatronPretrainingRandomSampler", - "MegatronPretrainingSampler", -] diff --git a/apex/transformer/_data/_batchsampler.py b/apex/transformer/_data/_batchsampler.py deleted file mode 100644 index b2e96a9..0000000 --- a/apex/transformer/_data/_batchsampler.py +++ /dev/null @@ -1,180 +0,0 @@ -"""BatchSampler implementations for POC of dynamic batch size or rampup_batch_size support. - -Implementations are based on https://github.com/NVIDIA/Megatron-LM/blob/bcd605f8570ebeeb0436c115ebbfafc3c5a40ae5/megatron/data/data_samplers.py. -""" # NOQA -import abc - -import torch - - -__all__ = [ - "MegatronPretrainingSampler", - "MegatronPretrainingRandomSampler", -] - - -class _Base: - """Base class for Megatron style BatchSampler.""" - - @abc.abstractmethod - def __len__(self) -> int: - ... - - @abc.abstractmethod - def __iter__(self): - ... - - @property - @abc.abstractmethod - def local_minibatch_size(self) -> int: - ... - - @local_minibatch_size.setter - @abc.abstractclassmethod - def local_minibatch_size(self) -> None: - ... - - -class MegatronPretrainingSampler(_Base): - - def __init__( - self, - total_samples: int, - consumed_samples: int, - local_minibatch_size: int, - data_parallel_rank: int, - data_parallel_size: int, - drop_last: bool = True, - ): - # Sanity checks. - if total_samples <= 0: - raise RuntimeError('no sample to consume: {}'.format(self.total_samples)) - if consumed_samples >= total_samples: - raise RuntimeError('no samples left to consume: {}, {}'.format(self.consumed_samples, self.total_samples)) - if local_minibatch_size <= 0: - raise RuntimeError(f"local minibatch size must be greater than 0: {local_minibatch_size}") - if data_parallel_size <= 0: - raise RuntimeError(f"data parallel size must be greater than 0: {data_parallel_size}") - if data_parallel_rank >= data_parallel_size: - raise RuntimeError('data_parallel_rank should be smaller than data size: {}, {}'.format(self.data_parallel_rank, data_parallel_size)) - # Keep a copy of input params for later use. - self.total_samples = total_samples - self.consumed_samples = consumed_samples - self._local_minibatch_size = local_minibatch_size - self.data_parallel_rank = data_parallel_rank - self.data_parallel_size = data_parallel_size - self.local_minibatch_times_data_parallel_size = self._local_minibatch_size * data_parallel_size - self.drop_last = drop_last - - def __len__(self): - return self.total_samples - - def get_start_end_idx(self): - start_idx = self.data_parallel_rank * self.local_minibatch_size - end_idx = start_idx + self.local_minibatch_size - return start_idx, end_idx - - @property - def local_minibatch_size(self) -> int: - return self._local_minibatch_size - - @local_minibatch_size.setter - def local_minibatch_size(self, new_local_minibatch_size) -> None: - self._local_minibatch_size = new_local_minibatch_size - self.local_minibatch_times_data_parallel_size = self._local_minibatch_size * self.data_parallel_size - - def __iter__(self): - batch = [] - # Last batch will be dropped if drop_last is not set False - for idx in range(self.consumed_samples, self.total_samples): - batch.append(idx) - if len(batch) == self.local_minibatch_size: - start_idx, end_idx = self.get_start_end_idx() - yield batch[start_idx:end_idx] - batch = [] - - # Check the last partial batch and see drop_last is set - if len(batch) > 0 and not self.drop_last: - start_idx, end_idx = self.get_start_end_idx() - yield batch[start_idx:end_idx] - - -class MegatronPretrainingRandomSampler(_Base): - """Megatron style Random Batch Sampler. - - Major difference is that `__iter__` yields a local minibatch, not a microbatch. - A local minibatch consists of `global_batch_size / data_parallel_size` - - Args: - total_samples: The number of data samples, i.e. ``len(dataset)``. - consumed_samples: The number of samples already consumed in pretraining. - local_minibatch_size: The number of data in each batch returned from `__iter__`. Basically - `local_minibatch_size = global_batch_size / data_parallel_size`. - data_parallel_rank: - data_parallel_size: - """ - - def __init__( - self, - total_samples: int, - consumed_samples: int, - local_minibatch_size: int, - data_parallel_rank: int, - data_parallel_size: int, - ) -> None: - if total_samples <= 0: - raise ValueError(f"no sample to consume: total_samples of {total_samples}") - if local_minibatch_size <= 0: - raise ValueError(f"Invalid local_minibatch_size: {local_minibatch_size}") - if data_parallel_size <= 0: - raise ValueError(f"Invalid data_parallel_size: {data_parallel_size}") - if data_parallel_rank >= data_parallel_size: - raise ValueError( - f"data_parallel_rank should be smaller than data parallel size: {data_parallel_rank} < {data_parallel_size}" - ) - # Keep a copy of input params for later use. - self.total_samples = total_samples - self.consumed_samples = consumed_samples - self._local_minibatch_size = local_minibatch_size - self.data_parallel_rank = data_parallel_rank - self.data_parallel_size = data_parallel_size - self.local_minibatch_times_data_parallel_size = self._local_minibatch_size * self.data_parallel_size - self.last_batch_size = self.total_samples % self.local_minibatch_times_data_parallel_size - - def __len__(self) -> int: - return self.total_samples - - @property - def local_minibatch_size(self) -> int: - return self._local_minibatch_size - - @local_minibatch_size.setter - def local_minibatch_size(self, new_local_minibatch_size) -> None: - self._local_minibatch_size = new_local_minibatch_size - self.local_minibatch_times_data_parallel_size = self._local_minibatch_size * self.data_parallel_size - - def __iter__(self): - active_total_samples = self.total_samples - self.last_batch_size - self.epoch = self.consumed_samples // active_total_samples - current_epoch_samples = self.consumed_samples % active_total_samples - # note(mkozuki): might be better to uncomment - # assert current_epoch_samples % (self.data_parallel_size * apex.transformer.pipeline_parallel.utils.get_micro_batch_size()) == 0 - - # data sharding and random sampling - bucket_size = (self.total_samples // self.local_minibatch_times_data_parallel_size) * self.local_minibatch_size - bucket_offset = current_epoch_samples // self.data_parallel_size - start_idx = self.data_parallel_rank * bucket_size - - g = torch.Generator() - g.manual_seed(self.epoch) - random_idx = torch.randperm(bucket_size, generator=g).tolist() - idx_range = [start_idx + x for x in random_idx[bucket_offset:]] - - batch = [] - # Last batch if not complete will be dropped. - for idx in idx_range: - batch.append(idx) - if len(batch) == self.local_minibatch_size: - self.consumed_samples += self.local_minibatch_times_data_parallel_size - yield batch - batch = [] diff --git a/apex/transformer/amp/__init__.py b/apex/transformer/amp/__init__.py deleted file mode 100644 index dbef36a..0000000 --- a/apex/transformer/amp/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from apex.transformer.amp.grad_scaler import GradScaler - - -__all__ = [ - "GradScaler", -] diff --git a/apex/transformer/amp/grad_scaler.py b/apex/transformer/amp/grad_scaler.py deleted file mode 100644 index 5bcd061..0000000 --- a/apex/transformer/amp/grad_scaler.py +++ /dev/null @@ -1,119 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from collections import defaultdict - -import torch - -from apex.transformer import parallel_state - - -class GradScaler(torch.cuda.amp.GradScaler): - """ - Gradient scaler for model-parallel inf check. The inf in gradients are checked across tensor-parallel - ranks in (1) executing optimizer step and (2) gradient scaler update. - """ - - def __init__( - self, init_scale=2.0 ** 16, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, enabled=True - ): - super().__init__( - init_scale=init_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - enabled=enabled, - ) - - def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs): - retval = None - found_inf = torch.cuda.FloatTensor([sum(v.item() for v in optimizer_state["found_inf_per_device"].values())]) - - # Update across all model parallel instances. - torch.distributed.all_reduce( - found_inf, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_model_parallel_group() - ) - - if found_inf.item() == 0: - retval = optimizer.step(*args, **kwargs) - return retval - - def update(self, new_scale=None): - """ - Updates the scale factor. - If any optimizer steps were skipped the scale is multiplied by ``backoff_factor`` - to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively, - the scale is multiplied by ``growth_factor`` to increase it. - Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not - used directly, it's used to fill GradScaler's internal scale tensor. So if - ``new_scale`` was a tensor, later in-place changes to that tensor will not further - affect the scale GradScaler uses internally.) - Args: - new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor. - .. warning:: - :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has - been invoked for all optimizers used this iteration. - """ - if not self._enabled: - return - - _scale, _growth_tracker = self._check_scale_growth_tracker("update") - - if new_scale is not None: - # Accept a new user-defined scale. - if isinstance(new_scale, float): - self._scale.fill_(new_scale) # type: ignore[union-attr] - else: - reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False." - assert isinstance(new_scale, torch.cuda.FloatTensor), reason # type: ignore[attr-defined] - assert new_scale.numel() == 1, reason - assert new_scale.requires_grad is False, reason - self._scale.copy_(new_scale) # type: ignore[union-attr] - else: - # Consume shared inf/nan data collected from optimizers to update the scale. - # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. - found_infs = [ - found_inf.to(device=_scale.device, non_blocking=True) - for state in self._per_optimizer_states.values() - for found_inf in state["found_inf_per_device"].values() - ] - - assert len(found_infs) > 0, "No inf checks were recorded prior to update." - - found_inf_combined = found_infs[0] - - # Update across all model parallel instances. - torch.distributed.all_reduce( - found_inf_combined, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_model_parallel_group() - ) - - if len(found_infs) > 1: - for i in range(1, len(found_infs)): - found_inf = found_infs[i] - # Update across all model parallel instances. - torch.distributed.all_reduce( - found_inf, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_model_parallel_group() - ) - found_inf_combined += found_inf - - torch._amp_update_scale_( - _scale, - _growth_tracker, - found_inf_combined, - self._growth_factor, - self._backoff_factor, - self._growth_interval, - ) - - # To prepare for next iteration, clear the data collected from optimizers this iteration. - self._per_optimizer_states = defaultdict(torch.cuda.amp.grad_scaler._refresh_per_optimizer_state) diff --git a/apex/transformer/enums.py b/apex/transformer/enums.py deleted file mode 100644 index 78da6c9..0000000 --- a/apex/transformer/enums.py +++ /dev/null @@ -1,35 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import enum - - -class LayerType(enum.Enum): - encoder = 1 - decoder = 2 - - -class AttnType(enum.Enum): - self_attn = 1 - cross_attn = 2 - - -class AttnMaskType(enum.Enum): - padding = 1 - causal = 2 - - -class ModelType(enum.Enum): - encoder_or_decoder = 1 - encoder_and_decoder = 2 diff --git a/apex/transformer/functional/__init__.py b/apex/transformer/functional/__init__.py deleted file mode 100644 index d770c88..0000000 --- a/apex/transformer/functional/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from apex.transformer.functional.fused_softmax import FusedScaleMaskSoftmax - -__all__ = [ - "FusedScaleMaskSoftmax", -] diff --git a/apex/transformer/functional/fused_softmax.py b/apex/transformer/functional/fused_softmax.py deleted file mode 100644 index 8ceaffe..0000000 --- a/apex/transformer/functional/fused_softmax.py +++ /dev/null @@ -1,211 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import torch - -from apex._autocast_utils import _cast_if_autocast_enabled -from apex.transformer.enums import AttnMaskType - - -class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): - """ - Fused operation which performs following three operations in sequence - 1. Scale the tensor. - 2. Apply upper triangular mask (typically used in gpt models). - 3. Perform softmax. - """ - - @staticmethod - def forward(ctx, inputs, scale): - import scaled_upper_triang_masked_softmax_cuda - - scale_t = torch.tensor([scale]) - softmax_results = scaled_upper_triang_masked_softmax_cuda.forward( - inputs, scale_t[0] - ) - - ctx.save_for_backward(softmax_results, scale_t) - return softmax_results - - @staticmethod - def backward(ctx, output_grads): - import scaled_upper_triang_masked_softmax_cuda - - softmax_results, scale_t = ctx.saved_tensors - input_grads = scaled_upper_triang_masked_softmax_cuda.backward( - output_grads, softmax_results, scale_t[0] - ) - - return input_grads, None - - -def scaled_upper_triang_masked_softmax(inputs, _, scale): - b, np, sq, sk = inputs.size() - assert sq == sk, "causal mask is only for self attention" - # Reshaping input to 3D tensor (attn_batches, sq, sk) - inputs = inputs.view(-1, sq, sk) - args = _cast_if_autocast_enabled(inputs, scale) - with torch.cuda.amp.autocast(enabled=False): - probs = ScaledUpperTriangMaskedSoftmax.apply(*args) - return probs.view(b, np, sq, sk) - - -# NOTE (mkozuki): `ScaledMaskedSoftmax` somehow doesn't work well with `torch.cuda.amp.custom_fwd`. -# Without `cast_inputs` kwarg, somehow inputs are not cast to dtype used in the autocast context. -# So I needed to manually write two `torch.autograd.Function` inheritances. -# Fused operation which performs following three operations in sequence -# 1. Scale the tensor. -# 2. Apply the mask. -# 3. Perform softmax. -class ScaledMaskedSoftmax(torch.autograd.Function): - @staticmethod - def forward(ctx, inputs, mask, scale): - import scaled_masked_softmax_cuda - - scale_t = torch.tensor([scale]) - - softmax_results = scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0]) - ctx.save_for_backward(softmax_results, scale_t) - return softmax_results - - @staticmethod - def backward(ctx, output_grads): - import scaled_masked_softmax_cuda - - softmax_results, scale_t = ctx.saved_tensors - - input_grads = scaled_masked_softmax_cuda.backward( - output_grads, softmax_results, scale_t[0] - ) - return input_grads, None, None - - -def scaled_masked_softmax(inputs, mask, scale): - # input is 4D tensor (b, np, sq, sk) - args = _cast_if_autocast_enabled(inputs, mask, scale) - with torch.cuda.amp.autocast(enabled=False): - return ScaledMaskedSoftmax.apply(*args) - - -class FusedScaleMaskSoftmax(torch.nn.Module): - """ - fused operation: scaling + mask + softmax - - Arguments: - input_in_fp16: flag to indicate if input in fp16 data format. - input_in_bf16: flag to indicate if input in bf16 data format. - attn_mask_type: attention mask type (pad or causal) - scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion - mask_func: mask function to be applied. - softmax_in_fp32: if true, softmax in performed at fp32 precision. - scale: scaling factor used in input tensor scaling. - """ - - def __init__( - self, - input_in_fp16, - input_in_bf16, - attn_mask_type, - scaled_masked_softmax_fusion, - mask_func, - softmax_in_fp32, - scale, - ): - super().__init__() - self.input_in_fp16 = input_in_fp16 - self.input_in_bf16 = input_in_bf16 - if self.input_in_fp16 and self.input_in_bf16: - raise RuntimeError( - "both fp16 and bf16 flags cannot be active at the same time." - ) - self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 - self.attn_mask_type = attn_mask_type - self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion - self.mask_func = mask_func - self.softmax_in_fp32 = softmax_in_fp32 - self.scale = scale - - if not (self.scale is None or softmax_in_fp32): - raise RuntimeError("softmax should be in fp32 when scaled") - - if self.scaled_masked_softmax_fusion: - if self.attn_mask_type == AttnMaskType.causal: - self.fused_softmax_func = scaled_upper_triang_masked_softmax - elif self.attn_mask_type == AttnMaskType.padding: - self.fused_softmax_func = scaled_masked_softmax - else: - raise ValueError("Invalid attn_mask_type.") - - def forward(self, input, mask): - # [b, np, sq, sk] - assert input.dim() == 4 - - if self.is_kernel_available(mask, *input.size()): - return self.forward_fused_softmax(input, mask) - else: - return self.forward_torch_softmax(input, mask) - - def is_kernel_available(self, mask, b, np, sq, sk): - attn_batches = b * np - - if ( - self.scaled_masked_softmax_fusion # user want to fuse - and self.input_in_float16 # input must be fp16 - and ( - self.attn_mask_type == AttnMaskType.causal - or (self.attn_mask_type == AttnMaskType.padding and mask is not None) - ) - and 16 < sk <= 2048 # sk must be 16 ~ 2048 - and sq % 4 == 0 # sq must be divisor of 4 - and sk % 4 == 0 # sk must be divisor of 4 - and attn_batches % 4 == 0 # np * b must be divisor of 4 - ): - if 0 <= sk <= 2048: - batch_per_block = self.get_batch_per_block(sq, sk, b, np) - - if self.attn_mask_type == AttnMaskType.causal: - if attn_batches % batch_per_block == 0: - return True - else: - if sq % batch_per_block == 0: - return True - return False - - def forward_fused_softmax(self, input, mask): - # input.shape = [b, np, sq, sk] - scale = self.scale if self.scale is not None else 1.0 - return self.fused_softmax_func(input, mask, scale) - - def forward_torch_softmax(self, input, mask): - if self.input_in_float16 and self.softmax_in_fp32: - input = input.float() - - if self.scale is not None: - input = input * self.scale - mask_output = self.mask_func(input, mask) if mask is not None else input - probs = torch.nn.Softmax(dim=-1)(mask_output) - - if self.input_in_float16 and self.softmax_in_fp32: - if self.input_in_fp16: - probs = probs.half() - else: - probs = probs.bfloat16() - - return probs - - @staticmethod - def get_batch_per_block(sq, sk, b, np): - import scaled_masked_softmax_cuda - - return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np) diff --git a/apex/transformer/layers/__init__.py b/apex/transformer/layers/__init__.py deleted file mode 100644 index bc247d3..0000000 --- a/apex/transformer/layers/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -from apex.transformer.layers.layer_norm import FastLayerNorm -from apex.transformer.layers.layer_norm import FusedLayerNorm -from apex.transformer.layers.layer_norm import MixedFusedLayerNorm - - -__all__ = [ - "FastLayerNorm", - "FusedLayerNorm", - "MixedFusedLayerNorm", -] diff --git a/apex/transformer/layers/layer_norm.py b/apex/transformer/layers/layer_norm.py deleted file mode 100644 index 81cc239..0000000 --- a/apex/transformer/layers/layer_norm.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -# NOTE(mkozuki): This file defines two LayerNorm that are compatible with Megatron-LM. -# while avoiding introducing the breaking change of `"sequence_parallel_enabled"` attribute into apex.normalization.FusedLayerNorm -# and apex.contrib.layer_norm.FastLayerNorm. -import warnings - -import torch - -from apex.normalization import FusedLayerNorm as OrigFusedLayerNorm -from apex.normalization import MixedFusedLayerNorm as OrigMixedFusedLayerNorm -try: - from apex.contrib.layer_norm import FastLayerNorm as OrigFastLayerNorm -except ImportError: - HAS_FAST_LAYER_NORM = False -else: - HAS_FAST_LAYER_NORM = True - - -__all__ = [ - "FusedLayerNorm", - "FastLayerNorm", - "MixedFusedLayerNorm", -] - - -def _set_sequence_parallel_enabled( - param: torch.Tensor, - sequence_parallel_enabled: bool, -) -> None: - setattr(param, "sequence_parallel_enabled", sequence_parallel_enabled) - - -class FusedLayerNorm(OrigFusedLayerNorm): - def __init__( - self, - normalized_shape, - eps: float = 1e-5, - elementwise_affine: bool = True, - *, - sequence_parallel_enabled: bool = False, - ): - super().__init__( - normalized_shape=normalized_shape, - eps=eps, - elementwise_affine=elementwise_affine, - ) - self.sequence_parallel_enabled = sequence_parallel_enabled - if self.elementwise_affine: - _set_sequence_parallel_enabled(self.weight, self.sequence_parallel_enabled) - _set_sequence_parallel_enabled(self.bias, self.sequence_parallel_enabled) - - -# note: MixedFusedLayerNorm is no different from FusedLayerNorm if it's used in `torch.cuda.amp`. -class MixedFusedLayerNorm(OrigMixedFusedLayerNorm): - def __init__( - self, - normalized_shape, - eps: float = 1e-5, - **kwargs, - ) -> None: - self.sequence_parallel_enabled = kwargs.get("sequence_parallel_enabled", False) - super().__init__(normalized_shape=normalized_shape, eps=eps, **kwargs) - if self.sequence_parallel_enabled: - _set_sequence_parallel_enabled(self.weight, self.sequence_parallel_enabled) - _set_sequence_parallel_enabled(self.bias, self.sequence_parallel_enabled) - - -if HAS_FAST_LAYER_NORM: - class FastLayerNorm(OrigFastLayerNorm): - def __init__( - self, - hidden_size, - eps: float = 1e-5, - *, - sequence_parallel_enabled: bool = False, - ): - super().__init__( - hidden_size=hidden_size, - eps=eps - ) - self.sequence_parallel_enabled = sequence_parallel_enabled - _set_sequence_parallel_enabled(self.weight, self.sequence_parallel_enabled) - _set_sequence_parallel_enabled(self.bias, self.sequence_parallel_enabled) -else: - class FastLayerNorm(FusedLayerNorm): - def __init__( - self, - hidden_size, - eps: float = 1e-5, - *, - sequence_parallel_enabled: bool = False, - ): - warnings.warn("`apex.contrib.layer_norm.FastLayerNorm` isn't available thus falling back to `apex.normalization.FusedLayerNorm`") - super().__init__( - normalized_shape=hidden_size, - eps=eps, - elementwise_affine=True, - sequence_parallel_enabled=sequence_parallel_enabled, - ) diff --git a/apex/transformer/log_util.py b/apex/transformer/log_util.py deleted file mode 100644 index 7eaafee..0000000 --- a/apex/transformer/log_util.py +++ /dev/null @@ -1,18 +0,0 @@ -import logging -import os - - -def get_transformer_logger(name: str) -> logging.Logger: - name_wo_ext = os.path.splitext(name)[0] - return logging.getLogger(name_wo_ext) - - -def set_logging_level(verbosity) -> None: - """Change logging severity. - - Args: - verbosity - """ - from apex import _library_root_logger - - _library_root_logger.setLevel(verbosity) diff --git a/apex/transformer/microbatches.py b/apex/transformer/microbatches.py deleted file mode 100644 index 69673bc..0000000 --- a/apex/transformer/microbatches.py +++ /dev/null @@ -1,195 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Megatron number of micro-batches calculators.""" -from abc import ABC -from abc import abstractmethod -from typing import Optional, List - -from apex.transformer.log_util import get_transformer_logger - - -_logger = get_transformer_logger(__name__) - - -def build_num_microbatches_calculator( - rank: int, - rampup_batch_size: Optional[List[int]], - global_batch_size: int, - micro_batch_size: int, - data_parallel_size: int, -): - # Constant num micro-batches. - if rampup_batch_size is None: - num_microbatches_calculator = ConstantNumMicroBatches( - global_batch_size, micro_batch_size, data_parallel_size - ) - if rank == 0: - _logger.info( - "setting number of micro-batches to constant {}".format( - num_microbatches_calculator.get() - ) - ) - - else: - assert len(rampup_batch_size) == 3, ( - "expected the following " - "format: --rampup-batch-size " - " " - ) - start_batch_size = int(rampup_batch_size[0]) - batch_size_increment = int(rampup_batch_size[1]) - ramup_samples = int(rampup_batch_size[2]) - if rank == 0: - _logger.info( - "will use batch size rampup starting from global batch " - "size {} to global batch size {} with batch size increments " - "{} over {} samples.".format( - start_batch_size, - global_batch_size, - batch_size_increment, - ramup_samples, - ) - ) - num_microbatches_calculator = RampupBatchsizeNumMicroBatches( - start_batch_size, - batch_size_increment, - ramup_samples, - global_batch_size, - micro_batch_size, - data_parallel_size, - ) - - return num_microbatches_calculator - - -class NumMicroBatchesCalculator(ABC): - def __init__(self): - self.num_micro_batches = None - self.current_global_batch_size = None - - def get(self): - return self.num_micro_batches - - def get_current_global_batch_size(self): - return self.current_global_batch_size - - @abstractmethod - def update(self, consumed_samples, consistency_check): - pass - - -class ConstantNumMicroBatches(NumMicroBatchesCalculator): - def __init__(self, global_batch_size, micro_batch_size, data_parallel_size): - micro_batch_times_data_parallel = micro_batch_size * data_parallel_size - assert global_batch_size % micro_batch_times_data_parallel == 0, ( - "global batch size ({}) is not divisible by micro batch size ({})" - " times data parallel size ({})".format( - global_batch_size, micro_batch_size, data_parallel_size - ) - ) - self.num_micro_batches = global_batch_size // micro_batch_times_data_parallel - assert self.num_micro_batches >= 1 - self.current_global_batch_size = global_batch_size - - self.micro_batch_size = micro_batch_size - - def update(self, consumed_samples, consistency_check): - pass - - -class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator): - def __init__( - self, - start_batch_size, - batch_size_increment, - ramup_samples, - global_batch_size, - micro_batch_size, - data_parallel_size, - ): - """Batch size ramp up. - Over - steps = (global-batch-size - start-batch-size) / batch_size_increment - increment batch size from start-batch-size to global-batch-size using - rampup-samples / steps - samples. - Arguments: - start_batch_size: global batch size to start with - batch_size_increment: global batch size increments - ramup_samples: number of samples to use ramp up global - batch size from `start_batch_size` to `global_batch_size` - global_batch_size: global batch size post rampup - micro_batch_size: micro batch size - data_parallel_size: data parallel size. - """ - - self.micro_batch_size = micro_batch_size - self.data_parallel_size = data_parallel_size - self.micro_batch_times_data_parallel_size = ( - self.micro_batch_size * self.data_parallel_size - ) - assert self.micro_batch_times_data_parallel_size > 0 - - assert start_batch_size > 0 - self.start_batch_size = start_batch_size - - assert global_batch_size > 0 - self.global_batch_size = global_batch_size - diff_batch_size = self.global_batch_size - self.start_batch_size - assert diff_batch_size >= 0 - assert batch_size_increment > 0 - self.batch_size_increment = batch_size_increment - assert diff_batch_size % batch_size_increment == 0, ( - "expected " - "global batch size interval ({}) to be divisible by global batch " - "size increment ({})".format(diff_batch_size, batch_size_increment) - ) - - num_increments = diff_batch_size // self.batch_size_increment - self.ramup_samples = ramup_samples - assert self.ramup_samples >= 0 - self.rampup_samples_per_increment = self.ramup_samples / num_increments - - # Initialize number of microbatches. - self.update(0, False) - - def update(self, consumed_samples, consistency_check): - - if consumed_samples > self.ramup_samples: - self.current_global_batch_size = self.global_batch_size - else: - steps = int(consumed_samples / self.rampup_samples_per_increment) - self.current_global_batch_size = ( - self.start_batch_size + steps * self.batch_size_increment - ) - assert self.current_global_batch_size <= self.global_batch_size - - if consistency_check: - assert ( - self.current_global_batch_size - % self.micro_batch_times_data_parallel_size - == 0 - ), ( - "current global " - "batch size ({}) is not divisible by micro-batch-size ({}) times" - "data parallel size ({})".format( - self.current_global_batch_size, - self.micro_batch_size, - self.data_parallel_size, - ) - ) - self.num_micro_batches = ( - self.current_global_batch_size // self.micro_batch_times_data_parallel_size - ) diff --git a/apex/transformer/parallel_state.py b/apex/transformer/parallel_state.py deleted file mode 100644 index a8d16bf..0000000 --- a/apex/transformer/parallel_state.py +++ /dev/null @@ -1,682 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# TODO (mkozuki): Replace assert with RuntimeError. -# TODO (mkozuki): Sort the functions in the same order of megatron/mpu/initialize.py -"""Model and data parallel groups.""" -from typing import Tuple, Optional -import warnings - -import torch - -from apex.transformer.log_util import get_transformer_logger - - -_logger = get_transformer_logger(__name__) - -# N.B. (mkozuki): Diff btwn Megatron-LM & apex parallel_state -# set(megatron_mpu_initialize_funcs) - set(apex.transformer.parallel_state) = -# { -# 'get_num_layers', -# } - - -# Intra-layer model parallel group that the current rank belongs to. -_TENSOR_MODEL_PARALLEL_GROUP = None -# Inter-layer model parallel group that the current rank belongs to. -_PIPELINE_MODEL_PARALLEL_GROUP = None -# Model parallel group (both intra- and pipeline) that the current rank belongs to. -_MODEL_PARALLEL_GROUP = None -# Embedding group. -_EMBEDDING_GROUP = None -# Position embedding group. -_POSITION_EMBEDDING_GROUP = None -# Relative position embedding group. -_ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP = None -_DECODER_RELATIVE_POSITION_EMBEDDING_GROUP = None -# Data parallel group that the current rank belongs to. -_DATA_PARALLEL_GROUP = None - -_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None -_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None -_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = None - -# These values enable us to change the mpu sizes on the fly. -_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None -_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None -_MPU_TENSOR_MODEL_PARALLEL_RANK = None -_MPU_PIPELINE_MODEL_PARALLEL_RANK = None - -# A list of ranks that have a copy of the embedding. -_EMBEDDING_GLOBAL_RANKS = None - -# A list of ranks that have a copy of the position embedding. -_POSITION_EMBEDDING_GLOBAL_RANKS = None - -# A list of ranks that have a copy of the relative position embedding. -_ENCODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS = None -_DECODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS = None - -# A list of global ranks for each pipeline group to ease calculation of the source -# rank when broadcasting from the first or last pipeline stage -_PIPELINE_GLOBAL_RANKS = None - - -def is_unitialized(): - """Useful for code segments that may be accessed with or without mpu initialization""" - return _DATA_PARALLEL_GROUP is None - - -def initialize_model_parallel( - tensor_model_parallel_size_: int = 1, - pipeline_model_parallel_size_: int = 1, - virtual_pipeline_model_parallel_size_: Optional[int] = None, - pipeline_model_parallel_split_rank_: Optional[int] = None, - *, - default_backend: Optional[str] = None, - p2p_backend: Optional[str] = None, -) -> None: - """ - Initialize model data parallel groups. - - Arguments: - tensor_model_parallel_size: number of GPUs used to parallelize model tensor. - pipeline_model_parallel_size: number of GPUs used to parallelize model pipeline. - virtual_pipeline_model_parallel_size: number of virtual stages (interleaved pipeline). - pipeline_model_parallel_split_rank: for models with both encoder and decoder, rank in pipeline with split point. - Keyword Arguments: - default_backend: Backend of process groups except for pipeline parallel ones. - If :obj:`None`, the backend specified in `torch.distributed.init_process_group` will be used. - p2p_backend: Backend of process groups for pipeline model parallel. - If :obj:`None`, the backend specified in `torch.distributed.init_process_group` will be used. - - .. note:: - `torch_ucc `_ is - necessary for "ucc" backend. - - Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we - use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize - the model pipeline. The present function will - create 8 tensor model-parallel groups, 4 pipeline model-parallel groups - and 8 data-parallel groups as: - 8 data_parallel groups: - [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15] - 8 tensor model-parallel groups: - [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15] - 4 pipeline model-parallel groups: - [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15] - Note that for efficiency, the caller should make sure adjacent ranks - are on the same DGX box. For example if we are using 2 DGX-1 boxes - with a total of 16 GPUs, rank 0 to 7 belong to the first box and - ranks 8 to 15 belong to the second box. - """ - # Get world size and rank. Ensure some consistencies. - assert torch.distributed.is_initialized() - assert default_backend is None or default_backend in ("nccl", "ucc") - assert p2p_backend is None or p2p_backend in ("nccl", "ucc") - if "ucc" in (default_backend, p2p_backend): - check_torch_ucc_availability() - warnings.warn("`ucc` backend support is experimental", ExperimentalWarning) - if default_backend == "ucc": - warnings.warn("The UCC's functionality as `default_backend` is not well verified", ExperimentalWarning) - - world_size: int = torch.distributed.get_world_size() - tensor_model_parallel_size: int = min(tensor_model_parallel_size_, world_size) - pipeline_model_parallel_size: int = min(pipeline_model_parallel_size_, world_size) - if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size) != 0: - raise RuntimeError( - f"`world_size` ({world_size}) is not divisible by tensor_model_parallel_size ({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size})" - ) - data_parallel_size: int = world_size // ( - tensor_model_parallel_size * pipeline_model_parallel_size - ) - if torch.distributed.get_rank() == 0: - _logger.info( - "> initializing tensor model parallel with size {}".format( - tensor_model_parallel_size - ) - ) - _logger.info( - "> initializing pipeline model parallel with size {}".format( - pipeline_model_parallel_size - ) - ) - _logger.info( - "> initializing data parallel with size {}".format(data_parallel_size) - ) - - num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size - num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size - num_data_parallel_groups: int = world_size // data_parallel_size - - if virtual_pipeline_model_parallel_size_ is not None: - # n.b. (eqy) This check was inherited from Megatron-LM, need to revisit - # the root cause as we do see numerical mismatches with 2 stages and - # the interleaved schedule - assert pipeline_model_parallel_size_ > 2, ( - "pipeline-model-parallel size should be greater than 2 with " - "interleaved schedule" - ) - global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK - global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE - _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0 - _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = ( - virtual_pipeline_model_parallel_size_ - ) - - if pipeline_model_parallel_split_rank_ is not None: - global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK - _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank_ - - rank = torch.distributed.get_rank() - - # Build the data-parallel groups. - global _DATA_PARALLEL_GROUP - assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized" - all_data_parallel_group_ranks = [] - for i in range(pipeline_model_parallel_size): - start_rank = i * num_pipeline_model_parallel_groups - end_rank = (i + 1) * num_pipeline_model_parallel_groups - for j in range(tensor_model_parallel_size): - ranks = range(start_rank + j, end_rank, tensor_model_parallel_size) - all_data_parallel_group_ranks.append(list(ranks)) - group = torch.distributed.new_group(ranks, backend=default_backend) - if rank in ranks: - _DATA_PARALLEL_GROUP = group - - # Build the model-parallel groups. - global _MODEL_PARALLEL_GROUP - assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized" - for i in range(data_parallel_size): - ranks = [ - data_parallel_group_ranks[i] - for data_parallel_group_ranks in all_data_parallel_group_ranks - ] - group = torch.distributed.new_group(ranks, backend=default_backend) - if rank in ranks: - _MODEL_PARALLEL_GROUP = group - - # Build the tensor model-parallel groups. - global _TENSOR_MODEL_PARALLEL_GROUP - assert ( - _TENSOR_MODEL_PARALLEL_GROUP is None - ), "tensor model parallel group is already initialized" - for i in range(num_tensor_model_parallel_groups): - ranks = list( - range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) - ) - group = torch.distributed.new_group(ranks, backend=default_backend) - if rank in ranks: - _TENSOR_MODEL_PARALLEL_GROUP = group - - # Build the pipeline model-parallel groups and embedding groups - # (first and last rank in each pipeline model-parallel group). - global _PIPELINE_MODEL_PARALLEL_GROUP - global _PIPELINE_GLOBAL_RANKS - assert ( - _PIPELINE_MODEL_PARALLEL_GROUP is None - ), "pipeline model parallel group is already initialized" - global _EMBEDDING_GROUP - global _EMBEDDING_GLOBAL_RANKS - assert _EMBEDDING_GROUP is None, "embedding group is already initialized" - global _POSITION_EMBEDDING_GROUP - global _POSITION_EMBEDDING_GLOBAL_RANKS - assert ( - _POSITION_EMBEDDING_GROUP is None - ), "position embedding group is already initialized" - global _ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP - global _DECODER_RELATIVE_POSITION_EMBEDDING_GROUP - global _ENCODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS - global _DECODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS - assert _ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP is None or \ - _DECODER_RELATIVE_POSITION_EMBEDDING_GROUP is None, \ - 'relative position embedding group is already initialized' - for i in range(num_pipeline_model_parallel_groups): - ranks = range(i, world_size, num_pipeline_model_parallel_groups) - group = torch.distributed.new_group(ranks, backend=p2p_backend) - if rank in ranks: - _PIPELINE_MODEL_PARALLEL_GROUP = group - _PIPELINE_GLOBAL_RANKS = ranks - # Setup embedding group (to exchange gradients between - # first and last stages). - encoder_relative_position_embedding_ranks = None - decoder_relative_position_embedding_ranks = None - if len(ranks) > 1: - embedding_ranks = [ranks[0], ranks[-1]] - position_embedding_ranks = [ranks[0]] - encoder_relative_position_embedding_ranks = [ranks[0]] - decoder_relative_position_embedding_ranks = [ranks[0]] - if pipeline_model_parallel_split_rank_ is not None: - encoder_relative_position_embedding_ranks = \ - ranks[:pipeline_model_parallel_split_rank_] - decoder_relative_position_embedding_ranks = \ - ranks[pipeline_model_parallel_split_rank_:] - if ranks[pipeline_model_parallel_split_rank_] not in embedding_ranks: - embedding_ranks = [ - ranks[0], - ranks[pipeline_model_parallel_split_rank_], - ranks[-1], - ] - if ( - ranks[pipeline_model_parallel_split_rank_] - not in position_embedding_ranks - ): - position_embedding_ranks = [ - ranks[0], - ranks[pipeline_model_parallel_split_rank_], - ] - else: - embedding_ranks = ranks - position_embedding_ranks = ranks - encoder_relative_position_embedding_ranks = ranks - decoder_relative_position_embedding_ranks = ranks - - group = torch.distributed.new_group(embedding_ranks, backend=default_backend) - if rank in embedding_ranks: - _EMBEDDING_GROUP = group - if rank in ranks: - _EMBEDDING_GLOBAL_RANKS = embedding_ranks - - group = torch.distributed.new_group(position_embedding_ranks, backend=default_backend) - if rank in position_embedding_ranks: - _POSITION_EMBEDDING_GROUP = group - if rank in ranks: - _POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks - - if encoder_relative_position_embedding_ranks: - group = torch.distributed.new_group(encoder_relative_position_embedding_ranks) - if rank in encoder_relative_position_embedding_ranks: - _ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP = group - if rank in ranks: - _ENCODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS = \ - encoder_relative_position_embedding_ranks - - if decoder_relative_position_embedding_ranks: - group = torch.distributed.new_group(decoder_relative_position_embedding_ranks) - if rank in decoder_relative_position_embedding_ranks: - _DECODER_RELATIVE_POSITION_EMBEDDING_GROUP = group - if rank in ranks: - _DECODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS = \ - decoder_relative_position_embedding_ranks - -def get_rank_info() -> Tuple[int, int, int]: - """Returns a tuple of (data, tensor, pipeline, virtual pipeline)-parallel-rank for logger.""" - if model_parallel_is_initialized(): - return ( - get_data_parallel_rank(), - get_tensor_model_parallel_rank(), - get_pipeline_model_parallel_rank(), - get_virtual_pipeline_model_parallel_rank(), - ) - return (0, 0, 0, 0) - - -def model_parallel_is_initialized(): - """Check if model and data parallel groups are initialized.""" - if ( - _TENSOR_MODEL_PARALLEL_GROUP is None - or _PIPELINE_MODEL_PARALLEL_GROUP is None - or _DATA_PARALLEL_GROUP is None - ): - return False - return True - - -def get_model_parallel_group(): - """Get the model parallel group the caller rank belongs to.""" - assert _MODEL_PARALLEL_GROUP is not None, "model parallel group is not initialized" - return _MODEL_PARALLEL_GROUP - - -def get_tensor_model_parallel_group(): - """Get the tensor model parallel group the caller rank belongs to.""" - assert ( - _TENSOR_MODEL_PARALLEL_GROUP is not None - ), "intra_layer_model parallel group is not initialized" - return _TENSOR_MODEL_PARALLEL_GROUP - - -def get_pipeline_model_parallel_group(): - """Get the pipeline model parallel group the caller rank belongs to.""" - assert ( - _PIPELINE_MODEL_PARALLEL_GROUP is not None - ), "pipeline_model parallel group is not initialized" - return _PIPELINE_MODEL_PARALLEL_GROUP - - -def get_data_parallel_group(): - """Get the data parallel group the caller rank belongs to.""" - assert _DATA_PARALLEL_GROUP is not None, "data parallel group is not initialized" - return _DATA_PARALLEL_GROUP - - -def get_embedding_group(): - """Get the embedding group the caller rank belongs to.""" - assert _EMBEDDING_GROUP is not None, "embedding group is not initialized" - return _EMBEDDING_GROUP - - -def get_position_embedding_group(): - """Get the position embedding group the caller rank belongs to.""" - assert ( - _POSITION_EMBEDDING_GROUP is not None - ), "position embedding group is not initialized" - return _POSITION_EMBEDDING_GROUP - -def get_encoder_relative_position_embedding_group(): - """Get the encoder relative position embedding group the caller rank belongs to.""" - assert _ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP is not None, \ - 'encoder relative position embedding group is not initialized' - return _ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP - -def get_decoder_relative_position_embedding_group(): - """Get the decoder relative position embedding group the caller rank belongs to.""" - assert _DECODER_RELATIVE_POSITION_EMBEDDING_GROUP is not None, \ - 'decoder relative position embedding group is not initialized' - return _DECODER_RELATIVE_POSITION_EMBEDDING_GROUP - -def is_rank_in_embedding_group(ignore_virtual=False): - """Return true if current rank is in embedding group, False otherwise.""" - rank = torch.distributed.get_rank() - global _EMBEDDING_GLOBAL_RANKS - if ignore_virtual: - return rank in _EMBEDDING_GLOBAL_RANKS - if rank in _EMBEDDING_GLOBAL_RANKS: - if rank == _EMBEDDING_GLOBAL_RANKS[0]: - return is_pipeline_first_stage(ignore_virtual=False) - elif rank == _EMBEDDING_GLOBAL_RANKS[-1]: - return is_pipeline_last_stage(ignore_virtual=False) - else: - return True - return False - - -def is_rank_in_position_embedding_group(): - """Return whether the current rank is in position embedding group.""" - rank = torch.distributed.get_rank() - global _POSITION_EMBEDDING_GLOBAL_RANKS - return rank in _POSITION_EMBEDDING_GLOBAL_RANKS - -def is_rank_in_encoder_relative_position_embedding_group(): - """Return true if current rank is in encoder relative position embedding group, False otherwise.""" - rank = torch.distributed.get_rank() - global _ENCODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS - return rank in _ENCODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS - -def is_rank_in_decoder_relative_position_embedding_group(): - """Return true if current rank is in decoder relative position embedding group, False otherwise.""" - rank = torch.distributed.get_rank() - global _DECODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS - return rank in _DECODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS - -def is_pipeline_stage_before_split(rank=None): - """Return True if pipeline stage executes encoder block for a model - with both encoder and decoder.""" - if get_pipeline_model_parallel_world_size() == 1: - return True - if rank is None: - rank = get_pipeline_model_parallel_rank() - global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK - if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None: - return True - if rank < _PIPELINE_MODEL_PARALLEL_SPLIT_RANK: - return True - return False - - -def is_pipeline_stage_after_split(rank=None): - """Return True if pipeline stage executes decoder block for a model - with both encoder and decoder.""" - if get_pipeline_model_parallel_world_size() == 1: - return True - if rank is None: - rank = get_pipeline_model_parallel_rank() - global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK - if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None: - return True - if rank >= _PIPELINE_MODEL_PARALLEL_SPLIT_RANK: - return True - return False - - -def is_pipeline_stage_at_split(): - """Return true if pipeline stage executes decoder block and next - stage executes encoder block for a model with both encoder and - decoder.""" - rank = get_pipeline_model_parallel_rank() - return is_pipeline_stage_before_split(rank) and is_pipeline_stage_after_split( - rank + 1 - ) - - -def set_tensor_model_parallel_world_size(world_size): - """Set the tensor model parallel size""" - global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE - _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = world_size - - -def set_pipeline_model_parallel_world_size(world_size): - """Set the pipeline model parallel size""" - global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE - _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size - - -def get_tensor_model_parallel_world_size(): - """Return world size for the tensor model parallel group.""" - global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE - if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None: - return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE - return torch.distributed.get_world_size(group=get_tensor_model_parallel_group()) - - -def get_pipeline_model_parallel_world_size(): - """Return world size for the pipeline model parallel group.""" - global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE - if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None: - return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE - return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group()) - - -def set_tensor_model_parallel_rank(rank): - """Set tensor model parallel rank.""" - global _MPU_TENSOR_MODEL_PARALLEL_RANK - _MPU_TENSOR_MODEL_PARALLEL_RANK = rank - - -def set_pipeline_model_parallel_rank(rank): - """Set pipeline model parallel rank.""" - global _MPU_PIPELINE_MODEL_PARALLEL_RANK - _MPU_PIPELINE_MODEL_PARALLEL_RANK = rank - - -def get_tensor_model_parallel_rank(): - """Return my rank for the tensor model parallel group.""" - global _MPU_TENSOR_MODEL_PARALLEL_RANK - if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None: - return _MPU_TENSOR_MODEL_PARALLEL_RANK - return torch.distributed.get_rank(group=get_tensor_model_parallel_group()) - - -def get_pipeline_model_parallel_rank(): - """Return my rank for the pipeline model parallel group.""" - global _MPU_PIPELINE_MODEL_PARALLEL_RANK - if _MPU_PIPELINE_MODEL_PARALLEL_RANK is not None: - return _MPU_PIPELINE_MODEL_PARALLEL_RANK - return torch.distributed.get_rank(group=get_pipeline_model_parallel_group()) - - -# TODO (mkozuki): Add [`get_num_layers`](https://github.com/NVIDIA/Megatron-LM/blob/e156d2fea7fc5c98e645f7742eb86b643956d840/megatron/mpu/initialize.py#L321) here, maybe? - - -def get_pipeline_model_parallel_split_rank(): - """Return my rank for the pipeline model parallel split rank.""" - global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK - return _PIPELINE_MODEL_PARALLEL_SPLIT_RANK - - -def set_pipeline_model_parallel_split_rank(pipeline_model_parallel_split_rank: int): - """Set my rank for the pipeline model parallel split rank.""" - global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK - _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank - - -def is_pipeline_first_stage(ignore_virtual=False): - """Return True if in the first pipeline model-parallel stage, False otherwise.""" - if not ignore_virtual: - if ( - get_virtual_pipeline_model_parallel_world_size() is not None - and get_virtual_pipeline_model_parallel_rank() != 0 - ): - return False - return get_pipeline_model_parallel_rank() == 0 - - -def is_pipeline_last_stage(ignore_virtual=False): - """Return True if in the last pipeline model-parallel stage, False otherwise.""" - if not ignore_virtual: - virtual_pipeline_model_parallel_world_size = ( - get_virtual_pipeline_model_parallel_world_size() - ) - if virtual_pipeline_model_parallel_world_size is not None and get_virtual_pipeline_model_parallel_rank() != ( - virtual_pipeline_model_parallel_world_size - 1 - ): - return False - return get_pipeline_model_parallel_rank() == ( - get_pipeline_model_parallel_world_size() - 1 - ) - - -def get_virtual_pipeline_model_parallel_rank(): - """Return the virtual pipeline-parallel rank.""" - global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK - return _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK - - -def set_virtual_pipeline_model_parallel_rank(rank): - """Set the virtual pipeline-parallel rank.""" - global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK - _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = rank - - -def get_virtual_pipeline_model_parallel_world_size(): - """Return the virtual pipeline-parallel world size.""" - global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE - return _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE - - -def get_tensor_model_parallel_src_rank(): - """Calculate the global rank corresponding to the first local rank - in the tensor model parallel group.""" - global_rank = torch.distributed.get_rank() - local_world_size = get_tensor_model_parallel_world_size() - return (global_rank // local_world_size) * local_world_size - - -def get_data_parallel_src_rank(): - """Calculate the global rank corresponding to the first local rank in the data parallel group.""" - global_rank = torch.distributed.get_rank() - data_parallel_size: int = get_data_parallel_world_size() - num_data_parallel_groups = torch.distributed.get_world_size() // data_parallel_size - return global_rank % num_data_parallel_groups - - -def get_pipeline_model_parallel_first_rank(): - assert ( - _PIPELINE_GLOBAL_RANKS is not None - ), "Pipeline parallel group is not initialized" - return _PIPELINE_GLOBAL_RANKS[0] - - -def get_pipeline_model_parallel_last_rank(): - assert ( - _PIPELINE_GLOBAL_RANKS is not None - ), "Pipeline parallel group is not initialized" - last_rank_local = get_pipeline_model_parallel_world_size() - 1 - return _PIPELINE_GLOBAL_RANKS[last_rank_local] - - -def get_pipeline_model_parallel_next_rank(): - assert ( - _PIPELINE_GLOBAL_RANKS is not None - ), "Pipeline parallel group is not initialized" - rank_in_pipeline = get_pipeline_model_parallel_rank() - world_size = get_pipeline_model_parallel_world_size() - return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size] - - -def get_pipeline_model_parallel_prev_rank(): - assert ( - _PIPELINE_GLOBAL_RANKS is not None - ), "Pipeline parallel group is not initialized" - rank_in_pipeline = get_pipeline_model_parallel_rank() - world_size = get_pipeline_model_parallel_world_size() - return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size] - - -def get_data_parallel_world_size(): - """Return world size for the data parallel group.""" - return torch.distributed.get_world_size(group=get_data_parallel_group()) - - -def get_data_parallel_rank(): - """Return my rank for the data parallel group.""" - return torch.distributed.get_rank(group=get_data_parallel_group()) - - -# note (mkozuki): `destroy_model_parallel` voids more global variables than Megatron-LM. -# Otherwise pipeline parallel forward_backward functions test hangs possibly because -# the clean-up of the original is NOT enough. -def destroy_model_parallel(): - """Set the groups to none.""" - global _MODEL_PARALLEL_GROUP - _MODEL_PARALLEL_GROUP = None - global _TENSOR_MODEL_PARALLEL_GROUP - _TENSOR_MODEL_PARALLEL_GROUP = None - global _PIPELINE_MODEL_PARALLEL_GROUP - _PIPELINE_MODEL_PARALLEL_GROUP = None - global _DATA_PARALLEL_GROUP - _DATA_PARALLEL_GROUP = None - global _EMBEDDING_GROUP - _EMBEDDING_GROUP = None - global _POSITION_EMBEDDING_GROUP - _POSITION_EMBEDDING_GROUP = None - global _ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP - _ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP = None - global _DECODER_RELATIVE_POSITION_EMBEDDING_GROUP - _DECODER_RELATIVE_POSITION_EMBEDDING_GROUP = None - global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK - _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None - global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE - _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None - global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE - _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None - global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE - _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None - global _MPU_TENSOR_MODEL_PARALLEL_RANK - _MPU_TENSOR_MODEL_PARALLEL_RANK = None - global _MPU_PIPELINE_MODEL_PARALLEL_RANK - _MPU_PIPELINE_MODEL_PARALLEL_RANK = None - - -# Used to warn when the UCC is specified. -class ExperimentalWarning(Warning): pass - - -def check_torch_ucc_availability() -> None: - try: - import torch_ucc # NOQA - except ImportError: - raise ImportError( - "UCC backend requires [torch_ucc](https://github.com/facebookresearch/torch_ucc) but not found" - ) diff --git a/apex/transformer/pipeline_parallel/__init__.py b/apex/transformer/pipeline_parallel/__init__.py deleted file mode 100644 index 98bb960..0000000 --- a/apex/transformer/pipeline_parallel/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -from apex.transformer.pipeline_parallel.schedules import get_forward_backward_func -from apex.transformer.pipeline_parallel.schedules.common import build_model - - -__all__ = [ - "get_forward_backward_func", - "build_model", -] diff --git a/apex/transformer/pipeline_parallel/_timers.py b/apex/transformer/pipeline_parallel/_timers.py deleted file mode 100644 index 55d89f3..0000000 --- a/apex/transformer/pipeline_parallel/_timers.py +++ /dev/null @@ -1,83 +0,0 @@ -import time - -import torch - - -class _Timer: - """Timer.""" - - def __init__(self, name): - self.name_ = name - self.elapsed_ = 0.0 - self.started_ = False - self.start_time = time.time() - - def start(self): - """Start the timer.""" - assert not self.started_, "timer has already been started" - torch.cuda.synchronize() - self.start_time = time.time() - self.started_ = True - - def stop(self): - """Stop the timer.""" - assert self.started_, "timer is not started" - torch.cuda.synchronize() - self.elapsed_ += time.time() - self.start_time - self.started_ = False - - def reset(self): - """Reset timer.""" - self.elapsed_ = 0.0 - self.started_ = False - - def elapsed(self, reset=True): - """Calculate the elapsed time.""" - started_ = self.started_ - # If the timing in progress, end it first. - if self.started_: - self.stop() - # Get the elapsed time. - elapsed_ = self.elapsed_ - # Reset the elapsed time - if reset: - self.reset() - # If timing was in progress, set it back. - if started_: - self.start() - return elapsed_ - - -class _Timers: - """Group of timers.""" - - def __init__(self): - self.timers = {} - - def __call__(self, name): - if name not in self.timers: - self.timers[name] = _Timer(name) - return self.timers[name] - - def write(self, names, writer, iteration, normalizer=1.0, reset=False): - """Write timers to a tensorboard writer""" - # currently when using add_scalars, - # torch.utils.add_scalars makes each timer its own run, which - # polutes the runs list, so we just add each as a scalar - assert normalizer > 0.0 - for name in names: - value = self.timers[name].elapsed(reset=reset) / normalizer - writer.add_scalar(name + "-time", value, iteration) - - def log(self, names, normalizer=1.0, reset=True): - """Log a group of timers.""" - assert normalizer > 0.0 - string = "time (ms)" - for name in names: - elapsed_time = self.timers[name].elapsed(reset=reset) * 1000.0 / normalizer - string += " | {}: {:.2f}".format(name, elapsed_time) - if torch.distributed.is_initialized(): - if torch.distributed.get_rank() == (torch.distributed.get_world_size() - 1): - print(string, flush=True) - else: - print(string, flush=True) diff --git a/apex/transformer/pipeline_parallel/p2p_communication.py b/apex/transformer/pipeline_parallel/p2p_communication.py deleted file mode 100644 index 6c4b0d9..0000000 --- a/apex/transformer/pipeline_parallel/p2p_communication.py +++ /dev/null @@ -1,578 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2021-22, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# TODO(mkozuki): Consider removing `timers`. - -from functools import reduce -import operator -from typing import Union, Optional, Tuple - -import torch - -from apex.transformer import parallel_state -from apex.transformer.log_util import get_transformer_logger -from apex.transformer.utils import split_tensor_into_1d_equal_chunks -from apex.transformer.utils import gather_split_1d_tensor -from apex.transformer.pipeline_parallel.utils import Shape -from apex.transformer.pipeline_parallel._timers import _Timers - - -_logger = get_transformer_logger(__name__) - - -class FutureTensor: - def __init__(self, tensor: torch.Tensor, waitfunc): - self.tensor = tensor - self.waitfunc = waitfunc - - def get(self): - if self.waitfunc is not None: - res = self.waitfunc() - if isinstance(res, torch.Tensor): - self.tensor = res - self.waitfunc = None - return self.tensor - - -def _run_p2pops( - tensor_send_prev: Union[torch.Tensor, None], - tensor_send_next: Union[torch.Tensor, None], - tensor_recv_prev: Union[torch.Tensor, None], - tensor_recv_next: Union[torch.Tensor, None], - async_comm: bool = False -): - ops = [] - p2p_group = parallel_state.get_pipeline_model_parallel_group() - default_group = parallel_state.get_model_parallel_group() - - need_to_sync = p2p_group.name() != default_group.name() - - if tensor_send_prev is not None: - send_prev_op = torch.distributed.P2POp( - op=torch.distributed.isend, - tensor=tensor_send_prev, - peer=parallel_state.get_pipeline_model_parallel_prev_rank(), - group=p2p_group, - ) - ops.append(send_prev_op) - if tensor_recv_prev is not None: - recv_prev_op = torch.distributed.P2POp( - op=torch.distributed.irecv, - tensor=tensor_recv_prev, - peer=parallel_state.get_pipeline_model_parallel_prev_rank(), - group=p2p_group, - ) - ops.append(recv_prev_op) - if tensor_send_next is not None: - send_next_op = torch.distributed.P2POp( - op=torch.distributed.isend, - tensor=tensor_send_next, - peer=parallel_state.get_pipeline_model_parallel_next_rank(), - group=p2p_group, - ) - ops.append(send_next_op) - if tensor_recv_next is not None: - recv_next_op = torch.distributed.P2POp( - op=torch.distributed.irecv, - tensor=tensor_recv_next, - peer=parallel_state.get_pipeline_model_parallel_next_rank(), - group=p2p_group, - ) - ops.append(recv_next_op) - if len(ops) > 0: - if need_to_sync: - torch.cuda.synchronize() - - reqs = torch.distributed.batch_isend_irecv(ops) - if async_comm: - assert len(reqs) == len(ops) - tensor_send_prev_req = None if tensor_send_prev is None else reqs.pop(0) - tensor_recv_prev_req = None if tensor_recv_prev is None else reqs.pop(0) - tensor_send_next_req = None if tensor_send_next is None else reqs.pop(0) - tensor_recv_next_req = None if tensor_recv_next is None else reqs.pop(0) - return (tensor_send_prev_req, tensor_recv_prev_req, tensor_send_next_req, tensor_recv_next_req) - else: - for req in reqs: - req.wait() - return (None, None, None, None) - return (None, None, None, None) - - -# TODO(mkozuki): Check if it's possible to sunset `override_scatter_gather_tensors_in_pipeline`. -# TODO(mkozuki): Think about if it's possible to push some logic and arguments e.g. -# `scatter_gather_tensors_in_pipeline`, `sequence_parallel_enabled`, and -# `override_scatter_gather_tensors_in_pipeline` # to the user of -# apex.transformer forward_backwardfunctions. -def _communicate( - tensor_send_next: Optional[torch.Tensor], - tensor_send_prev: Optional[torch.Tensor], - recv_prev: bool, - recv_next: bool, - tensor_shape: Optional[Shape] = None, - override_scatter_gather_tensors_in_pipeline: bool = False, - dtype_: Optional[torch.dtype] = None, - *, - scatter_gather_tensors_in_pipeline: bool = True, - params_dtype: Optional[torch.dtype] = None, - fp32_residual_connection: bool = False, - async_comm: bool = False, - sequence_parallel_enabled: bool = False, -) -> Tuple[Union[torch.Tensor, FutureTensor, None], Union[torch.Tensor, FutureTensor, None]]: - """Base function for communication of tensors between stages. - - - .. note:: - Reference https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/blob/cfd2e2160700b7f2c1bf35298ac14bc341f4c759/megatron/p2p_communication.py#L24-L159 - - dtype logic: If none of ``dtype_``, ``params_dtype``, ``fp32_residual_connection`` is specified, - torch.float32 is used. - - See https://github.com/NVIDIA/Megatron-LM/blob/d41696840ed0a7edb7e0499eb82a48ae112d9bb3/megatron/arguments.py#L145-L159 - for the details of arguments of ``dtype_``, ``params_dtype``, ``fp32_residual_connection``. - - Args: - tensor_send_next: tensor to send to next rank (no tensor sent if set to None). - tensor_send_prev: tensor to send to prev rank (no tensor sent if set to None). - recv_prev: boolean for whether tensor should be received from previous rank. - recv_next: boolean for whether tensor should be received from next rank. - tensor_shape: optional, use when the input sequence contains less tokens than the default sequence length - override_scatter_gather_tensors_in_pipeline: - optional, this is used when tensor_shape is provided to override scatter gather tensors - dtype_: This is used when tensor_shape is provided and what is the type of tensor_shape - - Keyword args: - scatter_gather_tensors_in_pipeline: Optional. If :obj:`True`, use scatter/gather to optimize communication of tensors. - params_dtype: Optional and legacy. Defaults to torch.float. If you manually call `.half()` or `.bfloat16()` on - your model deliberately, pass this argument. - fp32_residual_connection: Optional. If :obj:`True`, move residual connections to fp32. - sequence_parallel_enabled: Set to :obj:`True` if sequence parallel is enabled. - This argument is here for consistency with Megatron-LM. - This argument has an effect on the communication optimization, not on tensor_shape update. - - Returns: - tuple containing - - - tensor_recv_prev: `torch.Tensor` if `recv_prev` is :obj:`True`, `None` otherwise. - - tensor_recv_next: `torch.Tensor` if `recv_next` is :obj:`True`, `None` otherwise. - """ - if async_comm and sequence_parallel_enabled: - import warnings # NOQA - class ExperimentalWarning(UserWarning): pass # NOQA - warnings.warn( - "The combination of `async_comm` and `sequence_parallel_enabled` is not well tested.", - ExperimentalWarning, - ) - # Create placeholder tensors for receive in forward and backward directions if needed. - tensor_recv_prev = None - tensor_recv_next = None - if tensor_shape is None: - # In megatron, `tensor_shape` is set to `(args.seq_length, args.micro_batch_size, args.hidden_size)` - raise RuntimeError( - "`tensor_shape` must be specified. Common `tensor_shape` is `(seq_length, micro_batch_size, hidden_size)`") - - tensor_parallel_size = parallel_state.get_tensor_model_parallel_world_size() - override_scatter_gather_tensors_in_pipeline_ = False - # TODO(mkozuki): Demystify hardcode False of `scatter_gather_tensors_in_pipeline` and add a testcase if possible. - # NOTE(mkozuki): This is super strange and doesn't make sense to me. I have no idea what is happening here. - # However, I can say that this hardcoding override is necessary for sequence parallel in nemo megatron to work. - # I've not managed to reproduce the hang using standalone GPT with sequence parallel. - # The hang in NeMo Megatron happens in the 3rd iteration, the last iteration of stead phase inside - # forward_backward_pipelining_without_interleaving, pipeline parallel rank of 0 (tensor model parallel world - # size of 2 and pipeline model parallel world size of 2). The commit then of APEX and NeMo were - # https://github.com/NVIDIA/apex/pull/1396/commits/3060c98dd8ba42abf7702ea9d2cff0f39ea74f45 and - # https://github.com/NVIDIA/NeMo/pull/4232/commits/1cb32dfca2ab9b20f53ebdb84476c34cb42f0205. - # The PyTorch version was 1.13.0a0+git2d354cd, for what is worth. - # Currently, indiscriminately this is set to `False`, which can lead to an unexpected performance regression - # for non sequence parallel case. - scatter_gather_tensors_in_pipeline = False - if scatter_gather_tensors_in_pipeline and not sequence_parallel_enabled: - tensor_chunk_size = int(reduce(operator.mul, tensor_shape, 1)) - if tensor_chunk_size % tensor_parallel_size == 0: - tensor_chunk_shape = [tensor_chunk_size // tensor_parallel_size] - else: - tensor_chunk_shape = tensor_shape - override_scatter_gather_tensors_in_pipeline_ = True - else: - tensor_chunk_shape = tensor_shape - - # The dtype logic below is copied from NVIDIA/Megatron-LM repo: - # https://github.com/NVIDIA/Megatron-LM/blob/d41696840ed0a7edb7e0499eb82a48ae112d9bb3/megatron/p2p_communication.py#L74-L81 - dtype = params_dtype or torch.float - if fp32_residual_connection: - dtype = torch.float - requires_grad = True - if dtype_ is not None: - dtype = dtype_ - # TODO(mkozuki): Figure out why this logic of requires_grad isn't working - # when sequence_parallel_enabled=True. Otherwise, `x.retain_grad()` of - # https://github.com/crcrpar/apex/blob/069832078a652b4bd8a99db84faf953a81415ab3/apex/transformer/pipeline_parallel/schedules/common.py#L360 - # fails. - # requires_grad = False - - if recv_prev: - tensor_recv_prev = torch.empty( - tensor_chunk_shape, - requires_grad=requires_grad, - device=torch.cuda.current_device(), - dtype=dtype, - ) - if recv_next: - tensor_recv_next = torch.empty( - tensor_chunk_shape, - requires_grad=requires_grad, - device=torch.cuda.current_device(), - dtype=dtype, - ) - - # Split tensor into smaller chunks if using scatter-gather optimization. - scatter_gather_optimization_doable = ( - not override_scatter_gather_tensors_in_pipeline_ - and scatter_gather_tensors_in_pipeline - and not sequence_parallel_enabled - ) - if scatter_gather_optimization_doable: - if tensor_send_next is not None: - tensor_send_next = split_tensor_into_1d_equal_chunks(tensor_send_next) - - if tensor_send_prev is not None: - tensor_send_prev = split_tensor_into_1d_equal_chunks(tensor_send_prev) - - # Send tensors in both the forward and backward directions as appropriate. - tensor_send_prev_req, tensor_recv_prev_req, tensor_send_next_req, tensor_recv_next_req = _run_p2pops(tensor_send_prev, tensor_send_next, tensor_recv_prev, tensor_recv_next, async_comm=async_comm) - - if async_comm: - tensor_recv_prev_waitfunc = None - tensor_recv_next_waitfunc = None - # TODO: investigate whether this is necessary for correctness (ref: https://github.com/pytorch/pytorch/issues/38642) - # see also: sync added for async_comm callbacks below in gather_recv_prev_wait and gather_recv_next_wait - if tensor_recv_prev_req is not None: - def tensor_recv_prev_wait(): - tensor_recv_prev_req.wait() - torch.cuda.synchronize() - tensor_recv_prev_waitfunc = tensor_recv_prev_wait - if tensor_recv_next_req is not None: - def tensor_recv_next_wait(): - tensor_recv_next_req.wait() - torch.cuda.synchronize() - tensor_recv_next_waitfunc = tensor_recv_next_wait - else: - # To protect against race condition when using batch_isend_irecv(). - torch.cuda.synchronize() - - # If using scatter-gather optimization, gather smaller chunks. - if scatter_gather_optimization_doable: - if not async_comm: - if recv_prev: - tensor_recv_prev = ( - gather_split_1d_tensor(tensor_recv_prev) - .view(tensor_shape) - .requires_grad_() - ) - - if recv_next: - tensor_recv_next = ( - gather_split_1d_tensor(tensor_recv_next) - .view(tensor_shape) - .requires_grad_() - ) - else: - def gather_recv_prev_wait(): - tensor_recv_prev_req.wait() - # From @Deepak's PR https://github.com/NVIDIA/Megatron-LM/commit/27fc468964064eeb33b703c9a0b2af938d80dd14 - # A sync seems to be needed before gather otherwise losses jump around e.g., in run_gpt_minimal_test - torch.cuda.synchronize() - return ( - gather_split_1d_tensor(tensor_recv_prev) - .view(tensor_shape) - .requires_grad_() - ) - def gather_recv_next_wait(): - tensor_recv_next_req.wait() - torch.cuda.synchronize() - return ( - gather_split_1d_tensor(tensor_recv_next) - .view(tensor_shape) - .requires_grad_() - ) - tensor_recv_prev_waitfunc = gather_recv_prev_wait - tensor_recv_next_waitfunc = gather_recv_next_wait - if async_comm: - future_tensor_recv_prev = None - future_tensor_recv_next = None - if tensor_recv_prev is not None: - future_tensor_recv_prev = FutureTensor(tensor_recv_prev, tensor_recv_prev_waitfunc) - if tensor_recv_next is not None: - future_tensor_recv_next = FutureTensor(tensor_recv_next, tensor_recv_next_waitfunc) - return future_tensor_recv_prev, future_tensor_recv_next - return tensor_recv_prev, tensor_recv_next - - -def recv_forward( - tensor_shape: Shape, - override_scatter_gather_tensors_in_pipeline: bool = False, - *, - dtype: Optional[torch.dtype] = None, - async_comm: bool = False, - sequence_parallel_enabled: bool = False, - timers: _Timers = None, -) -> Union[torch.Tensor, FutureTensor, None]: - """Receive tensor from previous rank in pipeline (forward receive).""" - if parallel_state.is_pipeline_first_stage(): - return None - # if timers is not None: - # timers("forward-recv").start() - input_tensor, _ = _communicate( - tensor_send_next=None, - tensor_send_prev=None, - recv_prev=True, - recv_next=False, - tensor_shape=tensor_shape, - override_scatter_gather_tensors_in_pipeline=override_scatter_gather_tensors_in_pipeline, - dtype_=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - ) - # if timers is not None: - # timers("forward-recv").stop() - return input_tensor - - -def recv_backward( - tensor_shape: Shape = None, - *, - dtype: Optional[torch.dtype] = None, - async_comm: bool = False, - sequence_parallel_enabled: bool = False, - timers: _Timers = None, -) -> Union[torch.Tensor, FutureTensor, None]: - """Receive tensor from next rank in pipeline (backward receive).""" - if parallel_state.is_pipeline_last_stage(): - return None - # if timers is not None: - # timers("backward-recv").start() - _, output_tensor_grad = _communicate( - tensor_send_next=None, - tensor_send_prev=None, - recv_prev=False, - recv_next=True, - tensor_shape=tensor_shape, - dtype_=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - ) - # if timers is not None: - # timers("backward-recv").stop() - return output_tensor_grad - - -def send_forward( - output_tensor: torch.Tensor, - override_scatter_gather_tensors_in_pipeline: bool = False, - tensor_shape: Shape = None, - *, - dtype: Optional[torch.dtype] = None, - async_comm: bool = False, - sequence_parallel_enabled: bool = False, - timers: _Timers = None, -) -> None: - """Send tensor to next rank in pipeline (forward send).""" - if parallel_state.is_pipeline_last_stage(): - return - # if timers is not None: - # timers("forward-send").start() - _communicate( - tensor_send_next=output_tensor, - tensor_send_prev=None, - recv_prev=False, - recv_next=False, - override_scatter_gather_tensors_in_pipeline=override_scatter_gather_tensors_in_pipeline, - tensor_shape=tensor_shape, - dtype_=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - ) - # if timers is not None: - # timers("forward-send").stop() - - -def send_backward( - input_tensor_grad: torch.Tensor, - tensor_shape: Shape, - *, - dtype: Optional[torch.dtype] = None, - async_comm: bool = False, - sequence_parallel_enabled: bool = False, - timers: _Timers = None, -) -> None: - """Send tensor to previous rank in pipeline (backward send).""" - if parallel_state.is_pipeline_first_stage(): - return - # if timers is not None: - # timers("backward-send").start() - _communicate( - tensor_send_next=None, - tensor_send_prev=input_tensor_grad, - recv_prev=False, - recv_next=False, - tensor_shape=tensor_shape, - dtype_=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - ) - # if timers is not None: - # timers("backward-send").stop() - - -def send_forward_recv_backward( - output_tensor: torch.Tensor, - tensor_shape: Shape, - *, - dtype: Optional[torch.dtype] = None, - async_comm: bool = False, - sequence_parallel_enabled: bool = False, - timers: _Timers = None, -) -> Union[torch.Tensor, FutureTensor, None]: - """Batched send and recv with next rank in pipeline.""" - if parallel_state.is_pipeline_last_stage(): - return None - # if timers is not None: - # timers("forward-send-backward-recv").start() - _, output_tensor_grad = _communicate( - tensor_send_next=output_tensor, - tensor_send_prev=None, - recv_prev=False, - recv_next=True, - tensor_shape=tensor_shape, - dtype_=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - ) - # if timers is not None: - # timers("forward-send-backward-recv").stop() - return output_tensor_grad - - -def send_backward_recv_forward( - input_tensor_grad: torch.Tensor, - tensor_shape: Shape, - *, - dtype: Optional[torch.dtype] = None, - async_comm: bool = False, - sequence_parallel_enabled: bool = False, - timers: _Timers = None, -) -> Union[torch.Tensor, FutureTensor, None]: - """Batched send and recv with previous rank in pipeline.""" - if parallel_state.is_pipeline_first_stage(): - return None - # if timers is not None: - # timers("backward-send-forward-recv").start() - input_tensor, _ = _communicate( - tensor_send_next=None, - tensor_send_prev=input_tensor_grad, - recv_prev=True, - recv_next=False, - tensor_shape=tensor_shape, - dtype_=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - ) - # if timers is not None: - # timers("backward-send-forward-recv").stop() - return input_tensor - - -def send_forward_recv_forward( - output_tensor: torch.Tensor, - recv_prev: bool, - tensor_shape: Shape, - *, - dtype: Optional[torch.dtype] = None, - async_comm: bool = False, - sequence_parallel_enabled: bool = False, - timers: _Timers = None, -) -> Union[torch.Tensor, FutureTensor]: - """Batched recv from previous rank and send to next rank in pipeline.""" - # if timers is not None: - # timers("forward-send-forward-recv").start() - input_tensor, _ = _communicate( - tensor_send_next=output_tensor, - tensor_send_prev=None, - recv_prev=recv_prev, - recv_next=False, - tensor_shape=tensor_shape, - dtype_=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - ) - # if timers is not None: - # timers("forward-send-forward-recv").stop() - return input_tensor - - -def send_backward_recv_backward( - input_tensor_grad: torch.Tensor, - recv_next: bool, - tensor_shape: Shape, - *, - dtype: Optional[torch.dtype] = None, - async_comm: bool = False, - sequence_parallel_enabled: bool = False, - timers: _Timers = None, -) -> Union[torch.Tensor, FutureTensor]: - """Batched recv from next rank and send to previous rank in pipeline.""" - # if timers is not None: - # timers("backward-send-backward-recv").start() - _, output_tensor_grad = _communicate( - tensor_send_next=None, - tensor_send_prev=input_tensor_grad, - recv_prev=False, - recv_next=recv_next, - tensor_shape=tensor_shape, - dtype_=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - ) - # if timers is not None: - # timers("backward-send-backward-recv").stop() - return output_tensor_grad - - -def send_forward_backward_recv_forward_backward( - output_tensor: torch.Tensor, - input_tensor_grad: torch.Tensor, - recv_prev: bool, - recv_next: bool, - tensor_shape: Shape, - *, - dtype: Optional[torch.dtype] = None, - async_comm: bool = False, - sequence_parallel_enabled: bool = False, - timers: _Timers = None, -) -> Tuple[Union[torch.Tensor, FutureTensor], Union[torch.Tensor, FutureTensor]]: - """Batched send and recv with previous and next ranks in pipeline.""" - # if timers is not None: - # timers("forward-backward-send-forward-backward-recv").start() - input_tensor, output_tensor_grad = _communicate( - tensor_send_next=output_tensor, - tensor_send_prev=input_tensor_grad, - recv_prev=recv_prev, - recv_next=recv_next, - tensor_shape=tensor_shape, - dtype_=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - ) - # if timers is not None: - # timers("forward-backward-send-forward-backward-recv").stop() - return input_tensor, output_tensor_grad diff --git a/apex/transformer/pipeline_parallel/schedules/__init__.py b/apex/transformer/pipeline_parallel/schedules/__init__.py deleted file mode 100644 index 7e13192..0000000 --- a/apex/transformer/pipeline_parallel/schedules/__init__.py +++ /dev/null @@ -1,35 +0,0 @@ -from apex.transformer import parallel_state -from apex.transformer.pipeline_parallel.utils import get_num_microbatches -from apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining import ( - forward_backward_no_pipelining, -) -from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving import ( - _forward_backward_pipelining_with_interleaving, -) -from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import ( - forward_backward_pipelining_without_interleaving, -) - -__all__ = [ - "get_forward_backward_func", -] - - -class ExperimentalWarning(Warning): - pass - - -def get_forward_backward_func( - virtual_pipeline_model_parallel_size, pipeline_model_parallel_size, -): - if parallel_state.get_pipeline_model_parallel_world_size() > 1: - if virtual_pipeline_model_parallel_size is not None: - if get_num_microbatches() % pipeline_model_parallel_size != 0: - msg = "number of microbatches is not divisible by pipeline-parallel size when using interleaved schedule" - raise RuntimeError(msg) - forward_backward_func = _forward_backward_pipelining_with_interleaving - else: - forward_backward_func = forward_backward_pipelining_without_interleaving - else: - forward_backward_func = forward_backward_no_pipelining - return forward_backward_func diff --git a/apex/transformer/pipeline_parallel/schedules/common.py b/apex/transformer/pipeline_parallel/schedules/common.py deleted file mode 100644 index 6016035..0000000 --- a/apex/transformer/pipeline_parallel/schedules/common.py +++ /dev/null @@ -1,398 +0,0 @@ -from typing import Any, Callable, Dict, List, Tuple, Union, Optional, Sequence - -import torch -from torch.autograd.variable import Variable - -from apex.normalization.fused_layer_norm import FusedLayerNorm -from apex.transformer import parallel_state -from apex.transformer.enums import ModelType -from apex.transformer.pipeline_parallel.p2p_communication import FutureTensor -from apex.transformer.pipeline_parallel.utils import get_num_microbatches -from apex.transformer.pipeline_parallel.utils import listify_model -from apex.transformer.pipeline_parallel.utils import unwrap_model -from apex.transformer.pipeline_parallel.utils import get_model_type -from apex.transformer.tensor_parallel.layers import ( - set_defaults_if_not_set_tensor_model_parallel_attributes, -) -from apex.transformer.log_util import get_transformer_logger - - -_logger = get_transformer_logger(__name__) - - -Batch = Union[torch.Tensor, FutureTensor, List[Union[torch.Tensor, FutureTensor]], Tuple[Union[torch.Tensor, FutureTensor], ...]] -LossFunc = Callable[[torch.Tensor], torch.Tensor] -FwdStepFunc = Callable[ - [Optional[Batch], torch.nn.Module], Tuple[torch.Tensor, LossFunc] -] - - -def build_model( - model_provider_func: Callable[[Any, Dict[str, Any]], torch.nn.Module], - wrap_with_ddp: bool = True, - virtual_pipeline_model_parallel_size: Optional[int] = None, - model_type: ModelType = ModelType.encoder_or_decoder, - *args: Any, - **kwargs: Any, -) -> List[torch.nn.Module]: - """Build the model satisfying pipeline model parallel requirements. - - This function sets `pre_process` and `post_process` to `**kwargs` and pass `*args` and `**kwargs` to - `model_provider_func`. - - Args: - model_provider_func: A function which takes `*args` and `**kwargs` and returns a `nn.Module`. - wrap_with_ddp: If :obj:`True`, wrap the instantiated model - with `torch.nn.parallel.distributed.DistributedDataParallel`, a.k.a. `DDP`. - virtual_pipeline_model_parallel_size: Specify when using interleaving scheduling pipeline model parallel. - model_type: - *args: arguments for model provider func - **kwargs: Keyword arguments for model provider func - - Returns: - a list of `nn.Module`(s). If `virtual_pipeline_model_parallel_size` is not None, - the list has multiple models, otherwise one. - """ - if ( - parallel_state.get_pipeline_model_parallel_world_size() > 1 - and virtual_pipeline_model_parallel_size is not None - ): - model = [] - for i in range(virtual_pipeline_model_parallel_size): - cur_args = args - cur_kwargs = kwargs - parallel_state.set_virtual_pipeline_model_parallel_rank(i) - # Set pre_process and post_process only after virtual rank is set. - pre_process = parallel_state.is_pipeline_first_stage() - post_process = parallel_state.is_pipeline_last_stage() - cur_kwargs.update( - {"pre_process": pre_process, "post_process": post_process,} - ) - this_model = model_provider_func(*cur_args, **cur_kwargs) - model.append(this_model) - else: - cur_args = args - cur_kwargs = kwargs - if model_type == ModelType.encoder_or_decoder: - pre_process = parallel_state.is_pipeline_first_stage() - post_process = parallel_state.is_pipeline_last_stage() - cur_kwargs.update( - {"pre_process": pre_process, "post_process": post_process,} - ) - model = model_provider_func(*cur_args, **cur_kwargs) - elif model_type == ModelType.encoder_and_decoder: - pre_process = parallel_state.is_pipeline_first_stage() - post_process = parallel_state.is_pipeline_last_stage() - # `add_encoder` & `add_decoder` logic. - add_encoder, add_decoder = True, True - if parallel_state.get_pipeline_model_parallel_world_size() > 1: - split_rank = parallel_state.get_pipeline_model_parallel_split_rank() - if split_rank is None: - raise RuntimeError( - "Split rank needs to be specified for model with both encoder and decoder." - ) - rank = parallel_state.get_pipeline_model_parallel_rank() - world_size = parallel_state.get_pipeline_model_parallel_world_size() - pre_process = rank == 0 or rank == split_rank - post_process = rank == (split_rank - 1) or rank == (world_size - 1) - add_encoder = parallel_state.is_pipeline_stage_before_split() - add_decoder = parallel_state.is_pipeline_stage_after_split() - cur_kwargs.update( - { - "pre_process": pre_process, - "post_process": post_process, - "add_encoder": add_encoder, - "add_decoder": add_decoder, - } - ) - model = model_provider_func(*cur_args, **cur_kwargs) - model.model_type = model_type - - if not isinstance(model, list): - model = [model] - - # Set tensor model parallel attributes if not set. - # Only parameters that are already tensor model parallel have these - # attributes set for them. We should make sure the default attributes - # are set for all params so the optimizer can use them. - for model_module in model: - for param in model_module.parameters(): - set_defaults_if_not_set_tensor_model_parallel_attributes(param) - - # Print number of parameters. - if ( - parallel_state.model_parallel_is_initialized() - and parallel_state.get_data_parallel_rank() == 0 - ): - msg = " > number of parameters on (tensor, pipeline) model parallel rank ({}, {}): {}".format( - parallel_state.get_tensor_model_parallel_rank(), - parallel_state.get_pipeline_model_parallel_rank(), - _calc_number_of_params(model), - ) - print(msg, flush=True) - - # GPU allocation. - for model_module in model: - model_module.cuda(torch.cuda.current_device()) - - if wrap_with_ddp: - i = torch.cuda.current_device() - model = [ - torch.nn.parallel.distributed.DistributedDataParallel( - model_module, - device_ids=[i], - output_device=i, - process_group=parallel_state.get_data_parallel_group(), - ) - for model_module in model - ] - return model - - -def _calc_number_of_params(model: List[torch.nn.Module]) -> int: - assert isinstance(model, list) - return sum( - [ - sum([p.nelement() for p in model_module.parameters()]) - for model_module in model - ] - ) - - -def _get_params_for_weight_decay_optimization( - model: Union[torch.nn.Module, List[torch.nn.Module]], - *, - no_weight_decay_modules=(FusedLayerNorm,), -) -> Dict[str, torch.nn.Parameter]: - """Divide params into with-weight-decay and without-weight-decay groups. - - Layernorms and biases will have no weight decay but the rest will. - """ - modules = listify_model(model) - weight_decay_params = {"params": []} - no_weight_decay_params = {"params": [], "weight_decay": 0.0} - for module in modules: - for module_ in module.modules(): - if isinstance(module_, no_weight_decay_modules): - no_weight_decay_params["params"].extend( - [p for p in list(module_._parameters.values()) if p is not None] - ) - else: - weight_decay_params["params"].extend( - [ - p - for n, p in list(module_._parameters.items()) - if p is not None and n != "bias" - ] - ) - no_weight_decay_params["params"].extend( - [ - p - for n, p in list(module_._parameters.items()) - if p is not None and n == "bias" - ] - ) - - return weight_decay_params, no_weight_decay_params - - -def free_output_tensor( - output_tensors: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]], - deallocate_pipeline_outputs: bool = False, -) -> None: - """Pseudo-free the output tensor's `.data` field. - - This method should be called right after the output tensor has been sent to the next - pipeline stage. At this point, the output tensor is only useful for its `.grad_fn` field, - and not its `.data`. - """ - if not deallocate_pipeline_outputs: - return - if output_tensors is None: - return - if isinstance(output_tensors, torch.Tensor): - output_tensors = [output_tensors] - for output_tensor in output_tensors: - output_tensor.data = torch.cuda.FloatTensor([0]) - - -def custom_backward(output: torch.Tensor, grad_output: Optional[torch.Tensor]) -> None: - """Directly call C++ autograd engine. - - To make the `free_output_tensor` optimization work, the C++ autograd engine must be called - directly, bypassing PyTorch's `torch.autograd.backward`. PyTorch's `backward` checks that the - output and grad have the same shape, while C++ `backward` does not. - """ - assert ( - output.numel() == 1 - ), "output should be pseudo-freed in schedule, to optimize memory consumption" - assert isinstance(output, torch.Tensor), "output == {}.".format( - type(output).__name__ - ) - assert isinstance( - grad_output, (torch.Tensor, type(None)) - ), "grad_outptu == {}.".format(type(grad_output).__name__) - - # Handle scalar output - if grad_output is None: - assert output.numel() == 1, "Implicit grad requires scalar output." - grad_output = torch.ones_like(output, memory_format=torch.preserve_format) - - # Call C++ engine [ see torch/csrc/autograd/python_engine.cpp ] - Variable._execution_engine.run_backward( - tensors=(output,), - grad_tensors=(grad_output,), - keep_graph=False, - create_graph=False, - inputs=(), - allow_unreachable=True, - accumulate_grad=True, - ) - - -def forward_step( - forward_step_func: FwdStepFunc, - batch: Optional[Batch], - model: torch.nn.Module, - input_tensor: Optional[Union[torch.Tensor, List[torch.Tensor]]], - losses_reduced: List[torch.Tensor], - dtype: torch.dtype, - disable_autocast: bool = False, -) -> Union[torch.Tensor, Sequence[torch.Tensor]]: - """Forward step for passed-in model. - - If first stage, input tensor is obtained from batch, otherwise passed-in input_tensor is used. - - Returns output tensor. - - Args: - forward_step_func: Model specific function. This takes a minibatch and model as its arguments and - returns the model's output and the loss function. - batch: minibatch - model: unwrappable model - input_tensor: - losses_reduced: - dtype: - disable_autocast: - - Returns: - output_tensor - """ - # timers = get_timers() - # timers("forward-compute").start() - unwrapped_model = unwrap_model(model) - model_type = get_model_type(unwrapped_model) - # NOTE (mkozuki): The passed `model` is expected to implement `set_input_tensor`. - # See https://github.com/NVIDIA/Megatron-LM/blob/5ac5571ba0265af4c491ee0af1508ca7589450c6/megatron/model/transformer.py#L679 # NOQA - # for the details of `set_input_tensor`. - unwrap_output_tensor = not isinstance(input_tensor, list) - if unwrap_output_tensor: - input_tensor = [input_tensor] - - input_tensor = [inp.get() if isinstance(inp, FutureTensor) else inp for inp in input_tensor] - - unwrapped_model.set_input_tensor(input_tensor) - with torch.cuda.amp.autocast( - enabled=not disable_autocast and dtype in (torch.half, torch.bfloat16), - dtype=dtype, - ): - output_tensor, loss_func = forward_step_func(batch, model) - if parallel_state.is_pipeline_last_stage(): - output_tensor = loss_func(output_tensor) - loss, loss_reduced = output_tensor - output_tensor = loss / get_num_microbatches() - losses_reduced.append(loss_reduced) - # timers("forward-compute").stop() - - # If T5 model (or other model with encoder and decoder) - # and in decoder stack, then send encoder_hidden_state - # downstream as well. - if ( - parallel_state.is_pipeline_stage_after_split() - and model_type == ModelType.encoder_and_decoder - ): - return [output_tensor, input_tensor[-1]] - if unwrap_output_tensor: - return output_tensor - return [output_tensor] - - -def backward_step( - input_tensor: Optional[torch.Tensor], - output_tensor: torch.Tensor, - output_tensor_grad: Optional[torch.Tensor], - model_type: ModelType, - *, - grad_scaler: Optional[torch.cuda.amp.GradScaler] = None, - deallocate_pipeline_outputs: bool = False, -) -> Union[None, torch.Tensor, Sequence[torch.Tensor]]: - """Backward step through passed-in output tensor. - - If last stage, output_tensor_grad is None, otherwise gradient of loss - with respect to stage's output tensor. - - Returns gradient of loss with respect to input tensor (None if first - stage). - - Args: - input_tensor: - output_tensor: - output_tensor_grad: - Keyword Arguments: - grad_scaler: - deallocate_pipeline_outputs: Experimental. - Returns: - input_tensor_grad - """ - - # timers = get_timers() - # timers("backward-compute").start() - - # Retain the grad on the input_tensor. - unwrap_input_tensor_grad = not isinstance(input_tensor, list) - if unwrap_input_tensor_grad: - input_tensor = [input_tensor] - - input_tensor = [inp.get() if isinstance(inp, FutureTensor) else inp for inp in input_tensor] - - for x in input_tensor: - if x is not None: - x.retain_grad() - - if not isinstance(output_tensor, list): - output_tensor = [output_tensor] - - output_tensor = [out.get() if isinstance(out, FutureTensor) else out for out in output_tensor] - - if not isinstance(output_tensor_grad, list): - output_tensor_grad = [output_tensor_grad] - - output_tensor_grad = [ogr.get() if isinstance(ogr, FutureTensor) else ogr for ogr in output_tensor_grad] - - # Backward pass. - if grad_scaler is not None and output_tensor_grad[0] is None: - output_tensor[0] = grad_scaler.scale(output_tensor[0]) - if deallocate_pipeline_outputs: - custom_backward(output_tensor[0], output_tensor_grad[0]) - else: - torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0]) - - # Collect the grad of the input_tensor. - input_tensor_grad = [None] - if input_tensor is not None: - input_tensor_grad = [] - for x in input_tensor: - input_tensor_grad.append(None if x is None else x.grad) - - # Handle single skip connection if it exists (encoder_hidden_state in model with encoder and decoder). - if ( - parallel_state.get_pipeline_model_parallel_world_size() > 1 - and parallel_state.is_pipeline_stage_after_split() - and model_type == ModelType.encoder_and_decoder - ): - if output_tensor_grad[1] is not None: - # todo (mkozuki): Replace the inplace add with `+= output_tensor_grad[1]`? - input_tensor_grad[-1].add_(output_tensor_grad[1]) - - # timers("backward-compute").stop() - return input_tensor_grad[0] if unwrap_input_tensor_grad else input_tensor_grad diff --git a/apex/transformer/pipeline_parallel/schedules/fwd_bwd_no_pipelining.py b/apex/transformer/pipeline_parallel/schedules/fwd_bwd_no_pipelining.py deleted file mode 100644 index 5500085..0000000 --- a/apex/transformer/pipeline_parallel/schedules/fwd_bwd_no_pipelining.py +++ /dev/null @@ -1,132 +0,0 @@ -from contextlib import contextmanager -from typing import List, Union, Optional - -import torch - -from apex.transformer.pipeline_parallel.utils import listify_model -from apex.transformer.pipeline_parallel.utils import get_num_microbatches -from apex.transformer.pipeline_parallel.utils import get_kth_microbatch -from apex.transformer.pipeline_parallel.utils import get_model_type -from apex.transformer.pipeline_parallel.schedules.common import Batch -from apex.transformer.pipeline_parallel.schedules.common import FwdStepFunc -from apex.transformer.pipeline_parallel.schedules.common import forward_step -from apex.transformer.pipeline_parallel.schedules.common import backward_step -from apex.transformer.log_util import get_transformer_logger - - -_all__ = ["forward_backward_no_pipelining"] - - -_logger = get_transformer_logger(__name__) - - -@contextmanager -def placeholder_handler(): - try: - yield - finally: - pass - - -def forward_backward_no_pipelining( - forward_step_func: FwdStepFunc, - batch: Batch, - model: Union[torch.nn.Module, List[torch.nn.Module]], - *, - forward_only: bool, - dtype: Optional[torch.dtype] = None, - grad_scaler: Optional[torch.cuda.amp.GradScaler] = None, - disable_autocast: bool = False, - custom_sync_context_handler=None, - **kwargs, -): - """Run forward and backward passes with no pipeline parallelism (no inter-stage communication). - - This pipeline parallel scheduling handles the last microbatch differently to synchronize gradients. - - Args: - forward_step_func: A function which takes a minibatch and model as its arguments and - returns model's forward output and the loss function. - The loss function is supposed to take one `torch.Tensor` and - return a `torch.Tensor` of loss and a dictionary of `str` and `torch.Tensor`. - batch: A List of torch.Tensors - model: A `torch.nn.Module` or a list of `torch.nn.Module`. - - Keyword args: - forward_only: - grad_scaler: - dtype: - disable_autocast: Turn off `enabled` flag of `torch.cuda.amp.autocast` if :obj:`True`. - Should be used when your forward and loss computation is in the autocast context to - avoid unnecesarily nest autocast context. - custom_sync_context_handler: - **kwargs: Added to handle `tensor_shape` which has no effect on this function. - - Returns: - a list of dictionaries of loss `torch.Tensor`s if the last stage, empty list otherwise. - """ - model = listify_model(model) - if len(model) != 1: - msg = f"`model` is expected be a `nn.Module`, but {type(model)}" - raise RuntimeError(msg) - model = model[0] - model_type = get_model_type(model) - - if custom_sync_context_handler is not None: - context_handler = custom_sync_context_handler - elif isinstance(model, torch.nn.parallel.distributed.DistributedDataParallel): - context_handler = model.no_sync - else: - context_handler = placeholder_handler - - losses_reduced = [] - input_tensor, output_tensor_grad = None, None - num_micro_batches = get_num_microbatches() - with context_handler(): - for i in range(num_micro_batches - 1): - _logger.info(f"Iter {i} of {num_micro_batches - 1}") - cur_micro_batch = get_kth_microbatch(batch, i) - _logger.debug("Call `forward_step`") - output_tensor = forward_step( - forward_step_func, - cur_micro_batch, - model, - input_tensor, - losses_reduced, - dtype=dtype, - disable_autocast=disable_autocast, - ) - if not forward_only: - _logger.debug("Call `backward_step`") - backward_step( - input_tensor, - output_tensor, - output_tensor_grad, - model_type=model_type, - grad_scaler=grad_scaler, - ) - - # Run computation for last microbatch out of context handler (want to - # synchronize gradients). - _logger.info("Cooldown") - _logger.debug("Call `forward_step`") - output_tensor = forward_step( - forward_step_func, - get_kth_microbatch(batch, num_micro_batches - 1), - model, - input_tensor, - losses_reduced, - dtype=dtype, - disable_autocast=disable_autocast, - ) - if not forward_only: - _logger.debug("Call `backward_step`") - backward_step( - input_tensor, - output_tensor, - output_tensor_grad, - model_type=model_type, - grad_scaler=grad_scaler, - ) - - return losses_reduced diff --git a/apex/transformer/pipeline_parallel/schedules/fwd_bwd_pipelining_with_interleaving.py b/apex/transformer/pipeline_parallel/schedules/fwd_bwd_pipelining_with_interleaving.py deleted file mode 100644 index 17ad833..0000000 --- a/apex/transformer/pipeline_parallel/schedules/fwd_bwd_pipelining_with_interleaving.py +++ /dev/null @@ -1,415 +0,0 @@ -from typing import List, Union, Optional, Sequence -import warnings - -import torch - -from apex.transformer import parallel_state -from apex.transformer.pipeline_parallel import p2p_communication -from apex.transformer.pipeline_parallel.schedules.common import Batch -from apex.transformer.pipeline_parallel.schedules.common import FwdStepFunc -from apex.transformer.pipeline_parallel.schedules.common import backward_step -from apex.transformer.pipeline_parallel.schedules.common import forward_step -from apex.transformer.pipeline_parallel.schedules.common import free_output_tensor -from apex.transformer.pipeline_parallel.utils import get_kth_microbatch -from apex.transformer.pipeline_parallel.utils import get_num_microbatches -from apex.transformer.pipeline_parallel.utils import get_model_type -from apex.transformer.log_util import get_transformer_logger - - -__all__ = ["_forward_backward_pipelining_with_interleaving"] - - -_logger = get_transformer_logger(__name__) - - -# TODO(mkozuki): Reduce cyclomatic complexity -def _forward_backward_pipelining_with_interleaving( - forward_step_func: FwdStepFunc, - batch: List[Optional[Batch]], - model: List[torch.nn.Module], - *, - forward_only: bool, - tensor_shape: Optional[Union[List[int], torch.Size]] = None, - dtype: Optional[torch.dtype] = None, - grad_scaler: Optional[torch.cuda.amp.GradScaler] = None, - disable_autocast: bool = False, - deallocate_pipeline_outputs: bool = False, - async_comm: bool = False, - sequence_parallel_enabled: bool = False, - **kwargs, -) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]: - """Run interleaved 1F1B schedule with communication between pipeline stages as needed. - - This function assumes `batch` and `model` is a list of `Batch`'s and a list of `torch.nn.Module`, respectively. - This means that model is split into model chunks. - - This pipeline parallel scheduling consists of three steps: - 1. warmup - 2. 1F1B a.k.a. steady state - 3. cooldown - Note that if `forward_only` this scheduling consists of only warmup phase. - - Args: - forward_step_func: A function which takes a minibatch and model as its arguments and - returns model's forward output and the loss function. - The loss function is supposed to take one `torch.Tensor` and - return a `torch.Tensor` of loss and a dictionary of `str` and `torch.Tensor`. - batch: A minibatch, i.e., a list of `torch.Tensor`'s. - model: A `torch.nn.Module` or a list of `torch.nn.Module`. - - Keyword args: - forward_only: - tensor_shape: Shape of tensor. The tensor is expected to be 3D and its order of dimension - is supposed to be ``(sequence, batch, hidden)``. - dtype: dtype used in p2p communication. If ``None`` (default value), - torch.float32 will be used even if ``autocast`` is enabled. - grad_scaler: - disable_autocast: - deallocate_pipeline_outputs: If :obj:`True`, free the data of the output tensor of - each pipeline stage. Experimental. - sequence_parallel_enabled: Set to :obj:`True` for this function to handle sequence length. - When :obj:`True`, the sequence length on each tensor model parallel rank is updated - to :math:`original\_sequence\_length / tensor\_model\_parallel\_world\_size`. - - Returns: - a list of loss `torch.Tensor`s if the last stage, empty list otherwise. - """ - if not isinstance(model, list): - raise RuntimeError("`model` must be a list of `nn.Module`'s'") - - if deallocate_pipeline_outputs: - warnings.warn( - "`deallocate_pipeline_outputs` is experimental and subject to change. " - "This option is not recommended." - ) - - # mypy will blame the following if statement - if sequence_parallel_enabled: - seq_length, batch_size, hidden = tensor_shape - tensor_shape = ( - seq_length // parallel_state.get_tensor_model_parallel_world_size(), - batch_size, - hidden, - ) - - num_model_chunks: int = len(model) - input_tensors: List[List[Union[None, torch.Tensor]]] = [ - [] for _ in range(num_model_chunks) - ] - output_tensors: List[List[Union[None, torch.Tensor]]] = [ - [] for _ in range(num_model_chunks) - ] - curr_iters: List[int] = [0 for _ in range(num_model_chunks)] - losses_reduced: List[Union[None, torch.Tensor]] = [] - if not forward_only: - output_tensor_grads: List[List[Union[None, torch.Tensor]]] = [ - [] for _ in range(num_model_chunks) - ] - - pipeline_parallel_size: int = parallel_state.get_pipeline_model_parallel_world_size() - pipeline_parallel_rank: int = parallel_state.get_pipeline_model_parallel_rank() - - # Compute number of warmup and remaining microbatches. - num_microbatches: int = get_num_microbatches() * num_model_chunks - all_warmup_microbatches: bool = False - if forward_only: - num_warmup_microbatches: int = num_microbatches - else: - # Run all forward passes and then all backward passes if number of - # microbatches is just the number of pipeline stages. - # Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on - # all workers, followed by more microbatches after depending on - # stage ID (more forward passes for earlier stages, later stages can - # immediately start with 1F1B). - if get_num_microbatches() == pipeline_parallel_size: - num_warmup_microbatches = num_microbatches - all_warmup_microbatches = True - else: - num_warmup_microbatches = ( - pipeline_parallel_size - pipeline_parallel_rank - 1 - ) * 2 - num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size - num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) - num_microbatches_remaining: int = num_microbatches - num_warmup_microbatches - - _logger.info( - f"num_microbatches: {num_microbatches}, " - f"num_warmup_microbatches: {num_warmup_microbatches}, " - f"num_microbatches_remaining: {num_microbatches_remaining}" - ) - - ################################################################################################################### - # Helper function definitions. - ################################################################################################################### - def get_model_chunk_id(microbatch_id: int, forward: bool) -> int: - """Helper function to get the model chunk ID given the iteration number.""" - pipeline_parallel_size = parallel_state.get_pipeline_model_parallel_world_size() - microbatch_id_in_group = microbatch_id % ( - pipeline_parallel_size * num_model_chunks - ) - model_chunk_id = microbatch_id_in_group // pipeline_parallel_size - if not forward: - model_chunk_id = num_model_chunks - model_chunk_id - 1 - return model_chunk_id - - def forward_step_helper(microbatch_id: int, curr_iters: List[int]) -> torch.Tensor: - """Helper method to run forward step with model split into chunks - - (run set_virtual_pipeline_model_parallel_rank() before calling forward_step()). - """ - model_chunk_id = get_model_chunk_id(microbatch_id, forward=True) - parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id) - - # forward step - if parallel_state.is_pipeline_first_stage() and len( - input_tensors[model_chunk_id] - ) == len(output_tensors[model_chunk_id]): - input_tensors[model_chunk_id].append(None) - input_tensor = input_tensors[model_chunk_id][-1] - output_tensor = forward_step( - forward_step_func, - get_kth_microbatch(batch, curr_iters[model_chunk_id]), - model[model_chunk_id], - input_tensor, - losses_reduced, - dtype, - disable_autocast, - ) - curr_iters[model_chunk_id] += 1 - output_tensors[model_chunk_id].append(output_tensor) - - # if forward-only, no need to save tensors for a backward pass - if forward_only: - input_tensors[model_chunk_id].pop() - output_tensors[model_chunk_id].pop() - - return output_tensor - - def backward_step_helper(microbatch_id: int) -> torch.Tensor: - """Helper method to run backward step with model split into chunks - - (run set_virtual_pipeline_model_parallel_rank() before calling backward_step()). - """ - model_chunk_id = get_model_chunk_id(microbatch_id, forward=False) - model_type = get_model_type(model[model_chunk_id]) - parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id) - - if parallel_state.is_pipeline_last_stage(): - if len(output_tensor_grads[model_chunk_id]) == 0: - output_tensor_grads[model_chunk_id].append(None) - input_tensor = input_tensors[model_chunk_id].pop(0) - output_tensor = output_tensors[model_chunk_id].pop(0) - output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0) - input_tensor_grad = backward_step( - input_tensor, - output_tensor, - output_tensor_grad, - model_type=model_type, - grad_scaler=grad_scaler, - deallocate_pipeline_outputs=deallocate_pipeline_outputs, - ) - - return input_tensor_grad - - ################################################################################################################### - # Run warmup forward passes. - ################################################################################################################### - parallel_state.set_virtual_pipeline_model_parallel_rank(0) - input_tensors[0].append( - p2p_communication.recv_forward( - tensor_shape=tensor_shape, - dtype=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - ) - ) - _logger.info("Warmup phase") - for k in range(num_warmup_microbatches): - _logger.debug(f"warmup iter: {k} / {num_warmup_microbatches}") - output_tensor = forward_step_helper(k, curr_iters) - - # Determine if tensor should be received from previous stage. - next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True) - recv_prev = True - if parallel_state.is_pipeline_first_stage(ignore_virtual=True): - if next_forward_model_chunk_id == 0: - recv_prev = False - if k == (num_microbatches - 1): - recv_prev = False - _logger.debug( - f"next fwd model chunk ID: {next_forward_model_chunk_id}, recv_prev: {recv_prev}" - ) - - # Don't send tensor downstream if on last stage. - if parallel_state.is_pipeline_last_stage(): - _logger.debug("Pipeline last stage, not sending tensor downstream") - output_tensor = None - - # Send and receive tensors as appropriate (send tensors computed - # in this iteration; receive tensors for next iteration). - if ( - k == (num_warmup_microbatches - 1) - and not forward_only - and not all_warmup_microbatches - ): - input_tensor_grad = None - recv_next = True - if parallel_state.is_pipeline_last_stage(ignore_virtual=True): - recv_next = False - _logger.debug("send fwd&bwd and receive fwd&bwd") - ( - input_tensor, - output_tensor_grad, - ) = p2p_communication.send_forward_backward_recv_forward_backward( - output_tensor, - input_tensor_grad, - recv_prev=recv_prev, - recv_next=recv_next, - tensor_shape=tensor_shape, - dtype=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - ) - output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad) - else: - _logger.debug("send fwd and receive fwd") - input_tensor = p2p_communication.send_forward_recv_forward( - output_tensor, - recv_prev=recv_prev, - tensor_shape=tensor_shape, - dtype=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - ) - input_tensors[next_forward_model_chunk_id].append(input_tensor) - free_output_tensor(output_tensor, deallocate_pipeline_outputs) - - ################################################################################################################### - # Run 1F1B in steady state. - ################################################################################################################### - _logger.info("Steady phase") - for k in range(num_microbatches_remaining): - # Forward pass. - _logger.debug(f" steady phase iter {k} / {num_microbatches_remaining}") - forward_k = k + num_warmup_microbatches - output_tensor = forward_step_helper(forward_k, curr_iters) - - # Backward pass. - backward_k = k - input_tensor_grad = backward_step_helper(backward_k) - - # Send output_tensor and input_tensor_grad, receive input_tensor - # and output_tensor_grad. - - # Determine if current stage has anything to send in either direction, - # otherwise set tensor to None. - forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True) - parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id) - if parallel_state.is_pipeline_last_stage(): - output_tensor = None - - backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False) - parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id) - _logger.debug( - f"fwd/bwd model chunk id: {forward_model_chunk_id}/{backward_model_chunk_id}" - ) - if parallel_state.is_pipeline_first_stage(): - input_tensor_grad = None - - # Determine if peers are sending, and where in data structure to put - # received tensors. - recv_prev = True - if parallel_state.is_pipeline_first_stage(ignore_virtual=True): - # First stage is ahead of last stage by (pipeline_parallel_size - 1). - next_forward_model_chunk_id = get_model_chunk_id( - forward_k - (pipeline_parallel_size - 1), forward=True - ) - if next_forward_model_chunk_id == (num_model_chunks - 1): - recv_prev = False - next_forward_model_chunk_id += 1 - else: - next_forward_model_chunk_id = get_model_chunk_id( - forward_k + 1, forward=True - ) - - recv_next = True - if parallel_state.is_pipeline_last_stage(ignore_virtual=True): - # Last stage is ahead of first stage by (pipeline_parallel_size - 1). - next_backward_model_chunk_id = get_model_chunk_id( - backward_k - (pipeline_parallel_size - 1), forward=False - ) - if next_backward_model_chunk_id == 0: - recv_next = False - next_backward_model_chunk_id -= 1 - else: - next_backward_model_chunk_id = get_model_chunk_id( - backward_k + 1, forward=False - ) - - # If last iteration, don't receive; we already received one extra - # before the start of the for loop. - if k == (num_microbatches_remaining - 1): - recv_prev = False - - # Communicate tensors. - _logger.debug("send fwd&bwd and receive fwd&bwd") - ( - input_tensor, - output_tensor_grad, - ) = p2p_communication.send_forward_backward_recv_forward_backward( - output_tensor, - input_tensor_grad, - recv_prev=recv_prev, - recv_next=recv_next, - tensor_shape=tensor_shape, - dtype=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - ) - free_output_tensor(output_tensor, deallocate_pipeline_outputs) - - # Put input_tensor and output_tensor_grad in data structures in the - # right location. - if recv_prev: - input_tensors[next_forward_model_chunk_id].append(input_tensor) - if recv_next: - output_tensor_grads[next_backward_model_chunk_id].append(output_tensor_grad) - - ################################################################################################################### - # Run cooldown backward passes (flush out pipeline). - ################################################################################################################### - _logger.info("Cooldown phase") - if not forward_only: - if all_warmup_microbatches: - output_tensor_grads[num_model_chunks - 1].append( - p2p_communication.recv_backward( - tensor_shape=tensor_shape, - dtype=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - ) - ) - for k in range(num_microbatches_remaining, num_microbatches): - _logger.debug( - f"cooldown iter {k} in range({num_microbatches_remaining}, {num_microbatches})" - ) - input_tensor_grad = backward_step_helper(k) - next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False) - recv_next = True - if parallel_state.is_pipeline_last_stage(ignore_virtual=True): - if next_backward_model_chunk_id == (num_model_chunks - 1): - recv_next = False - if k == (num_microbatches - 1): - recv_next = False - output_tensor_grads[next_backward_model_chunk_id].append( - p2p_communication.send_backward_recv_backward( - input_tensor_grad, - recv_next=recv_next, - tensor_shape=tensor_shape, - dtype=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - ) - ) - - return losses_reduced diff --git a/apex/transformer/pipeline_parallel/schedules/fwd_bwd_pipelining_without_interleaving.py b/apex/transformer/pipeline_parallel/schedules/fwd_bwd_pipelining_without_interleaving.py deleted file mode 100644 index 5dc2933..0000000 --- a/apex/transformer/pipeline_parallel/schedules/fwd_bwd_pipelining_without_interleaving.py +++ /dev/null @@ -1,489 +0,0 @@ -from typing import Union, List, Optional, Sequence -import warnings - -import torch - -from apex.transformer import parallel_state -from apex.transformer.enums import ModelType -from apex.transformer.pipeline_parallel import p2p_communication -from apex.transformer.pipeline_parallel.p2p_communication import FutureTensor -from apex.transformer.pipeline_parallel.utils import get_kth_microbatch -from apex.transformer.pipeline_parallel.utils import listify_model -from apex.transformer.pipeline_parallel.utils import get_num_microbatches -from apex.transformer.pipeline_parallel.utils import get_model_type -from apex.transformer.pipeline_parallel.schedules.common import Batch -from apex.transformer.pipeline_parallel.schedules.common import FwdStepFunc -from apex.transformer.pipeline_parallel.schedules.common import backward_step -from apex.transformer.pipeline_parallel.schedules.common import forward_step -from apex.transformer.pipeline_parallel.schedules.common import free_output_tensor -from apex.transformer.log_util import get_transformer_logger - - -__all__ = ["forward_backward_pipelining_without_interleaving"] - - -_logger = get_transformer_logger(__name__) - - -def get_tensor_shapes( - rank: int, - model_type: ModelType, - *, - tensor_shape: Union[List[int], torch.Size], - decoder_sequence_length: Optional[int] = None, - sequence_parallel_enabled: bool = False, -) -> Sequence[Sequence[int]]: - """Get tensors shapes - - Args: - rank: pipeline parallel rank - model_type: - - Keyword Args: - tensor_shape: - decoder_sequence_length: - sequence_parallel_enabled: - """ - # Determine right tensor sizes (based on position of rank with respect to split - # rank) and model size. - # Send two tensors if model is T5 and rank is in decoder stage: - # first tensor is decoder (pre-transpose), - # second tensor is encoder (post-transpose). - # If model is T5 and rank is at the boundary: - # send one tensor (post-transpose from encoder). - # Otherwise, send one tensor (pre-transpose). - assert ( - len(tensor_shape) == 3 - ), f"`tensor_shape` should be [sequence_length, micro_batch_size, hidden_size] but {tensor_shape}" - - sequence_length, micro_batch_size, hidden_size = tensor_shape - - tensor_shapes = [] - - if sequence_parallel_enabled: - seq_length = sequence_length // parallel_state.get_tensor_model_parallel_world_size() - else: - seq_length = sequence_length - - if model_type == ModelType.encoder_and_decoder: - - if sequence_parallel_enabled: - dec_seq_length = decoder_sequence_length // parallel_state.get_tensor_model_parallel_world_size() - else: - dec_seq_length = decoder_sequence_length - - if parallel_state.is_pipeline_stage_before_split(rank): - tensor_shapes.append((seq_length, micro_batch_size, hidden_size)) - else: - tensor_shapes.append((dec_seq_length, micro_batch_size, hidden_size)) - tensor_shapes.append((seq_length, micro_batch_size, hidden_size)) - else: - tensor_shapes.append((seq_length, micro_batch_size, hidden_size)) - - return tensor_shapes - - -def recv_forward( - tensor_shapes: List[Union[None, List[int]]], - *, - dtype: Optional[torch.dtype] = None, - async_comm: bool = False, - sequence_parallel_enabled: bool = False, -) -> List[Union[None, torch.Tensor, FutureTensor]]: - input_tensors = [] - for tensor_shape in tensor_shapes: - if tensor_shape is None: - input_tensors.append(None) - else: - input_tensors.append( - p2p_communication.recv_forward( - tensor_shape=tensor_shape, - dtype=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - ) - ) - return input_tensors - - -def recv_backward( - tensor_shapes: List[Union[None, List[int]]], - *, - dtype: Optional[torch.dtype] = None, - async_comm: bool = False, - sequence_parallel_enabled: bool = False, -) -> List[Union[None, torch.Tensor, FutureTensor]]: - output_tensor_grads = [] - for tensor_shape in tensor_shapes: - if tensor_shape is None: - output_tensor_grads.append(None) - else: - output_tensor_grads.append( - p2p_communication.recv_backward( - tensor_shape=tensor_shape, - dtype=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - ) - ) - return output_tensor_grads - - -def send_forward( - output_tensors: Union[torch.Tensor, List[Union[None, torch.Tensor]]], - tensor_shapes: List[Union[None, List[int]]], - *, - dtype: Optional[torch.dtype] = None, - async_comm: bool = False, - sequence_parallel_enabled: bool = False, -) -> None: - if not isinstance(output_tensors, list): - output_tensors = [output_tensors] - for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes): - if tensor_shape is None: - continue - p2p_communication.send_forward( - output_tensor, - tensor_shape=tensor_shape, - dtype=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - ) - - -def send_backward( - input_tensor_grads: Union[torch.Tensor, List[Union[None, torch.Tensor]]], - tensor_shapes: List[Union[None, List[int]]], - *, - dtype: Optional[torch.dtype] = None, - async_comm: bool = False, - sequence_parallel_enabled: bool = False, -) -> None: - if not isinstance(input_tensor_grads, list): - input_tensor_grads = [input_tensor_grads] - for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes): - if tensor_shape is None: - continue - p2p_communication.send_backward( - input_tensor_grad, - tensor_shape=tensor_shape, - dtype=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - ) - - -def send_forward_recv_backward( - output_tensors: Union[torch.Tensor, List[Union[None, torch.Tensor]]], - tensor_shapes: List[Union[None, List[int]]], - *, - dtype: Optional[torch.dtype] = None, - async_comm: bool = False, - sequence_parallel_enabled: bool = False, -) -> List[Union[None, torch.Tensor, FutureTensor]]: - if not isinstance(output_tensors, list): - output_tensors = [output_tensors] - output_tensor_grads = [] - for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes): - if tensor_shape is None: - output_tensor_grads.append(None) - continue - output_tensor_grad = p2p_communication.send_forward_recv_backward( - output_tensor, - tensor_shape=tensor_shape, - dtype=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - ) - output_tensor_grads.append(output_tensor_grad) - return output_tensor_grads - - -def send_backward_recv_forward( - input_tensor_grads: Union[torch.Tensor, List[Union[None, torch.Tensor]]], - tensor_shapes: List[Union[None, List[int]]], - *, - dtype: Optional[torch.dtype] = None, - async_comm: bool = False, - sequence_parallel_enabled: bool = False, -) -> List[Union[None, torch.Tensor, FutureTensor]]: - if not isinstance(input_tensor_grads, list): - input_tensor_grads = [input_tensor_grads] - input_tensors = [] - for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes): - if tensor_shape is None: - input_tensors.append(None) - continue - input_tensor = p2p_communication.send_backward_recv_forward( - input_tensor_grad, - tensor_shape=tensor_shape, - dtype=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - ) - input_tensors.append(input_tensor) - return input_tensors - - -def forward_backward_pipelining_without_interleaving( - forward_step_func: FwdStepFunc, - batch: Optional[Batch], - model: Union[torch.nn.Module, List[torch.nn.Module]], - *, - forward_only: bool, - tensor_shape: Optional[Union[List[int], torch.Size]] = None, - decoder_sequence_length: Optional[int] = None, - dtype: Optional[torch.dtype] = None, - grad_scaler: Optional[torch.cuda.amp.GradScaler] = None, - disable_autocast: bool = False, - deallocate_pipeline_outputs: bool = False, - async_comm: bool = False, - sequence_parallel_enabled: bool = False, - **kwargs, -) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]: - """Run non-interleaved 1F1B schedule, with communication between pipeline stages. - - This pipeline parallel scheduling consists of three steps: - 1. warmup - 2. 1F1B a.k.a. steady state - 3. cooldown if not forward_only - - Args: - forward_step_func: A function which takes a minibatch and model as its arguments and - returns model's forward output and the loss function. - The loss function is supposed to take one `torch.Tensor` and - return a `torch.Tensor` of loss and a dictionary of `str` and `torch.Tensor`. - batch: A minibatch, i.e., a list of `torch.Tensor`'s. - model: A `torch.nn.Module` or a list of `torch.nn.Module`. - - Keyword args: - forward_only: - tensor_shape: Shape of tensor. The tensor is expected to be 3D and its order of dimension - is supposed to be ``(sequence, batch, hidden)``. - dtype: dtype used in p2p communication. If ``None`` (default value), - torch.float32 will be used even if ``autocast`` is enabled. - grad_scaler: - disable_autocast: - deallocate_pipeline_outputs: If :obj:`True`, free the data of the output tensor of - each pipeline stage. Experimental. - sequence_parallel_enabled: Set to :obj:`True` for this function to handle sequence length. - When :obj:`True`, the sequence length on each tensor model parallel rank is updated - to :math:`original\_sequence\_length / tensor\_model\_parallel\_world\_size`. - - Returns: - a list of loss `torch.Tensor`s if the last stage, empty list otherwise. - """ - # timers = get_timers() - - if deallocate_pipeline_outputs: - warnings.warn( - "`deallocate_pipeline_outputs` is experimental and subject to change. " - "This option is not recommended." - ) - - model: List[torch.nn.Module] = listify_model(model) - if len(model) != 1: - msg = f"`model` is expected be a `nn.Module`, but {type(model)}" - raise RuntimeError(msg) - model: torch.nn.Module = model[0] - - # Compute number of warmup microbatches. - num_microbatches: int = get_num_microbatches() - num_warmup_microbatches: int = ( - parallel_state.get_pipeline_model_parallel_world_size() - parallel_state.get_pipeline_model_parallel_rank() - 1 - ) - num_warmup_microbatches: int = min(num_warmup_microbatches, num_microbatches) - num_microbatches_remaining: int = num_microbatches - num_warmup_microbatches - - model_type = get_model_type(model) - rank: int = parallel_state.get_pipeline_model_parallel_rank() - recv_tensor_shapes: List[List[int]] = get_tensor_shapes( - rank - 1, - model_type, - tensor_shape=tensor_shape, - decoder_sequence_length=decoder_sequence_length, - sequence_parallel_enabled=sequence_parallel_enabled, - ) - send_tensor_shapes: List[List[int]] = get_tensor_shapes( - rank, - model_type, - tensor_shape=tensor_shape, - decoder_sequence_length=decoder_sequence_length, - sequence_parallel_enabled=sequence_parallel_enabled, - ) - - _logger.info( - f"num_microbatches: {num_microbatches}, " - f"num_warmup_microbatches: {num_warmup_microbatches}, " - f"num_microbatches_remaining: {num_microbatches_remaining}" - ) - - # Input, output tensors only need to be saved when doing backward passes - input_tensors: List[Union[None, torch.Tensor]] = [] - output_tensors: List[Union[None, torch.Tensor]] = [] - losses_reduced: List[Union[None, torch.Tensor]] = [] - ################################################################################################################### - # Run warmup forward passes. - ################################################################################################################### - _logger.info("Warmup") - for i in range(num_warmup_microbatches): - _logger.debug(f"warmup iter: {i} / {num_warmup_microbatches}") - _logger.debug("receive fwd") - input_tensor = recv_forward( - tensor_shapes=recv_tensor_shapes, - dtype=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - ) - cur_microbatch: Optional[torch.Tensor] = get_kth_microbatch(batch, i) - output_tensor = forward_step( - forward_step_func, - cur_microbatch, - model, - input_tensor, - losses_reduced, - dtype, - disable_autocast, - ) - _logger.debug("send fwd") - send_forward( - output_tensor, - tensor_shapes=send_tensor_shapes, - dtype=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - ) - - if not forward_only: - input_tensors.append(input_tensor) - output_tensors.append(output_tensor) - free_output_tensor(output_tensor, deallocate_pipeline_outputs) - - # Before running 1F1B, need to receive first forward tensor. - # If all microbatches are run in warmup / cooldown phase, then no need to - # receive this tensor here. - if num_microbatches_remaining > 0: - _logger.debug("recv_forward before steady state start") - input_tensor: List[Union[None, torch.Tensor, FutureTensor]] = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype, async_comm=async_comm) - - ################################################################################################################### - # Run 1F1B in steady state. - ################################################################################################################### - _logger.info("Steady phase") - for i in range(num_microbatches_remaining): - _logger.debug(f"steady iter: {i} / {num_microbatches_remaining}") - last_iteration: bool = i == (num_microbatches_remaining - 1) - - cur_microbatch: Optional[torch.Tensor] = get_kth_microbatch(batch, i + num_warmup_microbatches) - output_tensor: Union[torch.Tensor, Sequence[torch.Tensor]] = forward_step( - forward_step_func, - cur_microbatch, - model, - input_tensor, - losses_reduced, - dtype, - disable_autocast, - ) - if forward_only: - _logger.debug("send fwd") - send_forward( - output_tensor, - tensor_shapes=send_tensor_shapes, - dtype=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - ) - - if not last_iteration: - _logger.debug("receive fwd (last iteration)") - input_tensor = recv_forward( - tensor_shapes=recv_tensor_shapes, - dtype=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - ) - - else: - _logger.debug("send fwd & receive bwd") - output_tensor_grad = send_forward_recv_backward( - output_tensor, - tensor_shapes=send_tensor_shapes, - dtype=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - ) - - # Add input_tensor and output_tensor to end of list. - input_tensors.append(input_tensor) - output_tensors.append(output_tensor) - free_output_tensor(output_tensor, deallocate_pipeline_outputs) - - # Pop input_tensor and output_tensor from the start of the list for the backward pass. - input_tensor = input_tensors.pop(0) - output_tensor = output_tensors.pop(0) - - input_tensor_grad = backward_step( - input_tensor, - output_tensor, - output_tensor_grad, - model_type=model_type, - grad_scaler=grad_scaler, - deallocate_pipeline_outputs=deallocate_pipeline_outputs, - ) - - if last_iteration: - input_tensor = None - _logger.debug("send bwd") - send_backward( - input_tensor_grad, - tensor_shapes=recv_tensor_shapes, - dtype=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - ) - else: - _logger.debug("send bwd and receive fwd") - input_tensor = send_backward_recv_forward( - input_tensor_grad, - tensor_shapes=recv_tensor_shapes, - dtype=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - ) - ################################################################################################################### - # Run cooldown backward passes. - ################################################################################################################### - _logger.info("Cooldown phase") - if not forward_only: - for i in range(num_warmup_microbatches): - _logger.debug(f"cooldown iter: {i} / {num_warmup_microbatches}") - input_tensor = input_tensors.pop(0) - output_tensor = output_tensors.pop(0) - - _logger.debug("receive bwd") - output_tensor_grad = recv_backward( - tensor_shapes=send_tensor_shapes, - dtype=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - ) - - input_tensor_grad = backward_step( - input_tensor, - output_tensor, - output_tensor_grad, - model_type=model_type, - grad_scaler=grad_scaler, - deallocate_pipeline_outputs=deallocate_pipeline_outputs, - ) - - _logger.debug("send bwd") - send_backward( - input_tensor_grad, - tensor_shapes=recv_tensor_shapes, - dtype=dtype, - async_comm=async_comm, - sequence_parallel_enabled=sequence_parallel_enabled, - ) - - return losses_reduced diff --git a/apex/transformer/pipeline_parallel/utils.py b/apex/transformer/pipeline_parallel/utils.py deleted file mode 100644 index ae550d0..0000000 --- a/apex/transformer/pipeline_parallel/utils.py +++ /dev/null @@ -1,357 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Utilities for pipeline model parallel.""" -from typing import Optional, List, Union - -import torch -from torch.nn.parallel import DistributedDataParallel - -from apex.multi_tensor_apply import multi_tensor_applier -from apex.transformer import parallel_state -from apex.transformer.enums import ModelType -from apex.transformer.microbatches import build_num_microbatches_calculator -from apex.transformer.pipeline_parallel._timers import _Timers -if multi_tensor_applier.available: - import amp_C - - -_GLOBAL_ARGS = None -_GLOBAL_NUM_MICROBATCHES_CALCULATOR = None -_GLOBAL_TOKENIZER = None -_GLOBAL_TENSORBOARD_WRITER = None -_GLOBAL_AUTORESUME = None -_GLOBAL_TIMERS = None - - -Shape = Union[List[int], torch.Size] - - -def listify_model(model: Union[torch.nn.Module, List[torch.nn.Module]]) -> List[torch.nn.Module]: - if isinstance(model, list): - return model - return [model] - - -def _ensure_var_is_initialized(var, name): - """Make sure the input variable is not None.""" - assert var is not None, "{} is not initialized.".format(name) - - -def _ensure_var_is_not_initialized(var, name): - """Make sure the input variable is not None.""" - assert var is None, "{} is already initialized.".format(name) - - -def setup_microbatch_calculator( - rank: int, - rampup_batch_size: Optional[List[int]], - global_batch_size: int, - micro_batch_size: int, - data_parallel_size: int, -) -> None: - global _GLOBAL_NUM_MICROBATCHES_CALCULATOR - _ensure_var_is_not_initialized(_GLOBAL_NUM_MICROBATCHES_CALCULATOR, 'num microbatches calculator') - - _GLOBAL_NUM_MICROBATCHES_CALCULATOR = build_num_microbatches_calculator( - rank, rampup_batch_size, global_batch_size, micro_batch_size, data_parallel_size) - - -def _reconfigure_microbatch_calculator( - rank: int, - rampup_batch_size: Optional[List[int]], - global_batch_size: int, - micro_batch_size: int, - data_parallel_size: int, -) -> None: - if torch.distributed.get_rank() == 0: - import warnings - warnings.warn("This function is only for unittest") - global _GLOBAL_NUM_MICROBATCHES_CALCULATOR - - _GLOBAL_NUM_MICROBATCHES_CALCULATOR = build_num_microbatches_calculator( - rank, rampup_batch_size, global_batch_size, micro_batch_size, data_parallel_size) - - -def get_micro_batch_size(): - return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.micro_batch_size - - -def get_num_microbatches(): - return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get() - - -def get_current_global_batch_size(): - return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get_current_global_batch_size() - - -def update_num_microbatches(consumed_samples, consistency_check=True): - _GLOBAL_NUM_MICROBATCHES_CALCULATOR.update(consumed_samples, consistency_check) - - -# note (mkozuki): Comment out in favor of `get_kth_microbatch` -def _split_batch_into_microbatch( - batch: List[torch.Tensor], - *, - _micro_batch_size: Optional[int] = None, - _global_batch_size: Optional[int] = None, -) -> List[List[torch.Tensor]]: - micro_batch_size = _micro_batch_size - global_batch_size = _global_batch_size - if micro_batch_size is None: - micro_batch_size = get_micro_batch_size() - if global_batch_size is None: - global_batch_size = get_current_global_batch_size() - for i in range(0, global_batch_size, micro_batch_size): - yield [x[i * micro_batch_size:(i + 1) * micro_batch_size] for x in batch] - - -# TODO(mkozuki): Support non-tensor local minibatches? -def get_kth_microbatch(batch: Optional[List[torch.Tensor]], k: int) -> List[torch.Tensor]: - """Create a list of microbatches from a list of local minibatches. - - This function creates a list of `k`th microbatches from a list of local minibatches. - `a local minibatch` consists of `global_batch_size / data_parallel_size` samples. - """ - if batch is None: - return batch - micro_batch_size = get_micro_batch_size() - start = k * micro_batch_size - end = start + micro_batch_size - microbatch = list() - for x in batch: - size = x.size(0) - assert size > start and size >= end - microbatch.append(x[start:end]) - assert len(microbatch) > 0 - return microbatch - - -def get_autoresume(): - return _GLOBAL_AUTORESUME - - -def _set_timers(): - """Initialize timers.""" - global _GLOBAL_TIMERS - _ensure_var_is_not_initialized(_GLOBAL_TIMERS, "timers") - _GLOBAL_TIMERS = _Timers() - - -def get_timers(): - """Return timers.""" - _ensure_var_is_initialized(_GLOBAL_TIMERS, "timers") - return _GLOBAL_TIMERS - - -def print_rank_0(message: str) -> None: - """If distributed is initialized, print only on rank 0.""" - if torch.distributed.is_initialized(): - if torch.distributed.get_rank() == 0: - print(message, flush=True) - else: - print(message, flush=True) - - -def is_last_rank(): - return torch.distributed.get_rank() == (torch.distributed.get_world_size() - 1) - - -def print_rank_last(message): - """If distributed is initialized, print only on last rank.""" - if torch.distributed.is_initialized(): - if is_last_rank(): - print(message, flush=True) - else: - print(message, flush=True) - - -def param_is_not_shared(param: torch.nn.Parameter) -> bool: - return getattr(param, "shared", False) - - -def unwrap_model(model, module_instances=(DistributedDataParallel,)): - return_list = True - if not isinstance(model, list): - model = [model] - return_list = False - unwrapped_model = [] - for model_module in model: - while isinstance(model_module, module_instances): - model_module = model_module.module - unwrapped_model.append(model_module) - if not return_list: - return unwrapped_model[0] - return unwrapped_model - - -def get_model_type( - model: torch.nn.Module, -) -> ModelType: - """Get `model_type` of `model`. - - If ``model`` doesn't have ``model_type`` attribute, return ``ModelType.encoder_or_decoder``. - - Args: - model - """ - return getattr(unwrap_model(model), "model_type", ModelType.encoder_or_decoder) - - -def calc_params_l2_norm(model: torch.nn.Module, bf16: bool): - """Calculate l2 norm of parameters """ - # args = get_args() - if not isinstance(model, list): - model = [model] - # Remove duplicate params. - params_data = [] - for model_ in model: - for param in model_.parameters(): - is_not_shared = param_is_not_shared(param) - is_not_tp_duplicate = parallel_state.param_is_not_tensor_parallel_duplicate(param) - if is_not_shared and is_not_tp_duplicate: - if bf16: - params_data.append(param.data.float()) - else: - params_data.append(param.data) - # Calculate norm - dummy_overflow_buf = torch.cuda.IntTensor([0]) - norm, _ = multi_tensor_applier( - amp_C.multi_tensor_l2norm, dummy_overflow_buf, [params_data], False # no per-parameter norm - ) - norm_2 = norm * norm - # Sum across all model-parallel GPUs. - torch.distributed.all_reduce( - norm_2, op=torch.distributed.ReduceOp.SUM, group=parallel_state.get_model_parallel_group() - ) - return norm_2.item() ** 0.5 - - -def average_losses_across_data_parallel_group(losses): - """Reduce a tensor of losses across all GPUs.""" - averaged_losses = torch.cat([loss.clone().detach().view(1) for loss in losses]) - torch.distributed.all_reduce(averaged_losses, group=parallel_state.get_data_parallel_group()) - averaged_losses = averaged_losses / torch.distributed.get_world_size( - group=parallel_state.get_data_parallel_group() - ) - - return averaged_losses - - -def report_memory(name): - """Simple GPU memory report.""" - mega_bytes = 1024.0 * 1024.0 - string = name + " memory (MB)" - string += " | allocated: {}".format(torch.cuda.memory_allocated() / mega_bytes) - string += " | max allocated: {}".format(torch.cuda.max_memory_allocated() / mega_bytes) - string += " | reserved: {}".format(torch.cuda.memory_reserved() / mega_bytes) - string += " | max reserved: {}".format(torch.cuda.max_memory_reserved() / mega_bytes) - if parallel_state.get_data_parallel_rank() == 0: - print("[Rank {}] {}".format(torch.distributed.get_rank(), string), flush=True) - - -def print_params_min_max_norm(optimizer, iteration): - """Print min, max, and norm of all parameters.""" - index = 0 - rank = torch.distributed.get_rank() - string = "iteration, rank, index, tensor-model-parallel, min, max, norm\n" - optimizer_ = optimizer.optimizer - for param_group in optimizer_.param_groups: - for param in param_group["params"]: - index += 1 - min_ = param.data.min() - max_ = param.data.max() - norm = torch.linalg.norm(param.data) - string += "{:7d}, {:4d}, {:4d}, {:2d}, ".format( - iteration, rank, index, int(param.tensor_model_parallel) - ) - string += "{:.6E}, {:.6E}, {:.6E}\n".format(min_, max_, norm) - print(string, flush=True) - - -# NOTE (mkozuki): APEX doesn't have anything equivalent for -# `_GLOBAL_ADLR_AUTORESUME` like Megatron-LM. -# def check_adlr_autoresume_termination(iteration, model, optimizer, lr_scheduler, save: bool): -# """Check for autoresume signal and exit if it is received.""" -# from apex.ppu.checkpointing import save_checkpoint -# -# autoresume = get_adlr_autoresume() -# # Add barrier to ensure consistency. -# torch.distributed.barrier() -# if autoresume.termination_requested(): -# if save: -# save_checkpoint(iteration, model, optimizer, lr_scheduler) -# print_rank_0(">>> autoresume termination request found!") -# if torch.distributed.get_rank() == 0: -# autoresume.request_resume() -# print_rank_0(">>> training terminated. Returning") -# sys.exit(0) - - -def get_ltor_masks_and_position_ids( - data, eod_token, reset_position_ids, reset_attention_mask, eod_mask_loss -): - """Build masks and position id for left to right model.""" - - # Extract batch size and sequence length. - micro_batch_size, seq_length = data.size() - - # Attention mask (lower triangular). - if reset_attention_mask: - att_mask_batch = micro_batch_size - else: - att_mask_batch = 1 - attention_mask = torch.tril( - torch.ones((att_mask_batch, seq_length, seq_length), device=data.device) - ).view(att_mask_batch, 1, seq_length, seq_length) - - # Loss mask. - loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device) - if eod_mask_loss: - loss_mask[data == eod_token] = 0.0 - - # Position ids. - position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device) - position_ids = position_ids.unsqueeze(0).expand_as(data) - # We need to clone as the ids will be modifed based on batch index. - if reset_position_ids: - position_ids = position_ids.clone() - - if reset_position_ids or reset_attention_mask: - # Loop through the batches: - for b in range(micro_batch_size): - - # Find indecies where EOD token is. - eod_index = position_ids[b, data[b] == eod_token] - # Detach indecies from positions if going to modify positions. - if reset_position_ids: - eod_index = eod_index.clone() - - # Loop through EOD indecies: - prev_index = 0 - for j in range(eod_index.size()[0]): - i = eod_index[j] - # Mask attention loss. - if reset_attention_mask: - attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0 - # Reset positions. - if reset_position_ids: - position_ids[b, (i + 1) :] -= i + 1 - prev_index - prev_index = i + 1 - - # Convert attention mask to binary: - attention_mask = attention_mask < 0.5 - - return attention_mask, loss_mask, position_ids diff --git a/apex/transformer/tensor_parallel/__init__.py b/apex/transformer/tensor_parallel/__init__.py deleted file mode 100644 index ccad80e..0000000 --- a/apex/transformer/tensor_parallel/__init__.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Model parallel utility interface.""" - -from apex.transformer.tensor_parallel.cross_entropy import vocab_parallel_cross_entropy - -from apex.transformer.tensor_parallel.data import broadcast_data - -from apex.transformer.tensor_parallel.layers import ( - ColumnParallelLinear, - RowParallelLinear, - VocabParallelEmbedding, - set_tensor_model_parallel_attributes, - set_defaults_if_not_set_tensor_model_parallel_attributes, - copy_tensor_model_parallel_attributes, -) - -from apex.transformer.tensor_parallel.mappings import ( - copy_to_tensor_model_parallel_region, - gather_from_tensor_model_parallel_region, - reduce_from_tensor_model_parallel_region, - scatter_to_tensor_model_parallel_region, - scatter_to_sequence_parallel_region, -) - -from .random import ( - checkpoint, - get_cuda_rng_tracker, - init_checkpointed_activations_memory_buffer, - model_parallel_cuda_manual_seed, - reset_checkpointed_activations_memory_buffer, -) - -from apex.transformer.tensor_parallel.utils import split_tensor_along_last_dim - - -__all__ = [ - # cross_entropy.py - "vocab_parallel_cross_entropy", - # data.py - "broadcast_data", - # layers.py - "ColumnParallelLinear", - "RowParallelLinear", - "VocabParallelEmbedding", - "set_tensor_model_parallel_attributes", - "set_defaults_if_not_set_tensor_model_parallel_attributes", - "copy_tensor_model_parallel_attributes", - # mappings.py - "copy_to_tensor_model_parallel_region", - "gather_from_tensor_model_parallel_region", - "reduce_from_tensor_model_parallel_region", - "scatter_to_tensor_model_parallel_region", - "scatter_to_sequence_parallel_region", - # random.py - "checkpoint", - "get_cuda_rng_tracker", - "init_checkpointed_activations_memory_buffer", - "model_parallel_cuda_manual_seed", - "reset_checkpointed_activations_memory_buffer", - # utils.py - "split_tensor_along_last_dim", -] diff --git a/apex/transformer/tensor_parallel/cross_entropy.py b/apex/transformer/tensor_parallel/cross_entropy.py deleted file mode 100644 index 3918645..0000000 --- a/apex/transformer/tensor_parallel/cross_entropy.py +++ /dev/null @@ -1,103 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import torch - -from apex.transformer.parallel_state import get_tensor_model_parallel_group -from apex.transformer.parallel_state import get_tensor_model_parallel_rank -from apex.transformer.parallel_state import get_tensor_model_parallel_world_size -from apex.transformer.tensor_parallel.utils import VocabUtility - - -class _VocabParallelCrossEntropy(torch.autograd.Function): - @staticmethod - def forward(ctx, vocab_parallel_logits, target): - - # Maximum value along vocab dimension across all GPUs. - logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] - torch.distributed.all_reduce( - logits_max, op=torch.distributed.ReduceOp.MAX, group=get_tensor_model_parallel_group() - ) - # Subtract the maximum value. - vocab_parallel_logits = vocab_parallel_logits - logits_max.unsqueeze(dim=-1) - - # Get the partition's vocab indecies - get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size - partition_vocab_size = vocab_parallel_logits.size()[-1] - rank = get_tensor_model_parallel_rank() - world_size = get_tensor_model_parallel_world_size() - vocab_start_index, vocab_end_index = get_vocab_range(partition_vocab_size, rank, world_size) - - # Create a mask of valid vocab ids (1 means it needs to be masked). - target_mask = (target < vocab_start_index) | (target >= vocab_end_index) - masked_target = target.clone() - vocab_start_index - masked_target[target_mask] = 0 - - # Get predicted-logits = logits[target]. - # For Simplicity, we convert logits to a 2-D tensor with size - # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. - logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size) - masked_target_1d = masked_target.view(-1) - arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device) - predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] - predicted_logits_1d = predicted_logits_1d.clone().contiguous() - predicted_logits = predicted_logits_1d.view_as(target) - predicted_logits[target_mask] = 0.0 - # All reduce is needed to get the chunks from other GPUs. - torch.distributed.all_reduce( - predicted_logits, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group() - ) - - # Sum of exponential of logits along vocab dimension across all GPUs. - exp_logits = vocab_parallel_logits - torch.exp(vocab_parallel_logits, out=exp_logits) - sum_exp_logits = exp_logits.sum(dim=-1) - torch.distributed.all_reduce( - sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=get_tensor_model_parallel_group() - ) - - # Loss = log(sum(exp(logits))) - predicted-logit. - loss = torch.log(sum_exp_logits) - predicted_logits - - # Store softmax, target-mask and masked-target for backward pass. - exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) - ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) - - return loss - - @staticmethod - def backward(ctx, grad_output): - - # Retreive tensors from the forward path. - softmax, target_mask, masked_target_1d = ctx.saved_tensors - - # All the inputs have softmax as thier gradient. - grad_input = softmax - # For simplicity, work with the 2D gradient. - partition_vocab_size = softmax.size()[-1] - grad_2d = grad_input.view(-1, partition_vocab_size) - - # Add the gradient from matching classes. - arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) - grad_2d[arange_1d, masked_target_1d] -= 1.0 - target_mask.view(-1).float() - - # Finally elementwise multiplication with the output gradients. - grad_input.mul_(grad_output.unsqueeze(dim=-1)) - - return grad_input, None - - -def vocab_parallel_cross_entropy(vocab_parallel_logits, target): - """Helper function for the cross entropy.""" - return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target) diff --git a/apex/transformer/tensor_parallel/data.py b/apex/transformer/tensor_parallel/data.py deleted file mode 100644 index 39d6ca8..0000000 --- a/apex/transformer/tensor_parallel/data.py +++ /dev/null @@ -1,122 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import torch - -from apex.transformer.parallel_state import get_tensor_model_parallel_group -from apex.transformer.parallel_state import get_tensor_model_parallel_rank -from apex.transformer.parallel_state import get_tensor_model_parallel_src_rank - - -_MAX_DATA_DIM = 5 - - -def _check_data_types(keys, data, target_dtype): - """Check that all the keys have the same target data type.""" - for key in keys: - assert data[key].dtype == target_dtype, ( - "{} has data type {} which " - "is different than {}".format(key, data[key].dtype, target_dtype) - ) - - -def _build_key_size_numel_dictionaries(keys, data): - """Build the size on rank 0 and broadcast.""" - max_dim = _MAX_DATA_DIM - sizes = [0 for _ in range(max_dim) for _ in keys] - - # Pack the sizes on rank zero. - if get_tensor_model_parallel_rank() == 0: - offset = 0 - for key in keys: - assert data[key].dim() < max_dim, "you should increase MAX_DATA_DIM" - size = data[key].size() - for i, s in enumerate(size): - sizes[i + offset] = s - offset += max_dim - - # Move to GPU and broadcast. - sizes_cuda = torch.cuda.LongTensor(sizes) - torch.distributed.broadcast( - sizes_cuda, - get_tensor_model_parallel_src_rank(), - group=get_tensor_model_parallel_group(), - ) - - # Move back to cpu and unpack. - sizes_cpu = sizes_cuda.cpu() - key_size = {} - key_numel = {} - total_numel = 0 - offset = 0 - for key in keys: - i = 0 - size = [] - numel = 1 - while sizes_cpu[offset + i] > 0: - this_size = sizes_cpu[offset + i] - size.append(this_size) - numel *= this_size - i += 1 - key_size[key] = size - key_numel[key] = numel - total_numel += numel - offset += max_dim - - return key_size, key_numel, total_numel - - -def broadcast_data(keys, data, datatype): - """Broadcast data from rank zero of each model parallel group to the - members of the same model parallel group. - - Arguments: - keys: list of keys in the data disctionary to be broadcasted - data: data dictionary of string keys and cpu tensor values. - datatype: torch data type of all tensors in data associated - with keys. - """ - # Build (key, size) and (key, number of elements) dictionaries along - # with the total number of elements on all ranks. - key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys, data) - # Pack on rank zero. - if get_tensor_model_parallel_rank() == 0: - # Check that all keys have the same data type. - _check_data_types(keys, data, datatype) - # Flatten the data associated with the keys - flatten_data = torch.cat( - [data[key].contiguous().view(-1) for key in keys], dim=0 - ).cuda() - else: - flatten_data = torch.empty( - total_numel, device=torch.cuda.current_device(), dtype=datatype - ) - - # Broadcast - torch.distributed.broadcast( - flatten_data, - get_tensor_model_parallel_src_rank(), - group=get_tensor_model_parallel_group(), - ) - - # Unpack - output = {} - offset = 0 - for key in keys: - size = key_size[key] - numel = key_numel[key] - output[key] = flatten_data.narrow(0, offset, numel).view(size) - offset += numel - - return output diff --git a/apex/transformer/tensor_parallel/layers.py b/apex/transformer/tensor_parallel/layers.py deleted file mode 100644 index e2d7e52..0000000 --- a/apex/transformer/tensor_parallel/layers.py +++ /dev/null @@ -1,780 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2021-22, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# Parts of the code here are adapted from PyTorch -# repo: https://github.com/pytorch/pytorch -from typing import Optional, Dict, Tuple, List -import warnings - -import torch -import torch.nn.functional as F -import torch.nn.init as init -from torch.nn.parameter import Parameter - -from apex._autocast_utils import _cast_if_autocast_enabled -from apex.transformer.parallel_state import get_tensor_model_parallel_group -from apex.transformer.parallel_state import get_tensor_model_parallel_rank -from apex.transformer.parallel_state import get_tensor_model_parallel_world_size -from apex.transformer.utils import divide -from apex.transformer.tensor_parallel.mappings import ( - copy_to_tensor_model_parallel_region, -) -from apex.transformer.tensor_parallel.mappings import ( - gather_from_tensor_model_parallel_region, -) -from apex.transformer.tensor_parallel.mappings import ( - reduce_from_tensor_model_parallel_region, -) -from apex.transformer.tensor_parallel.mappings import ( - scatter_to_tensor_model_parallel_region, -) -from apex.transformer.tensor_parallel.mappings import ( - reduce_scatter_to_sequence_parallel_region, -) -from apex.transformer.tensor_parallel.random import get_cuda_rng_tracker -from apex.transformer.tensor_parallel.utils import VocabUtility -from apex.transformer.log_util import get_transformer_logger - - -_logger = get_transformer_logger(__name__) - - -_grad_accum_fusion_available = True -try: - import fused_weight_gradient_mlp_cuda -except ImportError: - _grad_accum_fusion_available = False - - -_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = { - "tensor_model_parallel": False, - "partition_dim": -1, - "partition_stride": 1, -} - - -def param_is_not_tensor_parallel_duplicate(param: torch.Tensor) -> bool: - return ( - hasattr(param, "tensor_model_parallel") and param.tensor_model_parallel - ) or (get_tensor_model_parallel_rank() == 0) - - -def set_tensor_model_parallel_attributes(tensor: torch.Tensor, is_parallel: bool, dim: int, stride: int) -> None: - # Make sure the attributes are not set. - for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: - assert not hasattr(tensor, attribute) - # Set the attributes. - setattr(tensor, "tensor_model_parallel", is_parallel) - setattr(tensor, "partition_dim", dim) - setattr(tensor, "partition_stride", stride) - - -def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor: torch.Tensor) -> None: - def maybe_set(attribute, value): - if not hasattr(tensor, attribute): - setattr(tensor, attribute, value) - - for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: - maybe_set(attribute, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS[attribute]) - - -def copy_tensor_model_parallel_attributes(destination_tensor: torch.Tensor, source_tensor: torch.Tensor) -> None: - def maybe_copy(attribute): - if hasattr(source_tensor, attribute): - setattr(destination_tensor, attribute, getattr(source_tensor, attribute)) - - for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: - maybe_copy(attribute) - - -def _initialize_affine_weight_gpu(weight, init_method, partition_dim, stride=1): - """Initialize affine weight for model parallel on GPU. - - Args: - weight (Parameter): - init_method (Callable[[Tensor], None]): Taking a Tensor and initialize its elements. - partition_dim (int): Dimension to apply partition. - stride (int): - """ - - set_tensor_model_parallel_attributes( - tensor=weight, is_parallel=True, dim=partition_dim, stride=stride - ) - - with get_cuda_rng_tracker().fork(): - init_method(weight) - - -# TODO (mkozuki): Re-consider removing params_dtype from arguments to make this -# more parallel with _initialize_affine_weight_gpu -def _initialize_affine_weight_cpu( - weight, - output_size, - input_size, - per_partition_size, - partition_dim, - init_method, - stride=1, - return_master_weight=False, - *, - params_dtype=torch.float32, -): - """Initialize affine weight for model parallel. - - Build the master weight on all processes and scatter - the relevant chunk.""" - - set_tensor_model_parallel_attributes( - tensor=weight, is_parallel=True, dim=partition_dim, stride=stride - ) - - # Initialize master weight - master_weight = torch.empty( - output_size, input_size, dtype=torch.float, requires_grad=False - ) - init_method(master_weight) - master_weight = master_weight.to(dtype=params_dtype) - - # Split and copy - per_partition_per_stride_size = divide(per_partition_size, stride) - weight_list = torch.split( - master_weight, per_partition_per_stride_size, dim=partition_dim - ) - rank = get_tensor_model_parallel_rank() - world_size = get_tensor_model_parallel_world_size() - my_weight_list = weight_list[rank::world_size] - - with torch.no_grad(): - torch.cat(my_weight_list, dim=partition_dim, out=weight) - if return_master_weight: - return master_weight - return None - - -class VocabParallelEmbedding(torch.nn.Module): - """Embedding parallelized in the vocabulary dimension. - - This is mainly adapted from torch.nn.Embedding and all the default - values are kept. - Arguments: - num_embeddings: vocabulary size. - embedding_dim: size of hidden state. - init_method: method to initialize weights. - """ - - def __init__( - self, - num_embeddings: int, - embedding_dim: int, - init_method=init.xavier_normal_, - *, - params_dtype: torch.dtype=torch.float32, - use_cpu_initialization: bool = False, - ): - super().__init__() - # Keep the input dimensions. - self.num_embeddings = num_embeddings - self.embedding_dim = embedding_dim - # Set the detauls for compatibility. - self.padding_idx = None - self.max_norm = None - self.norm_type = 2.0 - self.scale_grad_by_freq = False - self.sparse = False - self._weight = None - self.tensor_model_parallel_size = get_tensor_model_parallel_world_size() - # Divide the weight matrix along the vocabulary dimension. - ( - self.vocab_start_index, - self.vocab_end_index, - ) = VocabUtility.vocab_range_from_global_vocab_size( - self.num_embeddings, - get_tensor_model_parallel_rank(), - self.tensor_model_parallel_size, - ) - self.num_embeddings_per_partition = ( - self.vocab_end_index - self.vocab_start_index - ) - - # Allocate weights and initialize. - if use_cpu_initialization: - self.weight = Parameter( - torch.empty( - self.num_embeddings_per_partition, - self.embedding_dim, - dtype=params_dtype, - ) - ) - _initialize_affine_weight_cpu( - self.weight, - self.num_embeddings, - self.embedding_dim, - self.num_embeddings_per_partition, - 0, - init_method, - params_dtype=params_dtype, - ) - else: - self.weight = Parameter( - torch.empty( - self.num_embeddings_per_partition, - self.embedding_dim, - device=torch.cuda.current_device(), - dtype=params_dtype, - ) - ) - _initialize_affine_weight_gpu( - self.weight, init_method, partition_dim=0, stride=1 - ) - - def forward(self, input_): - if self.tensor_model_parallel_size > 1: - # Build the mask. - input_mask = (input_ < self.vocab_start_index) | ( - input_ >= self.vocab_end_index - ) - # Mask the input. - masked_input = input_.clone() - self.vocab_start_index - masked_input[input_mask] = 0 - else: - masked_input = input_ - # Get the embeddings. - output_parallel = F.embedding( - masked_input, - self.weight, - self.padding_idx, - self.max_norm, - self.norm_type, - self.scale_grad_by_freq, - self.sparse, - ) - # Mask the output embedding. - if self.tensor_model_parallel_size > 1: - output_parallel[input_mask, :] = 0.0 - # Reduce across all the model parallel GPUs. - output = reduce_from_tensor_model_parallel_region(output_parallel) - return output - - -class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): - """Linear layer execution with asynchronous communication and gradient accumulation fusion in backprop.""" - - @staticmethod - def forward( - ctx, - input: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor], - gradient_accumulation_fusion: bool, - async_grad_allreduce: bool, - sequence_parallel_enabled: bool, - use_16bit_in_wgrad_accum_fusion: bool = False, - ): - ctx.save_for_backward(input, weight) - ctx.use_bias = bias is not None - ctx.gradient_accumulation_fusion = gradient_accumulation_fusion - ctx.async_grad_allreduce = async_grad_allreduce - ctx.sequence_parallel_enabled = sequence_parallel_enabled - ctx.use_16bit_in_wgrad_accum_fusion = use_16bit_in_wgrad_accum_fusion - - if ctx.sequence_parallel_enabled: - world_size = get_tensor_model_parallel_world_size() - # `input` is supposed to be 3D and its order of dimension is [sequence, batch, hidden] - shape = list(input.shape) - shape[0] *= world_size - - all_gather_buffer = torch.empty( - shape, - dtype=input.dtype, - device=torch.cuda.current_device(), - requires_grad=False, - ) - torch.distributed._all_gather_base(all_gather_buffer, input, group=get_tensor_model_parallel_group()) - total_input = all_gather_buffer - else: - total_input = input - output = torch.matmul(total_input, weight.t()) - if bias is not None: - output = output + bias - return output - - @staticmethod - def backward(ctx, grad_output): - input, weight = ctx.saved_tensors - use_bias = ctx.use_bias - - if ctx.sequence_parallel_enabled: - world_size = get_tensor_model_parallel_world_size() - shape = list(input.shape) - shape[0] *= world_size - - all_gather_buffer = torch.empty( - shape, - dtype=input.dtype, - device=torch.cuda.current_device(), - requires_grad=False, - ) - handle = torch.distributed._all_gather_base( - all_gather_buffer, - input, - group=get_tensor_model_parallel_group(), - async_op=True, - ) - total_input = all_gather_buffer - else: - total_input = input - grad_input = grad_output.matmul(weight) - - if ctx.sequence_parallel_enabled: - handle.wait() - - # Convert the tensor shapes to 2D for execution compatibility - grad_output = grad_output.view( - grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2] - ) - total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2]) - if ctx.async_grad_allreduce: - # Asynchronous all-reduce - handle = torch.distributed.all_reduce( - grad_input, group=get_tensor_model_parallel_group(), async_op=True - ) - - if ctx.sequence_parallel_enabled: - assert not ctx.async_grad_allreduce - sub_grad_input = torch.empty(input.shape, dtype=input.dtype, device=torch.cuda.current_device(), requires_grad=False) - handle = torch.distributed._reduce_scatter_base( - sub_grad_input, - grad_input, - group=get_tensor_model_parallel_group(), - async_op=True - ) - - if ctx.gradient_accumulation_fusion: - if not ctx.use_16bit_in_wgrad_accum_fusion: - fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32( - total_input, grad_output, weight.main_grad - ) - else: - fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16( - total_input, grad_output, weight.main_grad - ) - grad_weight = None - else: - grad_weight = grad_output.t().matmul(total_input) - - grad_bias = grad_output.sum(dim=0) if use_bias else None - if ctx.sequence_parallel_enabled: - handle.wait() - return sub_grad_input, grad_weight, grad_bias, None, None, None, None - if ctx.async_grad_allreduce: - handle.wait() - return grad_input, grad_weight, grad_bias, None, None, None, None - - -def linear_with_grad_accumulation_and_async_allreduce( - input: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor], - gradient_accumulation_fusion: bool, - async_grad_allreduce: bool, - sequence_parallel_enabled: bool, -) -> torch.Tensor: - args = _cast_if_autocast_enabled( - input, - weight, - bias, - gradient_accumulation_fusion, - async_grad_allreduce, - sequence_parallel_enabled, - False, # use_16bit_in_wgrad_accum_fusion - ) - with torch.cuda.amp.autocast(enabled=False): - return LinearWithGradAccumulationAndAsyncCommunication.apply(*args) - - -def linear_with_grad_accumulation_and_async_allreduce_in16bit( - input: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor], - gradient_accumulation_fusion: bool, - async_grad_allreduce: bool, - sequence_parallel_enabled: bool, -) -> torch.Tensor: - args = _cast_if_autocast_enabled( - input, - weight, - bias, - gradient_accumulation_fusion, - async_grad_allreduce, - sequence_parallel_enabled, - True, # use_16bit_in_wgrad_accum_fusion - ) - with torch.cuda.amp.autocast(enabled=False): - return LinearWithGradAccumulationAndAsyncCommunication.apply(*args) - - -class ColumnParallelLinear(torch.nn.Module): - """Linear layer with column parallelism. - - The linear layer is defined as Y = XA + b. A is parallelized along - its second dimension as A = [A_1, ..., A_p]. - - .. note:: - Input is supposed to be three dimensional and each dimension - is expected to be sequence, batch, and hidden feature, respectively. - - Arguments: - input_size: first dimension of matrix A. - output_size: second dimension of matrix A. - bias: If true, add bias - gather_output: If true, call all-gether on output and make Y avaiable - to all GPUs, otherwise, every GPU will have its output - which is Y_i = XA_i - init_method: method to initialize weights. Note that bias is always set - to zero. - stride: For the strided linear layers. - keep_master_weight_for_test: This was added for testing and should be - set to False. It returns the master weights - used for initialization. - skip_bias_add: This was added to enable performance optimations where bias - can be fused with other elementwise operations. we skip - adding bias but instead return it. - - Keyword Arguments: - no_async_tensor_model_parallel_allreduce: - params_dtype: - use_cpu_initialization: - gradient_accumulation_fusion: - accumulation_in_fp16: - sequence_parallel_enabled: - """ - - def __init__( - self, - input_size, - output_size, - bias=True, - gather_output=True, - init_method=init.xavier_normal_, - stride=1, - keep_master_weight_for_test=False, - skip_bias_add=False, - *, - no_async_tensor_model_parallel_allreduce=False, - params_dtype=torch.float32, - use_cpu_initialization=False, - gradient_accumulation_fusion=False, - accumulation_in_fp16: bool = False, - sequence_parallel_enabled: bool = False, - ): - super().__init__() - - # Keep input parameters - self.input_size = input_size - self.output_size = output_size - self.gather_output = gather_output - # Divide the weight matrix along the last dimension. - world_size = get_tensor_model_parallel_world_size() - self.output_size_per_partition = divide(output_size, world_size) - self.skip_bias_add = skip_bias_add - - # Parameters. - # Note: torch.nn.functional.linear performs XA^T + b and as a result - # we allocate the transpose. - # Initialize weight. - if use_cpu_initialization: - self.weight = Parameter( - torch.empty(self.output_size_per_partition, self.input_size, dtype=params_dtype) - ) - self.master_weight = _initialize_affine_weight_cpu( - self.weight, - self.output_size, - self.input_size, - self.output_size_per_partition, - 0, - init_method, - stride=stride, - return_master_weight=keep_master_weight_for_test, - params_dtype=params_dtype, - ) - else: - self.weight = Parameter( - torch.empty( - self.output_size_per_partition, - self.input_size, - device=torch.cuda.current_device(), - dtype=params_dtype, - ) - ) - _initialize_affine_weight_gpu(self.weight, init_method, partition_dim=0, stride=stride) - - if bias: - if use_cpu_initialization: - self.bias = Parameter(torch.empty(self.output_size_per_partition, dtype=params_dtype)) - else: - self.bias = Parameter( - torch.empty( - self.output_size_per_partition, - device=torch.cuda.current_device(), - dtype=params_dtype, - ) - ) - set_tensor_model_parallel_attributes(self.bias, True, 0, stride) - # Always initialize bias to zero. - with torch.no_grad(): - self.bias.zero_() - else: - self.register_parameter("bias", None) - - self.async_tensor_model_parallel_allreduce = ( - not no_async_tensor_model_parallel_allreduce and world_size > 1 - ) - if sequence_parallel_enabled: - if world_size <= 1: - warnings.warn( - f"`sequence_parallel_enabled` is set to `True`, but got world_size of {world_size}" - ) - # sequence_parallel_enabled = False - self.sequence_parallel_enabled = sequence_parallel_enabled - if gradient_accumulation_fusion: - if not _grad_accum_fusion_available: - # Basically, apex.transformer module users are expected to install APEX's - # `--cpp_ext` and `--cuda_ext`. The example installation command is as follows: - # `pip install --global-option="--cpp_ext" --global-option="--cuda_ext ." - # at the root of APEX repository. - warnings.warn( - "`gradient_accumulation_fusion` is set to `True` but " - "the custom CUDA extension of `fused_weight_gradient_mlp_cuda` module not " - "found. Thus `gradient_accumulation_fusion` set to `False`. " - "Note that the extension requires CUDA>=11." - ) - gradient_accumulation_fusion = False - self.gradient_accumulation_fusion = gradient_accumulation_fusion - - - if self.async_tensor_model_parallel_allreduce and self.sequence_parallel_enabled: - raise RuntimeError("`async_tensor_model_parallel_allreduce` and `sequence_parallel_enabled` cannot be enabled at the same time.") - - self._forward_impl = ( - linear_with_grad_accumulation_and_async_allreduce_in16bit - if accumulation_in_fp16 - else linear_with_grad_accumulation_and_async_allreduce - ) - - def forward(self, input_: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - """Forward of ColumnParallelLinear - - Args: - input_: 3D tensor whose order of dimension is [sequence, batch, hidden] - - Returns: - - output - - bias - """ - bias = self.bias if not self.skip_bias_add else None - - if self.async_tensor_model_parallel_allreduce or self.sequence_parallel_enabled: - input_parallel = input_ - else: - input_parallel = copy_to_tensor_model_parallel_region(input_) - - # Matrix multiply. - output_parallel = self._forward_impl( - input=input_parallel, - weight=self.weight, - bias=bias, - gradient_accumulation_fusion=self.gradient_accumulation_fusion, - async_grad_allreduce=self.async_tensor_model_parallel_allreduce, - sequence_parallel_enabled=self.sequence_parallel_enabled, - ) - if self.gather_output: - # All-gather across the partitions. - assert not self.sequence_parallel_enabled - output = gather_from_tensor_model_parallel_region(output_parallel) - else: - output = output_parallel - output_bias = self.bias if self.skip_bias_add else None - return output, output_bias - - -class RowParallelLinear(torch.nn.Module): - """Linear layer with row parallelism. - - The linear layer is defined as Y = XA + b. A is parallelized along - its first dimension and X along its second dimension as: - - - - | A_1 | - | . | - A = | . | X = [X_1, ..., X_p] - | . | - | A_p | - - - - - .. note:: - Input is supposed to be three dimensional and each dimension - is expected to be sequence, batch, and hidden feature, respectively. - - Arguments: - input_size: first dimension of matrix A. - output_size: second dimension of matrix A. - bias: If true, add bias. Note that bias is not parallelized. - input_is_parallel: If true, we assume that the input is already - split across the GPUs and we do not split - again. - init_method: method to initialize weights. Note that bias is always set - to zero. - stride: For the strided linear layers. - keep_master_weight_for_test: This was added for testing and should be - set to False. It returns the master weights - used for initialization. - skip_bias_add: This was added to enable performance optimization where bias - can be fused with other elementwise operations. We skip - adding bias but instead return it. - Keyword Arguments: - params_dtype: - use_cpu_initialization: - gradient_accumulation_fusion: - accumulation_in_fp16: - sequence_parallel_enabled: - """ - - def __init__( - self, - input_size, - output_size, - bias=True, - input_is_parallel=False, - init_method=init.xavier_normal_, - stride=1, - keep_master_weight_for_test=False, - skip_bias_add=False, - *, - params_dtype=torch.float32, - use_cpu_initialization=False, - gradient_accumulation_fusion=False, - accumulation_in_fp16: bool = False, - sequence_parallel_enabled: bool = False, - ): - super().__init__() - - # Keep input parameters - self.input_size = input_size - self.output_size = output_size - self.input_is_parallel = input_is_parallel - # Divide the weight matrix along the last dimension. - world_size = get_tensor_model_parallel_world_size() - self.input_size_per_partition = divide(input_size, world_size) - self.skip_bias_add = skip_bias_add - self.gradient_accumulation_fusion = gradient_accumulation_fusion - self.sequence_parallel_enabled = sequence_parallel_enabled - if self.sequence_parallel_enabled and not self.input_is_parallel: - raise RuntimeError("To enable `sequence_parallel_enabled`, `input_is_parallel` must be `True`") - - # as an argument to this function? - # Parameters. - # Note: torch.nn.functional.linear performs XA^T + b and as a result - # we allocate the transpose. - # Initialize weight. - if use_cpu_initialization: - self.weight = Parameter( - torch.empty( - self.output_size, self.input_size_per_partition, dtype=params_dtype - ) - ) - self.master_weight = _initialize_affine_weight_cpu( - self.weight, - self.output_size, - self.input_size, - self.input_size_per_partition, - 1, - init_method, - stride=stride, - return_master_weight=keep_master_weight_for_test, - params_dtype=params_dtype, - ) - else: - self.weight = Parameter( - torch.empty( - self.output_size, - self.input_size_per_partition, - device=torch.cuda.current_device(), - dtype=params_dtype, - ) - ) - _initialize_affine_weight_gpu( - self.weight, init_method, partition_dim=1, stride=stride - ) - if bias: - if use_cpu_initialization: - self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype)) - else: - self.bias = Parameter( - torch.empty( - self.output_size, - device=torch.cuda.current_device(), - dtype=params_dtype, - ) - ) - # Always initialize bias to zero. - with torch.no_grad(): - self.bias.zero_() - setattr(self.bias, "sequence_parallel_enabled", sequence_parallel_enabled) - else: - self.register_parameter("bias", None) - - self._forward_impl = ( - linear_with_grad_accumulation_and_async_allreduce_in16bit - if accumulation_in_fp16 - else linear_with_grad_accumulation_and_async_allreduce - ) - - def forward(self, input_: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - """Forward of RowParallelLinear - - Args: - input_: 3D tensor whose order of dimension is [sequence, batch, hidden] - - Returns: - - output - - bias - """ - # Set up backprop all-reduce. - if self.input_is_parallel: - input_parallel = input_ - else: - assert not self.sequence_parallel_enabled - input_parallel = scatter_to_tensor_model_parallel_region(input_) - # Matrix multiply. - output_parallel = self._forward_impl( - input=input_parallel, - weight=self.weight, - bias=None, - gradient_accumulation_fusion=self.gradient_accumulation_fusion, - async_grad_allreduce=False, - sequence_parallel_enabled=False, - ) - # All-reduce across all the partitions. - if self.sequence_parallel_enabled: - output_ = reduce_scatter_to_sequence_parallel_region(output_parallel) - else: - output_ = reduce_from_tensor_model_parallel_region(output_parallel) - if not self.skip_bias_add: - output = output_ + self.bias if self.bias is not None else output_ - output_bias = None - else: - output = output_ - output_bias = self.bias - return output, output_bias diff --git a/apex/transformer/tensor_parallel/mappings.py b/apex/transformer/tensor_parallel/mappings.py deleted file mode 100644 index da3454e..0000000 --- a/apex/transformer/tensor_parallel/mappings.py +++ /dev/null @@ -1,304 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2021-22, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import torch - -from apex.transformer.parallel_state import get_tensor_model_parallel_group -from apex.transformer.parallel_state import get_tensor_model_parallel_world_size -from apex.transformer.parallel_state import get_tensor_model_parallel_rank -from apex.transformer.tensor_parallel.utils import split_tensor_along_last_dim - - -def _reduce(input_: torch.Tensor) -> torch.Tensor: - """All-reduce the input tensor across model parallel group.""" - - # Bypass the function if we are using only 1 GPU. - if get_tensor_model_parallel_world_size() == 1: - return input_ - - # All-reduce. - torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group()) - - return input_ - - -def _split_along_last_dim(input_: torch.Tensor) -> torch.Tensor: - """Split the tensor along its last dimension and keep the - corresponding slice.""" - - world_size = get_tensor_model_parallel_world_size() - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - - # Split along last dimension. - input_list = split_tensor_along_last_dim(input_, world_size) - - # Note: torch.split does not create contiguous tensors by default. - rank = get_tensor_model_parallel_rank() - output = input_list[rank].contiguous() - - return output - - -def _split_along_first_dim(input_: torch.Tensor) -> torch.Tensor: - """Split the tensor along its first dimension and keep the corresponding slice.""" - world_size = get_tensor_model_parallel_world_size() - # Bypass the function if we are using only 1 GPU for tensor model parallel. - if world_size == 1: - return input_ - - # Split along first dimension. - dim_size = input_.size(0) - assert dim_size % world_size == 0 - local_dim_size = dim_size // world_size - dim_offset = get_tensor_model_parallel_rank() * local_dim_size - output = input_[dim_offset:dim_offset + local_dim_size].contiguous() - return output - - -def _gather_along_last_dim(input_: torch.Tensor) -> torch.Tensor: - """Gather tensors and concatenate along the last dimension.""" - - world_size = get_tensor_model_parallel_world_size() - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - - # Size and dimension. - last_dim = input_.dim() - 1 - rank = get_tensor_model_parallel_rank() - - tensor_list = [torch.empty_like(input_) for _ in range(world_size)] - tensor_list[rank] = input_ - torch.distributed.all_gather( - tensor_list, input_, group=get_tensor_model_parallel_group() - ) - - # Note: torch.cat already creates a contiguous tensor. - output = torch.cat(tensor_list, dim=last_dim).contiguous() - - return output - - -def _gather_along_first_dim(input_: torch.Tensor) -> torch.Tensor: - """Gather tensors and concatenate along the first dimension.""" - world_size = get_tensor_model_parallel_world_size() - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - - shape = list(input_.shape) - shape[0] *= world_size - - output = torch.empty(shape, dtype=input_.dtype, device=torch.cuda.current_device()) - torch.distributed._all_gather_base( - output, - input_.contiguous(), - group=get_tensor_model_parallel_group() - ) - return output - - -def _reduce_scatter_along_first_dim(input_: torch.Tensor) -> torch.Tensor: - """Reduce-scatter the input tensor across model parallel group.""" - world_size = get_tensor_model_parallel_world_size() - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - - shape = list(input_.shape) - assert shape[0] % world_size == 0 - shape[0] //= world_size - output = torch.empty(shape, dtype=input_.dtype, device=torch.cuda.current_device()) - torch.distributed._reduce_scatter_base( - output, - input_.contiguous(), - group=get_tensor_model_parallel_group() - ) - return output - - -class _CopyToModelParallelRegion(torch.autograd.Function): - """Pass the input to the tensor model parallel region.""" - - # FIXME(mkozuki): Definition of static symbolic methods don't look correct according to - # https://pytorch.org/docs/stable/onnx.html#static-symbolic-method - @staticmethod - def symbolic(graph, input_): - return input_ - - @staticmethod - def forward(ctx, input_): - return input_ - - @staticmethod - def backward(ctx, grad_output): - return _reduce(grad_output) - - -class _ReduceFromModelParallelRegion(torch.autograd.Function): - """All-reduce the input from the tensor model parallel region.""" - - # FIXME(mkozuki): Definition of static symbolic methods don't look correct according to - # https://pytorch.org/docs/stable/onnx.html#static-symbolic-method - @staticmethod - def symbolic(graph, input_): - return _reduce(input_) - - @staticmethod - def forward(ctx, input_): - return _reduce(input_) - - @staticmethod - def backward(ctx, grad_output): - return grad_output - - -class _ScatterToModelParallelRegion(torch.autograd.Function): - """Split the input and keep only the corresponding chuck to the rank.""" - - # FIXME(mkozuki): Definition of static symbolic methods don't look correct according to - # https://pytorch.org/docs/stable/onnx.html#static-symbolic-method - @staticmethod - def symbolic(graph, input_): - return _split_along_last_dim(input_) - - @staticmethod - def forward(ctx, input_): - return _split_along_last_dim(input_) - - @staticmethod - def backward(ctx, grad_output): - return _gather_along_last_dim(grad_output) - - -class _GatherFromModelParallelRegion(torch.autograd.Function): - """Gather the input from tensor model parallel region and concatenate.""" - - # FIXME(mkozuki): Definition of static symbolic methods don't look correct according to - # https://pytorch.org/docs/stable/onnx.html#static-symbolic-method - @staticmethod - def symbolic(graph, input_): - return _gather_along_last_dim(input_) - - @staticmethod - def forward(ctx, input_): - return _gather_along_last_dim(input_) - - @staticmethod - def backward(ctx, grad_output): - return _split_along_last_dim(grad_output) - - -class _ScatterToSequenceParallelRegion(torch.autograd.Function): - """Split the input and keep only the corresponding chunk to the rank.""" - - # FIXME(mkozuki): Definition of static symbolic methods don't look correct according to - # https://pytorch.org/docs/stable/onnx.html#static-symbolic-method - @staticmethod - def symbolic(graph, input_): - return _split_along_first_dim(input_) - - @staticmethod - def forward(ctx, input_): - return _split_along_first_dim(input_) - - @staticmethod - def backward(ctx, grad_output): - return _gather_along_first_dim(grad_output) - - -class _GatherFromSequenceParallelRegion(torch.autograd.Function): - """Gather the input from sequence parallel region and concatenate.""" - - # FIXME(mkozuki): Definition of static symbolic methods don't look correct according to - # https://pytorch.org/docs/stable/onnx.html#static-symbolic-method - @staticmethod - def symbolic(graph, input_, to_model_parallel: bool = True): - return _gather_along_first_dim(input_) - - @staticmethod - def forward(ctx, input_, to_model_parallel: bool = True): - ctx.to_model_parallel = to_model_parallel - return _gather_along_first_dim(input_) - - @staticmethod - def backward(ctx, grad_output): - if ctx.to_model_parallel: - return _reduce_scatter_along_first_dim(grad_output), None - else: - return _split_along_first_dim(grad_output), None - - -class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function): - """Reduce scatter the input from the sequence parallel region and concatenate.""" - - # FIXME(mkozuki): Definition of static symbolic methods don't look correct according to - # https://pytorch.org/docs/stable/onnx.html#static-symbolic-method - @staticmethod - def symbolic(graph, input_): - return _reduce_scatter_along_first_dim(input_) - - @staticmethod - def forward(ctx, input_): - return _reduce_scatter_along_first_dim(input_) - - @staticmethod - def backward(ctx, grad_output): - return _gather_along_first_dim(grad_output) - - -# ----------------- -# Helper functions. -# ----------------- - - -def copy_to_tensor_model_parallel_region(input_: torch.Tensor) -> torch.Tensor: - return _CopyToModelParallelRegion.apply(input_) - - -def reduce_from_tensor_model_parallel_region(input_: torch.Tensor) -> torch.Tensor: - return _ReduceFromModelParallelRegion.apply(input_) - - -def scatter_to_tensor_model_parallel_region(input_: torch.Tensor) -> torch.Tensor: - return _ScatterToModelParallelRegion.apply(input_) - - -def gather_from_tensor_model_parallel_region(input_: torch.Tensor) -> torch.Tensor: - return _GatherFromModelParallelRegion.apply(input_) - - -def scatter_to_sequence_parallel_region(input_: torch.Tensor) -> torch.Tensor: - return _ScatterToSequenceParallelRegion.apply(input_) - - -def gather_from_sequence_parallel_region(input_: torch.Tensor, to_model_parallel: bool = True) -> torch.Tensor: - return _GatherFromSequenceParallelRegion.apply(input_, to_model_parallel) - - -def reduce_scatter_to_sequence_parallel_region(input_: torch.Tensor) -> torch.Tensor: - return _ReduceScatterToSequenceParallelRegion.apply(input_) - - -__all__ = [ - "copy_to_tensor_model_parallel_region", - "reduce_from_tensor_model_parallel_region", - "scatter_to_tensor_model_parallel_region", - "gather_from_tensor_model_parallel_region", - "scatter_to_sequence_parallel_region", - "gather_from_sequence_parallel_region", - "reduce_scatter_to_sequence_parallel_region", -] diff --git a/apex/transformer/tensor_parallel/memory.py b/apex/transformer/tensor_parallel/memory.py deleted file mode 100644 index 6df9a13..0000000 --- a/apex/transformer/tensor_parallel/memory.py +++ /dev/null @@ -1,151 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. - -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# TODO(mkozuki): Remove this file as Megatron-LM seems to have done so. -import torch - - -# A dictionary of all the memory buffers allocated. -_MEM_BUFFS = dict() - - -def allocate_mem_buff(name, numel, dtype, track_usage): - """Allocate a memory buffer.""" - assert name not in _MEM_BUFFS, "memory buffer {} already allocated.".format(name) - _MEM_BUFFS[name] = MemoryBuffer(name, numel, dtype, track_usage) - return _MEM_BUFFS[name] - - -def get_mem_buff(name): - """Get the memory buffer.""" - return _MEM_BUFFS[name] - - -class MemoryBuffer: - """Contiguous memory buffer. - Allocate a contiguous memory of type `dtype` and size `numel`. It is - used to reduce memory fragmentation. - - Usage: After the allocation, the `_start` index is set tot the first - index of the memory. A memory chunk starting from `_start` index - can be `allocated` for an input tensor, with the elements of the - tensor being coppied. The buffer can be reused by resetting the - `_start` index. - - """ - - def __init__(self, name, numel, dtype, track_usage): - if torch.distributed.get_rank() == 0: - element_size = torch.tensor([], dtype=dtype).element_size() - print( - "> building the {} memory buffer with {} num elements " - "and {} dtype ({:.1f} MB)...".format( - name, numel, dtype, numel * element_size / 1024 / 1024 - ), - flush=True, - ) - self.name = name - self.numel = numel - self.dtype = dtype - self.data = torch.empty( - self.numel, - dtype=self.dtype, - device=torch.cuda.current_device(), - requires_grad=False, - ) - - # Index tracking the start of the free memory. - self._start = 0 - - # Values used for tracking usage. - self.track_usage = track_usage - if self.track_usage: - self.in_use_value = 0.0 - self.total_value = 0.0 - - def reset(self): - """Reset the buffer start index to the beginning of the buffer.""" - self._start = 0 - - def is_in_use(self): - """Whether the current buffer hold on to any memory.""" - return self._start > 0 - - def numel_in_use(self): - """Return number of elements in use.""" - return self._start - - def add(self, tensor): - """Allocate a chunk of memory from the buffer to tensor and copy - the values.""" - assert ( - tensor.dtype == self.dtype - ), "Input tensor type {} different from buffer type {}".format( - tensor.dtype, self.dtype - ) - # Number of elements of the input tensor. - tensor_numel = torch.numel(tensor) - new_start = self._start + tensor_numel - assert ( - new_start <= self.numel - ), "Not enough memory left in the buffer ({} > {})".format( - tensor_numel, self.numel - self._start - ) - # New tensor is a view into the memory. - new_tensor = self.data[self._start : new_start] - self._start = new_start - new_tensor = new_tensor.view(tensor.shape) - new_tensor.copy_(tensor) - # Return a pointer to the new tensor. - return new_tensor - - def get_data(self): - """Return the data currently in use.""" - if self.track_usage: - self.in_use_value += float(self._start) - self.total_value += float(self.numel) - return self.data[: self._start] - - def print_average_usage(self): - """Print memory usage average over time. We would like this value - to be as high as possible.""" - assert self.track_usage, "You need to enable track usage." - if torch.distributed.get_rank() == 0: - print( - " > usage of {} memory buffer: {:.2f} %".format( - self.name, self.in_use_value * 100.0 / self.total_value - ), - flush=True, - ) - - -class RingMemBuffer: - """A ring of memory buffers.""" - - def __init__(self, name, num_buffers, numel, dtype, track_usage): - self.num_buffers = num_buffers - self.buffers = [ - allocate_mem_buff(name + " {}".format(i), numel, dtype, track_usage) - for i in range(num_buffers) - ] - self._index = -1 - - def get_next_buffer(self): - self._index += 1 - self._index = self._index % self.num_buffers - buff = self.buffers[self._index] - assert not buff.is_in_use(), "buffer is already in use." - return buff diff --git a/apex/transformer/tensor_parallel/random.py b/apex/transformer/tensor_parallel/random.py deleted file mode 100644 index 4bd64cb..0000000 --- a/apex/transformer/tensor_parallel/random.py +++ /dev/null @@ -1,311 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2021-22, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# NOTE(mkozuki): This file is based on megatron-lm/mpu/random.py with some differences: -# - Not using "viewless" tensor: -# - _kernel_make_viewless_tensor -# - MakeViewlessTensor -# - make_viewless_tensor -# - assert_viewless_tensor -# - safely_set_viewless_tensor_data - -# Parts of the code here are adapted from PyTorch -# repo: https://github.com/pytorch/pytorch -import contextlib - -import torch -from torch import _C -from torch.cuda import _lazy_call, device as device_ctx_manager -from torch.utils.checkpoint import detach_variable - -from apex.transformer.parallel_state import get_tensor_model_parallel_rank -from apex.transformer.tensor_parallel.memory import allocate_mem_buff -from apex.transformer.utils import split_tensor_into_1d_equal_chunks -from apex.transformer.utils import gather_split_1d_tensor - - -# Default name for the model parallel rng tracker. -_MODEL_PARALLEL_RNG_TRACKER_NAME = "model-parallel-rng" - -# TODO(mkozuki): Remove `_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER` as megatron-lm doesn't seem to use. -# Whether apply model parallelism to checkpointed hidden states. -_CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = None - - -# TODO(mkozuki): Remove `init_checkpointed_activations_memory_buffer` as megatron-lm doesn't seem to use. -def init_checkpointed_activations_memory_buffer( - micro_batch_size, - max_position_embeddings, - hidden_size, - num_layers, - tensor_model_parallel_size, - checkpoint_num_layers, - fp16, -): - """Initializ the memory buffer for the checkpointed activations.""" - - per_layer = ( - micro_batch_size - * max_position_embeddings - * hidden_size - // tensor_model_parallel_size - ) - assert ( - num_layers % checkpoint_num_layers == 0 - ), "number of layers is not divisible by checkpoint-num-layers" - num_checkpointer_layers = num_layers // checkpoint_num_layers - numel = per_layer * num_checkpointer_layers - dtype = torch.half - if not fp16: - dtype = torch.float - - global _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER - assert ( - _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is None - ), "checkpointed activations memory buffer is already allocated." - _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = allocate_mem_buff( - "checkpointed activations", numel, dtype, track_usage=False - ) - - -# TODO(mkozuki): Remove `reset_checkpointed_activations_memory_buffer` as megatron-lm doesn't seem to use. -def reset_checkpointed_activations_memory_buffer(): - """Reset the memory used for checkpointing.""" - if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None: - _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER.reset() - - -def _set_cuda_rng_state(new_state, device=-1): - """Sets the random number generator state of the current GPU. - - Arguments: - new_state (torch.ByteTensor): The desired state - This function is adapted from PyTorch repo (torch.cuda.set_rng_state) - with a single change: the input state is not cloned. Cloning caused - major performance issues for +4 GPU cases. - """ - if hasattr(_C, "_cuda_setRNGState") and callable(_C._cuda_setRNGState): - # older PyTorch - def cb(): - with device_ctx_manager(device): - _C._cuda_setRNGState(new_state) - - else: - # newer PyTorch - if device == -1: - device = torch.device("cuda") - elif isinstance(device, str): - device = torch.device(device) - elif isinstance(device, int): - device = torch.device("cuda", device) - - def cb(): - idx = device.index - if idx is None: - idx = torch.cuda.current_device() - default_generator = torch.cuda.default_generators[idx] - default_generator.set_state(new_state) - - _lazy_call(cb) - - -class CudaRNGStatesTracker: - """Tracker for the cuda RNG states. - - Using the `add` method, a cuda rng state is initialized based on - the input `seed` and is assigned to `name`. Later, by forking the - rng state, we can perform operations and return to our starting - cuda state. - """ - - def __init__(self): - # Map from a string name to the cuda rng state. - self.states_ = {} - # Seeds are just for book keeping and ensure no seed is set twice. - self.seeds_ = set() - - def reset(self): - """Set to the initial state (no tracker).""" - self.states_ = {} - self.seeds_ = set() - - def get_states(self): - """Get rng states. Copy the dictionary so we have direct - pointers to the states, not just a pointer to the dictionary.""" - states = {} - for name in self.states_: - states[name] = self.states_[name] - return states - - def set_states(self, states): - """Set the rng states. For efficiency purposes, we do not check - the size of seed for compatibility.""" - self.states_ = states - - def add(self, name, seed): - """Track the rng state.""" - # Check seed is not already used. - if seed in self.seeds_: - raise Exception("seed {} already exists".format(seed)) - self.seeds_.add(seed) - # Check that state is not already defined. - if name in self.states_: - raise Exception("cuda rng state {} already exists".format(name)) - # Get the current rng state. - orig_rng_state = torch.cuda.get_rng_state() - # Set the new state and store it. - torch.cuda.manual_seed(seed) - self.states_[name] = torch.cuda.get_rng_state() - # Reset rng state to what it was. - _set_cuda_rng_state(orig_rng_state) - - @contextlib.contextmanager - def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME): - """Fork the cuda rng state, perform operations, and exit with - the original state.""" - # Check if we have added the state - if name not in self.states_: - raise Exception("cuda rng state {} is not added".format(name)) - # Store current rng state. - orig_cuda_rng_state = torch.cuda.get_rng_state() - # Set rng state to the desired one - _set_cuda_rng_state(self.states_[name]) - # Do the stuff we wanted to do. - try: - yield - finally: - # Update the current rng state for later use. - self.states_[name] = torch.cuda.get_rng_state() - # And set the state to the original state we started with. - _set_cuda_rng_state(orig_cuda_rng_state) - - -# RNG tracker object. -_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() - - -def get_cuda_rng_tracker(): - """Get cuda rng tracker.""" - return _CUDA_RNG_STATE_TRACKER - - -def model_parallel_cuda_manual_seed(seed): - """Initialize model parallel cuda seed. - - This function should be called after the model parallel is - initialized. Also, no torch.cuda.manual_seed should be called - after this function. Basically, this is replacement for that - function. - Two set of RNG states are tracked: - default state: This is for data parallelism and is the same among a - set of model parallel GPUs but different across - different model paralle groups. This is used for - example for dropout in the non-tensor-model-parallel regions. - tensor-model-parallel state: This state is different among a set of model - parallel GPUs, but the same across data parallel - groups. This is used for example for dropout in - model parallel regions. - """ - # 2718 is just for fun and any POSITIVE value will work. - offset = seed + 2718 - tensor_model_parallel_seed = offset + get_tensor_model_parallel_rank() - # Data parallel gets the original seed. - data_parallel_seed = seed - - _CUDA_RNG_STATE_TRACKER.reset() - # Set the default state. - torch.cuda.manual_seed(data_parallel_seed) - # and model parallel state. - _CUDA_RNG_STATE_TRACKER.add( - _MODEL_PARALLEL_RNG_TRACKER_NAME, tensor_model_parallel_seed - ) - - -# TODO (mkozuki): Move the below gradient checkpoint related features to another (new) file. -class CheckpointFunction(torch.autograd.Function): - """This function is adapted from torch.utils.checkpoint with - two main changes: - 1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state` - 2) the states in the model parallel tracker are also properly - tracked/set/reset. - """ - - @staticmethod - def forward(ctx, run_function, distribute_saved_activations, *args): - ctx.run_function = run_function - ctx.distribute_saved_activations = distribute_saved_activations - - # Copy the rng states. - ctx.fwd_cpu_rng_state = torch.get_rng_state() - ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() - ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() - - with torch.no_grad(): - outputs = run_function(*args) - - # Divide hidden states across model parallel group and only keep - # the chunk corresponding to the current rank. - if ctx.distribute_saved_activations: - ctx.input_0_shape = args[0].shape - - # Store everything. - ctx.save_for_backward(*args) - return outputs - - @staticmethod - def backward(ctx, *args): - if not torch.autograd._is_checkpoint_valid(): - raise RuntimeError( - "Checkpointing is not compatible with .grad(), " - "please use .backward() if possible" - ) - inputs = ctx.saved_tensors - - # Store the current states. - bwd_cpu_rng_state = torch.get_rng_state() - bwd_cuda_rng_state = torch.cuda.get_rng_state() - bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() - - # Set the states to what it used to be before the forward pass. - torch.set_rng_state(ctx.fwd_cpu_rng_state) - _set_cuda_rng_state(ctx.fwd_cuda_rng_state) - get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker) - - # Compute the forward pass. - detached_inputs = detach_variable(inputs) - with torch.enable_grad(): - outputs = ctx.run_function(*detached_inputs) - - # Set the states back to what it was at the start of this function. - torch.set_rng_state(bwd_cpu_rng_state) - _set_cuda_rng_state(bwd_cuda_rng_state) - get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker) - - if isinstance(outputs, torch.Tensor): - outputs = (outputs,) - torch.autograd.backward(outputs, args) - grads = tuple( - inp.grad if isinstance(inp, torch.Tensor) else inp - for inp in detached_inputs - ) - return (None, None) + grads - - -# NOTE(mkozuki): It doesn't look like `distribute_saved_activations` is used in apex.transformer -# but I added this change to reduce the superficial difference from Megatron-LM. -def checkpoint(function, distribute_saved_activations, *args): - """Checkpoint a model or part of the model. - This has been directly copied from torch.utils.checkpoint.""" - return CheckpointFunction.apply(function, distribute_saved_activations, *args) diff --git a/apex/transformer/tensor_parallel/utils.py b/apex/transformer/tensor_parallel/utils.py deleted file mode 100644 index 85ea413..0000000 --- a/apex/transformer/tensor_parallel/utils.py +++ /dev/null @@ -1,64 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import List, Sequence - -import torch - -from apex.transformer.utils import divide - - -def split_tensor_along_last_dim( - tensor: torch.Tensor, - num_partitions: int, - contiguous_split_chunks: bool = False, -) -> List[torch.Tensor]: - """Split a tensor along its last dimension. - Arguments: - tensor: input tensor. - num_partitions: number of partitions to split the tensor - contiguous_split_chunks: If True, make each chunk contiguous - in memory. - """ - # Get the size and dimension. - last_dim = tensor.dim() - 1 - last_dim_size = divide(tensor.size()[last_dim], num_partitions) - # Split. - tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) - # Note: torch.split does not create contiguous tensors by default. - if contiguous_split_chunks: - return tuple(chunk.contiguous() for chunk in tensor_list) - - return tensor_list - - -class VocabUtility: - """Split the vocabulary into `world_size` chunks and return the - first and last index of the vocabulary belonging to the `rank` - partition: Note that indices in [fist, last)""" - - @staticmethod - def vocab_range_from_per_partition_vocab_size( - per_partition_vocab_size: int, rank, world_size: int - ) -> Sequence[int]: - index_f = rank * per_partition_vocab_size - index_l = index_f + per_partition_vocab_size - return index_f, index_l - - @staticmethod - def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int, world_size: int) -> Sequence[int]: - per_partition_vocab_size = divide(global_vocab_size, world_size) - return VocabUtility.vocab_range_from_per_partition_vocab_size( - per_partition_vocab_size, rank, world_size - ) diff --git a/apex/transformer/testing/__init__.py b/apex/transformer/testing/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/apex/transformer/testing/arguments.py b/apex/transformer/testing/arguments.py deleted file mode 100644 index f39c288..0000000 --- a/apex/transformer/testing/arguments.py +++ /dev/null @@ -1,971 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Megatron arguments.""" - -import argparse -import os - -import torch - -def parse_args(extra_args_provider=None, defaults={}, - ignore_unknown_args=False): - """Parse all arguments.""" - parser = argparse.ArgumentParser(description='Megatron-LM Arguments', - allow_abbrev=False) - - # Standard arguments. - parser = _add_network_size_args(parser) - parser = _add_regularization_args(parser) - parser = _add_training_args(parser) - parser = _add_initialization_args(parser) - parser = _add_learning_rate_args(parser) - parser = _add_checkpointing_args(parser) - parser = _add_mixed_precision_args(parser) - parser = _add_distributed_args(parser) - parser = _add_validation_args(parser) - parser = _add_data_args(parser) - parser = _add_autoresume_args(parser) - parser = _add_biencoder_args(parser) - parser = _add_vision_args(parser) - parser = _add_logging_args(parser) - - # NOTE(mkozuki): This option is added to investigate the potential of `torch.autograd.graph.save_on_cpu()`. - # ref: https://pytorch.org/docs/stable/autograd.html#torch.autograd.graph.save_on_cpu. - parser.add_argument('--cpu-offload', action='store_true', default=False, help='Turns on CPU offloading') - - # Custom arguments. - if extra_args_provider is not None: - parser = extra_args_provider(parser) - - # Parse. - if ignore_unknown_args: - args, _ = parser.parse_known_args() - else: - args = parser.parse_args() - - # Distributed args. - args.rank = int(os.getenv('RANK', '0')) - args.world_size = int(os.getenv("WORLD_SIZE", '1')) - # Tensor model parallel size. - args.tensor_model_parallel_size = min( - args.tensor_model_parallel_size, args.world_size) - assert args.world_size % args.tensor_model_parallel_size == 0, 'world size'\ - ' ({}) is not divisible by tensor model parallel size ({})'.format( - args.world_size, args.tensor_model_parallel_size) - # Pipeline model parallel size. - args.pipeline_model_parallel_size = min( - args.pipeline_model_parallel_size, - (args.world_size // args.tensor_model_parallel_size)) - args.transformer_pipeline_model_parallel_size = ( - args.pipeline_model_parallel_size - 1 - if args.standalone_embedding_stage else - args.pipeline_model_parallel_size - ) - # Checks. - model_parallel_size = args.pipeline_model_parallel_size * \ - args.tensor_model_parallel_size - assert args.world_size % model_parallel_size == 0, 'world size is not'\ - ' divisible by tensor parallel size ({}) times pipeline parallel ' \ - 'size ({})'.format(args.world_size, args.tensor_model_parallel_size, - args.pipeline_model_parallel_size) - args.data_parallel_size = args.world_size // model_parallel_size - if args.rank == 0: - print('using world size: {}, data-parallel-size: {}, ' - 'tensor-model-parallel size: {}, ' - 'pipeline-model-parallel size: {} '.format( - args.world_size, args.data_parallel_size, - args.tensor_model_parallel_size, - args.pipeline_model_parallel_size), flush=True) - if args.pipeline_model_parallel_size > 1: - if args.pipeline_model_parallel_split_rank is not None: - assert args.pipeline_model_parallel_split_rank < \ - args.pipeline_model_parallel_size, 'split rank needs'\ - ' to be less than pipeline model parallel size ({})'.format( - args.pipeline_model_parallel_size) - - # Deprecated arguments - assert args.batch_size is None, '--batch-size argument is no longer ' \ - 'valid, use --micro-batch-size instead' - del args.batch_size - assert args.warmup is None, '--warmup argument is no longer valid, use ' \ - '--lr-warmup-fraction instead' - del args.warmup - assert args.model_parallel_size is None, '--model-parallel-size is no ' \ - 'longer valid, use --tensor-model-parallel-size instead' - del args.model_parallel_size - if args.checkpoint_activations: - args.recompute_granularity = 'full' - args.recompute_method = 'uniform' - if args.rank == 0: - print('--checkpoint-activations is no longer valid, ' - 'use --recompute-granularity and --recompute-method instead. ' - 'Defaulting to recompute-granularity=full and recompute-method=uniform.') - del args.checkpoint_activations - - if args.recompute_activations: - args.recompute_granularity = 'selective' - del args.recompute_activations - - # Set input defaults. - for key in defaults: - # For default to be valid, it should not be provided in the - # arguments that are passed to the program. We check this by - # ensuring the arg is set to None. - if getattr(args, key) is not None: - if args.rank == 0: - print('WARNING: overriding default arguments for {key}:{v} \ - with {key}:{v2}'.format(key=key, v=defaults[key], - v2=getattr(args, key)), - flush=True) - else: - setattr(args, key, defaults[key]) - - # Batch size. - assert args.micro_batch_size is not None - assert args.micro_batch_size > 0 - if args.global_batch_size is None: - args.global_batch_size = args.micro_batch_size * args.data_parallel_size - if args.rank == 0: - print('setting global batch size to {}'.format( - args.global_batch_size), flush=True) - assert args.global_batch_size > 0 - if args.num_layers_per_virtual_pipeline_stage is not None: - assert args.pipeline_model_parallel_size > 2, \ - 'pipeline-model-parallel size should be greater than 2 with ' \ - 'interleaved schedule' - assert args.num_layers % args.num_layers_per_virtual_pipeline_stage == 0, \ - 'number of layers is not divisible by number of layers per virtual ' \ - 'pipeline stage' - args.virtual_pipeline_model_parallel_size = \ - (args.num_layers // args.pipeline_model_parallel_size) // \ - args.num_layers_per_virtual_pipeline_stage - else: - args.virtual_pipeline_model_parallel_size = None - - # Parameters dtype. - args.params_dtype = torch.float - if args.fp16: - assert not args.bf16 - args.params_dtype = torch.half - if args.bf16: - assert not args.fp16 - args.params_dtype = torch.bfloat16 - # bfloat16 requires gradient accumulation and all-reduce to - # be done in fp32. - if not args.accumulate_allreduce_grads_in_fp32: - args.accumulate_allreduce_grads_in_fp32 = True - if args.rank == 0: - print('accumulate and all-reduce gradients in fp32 for ' - 'bfloat16 data type.', flush=True) - - if args.rank == 0: - print('using {} for parameters ...'.format(args.params_dtype), - flush=True) - - # If we do accumulation and all-reduces in fp32, we need to have local DDP - # and we should make sure use-contiguous-buffers-in-local-ddp is not off. - if args.accumulate_allreduce_grads_in_fp32: - assert args.DDP_impl == 'local' - assert args.use_contiguous_buffers_in_local_ddp - else: - if args.gradient_accumulation_fusion: - args.gradient_accumulation_fusion = False - if args.rank == 0: - print('Gradient accumulation fusion to linear layer weight ' - 'gradient computation is supported only with fp32 ' - 'gradient accumulation. Setting gradient_accumulation_fusion ' - 'to False', flush=True) - - # For torch DDP, we do not use contiguous buffer - if args.DDP_impl == 'torch': - args.use_contiguous_buffers_in_local_ddp = False - - if args.dataloader_type is None: - args.dataloader_type = 'single' - - # Consumed tokens. - args.consumed_train_samples = 0 - args.consumed_valid_samples = 0 - - # Iteration-based training. - if args.train_iters: - # If we use iteration-based training, make sure the - # sample-based options are off. - assert args.train_samples is None, \ - 'expected iteration-based training' - assert args.lr_decay_samples is None, \ - 'expected iteration-based learning rate decay' - assert args.lr_warmup_samples == 0, \ - 'expected iteration-based learning rate warmup' - assert args.rampup_batch_size is None, \ - 'expected no batch-size rampup for iteration-based training' - if args.lr_warmup_fraction is not None: - assert args.lr_warmup_iters == 0, \ - 'can only specify one of lr-warmup-fraction and lr-warmup-iters' - - # Sample-based training. - if args.train_samples: - # If we use sample-based training, make sure the - # iteration-based options are off. - assert args.train_iters is None, \ - 'expected sample-based training' - assert args.lr_decay_iters is None, \ - 'expected sample-based learning rate decay' - assert args.lr_warmup_iters == 0, \ - 'expected sample-based learnig rate warmup' - if args.lr_warmup_fraction is not None: - assert args.lr_warmup_samples == 0, \ - 'can only specify one of lr-warmup-fraction ' \ - 'and lr-warmup-samples' - - # Check required arguments. - required_args = ['num_layers', 'hidden_size', 'num_attention_heads', - 'max_position_embeddings'] - for req_arg in required_args: - _check_arg_is_not_none(args, req_arg) - - # Checks. - if args.ffn_hidden_size is None: - args.ffn_hidden_size = 4 * args.hidden_size - - if args.kv_channels is None: - assert args.hidden_size % args.num_attention_heads == 0 - args.kv_channels = args.hidden_size // args.num_attention_heads - - if args.seq_length is not None: - assert args.encoder_seq_length is None - args.encoder_seq_length = args.seq_length - else: - assert args.encoder_seq_length is not None - args.seq_length = args.encoder_seq_length - - if args.seq_length is not None: - assert args.max_position_embeddings >= args.seq_length - if args.decoder_seq_length is not None: - assert args.max_position_embeddings >= args.decoder_seq_length - if args.lr is not None: - assert args.min_lr <= args.lr - if args.save is not None: - assert args.save_interval is not None - # Mixed precision checks. - if args.fp16_lm_cross_entropy: - assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.' - if args.fp32_residual_connection: - assert args.fp16 or args.bf16, \ - 'residual connection in fp32 only supported when using fp16 or bf16.' - - if args.weight_decay_incr_style == 'constant': - assert args.start_weight_decay is None - assert args.end_weight_decay is None - args.start_weight_decay = args.weight_decay - args.end_weight_decay = args.weight_decay - else: - assert args.start_weight_decay is not None - assert args.end_weight_decay is not None - - TORCH_MAJOR = int(torch.__version__.split('.')[0]) - TORCH_MINOR = int(torch.__version__.split('.')[1]) - # Persistent fused layer norm. - if TORCH_MAJOR < 1 or (TORCH_MAJOR == 1 and TORCH_MINOR < 11): - args.no_persist_layer_norm = True - if args.rank == 0: - print('Persistent fused layer norm kernel is supported from ' - 'pytorch v1.11 (nvidia pytorch container paired with v1.11). ' - 'Defaulting to no_persist_layer_norm=True') - - # Activation recomputing. - if args.distribute_saved_activations: - assert args.tensor_model_parallel_size > 1, 'can distribute ' \ - 'recomputed activations only across tensor model ' \ - 'parallel groups' - assert args.recompute_granularity == 'full', \ - 'distributed recompute activations is only '\ - 'application to full recompute granularity' - assert args.recompute_method is not None, \ - 'for distributed recompute activations to work you '\ - 'need to use a recompute method ' - assert TORCH_MAJOR >= 1 and TORCH_MINOR >= 10, \ - 'distributed recompute activations are supported for pytorch ' \ - 'v1.10 and above (Nvidia Pytorch container >= 21.07). Current ' \ - 'pytorch version is v%s.%s.' % (TORCH_MAJOR, TORCH_MINOR) - - if args.recompute_granularity == 'selective': - assert args.recompute_method is None, \ - 'recompute method is not yet supported for ' \ - 'selective recomputing granularity' - - # disable async_tensor_model_parallel_allreduce when - # model parallel memory optimization is enabled - if args.sequence_parallel: - args.async_tensor_model_parallel_allreduce = False - - _print_args(args) - return args - - -def _print_args(args): - """Print arguments.""" - if args.rank == 0: - print('------------------------ arguments ------------------------', - flush=True) - str_list = [] - for arg in vars(args): - dots = '.' * (48 - len(arg)) - str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg))) - for arg in sorted(str_list, key=lambda x: x.lower()): - print(arg, flush=True) - print('-------------------- end of arguments ---------------------', - flush=True) - - -def _check_arg_is_not_none(args, arg): - assert getattr(args, arg) is not None, '{} argument is None'.format(arg) - - -def _add_inference_args(parser): - group = parser.add_argument_group(title='inference') - - group.add_argument('--inference-batch-times-seqlen-threshold', - type=int, default=512, - help='During inference, if batch-size times ' - 'sequence-length is smaller than this threshold ' - 'then we will not use pipelining, otherwise we will.') - - return parser - - -def _add_network_size_args(parser): - group = parser.add_argument_group(title='network size') - - group.add_argument('--num-layers', type=int, default=None, - help='Number of transformer layers.') - group.add_argument('--hidden-size', type=int, default=None, - help='Tansformer hidden size.') - group.add_argument('--ffn-hidden-size', type=int, default=None, - help='Transformer Feed-Forward Network hidden size. ' - 'This is set to 4*hidden-size if not provided') - group.add_argument('--num-attention-heads', type=int, default=None, - help='Number of transformer attention heads.') - group.add_argument('--kv-channels', type=int, default=None, - help='Projection weights dimension in multi-head ' - 'attention. This is set to ' - ' args.hidden_size // args.num_attention_heads ' - 'if not provided.') - group.add_argument('--max-position-embeddings', type=int, default=None, - help='Maximum number of position embeddings to use. ' - 'This is the size of position embedding.') - group.add_argument('--make-vocab-size-divisible-by', type=int, default=128, - help='Pad the vocab size to be divisible by this value.' - 'This is added for computational efficieny reasons.') - group.add_argument('--layernorm-epsilon', type=float, default=1e-5, - help='Layer norm epsilon.') - group.add_argument('--apply-residual-connection-post-layernorm', - action='store_true', - help='If set, use original BERT residula connection ' - 'ordering.') - group.add_argument('--openai-gelu', action='store_true', - help='Use OpenAIs GeLU implementation. This option' - 'should not be used unless for backward compatibility' - 'reasons.') - group.add_argument('--onnx-safe', type=bool, required=False, - help='Use workarounds for known problems with ' - 'Torch ONNX exporter') - group.add_argument('--bert-no-binary-head', action='store_false', - help='Disable BERT binary head.', - dest='bert_binary_head') - group.add_argument('--num-experts', type=int, default=None, - help='Number of Experts in Switch Transformer (None means no Switch)') - - return parser - - -def _add_logging_args(parser): - group = parser.add_argument_group(title='logging') - - group.add_argument('--log-params-norm', action='store_true', - help='If set, calculate and log parameters norm.') - group.add_argument('--log-num-zeros-in-grad', action='store_true', - help='If set, calculate and log the number of zeros in gradient.') - group.add_argument('--tensorboard-log-interval', type=int, default=1, - help='Report to tensorboard interval.') - group.add_argument('--tensorboard-queue-size', type=int, default=1000, - help='Size of the tensorboard queue for pending events ' - 'and summaries before one of the ‘add’ calls forces a ' - 'flush to disk.') - group.add_argument('--log-timers-to-tensorboard', action='store_true', - help='If set, write timers to tensorboard.') - group.add_argument('--log-batch-size-to-tensorboard', action='store_true', - help='If set, write batch-size to tensorboard.') - group.add_argument('--no-log-learnig-rate-to-tensorboard', - action='store_false', - help='Disable learning rate logging to tensorboard.', - dest='log_learning_rate_to_tensorboard') - group.add_argument('--no-log-loss-scale-to-tensorboard', - action='store_false', - help='Disable loss-scale logging to tensorboard.', - dest='log_loss_scale_to_tensorboard') - group.add_argument('--log-validation-ppl-to-tensorboard', - action='store_true', - help='If set, write validation perplexity to ' - 'tensorboard.') - group.add_argument('--log-memory-to-tensorboard', - action='store_true', - help='Enable memory logging to tensorboard.') - group.add_argument('--log-world-size-to-tensorboard', - action='store_true', - help='Enable world size logging to tensorboard.') - - return parser - - -def _add_regularization_args(parser): - group = parser.add_argument_group(title='regularization') - - group.add_argument('--attention-dropout', type=float, default=0.1, - help='Post attention dropout probability.') - group.add_argument('--hidden-dropout', type=float, default=0.1, - help='Dropout probability for hidden state transformer.') - group.add_argument('--weight-decay', type=float, default=0.01, - help='Weight decay coefficient for L2 regularization.') - group.add_argument('--start-weight-decay', type=float, - help='Initial weight decay coefficient for L2 regularization.') - group.add_argument('--end-weight-decay', type=float, - help='End of run weight decay coefficient for L2 regularization.') - group.add_argument('--weight-decay-incr-style', type=str, default='constant', - choices=['constant', 'linear', 'cosine'], - help='Weight decay increment function.') - group.add_argument('--clip-grad', type=float, default=1.0, - help='Gradient clipping based on global L2 norm.') - group.add_argument('--adam-beta1', type=float, default=0.9, - help='First coefficient for computing running averages ' - 'of gradient and its square') - group.add_argument('--adam-beta2', type=float, default=0.999, - help='Second coefficient for computing running averages ' - 'of gradient and its square') - group.add_argument('--adam-eps', type=float, default=1e-08, - help='Term added to the denominator to improve' - 'numerical stability') - group.add_argument('--sgd-momentum', type=float, default=0.9, - help='Momentum factor for sgd') - - return parser - - -def _add_training_args(parser): - group = parser.add_argument_group(title='training') - - group.add_argument('--micro-batch-size', type=int, default=None, - help='Batch size per model instance (local batch size). ' - 'Global batch size is local batch size times data ' - 'parallel size times number of micro batches.') - group.add_argument('--batch-size', type=int, default=None, - help='Old batch size parameter, do not use. ' - 'Use --micro-batch-size instead') - group.add_argument('--global-batch-size', type=int, default=None, - help='Training batch size. If set, it should be a ' - 'multiple of micro-batch-size times data-parallel-size. ' - 'If this value is None, then ' - 'use micro-batch-size * data-parallel-size as the ' - 'global batch size. This choice will result in 1 for ' - 'number of micro-batches.') - group.add_argument('--rampup-batch-size', nargs='*', default=None, - help='Batch size ramp up with the following values:' - ' --rampup-batch-size ' - ' ' - ' ' - 'For example:' - ' --rampup-batch-size 16 8 300000 \ ' - ' --global-batch-size 1024' - 'will start with global batch size 16 and over ' - ' (1024 - 16) / 8 = 126 intervals will increase' - 'the batch size linearly to 1024. In each interval' - 'we will use approximately 300000 / 126 = 2380 samples.') - group.add_argument('--recompute-activations', action='store_true', - help='recompute activation to allow for training ' - 'with larger models, sequences, and batch sizes.') - group.add_argument('--recompute-granularity', type=str, default=None, - choices=['full', 'selective'], - help='Checkpoint activations to allow for training ' - 'with larger models, sequences, and batch sizes. ' - 'It is supported at two granularities 1) full: ' - 'whole transformer layer is recomputed, ' - '2) selective: core attention part of the transformer ' - 'layer is recomputed.') - group.add_argument('--distribute-saved-activations', - action='store_true', - help='If set, distribute recomputed activations ' - 'across model parallel group.') - group.add_argument('--recompute-method', type=str, default=None, - choices=['uniform', 'block'], - help='1) uniform: uniformly divide the total number of ' - 'Transformer layers and recompute the input activation of ' - 'each divided chunk at specified granularity, ' - '2) recompute the input activations of only a set number of ' - 'individual Transformer layers per pipeline stage and do the ' - 'rest without any recomputing at specified granularity' - 'default) do not apply activations recompute to any layers') - group.add_argument('--recompute-num-layers', type=int, default=1, - help='1) uniform: the number of Transformer layers in each ' - 'uniformly divided recompute unit, ' - '2) block: the number of individual Transformer layers ' - 'to recompute within each pipeline stage.') - - # deprecated - group.add_argument('--checkpoint-activations', action='store_true', - help='Checkpoint activation to allow for training ' - 'with larger models, sequences, and batch sizes.') - group.add_argument('--train-iters', type=int, default=None, - help='Total number of iterations to train over all ' - 'training runs. Note that either train-iters or ' - 'train-samples should be provided.') - group.add_argument('--train-samples', type=int, default=None, - help='Total number of samples to train over all ' - 'training runs. Note that either train-iters or ' - 'train-samples should be provided.') - group.add_argument('--log-interval', type=int, default=100, - help='Report loss and timing interval.') - group.add_argument('--exit-interval', type=int, default=None, - help='Exit the program after the iteration is divisible ' - 'by this value.') - group.add_argument('--exit-duration-in-mins', type=int, default=None, - help='Exit the program after this many minutes.') - group.add_argument('--tensorboard-dir', type=str, default=None, - help='Write TensorBoard logs to this directory.') - group.add_argument('--no-masked-softmax-fusion', - action='store_false', - help='Disable fusion of query_key_value scaling, ' - 'masking, and softmax.', - dest='masked_softmax_fusion') - group.add_argument('--no-bias-gelu-fusion', action='store_false', - help='Disable bias and gelu fusion.', - dest='bias_gelu_fusion') - group.add_argument('--no-bias-dropout-fusion', action='store_false', - help='Disable bias and dropout fusion.', - dest='bias_dropout_fusion') - group.add_argument('--optimizer', type=str, default='adam', - choices=['adam', 'sgd'], - help='Optimizer function') - group.add_argument('--dataloader-type', type=str, default=None, - choices=['single', 'cyclic'], - help='Single pass vs multiple pass data loader') - group.add_argument('--no-async-tensor-model-parallel-allreduce', - action='store_true', - help='Disable asynchronous execution of ' - 'tensor-model-parallel all-reduce with weight ' - 'gradient compuation of a column-linear layer.', - dest='async_tensor_model_parallel_allreduce') - group.add_argument('--no-persist-layer-norm', action='store_true', - help='Disable using persistent fused layer norm kernel. ' - 'This kernel supports only a set of hidden sizes. Please ' - 'check persist_ln_hidden_sizes if your hidden ' - 'size is supported.') - group.add_argument('--sequence-parallel', action='store_true', - help='Enable sequence parallel optimization.') - group.add_argument('--no-gradient-accumulation-fusion', - action='store_false', - help='Disable fusing gradient accumulation to weight ' - 'gradient computation of linear layers', - dest='gradient_accumulation_fusion') - return parser - - -def _add_initialization_args(parser): - group = parser.add_argument_group(title='initialization') - - group.add_argument('--seed', type=int, default=1234, - help='Random seed used for python, numpy, ' - 'pytorch, and cuda.') - group.add_argument('--init-method-std', type=float, default=0.02, - help='Standard deviation of the zero mean normal ' - 'distribution used for weight initialization.') - group.add_argument('--init-method-xavier-uniform', action='store_true', - help='Enable Xavier uniform parameter initialization') - - return parser - - -def _add_learning_rate_args(parser): - group = parser.add_argument_group(title='learning rate') - - group.add_argument('--lr', type=float, default=None, - help='Initial learning rate. Depending on decay style ' - 'and initial warmup, the learing rate at each ' - 'iteration would be different.') - group.add_argument('--lr-decay-style', type=str, default='linear', - choices=['constant', 'linear', 'cosine'], - help='Learning rate decay function.') - group.add_argument('--lr-decay-iters', type=int, default=None, - help='number of iterations to decay learning rate over,' - ' If None defaults to `--train-iters`') - group.add_argument('--lr-decay-samples', type=int, default=None, - help='number of samples to decay learning rate over,' - ' If None defaults to `--train-samples`') - group.add_argument('--lr-warmup-fraction', type=float, default=None, - help='fraction of lr-warmup-(iters/samples) to use ' - 'for warmup (as a float)') - group.add_argument('--lr-warmup-iters', type=int, default=0, - help='number of iterations to linearly warmup ' - 'learning rate over.') - group.add_argument('--lr-warmup-samples', type=int, default=0, - help='number of samples to linearly warmup ' - 'learning rate over.') - group.add_argument('--warmup', type=int, default=None, - help='Old lr warmup argument, do not use. Use one of the' - '--lr-warmup-* arguments above') - group.add_argument('--min-lr', type=float, default=0.0, - help='Minumum value for learning rate. The scheduler' - 'clip values below this threshold.') - group.add_argument('--override-lr-scheduler', action='store_true', - help='Reset the values of the scheduler (learning rate,' - 'warmup iterations, minimum learning rate, maximum ' - 'number of iterations, and decay style from input ' - 'arguments and ignore values from checkpoints. Note' - 'that all the above values will be reset.') - group.add_argument('--use-checkpoint-lr-scheduler', action='store_true', - help='Use checkpoint to set the values of the scheduler ' - '(learning rate, warmup iterations, minimum learning ' - 'rate, maximum number of iterations, and decay style ' - 'from checkpoint and ignore input arguments.') - - return parser - - -def _add_checkpointing_args(parser): - group = parser.add_argument_group(title='checkpointing') - - group.add_argument('--save', type=str, default=None, - help='Output directory to save checkpoints to.') - group.add_argument('--save-interval', type=int, default=None, - help='Number of iterations between checkpoint saves.') - group.add_argument('--no-save-optim', action='store_true', default=None, - help='Do not save current optimizer.') - group.add_argument('--no-save-rng', action='store_true', default=None, - help='Do not save current rng state.') - group.add_argument('--load', type=str, default=None, - help='Directory containing a model checkpoint.') - group.add_argument('--no-load-optim', action='store_true', default=None, - help='Do not load optimizer when loading checkpoint.') - group.add_argument('--no-load-rng', action='store_true', default=None, - help='Do not load rng state when loading checkpoint.') - group.add_argument('--finetune', action='store_true', - help='Load model for finetuning. Do not load optimizer ' - 'or rng state from checkpoint and set iteration to 0. ' - 'Assumed when loading a release checkpoint.') - - return parser - - -def _add_mixed_precision_args(parser): - group = parser.add_argument_group(title='mixed precision') - - group.add_argument('--fp16', action='store_true', - help='Run model in fp16 mode.') - group.add_argument('--bf16', action='store_true', - help='Run model in bfloat16 mode.') - group.add_argument('--loss-scale', type=float, default=None, - help='Static loss scaling, positive power of 2 ' - 'values can improve fp16 convergence. If None, dynamic' - 'loss scaling is used.') - group.add_argument('--initial-loss-scale', type=float, default=2**32, - help='Initial loss-scale for dynamic loss scaling.') - group.add_argument('--min-loss-scale', type=float, default=1.0, - help='Minimum loss scale for dynamic loss scale.') - group.add_argument('--loss-scale-window', type=float, default=1000, - help='Window over which to raise/lower dynamic scale.') - group.add_argument('--hysteresis', type=int, default=2, - help='hysteresis for dynamic loss scaling') - group.add_argument('--fp32-residual-connection', action='store_true', - help='Move residual connections to fp32.') - group.add_argument('--no-query-key-layer-scaling', action='store_false', - help='Do not scale Q * K^T by 1 / layer-number.', - dest='apply_query_key_layer_scaling') - group.add_argument('--attention-softmax-in-fp32', action='store_true', - help='Run attention masking and softmax in fp32. ' - 'This flag is ignored unless ' - '--no-query-key-layer-scaling is specified.') - group.add_argument('--accumulate-allreduce-grads-in-fp32', - action='store_true', - help='Gradient accumulation and all-reduce in fp32.') - group.add_argument('--fp16-lm-cross-entropy', action='store_true', - help='Move the cross entropy unreduced loss calculation' - 'for lm head to fp16.') - - return parser - - -def _add_distributed_args(parser): - group = parser.add_argument_group(title='distributed') - - group.add_argument('--tensor-model-parallel-size', type=int, default=1, - help='Degree of tensor model parallelism.') - group.add_argument('--pipeline-model-parallel-size', type=int, default=1, - help='Degree of pipeline model parallelism.') - group.add_argument('--pipeline-model-parallel-split-rank', - type=int, default=None, - help='Rank where encoder and decoder should be split.') - group.add_argument('--model-parallel-size', type=int, default=None, - help='Old model parallel argument, do not use. Use ' - '--tensor-model-parallel-size instead.') - group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None, - help='Number of layers per virtual pipeline stage') - group.add_argument('--distributed-backend', default='nccl', - choices=['nccl', 'gloo'], - help='Which backend to use for distributed training.') - group.add_argument('--DDP-impl', default='local', - choices=['local', 'torch'], - help='which DistributedDataParallel implementation ' - 'to use.') - group.add_argument('--no-contiguous-buffers-in-local-ddp', - action='store_false', help='If set, dont use ' - 'contiguous buffer in local DDP.', - dest='use_contiguous_buffers_in_local_ddp') - group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false', - help='Use scatter/gather to optimize communication of tensors in pipeline', - dest='scatter_gather_tensors_in_pipeline') - group.add_argument('--local_rank', type=int, default=None, - help='local rank passed from distributed launcher.') - group.add_argument('--lazy-mpu-init', type=bool, required=False, - help='If set to True, initialize_megatron() ' - 'skips DDP initialization and returns function to ' - 'complete it instead.Also turns on ' - '--use-cpu-initialization flag. This is for ' - 'external DDP manager.' ) - group.add_argument('--use-cpu-initialization', action='store_true', - default=None, help='If set, affine parallel weights ' - 'initialization uses CPU' ) - group.add_argument('--empty-unused-memory-level', default=0, type=int, - choices=[0, 1, 2], - help='Call torch.cuda.empty_cache() each iteration ' - '(training and eval), to reduce fragmentation.' - '0=off, 1=moderate, 2=aggressive.') - group.add_argument('--standalone-embedding-stage', action='store_true', - default=False, help='If set, *input* embedding layer ' - 'is placed on its own pipeline stage, without any ' - 'transformer layers. (For T5, this flag currently only ' - 'affects the encoder embedding.)') - return parser - - -def _add_validation_args(parser): - group = parser.add_argument_group(title='validation') - - group.add_argument('--eval-iters', type=int, default=100, - help='Number of iterations to run for evaluation' - 'validation/test for.') - group.add_argument('--eval-interval', type=int, default=1000, - help='Interval between running evaluation on ' - 'validation set.') - - return parser - - -def _add_data_args(parser): - group = parser.add_argument_group(title='data and dataloader') - - group.add_argument('--data-path', nargs='*', default=None, - help='Path to the training dataset. Accepted format:' - '1) a single data path, 2) multiple datasets in the' - 'form: dataset1-weight dataset1-path dataset2-weight ' - 'dataset2-path ...') - group.add_argument('--split', type=str, default='969, 30, 1', - help='Comma-separated list of proportions for training,' - ' validation, and test split. For example the split ' - '`90,5,5` will use 90%% of data for training, 5%% for ' - 'validation and 5%% for test.') - group.add_argument('--vocab-file', type=str, default=None, - help='Path to the vocab file.') - group.add_argument('--merge-file', type=str, default=None, - help='Path to the BPE merge file.') - group.add_argument('--vocab-extra-ids', type=int, default=0, - help='Number of additional vocabulary tokens. ' - 'They are used for span masking in the T5 model') - group.add_argument('--seq-length', type=int, default=None, - help='Maximum sequence length to process.') - group.add_argument('--encoder-seq-length', type=int, default=None, - help='Maximum encoder sequence length to process.' - 'This should be exclusive of --seq-length') - group.add_argument('--decoder-seq-length', type=int, default=None, - help="Maximum decoder sequence length to process.") - group.add_argument('--retriever-seq-length', type=int, default=256, - help='Maximum sequence length for the biencoder model ' - ' for retriever') - group.add_argument('--sample-rate', type=float, default=1.0, - help='sample rate for training data. Supposed to be 0 ' - ' < sample_rate < 1') - group.add_argument('--mask-prob', type=float, default=0.15, - help='Probability of replacing a token with mask.') - group.add_argument('--short-seq-prob', type=float, default=0.1, - help='Probability of producing a short sequence.') - group.add_argument('--mmap-warmup', action='store_true', - help='Warm up mmap files.') - group.add_argument('--num-workers', type=int, default=2, - help="Dataloader number of workers.") - group.add_argument('--tokenizer-type', type=str, - default=None, - choices=['BertWordPieceLowerCase', - 'BertWordPieceCase', - 'GPT2BPETokenizer'], - help='What type of tokenizer to use.') - group.add_argument('--data-impl', type=str, default='infer', - choices=['lazy', 'cached', 'mmap', 'infer'], - help='Implementation of indexed datasets.') - group.add_argument('--reset-position-ids', action='store_true', - help='Reset posistion ids after end-of-document token.') - group.add_argument('--reset-attention-mask', action='store_true', - help='Reset self attention maske after ' - 'end-of-document token.') - group.add_argument('--eod-mask-loss', action='store_true', - help='Mask loss for the end of document tokens.') - - return parser - - -def _add_autoresume_args(parser): - group = parser.add_argument_group(title='autoresume') - - group.add_argument('--adlr-autoresume', action='store_true', - help='Enable autoresume on adlr cluster.') - group.add_argument('--adlr-autoresume-interval', type=int, default=1000, - help='Intervals over which check for autoresume' - 'termination signal') - - return parser - - -def _add_biencoder_args(parser): - group = parser.add_argument_group(title='biencoder') - - # network size - group.add_argument('--ict-head-size', type=int, default=None, - help='Size of block embeddings to be used in ICT and ' - 'REALM (paper default: 128)') - group.add_argument('--biencoder-projection-dim', type=int, default=0, - help='Size of projection head used in biencoder (paper' - ' default: 128)') - group.add_argument('--biencoder-shared-query-context-model', action='store_true', - help='Whether to share the parameters of the query ' - 'and context models or not') - - # checkpointing - group.add_argument('--ict-load', type=str, default=None, - help='Directory containing an ICTBertModel checkpoint') - group.add_argument('--bert-load', type=str, default=None, - help='Directory containing an BertModel checkpoint ' - '(needed to start ICT and REALM)') - - # data - group.add_argument('--titles-data-path', type=str, default=None, - help='Path to titles dataset used for ICT') - group.add_argument('--query-in-block-prob', type=float, default=0.1, - help='Probability of keeping query in block for ' - 'ICT dataset') - group.add_argument('--use-one-sent-docs', action='store_true', - help='Whether to use one sentence documents in ICT') - group.add_argument('--evidence-data-path', type=str, default=None, - help='Path to Wikipedia Evidence frm DPR paper') - - # training - group.add_argument('--retriever-report-topk-accuracies', nargs='+', type=int, - default=[], help="Which top-k accuracies to report " - "(e.g. '1 5 20')") - group.add_argument('--retriever-score-scaling', action='store_true', - help='Whether to scale retriever scores by inverse ' - 'square root of hidden size') - - # faiss index - group.add_argument('--block-data-path', type=str, default=None, - help='Where to save/load BlockData to/from') - group.add_argument('--embedding-path', type=str, default=None, - help='Where to save/load Open-Retrieval Embedding' - ' data to/from') - - # indexer - group.add_argument('--indexer-batch-size', type=int, default=128, - help='How large of batches to use when doing indexing ' - 'jobs') - group.add_argument('--indexer-log-interval', type=int, default=1000, - help='After how many batches should the indexer ' - 'report progress') - return parser - - -def _add_vision_args(parser): - group = parser.add_argument_group(title="vision") - - # general vision arguments - group.add_argument('--num-classes', type=int, default=1000, - help='num of classes in vision classificaiton task') - group.add_argument('--img-h', type=int, default=224, - help='Image height for vision classification task') - group.add_argument('--img-w', type=int, default=224, - help='Image height for vision classification task') - group.add_argument('--num-channels', type=int, default=3, - help='Number of channels in input image data') - group.add_argument('--patch-dim', type=int, default=16, - help='patch dimension') - group.add_argument('--classes-fraction', type=float, default=1.0, - help='training with fraction of classes.') - group.add_argument('--data-per-class-fraction', type=float, default=1.0, - help='training with fraction of data per class.') - group.add_argument('--no-data-sharding', action='store_false', - help='Disable data sharding.', - dest='data_sharding') - group.add_argument('--head-lr-mult', type=float, default=1.0, - help='learning rate multiplier for head during finetuning') - - # pretraining type and backbone selection` - group.add_argument('--vision-pretraining', action='store_true', - help='flag to indicate vision pretraining') - group.add_argument('--vision-pretraining-type', type=str, default='classify', - choices=['classify', 'inpaint', 'dino'], - help='pretraining objectives') - group.add_argument('--vision-backbone-type', type=str, default='vit', - choices=['vit', 'mit', 'swin'], - help='backbone types types') - group.add_argument('--swin-backbone-type', type=str, default='tiny', - choices=['tiny', 'base', 'h3'], - help='pretraining objectives') - - # inpainting arguments - group.add_argument('--mask-type', type=str, default='random', - choices=['random', 'row'], - help='mask types') - group.add_argument('--mask-factor', type=float, default=1.0, - help='mask size scaling parameter') - - # dino arguments - group.add_argument('--iter-per-epoch', type=int, default=1250, - help='iterations per epoch') - group.add_argument('--dino-local-img-size', type=int, default=96, - help='Image size for vision classification task') - group.add_argument('--dino-local-crops-number', type=int, default=10, - help='Number of local crops') - group.add_argument('--dino-head-hidden-size', type=int, default=2048, - help='Hidden dimension size in dino head') - group.add_argument('--dino-bottleneck-size', type=int, default=256, - help='Bottle neck dimension in dino head ') - group.add_argument('--dino-freeze-last-layer', type=float, default=1, - help='Freezing last layer weights') - group.add_argument('--dino-norm-last-layer', action='store_true', - help='Disable Norm in last layer.') - group.add_argument('--dino-warmup-teacher-temp', type=float, default=0.04, - help='warump teacher temperature') - group.add_argument('--dino-teacher-temp', type=float, default=0.07, - help='teacher temperature') - group.add_argument('--dino-warmup-teacher-temp-epochs', type=int, default=30, - help='warmup teacher temperaure epochs') - - return parser diff --git a/apex/transformer/testing/commons.py b/apex/transformer/testing/commons.py deleted file mode 100644 index 226e449..0000000 --- a/apex/transformer/testing/commons.py +++ /dev/null @@ -1,297 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from dataclasses import dataclass -import datetime -import os -import random -from typing import Optional, Union, List, Tuple, Callable, Dict - -import numpy -import torch -import torch.nn as nn - -from apex import transformer -from apex.transformer.tensor_parallel import( - ColumnParallelLinear, - RowParallelLinear, - scatter_to_sequence_parallel_region, -) -from apex.transformer.pipeline_parallel.utils import ( - average_losses_across_data_parallel_group, -) -from apex.transformer.pipeline_parallel.schedules.common import ( - Batch, -) -from apex.transformer.testing import global_vars - - -TEST_SUCCESS_MESSAGE = ">> passed the test :-)" - - -# note (mkozuki): `pre_process` and `post_process` are a placeholder until interleaving schedule test comes. -class MyLayer(nn.Module): - def __init__(self, hidden_size: int, pre_process: bool, post_process: bool): - super().__init__() - self.pre_process = pre_process - self.post_process = post_process - self.layer = nn.Linear(hidden_size, hidden_size) - - def forward(self, x): - return self.layer(x) - - -class MyModel(nn.Module): - def __init__( - self, - hidden_size: int, pre_process: bool = False, post_process: bool = False, - *, - add_encoder: bool = False, add_decoder: bool = False, - ) -> None: - super().__init__() - self.pre_process = pre_process - self.post_process = post_process - self.layer = MyLayer( - hidden_size=hidden_size, pre_process=pre_process, post_process=post_process - ) - self.input_tensor = None - - def set_input_tensor( - self, input_tensor: Union[torch.Tensor, List[torch.Tensor]] - ) -> None: - if not isinstance(input_tensor, list): - input_tensor = [input_tensor] - self.input_tensor = input_tensor[0] - - def forward(self, x: Optional[torch.Tensor]) -> torch.Tensor: - if self.input_tensor is None: - return self.layer(x) - return self.layer(self.input_tensor) - - -class ToyParallelMLP(nn.Module): - def __init__( - self, - hidden_size: int, pre_process: bool = False, post_process: bool = False, - *, - sequence_parallel_enabled: bool = False, - # TODO(mkozuki): Support these two? - add_encoder: bool = False, add_decoder: bool = False, - ) -> None: - super().__init__() - self.pre_process = pre_process - self.post_process = post_process - self.sequence_parallel_enabled = sequence_parallel_enabled - - ffn_hidden_size = 4 * hidden_size - self.dense_h_to_4h = ColumnParallelLinear( - hidden_size, - ffn_hidden_size, - gather_output=False, - # init_method=init_method, - skip_bias_add=True, - # use_cpu_initialization=use_cpu_initialization, - bias=True, - sequence_parallel_enabled=sequence_parallel_enabled, - no_async_tensor_model_parallel_allreduce=True, - ) - self.dense_4h_to_h = RowParallelLinear( - ffn_hidden_size, - hidden_size, - input_is_parallel=True, - # init_method=output_layer_init_method, - skip_bias_add=False, - # use_cpu_initialization=use_cpu_initialization, - bias=True, - sequence_parallel_enabled=sequence_parallel_enabled, - ) - self.activation_func = torch.nn.GELU() - - def set_input_tensor( - self, - input_tensor: Union[torch.Tensor, List[torch.Tensor]], - ) -> None: - if not isinstance(input_tensor, list): - input_tensor = [input_tensor] - self.input_tensor = input_tensor[0] - - def forward( - self, - x: Optional[torch.Tensor], - ) -> torch.Tensor: - """Forward of Simplified ParallelMLP. - - Args: - x: :obj:`None` if pipeline rank != pippeline first rank. When :obj:`None`, - `self.input_tensor` is taken care of by `forward_step` defined in - apex/transformer/pipeline_parallel/schedules/common.py - """ - # [s, b, h] - if self.input_tensor is None: - input = x - else: - input = self.input_tensor - intermediate_parallel, bias_parallel = self.dense_h_to_4h(input) - - if bias_parallel is not None: - intermediate_parallel += bias_parallel - intermediate_parallel = self.activation_func(intermediate_parallel) - # [s, b, h] - output, output_bias = self.dense_4h_to_h(intermediate_parallel) - return output - - -def model_provider_func( - hidden_size: int, - pre_process: bool, - post_process: bool, - *, - add_encoder: bool = False, - add_decoder: bool = False) -> MyModel: - return MyModel(hidden_size, pre_process, post_process, add_encoder=add_encoder, add_decoder=add_decoder) - - -def mlp_provider_func( - hidden_size: int, - pre_process: bool, - post_process: bool, - *, - add_encoder: bool = False, - add_decoder: bool = False, - sequence_parallel_enabled: bool = False, -) -> ToyParallelMLP: - return ToyParallelMLP( - hidden_size, - pre_process, - post_process, - add_encoder=add_encoder, - add_decoder=add_decoder, - sequence_parallel_enabled=sequence_parallel_enabled, - ) - - -def process_batch(batch): - if isinstance(batch, list): - x = batch[0] - else: - x = batch - return x - - -def fwd_step_func(batch, model): - x = process_batch(batch) - y = model(x) - - # note (mkozuki): I don't think this function is nice but I do think this is enough for now - # just to check the sanity of ported pipeline functions. - def loss_func(x): - loss = torch.sum(x) - averaged_loss = average_losses_across_data_parallel_group([loss]) - return loss, {"avg": averaged_loss} - - return y, loss_func - - -@dataclass(frozen=True) -class ToyParallelMLPFwdBwdStepFunc: - - sequence_parallel_enabled: bool - - def __call__( - self, - batch: Batch, - model: torch.nn.Module, - ) -> Tuple[torch.Tensor, Callable[[torch.Tensor], Tuple[torch.Tensor, Dict[str, torch.Tensor]]]]: - x = batch[0] if isinstance(batch, list) else batch - if isinstance(x, torch.Tensor): - x = x.transpose(0, 1).contiguous() - if self.sequence_parallel_enabled: - x = scatter_to_sequence_parallel_region(x) - y = model(x) - - # note (mkozuki): I don't think this function is nice but I do think this is enough for now - # just to check the sanity of ported pipeline functions. - def loss_func(x): - loss = torch.sum(x) - averaged_loss = average_losses_across_data_parallel_group([loss]) - return loss, {"avg": averaged_loss} - - return y, loss_func - - -class IdentityLayer(torch.nn.Module): - def __init__(self, size, scale=1.0): - super(IdentityLayer, self).__init__() - self.weight = torch.nn.Parameter(scale * torch.randn(size)) - - def forward(self): - return self.weight - - -def set_random_seed(seed): - """Set random seed for reproducibility.""" - random.seed(seed) - numpy.random.seed(seed) - torch.manual_seed(seed) - transformer.tensor_parallel.model_parallel_cuda_manual_seed(seed) - - -def initialize_distributed(backend="nccl"): - """Initialize torch.distributed.""" - # Get local rank in case it is provided. - # parser = argparse.ArgumentParser() - # parser.add_argument('--local_rank', type=int, default=None, - # help='local rank passed from distributed launcher') - # args = parser.parse_args() - if backend not in ("nccl", "ucc"): - raise RuntimeError(f"Currently only nccl & ucc are supported but {backend}") - if backend == "ucc": - import torch_ucc # NOQA - args = global_vars.get_args() - local_rank = args.local_rank - - # Get rank and world size. - rank = int(os.getenv("RANK", "0")) - world_size = int(os.getenv("WORLD_SIZE", "1")) - - print( - "> initializing torch.distributed with local rank: {}, " - "rank: {}, world size: {}".format(local_rank, rank, world_size) - ) - - # Set the device id. - device = rank % torch.cuda.device_count() - if local_rank is not None: - device = local_rank - torch.cuda.set_device(device) - - # Call the init process. - init_method = "tcp://" - master_ip = os.getenv("MASTER_ADDR", "localhost") - master_port = os.getenv("MASTER_PORT", "6000") - init_method += master_ip + ":" + master_port - torch.distributed.init_process_group( - backend=backend, world_size=world_size, rank=rank, init_method=init_method, - timeout=datetime.timedelta(seconds=60), - ) - - -def print_separator(message): - torch.distributed.barrier() - filler_len = (78 - len(message)) // 2 - filler = "-" * filler_len - string = "\n" + filler + " {} ".format(message) + filler - if torch.distributed.get_rank() == 0: - print(string, flush=True) - torch.distributed.barrier() diff --git a/apex/transformer/testing/distributed_test_base.py b/apex/transformer/testing/distributed_test_base.py deleted file mode 100644 index 7a81687..0000000 --- a/apex/transformer/testing/distributed_test_base.py +++ /dev/null @@ -1,133 +0,0 @@ -import os -import sys -import unittest -from packaging.version import Version, parse - -import torch -from torch import distributed as dist -from torch.utils import collect_env -from torch.testing._internal import common_utils -from torch.testing._internal import common_distributed - -HAS_TORCH_UCC = None -try: - import torch_ucc - HAS_TORCH_UCC = True -except ImportError: - HAS_TORCH_UCC = False - -# NOTE(mkozuki): Version guard for ucc. ref: https://github.com/openucx/ucc/issues/496 -_TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION = Version("470.42.01") -_driver_version = None -if torch.cuda.is_available(): - if collect_env.get_nvidia_driver_version(collect_env.run) != None: - _driver_version = parse(collect_env.get_nvidia_driver_version(collect_env.run)) - else: - _driver_version = None -HAS_TORCH_UCC_COMPAT_NVIDIA_DRIVER = _driver_version is not None and _driver_version >= _TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION - - -class DistributedTestBase(common_distributed.MultiProcessTestCase): - - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - - def setUp(self) -> None: - super().setUp() - self._setup_pre_spawn() - self._spawn_processes() - - def tearDown(self) -> None: - super().tearDown() - - @property - def world_size(self) -> int: - return min(torch.cuda.device_count(), 4) - - @property - def init_method(self): - return f"{common_utils.FILE_SCHEMA}{self.file_name}" - - @classmethod - def _run(cls, rank, test_name, file_name, pipe): - self = cls(test_name) - self.assertTrue(torch.cuda.is_available()) - self.assertTrue(hasattr(self, "DISTRIBUTED_BACKEND")) - self.rank = rank - self.file_name = file_name - - print(f"[dist init] rank = {self.rank}, world_size = {self.world_size}") - - try: - dist.init_process_group( - init_method=self.init_method, - backend=self.DISTRIBUTED_BACKEND, - world_size=int(self.world_size), - rank=self.rank, - ) - except RuntimeError as e: - if "recompile" in e.args[0]: - print(f"Backend of {self.DISTRIBUTED_BACKEND} not available") - sys.exit(0) - raise - - torch.cuda.set_device(self.rank % torch.cuda.device_count()) - - dist.barrier() - self.run_test(test_name, pipe) - dist.barrier() - - dist.destroy_process_group() - sys.exit(0) - - def _setup_pre_spawn(self): - pass - - -class NcclDistributedTestBase(DistributedTestBase): - - DISTRIBUTED_BACKEND = "nccl" - - -@unittest.skipUnless( - HAS_TORCH_UCC, - "Requires [`torch_ucc`](https://github.com/facebookresearch/torch_ucc)", -) -@unittest.skipUnless( - HAS_TORCH_UCC_COMPAT_NVIDIA_DRIVER, - f"`torch_ucc` requires NVIDIA driver >= {_TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION} but {_driver_version} found. " - "See https://github.com/openucx/ucc/issues/496", -) -class UccDistributedTestBase(DistributedTestBase): - - DISTRIBUTED_BACKEND = "ucc" - - def _setup_pre_spawn(self) -> None: - self.master_addr = "localhost" - os.environ["MASTER_ADDR"] = "localhost" - self._has_master_port = "MASTER_PORT" in os.environ - if self._has_master_port: - self.master_port = os.environ["MASTER_PORT"] - else: - try: - from caffe2.torch.fb.common.utils import get_free_port - self.master_port = str(get_free_port()) - except ImportError: - self.master_port = "12375" - os.environ["MASTER_PORT"] = self.master_port - - self._has_ucx_tls = "UCX_TLS" in os.environ - if not self._has_ucx_tls: - os.environ["UCX_TLS"] = "tcp,cuda" - print('os.environ[\"UCX_TLS\"] = {}'.format(os.environ["UCX_TLS"])) - - def tearDown(self) -> None: - super().tearDown() - if not self._has_master_port: - del os.environ["MASTER_PORT"] - if not self._has_ucx_tls: - del os.environ["UCX_TLS"] - - @property - def init_method(self): - return "tcp://localhost:" + os.environ["MASTER_PORT"] diff --git a/apex/transformer/testing/global_vars.py b/apex/transformer/testing/global_vars.py deleted file mode 100644 index 6b85374..0000000 --- a/apex/transformer/testing/global_vars.py +++ /dev/null @@ -1,270 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Megatron global variables.""" -import os -import sys -import time - -import torch - -from apex.transformer.microbatches import build_num_microbatches_calculator -from .arguments import parse_args - -_GLOBAL_ARGS = None -_GLOBAL_NUM_MICROBATCHES_CALCULATOR = None -_GLOBAL_TOKENIZER = None -_GLOBAL_TENSORBOARD_WRITER = None -_GLOBAL_ADLR_AUTORESUME = None -_GLOBAL_TIMERS = None - - -def get_args(): - """Return arguments.""" - _ensure_var_is_initialized(_GLOBAL_ARGS, 'args') - return _GLOBAL_ARGS - - -def get_num_microbatches() -> int: - return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get() - - -def get_current_global_batch_size() -> int: - return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get_current_global_batch_size() - - -def update_num_microbatches(consumed_samples: int, *, consistency_check: bool = True) -> None: - """Update the number of microbatches upon the number of consumed samples. - - .. note:: - This function has no effect unless ``rampup_batch_size`` is set. - - Args: - consumed_samples: The number of consumed samples so far. Basically this is equal to - :math:`num_iter * global_batch_size`. - consistency_check: If :obj:`True`, sanity checks the consumed samples, i.e., check if - ``consumed_samples`` is divisible by :math:`micro_batch_size \times data_parallel_size`. - """ - _GLOBAL_NUM_MICROBATCHES_CALCULATOR.update(consumed_samples, consistency_check) - - -# def get_tokenizer(): -# """Return tokenizer.""" -# _ensure_var_is_initialized(_GLOBAL_TOKENIZER, 'tokenizer') -# return _GLOBAL_TOKENIZER - - -def get_tensorboard_writer(): - """Return tensorboard writer. It can be None so no need - to check if it is initialized.""" - return _GLOBAL_TENSORBOARD_WRITER - - -def get_adlr_autoresume(): - """ADLR autoresume object. It can be None so no need - to check if it is initialized.""" - return _GLOBAL_ADLR_AUTORESUME - - -def get_timers(): - """Return timers.""" - _ensure_var_is_initialized(_GLOBAL_TIMERS, 'timers') - return _GLOBAL_TIMERS - - -def set_global_variables(extra_args_provider=None, args_defaults={}, - ignore_unknown_args=False): - """Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers.""" - args = _parse_args(extra_args_provider=extra_args_provider, - defaults=args_defaults, - ignore_unknown_args=ignore_unknown_args) - # _build_num_microbatches_calculator(args) - # if args.vocab_file: - # _ = _build_tokenizer(args) - _set_tensorboard_writer(args) - _set_adlr_autoresume(args) - _set_timers() - - -def _parse_args(extra_args_provider=None, defaults={}, - ignore_unknown_args=False): - """Parse entire arguments.""" - global _GLOBAL_ARGS - _ensure_var_is_not_initialized(_GLOBAL_ARGS, 'args') - _GLOBAL_ARGS = parse_args(extra_args_provider=extra_args_provider, - defaults=defaults, - ignore_unknown_args=ignore_unknown_args) - return _GLOBAL_ARGS - - -def _build_num_microbatches_calculator(args): - - global _GLOBAL_NUM_MICROBATCHES_CALCULATOR - _ensure_var_is_not_initialized(_GLOBAL_NUM_MICROBATCHES_CALCULATOR, - 'num microbatches calculator') - - _GLOBAL_NUM_MICROBATCHES_CALCULATOR = build_num_microbatches_calculator( - args) - - -# def _build_tokenizer(args): -# """Initialize tokenizer.""" -# global _GLOBAL_TOKENIZER -# _ensure_var_is_not_initialized(_GLOBAL_TOKENIZER, 'tokenizer') -# _GLOBAL_TOKENIZER = build_tokenizer(args) -# return _GLOBAL_TOKENIZER - - -# def rebuild_tokenizer(args): -# global _GLOBAL_TOKENIZER -# _GLOBAL_TOKENIZER = None -# return _build_tokenizer(args) - - -def _set_tensorboard_writer(args): - """Set tensorboard writer.""" - global _GLOBAL_TENSORBOARD_WRITER - _ensure_var_is_not_initialized(_GLOBAL_TENSORBOARD_WRITER, - 'tensorboard writer') - - if hasattr(args, 'tensorboard_dir') and \ - args.tensorboard_dir and args.rank == (args.world_size - 1): - try: - from torch.utils.tensorboard import SummaryWriter - print('> setting tensorboard ...') - _GLOBAL_TENSORBOARD_WRITER = SummaryWriter( - log_dir=args.tensorboard_dir, - max_queue=args.tensorboard_queue_size) - except ModuleNotFoundError: - print('WARNING: TensorBoard writing requested but is not ' - 'available (are you using PyTorch 1.1.0 or later?), ' - 'no TensorBoard logs will be written.', flush=True) - - -def _set_adlr_autoresume(args): - """Initialize ADLR autoresume.""" - global _GLOBAL_ADLR_AUTORESUME - _ensure_var_is_not_initialized(_GLOBAL_ADLR_AUTORESUME, 'adlr autoresume') - - if args.adlr_autoresume: - if args.rank == 0: - print('enabling autoresume ...', flush=True) - sys.path.append(os.environ.get('SUBMIT_SCRIPTS', '.')) - try: - from userlib.auto_resume import AutoResume - except BaseException: - print('ADLR autoresume is not available, exiting ...') - sys.exit() - - _GLOBAL_ADLR_AUTORESUME = AutoResume - - -def _set_timers(): - """Initialize timers.""" - global _GLOBAL_TIMERS - _ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers') - _GLOBAL_TIMERS = Timers() - - -def _ensure_var_is_initialized(var, name): - """Make sure the input variable is not None.""" - assert var is not None, '{} is not initialized.'.format(name) - - -def _ensure_var_is_not_initialized(var, name): - """Make sure the input variable is not None.""" - assert var is None, '{} is already initialized.'.format(name) - - -class _Timer: - """Timer.""" - - def __init__(self, name): - self.name_ = name - self.elapsed_ = 0.0 - self.started_ = False - self.start_time = time.time() - - def start(self): - """Start the timer.""" - assert not self.started_, 'timer has already been started' - torch.cuda.synchronize() - self.start_time = time.time() - self.started_ = True - - def stop(self): - """Stop the timer.""" - assert self.started_, 'timer is not started' - torch.cuda.synchronize() - self.elapsed_ += (time.time() - self.start_time) - self.started_ = False - - def reset(self): - """Reset timer.""" - self.elapsed_ = 0.0 - self.started_ = False - - def elapsed(self, reset=True): - """Calculate the elapsed time.""" - started_ = self.started_ - # If the timing in progress, end it first. - if self.started_: - self.stop() - # Get the elapsed time. - elapsed_ = self.elapsed_ - # Reset the elapsed time - if reset: - self.reset() - # If timing was in progress, set it back. - if started_: - self.start() - return elapsed_ - - -class Timers: - """Group of timers.""" - - def __init__(self): - self.timers = {} - - def __call__(self, name): - if name not in self.timers: - self.timers[name] = _Timer(name) - return self.timers[name] - - def write(self, names, writer, iteration, normalizer=1.0, reset=False): - """Write timers to a tensorboard writer""" - # currently when using add_scalars, - # torch.utils.add_scalars makes each timer its own run, which - # polutes the runs list, so we just add each as a scalar - assert normalizer > 0.0 - for name in names: - value = self.timers[name].elapsed(reset=reset) / normalizer - writer.add_scalar(name + '-time', value, iteration) - - def log(self, names, normalizer=1.0, reset=True): - """Log a group of timers.""" - assert normalizer > 0.0 - string = 'time (ms)' - for name in names: - elapsed_time = self.timers[name].elapsed( - reset=reset) * 1000.0 / normalizer - string += ' | {}: {:.2f}'.format(name, elapsed_time) - if torch.distributed.is_initialized(): - if torch.distributed.get_rank() == ( - torch.distributed.get_world_size() - 1): - print(string, flush=True) - else: - print(string, flush=True) diff --git a/apex/transformer/testing/standalone_bert.py b/apex/transformer/testing/standalone_bert.py deleted file mode 100644 index dd66abc..0000000 --- a/apex/transformer/testing/standalone_bert.py +++ /dev/null @@ -1,255 +0,0 @@ -import contextlib - -import torch - -from apex.transformer import tensor_parallel -from apex.transformer.enums import AttnMaskType -from apex.transformer.enums import ModelType -from apex.transformer.layers import FusedLayerNorm as LayerNorm -from apex.transformer.testing.global_vars import get_args -from apex.transformer.testing.standalone_transformer_lm import ( - MegatronModule, - get_language_model, - get_linear_layer, - init_method_normal, - scaled_init_method_normal, - parallel_lm_logits, -) - - -def bert_extended_attention_mask(attention_mask): - # We create a 3D attention mask from a 2D tensor mask. - # [b, 1, s] - attention_mask_b1s = attention_mask.unsqueeze(1) - # [b, s, 1] - attention_mask_bs1 = attention_mask.unsqueeze(2) - # [b, s, s] - attention_mask_bss = attention_mask_b1s * attention_mask_bs1 - # [b, 1, s, s] - extended_attention_mask = attention_mask_bss.unsqueeze(1) - - # Convert attention mask to binary: - extended_attention_mask = (extended_attention_mask < 0.5) - - return extended_attention_mask - - -def bert_position_ids(token_ids): - # Create position ids - seq_length = token_ids.size(1) - position_ids = torch.arange(seq_length, dtype=torch.long, - device=token_ids.device) - position_ids = position_ids.unsqueeze(0).expand_as(token_ids) - - return position_ids - - -class BertLMHead(MegatronModule): - """Masked LM head for Bert - - Arguments: - mpu_vocab_size: model parallel size of vocabulary. - hidden_size: hidden size - init_method: init method for weight initialization - layernorm_epsilon: tolerance for layer norm divisions - parallel_output: whether output logits being distributed or not. - """ - - def __init__(self, mpu_vocab_size, hidden_size, init_method, - layernorm_epsilon, parallel_output): - - super(BertLMHead, self).__init__() - - args = get_args() - - self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size)) - # TODO: do we need this? - # mpu.set_tensor_model_parallel_attributes(self.bias, True, 0, 1) - self.parallel_output = parallel_output - - self.dense = get_linear_layer(hidden_size, hidden_size, init_method) - setattr(self.dense.weight, 'sequence_parallel', args.sequence_parallel) - setattr(self.dense.bias, 'sequence_parallel', args.sequence_parallel) - - self.layernorm = LayerNorm( - hidden_size, eps=layernorm_epsilon, sequence_parallel_enabled=args.sequence_parallel) - self.gelu = torch.nn.functional.gelu - if args.openai_gelu: - self.gelu = openai_gelu - elif args.onnx_safe: - self.gelu = erf_gelu - - - def forward(self, hidden_states, word_embeddings_weight): - hidden_states = self.dense(hidden_states) - hidden_states = self.gelu(hidden_states) - hidden_states = self.layernorm(hidden_states) - output = parallel_lm_logits(hidden_states, - word_embeddings_weight, - self.parallel_output, - bias=self.bias) - return output - - -def post_language_model_processing(lm_output, pooled_output, - lm_head, binary_head, - lm_labels, - logit_weights, - fp16_lm_cross_entropy): - # Output. - lm_logits = lm_head( - lm_output, logit_weights) - - binary_logits = None - if binary_head is not None: - binary_logits = binary_head(pooled_output) - - if lm_labels is None: - # [s b h] => [b s h] - return lm_logits.transpose(0, 1).contiguous(), binary_logits - else: - # [b s] => [s b] - lm_labels = lm_labels.transpose(0, 1).contiguous() - # lm_logits: [s b h] lm_labels: [s b] - if fp16_lm_cross_entropy: - assert lm_logits.dtype == torch.half - lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits, lm_labels) - else: - lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits.float(), - lm_labels) - return lm_loss, binary_logits - - -class BertModel(MegatronModule): - """Bert Language model.""" - - def __init__(self, - num_tokentypes=2, - add_binary_head=True, - parallel_output=True, - pre_process=True, - post_process=True, - cpu_offload=False): - super(BertModel, self).__init__() - args = get_args() - - self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy - self.add_binary_head = add_binary_head - self.parallel_output = parallel_output - self.pre_process = pre_process - self.post_process = post_process - - init_method = init_method_normal(args.init_method_std) - scaled_init_method = scaled_init_method_normal(args.init_method_std, - args.num_layers) - - self.language_model, self._language_model_key = get_language_model( - num_tokentypes=num_tokentypes, - add_pooler=self.add_binary_head, - encoder_attn_mask_type=AttnMaskType.padding, - init_method=init_method, - scaled_init_method=scaled_init_method, - pre_process=self.pre_process, - post_process=self.post_process) - - self.initialize_word_embeddings(init_method_normal) - if self.post_process: - self.lm_head = BertLMHead( - self.word_embeddings_weight().size(0), - args.hidden_size, init_method, args.layernorm_epsilon, parallel_output) - self._lm_head_key = 'lm_head' - self.binary_head = None - if self.add_binary_head: - self.binary_head = get_linear_layer(args.hidden_size, 2, - init_method) - self._binary_head_key = 'binary_head' - - self.forward_context = contextlib.nullcontext - if cpu_offload: - self.forward_context = torch.autograd.graph.save_on_cpu - - def set_input_tensor(self, input_tensor): - """See megatron.model.transformer.set_input_tensor()""" - self.language_model.set_input_tensor(input_tensor) - - def forward(self, bert_model_input, attention_mask, - tokentype_ids=None, lm_labels=None): - with self.forward_context(): - extended_attention_mask = bert_extended_attention_mask(attention_mask) - input_ids = bert_model_input - position_ids = bert_position_ids(input_ids) - - lm_output = self.language_model( - input_ids, - position_ids, - extended_attention_mask, - tokentype_ids=tokentype_ids - ) - - if self.post_process and self.add_binary_head: - lm_output, pooled_output = lm_output - else: - pooled_output = None - - if self.post_process: - return post_language_model_processing(lm_output, pooled_output, - self.lm_head, self.binary_head, - lm_labels, - self.word_embeddings_weight(), - self.fp16_lm_cross_entropy) - else: - return lm_output - - # NOTE(mkozuki): This method is not maintained as apex only tests forward_backward with best effort. - def state_dict_for_save_checkpoint(self, destination=None, prefix='', - keep_vars=False): - """For easy load when model is combined with other heads, - add an extra key.""" - - state_dict_ = {} - state_dict_[self._language_model_key] \ - = self.language_model.state_dict_for_save_checkpoint( - destination, prefix, keep_vars) - if self.post_process: - state_dict_[self._lm_head_key] \ - = self.lm_head.state_dict_for_save_checkpoint( - destination, prefix, keep_vars) - if self.post_process and self.add_binary_head: - state_dict_[self._binary_head_key] \ - = self.binary_head.state_dict(destination, prefix, keep_vars) - # Save word_embeddings. - if self.post_process and not self.pre_process: - state_dict_[self._word_embeddings_for_head_key] \ - = self.word_embeddings.state_dict(destination, prefix, keep_vars) - return state_dict_ - - # NOTE(mkozuki): This method is not maintained as apex only tests forward_backward with best effort. - def load_state_dict(self, state_dict, strict=True): - """Customized load.""" - - self.language_model.load_state_dict( - state_dict[self._language_model_key], strict=strict) - if self.post_process: - self.lm_head.load_state_dict( - state_dict[self._lm_head_key], strict=strict) - if self.post_process and self.add_binary_head: - self.binary_head.load_state_dict( - state_dict[self._binary_head_key], strict=strict) - # Load word_embeddings. - if self.post_process and not self.pre_process: - self.word_embeddings.load_state_dict( - state_dict[self._word_embeddings_for_head_key], strict=strict) - - -def bert_model_provider(pre_process=True, post_process=True, cpu_offload=False): - args = get_args() - num_tokentypes = 2 if args.bert_binary_head else 0 - model = BertModel( - num_tokentypes=num_tokentypes, - add_binary_head=args.bert_binary_head, - parallel_output=True, - pre_process=pre_process, - post_process=post_process, - cpu_offload=cpu_offload, - ) - return model diff --git a/apex/transformer/testing/standalone_gpt.py b/apex/transformer/testing/standalone_gpt.py deleted file mode 100644 index 0e3d464..0000000 --- a/apex/transformer/testing/standalone_gpt.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright (c) 2021-22, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import contextlib -import torch - -from apex.transformer.enums import AttnMaskType -from apex.transformer.enums import ModelType -from apex.transformer import tensor_parallel -from apex.transformer.testing.global_vars import get_args -from apex.transformer.testing.standalone_transformer_lm import MegatronModule -from apex.transformer.testing.standalone_transformer_lm import parallel_lm_logits -from apex.transformer.testing.standalone_transformer_lm import post_language_model_processing -from apex.transformer.testing.standalone_transformer_lm import get_language_model -from apex.transformer.testing.standalone_transformer_lm import init_method_normal -from apex.transformer.testing.standalone_transformer_lm import ( - scaled_init_method_normal, -) - - - -def gpt_model_provider(pre_process: bool = True, post_process: bool = True, cpu_offload: bool = False,) -> "GPTModel": - args = get_args() - model = GPTModel( - num_tokentypes=0, - parallel_output=True, - pre_process=pre_process, - post_process=post_process, - cpu_offload=args.cpu_offload, - ) - return model - - -class GPTModel(MegatronModule): - """GPT-2 Language model.""" - - def __init__( - self, - num_tokentypes:int = 0, - parallel_output: bool = True, - pre_process: bool = True, - post_process: bool = True, - cpu_offload: bool = False, - ): - super().__init__() - args = get_args() - - self.forward_context = contextlib.nullcontext - if cpu_offload: - self.forward_context = torch.autograd.graph.save_on_cpu - - self.parallel_output = parallel_output - self.pre_process = pre_process - self.post_process = post_process - self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy - - self.language_model, self._language_model_key = get_language_model( - num_tokentypes=num_tokentypes, - add_pooler=False, - encoder_attn_mask_type=AttnMaskType.causal, - init_method=init_method_normal(args.init_method_std), - scaled_init_method=scaled_init_method_normal( - args.init_method_std, args.num_layers - ), - pre_process=self.pre_process, - post_process=self.post_process, - ) - - self.initialize_word_embeddings(init_method_normal) - - def set_input_tensor(self, input_tensor): - """See megatron.model.transformer.set_input_tensor()""" - self.language_model.set_input_tensor(input_tensor) - - def forward( - self, - input_ids, - position_ids, - attention_mask, - labels=None, - tokentype_ids=None, - inference_params=None, - ): - - with self.forward_context(): - lm_output = self.language_model( - input_ids, position_ids, attention_mask, inference_params=inference_params - ) - - if self.post_process: - return post_language_model_processing( - lm_output, - # note(mkozuki): Am I overlooking some order of dim change? - labels.t().contiguous(), - self.word_embeddings_weight(), - self.parallel_output, - self.fp16_lm_cross_entropy, - ) - else: - return lm_output diff --git a/apex/transformer/testing/standalone_transformer_lm.py b/apex/transformer/testing/standalone_transformer_lm.py deleted file mode 100644 index 6cd90c7..0000000 --- a/apex/transformer/testing/standalone_transformer_lm.py +++ /dev/null @@ -1,1574 +0,0 @@ -# coding=utf-8 -# Copyright (c) 2021-22, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""GPT-2 model.""" -import enum -import math -import contextlib -import json - -import torch -import torch.nn.functional as F - -import apex.transformer.utils -from apex.transformer.layers import FusedLayerNorm as LayerNorm -from apex.transformer.functional import FusedScaleMaskSoftmax -from apex.transformer import tensor_parallel -from apex.transformer.tensor_parallel.layers import ColumnParallelLinear -from apex.transformer.tensor_parallel.layers import RowParallelLinear -from apex.transformer.tensor_parallel.layers import VocabParallelEmbedding -from apex.transformer.tensor_parallel.mappings import scatter_to_sequence_parallel_region -from apex.transformer import parallel_state -from apex.transformer.testing.global_vars import get_args -from apex.transformer.enums import ModelType -from apex.transformer.enums import LayerType -from apex.transformer.enums import AttnType -from apex.transformer.enums import AttnMaskType -from apex.transformer.log_util import get_transformer_logger - - -_logger = get_transformer_logger(__name__) - - -def param_is_not_shared(param: torch.Tensor) -> bool: - return getattr(param, "shared", False) - - -class MegatronModule(torch.nn.Module): - """Megatron specific extensions of torch Module with support for pipelining.""" - - def __init__(self, share_word_embeddings: bool = True) -> None: - super().__init__() - self.share_word_embeddings = share_word_embeddings - - def word_embeddings_weight(self): - if self.pre_process: - return self.language_model.embedding.word_embeddings.weight - else: - if not self.share_word_embeddings: - raise Exception('word_embeddings_weight() called for last stage, but share_word_embeddings is false') - return self.word_embeddings.weight - - - def initialize_word_embeddings(self, init_method_normal): - args = get_args() - if not self.share_word_embeddings: - raise Exception("initialize_word_embeddings() was called but share_word_embeddings is false") - - # This function just initializes the word embeddings in the final stage - # when we are using pipeline parallelism. Nothing to do if we aren't - # using pipeline parallelism. - if args.pipeline_model_parallel_size == 1: - return - - # Parameters are shared between the word embeddings layers, and the - # heads at the end of the model. In a pipelined setup with more than - # one stage, the initial embedding layer and the head are on different - # workers, so we do the following: - # 1. Create a second copy of word_embeddings on the last stage, with - # initial parameters of 0.0. - # 2. Do an all-reduce between the first and last stage to ensure that - # the two copies of word_embeddings start off with the same - # parameter values. - # 3. In the training loop, before an all-reduce between the grads of - # the two word_embeddings layers to ensure that every applied weight - # update is the same on both stages. - if parallel_state.is_pipeline_last_stage() and not self.pre_process: - assert not parallel_state.is_pipeline_first_stage() - self._word_embeddings_for_head_key = 'word_embeddings_for_head' - # set word_embeddings weights to 0 here, then copy first - # stage's weights using all_reduce below. - self.word_embeddings = VocabParallelEmbedding( - args.padded_vocab_size, args.hidden_size, - init_method=init_method_normal(args.init_method_std)) - self.word_embeddings.weight.data.fill_(0) - self.word_embeddings.weight.shared = True - - # Zero out initial weights for decoder embedding. - # NOTE: We don't currently support T5 with the interleaved schedule. - if not parallel_state.is_pipeline_first_stage(ignore_virtual=True) and self.pre_process: - self.language_model.embedding.zero_parameters() - - # Ensure that first and last stages have the same initial parameter - # values. - if torch.distributed.is_initialized(): - if parallel_state.is_rank_in_embedding_group(): - torch.distributed.all_reduce(self.word_embeddings_weight(), - group=parallel_state.get_embedding_group()) - - # Ensure that encoder(first stage) and decoder(split stage) position - # embeddings have the same initial parameter values - # NOTE: We don't currently support T5 with the interleaved schedule. - if parallel_state.is_rank_in_position_embedding_group() and \ - args.pipeline_model_parallel_split_rank is not None: - # TODO: Support tokentype embedding. - self.language_model.embedding.cuda() - position_embeddings = self.language_model.embedding.position_embeddings - torch.distributed.all_reduce(position_embeddings.weight, - group=parallel_state.get_position_embedding_group()) - - else: - print("WARNING! Distributed processes aren't initialized, so " - "word embeddings in the last layer are not initialized. " - "If you are just manipulating a model this is fine, but " - "this needs to be handled manually. If you are training " - "something is definitely wrong.") - - -def get_linear_layer(rows, columns, init_method): - """Simple linear layer with weight initialization.""" - layer = torch.nn.Linear(rows, columns) - init_method(layer.weight) - with torch.no_grad(): - layer.bias.zero_() - return layer - - -# NOTE(mkozuki): Avoid inplace op. -def attention_mask_func(attention_scores: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: - # attention_scores.masked_fill_(attention_mask, -10000.0) - # return attention_scores - return attention_scores.masked_fill(attention_mask, -10000.0) - - -def init_method_normal(sigma): - """Init method based on N(0, sigma).""" - - def init_(tensor): - return torch.nn.init.normal_(tensor, mean=0.0, std=sigma) - - return init_ - - -def scaled_init_method_normal(sigma, num_layers): - """Init method based on N(0, sigma/sqrt(2*num_layers).""" - std = sigma / math.sqrt(2.0 * num_layers) - - def init_(tensor): - return torch.nn.init.normal_(tensor, mean=0.0, std=std) - - return init_ - - -class ParallelMLP(MegatronModule): - """MLP. - - MLP will take the input with h hidden state, project it to 4*h - hidden dimension, perform nonlinear transformation, and project the - state back into h hidden dimension. - """ - - def __init__(self, init_method, output_layer_init_method): - super().__init__() - args = get_args() - - # Project to 4h. - self.dense_h_to_4h = ColumnParallelLinear( - args.hidden_size, - args.ffn_hidden_size, - gather_output=False, - init_method=init_method, - skip_bias_add=True, - no_async_tensor_model_parallel_allreduce=not args.async_tensor_model_parallel_allreduce, - sequence_parallel_enabled=args.sequence_parallel, - ) - - self.bias_gelu_fusion = args.bias_gelu_fusion - self.activation_func = F.gelu - - # Project back to h. - self.dense_4h_to_h = RowParallelLinear( - args.ffn_hidden_size, - args.hidden_size, - input_is_parallel=True, - init_method=output_layer_init_method, - skip_bias_add=True, - sequence_parallel_enabled=args.sequence_parallel, - ) - - def forward(self, hidden_states): - - # [s, b, 4hp] - intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states) - - intermediate_parallel = self.activation_func(intermediate_parallel + bias_parallel) - - # [s, b, h] - output, output_bias = self.dense_4h_to_h(intermediate_parallel) - return output, output_bias - - -class CoreAttention(MegatronModule): - - def __init__(self, layer_number, attn_mask_type=AttnMaskType.padding): - super().__init__() - args = get_args() - self.fp16 = args.fp16 - self.bf16 = args.bf16 - - self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling - self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32 - if self.apply_query_key_layer_scaling: - self.attention_softmax_in_fp32 = True - self.layer_number = max(1, layer_number) - self.attn_mask_type = attn_mask_type - self.sequence_parallel = args.sequence_parallel - - projection_size = args.kv_channels * args.num_attention_heads - - # Per attention head and per partition values. - world_size = parallel_state.get_tensor_model_parallel_world_size() - self.hidden_size_per_partition = apex.transformer.utils.divide( - projection_size, world_size - ) - self.hidden_size_per_attention_head = apex.transformer.utils.divide( - projection_size, args.num_attention_heads - ) - self.num_attention_heads_per_partition = apex.transformer.utils.divide( - args.num_attention_heads, world_size - ) - - coeff = None - self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) - if self.apply_query_key_layer_scaling: - coeff = self.layer_number - self.norm_factor *= coeff - - self.scale_mask_softmax = FusedScaleMaskSoftmax( - self.fp16, - self.bf16, - self.attn_mask_type, - args.masked_softmax_fusion, - attention_mask_func, - self.attention_softmax_in_fp32, - coeff, - ) - # Dropout. Note that for a single iteration, this layer will generate - # different outputs on different number of parallel partitions but - # on average it should not be partition dependent. - self.attention_dropout = torch.nn.Dropout(args.attention_dropout) - - def forward(self, query_layer, key_layer, value_layer, attention_mask): - # =================================== - # Raw attention scores. [b, np, s, s] - # =================================== - # [b, np, sq, sk] - output_size = ( - query_layer.size(1), - query_layer.size(2), - query_layer.size(0), - key_layer.size(0), - ) - # [sq, b, np, hn] -> [sq, b * np, hn] - query_layer = query_layer.view( - output_size[2], output_size[0] * output_size[1], -1 - ) - # [sk, b, np, hn] -> [sk, b * np, hn] - key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) - - # preallocting input tensor: [b * np, sq, sk] - matmul_input_buffer = torch.empty( - output_size[0] * output_size[1], - output_size[2], - output_size[3], - dtype=query_layer.dtype, - device=torch.cuda.current_device(), - ) - # Raw attention scores. [b * np, sq, sk] - matmul_result = torch.baddbmm( - matmul_input_buffer, - query_layer.transpose(0, 1), # [b * np, sq, hn] - key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] - beta=0.0, - alpha=(1.0 / self.norm_factor), - ) - - # change view to [b, np, sq, sk] - attention_scores = matmul_result.view(*output_size) - - # =========================== - # Attention probs and dropout - # =========================== - # attention scores and attention mask [b, np, sq, sk] - attention_probs = self.scale_mask_softmax(attention_scores, attention_mask) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - if not self.sequence_parallel: - with tensor_parallel.get_cuda_rng_tracker().fork(): - attention_probs = self.attention_dropout(attention_probs) - else: - attention_probs = self.attention_dropout(attention_probs) - - # ========================= - # Context layer. [sq, b, hp] - # ========================= - - # value_layer -> context layer. - # [sk, b, np, hn] --> [b, np, sq, hn] - - # context layer shape: [b, np, sq, hn] - output_size = ( - value_layer.size(1), - value_layer.size(2), - query_layer.size(0), - value_layer.size(3), - ) - - # change view [sk, b * np, hn] - value_layer = value_layer.view( - value_layer.size(0), output_size[0] * output_size[1], -1 - ) - - # change view [b * np, sq, sk] - attention_probs = attention_probs.view( - output_size[0] * output_size[1], output_size[2], -1 - ) - - # matmul: [b * np, sq, hn] - context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) - - # change view [b, np, sq, hn] - context_layer = context_layer.view(*output_size) - - # [b, np, sq, hn] --> [sq, b, np, hn] - context_layer = context_layer.permute(2, 0, 1, 3).contiguous() - - # [sq, b, np, hn] --> [sq, b, hp] - new_context_layer_shape = context_layer.size()[:-2] + ( - self.hidden_size_per_partition, - ) - context_layer = context_layer.view(*new_context_layer_shape) - - return context_layer - - -class ParallelAttention(MegatronModule): - """Parallel self-attention layer abstract class. - - Self-attention layer takes input with size [b, s, h] - and returns output of the same size. - """ - - def __init__( - self, - init_method, - output_layer_init_method, - layer_number, - attention_type=AttnType.self_attn, - attn_mask_type=AttnMaskType.padding, - ): - super().__init__() - args = get_args() - self.layer_number = max(1, layer_number) - self.attention_type = attention_type - self.attn_mask_type = attn_mask_type - self.params_dtype = args.params_dtype - - projection_size = args.kv_channels * args.num_attention_heads - - # Per attention head and per partition values. - world_size = parallel_state.get_tensor_model_parallel_world_size() - self.hidden_size_per_attention_head = apex.transformer.utils.divide( - projection_size, args.num_attention_heads - ) - self.num_attention_heads_per_partition = apex.transformer.utils.divide( - args.num_attention_heads, world_size - ) - - # Strided linear layer. - if attention_type == AttnType.self_attn: - self.query_key_value = ColumnParallelLinear( - args.hidden_size, - 3 * projection_size, - gather_output=False, - init_method=init_method, - no_async_tensor_model_parallel_allreduce=not args.async_tensor_model_parallel_allreduce, - sequence_parallel_enabled=args.sequence_parallel, - ) - else: - assert attention_type == AttnType.cross_attn - self.query = ColumnParallelLinear( - args.hidden_size, - projection_size, - gather_output=False, - init_method=init_method, - no_async_tensor_model_parallel_allreduce=not args.async_tensor_model_parallel_allreduce, - sequence_parallel_enabled=args.sequence_parallel, - ) - - self.key_value = ColumnParallelLinear( - args.hidden_size, - 2 * projection_size, - gather_output=False, - init_method=init_method, - no_async_tensor_model_parallel_allreduce=not args.async_tensor_model_parallel_allreduce, - sequence_parallel_enabled=args.sequence_parallel, - ) - - self.core_attention = CoreAttention(self.layer_number, self.attn_mask_type) - self.checkpoint_core_attention = args.recompute_granularity == "selective" - - # Output. - self.dense = RowParallelLinear( - projection_size, - args.hidden_size, - input_is_parallel=True, - init_method=output_layer_init_method, - skip_bias_add=True, - sequence_parallel_enabled=args.sequence_parallel, - ) - - def _checkpointed_attention_forward( - self, query_layer, key_layer, value_layer, attention_mask - ): - """Forward method with activation checkpointing.""" - - def custom_forward(*inputs): - query_layer = inputs[0] - key_layer = inputs[1] - value_layer = inputs[2] - attention_mask = inputs[3] - output_ = self.core_attention( - query_layer, key_layer, value_layer, attention_mask - ) - return output_ - - hidden_states = tensor_parallel.checkpoint( - custom_forward, False, query_layer, key_layer, value_layer, attention_mask - ) - - return hidden_states - - def _allocate_memory(self, inference_max_sequence_len, batch_size): - return torch.empty( - inference_max_sequence_len, - batch_size, - self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head, - dtype=self.params_dtype, - device=torch.cuda.current_device(), - ) - - def forward( - self, hidden_states, attention_mask, encoder_output=None, inference_params=None - ): - # hidden_states: [sq, b, h] - - # ================================================= - # Pre-allocate memory for key-values for inference. - # ================================================= - if inference_params: - if self.layer_number not in inference_params.key_value_memory_dict: - inf_max_seq_len = inference_params.max_sequence_len - inf_max_batch_size = inference_params.max_batch_size - inference_key_memory = self._allocate_memory( - inf_max_seq_len, inf_max_batch_size - ) - inference_value_memory = self._allocate_memory( - inf_max_seq_len, inf_max_batch_size - ) - inference_params.key_value_memory_dict[self.layer_number] = ( - inference_key_memory, - inference_value_memory, - ) - else: - ( - inference_key_memory, - inference_value_memory, - ) = inference_params.key_value_memory_dict[self.layer_number] - - # ===================== - # Query, Key, and Value - # ===================== - - if self.attention_type == AttnType.self_attn: - # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] - mixed_x_layer, _ = self.query_key_value(hidden_states) - - # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] - new_tensor_shape = mixed_x_layer.size()[:-1] + ( - self.num_attention_heads_per_partition, - 3 * self.hidden_size_per_attention_head, - ) - mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) - - # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] - ( - query_layer, - key_layer, - value_layer, - ) = tensor_parallel.utils.split_tensor_along_last_dim(mixed_x_layer, 3) - else: - # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] - mixed_kv_layer, _ = self.key_value(encoder_output) - - # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn] - new_tensor_shape = mixed_kv_layer.size()[:-1] + ( - self.num_attention_heads_per_partition, - 2 * self.hidden_size_per_attention_head, - ) - mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape) - - # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn] - ( - key_layer, - value_layer, - ) = tensor_parallel.utils.split_tensor_along_last_dim(mixed_kv_layer, 2) - - # Attention head [sq, b, h] --> [sq, b, hp] - query_layer, _ = self.query(hidden_states) - # [sq, b, hp] --> [sq, b, np, hn] - new_tensor_shape = query_layer.size()[:-1] + ( - self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head, - ) - query_layer = query_layer.view(*new_tensor_shape) - - # ================================== - # Adjust key and value for inference - # ================================== - - if inference_params: - batch_start = inference_params.batch_size_offset - batch_end = batch_start + key_layer.size(1) - assert batch_end <= inference_key_memory.size(1) - sequence_start = inference_params.sequence_len_offset - sequence_end = sequence_start + key_layer.size(0) - assert sequence_end <= inference_key_memory.size(0) - # Copy key and values. - inference_key_memory[ - sequence_start:sequence_end, batch_start:batch_end, ... - ] = key_layer - inference_value_memory[ - sequence_start:sequence_end, batch_start:batch_end, ... - ] = value_layer - key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...] - value_layer = inference_value_memory[ - :sequence_end, batch_start:batch_end, ... - ] - - # ================================== - # core attention computation - # ================================== - - if self.checkpoint_core_attention: - context_layer = self._checkpointed_attention_forward( - query_layer, key_layer, value_layer, attention_mask - ) - else: - context_layer = self.core_attention( - query_layer, key_layer, value_layer, attention_mask - ) - - # ================= - # Output. [sq, b, h] - # ================= - - output, bias = self.dense(context_layer) - - return output, bias - - -def bias_dropout_add(x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor: - out = torch.nn.functional.dropout(x + bias, p=prob, training=training) - out = residual + out - return out - - -def get_bias_dropout_add(training): - def _bias_dropout_add(x, bias, residual, prob): - return bias_dropout_add(x, bias, residual, prob, training) - - return _bias_dropout_add - - -class ParallelTransformerLayer(MegatronModule): - """A single transformer layer. - - Transformer layer takes input with size [s, b, h] and returns an - output of the same size. - """ - - def __init__( - self, - init_method, - output_layer_init_method, - layer_number, - layer_type=LayerType.encoder, - self_attn_mask_type=AttnMaskType.padding, - drop_path_rate=0.0, - ): - args = get_args() - - super().__init__() - self.layer_number = layer_number - self.layer_type = layer_type - - self.apply_residual_connection_post_layernorm = ( - args.apply_residual_connection_post_layernorm - ) - - self.bf16 = args.bf16 - self.fp32_residual_connection = args.fp32_residual_connection - - # Layernorm on the input data. - self.input_layernorm = LayerNorm( - args.hidden_size, - eps=args.layernorm_epsilon, - # no_persist_layer_norm=args.no_persist_layer_norm, - sequence_parallel_enabled=args.sequence_parallel, - ) - - # Self attention. - self.self_attention = ParallelAttention( - init_method, - output_layer_init_method, - layer_number, - attention_type=AttnType.self_attn, - attn_mask_type=self_attn_mask_type, - ) - self.hidden_dropout = args.hidden_dropout - self.bias_dropout_fusion = args.bias_dropout_fusion - # note(mkozuki) - # self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None - assert drop_path_rate <= 0.0 - self.drop_path = None - - # Layernorm on the attention output - self.post_attention_layernorm = LayerNorm( - args.hidden_size, - eps=args.layernorm_epsilon, - # no_persist_layer_norm=args.no_persist_layer_norm, - sequence_parallel_enabled=args.sequence_parallel, - ) - - if self.layer_type == LayerType.decoder: - self.inter_attention = ParallelAttention( - init_method, - output_layer_init_method, - layer_number, - attention_type=AttnType.cross_attn, - ) - # Layernorm on the attention output. - self.post_inter_attention_layernorm = LayerNorm( - args.hidden_size, - eps=args.layernorm_epsilon, - # no_persist_layer_norm=args.no_persist_layer_norm, - sequence_parallel_enabled=args.sequence_parallel, - ) - - # MLP - # note(mkozuki) - assert args.num_experts is None - # if args.num_experts is not None: - # self.mlp = SwitchMLP(init_method, output_layer_init_method) - # else: - # self.mlp = ParallelMLP(init_method, output_layer_init_method) - self.mlp = ParallelMLP(init_method, output_layer_init_method) - - # Set bias+dropout+add fusion grad_enable execution handler. - TORCH_MAJOR = int(torch.__version__.split(".")[0]) - TORCH_MINOR = int(torch.__version__.split(".")[1]) - use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10) - self.bias_dropout_add_exec_handler = ( - contextlib.nullcontext if use_nvfuser else torch.enable_grad - ) - - def forward( - self, - hidden_states, - attention_mask, - encoder_output=None, - enc_dec_attn_mask=None, - inference_params=None, - ): - # hidden_states: [s, b, h] - # Layer norm at the beginning of the transformer layer. - layernorm_output = self.input_layernorm(hidden_states) - # Self attention. - attention_output, attention_bias = self.self_attention( - layernorm_output, attention_mask, inference_params=inference_params - ) - - # Residual connection. - if self.apply_residual_connection_post_layernorm: - residual = layernorm_output - else: - residual = hidden_states - - if self.drop_path is None: - bias_dropout_add_func = get_bias_dropout_add(self.training) - - with self.bias_dropout_add_exec_handler(): - layernorm_input = bias_dropout_add_func( - attention_output, - attention_bias.expand_as(residual), - residual, - self.hidden_dropout, - ) - else: - out = torch.nn.functional.dropout( - attention_output + attention_bias, - p=self.hidden_dropout, - training=self.training, - ) - layernorm_input = residual + self.drop_path(out) - - # Layer norm post the self attention. - layernorm_output = self.post_attention_layernorm(layernorm_input) - - if self.layer_type == LayerType.decoder: - attention_output, attention_bias = self.inter_attention( - layernorm_output, enc_dec_attn_mask, encoder_output=encoder_output - ) - # residual connection - if self.apply_residual_connection_post_layernorm: - residual = layernorm_output - else: - residual = layernorm_input - - with self.bias_dropout_add_exec_handler(): - layernorm_input = bias_dropout_add_func( - attention_output, - attention_bias.expand_as(residual), - residual, - self.hidden_dropout, - ) - - # Layer norm post the decoder attention - layernorm_output = self.post_inter_attention_layernorm(layernorm_input) - - # MLP. - mlp_output, mlp_bias = self.mlp(layernorm_output) - - # Second residual connection. - if self.apply_residual_connection_post_layernorm: - residual = layernorm_output - else: - residual = layernorm_input - - if self.drop_path is None: - with self.bias_dropout_add_exec_handler(): - output = bias_dropout_add_func( - mlp_output, - mlp_bias.expand_as(residual), - residual, - self.hidden_dropout, - ) - else: - out = torch.nn.functional.dropout( - mlp_output + mlp_bias, p=self.hidden_dropout, training=self.training - ) - output = residual + self.drop_path(out) - - return output - - -class ParallelTransformer(MegatronModule): - """Transformer class.""" - - def __init__( - self, - init_method, - output_layer_init_method, - layer_type=LayerType.encoder, - self_attn_mask_type=AttnMaskType.padding, - post_layer_norm=True, - pre_process=True, - post_process=True, - drop_path_rate=0.0, - ): - super().__init__() - args = get_args() - - self.layer_type = layer_type - self.model_type = args.model_type - self.bf16 = args.bf16 - self.fp32_residual_connection = args.fp32_residual_connection - self.post_layer_norm = post_layer_norm - self.pre_process = pre_process - self.post_process = post_process - self.input_tensor = None - self.drop_path_rate = drop_path_rate - - # Store activation checkpoiting flag. - self.recompute_granularity = args.recompute_granularity - self.recompute_method = args.recompute_method - self.recompute_num_layers = args.recompute_num_layers - self.distribute_saved_activations = ( - args.distribute_saved_activations and not args.sequence_parallel - ) - - self.sequence_parallel = args.sequence_parallel - - # Number of layers. - self.num_layers = get_num_layers( - args, args.model_type == ModelType.encoder_and_decoder - ) - - self.drop_path_rates = [ - rate.item() - for rate in torch.linspace(0, self.drop_path_rate, args.num_layers) - ] - - # Transformer layers. - def build_layer(layer_number): - return ParallelTransformerLayer( - init_method, - output_layer_init_method, - layer_number, - layer_type=layer_type, - self_attn_mask_type=self_attn_mask_type, - drop_path_rate=self.drop_path_rates[layer_number - 1], - ) - - if args.virtual_pipeline_model_parallel_size is not None: - assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, ( - "num_layers_per_stage must be divisible by " - "virtual_pipeline_model_parallel_size" - ) - assert args.model_type != ModelType.encoder_and_decoder - # Number of layers in each model chunk is the number of layers in the stage, - # divided by the number of model chunks in a stage. - self.num_layers = ( - self.num_layers // args.virtual_pipeline_model_parallel_size - ) - # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of - # layers to stages like (each list is a model chunk): - # Stage 0: [0] [2] [4] [6] - # Stage 1: [1] [3] [5] [7] - # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of - # layers to stages like (each list is a model chunk): - # Stage 0: [0, 1] [4, 5] - # Stage 1: [2, 3] [6, 7] - offset = parallel_state.get_virtual_pipeline_model_parallel_rank() * ( - args.num_layers // args.virtual_pipeline_model_parallel_size - ) + (parallel_state.get_pipeline_model_parallel_rank() * self.num_layers) - else: - # Each stage gets a contiguous set of layers. - if ( - args.model_type == ModelType.encoder_and_decoder - and parallel_state.get_pipeline_model_parallel_world_size() > 1 - ): - pipeline_rank = parallel_state.get_pipeline_model_parallel_rank() - if layer_type == LayerType.encoder: - offset = pipeline_rank * self.num_layers - else: - num_ranks_in_enc = args.pipeline_model_parallel_split_rank - offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers - else: - offset = ( - parallel_state.get_pipeline_model_parallel_rank() * self.num_layers - ) - - if self.num_layers == 0: - # When a standalone embedding stage is used (e.g., - # args.standalone_embedding_stage == True), virtual pipeline ranks - # on pipeline rank 0 will have zero transformer layers assigned to - # them. This results in the model's input and output tensors to be - # the same, which will cause failure for certain output tensor - # optimizations (e.g., pipeline output deallocation). To remedy - # this, we assign a 'no-op' layer on these ranks, which will - # disconnect the input tensor from the output tensor. - self.num_layers = 1 - self.layers = torch.nn.ModuleList([NoopTransformerLayer(1)]) - else: - self.layers = torch.nn.ModuleList( - [build_layer(i + 1 + offset) for i in range(self.num_layers)] - ) - - if self.post_process and self.post_layer_norm: - # Final layer norm before output. - self.final_layernorm = LayerNorm( - args.hidden_size, - eps=args.layernorm_epsilon, - # no_persist_layer_norm=args.no_persist_layer_norm, - sequence_parallel_enabled=args.sequence_parallel, - ) - - def _get_layer(self, layer_number): - return self.layers[layer_number] - - def _checkpointed_forward( - self, hidden_states, attention_mask, encoder_output, enc_dec_attn_mask - ): - """Forward method with activation checkpointing.""" - - def custom(start, end): - def custom_forward(*inputs): - x_ = inputs[0] - attention_mask = inputs[1] - encoder_output = inputs[2] - enc_dec_attn_mask = inputs[3] - for index in range(start, end): - layer = self._get_layer(index) - x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask) - return x_ - - return custom_forward - - if self.recompute_method == "uniform": - # Uniformly divide the total number of Transformer layers and checkpoint - # the input activation of each divided chunk. - # A method to further reduce memory usage reducing checkpoints. - l = 0 - while l < self.num_layers: - hidden_states = tensor_parallel.random.checkpoint( - custom(l, l + self.recompute_num_layers), - self.distribute_saved_activations, - hidden_states, - attention_mask, - encoder_output, - enc_dec_attn_mask, - ) - l += self.recompute_num_layers - - elif self.recompute_method == "block": - # Checkpoint the input activation of only a set number of individual - # Transformer layers and skip the rest. - # A method fully use the device memory removing redundant re-computation. - for l in range(self.num_layers): - if l < self.recompute_num_layers: - hidden_states = tensor_parallel.random.checkpoint( - custom(l, l + 1), - self.distribute_saved_activations, - hidden_states, - attention_mask, - encoder_output, - enc_dec_attn_mask, - ) - else: - hidden_states = custom(l, l + 1)( - hidden_states, attention_mask, encoder_output, enc_dec_attn_mask - ) - else: - raise ValueError("Invalid activation recompute method.") - - return hidden_states - - def set_input_tensor(self, input_tensor): - """Set input tensor to be used instead of forward()'s input. - - When doing pipeline parallelism the input from the previous - stage comes from communication, not from the input, so the - model's forward_step_func won't have it. This function is thus - used by internal code to bypass the input provided by the - forward_step_func""" - self.input_tensor = input_tensor - - def forward( - self, - hidden_states, - attention_mask, - encoder_output=None, - enc_dec_attn_mask=None, - inference_params=None, - ): - # hidden_states: [s, b, h] - - # Checks. - if inference_params: - assert ( - self.recompute_granularity is None - ), "inference does not work with activation checkpointing" - - if not self.pre_process: - # See set_input_tensor() - hidden_states = self.input_tensor - - # Viewless tensor. - # - We only need to create a viewless tensor in the case of micro batch - # size (mbs) == 1, since in this case, 'hidden_states.transpose()' - # above creates a view tensor, and '.contiguous()' is a pass-through. - # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating - # the need to make it viewless. - # - # However, we don't explicitly check mbs == 1 here because - # make_viewless_tensor() has negligible overhead when its input - # is already viewless. - # - # - For the 'else' case above, calling make_viewless_tensor() here is - # likely redundant, since p2p_communication.py (likely originator) - # already creates viewless tensors. That said, make_viewless_tensor() - # is called here to be future-proof and corner-case-proof. - # hidden_states = mpu.make_viewless_tensor(hidden_states, requires_grad=True, keep_graph=True) - - if self.sequence_parallel: - rng_context = tensor_parallel.get_cuda_rng_tracker().fork() - else: - rng_context = contextlib.nullcontext() - - with rng_context: - # Forward pass. - if self.recompute_granularity == "full": - hidden_states = self._checkpointed_forward( - hidden_states, attention_mask, encoder_output, enc_dec_attn_mask - ) - else: - for index in range(self.num_layers): - layer = self._get_layer(index) - hidden_states = layer( - hidden_states, - attention_mask, - encoder_output=encoder_output, - enc_dec_attn_mask=enc_dec_attn_mask, - inference_params=inference_params, - ) - - # Final layer norm. - if self.post_process and self.post_layer_norm: - hidden_states = self.final_layernorm(hidden_states) - - return hidden_states - - -def get_num_layers(args, is_encoder_and_decoder_model): - """Compute the number of transformer layers resident on the current rank.""" - if parallel_state.get_pipeline_model_parallel_world_size() > 1: - if is_encoder_and_decoder_model: - assert args.pipeline_model_parallel_split_rank is not None - - # When a standalone embedding stage is used, a rank is taken from - # the encoder's ranks, to be used for the encoder's embedding - # layer. This way, the rank referenced by the 'split rank' remains - # the same whether or not a standalone embedding stage is used. - num_ranks_in_encoder = ( - args.pipeline_model_parallel_split_rank - 1 - if args.standalone_embedding_stage - else args.pipeline_model_parallel_split_rank - ) - num_ranks_in_decoder = ( - args.transformer_pipeline_model_parallel_size - num_ranks_in_encoder - ) - assert args.num_layers % num_ranks_in_encoder == 0, ( - "num_layers (%d) must be divisible by number of ranks given to encoder (%d)" - % ( - args.num_layers, - num_ranks_in_encoder, - ) - ) - assert args.num_layers % num_ranks_in_decoder == 0, ( - "num_layers (%d) must be divisible by number of ranks given to decoder (%d)" - % ( - args.num_layers, - num_ranks_in_decoder, - ) - ) - if parallel_state.is_pipeline_stage_before_split(): - num_layers = ( - 0 - if args.standalone_embedding_stage - and parallel_state.get_pipeline_model_parallel_rank() == 0 - else args.num_layers // num_ranks_in_encoder - ) - else: - num_layers = args.num_layers // num_ranks_in_decoder - else: - assert ( - args.num_layers % args.transformer_pipeline_model_parallel_size == 0 - ), "num_layers must be divisible by transformer_pipeline_model_parallel_size" - - # When a standalone embedding stage is used, all transformer layers - # are divided among pipeline rank >= 1, while on pipeline rank 0, - # ranks either contain the input embedding layer (virtual pp rank 0), - # or no layers at all (virtual pp rank >= 1). - num_layers = ( - 0 - if args.standalone_embedding_stage - and parallel_state.get_pipeline_model_parallel_rank() == 0 - else args.num_layers // args.transformer_pipeline_model_parallel_size - ) - else: - num_layers = args.num_layers - return num_layers - - -class NoopTransformerLayer(MegatronModule): - """A single 'no-op' transformer layer. - - The sole purpose of this layer is for when a standalone embedding layer - is used (i.e., args.standalone_embedding_stage == True). In this case, - zero transformer layers are assigned when pipeline rank == 0. Additionally, - when virtual pipeline rank >= 1, zero total model parameters are created - (virtual rank 0 contains the input embedding). This results in the model's - input and output tensors being the same, which causes an error when - performing certain memory optimiations on the output tensor (e.g., - deallocating it). Thus, this layer disconnects the input from the output - via a clone. Since ranks containing a no-op layer are generally under- - utilized (both compute and memory), there's no worry of any performance - degredation. - """ - - def __init__(self, layer_number): - super().__init__() - self.layer_number = layer_number - - def forward( - self, - hidden_states, - attention_mask, - encoder_output=None, - enc_dec_attn_mask=None, - inference_params=None, - ): - return hidden_states.clone() - - -def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None): - """LM logits using word embedding weights.""" - args = get_args() - # Parallel logits. - if args.async_tensor_model_parallel_allreduce or args.sequence_parallel: - input_parallel = input_ - model_parallel = parallel_state.get_tensor_model_parallel_world_size() > 1 - async_grad_allreduce = ( - args.async_tensor_model_parallel_allreduce - and model_parallel - and not args.sequence_parallel - ) - else: - input_parallel = tensor_parallel.copy_to_tensor_model_parallel_region(input_) - async_grad_allreduce = False - - # Matrix multiply. - # logits_parallel = tensor_parallel.layers.LinearWithGradAccumulationAndAsyncCommunication.apply( - # input_parallel, word_embeddings_weight, bias, args.gradient_accumulation_fusion, async_grad_allreduce, args.sequence_parallel) - logits_parallel = ( - tensor_parallel.layers.linear_with_grad_accumulation_and_async_allreduce( - input_parallel, - word_embeddings_weight, - bias, - args.gradient_accumulation_fusion, - async_grad_allreduce, - args.sequence_parallel, - ) - ) - # Gather if needed. - - if parallel_output: - return logits_parallel - - return tensor_parallel.gather_from_tensor_model_parallel_region(logits_parallel) - - -def get_language_model( - num_tokentypes, - add_pooler, - encoder_attn_mask_type, - init_method=None, - scaled_init_method=None, - add_encoder=True, - add_decoder=False, - decoder_attn_mask_type=AttnMaskType.causal, - pre_process=True, - post_process=True, -): - """Build language model and return along with the key to save.""" - args = get_args() - - if init_method is None: - init_method = init_method_normal(args.init_method_std) - if scaled_init_method is None: - scaled_init_method = scaled_init_method_normal( - args.init_method_std, args.num_layers - ) - - # Language model. - language_model = TransformerLanguageModel( - init_method, - scaled_init_method, - encoder_attn_mask_type, - num_tokentypes=num_tokentypes, - add_encoder=add_encoder, - add_decoder=add_decoder, - decoder_attn_mask_type=decoder_attn_mask_type, - add_pooler=add_pooler, - pre_process=pre_process, - post_process=post_process, - ) - # key used for checkpoints. - language_model_key = "language_model" - - return language_model, language_model_key - - -class Pooler(MegatronModule): - """Pooler layer. - - Pool hidden states of a specific token (for example start of the - sequence) and add a linear transformation followed by a tanh. - - Arguments: - hidden_size: hidden size - init_method: weight initialization method for the linear layer. - bias is set to zero. - """ - - def __init__(self, hidden_size, init_method): - super().__init__() - args = get_args() - self.dense = get_linear_layer(hidden_size, hidden_size, init_method) - self.sequence_parallel = args.sequence_parallel - - def forward(self, hidden_states, sequence_index=0): - # hidden_states: [s, b, h] - # sequence_index: index of the token to pool. - # gather data along sequence dimensions - # same pooler is run on all tensor parallel nodes - if self.sequence_parallel: - hidden_states = tensor_parallel.mappings.gather_from_sequence_parallel_region(hidden_states) - pooled = hidden_states[sequence_index, :, :] - pooled = self.dense(pooled) - pooled = torch.tanh(pooled) - return pooled - - -class Embedding(MegatronModule): - """Language model embeddings. - - Arguments: - hidden_size: hidden size - vocab_size: vocabulary size - max_sequence_length: maximum size of sequence. This - is used for positional embedding - embedding_dropout_prob: dropout probability for embeddings - init_method: weight initialization method - num_tokentypes: size of the token-type embeddings. 0 value - will ignore this embedding - """ - - def __init__( - self, - hidden_size, - vocab_size, - max_sequence_length, - embedding_dropout_prob, - init_method, - num_tokentypes=0, - ): - super().__init__() - - self.hidden_size = hidden_size - self.init_method = init_method - self.num_tokentypes = num_tokentypes - - args = get_args() - - # Word embeddings (parallel). - self.word_embeddings = VocabParallelEmbedding( - vocab_size, self.hidden_size, init_method=self.init_method - ) - self._word_embeddings_key = "word_embeddings" - - # Position embedding (serial). - self.position_embeddings = torch.nn.Embedding( - max_sequence_length, self.hidden_size - ) - self._position_embeddings_key = "position_embeddings" - # Initialize the position embeddings. - self.init_method(self.position_embeddings.weight) - - # Token type embedding. - # Add this as an optional field that can be added through - # method call so we can load a pretrain model without - # token types and add them as needed. - self._tokentype_embeddings_key = "tokentype_embeddings" - if self.num_tokentypes > 0: - self.tokentype_embeddings = torch.nn.Embedding( - self.num_tokentypes, self.hidden_size - ) - # Initialize the token-type embeddings. - self.init_method(self.tokentype_embeddings.weight) - else: - self.tokentype_embeddings = None - - self.fp32_residual_connection = args.fp32_residual_connection - self.sequence_parallel = args.sequence_parallel - # Embeddings dropout - self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) - - def zero_parameters(self): - """Zero out all parameters in embedding.""" - self.word_embeddings.weight.data.fill_(0) - self.word_embeddings.weight.shared = True - self.position_embeddings.weight.data.fill_(0) - self.position_embeddings.weight.shared = True - if self.num_tokentypes > 0: - self.tokentype_embeddings.weight.fill_(0) - self.tokentype_embeddings.weight.shared = True - - def add_tokentype_embeddings(self, num_tokentypes): - """Add token-type embedding. This function is provided so we can add - token-type embeddings in case the pretrained model does not have it. - This allows us to load the model normally and then add this embedding. - """ - if self.tokentype_embeddings is not None: - raise Exception("tokentype embeddings is already initialized") - if torch.distributed.get_rank() == 0: - print( - "adding embedding for {} tokentypes".format(num_tokentypes), flush=True - ) - self.num_tokentypes = num_tokentypes - self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, self.hidden_size) - # Initialize the token-type embeddings. - self.init_method(self.tokentype_embeddings.weight) - - def forward(self, input_ids, position_ids, tokentype_ids=None): - # Embeddings. - words_embeddings = self.word_embeddings(input_ids) - position_embeddings = self.position_embeddings(position_ids) - embeddings = words_embeddings + position_embeddings - if tokentype_ids is not None: - assert self.tokentype_embeddings is not None - embeddings = embeddings + self.tokentype_embeddings(tokentype_ids) - else: - assert self.tokentype_embeddings is None - - # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. - embeddings = embeddings.transpose(0, 1).contiguous() - - # If the input flag for fp32 residual connection is set, convert for float. - if self.fp32_residual_connection: - embeddings = embeddings.float() - - # Dropout. - if self.sequence_parallel: - embeddings = scatter_to_sequence_parallel_region(embeddings) - with tensor_parallel.get_cuda_rng_tracker().fork(): - embeddings = self.embedding_dropout(embeddings) - else: - embeddings = self.embedding_dropout(embeddings) - - return embeddings - - -class TransformerLanguageModel(MegatronModule): - """Transformer language model. - - Arguments: - transformer_hparams: transformer hyperparameters - vocab_size: vocabulary size - max_sequence_length: maximum size of sequence. This - is used for positional embedding - embedding_dropout_prob: dropout probability for embeddings - num_tokentypes: size of the token-type embeddings. 0 value - will ignore this embedding - """ - - def __init__( - self, - init_method, - output_layer_init_method, - encoder_attn_mask_type, - num_tokentypes=0, - add_encoder=True, - add_decoder=False, - decoder_attn_mask_type=AttnMaskType.causal, - add_pooler=False, - pre_process=True, - post_process=True, - ): - super().__init__() - args = get_args() - - self.pre_process = pre_process - self.post_process = post_process - self.hidden_size = args.hidden_size - self.num_tokentypes = num_tokentypes - self.init_method = init_method - self.add_encoder = add_encoder - self.encoder_attn_mask_type = encoder_attn_mask_type - self.add_decoder = add_decoder - self.decoder_attn_mask_type = decoder_attn_mask_type - self.add_pooler = add_pooler - self.encoder_hidden_state = None - - # Embeddings. - if self.pre_process: - self.embedding = Embedding( - self.hidden_size, - args.padded_vocab_size, - args.max_position_embeddings, - args.hidden_dropout, - self.init_method, - self.num_tokentypes, - ) - self._embedding_key = "embedding" - - # Transformer. - # Encoder (usually set to True, False if part of an encoder-decoder - # architecture and in encoder-only stage). - if self.add_encoder: - self.encoder = ParallelTransformer( - self.init_method, - output_layer_init_method, - self_attn_mask_type=self.encoder_attn_mask_type, - pre_process=self.pre_process, - post_process=self.post_process, - ) - self._encoder_key = "encoder" - else: - self.encoder = None - - # Decoder (usually set to False, True if part of an encoder-decoder - # architecture and in decoder-only stage). - if self.add_decoder: - self.decoder = ParallelTransformer( - self.init_method, - output_layer_init_method, - layer_type=LayerType.decoder, - self_attn_mask_type=self.decoder_attn_mask_type, - pre_process=self.pre_process, - post_process=self.post_process, - ) - self._decoder_key = "decoder" - else: - self.decoder = None - - if self.post_process: - # Pooler. - if self.add_pooler: - self.pooler = Pooler(self.hidden_size, self.init_method) - self._pooler_key = "pooler" - - def set_input_tensor(self, input_tensor): - """See megatron.model.transformer.set_input_tensor()""" - - # This is usually handled in schedules.py but some inference code still - # gives us non-lists or None - if not isinstance(input_tensor, list): - input_tensor = [input_tensor] - - if self.add_encoder and self.add_decoder: - assert ( - len(input_tensor) == 1 - ), "input_tensor should only be length 1 for stage with both encoder and decoder" - self.encoder.set_input_tensor(input_tensor[0]) - elif self.add_encoder: - assert ( - len(input_tensor) == 1 - ), "input_tensor should only be length 1 for stage with only encoder" - self.encoder.set_input_tensor(input_tensor[0]) - elif self.add_decoder: - if len(input_tensor) == 2: - self.decoder.set_input_tensor(input_tensor[0]) - self.encoder_hidden_state = input_tensor[1] - elif len(input_tensor) == 1: - self.decoder.set_input_tensor(None) - self.encoder_hidden_state = input_tensor[0] - else: - raise Exception("input_tensor must have either length 1 or 2") - else: - raise Exception("Stage must have at least either encoder or decoder") - - def forward( - self, - enc_input_ids, - enc_position_ids, - enc_attn_mask, - dec_input_ids=None, - dec_position_ids=None, - dec_attn_mask=None, - enc_dec_attn_mask=None, - tokentype_ids=None, - inference_params=None, - pooling_sequence_index=0, - enc_hidden_states=None, - output_enc_hidden=False, - ): - - args = get_args() - # Encoder embedding. - if self.pre_process: - encoder_input = self.embedding( - enc_input_ids, enc_position_ids, tokentype_ids=tokentype_ids - ) - else: - encoder_input = None - - # Run encoder. - if enc_hidden_states is None: - if self.encoder is not None: - encoder_output = self.encoder( - encoder_input, enc_attn_mask, inference_params=inference_params - ) - else: - encoder_output = self.encoder_hidden_state - else: - encoder_output = enc_hidden_states.to(encoder_input.dtype) - - if self.post_process: - if self.add_pooler: - pooled_output = self.pooler(encoder_output, pooling_sequence_index) - - # output_enc_hidden refers to when we just need the encoder's - # output. For example, it is helpful to compute - # similarity between two sequences by average pooling - if not self.add_decoder or output_enc_hidden: - if self.add_pooler and self.post_process: - return encoder_output, pooled_output - else: - return encoder_output - - # Decoder embedding. - if self.pre_process: - decoder_input = self.embedding(dec_input_ids, dec_position_ids) - else: - decoder_input = None - - # Run decoder. - decoder_output = self.decoder( - decoder_input, - dec_attn_mask, - encoder_output=encoder_output, - enc_dec_attn_mask=enc_dec_attn_mask, - inference_params=inference_params, - ) - - if self.add_pooler and self.post_process: - return decoder_output, encoder_output, pooled_output - else: - return decoder_output, encoder_output - - -def post_language_model_processing( - lm_output, labels, logit_weights, parallel_output, fp16_lm_cross_entropy -): - # Output. - output = parallel_lm_logits(lm_output, logit_weights, parallel_output) - - if labels is None: - return output - else: - if fp16_lm_cross_entropy: - assert output.dtype == torch.half - loss = tensor_parallel.vocab_parallel_cross_entropy(output, labels) - else: - loss = tensor_parallel.vocab_parallel_cross_entropy(output.float(), labels) - return loss - - -def module_size(m: torch.nn.Module, only_trainable: bool = False): - """ - returns the total number of parameters used by `m` (only counting - shared parameters once); if `only_trainable` is True, then only - includes parameters with `requires_grad = True` - """ - parameters = list(m.parameters()) - if only_trainable: - parameters = [p for p in parameters if p.requires_grad] - unique = {p.data_ptr(): p for p in parameters}.values() - return sum(p.numel() for p in unique) diff --git a/apex/transformer/utils.py b/apex/transformer/utils.py deleted file mode 100644 index 2377592..0000000 --- a/apex/transformer/utils.py +++ /dev/null @@ -1,48 +0,0 @@ -"""Utility functions used by both `pipeline_parallel` and `tensor_parallel`""" -import torch - -from apex.transformer import parallel_state - - -def ensure_divisibility(numerator, denominator): - """Ensure that numerator is divisible by the denominator.""" - assert numerator % denominator == 0, "{} is not divisible by {}".format( - numerator, denominator - ) - - -def divide(numerator, denominator): - """Ensure that numerator is divisible by the denominator and return - the division value.""" - ensure_divisibility(numerator, denominator) - return numerator // denominator - - -def split_tensor_into_1d_equal_chunks(tensor): - """Break a tensor into equal 1D chunks.""" - data = tensor.view(-1) - partition_size = ( - torch.numel(data) // parallel_state.get_tensor_model_parallel_world_size() - ) - start_index = partition_size * parallel_state.get_tensor_model_parallel_rank() - end_index = start_index + partition_size - return data[start_index:end_index] - - -def gather_split_1d_tensor(tensor): - """Opposite of above function, gather values from model parallel ranks.""" - world_size = parallel_state.get_tensor_model_parallel_world_size() - numel = torch.numel(tensor) - numel_gathered = world_size * numel - gathered = torch.empty( - numel_gathered, - dtype=tensor.dtype, - device=torch.cuda.current_device(), - requires_grad=False, - ) - torch.distributed._all_gather_base( - gathered, - tensor, - group=parallel_state.get_tensor_model_parallel_group() - ) - return gathered diff --git a/csrc/amp_C_frontend.cpp b/csrc/amp_C_frontend.cpp deleted file mode 100644 index c27ef91..0000000 --- a/csrc/amp_C_frontend.cpp +++ /dev/null @@ -1,194 +0,0 @@ -#include - -void multi_tensor_scale_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - float scale); - -void multi_tensor_sgd_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - float wd, - float momentum, - float dampening, - float lr, - bool nesterov, - bool first_run, - bool wd_after_momentum, - float scale); - -void multi_tensor_axpby_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - float a, - float b, - int arg_to_check); - -std::tuple multi_tensor_l2norm_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - at::optional per_tensor_python); - -std::tuple multi_tensor_l2norm_mp_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - at::optional per_tensor_python); - -std::tuple multi_tensor_l2norm_scale_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - float scale, - at::optional per_tensor_python); - -void multi_tensor_lamb_stage1_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - at::Tensor per_tensor_decay, - const int step, - const float beta1, - const float beta2, - const float epsilon, - at::Tensor global_grad_norm, - const float max_global_grad_norm); - -void multi_tensor_lamb_stage2_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - at::Tensor per_tensor_param_norm, - at::Tensor per_tensor_update_norm, - const float lr, - const float weight_decay, - at::optional use_nvlamb_python); - -void multi_tensor_adam_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - const float lr, - const float beta1, - const float beta2, - const float epsilon, - const int step, - const int mode, - const int bias_correction, - const float weight_decay); - - -void multi_tensor_adagrad_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - const float lr, - const float epsilon, - const int mode, - const float weight_decay); - - -void multi_tensor_novograd_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - at::Tensor grad_norms, - const float lr, - const float beta1, - const float beta2, - const float epsilon, - const int step, - const int bias_correction, - const float weight_decay, - const int grad_averaging, - const int mode, - const int norm_type); - -void multi_tensor_lamb_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - const float lr, - const float beta1, - const float beta2, - const float epsilon, - const int step, - const int bias_correction, - const float weight_decay, - const int grad_averaging, - const int mode, - at::Tensor global_grad_norm, - const float max_grad_norm, - at::optional use_nvlamb_python); - -void multi_tensor_lamb_mp_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - at::Tensor lr, - const float beta1, - const float beta2, - const float epsilon, - at::Tensor step, - const int bias_correction, - const float weight_decay, - const int grad_averaging, - const int mode, - at::Tensor global_grad_norm, - at::Tensor max_grad_norm, - at::optional use_nvlamb_python, - at::Tensor found_inf, - at::Tensor inv_scale); - -void multi_tensor_lars_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - at::Tensor grad_norms, - at::Tensor param_norms, - float lr, - float trust_coefficient, - float epsilon, - float weight_decay, - float momentum, - float dampening, - bool nesterov, - bool first_run, - bool wd_after_momentum, - float scale, - const bool is_skipped); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("multi_tensor_scale", &multi_tensor_scale_cuda, - "Fused overflow check + scale for a list of contiguous tensors"); - m.def("multi_tensor_sgd", &multi_tensor_sgd_cuda, - "Fused SGD optimizer for list of contiguous tensors"); - m.def("multi_tensor_axpby", &multi_tensor_axpby_cuda, - "out = a*x + b*y for a list of contiguous tensors"); - m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda, - "Computes L2 norm for a list of contiguous tensors"); - m.def("multi_tensor_l2norm_mp", &multi_tensor_l2norm_mp_cuda, - "Computes L2 norm for a list of contiguous tensors"); - m.def("multi_tensor_l2norm_scale", &multi_tensor_l2norm_scale_cuda, - "Computes L2 norm for a list of contiguous tensors and does scaling"); - m.def("multi_tensor_lamb_stage1_cuda", &multi_tensor_lamb_stage1_cuda, - "Computes update part of LAMB optimizer"); - m.def("multi_tensor_lamb_stage2_cuda", &multi_tensor_lamb_stage2_cuda, - "Completes application of gradient to parameters for LAMB optimizer"); - m.def("multi_tensor_adam", &multi_tensor_adam_cuda, - "Compute and apply gradient update to parameters for Adam optimizer"); - m.def("multi_tensor_adagrad", &multi_tensor_adagrad_cuda, - "Compute and apply gradient update to parameters for Adam optimizer"); - m.def("multi_tensor_novograd", &multi_tensor_novograd_cuda, - "Compute and apply gradient update to parameters for Adam optimizer"); - m.def("multi_tensor_lamb", &multi_tensor_lamb_cuda, - "Computes and apply update for LAMB optimizer"); - m.def("multi_tensor_lamb_mp", &multi_tensor_lamb_mp_cuda, - "Computes and apply update for LAMB optimizer"); - m.def("multi_tensor_lars", &multi_tensor_lars_cuda, - "Fused LARS optimizer for list of contiguous tensors"); -} diff --git a/csrc/compat.h b/csrc/compat.h deleted file mode 100644 index acafb05..0000000 --- a/csrc/compat.h +++ /dev/null @@ -1,9 +0,0 @@ -#ifndef TORCH_CHECK -#define TORCH_CHECK AT_CHECK -#endif - -#ifdef VERSION_GE_1_3 -#define DATA_PTR data_ptr -#else -#define DATA_PTR data -#endif diff --git a/csrc/flatten_unflatten.cpp b/csrc/flatten_unflatten.cpp deleted file mode 100644 index d49ce75..0000000 --- a/csrc/flatten_unflatten.cpp +++ /dev/null @@ -1,18 +0,0 @@ -#include -#include -// https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_flatten.h - -at::Tensor flatten(std::vector tensors) -{ - return torch::utils::flatten_dense_tensors(tensors); -} - -std::vector unflatten(at::Tensor flat, std::vector tensors) -{ - return torch::utils::unflatten_dense_tensors(flat, tensors); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("flatten", &flatten, "Flatten dense tensors"); - m.def("unflatten", &unflatten, "Unflatten dense tensors"); -} diff --git a/csrc/fused_dense.cpp b/csrc/fused_dense.cpp deleted file mode 100644 index da8e71f..0000000 --- a/csrc/fused_dense.cpp +++ /dev/null @@ -1,192 +0,0 @@ -#include -#include -#include - -#include - - -template -int linear_bias_forward_cuda(at::Tensor input, T *weight, at::Tensor bias, int in_features, int batch_size, int out_features, at::Tensor output, void *lt_workspace); - -template -int linear_bias_backward_cuda(T *input, T *weight, T *d_output, int in_features, int batch_size, int out_features, T *d_weight, T *d_bias, T *d_input, void *lt_workspace); - -template -int linear_gelu_linear_forward_cuda(T *input, T *weight1, T *bias1, T *weight2, T *bias2, int in_features, int hidden_features, int batch_size, int out_features, T *output1, T *output2, T *gelu_in, void *lt_workspace) ; - -template -int linear_gelu_linear_backward_cuda(T *input, T *gelu_in, T *output1, T *weight1, T *weight2, T *d_output1, T *d_output2, int in_features, int batch_size, int hidden_features, int out_features, T *d_weight1, T *d_weight2, T *d_bias1, T *d_bias2, T *d_input, void *lt_workspace); - -at::Tensor linear_bias_forward(at::Tensor input, at::Tensor weight, at::Tensor bias) { - - auto batch_size = input.size(0); - auto in_features = input.size(1); - - int out_features = weight.size(0); - - //auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data()); - - // create output/workspace tensor - auto out = at::empty({batch_size, out_features}, input.type()); - //auto reserved_space = at::empty({reserved_size}, inputs[0].type()); - // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB - auto lt_workspace = at::empty({1 << 22}, input.type()); - - AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "linear_bias_forward", [&] { - scalar_t* w_ptr = weight.data_ptr(); - scalar_t* b_ptr = bias.data_ptr(); - auto result = linear_bias_forward_cuda( - input, - w_ptr, - bias, - in_features, - batch_size, - out_features, - out, - //out.data_ptr(), - // reserved_space.data_ptr(), - (void*) (lt_workspace.data_ptr())); - }); - - return {out}; -} - -std::vector linear_bias_backward(at::Tensor input, at::Tensor weight, at::Tensor d_output) { - - auto batch_size = input.size(0); - auto in_features = input.size(1); - - int out_features = weight.size(0); - - //auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data()); - - // create output/workspace tensor - auto d_weight = at::empty({out_features, in_features}, input.type()); -#if (defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600) || __HIP_PLATFORM_HCC__ - auto d_bias = d_output.view({-1, out_features}).sum(0, false); -#else - auto d_bias = at::empty({out_features}, input.type()); -#endif - auto d_input = at::empty({batch_size, in_features}, input.type()); - //auto reserved_space = at::empty({reserved_size}, inputs[0].type()); - // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB - auto lt_workspace = at::empty({1 << 22}, input.type()); - - AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "linear_bias_backward", [&] { - scalar_t* w_ptr = weight.data_ptr(); - scalar_t* d_b_ptr = d_bias.data_ptr(); - auto result = linear_bias_backward_cuda( - input.data_ptr(), - w_ptr, - d_output.data_ptr(), - in_features, - batch_size, - out_features, - d_weight.data_ptr(), - d_bias.data_ptr(), - d_input.data_ptr(), - // reserved_space.data_ptr(), - (void*) (lt_workspace.data_ptr())); - }); - - return {d_input, d_weight, d_bias}; -} - -std::vector linear_gelu_linear_forward(at::Tensor input, at::Tensor weight1, at::Tensor bias1, at::Tensor weight2, at::Tensor bias2) { - - auto batch_size = input.size(0); - auto in_features = input.size(1); - - int hidden_features = weight1.size(0); - int out_features = weight2.size(0); - - //auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data()); - - // create output/workspace tensor - auto output1 = at::empty({batch_size, hidden_features}, input.type()); - auto gelu_in = at::empty({batch_size, hidden_features}, input.type()); - auto output2 = at::empty({batch_size, out_features}, input.type()); - //auto reserved_space = at::empty({reserved_size}, inputs[0].type()); - // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB - auto lt_workspace = at::empty({1 << 22}, input.type()); - - AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "linear_gelu_linear_forward", [&] { - scalar_t* w1_ptr = weight1.data_ptr(); - scalar_t* b1_ptr = bias1.data_ptr(); - scalar_t* w2_ptr = weight2.data_ptr(); - scalar_t* b2_ptr = bias2.data_ptr(); - auto result = linear_gelu_linear_forward_cuda( - input.data_ptr(), - w1_ptr, - b1_ptr, - w2_ptr, - b2_ptr, - in_features, - hidden_features, - batch_size, - out_features, - output1.data_ptr(), - output2.data_ptr(), - gelu_in.data_ptr(), - // reserved_space.data_ptr(), - (void*) (lt_workspace.data_ptr())); - }); - - return {output1, output2, gelu_in}; -} - -std::vector linear_gelu_linear_backward(at::Tensor input, at::Tensor gelu_in, at::Tensor output1, at::Tensor weight1, at::Tensor weight2, at::Tensor d_output2) { - - auto batch_size = input.size(0); - auto in_features = input.size(1); - - int hidden_features = weight1.size(0); - int out_features = weight2.size(0); - - //auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data()); - - // create output/workspace tensor - auto d_weight1 = at::empty({hidden_features, in_features}, input.type()); - auto d_weight2 = at::empty({out_features, hidden_features}, input.type()); - auto d_bias1 = at::empty({hidden_features}, input.type()); - auto d_bias2 = at::empty({out_features}, input.type()); - auto d_input = at::empty({batch_size, in_features}, input.type()); - auto d_output1 = at::empty({batch_size, hidden_features}, input.type()); - //auto reserved_space = at::empty({reserved_size}, inputs[0].type()); - // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB - auto lt_workspace = at::empty({1 << 22}, input.type()); - - AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "linear_bias_backward", [&] { - //scalar_t* w_ptr = weight.data_ptr(); - //scalar_t* d_b_ptr = d_bias.data_ptr(); - auto result = linear_gelu_linear_backward_cuda( - input.data_ptr(), - gelu_in.data_ptr(), - output1.data_ptr(), - weight1.data_ptr(), - weight2.data_ptr(), - d_output1.data_ptr(), - d_output2.data_ptr(), - in_features, - batch_size, - hidden_features, - out_features, - d_weight1.data_ptr(), - d_weight2.data_ptr(), - d_bias1.data_ptr(), - d_bias2.data_ptr(), - d_input.data_ptr(), - // reserved_space.data_ptr(), - (void*) (lt_workspace.data_ptr())); - }); - - return {d_input, d_weight1, d_bias1, d_weight2, d_bias2}; -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("linear_bias_forward", &linear_bias_forward, "linear bias forward"); - m.def("linear_bias_backward", &linear_bias_backward, "linear bias backward"); - m.def("linear_gelu_linear_forward", &linear_gelu_linear_forward, "linear gelu linear forward"); - m.def("linear_gelu_linear_backward", &linear_gelu_linear_backward, "linear gelu linear backward"); -} - diff --git a/csrc/fused_dense_cuda.cu b/csrc/fused_dense_cuda.cu deleted file mode 100644 index 7b01a38..0000000 --- a/csrc/fused_dense_cuda.cu +++ /dev/null @@ -1,1525 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include - -/* Includes, cuda */ -#include -#include - -#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000 -// includes cublaslt -#include -#endif -// FP64 Wrapper around cublas GEMMEx -cublasStatus_t gemm_bias( - cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float* alpha, - double* A, - int lda, - double* B, - int ldb, - const float* beta, - double* C, - int ldc) { -#ifdef __HIP_PLATFORM_HCC__ - return rocblas_gemm_ex( - handle, - transa, - transb, - m, - n, - k, - alpha, - A, - rocblas_datatype_f64_r, - lda, - B, - rocblas_datatype_f64_r, - ldb, - beta, - C, - rocblas_datatype_f64_r, - ldc, - C, - rocblas_datatype_f64_r, - ldc, - rocblas_datatype_f64_r, - rocblas_gemm_algo_standard, - 0, - 0); -#else - return cublasGemmEx( - handle, - transa, - transb, - m, - n, - k, - alpha, - A, - CUDA_R_64F, - lda, - B, - CUDA_R_64F, - ldb, - beta, - C, - CUDA_R_64F, - ldc, - CUDA_R_64F, - CUBLAS_GEMM_DEFAULT); -#endif -} - -// FP32 Wrapper around cublas GEMMEx -cublasStatus_t gemm_bias( - cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float* alpha, - float* A, - int lda, - float* B, - int ldb, - const float* beta, - float* C, - int ldc) { -#ifdef __HIP_PLATFORM_HCC__ - return rocblas_gemm_ex( - handle, - transa, - transb, - m, - n, - k, - alpha, - A, - rocblas_datatype_f32_r, - lda, - B, - rocblas_datatype_f32_r, - ldb, - beta, - C, - rocblas_datatype_f32_r, - ldc, - C, - rocblas_datatype_f32_r, - ldc, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard, - 0, - 0); - -#else - return cublasGemmEx( - handle, - transa, - transb, - m, - n, - k, - alpha, - A, - CUDA_R_32F, - lda, - B, - CUDA_R_32F, - ldb, - beta, - C, - CUDA_R_32F, - ldc, - CUDA_R_32F, - CUBLAS_GEMM_DEFAULT); -#endif -} - -// FP16 Tensor core wrapper around cublas GEMMEx -cublasStatus_t gemm_bias( - cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float* alpha, - at::Half* A, - int lda, - at::Half* B, - int ldb, - const float* beta, - at::Half* C, - int ldc) { -#ifdef __HIP_PLATFORM_HCC__ - return rocblas_gemm_ex( - handle, - transa, - transb, - m, - n, - k, - alpha, - A, - rocblas_datatype_f16_r, - lda, - B, - rocblas_datatype_f16_r, - ldb, - beta, - C, - rocblas_datatype_f16_r, - ldc, - C, - rocblas_datatype_f16_r, - ldc, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard, - 0, - 0); -#else - return cublasGemmEx( - handle, - transa, - transb, - m, - n, - k, - alpha, - A, - CUDA_R_16F, - lda, - B, - CUDA_R_16F, - ldb, - beta, - C, - CUDA_R_16F, - ldc, - CUDA_R_32F, - CUBLAS_GEMM_DEFAULT_TENSOR_OP); -#endif -} - - -#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600 - - -int gemm_bias_lt( - cublasLtHandle_t ltHandle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float *alpha, /* host pointer */ - at::Half* A, - int lda, - at::Half* B, - int ldb, - const float *beta, /* host pointer */ - at::Half* C, - int ldc, - void *workspace, - size_t workspaceSize, - cudaStream_t stream, - bool use_bias, - const void* bias) { - cublasStatus_t status = CUBLAS_STATUS_SUCCESS; - - cublasLtMatmulDescOpaque_t operationDesc = {}; - cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {}; - cublasLtMatmulPreferenceOpaque_t preference = {}; - - int returnedResults = 0; - cublasLtMatmulHeuristicResult_t heuristicResult = {}; - cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT; - - // Create operation descriptor; see cublasLtMatmulDescAttributes_t - // for details about defaults; here we just set the transforms for - // A and B. - status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - if (use_bias) { - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias)); - if (status != CUBLAS_STATUS_SUCCESS) { - goto CLEANUP; - } - epilogue = CUBLASLT_EPILOGUE_BIAS; - } - - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)); - if (status != CUBLAS_STATUS_SUCCESS) { - goto CLEANUP; - } - - // Create matrix descriptors. Not setting any extra attributes. - status = cublasLtMatrixLayoutInit( - &Adesc, CUDA_R_16F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatrixLayoutInit( - &Bdesc, CUDA_R_16F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16F, m, n, ldc); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - // Create preference handle; In general, extra attributes can be - // used here to disable tensor ops or to make sure algo selected - // will work with badly aligned A, B, C. However, for simplicity - // here we assume A,B,C are always well aligned (e.g., directly - // come from cudaMalloc) - status = cublasLtMatmulPreferenceInit(&preference); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulPreferenceSetAttribute( - &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - // We just need the best available heuristic to try and run matmul. - // There is no guarantee that this will work. For example, if A is - // badly aligned, you can request more (e.g. 32) algos and try to - // run them one by one until something works. - status = cublasLtMatmulAlgoGetHeuristic( - ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - if (returnedResults == 0) { - status = CUBLAS_STATUS_NOT_SUPPORTED; - goto CLEANUP; - } - status = cublasLtMatmul(ltHandle, - &operationDesc, - alpha, - A, - &Adesc, - B, - &Bdesc, - beta, - C, - &Cdesc, - C, - &Cdesc, - //&heuristicResult.algo, - NULL, - workspace, - workspaceSize, - stream); - -CLEANUP: - // Descriptors are no longer needed as all GPU work was already - // enqueued. - return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; -} - - - - - - - -int gemm_bias_lt( - cublasLtHandle_t ltHandle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float *alpha, /* host pointer */ - double* A, - int lda, - double* B, - int ldb, - const float *beta, /* host pointer */ - double* C, - int ldc, - void *workspace, - size_t workspaceSize, - cudaStream_t stream, - bool use_bias, - const void* bias) { - return 1; -} - -int gemm_bias_lt( - cublasLtHandle_t ltHandle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float *alpha, /* host pointer */ - float *A, - int lda, - float *B, - int ldb, - const float *beta, /* host pointer */ - float *C, - int ldc, - void *workspace, - size_t workspaceSize, - cudaStream_t stream, - bool use_bias, - const void* bias) { - cublasStatus_t status = CUBLAS_STATUS_SUCCESS; - - cublasLtMatmulDescOpaque_t operationDesc = {}; - cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {}; - cublasLtMatmulPreferenceOpaque_t preference = {}; - - int returnedResults = 0; - cublasLtMatmulHeuristicResult_t heuristicResult = {}; - cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT; - - // Create operation descriptor; see cublasLtMatmulDescAttributes_t - // for details about defaults; here we just set the transforms for - // A and B. - status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - if (use_bias) { - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias)); - if (status != CUBLAS_STATUS_SUCCESS) { - goto CLEANUP; - } - epilogue = CUBLASLT_EPILOGUE_BIAS; - } - - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)); - if (status != CUBLAS_STATUS_SUCCESS) { - goto CLEANUP; - } - - // Create matrix descriptors. Not setting any extra attributes. - status = cublasLtMatrixLayoutInit( - &Adesc, CUDA_R_32F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatrixLayoutInit( - &Bdesc, CUDA_R_32F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_32F, m, n, ldc); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - // Create preference handle; In general, extra attributes can be - // used here to disable tensor ops or to make sure algo selected - // will work with badly aligned A, B, C. However, for simplicity - // here we assume A,B,C are always well aligned (e.g., directly - // come from cudaMalloc) - status = cublasLtMatmulPreferenceInit(&preference); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulPreferenceSetAttribute( - &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - // We just need the best available heuristic to try and run matmul. - // There is no guarantee that this will work. For example, if A is - // badly aligned, you can request more (e.g. 32) algos and try to - // run them one by one until something works. - status = cublasLtMatmulAlgoGetHeuristic( - ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - if (returnedResults == 0) { - status = CUBLAS_STATUS_NOT_SUPPORTED; - goto CLEANUP; - } - - status = cublasLtMatmul(ltHandle, - &operationDesc, - alpha, - A, - &Adesc, - B, - &Bdesc, - beta, - C, - &Cdesc, - C, - &Cdesc, - &heuristicResult.algo, - workspace, - workspaceSize, - stream); - -CLEANUP: - // Descriptors are no longer needed as all GPU work was already - // enqueued. - return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; -} - - - -int gemm_bias_gelu_lt( - cublasLtHandle_t ltHandle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float *alpha, /* host pointer */ - at::Half* A, - int lda, - at::Half* B, - int ldb, - const float *beta, /* host pointer */ - at::Half* C, - int64_t ldc, - void *workspace, - size_t workspaceSize, - cudaStream_t stream, - bool use_bias, - const void* gelu_in, - const void* bias) { - cublasStatus_t status = CUBLAS_STATUS_SUCCESS; - - cublasLtMatmulDescOpaque_t operationDesc = {}; - cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {}; - cublasLtMatmulPreferenceOpaque_t preference = {}; - - int returnedResults = 0; - cublasLtMatmulHeuristicResult_t heuristicResult = {}; - cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_GELU_AUX; - - // Create operation descriptor; see cublasLtMatmulDescAttributes_t - // for details about defaults; here we just set the transforms for - // A and B. - status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &gelu_in, sizeof(gelu_in)); - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc)); - - if (use_bias) { - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias)); - if (status != CUBLAS_STATUS_SUCCESS) { - goto CLEANUP; - } - epilogue = CUBLASLT_EPILOGUE_GELU_AUX_BIAS; - } - - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)); - if (status != CUBLAS_STATUS_SUCCESS) { - goto CLEANUP; - } - - // Create matrix descriptors. Not setting any extra attributes. - status = cublasLtMatrixLayoutInit( - &Adesc, CUDA_R_16F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatrixLayoutInit( - &Bdesc, CUDA_R_16F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16F, m, n, ldc); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - // Create preference handle; In general, extra attributes can be - // used here to disable tensor ops or to make sure algo selected - // will work with badly aligned A, B, C. However, for simplicity - // here we assume A,B,C are always well aligned (e.g., directly - // come from cudaMalloc) - status = cublasLtMatmulPreferenceInit(&preference); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulPreferenceSetAttribute( - &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - // We just need the best available heuristic to try and run matmul. - // There is no guarantee that this will work. For example, if A is - // badly aligned, you can request more (e.g. 32) algos and try to - // run them one by one until something works. - status = cublasLtMatmulAlgoGetHeuristic( - ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - if (returnedResults == 0) { - status = CUBLAS_STATUS_NOT_SUPPORTED; - goto CLEANUP; - } - status = cublasLtMatmul(ltHandle, - &operationDesc, - alpha, - A, - &Adesc, - B, - &Bdesc, - beta, - C, - &Cdesc, - C, - &Cdesc, - //&heuristicResult.algo, - NULL, - workspace, - workspaceSize, - stream); - -CLEANUP: - // Descriptors are no longer needed as all GPU work was already - // enqueued. - return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; -} - - -int gemm_bias_gelu_lt( - cublasLtHandle_t ltHandle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float *alpha, /* host pointer */ - double* A, - int lda, - double* B, - int ldb, - const float *beta, /* host pointer */ - double* C, - int ldc, - void *workspace, - size_t workspaceSize, - cudaStream_t stream, - bool use_bias, - const void *gelu_in, - const void* bias) { - return 1; -} - - -int gemm_bias_gelu_lt( - cublasLtHandle_t ltHandle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float *alpha, /* host pointer */ - float *A, - int lda, - float *B, - int ldb, - const float *beta, /* host pointer */ - float *C, - int64_t ldc, - void *workspace, - size_t workspaceSize, - cudaStream_t stream, - bool use_bias, - const void* gelu_in, - const void* bias) { - cublasStatus_t status = CUBLAS_STATUS_SUCCESS; - - cublasLtMatmulDescOpaque_t operationDesc = {}; - cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {}; - cublasLtMatmulPreferenceOpaque_t preference = {}; - - int returnedResults = 0; - cublasLtMatmulHeuristicResult_t heuristicResult = {}; - cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_GELU_AUX; - - // Create operation descriptor; see cublasLtMatmulDescAttributes_t - // for details about defaults; here we just set the transforms for - // A and B. - status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &gelu_in, sizeof(gelu_in)); - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc)); - - if (use_bias) { - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias)); - if (status != CUBLAS_STATUS_SUCCESS) { - goto CLEANUP; - } - epilogue = CUBLASLT_EPILOGUE_GELU_AUX_BIAS; - } - - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)); - if (status != CUBLAS_STATUS_SUCCESS) { - goto CLEANUP; - } - - // Create matrix descriptors. Not setting any extra attributes. - status = cublasLtMatrixLayoutInit( - &Adesc, CUDA_R_32F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatrixLayoutInit( - &Bdesc, CUDA_R_32F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_32F, m, n, ldc); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - // Create preference handle; In general, extra attributes can be - // used here to disable tensor ops or to make sure algo selected - // will work with badly aligned A, B, C. However, for simplicity - // here we assume A,B,C are always well aligned (e.g., directly - // come from cudaMalloc) - status = cublasLtMatmulPreferenceInit(&preference); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulPreferenceSetAttribute( - &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - // We just need the best available heuristic to try and run matmul. - // There is no guarantee that this will work. For example, if A is - // badly aligned, you can request more (e.g. 32) algos and try to - // run them one by one until something works. - status = cublasLtMatmulAlgoGetHeuristic( - ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - if (returnedResults == 0) { - status = CUBLAS_STATUS_NOT_SUPPORTED; - goto CLEANUP; - } - status = cublasLtMatmul(ltHandle, - &operationDesc, - alpha, - A, - &Adesc, - B, - &Bdesc, - beta, - C, - &Cdesc, - C, - &Cdesc, - //&heuristicResult.algo, - NULL, - workspace, - workspaceSize, - stream); - -CLEANUP: - // Descriptors are no longer needed as all GPU work was already - // enqueued. - return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; -} - - - -int gemm_bgradb_lt( - cublasLtHandle_t ltHandle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float *alpha, /* host pointer */ - at::Half* A, - int lda, - at::Half* B, - int ldb, - const float *beta, /* host pointer */ - at::Half* C, - int ldc, - void *workspace, - size_t workspaceSize, - cudaStream_t stream, - bool use_bias, - const void* bgrad) { - cublasStatus_t status = CUBLAS_STATUS_SUCCESS; - - cublasLtMatmulDescOpaque_t operationDesc = {}; - cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {}; - cublasLtMatmulPreferenceOpaque_t preference = {}; - - int returnedResults = 0; - cublasLtMatmulHeuristicResult_t heuristicResult = {}; - cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT; - - // Create operation descriptor; see cublasLtMatmulDescAttributes_t - // for details about defaults; here we just set the transforms for - // A and B. - status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - if (use_bias) { - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad)); - if (status != CUBLAS_STATUS_SUCCESS) { - goto CLEANUP; - } - epilogue = CUBLASLT_EPILOGUE_BGRADB; - } - - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)); - if (status != CUBLAS_STATUS_SUCCESS) { - goto CLEANUP; - } - - // Create matrix descriptors. Not setting any extra attributes. - status = cublasLtMatrixLayoutInit( - &Adesc, CUDA_R_16F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatrixLayoutInit( - &Bdesc, CUDA_R_16F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16F, m, n, ldc); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - // Create preference handle; In general, extra attributes can be - // used here to disable tensor ops or to make sure algo selected - // will work with badly aligned A, B, C. However, for simplicity - // here we assume A,B,C are always well aligned (e.g., directly - // come from cudaMalloc) - status = cublasLtMatmulPreferenceInit(&preference); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulPreferenceSetAttribute( - &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - // We just need the best available heuristic to try and run matmul. - // There is no guarantee that this will work. For example, if A is - // badly aligned, you can request more (e.g. 32) algos and try to - // run them one by one until something works. - status = cublasLtMatmulAlgoGetHeuristic( - ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - if (returnedResults == 0) { - status = CUBLAS_STATUS_NOT_SUPPORTED; - goto CLEANUP; - } - status = cublasLtMatmul(ltHandle, - &operationDesc, - alpha, - A, - &Adesc, - B, - &Bdesc, - beta, - C, - &Cdesc, - C, - &Cdesc, - //&heuristicResult.algo, - NULL, - workspace, - workspaceSize, - stream); - -CLEANUP: - // Descriptors are no longer needed as all GPU work was already - // enqueued. - return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; -} - - - - - - - -int gemm_bgradb_lt( - cublasLtHandle_t ltHandle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float *alpha, /* host pointer */ - double* A, - int lda, - double* B, - int ldb, - const float *beta, /* host pointer */ - double* C, - int ldc, - void *workspace, - size_t workspaceSize, - cudaStream_t stream, - bool use_bias, - const void* bgrad) { - return 1; -} - -int gemm_bgradb_lt( - cublasLtHandle_t ltHandle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float *alpha, /* host pointer */ - float *A, - int lda, - float *B, - int ldb, - const float *beta, /* host pointer */ - float *C, - int ldc, - void *workspace, - size_t workspaceSize, - cudaStream_t stream, - bool use_bias, - const void* bgrad) { - cublasStatus_t status = CUBLAS_STATUS_SUCCESS; - - cublasLtMatmulDescOpaque_t operationDesc = {}; - cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {}; - cublasLtMatmulPreferenceOpaque_t preference = {}; - - int returnedResults = 0; - cublasLtMatmulHeuristicResult_t heuristicResult = {}; - cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT; - - // Create operation descriptor; see cublasLtMatmulDescAttributes_t - // for details about defaults; here we just set the transforms for - // A and B. - status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - if (use_bias) { - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad)); - if (status != CUBLAS_STATUS_SUCCESS) { - goto CLEANUP; - } - epilogue = CUBLASLT_EPILOGUE_BGRADB; - } - - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)); - if (status != CUBLAS_STATUS_SUCCESS) { - goto CLEANUP; - } - - // Create matrix descriptors. Not setting any extra attributes. - status = cublasLtMatrixLayoutInit( - &Adesc, CUDA_R_32F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatrixLayoutInit( - &Bdesc, CUDA_R_32F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_32F, m, n, ldc); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - // Create preference handle; In general, extra attributes can be - // used here to disable tensor ops or to make sure algo selected - // will work with badly aligned A, B, C. However, for simplicity - // here we assume A,B,C are always well aligned (e.g., directly - // come from cudaMalloc) - status = cublasLtMatmulPreferenceInit(&preference); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulPreferenceSetAttribute( - &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - // We just need the best available heuristic to try and run matmul. - // There is no guarantee that this will work. For example, if A is - // badly aligned, you can request more (e.g. 32) algos and try to - // run them one by one until something works. - status = cublasLtMatmulAlgoGetHeuristic( - ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - if (returnedResults == 0) { - status = CUBLAS_STATUS_NOT_SUPPORTED; - goto CLEANUP; - } - - status = cublasLtMatmul(ltHandle, - &operationDesc, - alpha, - A, - &Adesc, - B, - &Bdesc, - beta, - C, - &Cdesc, - C, - &Cdesc, - &heuristicResult.algo, - workspace, - workspaceSize, - stream); - -CLEANUP: - // Descriptors are no longer needed as all GPU work was already - // enqueued. - return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; -} - - -int gemm_dgelu_bgradb_lt( - cublasLtHandle_t ltHandle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float *alpha, /* host pointer */ - at::Half* A, - int lda, - at::Half* B, - int ldb, - const float *beta, /* host pointer */ - at::Half* C, - int64_t ldc, - void *workspace, - size_t workspaceSize, - cudaStream_t stream, - const void *gelu_in, - const void *bgrad) { - cublasStatus_t status = CUBLAS_STATUS_SUCCESS; - - cublasLtMatmulDescOpaque_t operationDesc = {}; - cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {}; - cublasLtMatmulPreferenceOpaque_t preference = {}; - - int returnedResults = 0; - cublasLtMatmulHeuristicResult_t heuristicResult = {}; - cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DGELU_BGRAD; - - // Create operation descriptor; see cublasLtMatmulDescAttributes_t - // for details about defaults; here we just set the transforms for - // A and B. - status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad)); - if (status != CUBLAS_STATUS_SUCCESS) { - goto CLEANUP; - } - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &gelu_in, sizeof(gelu_in)); - if (status != CUBLAS_STATUS_SUCCESS) { - goto CLEANUP; - } - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc)); - - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)); - if (status != CUBLAS_STATUS_SUCCESS) { - goto CLEANUP; - } - - // Create matrix descriptors. Not setting any extra attributes. - status = cublasLtMatrixLayoutInit( - &Adesc, CUDA_R_16F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatrixLayoutInit( - &Bdesc, CUDA_R_16F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16F, m, n, ldc); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - // Create preference handle; In general, extra attributes can be - // used here to disable tensor ops or to make sure algo selected - // will work with badly aligned A, B, C. However, for simplicity - // here we assume A,B,C are always well aligned (e.g., directly - // come from cudaMalloc) - status = cublasLtMatmulPreferenceInit(&preference); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulPreferenceSetAttribute( - &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - // We just need the best available heuristic to try and run matmul. - // There is no guarantee that this will work. For example, if A is - // badly aligned, you can request more (e.g. 32) algos and try to - // run them one by one until something works. - status = cublasLtMatmulAlgoGetHeuristic( - ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - if (returnedResults == 0) { - status = CUBLAS_STATUS_NOT_SUPPORTED; - goto CLEANUP; - } - status = cublasLtMatmul(ltHandle, - &operationDesc, - alpha, - A, - &Adesc, - B, - &Bdesc, - beta, - C, - &Cdesc, - C, - &Cdesc, - //&heuristicResult.algo, - NULL, - workspace, - workspaceSize, - stream); - -CLEANUP: - // Descriptors are no longer needed as all GPU work was already - // enqueued. - return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; -} - -int gemm_dgelu_bgradb_lt( - cublasLtHandle_t ltHandle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float *alpha, /* host pointer */ - double *A, - int lda, - double *B, - int ldb, - const float *beta, /* host pointer */ - double *C, - int ldc, - void *workspace, - size_t workspaceSize, - cudaStream_t stream, - const void *gelu_in, - const void *bgrad) { - return 1; -} - -int gemm_dgelu_bgradb_lt( - cublasLtHandle_t ltHandle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float *alpha, /* host pointer */ - float *A, - int lda, - float *B, - int ldb, - const float *beta, /* host pointer */ - float *C, - int64_t ldc, - void *workspace, - size_t workspaceSize, - cudaStream_t stream, - const void *gelu_in, - const void *bgrad) { - cublasStatus_t status = CUBLAS_STATUS_SUCCESS; - - cublasLtMatmulDescOpaque_t operationDesc = {}; - cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {}; - cublasLtMatmulPreferenceOpaque_t preference = {}; - - int returnedResults = 0; - cublasLtMatmulHeuristicResult_t heuristicResult = {}; - cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DGELU_BGRAD; - - // Create operation descriptor; see cublasLtMatmulDescAttributes_t - // for details about defaults; here we just set the transforms for - // A and B. - status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad)); - if (status != CUBLAS_STATUS_SUCCESS) { - goto CLEANUP; - } - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &gelu_in, sizeof(gelu_in)); - if (status != CUBLAS_STATUS_SUCCESS) { - goto CLEANUP; - } - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc)); - - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)); - if (status != CUBLAS_STATUS_SUCCESS) { - goto CLEANUP; - } - - // Create matrix descriptors. Not setting any extra attributes. - status = cublasLtMatrixLayoutInit( - &Adesc, CUDA_R_32F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatrixLayoutInit( - &Bdesc, CUDA_R_32F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_32F, m, n, ldc); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - // Create preference handle; In general, extra attributes can be - // used here to disable tensor ops or to make sure algo selected - // will work with badly aligned A, B, C. However, for simplicity - // here we assume A,B,C are always well aligned (e.g., directly - // come from cudaMalloc) - status = cublasLtMatmulPreferenceInit(&preference); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulPreferenceSetAttribute( - &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - // We just need the best available heuristic to try and run matmul. - // There is no guarantee that this will work. For example, if A is - // badly aligned, you can request more (e.g. 32) algos and try to - // run them one by one until something works. - status = cublasLtMatmulAlgoGetHeuristic( - ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - if (returnedResults == 0) { - status = CUBLAS_STATUS_NOT_SUPPORTED; - goto CLEANUP; - } - status = cublasLtMatmul(ltHandle, - &operationDesc, - alpha, - A, - &Adesc, - B, - &Bdesc, - beta, - C, - &Cdesc, - C, - &Cdesc, - //&heuristicResult.algo, - NULL, - workspace, - workspaceSize, - stream); - -CLEANUP: - // Descriptors are no longer needed as all GPU work was already - // enqueued. - return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; -} - -#endif - -template -int linear_bias_forward_cuda(at::Tensor input, T *weight, at::Tensor bias, int in_features, int batch_size, int out_features, at::Tensor output, void *lt_workspace) { - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - // Get the stream from cublas handle to reuse for biasReLU kernel. - cudaStream_t stream; - cublasGetStream(handle, &stream); - const float alpha = 1.0; - const float beta_zero = 0.0; - const float beta_one = 1.0; - int status = 1; -#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000 - status = gemm_bias_lt( - (cublasLtHandle_t)handle, - CUBLAS_OP_T, - CUBLAS_OP_N, - out_features, - batch_size, - in_features, - &alpha, /* host pointer */ - weight, - in_features, - input.data_ptr(), - in_features, - &beta_zero, /* host pointer */ - output.data_ptr(), - out_features, - lt_workspace, - 1 << 22, - stream, - true, - static_cast(bias.data_ptr())); -#endif - if (status != 0){ - output.copy_(bias); - status = gemm_bias( - handle, - CUBLAS_OP_T, - CUBLAS_OP_N, - out_features, - batch_size, - in_features, - &alpha, - weight, - in_features, - input.data_ptr(), - in_features, - &beta_one, - output.data_ptr(), - out_features); - } - return status; -} - - -template -int linear_bias_backward_cuda(T *input, T *weight, T *d_output, int in_features, int batch_size, int out_features, T *d_weight, T *d_bias, T *d_input, void *lt_workspace) { - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - // Get the stream from cublas handle to reuse for biasReLU kernel. - cudaStream_t stream; - cublasGetStream(handle, &stream); - const float alpha = 1.0; - const float beta_zero = 0.0; - const float beta_one = 1.0; - int status = 1; -#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600 - status = gemm_bgradb_lt( - (cublasLtHandle_t)handle, - CUBLAS_OP_N, - CUBLAS_OP_T, - in_features, - out_features, - batch_size, - &alpha, /* host pointer */ - input, - in_features, - d_output, - out_features, - &beta_zero, /* host pointer */ - d_weight, - in_features, - lt_workspace, - 1 << 22, - stream, - true, - static_cast(d_bias)); -#endif - - - if (status != 0){ - - status = gemm_bias( - handle, - CUBLAS_OP_N, - CUBLAS_OP_T, - in_features, - out_features, - batch_size, - &alpha, - input, - in_features, - d_output, - out_features, - &beta_zero, - d_weight, - in_features); - } - - status = gemm_bias( - handle, - CUBLAS_OP_N, - CUBLAS_OP_N, - in_features, - batch_size, - out_features, - &alpha, - weight, - in_features, - d_output, - out_features, - &beta_zero, - d_input, - in_features); - return status; - -} - -template -int linear_gelu_linear_forward_cuda(T *input, T *weight1, T *bias1, T *weight2, T *bias2, int in_features, int hidden_features, int batch_size, int out_features, T *output1, T *output2, T *gelu_in, void *lt_workspace) { - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - // Get the stream from cublas handle to reuse for biasReLU kernel. - cudaStream_t stream; - cublasGetStream(handle, &stream); - const float alpha = 1.0; - const float beta_zero = 0.0; - int status = 1; -#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000 - status = gemm_bias_gelu_lt( - (cublasLtHandle_t)handle, - CUBLAS_OP_T, - CUBLAS_OP_N, - hidden_features, - batch_size, - in_features, - &alpha, /* host pointer */ - weight1, - in_features, - input, - in_features, - &beta_zero, /* host pointer */ - output1, - hidden_features, - lt_workspace, - 1 << 22, - stream, - true, - static_cast(gelu_in), - static_cast(bias1)); - status = gemm_bias_lt( - (cublasLtHandle_t)handle, - CUBLAS_OP_T, - CUBLAS_OP_N, - out_features, - batch_size, - hidden_features, - &alpha, /* host pointer */ - weight2, - hidden_features, - output1, - hidden_features, - &beta_zero, /* host pointer */ - output2, - out_features, - lt_workspace, - 1 << 22, - stream, - true, - static_cast(bias2)); - return status; -#else - return 1; -#endif -} - -template -int linear_gelu_linear_backward_cuda(T *input, T *gelu_in, T *output1, T *weight1, T *weight2, T *d_output1, T *d_output2, int in_features, int batch_size, int hidden_features, int out_features, T *d_weight1, T *d_weight2, T *d_bias1, T *d_bias2, T *d_input, void *lt_workspace) { - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - // Get the stream from cublas handle to reuse for biasReLU kernel. - cudaStream_t stream; - cublasGetStream(handle, &stream); - const float alpha = 1.0; - const float beta_zero = 0.0; - const float beta_one = 1.0; - int status = 1; -#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000 -//wgrad for first gemm - status = gemm_bgradb_lt( - (cublasLtHandle_t)handle, - CUBLAS_OP_N, - CUBLAS_OP_T, - hidden_features, - out_features, - batch_size, - &alpha, /* host pointer */ - output1, - hidden_features, - d_output2, - out_features, - &beta_zero, /* host pointer */ - d_weight2, - hidden_features, - lt_workspace, - 1 << 22, - stream, - true, - static_cast(d_bias2)); -//dgrad for second GEMM - status = gemm_dgelu_bgradb_lt( - (cublasLtHandle_t)handle, - CUBLAS_OP_N, - CUBLAS_OP_N, - hidden_features, - batch_size, - out_features, - &alpha, /* host pointer */ - weight2, - hidden_features, - d_output2, - out_features, - &beta_zero, /* host pointer */ - d_output1, - hidden_features, - lt_workspace, - 1 << 22, - stream, - static_cast(gelu_in), - static_cast(d_bias1)); -//wgrad for the first GEMM - status = gemm_bias( - handle, - CUBLAS_OP_N, - CUBLAS_OP_T, - in_features, - hidden_features, - batch_size, - &alpha, - input, - in_features, - d_output1, - hidden_features, - &beta_zero, - d_weight1, - in_features); - -//dgrad for the first GEMM - status = gemm_bias( - handle, - CUBLAS_OP_N, - CUBLAS_OP_N, - in_features, - batch_size, - hidden_features, - &alpha, - weight1, - in_features, - d_output1, - hidden_features, - &beta_zero, - d_input, - in_features); -#endif - return status; - -} - - -template int linear_bias_forward_cuda(at::Tensor input, at::Half *weight, at::Tensor bias, int in_features, int batch_size, int out_features, at::Tensor output, void *lt_workspace); - -template int linear_bias_forward_cuda(at::Tensor input, float *weight, at::Tensor bias, int in_features, int batch_size, int out_features, at::Tensor output, void *lt_workspace); - -template int linear_bias_forward_cuda(at::Tensor input, double *weight, at::Tensor bias, int in_features, int batch_size, int out_features, at::Tensor output, void *lt_workspace); - -template int linear_bias_backward_cuda(at::Half *input, at::Half *weight, at::Half *d_output, int in_features, int batch_size, int out_features, at::Half *d_weight, at::Half *d_bias, at::Half *d_input, void *lt_workspace) ; - -template int linear_bias_backward_cuda(float *input, float *weight, float *d_output, int in_features, int batch_size, int out_features, float *d_weight, float *d_bias, float *d_input, void *lt_workspace) ; - -template int linear_bias_backward_cuda(double *input, double *weight, double *d_output, int in_features, int batch_size, int out_features, double *d_weight, double *d_bias, double *d_input, void *lt_workspace) ; - - -template int linear_gelu_linear_forward_cuda(at::Half *input, at::Half *weight1, at::Half *bias1, at::Half *weight2, at::Half *bias2, int in_features, int hidden_features, int batch_size, int out_features, at::Half *output1, at::Half *output2, at::Half *gelu_in, void *lt_workspace) ; - -template int linear_gelu_linear_forward_cuda(float *input, float *weight1, float *bias1, float *weight2, float *bias2, int in_features, int hidden_features, int batch_size, int out_features, float *output1, float *output2, float *gelu_in, void *lt_workspace); - -template int linear_gelu_linear_forward_cuda(double *input, double *weight1, double *bias1, double *weight2, double *bias2, int in_features, int hidden_features, int batch_size, int out_features, double *output1, double *output2, double *gelu_in, void *lt_workspace) ; - -template int linear_gelu_linear_backward_cuda(at::Half *input, at::Half *gelu_in, at::Half *output1, at::Half *weight1, at::Half *weight2, at::Half *d_output1, at::Half *d_output2, int in_features, int batch_size, int hidden_features, int out_features, at::Half *d_weight1, at::Half *d_weight2, at::Half *d_bias1, at::Half *d_bias2, at::Half *d_input, void *lt_workspace); - -template int linear_gelu_linear_backward_cuda(float *input, float *gelu_in, float *output1, float *weight1, float *weight2, float *d_output1, float *d_output2, int in_features, int batch_size, int hidden_features, int out_features, float *d_weight1, float *d_weight2, float *d_bias1, float *d_bias2, float *d_input, void *lt_workspace); - -template int linear_gelu_linear_backward_cuda(double *input, double *gelu_in, double *output1, double *weight1, double *weight2, double *d_output1, double *d_output2, int in_features, int batch_size, int hidden_features, int out_features, double *d_weight1, double *d_weight2, double *d_bias1, double *d_bias2, double *d_input, void *lt_workspace); - diff --git a/csrc/layer_norm_cuda.cpp b/csrc/layer_norm_cuda.cpp deleted file mode 100644 index 8698701..0000000 --- a/csrc/layer_norm_cuda.cpp +++ /dev/null @@ -1,442 +0,0 @@ -#include -#include -#include -#include "compat.h" - -namespace { -void compute_n1_n2( - at::Tensor input, - #ifdef VERSION_GE_1_1 - at::IntArrayRef normalized_shape, - #else - at::IntList normalized_shape, - #endif - int& n1, - int& n2) -{ - int idiff = input.ndimension() - normalized_shape.size(); - n2 = 1; - for (int i = 0; i < (int)normalized_shape.size(); ++i) { - assert( input.sizes()[i+idiff] == normalized_shape[i] ); - n2 *= normalized_shape[i]; - } - n1 = 1; - for (int i = 0; i < idiff; ++i) { - n1 *= input.sizes()[i]; - } -} - -void check_args( - #ifdef VERSION_GE_1_1 - at::IntArrayRef normalized_shape, - #else - at::IntList normalized_shape, - #endif - at::Tensor gamma, - at::Tensor beta - ) -{ - TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape)); - TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape)); -} - -void check_args( - #ifdef VERSION_GE_1_1 - at::IntArrayRef normalized_shape, - #else - at::IntList normalized_shape, - #endif - at::Tensor gamma - ) -{ - TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape)); -} - - -void check_args( - at::Tensor input, - #ifdef VERSION_GE_1_1 - at::IntArrayRef normalized_shape, - #else - at::IntList normalized_shape, - #endif - int& n1, - int& n2 - ) -{ - int64_t normalized_ndim = normalized_shape.size(); - - if (normalized_ndim < 1) { - std::stringstream ss; - ss << "Expected normalized_shape to be at least 1-dimensional, i.e., " - << "containing at least one element, but got normalized_shape=" - << normalized_shape; - throw std::runtime_error(ss.str()); - } - - auto input_shape = input.sizes(); - auto input_ndim = input.dim(); - - if (input_ndim < normalized_ndim || - !input_shape.slice(input_ndim - normalized_ndim).equals(normalized_shape)) { - std::stringstream ss; - ss << "Given normalized_shape=" << normalized_shape - << ", expected input with shape [*"; - for (auto size : normalized_shape) { - ss << ", " << size; - } - ss << "], but got input of size" << input_shape; - throw std::runtime_error(ss.str()); - } - - compute_n1_n2(input,normalized_shape,n1,n2); -} - -void check_args( - at::Tensor input, - #ifdef VERSION_GE_1_1 - at::IntArrayRef normalized_shape, - #else - at::IntList normalized_shape, - #endif - at::Tensor gamma, - at::Tensor beta, - int& n1, - int& n2 - ) -{ - check_args(input,normalized_shape,n1,n2); - check_args(normalized_shape,gamma,beta); -} - -void check_args( - at::Tensor input, - #ifdef VERSION_GE_1_1 - at::IntArrayRef normalized_shape, - #else - at::IntList normalized_shape, - #endif - at::Tensor gamma, - int& n1, - int& n2 - ) -{ - check_args(input,normalized_shape,n1,n2); - check_args(normalized_shape,gamma); -} -} - -void cuda_layer_norm( - at::Tensor* output, - at::Tensor* mean, - at::Tensor* invvar, - at::Tensor* input, - int n1, - int n2, - #ifdef VERSION_GE_1_1 - at::IntArrayRef normalized_shape, - #else - at::IntList normalized_shape, - #endif - at::Tensor* gamma, - at::Tensor* beta, - double epsilon); - -#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) - -std::vector layer_norm( - at::Tensor input, - #ifdef VERSION_GE_1_1 - at::IntArrayRef normalized_shape, - #else - at::IntList normalized_shape, - #endif - double epsilon) { - CHECK_INPUT(input); - int n1,n2; - check_args(input,normalized_shape,n1,n2); - at::Tensor output = at::empty_like(input); - at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half || input.scalar_type()==at::ScalarType::BFloat16 ? at::ScalarType::Float : input.scalar_type())); - at::Tensor invvar = at::empty_like(mean); - cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2, - normalized_shape,NULL,NULL,epsilon); - return {output, mean, invvar}; -} - -std::vector layer_norm_affine( - at::Tensor input, - #ifdef VERSION_GE_1_1 - at::IntArrayRef normalized_shape, - #else - at::IntList normalized_shape, - #endif - at::Tensor gamma, - at::Tensor beta, - double epsilon) { - CHECK_INPUT(input); - CHECK_INPUT(gamma); - CHECK_INPUT(beta); - int n1,n2; - check_args(input,normalized_shape,gamma,beta,n1,n2); - at::Tensor output = at::empty_like(input); - const auto stats_dtype = (input.scalar_type() == at::ScalarType::Half || input.scalar_type() == at::ScalarType::BFloat16) ? at::ScalarType::Float : input.scalar_type(); - at::Tensor mean = at::empty({n1}, input.options().dtype(stats_dtype)); - at::Tensor invvar = at::empty_like(mean); - cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2, - normalized_shape,&gamma,&beta,epsilon); - return {output, mean, invvar}; -} - -std::vector layer_norm_affine_mixed_dtypes( - at::Tensor input, - #ifdef VERSION_GE_1_1 - at::IntArrayRef normalized_shape, - #else - at::IntList normalized_shape, - #endif - at::Tensor gamma, - at::Tensor beta, - double epsilon) { - CHECK_INPUT(input); - int n1, n2; - check_args(input, normalized_shape, n1, n2); - at::Tensor output = at::empty_like(input, gamma.options().dtype(gamma.scalar_type())); - at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type() == at::ScalarType::Half || input.scalar_type() == at::ScalarType::BFloat16 ? at::ScalarType::Float : input.scalar_type())); - at::Tensor invvar = at::empty_like(mean); - cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2, - normalized_shape, &gamma, &beta, epsilon); - return {output, mean, invvar}; -} - -void cuda_layer_norm_gradient( - at::Tensor* dout, - at::Tensor* mean, - at::Tensor* invvar, - at::Tensor* input, - int n1, - int n2, - #ifdef VERSION_GE_1_1 - at::IntArrayRef normalized_shape, - #else - at::IntList normalized_shape, - #endif - at::Tensor* gamma, - at::Tensor* beta, - double epsilon, - at::Tensor* grad_input, - at::Tensor* grad_gamma, - at::Tensor* grad_beta - ); - -at::Tensor layer_norm_gradient( - at::Tensor dout, - at::Tensor mean, - at::Tensor invvar, - at::Tensor input, - #ifdef VERSION_GE_1_1 - at::IntArrayRef normalized_shape, - #else - at::IntList normalized_shape, - #endif - double epsilon) { - CHECK_INPUT(dout); - CHECK_INPUT(mean); - CHECK_INPUT(invvar); - CHECK_INPUT(input); - int n1,n2; - check_args(input,normalized_shape,n1,n2); - at::Tensor grad_input = at::empty_like(input); - cuda_layer_norm_gradient(&dout,&mean,&invvar,&input,n1,n2, - normalized_shape,NULL,NULL,epsilon, - &grad_input,NULL,NULL); - return grad_input; -} - -std::vector layer_norm_gradient_affine( - at::Tensor dout, - at::Tensor mean, - at::Tensor invvar, - at::Tensor input, - #ifdef VERSION_GE_1_1 - at::IntArrayRef normalized_shape, - #else - at::IntList normalized_shape, - #endif - at::Tensor gamma, - at::Tensor beta, - double epsilon) { - CHECK_INPUT(dout); - CHECK_INPUT(mean); - CHECK_INPUT(invvar); - CHECK_INPUT(input); - CHECK_INPUT(gamma); - CHECK_INPUT(beta); - int n1,n2; - check_args(input,normalized_shape,gamma,beta,n1,n2); - at::Tensor grad_input = at::empty_like(input); - at::Tensor grad_gamma = at::empty_like(gamma); - at::Tensor grad_beta = at::empty_like(beta); - cuda_layer_norm_gradient(&dout,&mean,&invvar,&input,n1,n2, - normalized_shape,&gamma,&beta,epsilon, - &grad_input,&grad_gamma,&grad_beta); - return {grad_input, grad_gamma, grad_beta}; -} - -void cuda_rms_norm( - at::Tensor* output, - at::Tensor* invvar, - at::Tensor* input, - int n1, - int n2, - #ifdef VERSION_GE_1_1 - at::IntArrayRef normalized_shape, - #else - at::IntList normalized_shape, - #endif - at::Tensor* gamma, - double epsilon); - -#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) - -std::vector rms_norm( - at::Tensor input, - #ifdef VERSION_GE_1_1 - at::IntArrayRef normalized_shape, - #else - at::IntList normalized_shape, - #endif - double epsilon) { - CHECK_INPUT(input); - int n1,n2; - check_args(input,normalized_shape,n1,n2); - at::Tensor output = at::empty_like(input); - at::Tensor invvar = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half || input.scalar_type()==at::ScalarType::BFloat16 ? at::ScalarType::Float : input.scalar_type())); - cuda_rms_norm(&output,&invvar,&input,n1,n2, - normalized_shape,NULL,epsilon); - return {output, invvar}; -} - -std::vector rms_norm_affine( - at::Tensor input, - #ifdef VERSION_GE_1_1 - at::IntArrayRef normalized_shape, - #else - at::IntList normalized_shape, - #endif - at::Tensor gamma, - double epsilon) { - CHECK_INPUT(input); - CHECK_INPUT(gamma); - int n1,n2; - check_args(input,normalized_shape,gamma,n1,n2); - at::Tensor output = at::empty_like(input); - const auto stats_dtype = (input.scalar_type() == at::ScalarType::Half || input.scalar_type() == at::ScalarType::BFloat16) ? at::ScalarType::Float : input.scalar_type(); - at::Tensor invvar = at::empty({n1}, input.options().dtype(stats_dtype)); - cuda_rms_norm(&output,&invvar,&input,n1,n2, - normalized_shape,&gamma,epsilon); - return {output, invvar}; -} - -std::vector rms_norm_affine_mixed_dtypes( - at::Tensor input, - #ifdef VERSION_GE_1_1 - at::IntArrayRef normalized_shape, - #else - at::IntList normalized_shape, - #endif - at::Tensor gamma, - double epsilon) { - CHECK_INPUT(input); - int n1, n2; - check_args(input, normalized_shape, n1, n2); - at::Tensor output = at::empty_like(input, gamma.options().dtype(gamma.scalar_type())); - at::Tensor invvar = at::empty({n1}, input.options().dtype(input.scalar_type() == at::ScalarType::Half || input.scalar_type() == at::ScalarType::BFloat16 ? at::ScalarType::Float : input.scalar_type())); - - cuda_rms_norm(&output,&invvar, &input, n1, n2, - normalized_shape, &gamma,epsilon); - return {output,invvar}; -} - -void cuda_rms_norm_gradient( - at::Tensor* dout, - at::Tensor* invvar, - at::Tensor* input, - int n1, - int n2, - #ifdef VERSION_GE_1_1 - at::IntArrayRef normalized_shape, - #else - at::IntList normalized_shape, - #endif - at::Tensor* gamma, - double epsilon, - at::Tensor* grad_input, - at::Tensor* grad_gamma); - -at::Tensor rms_norm_gradient( - at::Tensor dout, - at::Tensor invvar, - at::Tensor input, - #ifdef VERSION_GE_1_1 - at::IntArrayRef normalized_shape, - #else - at::IntList normalized_shape, - #endif - double epsilon) { - CHECK_INPUT(dout); - CHECK_INPUT(invvar); - CHECK_INPUT(input); - int n1,n2; - check_args(input,normalized_shape,n1,n2); - at::Tensor grad_input = at::empty_like(input); - cuda_rms_norm_gradient(&dout,&invvar,&input,n1,n2, - normalized_shape,NULL,epsilon, - &grad_input,NULL); - return grad_input; -} - -std::vector rms_norm_gradient_affine( - at::Tensor dout, - at::Tensor invvar, - at::Tensor input, - #ifdef VERSION_GE_1_1 - at::IntArrayRef normalized_shape, - #else - at::IntList normalized_shape, - #endif - at::Tensor gamma, - double epsilon) { - CHECK_INPUT(dout); - CHECK_INPUT(invvar); - CHECK_INPUT(input); - CHECK_INPUT(gamma); - int n1,n2; - check_args(input,normalized_shape,gamma,n1,n2); - at::Tensor grad_input = at::empty_like(input); - at::Tensor grad_gamma = at::empty_like(gamma); - cuda_rms_norm_gradient(&dout,&invvar,&input,n1,n2, - normalized_shape,&gamma,epsilon, - &grad_input,&grad_gamma); - return {grad_input, grad_gamma}; -} - - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward_affine", &layer_norm_affine, "LayerNorm forward (CUDA)"); - m.def("forward", &layer_norm, "LayerNorm forward (CUDA)"); - m.def("backward_affine", &layer_norm_gradient_affine, "LayerNorm backward (CUDA)"); - m.def("backward", &layer_norm_gradient, "LayerNorm backward (CUDA)"); - - m.def("forward_affine_mixed_dtypes", &layer_norm_affine_mixed_dtypes, "LayerNorm forward with mixed dtypes (CUDA) compatible with Megatron's implementation"); - - m.def("rms_forward_affine", &rms_norm_affine, "RMSNorm forward (CUDA)"); - m.def("rms_forward", &rms_norm, "RMSNorm forward (CUDA)"); - m.def("rms_backward_affine", &rms_norm_gradient_affine, "RMSNorm backward (CUDA)"); - m.def("rms_backward", &rms_norm_gradient, "RMSNorm backward (CUDA)"); - - m.def("rms_forward_affine_mixed_dtypes", &rms_norm_affine_mixed_dtypes, "RMSNorm forward with mixed dtypes (CUDA) compatible with Megatron's implementation"); -} diff --git a/csrc/layer_norm_cuda_kernel.cu b/csrc/layer_norm_cuda_kernel.cu deleted file mode 100644 index 6b6664b..0000000 --- a/csrc/layer_norm_cuda_kernel.cu +++ /dev/null @@ -1,1229 +0,0 @@ -#include "ATen/ATen.h" -#include "ATen/AccumulateType.h" -#include "ATen/cuda/CUDAContext.h" -#include "ATen/cuda/DeviceUtils.cuh" - -#include -#include - -#include "type_shim.h" - - -template __device__ -void cuWelfordOnlineSum( - const U curr, - U& mu, - U& sigma2, - U& count) -{ - count = count + U(1); - U delta = curr - mu; - U lmean = mu + delta / count; - mu = lmean; - U delta2 = curr - lmean; - sigma2 = sigma2 + delta * delta2; -} - -template __device__ -void cuChanOnlineSum( - const U muB, - const U sigma2B, - const U countB, - U& mu, - U& sigma2, - U& count) -{ - U delta = muB - mu; - U nA = count; - U nB = countB; - count = count + countB; - U nX = count; - if (nX > U(0)) { - nA = nA / nX; - nB = nB / nX; - mu = nA*mu + nB*muB; - sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX; - } else { - mu = U(0); - sigma2 = U(0); - } -} - -template __device__ -void cuRMSOnlineSum( - const U curr, - U& sigma2) -{ - sigma2 = sigma2 + curr * curr; -} - -template __device__ -void cuChanRMSOnlineSum( - const U sigma2B, - U& sigma2) -{ - sigma2 = sigma2 + sigma2B; -} - - -template __device__ -void cuWelfordMuSigma2( - const T* __restrict__ vals, - const int n1, - const int n2, - const int i1, - U& mu, - U& sigma2, - U* buf, - const int GPU_WARP_SIZE, - bool rms_only) -{ - // Assumptions: - // 1) blockDim.x == warpSize - // 2) Tensor is contiguous - // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. - // - // compute variance and mean over n2 - U count = U(0); - mu= U(0); - sigma2 = U(0); - if (i1 < n1) { - // one warp normalizes one n1 index, - // synchronization is implicit - // initialize with standard Welford algorithm - const int numx = blockDim.x * blockDim.y; - const int thrx = threadIdx.x + threadIdx.y * blockDim.x; - const T* lvals = vals + i1*n2; - int l = 4*thrx; - for (; l+3 < n2; l+=4*numx) { - for (int k = 0; k < 4; ++k) { - U curr = static_cast(lvals[l+k]); - if (!rms_only) { - cuWelfordOnlineSum(curr,mu,sigma2,count); - } else { - cuRMSOnlineSum(curr, sigma2); - } - } - } - for (; l < n2; ++l) { - U curr = static_cast(lvals[l]); - if (!rms_only) { - cuWelfordOnlineSum(curr,mu,sigma2,count); - } else { - cuRMSOnlineSum(curr, sigma2); - } - } - // intra-warp reductions - #pragma unroll - for (int stride = GPU_WARP_SIZE / 2; stride > 0; stride /= 2) { - U sigma2B = WARP_SHFL_DOWN(sigma2, stride); - if (!rms_only) { - U muB = WARP_SHFL_DOWN(mu, stride); - U countB = WARP_SHFL_DOWN(count, stride); - cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); - } else { - cuChanRMSOnlineSum(sigma2B, sigma2); - } - } - // threadIdx.x == 0 has correct values for each warp - // inter-warp reductions - if (blockDim.y > 1) { - U* ubuf = (U*)buf; - U* ibuf = (U*)(ubuf + blockDim.y); - for (int offset = blockDim.y/2; offset > 0; offset /= 2) { - // upper half of warps write to shared - if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) { - const int wrt_y = threadIdx.y - offset; - if (!rms_only) { - ubuf[2*wrt_y] = mu; - ibuf[wrt_y] = count; - } - ubuf[2*wrt_y+1] = sigma2; - } - __syncthreads(); - // lower half merges - if (threadIdx.x == 0 && threadIdx.y < offset) { - U sigma2B = ubuf[2*threadIdx.y+1]; - if (!rms_only) { - U muB = ubuf[2*threadIdx.y]; - U countB = ibuf[threadIdx.y]; - cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count); - } else { - cuChanRMSOnlineSum(sigma2B,sigma2); - } - } - __syncthreads(); - } - // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values - if (threadIdx.x == 0 && threadIdx.y == 0) { - if (!rms_only) { - ubuf[0] = mu; - } - ubuf[1] = sigma2; - } - __syncthreads(); - if (!rms_only) { - mu = ubuf[0]; - } - sigma2 = ubuf[1]/U(n2); - // don't care about final value of count, we know count == n2 - } else { - if (!rms_only) { - mu = WARP_SHFL(mu, 0); - } - sigma2 = WARP_SHFL(sigma2/U(n2), 0); - } - } -} - -template<> __device__ -void cuWelfordMuSigma2( - const at::Half* __restrict__ vals, - const int n1, - const int n2, - const int i1, - float& mu, - float& sigma2, - float* buf, - const int GPU_WARP_SIZE, - bool rms_only) -{ - // Assumptions: - // 1) blockDim.x == warpSize - // 2) Tensor is contiguous - // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. - // - // compute variance and mean over n2 - float count = 0.0f; - mu= float(0); - sigma2 = float(0); - if (i1 < n1) { - // one warp normalizes one n1 index, - // synchronization is implicit - // initialize with standard Welford algorithm - const int numx = blockDim.x * blockDim.y; - const int thrx = threadIdx.x + threadIdx.y * blockDim.x; - const at::Half* lvals = vals + i1*n2; - int l = 8*thrx; - if ((((size_t)lvals)&3) != 0) { - // 16 bit alignment - // first thread consumes first point - if (thrx == 0) { - float curr = static_cast(lvals[0]); - if (!rms_only) { - cuWelfordOnlineSum(curr,mu,sigma2,count); - } else { - cuRMSOnlineSum(curr, sigma2); - } - - } - ++l; - } - // at this point, lvals[l] are 32 bit aligned for all threads. - for (; l+7 < n2; l+=8*numx) { - for (int k = 0; k < 8; k+=2) { - float2 curr = __half22float2(*((__half2*)(lvals+l+k))); - if (!rms_only) { - cuWelfordOnlineSum(curr.x,mu,sigma2,count); - cuWelfordOnlineSum(curr.y,mu,sigma2,count); - } else { - cuRMSOnlineSum(curr.x, sigma2); - cuRMSOnlineSum(curr.y, sigma2); - } - } - } - for (; l < n2; ++l) { - float curr = static_cast(lvals[l]); - if (!rms_only) { - cuWelfordOnlineSum(curr,mu,sigma2,count); - } else { - cuRMSOnlineSum(curr, sigma2); - } - } - // intra-warp reductions - #pragma unroll - for (int stride = GPU_WARP_SIZE / 2; stride > 0; stride /= 2) { - float sigma2B = WARP_SHFL_DOWN(sigma2, stride); - if (!rms_only) { - float muB = WARP_SHFL_DOWN(mu, stride); - float countB = WARP_SHFL_DOWN(count, stride); - cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count); - } else { - cuChanRMSOnlineSum(sigma2B, sigma2); - } - } - // threadIdx.x == 0 has correct values for each warp - // inter-warp reductions - if (blockDim.y > 1) { - float* ubuf = (float*)buf; - float* ibuf = (float*)(ubuf + blockDim.y); - for (int offset = blockDim.y/2; offset > 0; offset /= 2) { - // upper half of warps write to shared - if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) { - const int wrt_y = threadIdx.y - offset; - ubuf[2*wrt_y+1] = sigma2; - if (!rms_only) { - ubuf[2*wrt_y] = mu; - ibuf[wrt_y] = count; - } - } - __syncthreads(); - // lower half merges - if (threadIdx.x == 0 && threadIdx.y < offset) { - float sigma2B = ubuf[2*threadIdx.y+1]; - if (!rms_only) { - float muB = ubuf[2*threadIdx.y]; - float countB = ibuf[threadIdx.y]; - cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count); - } else { - cuChanRMSOnlineSum(sigma2B, sigma2); - } - } - __syncthreads(); - } - // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values - if (threadIdx.x == 0 && threadIdx.y == 0) { - if (!rms_only) { - ubuf[0] = mu; - } - ubuf[1] = sigma2; - } - __syncthreads(); - if (!rms_only) { - mu = ubuf[0]; - } - sigma2 = ubuf[1]/float(n2); - // don't care about final value of count, we know count == n2 - } else { - if (!rms_only) { - mu = WARP_SHFL(mu, 0); - } - sigma2 = WARP_SHFL(sigma2/float(n2), 0); - } - } -} - -template U rsqrt(U v) { - return U(1) / sqrt(v); -} -#if defined __HIP_PLATFORM_HCC__ -__device__ float rsqrt(float v) { - return rsqrtf(v); -} -#else -template<> float rsqrt(float v) { - return rsqrtf(v); -} -#endif -template<> double rsqrt(double v) { - return rsqrt(v); -} - -namespace { -// This is the un-specialized struct. Note that we prevent instantiation of this -// struct by putting an undefined symbol in the function body so it won't compile. -// template -// struct SharedMemory -// { -// // Ensure that we won't compile any un-specialized types -// __device__ T *getPointer() -// { -// extern __device__ void error(void); -// error(); -// return NULL; -// } -// }; -// https://github.com/NVIDIA/apex/issues/246 -template -struct SharedMemory; - -template <> -struct SharedMemory -{ - __device__ float *getPointer() - { - extern __shared__ float s_float[]; - return s_float; - } -}; - -template <> -struct SharedMemory -{ - __device__ double *getPointer() - { - extern __shared__ double s_double[]; - return s_double; - } -}; -} - -template __device__ -void cuApplyLayerNorm_( - V* __restrict__ output_vals, - U* __restrict__ mean, - U* __restrict__ invvar, - const T* __restrict__ vals, - const int n1, - const int n2, - const U epsilon, - const V* __restrict__ gamma, - const V* __restrict__ beta, - const int GPU_WARP_SIZE, - bool rms_only) -{ - // Assumptions: - // 1) blockDim.x == warpSize - // 2) Tensors are contiguous - // - for (int i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { - SharedMemory shared; - U* buf = shared.getPointer(); - U mu,sigma2; - cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf, GPU_WARP_SIZE, rms_only); - const T* lvals = vals + i1*n2; - V* ovals = output_vals + i1*n2; - U c_invvar = rsqrt(sigma2 + epsilon); - const int numx = blockDim.x * blockDim.y; - const int thrx = threadIdx.x + threadIdx.y * blockDim.x; - if (gamma != NULL && (beta != NULL || rms_only)) { - for (int i = thrx; i < n2; i+=numx) { - U curr = static_cast(lvals[i]); - if (!rms_only) { - ovals[i] = gamma[i] * static_cast(c_invvar * (curr - mu)) + beta[i]; - } else { - ovals[i] = gamma[i] * static_cast(c_invvar * curr); - } - - } - } else { - for (int i = thrx; i < n2; i+=numx) { - U curr = static_cast(lvals[i]); - if (!rms_only) { - ovals[i] = static_cast(c_invvar * (curr - mu)); - } else { - ovals[i] = static_cast(c_invvar * curr); - } - } - } - if (threadIdx.x == 0 && threadIdx.y == 0) { - if (!rms_only) { - mean[i1] = mu; - } - invvar[i1] = c_invvar; - } - __syncthreads(); - } -} - -template __global__ -void cuApplyLayerNorm( - V* __restrict__ output_vals, - U* __restrict__ mean, - U* __restrict__ invvar, - const T* __restrict__ vals, - const int n1, - const int n2, - const U epsilon, - const V* __restrict__ gamma, - const V* __restrict__ beta, - const int warp_size) -{ - cuApplyLayerNorm_(output_vals, mean, invvar, vals, n1, n2, epsilon, gamma, beta, warp_size, false); -} - -template __global__ -void cuApplyRMSNorm( - V* __restrict__ output_vals, - U* __restrict__ invvar, - const T* __restrict__ vals, - const int n1, - const int n2, - const U epsilon, - const V* __restrict__ gamma, - const int warp_size) -{ - cuApplyLayerNorm_(output_vals, NULL, invvar, vals, n1, n2, epsilon, gamma, NULL, warp_size, true); -} - -template __device__ -void cuLoadWriteStridedInputs( - const int i1_block, - const int thr_load_row_off, - const int thr_load_col_off, - const int i2_off, - const int row_stride, - U* warp_buf1, - U* warp_buf2, - const T* input, - const V* dout, - const int i1_end, - const int n2, - const U* __restrict__ mean, - const U* __restrict__ invvar, - bool rms_only - ) -{ - int i1 = i1_block+thr_load_row_off; - if (i1 < i1_end) { - U curr_mean; - if (!rms_only) { - curr_mean = mean[i1]; - } - U curr_invvar = invvar[i1]; - for (int k = 0; k < blockDim.y; ++k) { - int i2 = i2_off + k; - int load_idx = i1*n2+i2; - int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; - if (i2(input[load_idx]); - U curr_dout = static_cast(dout[load_idx]); - if (!rms_only) { - warp_buf1[write_idx] = curr_dout; - warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar; - } else { - warp_buf2[write_idx] = curr_dout * (curr_input) * curr_invvar; - } - } else { - if (!rms_only) { - warp_buf1[write_idx] = U(0); - } - warp_buf2[write_idx] = U(0); - } - } - } else { - for (int k = 0; k < blockDim.y; ++k) { - int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; - if (!rms_only) { - warp_buf1[write_idx] = U(0); - } - warp_buf2[write_idx] = U(0); - } - } -} -template __device__ -void cuLoadAddStridedInputs( - const int i1_block, - const int thr_load_row_off, - const int thr_load_col_off, - const int i2_off, - const int row_stride, - U* warp_buf1, - U* warp_buf2, - const T* input, - const V* dout, - const int i1_end, - const int n2, - const U* __restrict__ mean, - const U* __restrict__ invvar, - bool rms_only - ) -{ - int i1 = i1_block+thr_load_row_off; - if (i1 < i1_end) { - U curr_mean; - if (!rms_only) { - curr_mean = mean[i1]; - } - U curr_invvar = invvar[i1]; - for (int k = 0; k < blockDim.y; ++k) { - int i2 = i2_off + k; - int load_idx = i1*n2+i2; - int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; - if (i2(input[load_idx]); - U curr_dout = static_cast(dout[load_idx]); - if (!rms_only) { - warp_buf1[write_idx] += curr_dout; - warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar; - } else { - warp_buf2[write_idx] += curr_dout * (curr_input) * curr_invvar; - } - } - } - } -} - - -template __global__ -void cuComputePartGradGammaBeta( - const V* __restrict__ dout, - const T* __restrict__ input, - const int n1, - const int n2, - const U* __restrict__ mean, - const U* __restrict__ invvar, - U epsilon, - U* part_grad_gamma, - U* part_grad_beta, - bool rms_only) -{ - const int numsegs_n1 = (n1+blockDim.y*blockDim.y-1) / (blockDim.y*blockDim.y); - const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y; - const int i1_beg = blockIdx.y * segs_per_block * blockDim.y*blockDim.y; - const int i1_beg_plus_one = (blockIdx.y+1) * segs_per_block * blockDim.y*blockDim.y; - const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1; - const int row_stride = blockDim.x+1; - const int thr_load_col_off = (threadIdx.x*blockDim.y)&(blockDim.x-1); - const int thr_load_row_off = (threadIdx.x*blockDim.y)/blockDim.x + threadIdx.y*blockDim.y; - const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off; - SharedMemory shared; - U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * blockDim.y + (blockDim.y - 1)*(blockDim.x/blockDim.y) elements - U* warp_buf1 = (U*)buf; - U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride; - // compute partial sums from strided inputs - // do this to increase number of loads in flight - cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar, rms_only); - for (int i1_block = i1_beg+blockDim.y*blockDim.y; i1_block < i1_end; i1_block+=blockDim.y*blockDim.y) { - cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar, rms_only); - } - __syncthreads(); - // inter-warp reductions - // sum within each warp - U acc1 = U(0); - U acc2 = U(0); - for (int k = 0; k < blockDim.y; ++k) { - int row1 = threadIdx.y + k*blockDim.y; - int idx1 = row1*row_stride + threadIdx.x; - if (!rms_only) { - acc1 += warp_buf1[idx1]; - } - acc2 += warp_buf2[idx1]; - } - if (!rms_only) { - warp_buf1[threadIdx.y*row_stride+threadIdx.x] = acc1; - } - warp_buf2[threadIdx.y*row_stride+threadIdx.x] = acc2; - __syncthreads(); - // sum all warps - for (int offset = blockDim.y/2; offset > 1; offset /= 2) { - if (threadIdx.y < offset) { - int row1 = threadIdx.y; - int row2 = threadIdx.y + offset; - int idx1 = row1*row_stride + threadIdx.x; - int idx2 = row2*row_stride + threadIdx.x; - if (!rms_only) { - warp_buf1[idx1] += warp_buf1[idx2]; - } - warp_buf2[idx1] += warp_buf2[idx2]; - } - __syncthreads(); - } - int i2 = blockIdx.x * blockDim.x + threadIdx.x; - if (threadIdx.y == 0 && i2 < n2) { - int row1 = threadIdx.y; - int row2 = threadIdx.y + 1; - int idx1 = row1*row_stride + threadIdx.x; - int idx2 = row2*row_stride + threadIdx.x; - if (!rms_only) { - part_grad_beta[blockIdx.y*n2+i2] = warp_buf1[idx1] + warp_buf1[idx2]; - } - part_grad_gamma[blockIdx.y*n2+i2] = warp_buf2[idx1] + warp_buf2[idx2]; - } -} - -template __global__ -void cuComputeGradGammaBeta( - const U* part_grad_gamma, - const U* part_grad_beta, - const int part_size, - const int n1, - const int n2, - V* grad_gamma, - V* grad_beta, - bool rms_only) -{ - // sum partial gradients for gamma and beta - SharedMemory shared; - U* buf = shared.getPointer(); - int i2 = blockIdx.x * blockDim.x + threadIdx.x; - if (i2 < n2) { - // each warp does sequential reductions until reduced part_size is num_warps - int num_warp_reductions = part_size / blockDim.y; - U sum_gamma = U(0); - U sum_beta = U(0); - const U* part_grad_gamma_ptr = part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2; - const U* part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2; - for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) { - sum_gamma += part_grad_gamma_ptr[warp_offset*n2]; - if (!rms_only) { - sum_beta += part_grad_beta_ptr[warp_offset*n2]; - } - } - // inter-warp reductions - const int nbsize3 = blockDim.x * blockDim.y / 2; - for (int offset = blockDim.y/2; offset >= 1; offset /= 2) { - // top half write to shared memory - if (threadIdx.y >= offset && threadIdx.y < 2*offset) { - const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x; - buf[write_idx] = sum_gamma; - if (!rms_only) { - buf[write_idx+nbsize3] = sum_beta; - } - } - __syncthreads(); - // bottom half sums - if (threadIdx.y < offset) { - const int read_idx = threadIdx.y * blockDim.x + threadIdx.x; - sum_gamma += buf[read_idx]; - if (!rms_only) { - sum_beta += buf[read_idx+nbsize3]; - } - } - __syncthreads(); - } - // write out fully summed gradients - if (threadIdx.y == 0) { - grad_gamma[i2] = sum_gamma; - if (!rms_only) { - grad_beta[i2] = sum_beta; - } - } - } -} - - -template __global__ -void cuComputeGradInput( - const V* __restrict__ dout, - const T* __restrict__ input, - const int n1, - const int n2, - const U* __restrict__ mean, - const U* __restrict__ invvar, - U epsilon, - const V* gamma, - T* grad_input, - bool rms_only) -{ - for (int i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { - U sum_loss1 = U(0); - U sum_loss2 = U(0); - U c_mean; - if (!rms_only) { - c_mean = mean[i1]; - } - const U c_invvar = invvar[i1]; - const T* k_input = input + i1*n2; - const V* k_dout = dout + i1*n2; - const int numx = blockDim.x * blockDim.y; - const int thrx = threadIdx.x + threadIdx.y * blockDim.x; - if (gamma != NULL) { - #ifndef __HIP_PLATFORM_HCC__ - int l = 4*thrx; - for (; l+3 < n2; l+=4*numx) { - for (int k = 0; k < 4; ++k) { - const U c_h = static_cast(k_input[l+k]); - const U c_loss = static_cast(k_dout[l+k]); - if (!rms_only) { - sum_loss1 += c_loss * gamma[l+k]; - sum_loss2 += c_loss * gamma[l+k] * (c_h - c_mean) * c_invvar; - } else { - sum_loss2 += c_loss * gamma[l+k] * (c_h) * c_invvar; - } - } - } - for (; l < n2; ++l) { - const U c_h = static_cast(k_input[l]); - const U c_loss = static_cast(k_dout[l]); - if (!rms_only) { - sum_loss1 += c_loss * gamma[l]; - sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar; - } else { - sum_loss2 += c_loss * gamma[l] * (c_h) * c_invvar; - } - - } - #else - // Optimization for ROCm MI100 - for( int l = 0; l < n2 ; l += numx) { - int idx = l + thrx; - const U gamma_idx = static_cast((idx((idx((idx(k_input[l+k]); - const U c_loss = static_cast(k_dout[l+k]); - if (!rms_only) { - sum_loss1 += c_loss; - sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; - } else { - sum_loss2 += c_loss * (c_h) * c_invvar; - } - } - } - for (; l < n2; ++l) { - const U c_h = static_cast(k_input[l]); - const U c_loss = static_cast(k_dout[l]); - if (!rms_only) { - sum_loss1 += c_loss; - sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; - } else { - sum_loss2 += c_loss * (c_h) * c_invvar; - } - } - #else - for( int l = 0; l < n2 ; l += numx) { - int idx = l + thrx; - const U c_h = static_cast((idx((idx 0; mask /= 2) { - if (!rms_only) { - sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask); - } - sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask); - } - // inter-warp reductions - if (blockDim.y > 1) { - SharedMemory shared; - U* buf = shared.getPointer(); - for (int offset = blockDim.y/2; offset > 0; offset /= 2) { - // upper half of warps write to shared - if (threadIdx.y >= offset && threadIdx.y < 2*offset) { - const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x; - if (!rms_only) { - buf[2*wrt_i] = sum_loss1; - } - buf[2*wrt_i+1] = sum_loss2; - } - __syncthreads(); - // lower half merges - if (threadIdx.y < offset) { - const int read_i = threadIdx.y * blockDim.x + threadIdx.x; - if (!rms_only) { - sum_loss1 += buf[2*read_i]; - } - sum_loss2 += buf[2*read_i+1]; - } - __syncthreads(); - } - if (threadIdx.y == 0) { - if (!rms_only) { - buf[2*threadIdx.x] = sum_loss1; - } - buf[2*threadIdx.x+1] = sum_loss2; - } - __syncthreads(); - if (threadIdx.y !=0) { - if (!rms_only) { - sum_loss1 = buf[2*threadIdx.x]; - } - sum_loss2 = buf[2*threadIdx.x+1]; - } - } - // all threads now have the two sums over l - U fH = (U)n2; - U term1 = (U(1) / fH) * c_invvar; - T* k_grad_input = grad_input + i1*n2; - if (gamma != NULL) { - for (int l = thrx; l < n2; l+=numx) { - const U c_h = static_cast(k_input[l]); - const U c_loss = static_cast(k_dout[l]); - U f_grad_input = fH * c_loss * gamma[l]; - if (!rms_only) { - f_grad_input -= sum_loss1; - f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; - } else { - f_grad_input -= (c_h) * c_invvar * sum_loss2; - } - f_grad_input *= term1; - k_grad_input[l] = static_cast(f_grad_input); - } - } else { - for (int l = thrx; l < n2; l+=numx) { - const U c_h = static_cast(k_input[l]); - const U c_loss = static_cast(k_dout[l]); - U f_grad_input = fH * c_loss; - if (!rms_only) { - f_grad_input -= sum_loss1; - f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; - } else { - f_grad_input -= (c_h) * c_invvar * sum_loss2; - } - f_grad_input *= term1; - k_grad_input[l] = static_cast(f_grad_input); - } - } - // prevent race where buf is written again before reads are done - __syncthreads(); - } -} - - -template -void HostApplyLayerNorm( - V* output, - U* mean, - U* invvar, - const T* input, - int n1, - int n2, - double epsilon, - const V* gamma, - const V* beta - ) -{ - auto stream = at::cuda::getCurrentCUDAStream().stream(); - const int warp_size = at::cuda::warp_size(); - dim3 threads(warp_size ,4, 1); // MI100 wavefront/warp = 64 - #ifdef __HIP_PLATFORM_HCC__ - // Optimization for ROCm MI100 - threads.y = 1; - #endif - - const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; - const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1); - int nshared = - threads.y > 1 ? - threads.y*sizeof(U)+(threads.y/2)*sizeof(U) : - 0; - cuApplyLayerNorm<<>>( - output, mean, invvar, input, n1, n2, U(epsilon), gamma, beta, warp_size); -} - -// Optimize HostRMSNormGradient for AMD GPUs: https://github.com/ROCmSoftwarePlatform/apex/pull/66/files -template -void HostApplyRMSNorm( - V* output, - U* invvar, - const T* input, - int n1, - int n2, - double epsilon, - const V* gamma) -{ - auto stream = at::cuda::getCurrentCUDAStream().stream(); - const int warp_size = at::cuda::warp_size(); - const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; - const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1); - dim3 threads(warp_size,4,1); - #ifdef __HIP_PLATFORM_HCC__ - // Optimization for ROCm MI100 - threads.y = 2; - #endif - int nshared = - threads.y > 1 ? - threads.y*sizeof(U)+(threads.y/2)*sizeof(U) : - 0; - cuApplyRMSNorm<<>>( - output, invvar, input, n1, n2, U(epsilon), gamma, warp_size); -} - -void cuda_layer_norm( - at::Tensor* output, - at::Tensor* mean, - at::Tensor* invvar, - at::Tensor* input, - int n1, - int n2, - #ifdef VERSION_GE_1_1 - at::IntArrayRef normalized_shape, - #else - at::IntList normalized_shape, - #endif - at::Tensor* gamma, - at::Tensor* beta, - double epsilon) -{ - using namespace at; - DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( - input->scalar_type(), output->scalar_type(), "layer_norm_cuda_kernel", - using accscalar_t = at::acc_type; - HostApplyLayerNorm( - output->DATA_PTR(), - mean->DATA_PTR(), - invvar->DATA_PTR(), - input->DATA_PTR(), - n1,n2, - epsilon, - gamma != NULL ? gamma->DATA_PTR() : NULL, - beta != NULL ? beta->DATA_PTR() : NULL); - ) -} - -void cuda_rms_norm( - at::Tensor* output, - at::Tensor* invvar, - at::Tensor* input, - int n1, - int n2, - #ifdef VERSION_GE_1_1 - at::IntArrayRef normalized_shape, - #else - at::IntList normalized_shape, - #endif - at::Tensor* gamma, - double epsilon) -{ - using namespace at; - DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( - input->scalar_type(), output->scalar_type(), "rms_norm_cuda_kernel", - using accscalar_t = at::acc_type; - HostApplyRMSNorm( - output->DATA_PTR(), - invvar->DATA_PTR(), - input->DATA_PTR(), - n1,n2, - epsilon, - gamma != NULL ? gamma->DATA_PTR() : NULL); - ) -} - - -template -void HostLayerNormGradient( - const V* dout, - const U* mean, - const U* invvar, - at::Tensor* input, - int n1, - int n2, - const V* gamma, - const V* beta, - double epsilon, - T* grad_input, - V* grad_gamma, - V* grad_beta - ) -{ - auto stream = at::cuda::getCurrentCUDAStream().stream(); - const int warp_size = at::cuda::warp_size(); - - if (gamma != NULL && beta != NULL) { - // compute grad_gamma(j) and grad_beta(j) - // Optimize layer normalization for AMD GPUs: https://github.com/ROCmSoftwarePlatform/apex/pull/66/files - const int part_size = warp_size; - const dim3 threads2(warp_size, 4, 1); - const dim3 blocks2((n2+threads2.x-1) / threads2.x,part_size, 1); - const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1); - const int nshared2_b = threads2.x * threads2.y * sizeof(U); - const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; - // note (mkozuki): I can hard code part_grad_gamma's dtype as float given that - // the `cuda_layer_norm_gradient` doesn't support double. - const auto part_grad_dtype = - (input->scalar_type() == at::ScalarType::Half || input->scalar_type() == at::ScalarType::BFloat16) ? - at::ScalarType::Float : - input->scalar_type(); - at::Tensor part_grad_gamma = at::empty({part_size,n2}, input->options().dtype(part_grad_dtype)); - at::Tensor part_grad_beta = at::empty_like(part_grad_gamma); - cuComputePartGradGammaBeta<<>>( - dout, - input->DATA_PTR(), - n1,n2, - mean, - invvar, - U(epsilon), - part_grad_gamma.DATA_PTR(), - part_grad_beta.DATA_PTR(), - false); - - const dim3 threads3(warp_size, 8, 1); - const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1); - const int nshared3 = threads3.x * threads3.y * sizeof(U); - cuComputeGradGammaBeta<<>>( - part_grad_gamma.DATA_PTR(), - part_grad_beta.DATA_PTR(), - part_size, - n1,n2, - grad_gamma, - grad_beta, - false); - } - - // compute grad_input - // https://github.com/microsoft/onnxruntime/pull/7682/files#diff-f9eace25e62b646410b067f96cd930c7fe843326dca1e8d383631ca27f1a8d00R540 - // https://github.com/amathews-amd/onnxruntime/blob/80c0555c2bc17fb109190e2082cd3fda0a37984c/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.cu#L541 - - const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; - const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1); - dim3 threads1(warp_size,4,1); // MI100 wavefront/warp = 64 - #ifdef __HIP_PLATFORM_HCC__ - // Optimization for ROCm MI100 - threads1.y = 2; - #endif - int nshared = - threads1.y > 1 ? - threads1.y*threads1.x*sizeof(U) : - 0; - cuComputeGradInput<<>>( - dout, - input->DATA_PTR(), - n1,n2, - mean, - invvar, - U(epsilon), - gamma, - grad_input, - false); -} -// Optimize HostRMSNormGradient for AMD GPUs: https://github.com/ROCmSoftwarePlatform/apex/pull/66/files -template -void HostRMSNormGradient( - const V* dout, - const U* invvar, - at::Tensor* input, - int n1, - int n2, - const V* gamma, - double epsilon, - T* grad_input, - V* grad_gamma) -{ - auto stream = at::cuda::getCurrentCUDAStream().stream(); - const int warp_size = at::cuda::warp_size(); - if (gamma != NULL) { - const int part_size = warp_size; - const dim3 threads2(warp_size,4,1); - const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1); - const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * (threads2.x + 1); - const int nshared2_b = threads2.x * threads2.y * sizeof(U); - const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; - // note (mkozuki): I can hard code part_grad_gamma's dtype as float given that - // the `cuda_layer_norm_gradient` doesn't support double. - const auto part_grad_dtype = - (input->scalar_type() == at::ScalarType::Half || input->scalar_type() == at::ScalarType::BFloat16) ? - at::ScalarType::Float : - input->scalar_type(); - at::Tensor part_grad_gamma = at::empty({part_size,n2}, input->options().dtype(part_grad_dtype)); - cuComputePartGradGammaBeta<<>>( - dout, - input->DATA_PTR(), - n1,n2, - invvar, // unused - invvar, - U(epsilon), - part_grad_gamma.DATA_PTR(), - part_grad_gamma.DATA_PTR(), /* unused */ - true); - - const dim3 threads3(warp_size,8,1); - const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1); - const int nshared3 = threads3.x * threads3.y * sizeof(U); - cuComputeGradGammaBeta<<>>( - part_grad_gamma.DATA_PTR(), - part_grad_gamma.DATA_PTR(), /* unused */ - part_size, - n1,n2, - grad_gamma, - grad_gamma, /* unused */ - true); - } - - // compute grad_input - const uint64_t maxGridY = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; - const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1); - const dim3 threads1(warp_size,4,1); - int nshared = - threads1.y > 1 ? - threads1.y*threads1.x*sizeof(U) : - 0; - cuComputeGradInput<<>>( - dout, - input->DATA_PTR(), - n1,n2, - invvar, /* unused */ - invvar, - U(epsilon), - gamma, - grad_input, - true); -} - -void cuda_layer_norm_gradient( - at::Tensor* dout, - at::Tensor* mean, - at::Tensor* invvar, - at::Tensor* input, - int n1, - int n2, - #ifdef VERSION_GE_1_1 - at::IntArrayRef normalized_shape, - #else - at::IntList normalized_shape, - #endif - at::Tensor* gamma, - at::Tensor* beta, - double epsilon, - at::Tensor* grad_input, - at::Tensor* grad_gamma, - at::Tensor* grad_beta) -{ - using namespace at; - // we can do away with `accscalar_t` as there're only three dtypes: fp32, fp16, bf16 - DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( - input->scalar_type(), gamma == NULL ? input->scalar_type() : gamma->scalar_type(), "cuComputeGradInput", - using accscalar_t = at::acc_type; - HostLayerNormGradient( - dout->DATA_PTR(), - mean->DATA_PTR(), - invvar->DATA_PTR(), - input, - n1,n2, - // TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta - // if gamma Tensor is NULL on input. - gamma != NULL ? gamma->DATA_PTR() : NULL, - gamma != NULL ? beta->DATA_PTR() : NULL, - epsilon, - grad_input->DATA_PTR(), - gamma != NULL ? grad_gamma->DATA_PTR() : NULL, - gamma != NULL ? grad_beta->DATA_PTR() : NULL); - ) -} - -void cuda_rms_norm_gradient( - at::Tensor* dout, - at::Tensor* invvar, - at::Tensor* input, - int n1, - int n2, - #ifdef VERSION_GE_1_1 - at::IntArrayRef normalized_shape, - #else - at::IntList normalized_shape, - #endif - at::Tensor* gamma, - double epsilon, - at::Tensor* grad_input, - at::Tensor* grad_gamma) -{ - using namespace at; - // we can do away with `accscalar_t` as there're only three dtypes: fp32, fp16, bf16 - // DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( - DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( - input->scalar_type(), gamma == NULL ? input->scalar_type() : gamma->scalar_type(), "cuComputeGradInputRMS", - using accscalar_t = at::acc_type; - HostRMSNormGradient( - dout->DATA_PTR(), - invvar->DATA_PTR(), - input, - n1,n2, - // TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta - // if gamma Tensor is NULL on input. - gamma != NULL ? gamma->DATA_PTR() : NULL, - epsilon, - grad_input->DATA_PTR(), - gamma != NULL ? grad_gamma->DATA_PTR() : NULL); - ) -} diff --git a/csrc/megatron/fused_weight_gradient_dense.cpp b/csrc/megatron/fused_weight_gradient_dense.cpp deleted file mode 100644 index a14c2b2..0000000 --- a/csrc/megatron/fused_weight_gradient_dense.cpp +++ /dev/null @@ -1,21 +0,0 @@ -#include - -#include -#include - -void wgrad_gemm_accum_fp32_cuda_stub( - at::Tensor &input_2d, - at::Tensor &d_output_2d, - at::Tensor &d_weight -); - -void wgrad_gemm_accum_fp16_cuda_stub( - at::Tensor &input_2d, - at::Tensor &d_output_2d, - at::Tensor &d_weight -); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("wgrad_gemm_accum_fp32", &wgrad_gemm_accum_fp32_cuda_stub, "wgrad gemm accum in fp32"); - m.def("wgrad_gemm_accum_fp16", &wgrad_gemm_accum_fp16_cuda_stub, "wgrad gemm accum in fp16"); -} diff --git a/csrc/megatron/fused_weight_gradient_dense_16bit_prec_cuda.cu b/csrc/megatron/fused_weight_gradient_dense_16bit_prec_cuda.cu deleted file mode 100644 index 60d1e8d..0000000 --- a/csrc/megatron/fused_weight_gradient_dense_16bit_prec_cuda.cu +++ /dev/null @@ -1,155 +0,0 @@ -#include -#include -#include -#include - -#include -#include -#include - -/* Includes, cuda */ -#include -#include - -#include "type_shim.h" - - -// BF16 inputs and BF16 accumulation -void gemmex_wrapper_fp16( - cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float* alpha, - at::BFloat16* A, - int lda, - at::BFloat16* B, - int ldb, - const float* beta, - at::BFloat16* C, - int ldc) { - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, - transa, - transb, - m, - n, - k, - alpha, - A, - CUDA_R_16BF, - lda, - B, - CUDA_R_16BF, - ldb, - beta, - C, - CUDA_R_16BF, - ldc, - CUDA_R_32F, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); -} - -// FP16 inputs and FP16 accumulation -void gemmex_wrapper_fp16( - cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float* alpha, - at::Half* A, - int lda, - at::Half* B, - int ldb, - const float* beta, - at::Half* C, - int ldc) { - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, - transa, - transb, - m, - n, - k, - alpha, - A, - CUDA_R_16F, - lda, - B, - CUDA_R_16F, - ldb, - beta, - C, - CUDA_R_16F, - ldc, - CUDA_R_32F, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); -} - -template -void wgrad_gemm_accum_fp16_cuda(T *input, T *d_output, T *d_weight, int in_dim, int hidden_dim, int out_dim) { - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - cudaStream_t stream; - cublasGetStream(handle, &stream); - const float alpha = 1.0; - const float beta = 1.0; - - gemmex_wrapper_fp16( - handle, - CUBLAS_OP_N, - CUBLAS_OP_T, - in_dim, - out_dim, - hidden_dim, - &alpha, - input, - in_dim, - d_output, - out_dim, - &beta, - d_weight, - in_dim); -} - -template void wgrad_gemm_accum_fp16_cuda(at::Half *input, at::Half *d_output, at::Half *d_weight, int in_dim, int hidden_dim, int out_dim); -template void wgrad_gemm_accum_fp16_cuda(at::BFloat16 *input, at::BFloat16 *d_output, at::BFloat16 *d_weight, int in_dim, int hidden_dim, int out_dim); - -void wgrad_gemm_accum_fp16_cuda_stub( - at::Tensor &input, - at::Tensor &d_output, - at::Tensor &d_weight -) { - at::Tensor input_2d, d_output_2d; - // input tensor: collapse to the first dim - auto in_sizes = input.sizes(); - if (input.dim() > 2) { - input_2d = input.view({-1, in_sizes[in_sizes.size() - 1]}); - } else { - input_2d = input; - } - // d_output tensor: collapse to the first dim - auto d_out_sizes = d_output.sizes(); - if (d_output.dim() > 2) { - d_output_2d = d_output.view({-1, d_out_sizes[d_out_sizes.size() - 1]}); - } else { - d_output_2d = d_output; - } - - const int hidden_dim = input_2d.size(0); - const int in_dim = input_2d.size(1); - const int out_dim = d_weight.size(0); - - DISPATCH_HALF_AND_BFLOAT(input_2d.scalar_type(), "wgrad_gemm_accum_fp16", - wgrad_gemm_accum_fp16_cuda( - input_2d.data_ptr(), - d_output_2d.data_ptr(), - d_weight.data_ptr(), - in_dim, - hidden_dim, - out_dim); - ); -} diff --git a/csrc/megatron/fused_weight_gradient_dense_cuda.cu b/csrc/megatron/fused_weight_gradient_dense_cuda.cu deleted file mode 100644 index dfaa134..0000000 --- a/csrc/megatron/fused_weight_gradient_dense_cuda.cu +++ /dev/null @@ -1,195 +0,0 @@ -#include -#include -#include -#include - -#include -#include -#include - -/* Includes, cuda */ -#include -#include - -#include "type_shim.h" - - -// BF16 Tensor core wrapper around cublas GEMMEx -void gemmex_wrapper( - cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float* alpha, - at::BFloat16* A, - int lda, - at::BFloat16* B, - int ldb, - const float* beta, - float* C, - int ldc) { - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, - transa, - transb, - m, - n, - k, - alpha, - A, - CUDA_R_16BF, - lda, - B, - CUDA_R_16BF, - ldb, - beta, - C, - CUDA_R_32F, - ldc, - CUDA_R_32F, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); -} - -// FP16 Tensor core wrapper around cublas GEMMEx -void gemmex_wrapper( - cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float* alpha, - at::Half* A, - int lda, - at::Half* B, - int ldb, - const float* beta, - float* C, - int ldc) { - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, - transa, - transb, - m, - n, - k, - alpha, - A, - CUDA_R_16F, - lda, - B, - CUDA_R_16F, - ldb, - beta, - C, - CUDA_R_32F, - ldc, - CUDA_R_32F, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); -} - -// FP32 wrapper around cublas GEMMEx -void gemmex_wrapper( - cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float *alpha, - float *A, - int lda, - float *B, - int ldb, - const float *beta, - float *C, - int ldc) { - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, - transa, - transb, - m, - n, - k, - alpha, - A, - CUDA_R_32F, - lda, - B, - CUDA_R_32F, - ldb, - beta, - C, - CUDA_R_32F, - ldc, - CUDA_R_32F, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); -} - -template -void wgrad_gemm_accum_fp32_cuda(T *input, T *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim) { - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - cudaStream_t stream; - cublasGetStream(handle, &stream); - const float alpha = 1.0; - const float beta = 1.0; - - gemmex_wrapper( - handle, - CUBLAS_OP_N, - CUBLAS_OP_T, - in_dim, - out_dim, - hidden_dim, - &alpha, - input, - in_dim, - d_output, - out_dim, - &beta, - d_weight, - in_dim); -} - -template void wgrad_gemm_accum_fp32_cuda(at::Half *input, at::Half *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim); -template void wgrad_gemm_accum_fp32_cuda(at::BFloat16 *input, at::BFloat16 *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim); -template void wgrad_gemm_accum_fp32_cuda(float *input, float *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim); - - -void wgrad_gemm_accum_fp32_cuda_stub( - at::Tensor &input, - at::Tensor &d_output, - at::Tensor &d_weight -) { - at::Tensor input_2d, d_output_2d; - // input tensor: collapse to the first dim - auto in_sizes = input.sizes(); - if (input.dim() > 2) { - input_2d = input.view({-1, in_sizes[in_sizes.size() - 1]}); - } else { - input_2d = input; - } - // d_output tensor: collapse to the first dim - auto d_out_sizes = d_output.sizes(); - if (d_output.dim() > 2) { - d_output_2d = d_output.view({-1, d_out_sizes[d_out_sizes.size() - 1]}); - } else { - d_output_2d = d_output; - } - - const int hidden_dim = input_2d.size(0); - const int in_dim = input_2d.size(1); - const int out_dim = d_weight.size(0); - - DISPATCH_FLOAT_HALF_AND_BFLOAT(input_2d.scalar_type(), 0, "wgrad_gemm_accum_fp32", - wgrad_gemm_accum_fp32_cuda( - input_2d.data_ptr(), - d_output_2d.data_ptr(), - d_weight.data_ptr(), - in_dim, - hidden_dim, - out_dim); - ); -} diff --git a/csrc/megatron/scaled_masked_softmax.cpp b/csrc/megatron/scaled_masked_softmax.cpp deleted file mode 100644 index dd471a0..0000000 --- a/csrc/megatron/scaled_masked_softmax.cpp +++ /dev/null @@ -1,96 +0,0 @@ -/* coding=utf-8 - * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace multihead_attn { -namespace fused_softmax { -namespace scaled_masked_softmax { - -torch::Tensor fwd_cuda( - torch::Tensor const& input, - torch::Tensor const& mask, - float scale_factor); - -torch::Tensor bwd_cuda( - torch::Tensor const& output_grads, - torch::Tensor const& softmax_results, - float scale_factor); - -int get_batch_per_block_cuda( - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads); - -torch::Tensor fwd( - torch::Tensor const& input, - torch::Tensor const& mask, - float scale_factor) { - AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); - AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || - (input.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - AT_ASSERTM(mask.dim() == 4, "expected 4D tensor"); - - return fwd_cuda(input, mask, scale_factor); -} - -torch::Tensor bwd( - torch::Tensor const& output_grads, - torch::Tensor const& softmax_results, - float scale_factor) { - - AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor"); - AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); - - AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || - (output_grads.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || - (softmax_results.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - - return bwd_cuda(output_grads, softmax_results, scale_factor); -} - -int get_batch_per_block( - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads) { - return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads); -} - -} // end namespace scaled_masked_softmax -} // end namespace fused_softmax -} // end namespace multihead_attn - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", - &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, - "Self Multihead Attention scaled, time masked softmax -- Forward."); - - m.def("backward", - &multihead_attn::fused_softmax::scaled_masked_softmax::bwd, - "Self Multihead Attention scaled, time masked softmax -- Backward."); - - m.def("get_batch_per_block", - &multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block, - "Return Batch per block size." - ); -} diff --git a/csrc/megatron/scaled_masked_softmax.h b/csrc/megatron/scaled_masked_softmax.h deleted file mode 100644 index 78a29cf..0000000 --- a/csrc/megatron/scaled_masked_softmax.h +++ /dev/null @@ -1,505 +0,0 @@ -/* coding=utf-8 - * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -namespace { - -template -__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); - -template <> -__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; } - -template <> -__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } - -template <> -__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *dst = *src; } - -template <> -__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); } - -template <> -__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } - -template <> -__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } - -int log2_ceil(int value) { - int log2_value = 0; - while ((1 << log2_value) < value) ++log2_value; - return log2_value; -} - -template -struct Add { - __device__ __forceinline__ T operator()(T a, T b) const { - return a + b; - } -}; - -template -struct Max { - __device__ __forceinline__ T operator()(T a, T b) const { - return a < b ? b : a; - } -}; - -template -__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) -{ -#if CUDA_VERSION >= 9000 - return __shfl_xor_sync(mask, value, laneMask, width); -#else - return __shfl_xor(value, laneMask, width); -#endif -} - -template class ReduceOp> -__device__ __forceinline__ void warp_reduce(acc_t* sum) { - ReduceOp r; - #pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); - sum[i] = r(sum[i], b); - } - } -} - -/* - * Extended softmax (from native aten pytorch) with following additional features - * 1) input scaling - * 2) Explicit masking - */ -template -__global__ void scaled_masked_softmax_warp_forward( - output_t *dst, - const input_t *src, - const uint8_t *mask, - const acc_t scale, - int micro_batch_size, - int element_count, - int pad_batches) -{ - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_forward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) - // gridDim/blockIdx = (seq_len, attn_heads, batches) - int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH; - int pad_first_batch = 0; - if (pad_batches != 1) { // bert style - pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH; - } else { // gpt2 style - pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; - } - - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - - // load data from global memory - acc_t elements[WARP_BATCH][WARP_ITERATIONS]; - input_t temp_data[ELEMENTS_PER_LDG_STG]; - uint8_t temp_mask[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : element_count; - - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - - if (element_index < batch_element_count) { - int itr_idx = i*element_count+it*WARP_SIZE; - copy_vector(temp_data, src + itr_idx); - copy_vector(temp_mask, mask + itr_idx); - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (temp_mask[element] != 1) { - elements[i][it + element] = (acc_t)temp_data[element] * scale; - } else { - elements[i][it + element] = -10000.0; - } - } - } else { - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - elements[i][it + element] = -std::numeric_limits::infinity(); - } - } - } - } - - // compute max_value - acc_t max_value[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = elements[i][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; - } - } - warp_reduce(max_value); - - acc_t sum[WARP_BATCH] { 0.0f }; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; ++it) { - elements[i][it] = std::exp((elements[i][it] - max_value[i])); - sum[i] += elements[i][it]; - } - } - warp_reduce(sum); - - // store result - output_t out[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = elements[i][it + element] / sum[i]; - } - copy_vector(dst + i * element_count + it * WARP_SIZE, out); - } else { - break; - } - } - } -} - -template -__global__ void scaled_masked_softmax_warp_backward( - output_t *gradInput, - input_t *grad, - const input_t *output, - acc_t scale, - int micro_batch_size, - int element_count) -{ - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_backward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) - // gridDim/blockIdx = (seq_len, attn_heads, batches) - int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; - - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - // the first element to process by the current thread - int thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; - grad += thread_offset; - output += thread_offset; - gradInput += thread_offset; - - // load data from global memory - acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; - acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; - input_t temp_grad[ELEMENTS_PER_LDG_STG]; - input_t temp_output[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : element_count; - - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < batch_element_count) { - copy_vector(temp_grad, grad + i * element_count + it * WARP_SIZE); - copy_vector(temp_output, output + i * element_count + it * WARP_SIZE); - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - output_reg[i][it + element] = (acc_t)temp_output[element]; - } - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; - } - } - } - } - - acc_t sum[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] = grad_reg[i][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - sum[i] += grad_reg[i][it]; - } - } - warp_reduce(sum); - - // store result - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { - // compute gradients - output_t out[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); - } - copy_vector(gradInput + i * element_count + it * WARP_SIZE, out); - } - } - } -} -} // end of anonymous namespace - -int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads){ - int log2_elements = log2_ceil(key_seq_len); - const int next_power_of_two = 1 << log2_elements; - - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - constexpr int threads_per_block = 128; - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - - return batches_per_block; -} - -template -void dispatch_scaled_masked_softmax_forward( - output_t *dst, - const input_t *src, - const uint8_t *mask, - const input_t scale, - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads, - int pad_batches) -{ - TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 2048 ); - if (key_seq_len == 0) { - return; - } else { - int log2_elements = log2_ceil(key_seq_len); - const int next_power_of_two = 1 << log2_elements; - int batch_count = batches * attn_heads * query_seq_len; - - // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0); - dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches); - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 1: // 2 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 2: // 4 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 3: // 8 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 4: // 16 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 5: // 32 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 6: // 64 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 7: // 128 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 8: // 256 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 9: // 512 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 10: // 1024 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - case 11: // 2048 - scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); - break; - default: - break; - } - } -} - -template -void dispatch_scaled_masked_softmax_backward( - output_t *grad_input, - input_t *grad, - const input_t *output, - const acc_t scale, - int query_seq_len, - int key_seq_len, - int batches, - int attn_heads) -{ - TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 2048 ); - if (key_seq_len == 0) { - return; - } else { - int log2_elements = log2_ceil(key_seq_len); - const int next_power_of_two = 1 << log2_elements; - int batch_count = batches * attn_heads * query_seq_len; - - // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - int blocks = batch_count/batches_per_block; - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 1: // 2 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 2: // 4 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 3: // 8 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 4: // 16 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 5: // 32 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 6: // 64 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 7: // 128 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 8: // 256 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 9: // 512 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 10: // 1024 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - case 11: // 2048 - scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); - break; - default: - break; - } - } -} diff --git a/csrc/megatron/scaled_masked_softmax_cuda.cu b/csrc/megatron/scaled_masked_softmax_cuda.cu deleted file mode 100644 index 6096670..0000000 --- a/csrc/megatron/scaled_masked_softmax_cuda.cu +++ /dev/null @@ -1,117 +0,0 @@ -/* coding=utf-8 - * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -//#include -#include -#include -#include "scaled_masked_softmax.h" -#include "type_shim.h" - -namespace multihead_attn { -namespace fused_softmax { -namespace scaled_masked_softmax { - -int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){ - return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads); -} - - -torch::Tensor fwd_cuda( - torch::Tensor const& input, - torch::Tensor const& mask, - float scale_factor) -{ - // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] - const int batches = input.size(0); - const int pad_batches = mask.size(0); - const int attn_heads = input.size(1); - const int query_seq_len = input.size(2); - const int key_seq_len = input.size(3); - TORCH_INTERNAL_ASSERT(key_seq_len <= 2048); - TORCH_INTERNAL_ASSERT(query_seq_len > 1); - TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches); - TORCH_INTERNAL_ASSERT(mask.size(1) == 1); - TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len); - TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len); - - // Output - auto act_options = input.options().requires_grad(false); - torch::Tensor softmax_results = - torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); - - // Softmax Intermediate Result Ptr - void* input_ptr = static_cast(input.data_ptr()); - void* mask_ptr = static_cast(mask.data_ptr()); - void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); - - DISPATCH_HALF_AND_BFLOAT( - input.scalar_type(), - "dispatch_scaled_masked_softmax_forward", - dispatch_scaled_masked_softmax_forward( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(input_ptr), - reinterpret_cast(mask_ptr), - scale_factor, - query_seq_len, - key_seq_len, - batches, - attn_heads, - pad_batches); - ); - return softmax_results; -} - -torch::Tensor bwd_cuda( - torch::Tensor const& output_grads_, - torch::Tensor const& softmax_results_, - float scale_factor) { - - auto output_grads = output_grads_.contiguous(); - auto softmax_results = softmax_results_.contiguous(); - - //output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] - const int batches = output_grads.size(0); - const int attn_heads = output_grads.size(1); - const int query_seq_len = output_grads.size(2); - const int key_seq_len = output_grads.size(3); - - void* output_grads_ptr = static_cast(output_grads.data_ptr()); - - //Softmax Grad - DISPATCH_HALF_AND_BFLOAT( - output_grads_.scalar_type(), - "dispatch_scaled_masked_softmax_backward", - dispatch_scaled_masked_softmax_backward( - reinterpret_cast(output_grads_ptr), - reinterpret_cast(output_grads_ptr), - reinterpret_cast(softmax_results.data_ptr()), - scale_factor, - query_seq_len, - key_seq_len, - batches, - attn_heads); - ); - - //backward pass is completely in-place - return output_grads; -} -} -} -} diff --git a/csrc/megatron/scaled_upper_triang_masked_softmax.cpp b/csrc/megatron/scaled_upper_triang_masked_softmax.cpp deleted file mode 100644 index 12cec7f..0000000 --- a/csrc/megatron/scaled_upper_triang_masked_softmax.cpp +++ /dev/null @@ -1,71 +0,0 @@ -/* coding=utf-8 - * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace multihead_attn { -namespace fused_softmax { -namespace scaled_upper_triang_masked_softmax { - -torch::Tensor fwd_cuda( - torch::Tensor const& input, - float scale_factor); - -torch::Tensor bwd_cuda( - torch::Tensor const& output_grads, - torch::Tensor const& softmax_results, - float scale_factor); - -torch::Tensor fwd(torch::Tensor const& input, float scale_factor) { - AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); - AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || - (input.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - - return fwd_cuda(input, scale_factor); -} - -torch::Tensor bwd( - torch::Tensor const& output_grads, - torch::Tensor const& softmax_results, - float scale_factor) { - - AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); - AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); - - AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || - (output_grads.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || - (softmax_results.scalar_type() == at::ScalarType::BFloat16), - "Only fp16 and bf16 are supported"); - - return bwd_cuda(output_grads, softmax_results, scale_factor); -} - -} // end namespace scaled_upper_triang_masked_softmax -} // end namespace fused_softmax -} // end namespace multihead_attn - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", - &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd, - "Self Multihead Attention scaled, time masked softmax -- Forward."); - m.def("backward", - &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd, - "Self Multihead Attention scaled, time masked softmax -- Backward."); -} diff --git a/csrc/megatron/scaled_upper_triang_masked_softmax.h b/csrc/megatron/scaled_upper_triang_masked_softmax.h deleted file mode 100644 index 445e0d8..0000000 --- a/csrc/megatron/scaled_upper_triang_masked_softmax.h +++ /dev/null @@ -1,513 +0,0 @@ -/* coding=utf-8 - * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include - -namespace { - -template -__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); - -template <> -__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; } - -template <> -__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } - -template <> -__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *dst = *src; } - -template <> -__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); } - -template <> -__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } - -template <> -__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } - -template -__device__ __inline__ void copy_zero_vector(Datatype *dst); - -template <> -__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *dst = 0.0; } - -template <> -__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } - -template <> -__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *dst = 0.0; } - -template <> -__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } - - -int log2_ceil(int value) { - int log2_value = 0; - while ((1 << log2_value) < value) ++log2_value; - return log2_value; -} - -template -struct Add { - __device__ __forceinline__ T operator()(T a, T b) const { - return a + b; - } -}; - -template -struct Max { - __device__ __forceinline__ T operator()(T a, T b) const { - return a < b ? b : a; - } -}; - -template -__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) -{ -#if CUDA_VERSION >= 9000 - return __shfl_xor_sync(mask, value, laneMask, width); -#else - return __shfl_xor(value, laneMask, width); -#endif -} - -template class ReduceOp> -__device__ __forceinline__ void warp_reduce(acc_t* sum) { - ReduceOp r; - #pragma unroll - for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); - sum[i] = r(sum[i], b); - } - } -} - -/* - * Extended softmax (from native aten pytorch) with following additional features - * 1) input scaling - * 2) Implicit time (diagonal masking) - */ -template -__global__ void scaled_upper_triang_masked_softmax_warp_forward( - output_t *dst, - const input_t *src, - const acc_t scale, - int micro_batch_size, - int stride, - int element_count) -{ - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_forward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; - int local_seq = blockIdx.x + 1; - int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1)/ WARP_SIZE; - - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; - dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; - - // load data from global memory - acc_t elements[WARP_BATCH][WARP_ITERATIONS]; - input_t temp_data[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : local_seq; - - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - - if (element_index < batch_element_count) { - copy_vector(temp_data, src + i*element_count*stride + it*WARP_SIZE); - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if ((element_index + element) < batch_element_count) { - elements[i][it+element] = (acc_t)temp_data[element] * scale; - } else { - elements[i][it + element] = -std::numeric_limits::infinity(); - } - } - } else { - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - elements[i][it + element] = -std::numeric_limits::infinity(); - } - } - } - } - - // compute max_value - acc_t max_value[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - max_value[i] = elements[i][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; - } - } - warp_reduce(max_value); - - acc_t sum[WARP_BATCH] { 0.0f }; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; ++it) { - if (it < warp_iteration_limit) { - elements[i][it] = std::exp((elements[i][it] - max_value[i])); - sum[i] += elements[i][it]; - } - } - } - warp_reduce(sum); - - // store result - output_t out[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - - if (element_index < local_seq) { - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (element_index + element < local_seq) { - out[element] = elements[i][it + element] / sum[i]; - } else { - out[element] = 0; - } - } - copy_vector(dst + i * element_count * stride + it * WARP_SIZE, out); - } else if (element_index < element_count) { - copy_zero_vector(dst + i * element_count * stride + it * WARP_SIZE); - } else { - break; - } - } - } -} - -template -__global__ void scaled_upper_triang_masked_softmax_warp_backward( - output_t *gradInput, - input_t *grad, - const input_t *output, - acc_t scale, - int micro_batch_size, - int stride, - int element_count) -{ - // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and - // warp_size of method warp_softmax_backward_kernel. - constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; - constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; - constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; - - int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; - int local_seq = blockIdx.x + 1; - - // micro_batch_size might not be a multiple of WARP_BATCH. Check how - // many batches have to computed within this WARP. - int local_batches = micro_batch_size - first_batch; - if (local_batches > WARP_BATCH) - local_batches = WARP_BATCH; - - // there might be multiple batches per warp. compute the index within the batch - int local_idx = threadIdx.x; - - // the first element to process by the current thread - int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; - grad += thread_offset; - output += thread_offset; - gradInput += thread_offset; - - // load data from global memory - acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; - acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; - input_t temp_grad[ELEMENTS_PER_LDG_STG]; - input_t temp_output[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - int batch_element_count = (i >= local_batches) ? 0 : local_seq; - - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < batch_element_count) { - copy_vector(temp_grad, grad + i * element_count * stride + it * WARP_SIZE); - copy_vector(temp_output, output + i * element_count * stride + it * WARP_SIZE); - - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (element_index + element < batch_element_count) { - output_reg[i][it + element] = (acc_t)temp_output[element]; - } - } - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - if (element_index + element < batch_element_count) { - grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; - } - } - } - } - } - - acc_t sum[WARP_BATCH]; - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - sum[i] = grad_reg[i][0]; - #pragma unroll - for (int it = 1; it < WARP_ITERATIONS; ++it) { - sum[i] += grad_reg[i][it]; - } - } - warp_reduce(sum); - - // store result - #pragma unroll - for (int i = 0; i < WARP_BATCH; ++i) { - if (i >= local_batches) - break; - #pragma unroll - for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { - int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; - if (element_index < element_count) { - // compute gradients - output_t out[ELEMENTS_PER_LDG_STG]; - #pragma unroll - for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { - out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); - } - copy_vector(gradInput + i * element_count * stride + it * WARP_SIZE, out); - } - } - } -} - -} // end of anonymous namespace - -template -void dispatch_scaled_upper_triang_masked_softmax_forward( - output_t *dst, - const input_t *src, - const input_t scale, - int softmax_elements, - int softmax_elements_stride, - int attn_batches) -{ - TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048 ); - if (softmax_elements == 0) { - return; - } else { - int log2_elements = log2_ceil(softmax_elements); - const int next_power_of_two = 1 << log2_elements; - int seq_len = softmax_elements; - int batch_count = attn_batches * seq_len; - - // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); - - int blocks_per_seq = attn_batches / batches_per_block; - dim3 blocks(seq_len, blocks_per_seq, 1); - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 1: // 2 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 2: // 4 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 3: // 8 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 4: // 16 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 5: // 32 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 6: // 64 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 7: // 128 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 8: // 256 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 9: // 512 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 10: // 1024 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 11: // 2048 - scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - default: - break; - } - } -} - -template -void dispatch_scaled_upper_triang_masked_softmax_backward( - output_t *grad_input, - input_t *grad, - const input_t *output, - const acc_t scale, - int softmax_elements, - int softmax_elements_stride, - int attn_batches) -{ - TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 2048 ); - if (softmax_elements == 0) { - return; - } else { - int log2_elements = log2_ceil(softmax_elements); - const int next_power_of_two = 1 << log2_elements; - int seq_len = softmax_elements; - int batch_count = attn_batches * seq_len; - - // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. - int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; - - // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. - int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; - - // use 128 threads per block to maximimize gpu utilization - constexpr int threads_per_block = 128; - - int warps_per_block = (threads_per_block / warp_size); - int batches_per_block = warps_per_block * batches_per_warp; - TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); - - int blocks_per_seq = attn_batches / batches_per_block; - dim3 blocks(seq_len, blocks_per_seq, 1); - dim3 threads(warp_size, warps_per_block, 1); - // Launch code would be more elegant if C++ supported FOR CONSTEXPR - switch (log2_elements) { - case 0: // 1 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 1: // 2 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 2: // 4 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 3: // 8 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 4: // 16 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 5: // 32 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 6: // 64 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 7: // 128 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 8: // 256 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 9: // 512 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 10: // 1024 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - case 11: // 2048 - scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); - break; - default: - break; - } - } -} diff --git a/csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu b/csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu deleted file mode 100644 index df022cb..0000000 --- a/csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu +++ /dev/null @@ -1,98 +0,0 @@ -/* coding=utf-8 - * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -//#include -#include -#include -#include "scaled_upper_triang_masked_softmax.h" -#include "type_shim.h" - -namespace multihead_attn { -namespace fused_softmax { -namespace scaled_upper_triang_masked_softmax { - -torch::Tensor fwd_cuda( - torch::Tensor const& input, - float scale_factor) -{ - // input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] - const int attn_batches = input.size(0); - const int seq_len = input.size(1); - TORCH_INTERNAL_ASSERT(seq_len <= 2048); - - // Output - auto act_options = input.options().requires_grad(false); - torch::Tensor softmax_results = - torch::empty({attn_batches, seq_len, seq_len}, act_options); - - // Softmax Intermediate Result Ptr - void* input_ptr = static_cast(input.data_ptr()); - void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); - - DISPATCH_HALF_AND_BFLOAT( - input.scalar_type(), - "dispatch_scaled_upper_triang_masked_softmax_forward", - dispatch_scaled_upper_triang_masked_softmax_forward( - reinterpret_cast(softmax_results_ptr), - reinterpret_cast(input_ptr), - scale_factor, - seq_len, - seq_len, - attn_batches); - ); - return softmax_results; -} - - -torch::Tensor bwd_cuda( - torch::Tensor const& output_grads_, - torch::Tensor const& softmax_results_, - float scale_factor) { - - auto output_grads = output_grads_.contiguous(); - auto softmax_results = softmax_results_.contiguous(); - - //output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] - const int attn_batches = output_grads.size(0); - const int seq_len = output_grads.size(1); - TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2)); - - void* output_grads_ptr = static_cast(output_grads.data_ptr()); - - //Softmax Grad - DISPATCH_HALF_AND_BFLOAT( - output_grads_.scalar_type(), - "dispatch_scaled_upper_triang_masked_softmax_backward", - dispatch_scaled_upper_triang_masked_softmax_backward( - reinterpret_cast(output_grads_ptr), - reinterpret_cast(output_grads_ptr), - reinterpret_cast(softmax_results.data_ptr()), - scale_factor, - seq_len, - seq_len, - attn_batches); - ); - - //backward pass is completely in-place - return output_grads; -} -} -} -} diff --git a/csrc/mlp.cpp b/csrc/mlp.cpp deleted file mode 100644 index 830d606..0000000 --- a/csrc/mlp.cpp +++ /dev/null @@ -1,166 +0,0 @@ -#include -#include -#include - -#include - -size_t get_mlp_reserved_space(int64_t batch_size, int num_layers, const int* output_features); - -template -size_t get_mlp_bp_workspace_in_bytes(int batch_size, int num_layers, const int* output_features); - -template -int mlp_fp( - T* X, - int input_features, - int batch_size, - T** WPtr, - int num_layers, - int* output_features, - T** BPtr, - T* Y, - T* reserved_space, - int use_bias, - int activation, - void* lt_workspace); - -template -int mlp_bp( - T* X, - T* Y, - int input_features, - int batch_size, - T** WPtr, - int num_layers, - int* output_features, - T* dY, - T* reserved_space, - T* work_space, - T* dX, - T** dwPtr, - T** dbPtr, - bool requires_grad, - int use_bias, - int activation); - -std::vector mlp_forward(int use_bias, int activation, std::vector inputs) { - - auto num_layers = inputs.size() - 1; - if (use_bias) { - // inputs contains (input, weights, biases) - num_layers /= 2; - } - auto batch_size = inputs[0].size(0); - auto input_features = inputs[0].size(1); - - std::vector output_features; - for (int i = 0; i < num_layers; i++) { - output_features.push_back(inputs[i + 1].size(0)); - } - - auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data()); - - // create output/workspace tensor - auto out = at::empty({batch_size, output_features.back()}, inputs[0].type()); - auto reserved_space = at::empty({static_cast(reserved_size)}, inputs[0].type()); - // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB - auto lt_workspace = at::empty({1 << 22}, inputs[0].type()); - - AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputs[0].type(), "mlp_forward", [&] { - std::vector w_ptr; - std::vector b_ptr; - for (int i = 0; i < num_layers; i++) { - w_ptr.push_back(inputs[i + 1].data_ptr()); - if (use_bias) { - b_ptr.push_back(inputs[i + 1 + num_layers].data_ptr()); - } - } - auto result = mlp_fp( - inputs[0].data_ptr(), - input_features, - batch_size, - w_ptr.data(), - num_layers, - output_features.data(), - b_ptr.data(), - out.data_ptr(), - reserved_space.data_ptr(), - use_bias, - activation, - (void*) (lt_workspace.data_ptr())); - }); - - return {out, reserved_space}; -} - -std::vector mlp_backward( - int use_bias, - int activation, - at::Tensor grad_o, - std::vector fprop_outputs, - std::vector inputs) { - - auto num_layers = inputs.size() - 1; - if (use_bias) { - // inputs contains (input, weights, biases) - num_layers /= 2; - } - - auto batch_size = inputs[0].size(0); - auto input_features = inputs[0].size(1); - - bool requires_grad = inputs[0].requires_grad(); - - std::vector output_features; - for (int i = 0; i < num_layers; i++) { - output_features.push_back(inputs[i + 1].size(0)); - } - // create outputs, length of inputs - std::vector outputs; - for (int i = 0; i < inputs.size(); i++) { - outputs.push_back(at::empty(inputs[i].sizes(), inputs[i].type())); // clone for testing now - } - - AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputs[0].type(), "mlp_backward", [&] { - std::vector w_ptr; - for (int i = 0; i < num_layers; i++) { - w_ptr.push_back(inputs[i + 1].data_ptr()); - } - std::vector outputs_ptr; - for (int i = 0; i < inputs.size(); i++) { - outputs_ptr.push_back(outputs[i].data_ptr()); - } - - auto work_size = - get_mlp_bp_workspace_in_bytes(batch_size, num_layers, output_features.data()); - - // auto work_space = at::empty({work_size*4}, at::kByte); - auto work_space = at::empty({static_cast(work_size / sizeof(scalar_t))}, inputs[0].type()); - - auto result = mlp_bp( - inputs[0].data_ptr(), - fprop_outputs[0].data_ptr(), - input_features, - batch_size, - w_ptr.data(), - num_layers, - output_features.data(), - grad_o.contiguous().data_ptr(), - fprop_outputs[1].data_ptr(), - work_space.data_ptr(), - outputs_ptr[0], - outputs_ptr.data() + 1, - outputs_ptr.data() + 1 + num_layers, - requires_grad, - use_bias, - activation); - }); - - return outputs; -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &mlp_forward, "MLP forward"); - m.def("backward", &mlp_backward, "MLP backward"); -} - diff --git a/csrc/mlp_cuda.cu b/csrc/mlp_cuda.cu deleted file mode 100644 index 8290dea..0000000 --- a/csrc/mlp_cuda.cu +++ /dev/null @@ -1,1783 +0,0 @@ -// New MLP with denorm mitigation only for backprop - -#include -#include -#include -#include -#include -#include -#include - -/* Includes, cuda */ -#include -#include -#include "utils.h" - -#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000 -// includes cublaslt -#include -#endif -// constants for fused bias+relu kernel -#define BIAS_RELU_FW_NTHREADS 128 // forward number of thread per block -#define BIAS_RELU_BW_NTHREADS_X 32 // backward number of thread in feature dim -#define BIAS_RELU_BW_NTHREADS_Y 16 // backward number of thread in batch dim -#define BIAS_RELU_RED_PER_THREAD 16 // backward minimal reduction length per thread - - - -// move to a header later on -#define ILP 4 -template -__host__ __device__ __forceinline__ bool is_aligned(T* p){ - return ((uint64_t)p) % (ILP*sizeof(T)) == 0; -} - -template -__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){ - typedef typename std::aligned_storage::type LT; - ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset]; -} -template -__device__ __forceinline__ void load_store(T* dst, volatile T* src, int dst_offset, int src_offset){ - typedef typename std::aligned_storage::type LT; - ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset]; -} -template -__device__ __forceinline__ void load_store(volatile T* dst, T* src, int dst_offset, int src_offset){ - typedef typename std::aligned_storage::type LT; - ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset]; -} - -// Keep ReLU in float only. When using half, cast to float before calling. -__device__ __inline__ float relu(float a) { - float retf = max(a, 0.f); - return (retf); -} - -// Keep Sigmoid in float only. When using half, cast to float before calling. -__device__ __inline__ float sigmoid(float a) { - float retf = 1.f / (1.f + expf(-a)); - return (retf); -} - -// FP64 Wrapper around cublas GEMMEx -cublasStatus_t mlp_gemm( - cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - float* alpha, - const double* A, - int lda, - const double* B, - int ldb, - const float* beta, - double* C, - int ldc, - int flag) { -#ifdef __HIP_PLATFORM_HCC__ - return rocblas_gemm_ex( - handle, - transa, - transb, - m, - n, - k, - alpha, - A, - rocblas_datatype_f64_r, - lda, - B, - rocblas_datatype_f64_r, - ldb, - beta, - C, - rocblas_datatype_f64_r, - ldc, - C, - rocblas_datatype_f64_r, - ldc, - rocblas_datatype_f64_r, - rocblas_gemm_algo_standard, - 0, - flag); -#else - return cublasGemmEx( - handle, - transa, - transb, - m, - n, - k, - alpha, - A, - CUDA_R_64F, - lda, - B, - CUDA_R_64F, - ldb, - beta, - C, - CUDA_R_64F, - ldc, - CUDA_R_64F, - CUBLAS_GEMM_DEFAULT); -#endif -} - -// FP32 Wrapper around cublas GEMMEx -cublasStatus_t mlp_gemm( - cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - float* alpha, - const float* A, - int lda, - const float* B, - int ldb, - const float* beta, - float* C, - int ldc, - int flag) { -#ifdef __HIP_PLATFORM_HCC__ - return rocblas_gemm_ex( - handle, - transa, - transb, - m, - n, - k, - alpha, - A, - rocblas_datatype_f32_r, - lda, - B, - rocblas_datatype_f32_r, - ldb, - beta, - C, - rocblas_datatype_f32_r, - ldc, - C, - rocblas_datatype_f32_r, - ldc, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard, - 0, - flag); - -#else - return cublasGemmEx( - handle, - transa, - transb, - m, - n, - k, - alpha, - A, - CUDA_R_32F, - lda, - B, - CUDA_R_32F, - ldb, - beta, - C, - CUDA_R_32F, - ldc, - CUDA_R_32F, - CUBLAS_GEMM_DEFAULT); -#endif -} - -// FP16 Tensor core wrapper around cublas GEMMEx -cublasStatus_t mlp_gemm( - cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - float* alpha, - const at::Half* A, - int lda, - const at::Half* B, - int ldb, - float* beta, - at::Half* C, - int ldc, - int flag) { -#ifdef __HIP_PLATFORM_HCC__ - return rocblas_gemm_ex( - handle, - transa, - transb, - m, - n, - k, - alpha, - A, - rocblas_datatype_f16_r, - lda, - B, - rocblas_datatype_f16_r, - ldb, - beta, - C, - rocblas_datatype_f16_r, - ldc, - C, - rocblas_datatype_f16_r, - ldc, - rocblas_datatype_f32_r, - rocblas_gemm_algo_standard, - 0, - flag); -#else - return cublasGemmEx( - handle, - transa, - transb, - m, - n, - k, - alpha, - A, - CUDA_R_16F, - lda, - B, - CUDA_R_16F, - ldb, - beta, - C, - CUDA_R_16F, - ldc, - CUDA_R_32F, - CUBLAS_GEMM_DEFAULT_TENSOR_OP); -#endif -} -#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000 -int mlp_gemm_lt( - cublasLtHandle_t ltHandle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - float *alpha, /* host pointer */ - const at::Half* A, - int lda, - const at::Half* B, - int ldb, - float *beta, /* host pointer */ - at::Half* C, - int ldc, - void *workspace, - size_t workspaceSize, - cudaStream_t stream, - bool use_bias, - bool use_relu, - const void* bias) { - cublasStatus_t status = CUBLAS_STATUS_SUCCESS; - - cublasLtMatmulDescOpaque_t operationDesc = {}; - cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {}; - cublasLtMatmulPreferenceOpaque_t preference = {}; - - int returnedResults = 0; - cublasLtMatmulHeuristicResult_t heuristicResult = {}; - cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT; - - // Create operation descriptor; see cublasLtMatmulDescAttributes_t - // for details about defaults; here we just set the transforms for - // A and B. - status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - if (use_bias) { - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias)); - if (status != CUBLAS_STATUS_SUCCESS) { - goto CLEANUP; - } - if (use_relu) { - epilogue = CUBLASLT_EPILOGUE_RELU_BIAS; - } else { - epilogue = CUBLASLT_EPILOGUE_BIAS; - } - } else { - if (use_relu) { - epilogue = CUBLASLT_EPILOGUE_RELU; - } - } - - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)); - if (status != CUBLAS_STATUS_SUCCESS) { - goto CLEANUP; - } - - // Create matrix descriptors. Not setting any extra attributes. - status = cublasLtMatrixLayoutInit( - &Adesc, CUDA_R_16F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatrixLayoutInit( - &Bdesc, CUDA_R_16F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16F, m, n, ldc); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - // Create preference handle; In general, extra attributes can be - // used here to disable tensor ops or to make sure algo selected - // will work with badly aligned A, B, C. However, for simplicity - // here we assume A,B,C are always well aligned (e.g., directly - // come from cudaMalloc) - status = cublasLtMatmulPreferenceInit(&preference); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulPreferenceSetAttribute( - &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - // We just need the best available heuristic to try and run matmul. - // There is no guarantee that this will work. For example, if A is - // badly aligned, you can request more (e.g. 32) algos and try to - // run them one by one until something works. - status = cublasLtMatmulAlgoGetHeuristic( - ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - if (returnedResults == 0) { - status = CUBLAS_STATUS_NOT_SUPPORTED; - goto CLEANUP; - } - status = cublasLtMatmul(ltHandle, - &operationDesc, - alpha, - A, - &Adesc, - B, - &Bdesc, - beta, - C, - &Cdesc, - C, - &Cdesc, - &heuristicResult.algo, - workspace, - workspaceSize, - stream); - -CLEANUP: - // Descriptors are no longer needed as all GPU work was already - // enqueued. - return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; -} - -int mlp_gemm_lt( - cublasLtHandle_t ltHandle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - float *alpha, /* host pointer */ - const double* A, - int lda, - const double* B, - int ldb, - float *beta, /* host pointer */ - double* C, - int ldc, - void *workspace, - size_t workspaceSize, - cudaStream_t stream, - bool use_bias, - bool use_relu, - const void* bias) { - return 1; -} - -int mlp_gemm_lt( - cublasLtHandle_t ltHandle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - float *alpha, /* host pointer */ - const float *A, - int lda, - const float *B, - int ldb, - float *beta, /* host pointer */ - float *C, - int ldc, - void *workspace, - size_t workspaceSize, - cudaStream_t stream, - bool use_bias, - bool use_relu, - const void* bias) { - cublasStatus_t status = CUBLAS_STATUS_SUCCESS; - - cublasLtMatmulDescOpaque_t operationDesc = {}; - cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {}; - cublasLtMatmulPreferenceOpaque_t preference = {}; - - int returnedResults = 0; - cublasLtMatmulHeuristicResult_t heuristicResult = {}; - cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT; - - // Create operation descriptor; see cublasLtMatmulDescAttributes_t - // for details about defaults; here we just set the transforms for - // A and B. - status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - if (use_bias) { - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias)); - if (status != CUBLAS_STATUS_SUCCESS) { - goto CLEANUP; - } - if (use_relu) { - epilogue = CUBLASLT_EPILOGUE_RELU_BIAS; - } else { - epilogue = CUBLASLT_EPILOGUE_BIAS; - } - } else { - if (use_relu) { - epilogue = CUBLASLT_EPILOGUE_RELU; - } - } - - status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)); - if (status != CUBLAS_STATUS_SUCCESS) { - goto CLEANUP; - } - - // Create matrix descriptors. Not setting any extra attributes. - status = cublasLtMatrixLayoutInit( - &Adesc, CUDA_R_32F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatrixLayoutInit( - &Bdesc, CUDA_R_32F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_32F, m, n, ldc); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - // Create preference handle; In general, extra attributes can be - // used here to disable tensor ops or to make sure algo selected - // will work with badly aligned A, B, C. However, for simplicity - // here we assume A,B,C are always well aligned (e.g., directly - // come from cudaMalloc) - status = cublasLtMatmulPreferenceInit(&preference); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - status = cublasLtMatmulPreferenceSetAttribute( - &preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize)); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - // We just need the best available heuristic to try and run matmul. - // There is no guarantee that this will work. For example, if A is - // badly aligned, you can request more (e.g. 32) algos and try to - // run them one by one until something works. - status = cublasLtMatmulAlgoGetHeuristic( - ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults); - if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP; - - if (returnedResults == 0) { - status = CUBLAS_STATUS_NOT_SUPPORTED; - goto CLEANUP; - } - - status = cublasLtMatmul(ltHandle, - &operationDesc, - alpha, - A, - &Adesc, - B, - &Bdesc, - beta, - C, - &Cdesc, - C, - &Cdesc, - &heuristicResult.algo, - workspace, - workspaceSize, - stream); - -CLEANUP: - // Descriptors are no longer needed as all GPU work was already - // enqueued. - return status == CUBLAS_STATUS_SUCCESS ? 0 : 1; -} -#endif - -// Bias ADD. Assume input X is [features x batch size], column major. -// Bias is one 'features' long vector, with implicit broadcast. -template -__global__ void biasAdd_fprop(T *X, T *b, uint batch_size, uint features) { - T r_x[ILP]; - T r_b[ILP]; - if(is_aligned(X) && is_aligned(b) && features % ILP ==0) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - for (; tid*ILP < features * batch_size; tid += blockDim.x * gridDim.x) { - int row = tid % (features / ILP); - load_store(r_x, X, 0 , tid); - load_store(r_b, b, 0 , row); -#pragma unroll - for(int ii = 0; ii < ILP; ii++) { - float bias_sum = static_cast(r_x[ii]) + static_cast(r_b[ii]); - r_x[ii] = bias_sum; - } - load_store(X, r_x, tid , 0); - } - } else { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) { -#pragma unroll - for(int ii = 0; ii < ILP; ii++) { - int idx = tid + ii * blockDim.x * gridDim.x; - if(idx < features * batch_size) { - int row = tid % features; - r_x[ii] = X[idx]; - r_b[ii] = b[row]; - } - } -#pragma unroll - for(int ii = 0; ii < ILP; ii++) { - float bias_sum = static_cast(r_x[ii]) + static_cast(r_b[ii]); - r_x[ii] = bias_sum; - } -#pragma unroll - for(int ii = 0; ii < ILP; ii++) { - int idx = tid + ii * blockDim.x * gridDim.x; - if(idx < features * batch_size) { - X[idx] = r_x[ii]; - } - } - } - } -} - -// Bias ADD + ReLU. Assume input X is [features x batch size], column major. -// Activation support fuesed ReLU. Safe to call in-place. -template -__global__ void biasAddRelu_fprop(T *X, T *b, uint batch_size, uint features) { - T r_x[ILP]; - T r_b[ILP]; - if(is_aligned(X) && is_aligned(b) && features % ILP ==0) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - for (; tid*ILP < features * batch_size; tid += blockDim.x * gridDim.x) { - int row = tid % (features / ILP); - load_store(r_x, X, 0 , tid); - load_store(r_b, b, 0 , row); -#pragma unroll - for(int ii = 0; ii < ILP; ii++) { - float bias_sum = static_cast(r_x[ii]) + static_cast(r_b[ii]); - r_x[ii] = relu(bias_sum); - } - load_store(X, r_x, tid , 0); - } - } else { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) { -#pragma unroll - for(int ii = 0; ii < ILP; ii++) { - int idx = tid + ii * blockDim.x * gridDim.x; - if(idx < features * batch_size) { - int row = tid % features; - r_x[ii] = X[idx]; - r_b[ii] = b[row]; - } - } -#pragma unroll - for(int ii = 0; ii < ILP; ii++) { - float bias_sum = static_cast(r_x[ii]) + static_cast(r_b[ii]); - r_x[ii] = relu(bias_sum); - } -#pragma unroll - for(int ii = 0; ii < ILP; ii++) { - int idx = tid + ii * blockDim.x * gridDim.x; - if(idx < features * batch_size) { - X[idx] = r_x[ii]; - } - } - } - } -} - -// ReLU. Assume input X is [features x batch size], column major. -// Safe to call in-place. -template -__global__ void Relu_fprop(T *X, uint batch_size, uint features) { - T r_x[ILP]; - if(is_aligned(X) && features % ILP ==0) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - for (; tid*ILP < features * batch_size; tid += blockDim.x * gridDim.x) { - load_store(r_x, X, 0 , tid); -#pragma unroll - for(int ii = 0; ii < ILP; ii++) { - r_x[ii] = relu(static_cast(r_x[ii])); - } - load_store(X, r_x, tid , 0); - } - } else { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) { -#pragma unroll - for(int ii = 0; ii < ILP; ii++) { - int idx = tid + ii * blockDim.x * gridDim.x; - if(idx < features * batch_size) { - r_x[ii] = X[idx]; - } - } -#pragma unroll - for(int ii = 0; ii < ILP; ii++) { - r_x[ii] = relu(static_cast(r_x[ii])); - } -#pragma unroll - for(int ii = 0; ii < ILP; ii++) { - int idx = tid + ii * blockDim.x * gridDim.x; - if(idx < features * batch_size) { - X[idx] = r_x[ii]; - } - } - } - } -} - -// Sigmoid. Assume input X is [features x batch size], column major. -// Safe to call in-place. -template -__global__ void Sigmoid_fprop(T *X, uint batch_size, uint features) { - T r_x[ILP]; - if(is_aligned(X) && features % ILP ==0) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - for (; tid*ILP < features * batch_size; tid += blockDim.x * gridDim.x) { - load_store(r_x, X, 0 , tid); -#pragma unroll - for(int ii = 0; ii < ILP; ii++) { - r_x[ii] = sigmoid(static_cast(r_x[ii])); - } - load_store(X, r_x, tid , 0); - } - } else { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - for (; tid < features * batch_size; tid += ILP * blockDim.x * gridDim.x) { -#pragma unroll - for(int ii = 0; ii < ILP; ii++) { - int idx = tid + ii * blockDim.x * gridDim.x; - if(idx < features * batch_size) { - r_x[ii] = X[idx]; - } - } -#pragma unroll - for(int ii = 0; ii < ILP; ii++) { - r_x[ii] = sigmoid(static_cast(r_x[ii])); - } -#pragma unroll - for(int ii = 0; ii < ILP; ii++) { - int idx = tid + ii * blockDim.x * gridDim.x; - if(idx < features * batch_size) { - X[idx] = r_x[ii]; - } - } - } - } -} - -// ReLU. Assume input X is [features x batch size], column major. -// Safe to call in-place. -template -__global__ void Relu_bprop(T *dY, T *Y, uint batch_size, uint features, T *dX) { - T r_dy[ILP]; - T r_y[ILP]; - if(is_aligned(dY) && - is_aligned(Y) && - is_aligned(dX) && - features % ILP ==0) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - for (; tid*ILP < features * batch_size; tid += blockDim.x * gridDim.x) { - load_store(r_dy, dY, 0 , tid); - load_store(r_y, Y, 0 , tid); -#pragma unroll - for(int ii=0;ii -__global__ void Sigmoid_bprop(T *dY, T *Y, uint batch_size, uint features, T *dX) { - T r_dy[ILP]; - T r_y[ILP]; - if(is_aligned(dY) && - is_aligned(Y) && - is_aligned(dX) && - features % ILP ==0) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - for (; tid*ILP < features * batch_size; tid += blockDim.x * gridDim.x) { - load_store(r_dy, dY, 0 , tid); - load_store(r_y, Y, 0 , tid); -#pragma unroll - for(int ii=0;iimultiProcessorCount; - // can switch to occupancy calculation. use 4 below now for sm_70 - int max_blocks_y = (num_SMs * 4+(*grid_x)-1) / (*grid_x); - // block_y should be from minimal work per thread - int nRedSplits = (batch_size + block_y - 1) / block_y; - // increase number of elem per thread redcution to not launch more than enough - // kernel adjust work, so here we just launch max block - *grid_y = std::min(nRedSplits, max_blocks_y); - return; -} - -// Addition done deterministically via a 2-pass approach. Each CTA writes out partial -// sum, and the last CTA in grid Y dimension accumulates partials serially and writes to result. -template -__global__ void biasAdd_bprop( - T* dY, - int features, - int batch_size, - volatile float* intermediate, - int* semaphores, - T* db) { - // The feature that this thread is responsible for - int f = blockIdx.x * blockDim.x + threadIdx.x; - - // Compute the span this thread is responsible for - // For this block - int b_chunkSize = (batch_size + gridDim.y - 1) / gridDim.y; - int b_nStart = blockIdx.y * b_chunkSize; - int b_nSpan = min(batch_size, b_nStart + b_chunkSize) - b_nStart; - // For this thread - int chunkSize = (b_chunkSize + blockDim.y - 1) / blockDim.y; - int nStart = threadIdx.y * chunkSize + b_nStart; - int nSpan = min(b_nStart + b_nSpan, nStart + chunkSize) - nStart; - - volatile float* out = intermediate + blockIdx.y * features; - - // Flag to trigger last reduction. - __shared__ bool isLastBlock; - // we know block size for now - __shared__ float smem[BIAS_RELU_BW_NTHREADS_X*BIAS_RELU_BW_NTHREADS_Y]; - - // Accumulate db in FP32 always - float db_local = 0; - if (f < features) { - int nidx = 0; - // Handle non-multiple of UNROLL_FACTOR residue - for (; nidx < nSpan % UNROLL_FACTOR; nidx++) { - int64_t row, col, flat_idx; - row = f; - col = nStart + nidx; - flat_idx = col * features + row; - db_local += (float)dY[flat_idx]; - } - - // Handle meat of work - for (; (nidx + UNROLL_FACTOR - 1) < nSpan; nidx += UNROLL_FACTOR) { - int64_t row, col, flat_idx; - row = f; - col = nStart + nidx; - flat_idx = col * features + row; -#pragma unroll 4 - for (int u = 0; u < UNROLL_FACTOR; u++) { - db_local += (float)dY[flat_idx]; - flat_idx += features; - } - } - - // naive block reduction on y-dim - int linear_idx = threadIdx.y * blockDim.x + threadIdx.x; - smem[linear_idx] = db_local; - } - __syncthreads(); - if (f < features) { - if(threadIdx.y == 0) { - for(int yidx = 1; yidx < blockDim.y; yidx++){ - db_local += smem[yidx * blockDim.x + threadIdx.x]; - } - - // block result is in db_local now for all threadIdx.y == 0 - // Write out partial result - out[f] = db_local; - } - } - __threadfence(); - __syncthreads(); - - // Increment semaphore and check if this is the last CTA in the grid_y dimension. - // Only thread (0,0) calls this - if (threadIdx.x == 0 && threadIdx.y == 0 && f < features) { - unsigned int sum_idx; - sum_idx = atomicAdd(&(semaphores[blockIdx.x]), 1); - isLastBlock = (sum_idx == (gridDim.y - 1)); - } - __syncthreads(); - - db_local = 0; - // No block reduction for now, only thread (*,0) do grid reduction - if (isLastBlock && f < features) { - if(threadIdx.y == 0) { - for (int n = 0; n < gridDim.y; n++) { - int row, col; - row = f; - col = n; - db_local += (float)(intermediate[col * features + row]); - } - db[f] = (T)db_local; - } - } -} - -// Addition done deterministically via a 2-pass approach. Each CTA writes out partial -// sum, and the last CTA in grid Y dimension accumulates partials serially and writes to result. -template -__global__ void biasAddRelu_bprop( - T* Y, - T* dY, - int features, - int batch_size, - T* dX, - volatile float* intermediate, - int* semaphores, - T* db) { - // The feature that this thread is responsible for - int f = blockIdx.x * blockDim.x + threadIdx.x; - - // Compute the span this thread is responsible for - // For this block - int b_chunkSize = (batch_size + gridDim.y - 1) / gridDim.y; - int b_nStart = blockIdx.y * b_chunkSize; - int b_nSpan = min(batch_size, b_nStart + b_chunkSize) - b_nStart; - // For this thread - int chunkSize = (b_chunkSize + blockDim.y - 1) / blockDim.y; - int nStart = threadIdx.y * chunkSize + b_nStart; - int nSpan = min(b_nStart + b_nSpan, nStart + chunkSize) - nStart; - - volatile float* out = intermediate + blockIdx.y * features; - - // Flag to trigger last reduction. - __shared__ bool isLastBlock; - // we know block size for now - __shared__ float smem[BIAS_RELU_BW_NTHREADS_X*BIAS_RELU_BW_NTHREADS_Y]; - - // Accumulate db in FP32 always - float db_local = 0; - if (f < features) { - int nidx = 0; - // Handle non-multiple of UNROLL_FACTOR residue - for (; nidx < nSpan % UNROLL_FACTOR; nidx++) { - int row, col, flat_idx; - row = f; - col = nStart + nidx; - flat_idx = col * features + row; - T y_val = Y[flat_idx]; - T dy_val = dY[flat_idx]; - T dx_val; - if ((float)y_val > 0.f) - dx_val = dy_val; - else - dx_val = 0; - dX[flat_idx] = dx_val; - db_local += (float)dx_val; - } - - // Handle meat of work - for (; (nidx + UNROLL_FACTOR - 1) < nSpan; nidx += UNROLL_FACTOR) { - int row, col, flat_idx; - row = f; - col = nStart + nidx; - flat_idx = col * features + row; -#pragma unroll 4 - for (int u = 0; u < UNROLL_FACTOR; u++) { - T y_val = Y[flat_idx]; - T dy_val = dY[flat_idx]; - T dx_val; - if ((float)y_val > 0.f) - dx_val = dy_val; - else - dx_val = 0; - dX[flat_idx] = dx_val; - db_local += (float)dx_val; - flat_idx += features; - } - } - - // naive block reduction on y-dim - int linear_idx = threadIdx.y * blockDim.x + threadIdx.x; - smem[linear_idx] = db_local; - } - __syncthreads(); - if (f < features) { - if(threadIdx.y == 0) { - for(int yidx = 1; yidx < blockDim.y; yidx++){ - db_local += smem[yidx * blockDim.x + threadIdx.x]; - } - - // block result is in db_local now for all threadIdx.y == 0 - // Write out partial result - out[f] = db_local; - } - } - __threadfence(); - __syncthreads(); - - // Increment semaphore and check if this is the last CTA in the grid_y dimension. - // Only thread (0,0) calls this - if (threadIdx.x == 0 && threadIdx.y == 0 && f < features) { - unsigned int sum_idx; - sum_idx = atomicAdd(&(semaphores[blockIdx.x]), 1); - isLastBlock = (sum_idx == (gridDim.y - 1)); - } - __syncthreads(); - - db_local = 0; - // No block reduction for now, only thread (*,0) do grid reduction - if (isLastBlock && f < features) { - if(threadIdx.y == 0) { - for (int n = 0; n < gridDim.y; n++) { - int row, col; - row = f; - col = n; - db_local += (float)(intermediate[col * features + row]); - } - db[f] = (T)db_local; - } - } -} - -// Addition done deterministically via a 2-pass approach. Each CTA writes out partial -// sum, and the last CTA in grid Y dimension accumulates partials serially and writes to result. -template -__global__ void biasAddRelu_bprop_aligned( - T* Y, - T* dY, - int features, - int batch_size, - T* dX, - volatile float* intermediate, - int* semaphores, - T* db) { - // The feature that this thread is responsible for - int f = blockIdx.x * blockDim.x + threadIdx.x; - - // Compute the span this thread is responsible for - // For this block - int b_chunkSize = (batch_size + gridDim.y - 1) / gridDim.y; - int b_nStart = blockIdx.y * b_chunkSize; - int b_nSpan = min(batch_size, b_nStart + b_chunkSize) - b_nStart; - // For this thread - int chunkSize = (b_chunkSize + blockDim.y - 1) / blockDim.y; - int nStart = threadIdx.y * chunkSize + b_nStart; - int nSpan = min(b_nStart + b_nSpan, nStart + chunkSize) - nStart; - - volatile float* out = intermediate + blockIdx.y * features; - - // Flag to trigger last reduction. - __shared__ bool isLastBlock; - - // Accumulate db in FP32 always - float db_local[ILP]; - T r_y[ILP]; - T r_dy[ILP]; -#pragma unroll - for(int ii=0;ii -size_t get_mlp_bp_workspace_in_bytes(int batch_size, int num_layers, const int* output_features) { - size_t work_space = 0; - - // Store each intermediate dY explicitly. Need 2 dYs per MLP layer (one for o/p - // of biasReLU_bp and one for o/p of dgrad GEMM). - work_space += 2 * get_all_activations_size(batch_size, num_layers, output_features) * sizeof(T); - work_space += - get_reduction_scratch_space(batch_size, num_layers, output_features) * sizeof(float); - work_space += get_semaphores_size(num_layers, output_features) * sizeof(int); - - return work_space; -} - -// Returns pointers to each segment of the workspace -template -void partition_mlp_bp_workspace( - int batch_size, - int num_layers, - const int* output_features, - void* work_space, - T** dy_gemms, - T** dx_gemms, - float** db_scratch, - int** semaphores) { - /* - Workspace is partitioned as - DY_GEMMs : DX_GEMMs : DB_SCRATCH : SEMAPHORES - */ - // Start address where dy_gemm tensors are stored - *dy_gemms = reinterpret_cast(work_space); - // Start address where dx_gemm tensors are stored - *dx_gemms = *dy_gemms + get_all_activations_size(batch_size, num_layers, output_features); - // Start address where db intermediate tensors are stored - *db_scratch = reinterpret_cast( - *dx_gemms + get_all_activations_size(batch_size, num_layers, output_features)); - // Start address of semaphores - *semaphores = reinterpret_cast( - *db_scratch + get_reduction_scratch_space(batch_size, num_layers, output_features)); - - return; -} - -// Does a simple MLP fprop (GEMM+bias+ReLU). -// Can handle num_layers number of layers, each with its own shape. Output of layer i is assumed -// to be input of layer i+1. output_features, WPtr and BPtr are arrays of length num_layers, and -// must be in the same order i.e. WPtr[i] and BPtr[i] are respectively the weight and bias of layer -// 'i'. -template -int mlp_fp( - T* X, - int input_features, - int batch_size, - T** WPtr, - int num_layers, - int* output_features, - T** BPtr, - T* Y, - T* reserved_space, - int use_bias, - int activation, - void* lt_workspace) { - T *weight, *input, *output, *bias; - T *reserved_space_x, *reserved_space_y; - reserved_space_x = NULL; - reserved_space_y = reserved_space; - - // Get cublas handle from Pytorch - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - // Get the stream from cublas handle to reuse for biasReLU kernel. - cudaStream_t stream; - cublasGetStream(handle, &stream); - - for (int layer = 0; layer < num_layers; layer++) { - weight = WPtr[layer]; - input = (layer == 0) ? X : reserved_space_x; - output = (layer == num_layers - 1) ? Y : reserved_space_y; - if (use_bias) { - bias = BPtr[layer]; - } - int ifeat = (layer == 0) ? input_features : output_features[layer - 1]; - int ofeat = output_features[layer]; - - float one = 1.f; - float zero = 0.f; - - // try with cublaslt first for supported case with valid handle - int cublaslt_status = 1; -#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000 - if(activation < 1){ - cublaslt_status = mlp_gemm_lt( - //ltHandle, - (cublasLtHandle_t)handle, - CUBLAS_OP_T, - CUBLAS_OP_N, - ofeat, - batch_size, - ifeat, - &one, - weight, - ifeat, - input, - ifeat, - &zero, - output, - ofeat, - lt_workspace, - 1 << 22, - stream, - use_bias == 1, - activation == 1, - bias); - } -#endif - - // if cublaslt failed or not executed, fallback to cublas - if (cublaslt_status != 0) { - cublasStatus_t cublas_status; - // Call GEMM: fprop is Y = W'X - cublas_status = mlp_gemm( - handle, - CUBLAS_OP_T, - CUBLAS_OP_N, - ofeat, - batch_size, - ifeat, - &one, - weight, - ifeat, - input, - ifeat, - &zero, - output, - ofeat, - int(0)); // Do nothing for forward prop - - if (cublas_status != CUBLAS_STATUS_SUCCESS) { - printf("GEMM fprop failed with %d\n", cublas_status); - return 1; - } - - const uint &input_size = ofeat; - int num_blocks = 0; - int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; - // Call biasReLU - if(use_bias == 1) { - if (activation == 0) { // no activation - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, biasAdd_fprop, BIAS_RELU_FW_NTHREADS, 0); - biasAdd_fprop<<>>(output, bias, batch_size, input_size); - } else if (activation == 1) { // relu - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, biasAddRelu_fprop, BIAS_RELU_FW_NTHREADS, 0); - biasAddRelu_fprop<<>>(output, bias, batch_size, input_size); - } else if (activation == 2) { // sigmoid - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, biasAdd_fprop, BIAS_RELU_FW_NTHREADS, 0); - biasAdd_fprop<<>>(output, bias, batch_size, input_size); - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Sigmoid_fprop, BIAS_RELU_FW_NTHREADS, 0); - Sigmoid_fprop<<>>(output, batch_size, input_size); - } - } else { - // don't need to do anything in case of no activation and no bias - if (activation == 1) { // relu - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Relu_fprop, BIAS_RELU_FW_NTHREADS, 0); - Relu_fprop<<>>(output, batch_size, input_size); - } else if (activation == 2) { // sigmoid - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Sigmoid_fprop, BIAS_RELU_FW_NTHREADS, 0); - Sigmoid_fprop<<>>(output, batch_size, input_size); - } - } - } - // Set current output as next layer input - reserved_space_x = reserved_space_y; - // Set next layer output - reserved_space_y += ofeat * batch_size; - } - - return 0; -} - -// Does a simple MLP bprop (GEMM+bias+ReLU). -// Needs reserved space to come back exactly as it was populated in fprop. -// Does dgrad and wgrad sequentially. -template -int mlp_bp( - T* X, - T* Y, - int input_features, - int batch_size, - T** WPtr, - int num_layers, - int* output_features, - T* dY, - T* reserved_space, - T* work_space, - T* dX, - T** dwPtr, - T** dbPtr, - bool requires_grad, - int use_bias, - int activation) { - T* weight; - T *dweight, *dx, *dy, *dbias; - T *x, *y; - - // Where the dx of the biasReLU (== dy of gemm) is stored. Can be thrown away - // after bp call. - T* dy_gemm_base; - // Where the dx after GEMM is stored. - T* dx_gemm_base; - // Where partial reduction results are stored. - float* db_scratch; - // Semaphores for reduction. - int* semaphores; - - partition_mlp_bp_workspace( - batch_size, - num_layers, - output_features, - work_space, - &dy_gemm_base, - &dx_gemm_base, - &db_scratch, - &semaphores); - - size_t semaphore_size = get_semaphores_size(num_layers, output_features) * sizeof(int); - - // Get cublas handle from Pytorch - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); - // Get the stream from cublas handle to reuse for biasReLU kernel. - cudaStream_t stream; - cublasGetStream(handle, &stream); - int flag = 0; - #ifdef __HIP_PLATFORM_HCC__ - #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR) - #define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) - #if USE_GEMM_FLAGS_FP16_ALT_IMPL - #ifdef BACKWARD_PASS_GUARD - flag = at::BACKWARD_PASS_GUARD_CLASS::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0; - #endif - #endif - #endif - - int* y_offsets = (int*)malloc(num_layers * sizeof(int)); - get_y_offsets(batch_size, num_layers, output_features, y_offsets); - - for (int layer = num_layers - 1; layer >= 0; layer--) { - weight = WPtr[layer]; - dweight = dwPtr[layer]; - - // x is read from reserved space - x = (layer == 0) ? X : reserved_space + y_offsets[layer - 1]; - // dx is written in workspace for all but layer==0 - dx = (layer == 0) ? dX : dx_gemm_base + y_offsets[layer - 1]; - - // y is read from reserved space - y = (layer == num_layers - 1) ? Y : reserved_space + y_offsets[layer]; - // dx from layer+1 - dy = (layer == num_layers - 1) ? dY : dx_gemm_base + y_offsets[layer]; - // dy_gemm is written to and read immediately - T* dy_gemm = dy_gemm_base + y_offsets[layer]; - - dbias = dbPtr[layer]; - int xfeat = (layer == 0) ? input_features : output_features[layer - 1]; - int yfeat = output_features[layer]; - - float one = 1.f; - float zero = 0.f; - - if (use_bias == 1) { - if (activation == 0) { // no acitvation - // bgrad - dim3 block(BIAS_RELU_BW_NTHREADS_X, BIAS_RELU_BW_NTHREADS_Y); - int grid_x, grid_y; - cudaMemsetAsync(semaphores, 0, semaphore_size, stream); - - int block_x = BIAS_RELU_BW_NTHREADS_X; - int block_y = BIAS_RELU_RED_PER_THREAD * BIAS_RELU_BW_NTHREADS_Y; - get_biasAddRelu_bprop_grid_size(yfeat, batch_size, block_x, block_y, &grid_x, &grid_y); - dim3 grid(grid_x, grid_y); - biasAdd_bprop<<>>( - dy, yfeat, batch_size, db_scratch, semaphores, dbias); - // bypass dgrad through reset pointer - dy_gemm = dy; - } else if (activation == 1) { // relu - dim3 block(BIAS_RELU_BW_NTHREADS_X, BIAS_RELU_BW_NTHREADS_Y); - int grid_x, grid_y; - cudaMemsetAsync(semaphores, 0, semaphore_size, stream); - - if(yfeat % (ILP * BIAS_RELU_BW_NTHREADS_X) == 0 && - is_aligned(y) && - is_aligned(dy) && - is_aligned(dy_gemm) && - is_aligned(dbias)){ - int block_x = ILP * BIAS_RELU_BW_NTHREADS_X; - int block_y = BIAS_RELU_RED_PER_THREAD * BIAS_RELU_BW_NTHREADS_Y; - get_biasAddRelu_bprop_grid_size(yfeat, batch_size, block_x, block_y, &grid_x, &grid_y); - dim3 grid(grid_x, grid_y); - biasAddRelu_bprop_aligned<<>>( - y, dy, yfeat, batch_size, dy_gemm, db_scratch, semaphores, dbias); - } else { - int block_x = BIAS_RELU_BW_NTHREADS_X; - int block_y = BIAS_RELU_RED_PER_THREAD * BIAS_RELU_BW_NTHREADS_Y; - get_biasAddRelu_bprop_grid_size(yfeat, batch_size, block_x, block_y, &grid_x, &grid_y); - dim3 grid(grid_x, grid_y); - biasAddRelu_bprop<<>>( - y, dy, yfeat, batch_size, dy_gemm, db_scratch, semaphores, dbias); - } - } else if (activation == 2) { // sigmoid - // activation backward - int num_blocks = 0; - int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Sigmoid_bprop, BIAS_RELU_FW_NTHREADS, 0); - Sigmoid_bprop<<>>(dy, y, batch_size, yfeat, dy_gemm); - - // bgrad, from dy_gemm - dim3 block(BIAS_RELU_BW_NTHREADS_X, BIAS_RELU_BW_NTHREADS_Y); - int grid_x, grid_y; - cudaMemsetAsync(semaphores, 0, semaphore_size, stream); - - int block_x = BIAS_RELU_BW_NTHREADS_X; - int block_y = BIAS_RELU_RED_PER_THREAD * BIAS_RELU_BW_NTHREADS_Y; - get_biasAddRelu_bprop_grid_size(yfeat, batch_size, block_x, block_y, &grid_x, &grid_y); - dim3 grid(grid_x, grid_y); - biasAdd_bprop<<>>( - dy_gemm, yfeat, batch_size, db_scratch, semaphores, dbias); - } - } else { // no bias below - if (activation == 0) { - // bypass dgrad through reset pointer - dy_gemm = dy; - } else if (activation == 1) { // relu - int num_blocks = 0; - int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Relu_bprop, BIAS_RELU_FW_NTHREADS, 0); - Relu_bprop<<>>(dy, y, batch_size, yfeat, dy_gemm); - } else if (activation == 2) { // sigmoid - int num_blocks = 0; - int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Sigmoid_bprop, BIAS_RELU_FW_NTHREADS, 0); - Sigmoid_bprop<<>>(dy, y, batch_size, yfeat, dy_gemm); - } - } - - cublasStatus_t cublas_status; - // Call GEMM dgrad - if (layer > 0 || requires_grad == 1) { - cublas_status = mlp_gemm( - handle, - CUBLAS_OP_N, - CUBLAS_OP_N, - xfeat, - batch_size, - yfeat, - &one, - weight, - xfeat, - dy_gemm, - yfeat, - &zero, - dx, - xfeat, - flag); // - - if (cublas_status != CUBLAS_STATUS_SUCCESS) { - printf("GEMM dgrad failed with %d\n", cublas_status); - return 1; - } - } - - // Call GEMM wgrad - cublas_status = mlp_gemm( - handle, - CUBLAS_OP_N, - CUBLAS_OP_T, - xfeat, - yfeat, - batch_size, - &one, - x, - xfeat, - dy_gemm, - yfeat, - &zero, - dweight, - xfeat, - flag); // - - if (cublas_status != CUBLAS_STATUS_SUCCESS) { - printf("GEMM wgrad failed with %d\n", cublas_status); - return 1; - } - } - - return 0; -} - -// Instantiate for floating point types -template int mlp_fp( - float* X, - int input_features, - int batch_size, - float** WPtr, - int num_layers, - int* output_features, - float** BPtr, - float* Y, - float* reserved_space, - int use_bias, - int activation, - void* lt_workspace); - -template int mlp_bp( - float* X, - float* Y, - int input_features, - int batch_size, - float** WPtr, - int num_layers, - int* output_features, - float* dY, - float* reserved_space, - float* work_space, - float* dX, - float** dwPtr, - float** dbPtr, - bool requires_grad, - int use_bias, - int activation); - -template int mlp_fp( - at::Half* X, - int input_features, - int batch_size, - at::Half** WPtr, - int num_layers, - int* output_features, - at::Half** BPtr, - at::Half* Y, - at::Half* reserved_space, - int use_bias, - int activation, - void* lt_workspace); - -template int mlp_bp( - at::Half* X, - at::Half* Y, - int input_features, - int batch_size, - at::Half** WPtr, - int num_layers, - int* output_features, - at::Half* dY, - at::Half* reserved_space, - at::Half* work_space, - at::Half* dX, - at::Half** dwPtr, - at::Half** dbPtr, - bool requires_grad, - int use_bias, - int activation); - -template int mlp_fp( - double* X, - int input_features, - int batch_size, - double** WPtr, - int num_layers, - int* output_features, - double** BPtr, - double* Y, - double* reserved_space, - int use_bias, - int activation, - void* lt_workspace); - -template int mlp_bp( - double* X, - double* Y, - int input_features, - int batch_size, - double** WPtr, - int num_layers, - int* output_features, - double* dY, - double* reserved_space, - double* work_space, - double* dX, - double** dwPtr, - double** dbPtr, - bool requires_grad, - int use_bias, - int activation); - -template size_t get_mlp_bp_workspace_in_bytes( - int batch_size, - int num_layers, - const int* output_features); -template size_t get_mlp_bp_workspace_in_bytes( - int batch_size, - int num_layers, - const int* output_features); -template size_t get_mlp_bp_workspace_in_bytes( - int batch_size, - int num_layers, - const int* output_features); diff --git a/csrc/multi_tensor_adagrad.cu b/csrc/multi_tensor_adagrad.cu deleted file mode 100644 index 7bdb621..0000000 --- a/csrc/multi_tensor_adagrad.cu +++ /dev/null @@ -1,100 +0,0 @@ -#include -#include -#include -#include -// Another possibility: -// #include - -#include - -#include "multi_tensor_apply.cuh" -#include "type_shim.h" - -#define BLOCK_SIZE 1024 -#define ILP 4 - -typedef enum { - ADAGRAD_MODE_0 = 0, // L2 regularization mode. - ADAGRAD_MODE_1 = 1, // AdamW-style weight decay. - -} adagradMode_t; - -using MATH_T = float; - -template struct AdagradFunctor { - __device__ __forceinline__ void - operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<3> &tl, - const float epsilon, const float lr, adagradMode_t mode, - const float weight_decay) { - int tensor_loc = tl.block_to_tensor[blockIdx.x]; - int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; - - T *g = (T *)tl.addresses[0][tensor_loc]; - g += chunk_idx * chunk_size; - - T *p = (T *)tl.addresses[1][tensor_loc]; - p += chunk_idx * chunk_size; - - T *h = (T *)tl.addresses[2][tensor_loc]; - h += chunk_idx * chunk_size; - - n -= chunk_idx * chunk_size; - - // see note in multi_tensor_scale_kernel.cu - for (int i_start = 0; i_start < n && i_start < chunk_size; - i_start += blockDim.x * ILP) { - MATH_T r_g[ILP]; - MATH_T r_p[ILP]; - MATH_T r_h[ILP]; -#pragma unroll - for (int ii = 0; ii < ILP; ii++) { - int i = i_start + threadIdx.x + ii * blockDim.x; - if (i < n && i < chunk_size) { - r_g[ii] = g[i]; - r_p[ii] = p[i]; - r_h[ii] = h[i]; - } else { - r_g[ii] = MATH_T(0); - r_p[ii] = MATH_T(0); - r_h[ii] = MATH_T(0); - } - } -#pragma unroll - for (int ii = 0; ii < ILP; ii++) { - if (mode == ADAGRAD_MODE_0) { // L2 - r_g[ii] = r_g[ii] + weight_decay * r_p[ii]; - r_h[ii] = r_h[ii] + r_g[ii] * r_g[ii]; - r_p[ii] = r_p[ii] - lr * (r_g[ii] / (sqrtf(r_h[ii]) + epsilon)); - } else { // AdamW-style - r_h[ii] = r_h[ii] + r_g[ii] * r_g[ii]; - r_p[ii] = r_p[ii] - lr * (r_g[ii] / (sqrtf(r_h[ii]) + epsilon) + weight_decay * r_p[ii]); - } - } -#pragma unroll - for (int ii = 0; ii < ILP; ii++) { - int i = i_start + threadIdx.x + ii * blockDim.x; - if (i < n && i < chunk_size) { - p[i] = r_p[ii]; - h[i] = r_h[ii]; - } - } - } - } -}; - -void multi_tensor_adagrad_cuda( - int chunk_size, at::Tensor noop_flag, - std::vector> tensor_lists, const float lr, - const float epsilon, const int mode, const float weight_decay) { - using namespace at; - - // Assume single type across p,g,h now - DISPATCH_DOUBLE_FLOAT_AND_HALF_AND_BFLOAT16( - tensor_lists[0][0].scalar_type(), 0, "adagrad", - multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - AdagradFunctor(), epsilon, lr, - (adagradMode_t)mode, weight_decay);) - - AT_CUDA_CHECK(cudaGetLastError()); -} diff --git a/csrc/multi_tensor_adam.cu b/csrc/multi_tensor_adam.cu deleted file mode 100644 index 8aa3170..0000000 --- a/csrc/multi_tensor_adam.cu +++ /dev/null @@ -1,171 +0,0 @@ -#include -#include -#include -#include -// Another possibility: -// #include - -#include - -#include "type_shim.h" -#include "multi_tensor_apply.cuh" - -#define BLOCK_SIZE 1024 -#define ILP 4 - -typedef enum{ - ADAM_MODE_0 =0, // L2 regularization mode - ADAM_MODE_1 =1 // Decoupled weight decay mode(AdamW) -} adamMode_t; - -using MATH_T = float; - -template -struct AdamFunctor -{ - __device__ __forceinline__ void operator()( - int chunk_size, - volatile int* noop_gmem, - TensorListMetadata<4>& tl, - const float beta1, - const float beta2, - const float beta1_correction, - const float beta2_correction, - const float epsilon, - const float lr, - adamMode_t mode, - const float decay) - { - // I'd like this kernel to propagate infs/nans. - // if(*noop_gmem == 1) - // return; - - int tensor_loc = tl.block_to_tensor[blockIdx.x]; - - // potentially use to pass in list of scalar - // int tensor_num = tl.start_tensor_this_launch + tensor_loc; - - int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; - - T* g = (T*)tl.addresses[0][tensor_loc]; - g += chunk_idx*chunk_size; - - T* p = (T*)tl.addresses[1][tensor_loc]; - p += chunk_idx*chunk_size; - - T* m = (T*)tl.addresses[2][tensor_loc]; - m += chunk_idx*chunk_size; - - T* v = (T*)tl.addresses[3][tensor_loc]; - v += chunk_idx*chunk_size; - - n -= chunk_idx*chunk_size; - - // see note in multi_tensor_scale_kernel.cu - for(int i_start = 0; - i_start < n && i_start < chunk_size; - i_start += blockDim.x*ILP) - { - MATH_T r_g[ILP]; - MATH_T r_p[ILP]; - MATH_T r_m[ILP]; - MATH_T r_v[ILP]; -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) - { - r_g[ii] = g[i]; - r_p[ii] = p[i]; - r_m[ii] = m[i]; - r_v[ii] = v[i]; - } else { - r_g[ii] = MATH_T(0); - r_p[ii] = MATH_T(0); - r_m[ii] = MATH_T(0); - r_v[ii] = MATH_T(0); - } - } -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - if(mode == ADAM_MODE_0) { // L2 - r_g[ii] = r_g[ii] + (decay * r_p[ii]); - r_m[ii] = beta1 * r_m[ii] + (1-beta1) * r_g[ii]; - r_v[ii] = beta2 * r_v[ii] + (1-beta2) * r_g[ii] * r_g[ii]; - MATH_T next_m_unbiased = r_m[ii] / beta1_correction; - MATH_T next_v_unbiased = r_v[ii] / beta2_correction; - MATH_T denom = sqrtf(next_v_unbiased) + epsilon; - MATH_T update = next_m_unbiased / denom; - r_p[ii] = r_p[ii] - (lr * update); - } - else { // weight decay - r_m[ii] = beta1 * r_m[ii] + (1-beta1) * r_g[ii]; - r_v[ii] = beta2 * r_v[ii] + (1-beta2) * r_g[ii] * r_g[ii]; - MATH_T next_m_unbiased = r_m[ii] / beta1_correction; - MATH_T next_v_unbiased = r_v[ii] / beta2_correction; - MATH_T denom = sqrtf(next_v_unbiased) + epsilon; - MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]); - r_p[ii] = r_p[ii] - (lr * update); - } - } -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) - { - p[i] = r_p[ii]; - m[i] = r_m[ii]; - v[i] = r_v[ii]; - } - } - } - } -}; - -void multi_tensor_adam_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - const float lr, - const float beta1, - const float beta2, - const float epsilon, - const int step, - const int mode, - const int bias_correction, - const float weight_decay) -{ - using namespace at; - - // Handle bias correction mode - float bias_correction1 = 1.0f, bias_correction2 = 1.0f; - if (bias_correction == 1) { - bias_correction1 = 1 - std::pow(beta1, step); - bias_correction2 = 1 - std::pow(beta2, step); - } - - // Assume single type across p,g,m1,m2 now - DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( - tensor_lists[0][0].scalar_type(), 0, "adam", - multi_tensor_apply<4>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - AdamFunctor(), - beta1, - beta2, - bias_correction1, - bias_correction2, - epsilon, - lr, - (adamMode_t) mode, - weight_decay); ) - - AT_CUDA_CHECK(cudaGetLastError()); - -} diff --git a/csrc/multi_tensor_apply.cuh b/csrc/multi_tensor_apply.cuh deleted file mode 100644 index aaaee3f..0000000 --- a/csrc/multi_tensor_apply.cuh +++ /dev/null @@ -1,147 +0,0 @@ -#include -#include -#include -#include -#include -#include "compat.h" - -#include - -// #include - -// This header is the one-stop shop for all your multi-tensor apply needs. - - -// TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson) -constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30}; -constexpr int depth_to_max_blocks[5] = {2560, 2560, 2560, 2560, 2560}; - -template struct TensorListMetadata -{ - void* addresses[n][depth_to_max_tensors[n-1]]; - int sizes[depth_to_max_tensors[n-1]]; - unsigned char block_to_tensor[depth_to_max_blocks[n-1]]; - int block_to_chunk[depth_to_max_blocks[n-1]]; // I fear this needs to be a full int. - int start_tensor_this_launch; -}; - - -template -#ifdef __HIP_PLATFORM_HCC__ -__launch_bounds__(1024) -#endif -__global__ void multi_tensor_apply_kernel( - int chunk_size, - volatile int* noop_flag, - T tl, - U callable, - ArgTypes... args) -{ - // Hand the chunk information to the user-supplied functor to process however it likes. - callable(chunk_size, noop_flag, tl, args...); -} - -template -void multi_tensor_apply( - int block_size, - int chunk_size, - const at::Tensor& noop_flag, - const std::vector>& tensor_lists, - T callable, - ArgTypes... args) -{ - TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth"); - int len0 = tensor_lists[0].size(); - TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0"); - auto ref_device = tensor_lists[0][0].device(); - TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda"); - for (int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices - { - TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists"); - for(int t = 0; t < tensor_lists[l].size(); t++) - { - // TODO: Print which tensor fails. - bool contiguous_memory = (tensor_lists[l][t].is_sparse()) ? tensor_lists[l][t]._values().is_contiguous() : tensor_lists[l][t].is_contiguous(); -#ifdef VERSION_GE_1_5 - contiguous_memory = (contiguous_memory || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast) || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast3d)); -#endif - TORCH_CHECK(contiguous_memory, "A tensor was not contiguous."); - TORCH_CHECK(tensor_lists[l][t].device() == ref_device, "A tensor was not on the same device as the first tensor"); - TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch"); - } - } - - int ntensors = tensor_lists[0].size(); - - TensorListMetadata tl; - - const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0])); - auto stream = at::cuda::getCurrentCUDAStream(); - - tl.start_tensor_this_launch = 0; - int loc_block_info = 0; - int loc_tensor_info = 0; - for(int t = 0; t < ntensors; t++) - { - tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel(); - // skip empty tensors - if (tl.sizes[loc_tensor_info] == 0) { - continue; - } - for(int d = 0; d < depth; d++) { - if (tensor_lists[d][t].is_sparse()) { - at::Tensor dst = at::zeros(tensor_lists[d][t].sizes(), tensor_lists[d][t].options().layout(at::kStrided)); - dst.add_(tensor_lists[d][t]); - tl.addresses[d][loc_tensor_info] = dst.data_ptr(); - } else { - tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr(); - } - } - loc_tensor_info++; - - int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1)/chunk_size; - - for(int chunk = 0; chunk < chunks_this_tensor; chunk++) - { - // std::cout << chunks_this_tensor << std::endl; - tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1; - tl.block_to_chunk[loc_block_info] = chunk; - loc_block_info++; - - bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth-1] && - chunk == chunks_this_tensor - 1); - bool blocks_full = (loc_block_info == depth_to_max_blocks[depth-1]); - bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1); - if(tensors_full || blocks_full || last_chunk) - { - // using accscalar_t = acc_type; - multi_tensor_apply_kernel<<>>( - chunk_size, - noop_flag.DATA_PTR(), - tl, - callable, - args...); - - AT_CUDA_CHECK(cudaGetLastError()); - - // Reset. The control flow possibilities here make my brain hurt. - loc_block_info = 0; - if(chunk == chunks_this_tensor - 1) - { - // std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 << std::endl; - loc_tensor_info = 0; - tl.start_tensor_this_launch = t + 1; - } - else - { - // std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3 << std::endl; - tl.sizes[0] = tl.sizes[loc_tensor_info-1]; - for(int d = 0; d < depth; d++) - tl.addresses[d][0] = tl.addresses[d][loc_tensor_info-1]; - loc_tensor_info = 1; - tl.start_tensor_this_launch = t; - } - } - } - } -} diff --git a/csrc/multi_tensor_apply_base.cuh b/csrc/multi_tensor_apply_base.cuh deleted file mode 100644 index b6a9f17..0000000 --- a/csrc/multi_tensor_apply_base.cuh +++ /dev/null @@ -1,147 +0,0 @@ -#include -#include -#include -#include -#include -#include "compat.h" - -#include - -// #include - -// This header is the one-stop shop for all your multi-tensor apply needs. - - -// TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson) -constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30}; -constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320}; - -template struct TensorListMetadata -{ - void* addresses[n][depth_to_max_tensors[n-1]]; - int sizes[depth_to_max_tensors[n-1]]; - unsigned char block_to_tensor[depth_to_max_blocks[n-1]]; - int block_to_chunk[depth_to_max_blocks[n-1]]; // I fear this needs to be a full int. - int start_tensor_this_launch; -}; - - -template -#ifdef __HIP_PLATFORM_HCC__ -__launch_bounds__(1024) -#endif -__global__ void multi_tensor_apply_kernel( - int chunk_size, - volatile int* noop_flag, - T tl, - U callable, - ArgTypes... args) -{ - // Hand the chunk information to the user-supplied functor to process however it likes. - callable(chunk_size, noop_flag, tl, args...); -} - -template -void multi_tensor_apply( - int block_size, - int chunk_size, - const at::Tensor& noop_flag, - const std::vector>& tensor_lists, - T callable, - ArgTypes... args) -{ - TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth"); - int len0 = tensor_lists[0].size(); - TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0"); - auto ref_device = tensor_lists[0][0].device(); - TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda"); - for (int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices - { - TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists"); - for(int t = 0; t < tensor_lists[l].size(); t++) - { - // TODO: Print which tensor fails. - bool contiguous_memory = (tensor_lists[l][t].is_sparse()) ? tensor_lists[l][t]._values().is_contiguous() : tensor_lists[l][t].is_contiguous(); -#ifdef VERSION_GE_1_5 - contiguous_memory = (contiguous_memory || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast) || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast3d)); -#endif - TORCH_CHECK(contiguous_memory, "A tensor was not contiguous."); - TORCH_CHECK(tensor_lists[l][t].device() == ref_device, "A tensor was not on the same device as the first tensor"); - TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch"); - } - } - - int ntensors = tensor_lists[0].size(); - - TensorListMetadata tl; - - const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0])); - auto stream = at::cuda::getCurrentCUDAStream(); - - tl.start_tensor_this_launch = 0; - int loc_block_info = 0; - int loc_tensor_info = 0; - for(int t = 0; t < ntensors; t++) - { - tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel(); - // skip empty tensors - if (tl.sizes[loc_tensor_info] == 0) { - continue; - } - for(int d = 0; d < depth; d++) { - if (tensor_lists[d][t].is_sparse()) { - at::Tensor dst = at::zeros(tensor_lists[d][t].sizes(), tensor_lists[d][t].options().layout(at::kStrided)); - dst.add_(tensor_lists[d][t]); - tl.addresses[d][loc_tensor_info] = dst.data_ptr(); - } else { - tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr(); - } - } - loc_tensor_info++; - - int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1)/chunk_size; - - for(int chunk = 0; chunk < chunks_this_tensor; chunk++) - { - // std::cout << chunks_this_tensor << std::endl; - tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1; - tl.block_to_chunk[loc_block_info] = chunk; - loc_block_info++; - - bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth-1] && - chunk == chunks_this_tensor - 1); - bool blocks_full = (loc_block_info == depth_to_max_blocks[depth-1]); - bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1); - if(tensors_full || blocks_full || last_chunk) - { - // using accscalar_t = acc_type; - multi_tensor_apply_kernel<<>>( - chunk_size, - noop_flag.DATA_PTR(), - tl, - callable, - args...); - - AT_CUDA_CHECK(cudaGetLastError()); - - // Reset. The control flow possibilities here make my brain hurt. - loc_block_info = 0; - if(chunk == chunks_this_tensor - 1) - { - // std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 << std::endl; - loc_tensor_info = 0; - tl.start_tensor_this_launch = t + 1; - } - else - { - // std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3 << std::endl; - tl.sizes[0] = tl.sizes[loc_tensor_info-1]; - for(int d = 0; d < depth; d++) - tl.addresses[d][0] = tl.addresses[d][loc_tensor_info-1]; - loc_tensor_info = 1; - tl.start_tensor_this_launch = t; - } - } - } - } -} diff --git a/csrc/multi_tensor_axpby_kernel.cu b/csrc/multi_tensor_axpby_kernel.cu deleted file mode 100644 index 87f536b..0000000 --- a/csrc/multi_tensor_axpby_kernel.cu +++ /dev/null @@ -1,157 +0,0 @@ -#include -#include -#include -#include -// Another possibility: -// #include - -#include - -#include "type_shim.h" -#include "multi_tensor_apply.cuh" - -#define BLOCK_SIZE 1024 -#define ILP 4 - -template -__device__ __forceinline__ bool is_aligned(T* p){ - return ((uint64_t)p) % (ILP*sizeof(T)) == 0; -} - -template -__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){ - typedef typename std::aligned_storage::type LT; - ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset]; -} - -template -struct AxpbyFunctor -{ - __device__ __forceinline__ void operator()( - int chunk_size, - volatile int* noop_gmem, - TensorListMetadata<3>& tl, - float a, - float b, - int arg_to_check) - { - // I'd like this kernel to propagate infs/nans. - // if(*noop_gmem == 1) - // return; - - int tensor_loc = tl.block_to_tensor[blockIdx.x]; - int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; - - x_t* x = (x_t*)tl.addresses[0][tensor_loc]; - x += chunk_idx*chunk_size; - - y_t* y = (y_t*)tl.addresses[1][tensor_loc]; - y += chunk_idx*chunk_size; - - out_t* out = (out_t*)tl.addresses[2][tensor_loc]; - out += chunk_idx*chunk_size; - - n -= chunk_idx*chunk_size; - - bool finite = true; - x_t r_x[ILP]; - y_t r_y[ILP]; - out_t r_out[ILP]; - - // to make things simple, we put aligned case in a different code path - if(n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x) && is_aligned(y) && is_aligned(out)) - { - for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x) - { - // load - load_store(r_x, x, 0 , i_start); - load_store(r_y, y, 0 , i_start); -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - r_out[ii] = a*static_cast(r_x[ii]) + b*static_cast(r_y[ii]); - if(arg_to_check == -1) - finite = finite && (isfinite(r_x[ii]) && isfinite(r_y[ii])); - if(arg_to_check == 0) - finite = finite && isfinite(r_x[ii]); - if(arg_to_check == 1) - finite = finite && isfinite(r_y[ii]); - } - // store - load_store(out, r_out, i_start , 0); - } - } - else - { - // Non-divergent exit condition for __syncthreads, not necessary here - for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP) - { -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - r_x[ii] = 0; - r_y[ii] = 0; - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) - { - r_x[ii] = x[i]; - r_y[ii] = y[i]; - } - } -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - r_out[ii] = a*static_cast(r_x[ii]) + b*static_cast(r_y[ii]); - if(arg_to_check == -1) - finite = finite && (isfinite(r_x[ii]) && isfinite(r_y[ii])); - if(arg_to_check == 0) - finite = finite && isfinite(r_x[ii]); - if(arg_to_check == 1) - finite = finite && isfinite(r_y[ii]); - } - // see note in multi_tensor_scale_kernel.cu -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) - out[i] = r_out[ii]; - } - } - } - if(!finite) - *noop_gmem = 1; // Blindly fire off a write. These will race but that's ok. - } -}; - -void multi_tensor_axpby_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - float a, - float b, - int arg_to_check) -{ - using namespace at; - // The output (downscaled) type is always float. - // If build times suffer, think about where to put this dispatch, - // and what logic should be moved out of multi_tensor_apply. - - DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_axpby_cuda", - DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[1][0].scalar_type(), 1, "multi_tensor_axpby_cuda", - DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[2][0].scalar_type(), 2, "multi_tensor_axpby_cuda", - multi_tensor_apply<3>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - AxpbyFunctor(), - a, - b, - arg_to_check); ))) - - AT_CUDA_CHECK(cudaGetLastError()); - - // AT_CUDA_CHECK(cudaDeviceSynchronize()); -} diff --git a/csrc/multi_tensor_l2norm_kernel.cu b/csrc/multi_tensor_l2norm_kernel.cu deleted file mode 100644 index db713c2..0000000 --- a/csrc/multi_tensor_l2norm_kernel.cu +++ /dev/null @@ -1,456 +0,0 @@ -#include -#include -#include -#include -#include -// Another possibility: -// #include - -#include - -#include "type_shim.h" -#include "multi_tensor_apply_base.cuh" - -#define BLOCK_SIZE 512 -#define ILP 4 - -template -__device__ __forceinline__ bool is_aligned(T* p){ - return ((uint64_t)p) % (ILP*sizeof(T)) == 0; -} - -template -__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){ - typedef typename std::aligned_storage::type LT; - ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset]; -} - -template -struct L2NormFunctor -{ - __device__ __forceinline__ void operator()( - int chunk_size, - volatile int* noop_gmem, - TensorListMetadata<1>& tl, - float* output, - float* output_per_tensor, - bool per_tensor, - int max_chunks_per_tensor) - { - // I'd like this kernel to propagate infs/nans. - // if(*noop_gmem == 1) - // return; - - int tensor_loc = tl.block_to_tensor[blockIdx.x]; - int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; - - x_t* x = (x_t*)tl.addresses[0][tensor_loc]; - x += chunk_idx*chunk_size; - - n -= chunk_idx*chunk_size; - - __shared__ float s_vals[512]; - - float vals[ILP]; // = {0}; // this probably works too but I want to be sure... - x_t r_x[ILP]; - for(int i = 0; i < ILP; i++) - { - vals[i] = 0.f; - r_x[i] = 0; - } - - // to make things simple, we put aligned case in a different code path - if(n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x)) - { - for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x) - { - // load - load_store(r_x, x, 0 , i_start); -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - float next = static_cast(r_x[ii]); - vals[ii] += next*next; - } - } - } - else - { - for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP) - { -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) - { - float next = static_cast(x[i]); - vals[ii] += next*next; - } - } - } - } - - float val = 0.f; - for(int i = 0; i < ILP; i++) - val += vals[i]; - - float final = reduce_block_into_lanes(s_vals, val); - - if(threadIdx.x == 0) - { - if(!isfinite(final)) - *noop_gmem = 1; // Blindly fire off a write. These will race but that's ok. - output[blockIdx.x] += final; - if(per_tensor) - output_per_tensor[(tl.start_tensor_this_launch + tensor_loc)*max_chunks_per_tensor + chunk_idx] = final; - } - } -}; - -// Probably better to template, but since we are not likely to support other norm -template -struct MaxNormFunctor -{ - __device__ __forceinline__ void operator()( - int chunk_size, - volatile int* noop_gmem, - TensorListMetadata<1>& tl, - float* output, - float* output_per_tensor, - bool per_tensor, - int max_chunks_per_tensor) - { - // I'd like this kernel to propagate infs/nans. - // if(*noop_gmem == 1) - // return; - - int tensor_loc = tl.block_to_tensor[blockIdx.x]; - int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; - - x_t* x = (x_t*)tl.addresses[0][tensor_loc]; - x += chunk_idx*chunk_size; - - n -= chunk_idx*chunk_size; - - __shared__ float s_vals[512]; - - float vals[ILP]; // = {0}; // this probably works too but I want to be sure... - x_t r_x[ILP]; - for(int i = 0; i < ILP; i++) - { - vals[i] = 0.f; - r_x[i] = 0; - } - - // to make things simple, we put aligned case in a different code path - if(n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x)) - { - for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x) - { - // load - load_store(r_x, x, 0 , i_start); -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - float next = static_cast(r_x[ii]); - vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next)); - } - } - } - else - { - for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP) - { -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) - { - float next = static_cast(x[i]); - vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next)); - } - } - } - } - - float val = 0.f; - for(int i = 0; i < ILP; i++) - val = fmaxf(fabsf(val), fabsf(vals[i])); - - float final = reduce_block_into_lanes_max_op(s_vals, val); - - if(threadIdx.x == 0) - { - if(!isfinite(final)) - *noop_gmem = 1; // Blindly fire off a write. These will race but that's ok. - output[blockIdx.x] = fmaxf(fabsf(output[blockIdx.x]), fabsf(final)); - if(per_tensor) - output_per_tensor[(tl.start_tensor_this_launch + tensor_loc)*max_chunks_per_tensor + chunk_idx] = final; - } - } -}; - - -__global__ void -#ifdef __HIP_PLATFORM_HCC__ -__launch_bounds__(1024) -#endif -cleanup( - float* output, - float* output_per_tensor, - float* ret, - float* ret_per_tensor, - bool per_tensor, - int max_chunks_per_tensor) -{ - __shared__ float vals[512]; - - if(blockIdx.x == 0) - { - float val = 0; - if(threadIdx.x < 320) - val = output[threadIdx.x]; - - float final = reduce_block_into_lanes(vals, val); - - if(threadIdx.x == 0) - *ret = sqrt(final); - } - - if(per_tensor) - { - float* output_this_tensor = output_per_tensor + blockIdx.x*max_chunks_per_tensor; - - float val = 0; - for(int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x) - val += output_this_tensor[i]; - - float final = reduce_block_into_lanes(vals, val); - - if(threadIdx.x == 0) - ret_per_tensor[blockIdx.x] = sqrt(final); - } -} - -__global__ void -#ifdef __HIP_PLATFORM_HCC__ -__launch_bounds__(1024) -#endif -cleanup_v2( - float* output, - float* output_per_tensor, - float* ret, - float* ret_per_tensor, - bool per_tensor, - int max_chunks_per_tensor, - int norm_type, - float alpha, - float beta) -{ - __shared__ float vals[512]; - - if(blockIdx.x == 0) - { - float val = 0; - if(threadIdx.x < 320) - val = output[threadIdx.x]; - - if (norm_type == 0) { - float final = reduce_block_into_lanes_max_op(vals, val); - if(threadIdx.x == 0) - *ret = alpha * (*ret) + beta * final; - } - else { - float final = reduce_block_into_lanes(vals, val); - if(threadIdx.x == 0) - *ret = sqrt(alpha * (*ret) * (*ret) + beta * final); - } - } - - if(per_tensor) - { - float* output_this_tensor = output_per_tensor + blockIdx.x*max_chunks_per_tensor; - - if (norm_type == 0) { - float val = 0; - for(int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x) - val = fmaxf(fabsf(val), fabsf(output_this_tensor[i])); - - float final = reduce_block_into_lanes_max_op(vals, val); - - if(threadIdx.x == 0) - ret_per_tensor[blockIdx.x] = alpha * ret_per_tensor[blockIdx.x] + beta * final; - } - else { - float val = 0; - for(int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x) - val += output_this_tensor[i]; - - float final = reduce_block_into_lanes(vals, val); - - if(threadIdx.x == 0) - ret_per_tensor[blockIdx.x] = sqrt(alpha * ret_per_tensor[blockIdx.x] * ret_per_tensor[blockIdx.x] + beta * final); - } - } -} - -std::tuple multi_tensor_l2norm_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - at::optional per_tensor_python) -{ - bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false; - - auto float_options = tensor_lists[0][0].options().dtype(at::kFloat); - auto output = at::zeros({320}, float_options); - - at::Tensor output_per_tensor; - at::Tensor ret_per_tensor; - - int ntensors = tensor_lists[0].size(); - int max_chunks_per_tensor = -1; - - if(per_tensor) - { - for(int t = 0; t < ntensors; t++) - { - int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1)/chunk_size; - if(max_chunks_this_tensor > max_chunks_per_tensor) - max_chunks_per_tensor = max_chunks_this_tensor; - } - output_per_tensor = at::zeros({ntensors*max_chunks_per_tensor}, float_options); - ret_per_tensor = at::empty({ntensors}, float_options); - } - else - { - ret_per_tensor = at::empty({0}, float_options); - } - - DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda", - multi_tensor_apply<1>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - L2NormFunctor(), - output.DATA_PTR(), - per_tensor ? output_per_tensor.DATA_PTR() : nullptr, - per_tensor, - max_chunks_per_tensor);) - - AT_CUDA_CHECK(cudaGetLastError()); - // AT_CUDA_CHECK(cudaDeviceSynchronize()); - - // This involves one more small kernel launches, but will be negligible end to end. - // I could get rid of these by hacking the functor + multi tensor harness with persistence - // logic, but keeping it simple for now - auto ret = at::empty({1}, output.options()); - const at::cuda::OptionalCUDAGuard device_guard(device_of(output)); - auto stream = at::cuda::getCurrentCUDAStream(); - cleanup<<>>( - output.DATA_PTR(), - per_tensor ? output_per_tensor.DATA_PTR() : nullptr, - ret.DATA_PTR(), - per_tensor ? ret_per_tensor.DATA_PTR() : nullptr, - per_tensor, - max_chunks_per_tensor); - - return std::tuple(ret, ret_per_tensor); -} - - -// Compute and update grad norm -// Here use a per tensor norm, and blend new norm(n) and old norm(gn) by -// L-2: gn = sqrt(a * gn^2 + b * n^2) -// L-inf: gn = a * gn + b * n -void multi_tensor_norm_out_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - at::Tensor out, - const float alpha, - const float beta, - const int norm_type) -{ - auto float_options = tensor_lists[0][0].options().dtype(at::kFloat); - TORCH_CHECK(tensor_lists[0][0].device() == noop_flag.device(), "noop flag should be on the same device as tensors"); - // we don't need global thus uses empty here - auto output = at::empty({320}, float_options); - - at::Tensor output_per_tensor; - at::Tensor ret_per_tensor; - - int ntensors = tensor_lists[0].size(); - int max_chunks_per_tensor = -1; - - for(int t = 0; t < ntensors; t++) - { - int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1)/chunk_size; - if(max_chunks_this_tensor > max_chunks_per_tensor) - max_chunks_per_tensor = max_chunks_this_tensor; - } - - // Although it is single write then read, still need to be zero - // Since tailing element also participate cleanup - output_per_tensor = at::zeros({ntensors*max_chunks_per_tensor}, float_options); - - if (norm_type == 0) { - DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16( - tensor_lists[0][0].scalar_type(), 0, "multi_tensor_maxnorm_cuda", - multi_tensor_apply<1>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - MaxNormFunctor(), - output.DATA_PTR(), - output_per_tensor.DATA_PTR(), - true, - max_chunks_per_tensor);) - } - else { - DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16( - tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda", - multi_tensor_apply<1>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - L2NormFunctor(), - output.DATA_PTR(), - output_per_tensor.DATA_PTR(), - true, - max_chunks_per_tensor);) - } - AT_CUDA_CHECK(cudaGetLastError()); - - // AT_CUDA_CHECK(cudaDeviceSynchronize()); - - // This involves one more small kernel launches, but will be negligible end to end. - // I could get rid of these by hacking the functor + multi tensor harness with persistence - // logic, but keeping it simple for now - auto ret = at::empty({1}, output.options()); - - // Adding the following device guard since it happens sometimes that the - // tensors are on one device and the cuda stream is on another device which - // results in ILLEGAL MEM ACCESS error. - const at::cuda::OptionalCUDAGuard device_guard(device_of(output)); - auto stream = at::cuda::getCurrentCUDAStream(); - cleanup_v2<<>>( - output.DATA_PTR(), - output_per_tensor.DATA_PTR(), - ret.DATA_PTR(), - out.DATA_PTR(), - true, - max_chunks_per_tensor, - norm_type, - alpha, - beta); - - return ; -} diff --git a/csrc/multi_tensor_l2norm_kernel_mp.cu b/csrc/multi_tensor_l2norm_kernel_mp.cu deleted file mode 100644 index 5b2299b..0000000 --- a/csrc/multi_tensor_l2norm_kernel_mp.cu +++ /dev/null @@ -1,220 +0,0 @@ -#include -#include -#include -#include -#include -// Another possibility: -// #include - -#include - -#include "type_shim.h" -#include "multi_tensor_apply_base.cuh" - -#define BLOCK_SIZE 512 -#define ILP 4 - -template -__device__ __forceinline__ bool is_aligned(T* p){ - return ((uint64_t)p) % (ILP*sizeof(T)) == 0; -} - -template -__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){ - typedef typename std::aligned_storage::type LT; - ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset]; -} - -template -struct L2NormFunctor -{ - __device__ __forceinline__ void operator()( - int chunk_size, - volatile int* noop_gmem, - TensorListMetadata<1>& tl, - float* output, - float* output_per_tensor, - bool per_tensor, - int max_chunks_per_tensor) - { - if (*noop_gmem) { - return; - } - - int tensor_loc = tl.block_to_tensor[blockIdx.x]; - int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; - - x_t* x = (x_t*)tl.addresses[0][tensor_loc]; - x += chunk_idx*chunk_size; - - n -= chunk_idx*chunk_size; - - __shared__ float s_vals[512]; - - float vals[ILP]; // = {0}; // this probably works too but I want to be sure... - x_t r_x[ILP]; - for(int i = 0; i < ILP; i++) - { - vals[i] = 0.f; - r_x[i] = 0; - } - - // to make things simple, we put aligned case in a different code path - if(n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x)) - { - for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x) - { - // load - load_store(r_x, x, 0 , i_start); -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - float next = static_cast(r_x[ii]); - vals[ii] += next*next; - } - } - } - else - { - for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP) - { -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) - { - float next = static_cast(x[i]); - vals[ii] += next*next; - } - } - } - } - - float val = 0.f; - for(int i = 0; i < ILP; i++) - val += vals[i]; - - float final = reduce_block_into_lanes(s_vals, val); - - if(threadIdx.x == 0) - { - if(!isfinite(final)) - *noop_gmem = 1; // Blindly fire off a write. These will race but that's ok. - output[blockIdx.x] += final; - if(per_tensor) - output_per_tensor[(tl.start_tensor_this_launch + tensor_loc)*max_chunks_per_tensor + chunk_idx] = final; - } - } -}; - -__global__ void -#ifdef __HIP_PLATFORM_HCC__ -__launch_bounds__(1024) -#endif -cleanup( - float* output, - float* output_per_tensor, - float* ret, - float* ret_per_tensor, - bool per_tensor, - int max_chunks_per_tensor, - volatile int* noop_gmem) -{ - if (*noop_gmem) { - return; - } - __shared__ float vals[512]; - - if(blockIdx.x == 0) - { - float val = 0; - if(threadIdx.x < 320) - val = output[threadIdx.x]; - - float final = reduce_block_into_lanes(vals, val); - - if(threadIdx.x == 0) - *ret = sqrt(final); - } - - if(per_tensor) - { - float* output_this_tensor = output_per_tensor + blockIdx.x*max_chunks_per_tensor; - - float val = 0; - for(int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x) - val += output_this_tensor[i]; - - float final = reduce_block_into_lanes(vals, val); - - if(threadIdx.x == 0) - ret_per_tensor[blockIdx.x] = sqrt(final); - } -} - -std::tuple multi_tensor_l2norm_mp_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - at::optional per_tensor_python) -{ - bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false; - - auto float_options = tensor_lists[0][0].options().dtype(at::kFloat); - auto output = at::zeros({320}, float_options); - - at::Tensor output_per_tensor; - at::Tensor ret_per_tensor; - - int ntensors = tensor_lists[0].size(); - int max_chunks_per_tensor = -1; - - if(per_tensor) - { - for(int t = 0; t < ntensors; t++) - { - int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1)/chunk_size; - if(max_chunks_this_tensor > max_chunks_per_tensor) - max_chunks_per_tensor = max_chunks_this_tensor; - } - output_per_tensor = at::zeros({ntensors*max_chunks_per_tensor}, float_options); - ret_per_tensor = at::empty({ntensors}, float_options); - } - else - { - ret_per_tensor = at::empty({0}, float_options); - } - - DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_mp_cuda", - multi_tensor_apply<1>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - L2NormFunctor(), - output.data_ptr(), - per_tensor ? output_per_tensor.data_ptr() : nullptr, - per_tensor, - max_chunks_per_tensor);) - - AT_CUDA_CHECK(cudaGetLastError()); - // AT_CUDA_CHECK(cudaDeviceSynchronize()); - - // This involves one more small kernel launches, but will be negligible end to end. - // I could get rid of these by hacking the functor + multi tensor harness with persistence - // logic, but keeping it simple for now - auto ret = at::empty({1}, output.options()); - const at::cuda::OptionalCUDAGuard device_guard(device_of(output)); - auto stream = at::cuda::getCurrentCUDAStream(); - cleanup<<>>( - output.data_ptr(), - per_tensor ? output_per_tensor.data_ptr() : nullptr, - ret.data_ptr(), - per_tensor ? ret_per_tensor.data_ptr() : nullptr, - per_tensor, - max_chunks_per_tensor, noop_flag.data_ptr()); - - return std::tuple(ret, ret_per_tensor); -} diff --git a/csrc/multi_tensor_l2norm_scale_kernel.cu b/csrc/multi_tensor_l2norm_scale_kernel.cu deleted file mode 100644 index f856a52..0000000 --- a/csrc/multi_tensor_l2norm_scale_kernel.cu +++ /dev/null @@ -1,326 +0,0 @@ -#include -#include -#include -#include -#include -// Another possibility: -// #include - -#include - -#include "type_shim.h" -#include "multi_tensor_apply_base.cuh" - -#define BLOCK_SIZE 512 -#define ILP 4 - -template -__device__ __forceinline__ bool is_aligned(T* p){ - return ((uint64_t)p) % (ILP*sizeof(T)) == 0; -} - -template -__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){ - typedef typename std::aligned_storage::type LT; - ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset]; -} - -template -struct L2NormScaleFunctor -{ - __device__ __forceinline__ void operator()( - int chunk_size, - volatile int* noop_gmem, - TensorListMetadata<2>& tl, - float* output, - float* output_per_tensor, - float scale, - bool per_tensor, - int max_chunks_per_tensor) - { - // I'd like this kernel to propagate infs/nans. - // if(*noop_gmem == 1) - // return; - - int tensor_loc = tl.block_to_tensor[blockIdx.x]; - int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; - - in_t* in = (in_t*)tl.addresses[0][tensor_loc]; - in += chunk_idx*chunk_size; - - out_t* out = (out_t*)tl.addresses[1][tensor_loc]; - out += chunk_idx*chunk_size; - - n -= chunk_idx*chunk_size; - - __shared__ float s_vals[512]; - - float vals[ILP]; // = {0}; // this probably works too but I want to be sure... - in_t r_in[ILP]; - for(int i = 0; i < ILP; i++) - { - vals[i] = 0.f; - r_in[i] = 0; - } - //bool finite = true; - out_t r_out[ILP]; - - // to make things simple, we put aligned case in a different code path - if(n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(in) && is_aligned(out)) - { - for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x) - { - // load - load_store(r_in, in, 0 , i_start); -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - float next = static_cast(r_in[ii]); - r_out[ii] = next*scale; - vals[ii] += next*next; - //finite = finite && isfinite(r_in[ii]); - } - load_store(out, r_out, i_start, 0); - } - } - else - { - for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP) - { -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - r_in[ii] = 0; - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) - { - r_in[ii] = in[i]; - float next = static_cast(in[i]); - vals[ii] += next*next; - } - } -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - r_out[ii] = static_cast(r_in[ii]) * scale; - // finite = finite && isfinite(r_in[ii]); - } -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) - out[i] = r_out[ii]; - } - } - } - - float val = 0.f; - for(int i = 0; i < ILP; i++) - val += vals[i]; - - float final = reduce_block_into_lanes(s_vals, val); - - if(threadIdx.x == 0) - { - if(!isfinite(final)) - *noop_gmem = 1; // Blindly fire off a write. These will race but that's ok. - output[blockIdx.x] += final; - if(per_tensor) - output_per_tensor[(tl.start_tensor_this_launch + tensor_loc)*max_chunks_per_tensor + chunk_idx] = final; - } - } -}; -// Probably better to template, but since we are not likely to support other norm -template -struct MaxNormFunctor -{ - __device__ __forceinline__ void operator()( - int chunk_size, - volatile int* noop_gmem, - TensorListMetadata<1>& tl, - float* output, - float* output_per_tensor, - bool per_tensor, - int max_chunks_per_tensor) - { - // I'd like this kernel to propagate infs/nans. - // if(*noop_gmem == 1) - // return; - - int tensor_loc = tl.block_to_tensor[blockIdx.x]; - int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; - - x_t* x = (x_t*)tl.addresses[0][tensor_loc]; - x += chunk_idx*chunk_size; - - n -= chunk_idx*chunk_size; - - __shared__ float s_vals[512]; - - float vals[ILP]; // = {0}; // this probably works too but I want to be sure... - x_t r_x[ILP]; - for(int i = 0; i < ILP; i++) - { - vals[i] = 0.f; - r_x[i] = 0; - } - - // to make things simple, we put aligned case in a different code path - if(n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x)) - { - for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x) - { - // load - load_store(r_x, x, 0 , i_start); -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - float next = static_cast(r_x[ii]); - vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next)); - } - } - } - else - { - for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP) - { -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) - { - float next = static_cast(x[i]); - vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next)); - } - } - } - } - - float val = 0.f; - for(int i = 0; i < ILP; i++) - val = fmaxf(fabsf(val), fabsf(vals[i])); - - float final = reduce_block_into_lanes_max_op(s_vals, val); - - if(threadIdx.x == 0) - { - if(!isfinite(final)) - *noop_gmem = 1; // Blindly fire off a write. These will race but that's ok. - output[blockIdx.x] = fmaxf(fabsf(output[blockIdx.x]), fabsf(final)); - if(per_tensor) - output_per_tensor[(tl.start_tensor_this_launch + tensor_loc)*max_chunks_per_tensor + chunk_idx] = final; - } - } -}; - -__global__ void cleanup_v3( - float* output, - float* output_per_tensor, - float* ret, - float* ret_per_tensor, - bool per_tensor, - int max_chunks_per_tensor) -{ - __shared__ float vals[512]; - - if(blockIdx.x == 0) - { - float val = 0; - if(threadIdx.x < 320) - val = output[threadIdx.x]; - - float final = reduce_block_into_lanes(vals, val); - - if(threadIdx.x == 0) - *ret = sqrt(final); - } - - if(per_tensor) - { - float* output_this_tensor = output_per_tensor + blockIdx.x*max_chunks_per_tensor; - - float val = 0; - for(int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x) - val += output_this_tensor[i]; - - float final = reduce_block_into_lanes(vals, val); - - if(threadIdx.x == 0) - ret_per_tensor[blockIdx.x] = sqrt(final); - } -} - - -std::tuple multi_tensor_l2norm_scale_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - float scale, - at::optional per_tensor_python) -{ - bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false; - - auto float_options = tensor_lists[0][0].options().dtype(at::kFloat); - auto output = at::zeros({320}, float_options); - - at::Tensor output_per_tensor; - at::Tensor ret_per_tensor; - - int ntensors = tensor_lists[0].size(); - int max_chunks_per_tensor = -1; - - if(per_tensor) - { - for(int t = 0; t < ntensors; t++) - { - int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1)/chunk_size; - if(max_chunks_this_tensor > max_chunks_per_tensor) - max_chunks_per_tensor = max_chunks_this_tensor; - } - output_per_tensor = at::zeros({ntensors*max_chunks_per_tensor}, float_options); - ret_per_tensor = at::empty({ntensors}, float_options); - } - else - { - ret_per_tensor = at::empty({0}, float_options); - } - - DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_scale_cuda", - DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "multi_tensor_l2norm_scale_cuda", - multi_tensor_apply<2>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - L2NormScaleFunctor(), - output.DATA_PTR(), - per_tensor ? output_per_tensor.DATA_PTR() : nullptr, - scale, - per_tensor, - max_chunks_per_tensor);)) - - AT_CUDA_CHECK(cudaGetLastError()); - // AT_CUDA_CHECK(cudaDeviceSynchronize()); - - // This involves one more small kernel launches, but will be negligible end to end. - // I could get rid of these by hacking the functor + multi tensor harness with persistence - // logic, but keeping it simple for now - auto ret = at::empty({1}, output.options()); - const at::cuda::OptionalCUDAGuard device_guard(device_of(output)); - auto stream = at::cuda::getCurrentCUDAStream(); - cleanup_v3<<>>( - output.DATA_PTR(), - per_tensor ? output_per_tensor.DATA_PTR() : nullptr, - ret.DATA_PTR(), - per_tensor ? ret_per_tensor.DATA_PTR() : nullptr, - per_tensor, - max_chunks_per_tensor); - - return std::tuple(ret, ret_per_tensor); -} - - diff --git a/csrc/multi_tensor_lamb.cu b/csrc/multi_tensor_lamb.cu deleted file mode 100644 index 54a05a7..0000000 --- a/csrc/multi_tensor_lamb.cu +++ /dev/null @@ -1,413 +0,0 @@ -#include -#include -#include -#include -// Another possibility: -// #include - -#include - -#include "type_shim.h" -#include "multi_tensor_apply.cuh" - -#define BLOCK_SIZE 1024 -#define ILP 4 - -template -__device__ __forceinline__ bool is_aligned(T* p){ - return ((uint64_t)p) % (ILP*sizeof(T)) == 0; -} - -template -__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){ - typedef typename std::aligned_storage::type LT; - ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset]; -} - -typedef enum{ - MOMENT_MODE_0 =0, // L2 regularization mode - MOMENT_MODE_1 =1 // Decoupled weight decay mode -} adamMode_t; - -std::tuple multi_tensor_l2norm_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - at::optional per_tensor_python); - -using MATH_T = float; - -template -struct LAMBStage1Functor -{ - __device__ __forceinline__ void operator()( - int chunk_size, - volatile int* noop_gmem, - TensorListMetadata<4>& tl, - const float beta1, - const float beta2, - const float beta3, - const float beta1_correction, - const float beta2_correction, - const float epsilon, - adamMode_t mode, - const float decay, - const float* global_grad_norm, - const float max_global_grad_norm) - { - // I'd like this kernel to propagate infs/nans. - // if(*noop_gmem == 1) - // return; - - int tensor_loc = tl.block_to_tensor[blockIdx.x]; - int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; - - float clipped_global_grad_norm = (*global_grad_norm) > max_global_grad_norm ? (*global_grad_norm) / max_global_grad_norm : 1.0f; - - T* g = (T*)tl.addresses[0][tensor_loc]; - g += chunk_idx*chunk_size; - - T* p = (T*)tl.addresses[1][tensor_loc]; - p += chunk_idx*chunk_size; - - T* m = (T*)tl.addresses[2][tensor_loc]; - m += chunk_idx*chunk_size; - - T* v = (T*)tl.addresses[3][tensor_loc]; - v += chunk_idx*chunk_size; - - n -= chunk_idx*chunk_size; - - MATH_T r_g[ILP]; - MATH_T r_p[ILP]; - MATH_T r_m[ILP]; - MATH_T r_v[ILP]; - // to make things simple, we put aligned case in a different code path - if(n % ILP == 0 && - chunk_size % ILP == 0 && - is_aligned(g) && - is_aligned(p) && - is_aligned(m) && - is_aligned(v)) - { - T l_g[ILP]; - T l_p[ILP]; - T l_m[ILP]; - T l_v[ILP]; - for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x) - { - // load - load_store(l_g, g, 0, i_start); - if (decay != 0) - load_store(l_p, p, 0, i_start); - load_store(l_m, m, 0, i_start); - load_store(l_v, v, 0, i_start); - // unpack -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - r_g[ii] = l_g[ii]; - if (decay == 0) { - r_p[ii] = MATH_T(0); - } - else { - r_p[ii] = l_p[ii]; - } - r_m[ii] = l_m[ii]; - r_v[ii] = l_v[ii]; - } -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - if (mode == MOMENT_MODE_0) { - MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm; - // L2 on scaled grad - scaled_grad = scaled_grad + decay*r_p[ii]; - r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; - r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad; - MATH_T next_m_unbiased = r_m[ii] / beta1_correction; - MATH_T next_v_unbiased = r_v[ii] / beta2_correction; - MATH_T denom = sqrtf(next_v_unbiased) + epsilon; - r_p[ii] = next_m_unbiased / denom; - } - else { - MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm; - r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; - r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad; - MATH_T next_m_unbiased = r_m[ii] / beta1_correction; - MATH_T next_v_unbiased = r_v[ii] / beta2_correction; - MATH_T denom = sqrtf(next_v_unbiased) + epsilon; - r_p[ii] = (next_m_unbiased/denom) + (decay*r_p[ii]); - } - } -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - l_p[ii] = r_p[ii]; - l_m[ii] = r_m[ii]; - l_v[ii] = r_v[ii]; - } - // store - load_store(g, l_p, i_start, 0); - load_store(m, l_m, i_start, 0); - load_store(v, l_v, i_start, 0); - } - } - else - { - // see note in multi_tensor_scale_kernel.cu - for(int i_start = 0; - i_start < n && i_start < chunk_size; - i_start += blockDim.x*ILP) - { - MATH_T r_g[ILP]; - MATH_T r_p[ILP]; - MATH_T r_m[ILP]; - MATH_T r_v[ILP]; -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) - { - r_g[ii] = g[i]; - // special ?optimization? for lamb stage 1 - if (decay == 0) { - r_p[ii] = MATH_T(0); - } - else { - r_p[ii] = p[i]; - } - r_m[ii] = m[i]; - r_v[ii] = v[i]; - } else { - r_g[ii] = MATH_T(0); - r_p[ii] = MATH_T(0); - r_m[ii] = MATH_T(0); - r_v[ii] = MATH_T(0); - } - } -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - if (mode == MOMENT_MODE_0) { - MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm; - // L2 on scaled grad - scaled_grad = scaled_grad + decay*r_p[ii]; - r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; - r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad; - MATH_T next_m_unbiased = r_m[ii] / beta1_correction; - MATH_T next_v_unbiased = r_v[ii] / beta2_correction; - MATH_T denom = sqrtf(next_v_unbiased) + epsilon; - r_p[ii] = next_m_unbiased / denom; - } - else { - MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm; - r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; - r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad; - MATH_T next_m_unbiased = r_m[ii] / beta1_correction; - MATH_T next_v_unbiased = r_v[ii] / beta2_correction; - MATH_T denom = sqrtf(next_v_unbiased) + epsilon; - r_p[ii] = (next_m_unbiased/denom) + (decay*r_p[ii]); - } - } -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) - { - g[i] = r_p[ii]; - m[i] = r_m[ii]; - v[i] = r_v[ii]; - } - } - } - } - } -}; - -// Step 2 reads in 'update' value and per-tensor param_norm and update_norm. -// It computes new parameter value. -template -struct LAMBStage2Functor -{ - __device__ __forceinline__ void operator()( - int chunk_size, - volatile int* noop_gmem, - TensorListMetadata<2>& tl, - const float* per_tensor_param_norm, - const float* per_tensor_update_norm, - const float learning_rate, - const float decay, - bool use_nvlamb) - { - // I'd like this kernel to propagate infs/nans. - // if(*noop_gmem == 1) - // return; - - int tensor_loc = tl.block_to_tensor[blockIdx.x]; - int tensor_num = tl.start_tensor_this_launch + tensor_loc; - int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; - - MATH_T ratio = learning_rate; - // nvlamb: apply adaptive learning rate to all parameters - // otherwise, only apply to those with non-zero weight decay - if (use_nvlamb || (decay != 0.0)) - { - float param_norm = per_tensor_param_norm[tensor_num]; - float update_norm = per_tensor_update_norm[tensor_num]; - ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate; - } - - T* update = (T*)tl.addresses[0][tensor_loc]; - update += chunk_idx*chunk_size; - - T* p = (T*)tl.addresses[1][tensor_loc]; - p += chunk_idx*chunk_size; - - n -= chunk_idx*chunk_size; - - // to make things simple, we put aligned case in a different code path - if(n % ILP == 0 && - chunk_size % ILP == 0 && - is_aligned(p) && - is_aligned(update)) - { - T r_p[ILP]; - T r_update[ILP]; - for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x) - { - // load - load_store(r_p, p, 0, i_start); - load_store(r_update, update, 0, i_start); -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - r_p[ii] = static_cast(r_p[ii]) - (ratio * static_cast(r_update[ii])); - } - load_store(p, r_p, i_start, 0); - } - } - else - { - for(int i_start = 0; - i_start < n && i_start < chunk_size; - i_start += blockDim.x*ILP) - { - MATH_T r_p[ILP]; - MATH_T r_update[ILP]; -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) - { - r_p[ii] = p[i]; - r_update[ii] = update[i]; - } - } -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - r_p[ii] = r_p[ii] - (ratio * r_update[ii]); - } -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) - { - p[i] = r_p[ii]; - } - } - } - } - } -}; - - -void multi_tensor_lamb_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - const float lr, - const float beta1, - const float beta2, - const float epsilon, - const int step, - const int bias_correction, - const float weight_decay, - const int grad_averaging, - const int mode, - at::Tensor global_grad_norm, - const float max_grad_norm, - at::optional use_nvlamb_python) -{ - using namespace at; - // Master weight and 32bit momentum(potentially changing) is not handled by this - // So we assume every tensor are all in the same type - - bool use_nvlamb = use_nvlamb_python.has_value() ? use_nvlamb_python.value() : false; - - // Handle bias correction mode - float bias_correction1 = 1.0f, bias_correction2 = 1.0f; - if (bias_correction == 1) { - bias_correction1 = 1 - std::pow(beta1, step); - bias_correction2 = 1 - std::pow(beta2, step); - } - - // Handle grad averaging mode - float beta3 = 1.0f; - if (grad_averaging == 1) beta3 = 1 - beta1; - - std::vector> grad_list(tensor_lists.begin(), tensor_lists.begin()+1); - std::vector> param_list(tensor_lists.begin()+1, tensor_lists.begin()+2); - - // Compute per tensor param norm - auto param_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, param_list, true); - - // We now in-place modify grad to store update before compute its norm - // Generally this is not a issue since people modify grad in step() method all the time - // We can also grab list of empty tensor to avoid this, but I'd like to save space/cpu code - DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1", - multi_tensor_apply<4>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - LAMBStage1Functor(), - beta1, - beta2, - beta3, // 1-beta1 or 1 depends on averaging mode - bias_correction1, - bias_correction2, - epsilon, - (adamMode_t) mode, - weight_decay, - global_grad_norm.DATA_PTR(), - max_grad_norm); ) - - // Compute update norms - auto update_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, grad_list, true); - - std::vector> grad_param_list(tensor_lists.begin(), tensor_lists.begin()+2); - - DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2", - multi_tensor_apply<2>( - BLOCK_SIZE, - chunk_size, - noop_flag, - grad_param_list, - LAMBStage2Functor(), - std::get<1>(param_norm_tuple).DATA_PTR(), - std::get<1>(update_norm_tuple).DATA_PTR(), - lr, - weight_decay, - use_nvlamb); ) - - AT_CUDA_CHECK(cudaGetLastError()); - -} diff --git a/csrc/multi_tensor_lamb_mp.cu b/csrc/multi_tensor_lamb_mp.cu deleted file mode 100644 index a213c18..0000000 --- a/csrc/multi_tensor_lamb_mp.cu +++ /dev/null @@ -1,496 +0,0 @@ -#include -#include -#include -#include -// Another possibility: -// #include - -#include - -#include "type_shim.h" -#include "multi_tensor_apply.cuh" - -#define BLOCK_SIZE 1024 -#define ILP 4 - -template -__device__ __forceinline__ bool is_aligned(T* p){ - return ((uint64_t)p) % (ILP*sizeof(T)) == 0; -} - -template -__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){ - typedef typename std::aligned_storage::type LT; - ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset]; -} - -typedef enum{ - MOMENT_MODE_0 =0, // L2 regularization mode - MOMENT_MODE_1 =1 // Decoupled weight decay mode -} adamMode_t; - -std::tuple multi_tensor_l2norm_mp_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - at::optional per_tensor_python); - -using MATH_T = float; - -template -struct LAMBStage1Functor -{ - __device__ __forceinline__ void operator()( - int chunk_size, - volatile int* noop_gmem, - TensorListMetadata<4>& tl, - const float beta1, - const float beta2, - const float beta3, - const int* step_ptr, - const int bias_correction, - const float epsilon, - adamMode_t mode, - const float decay, - const float* global_grad_norm, - const float* max_global_grad_norm, - const float* found_inf, - const float* inv_scale) - { - if (*noop_gmem) { - return; - } - - float beta1_correction = 1.0f; - float beta2_correction = 1.0f; - if (bias_correction == 1) { - int step = *step_ptr; - beta1_correction = 1 - std::pow(beta1, step); - beta2_correction = 1 - std::pow(beta2, step); - } - - int tensor_loc = tl.block_to_tensor[blockIdx.x]; - int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; - - float clipped_global_grad_norm = (*global_grad_norm) > (*max_global_grad_norm) ? (*global_grad_norm) / (*max_global_grad_norm) : 1.0f; - - T* g = (T*)tl.addresses[0][tensor_loc]; - g += chunk_idx*chunk_size; - - param_t* p = (param_t*)tl.addresses[1][tensor_loc]; - p += chunk_idx*chunk_size; - - param_t* m = (param_t*)tl.addresses[2][tensor_loc]; - m += chunk_idx*chunk_size; - - param_t* v = (param_t*)tl.addresses[3][tensor_loc]; - v += chunk_idx*chunk_size; - - n -= chunk_idx*chunk_size; - - MATH_T r_g[ILP]; - MATH_T r_p[ILP]; - MATH_T r_m[ILP]; - MATH_T r_v[ILP]; - // to make things simple, we put aligned case in a different code path - if(n % ILP == 0 && - chunk_size % ILP == 0 && - is_aligned(g) && - is_aligned(p) && - is_aligned(m) && - is_aligned(v)) - { - T l_g[ILP]; - param_t l_p[ILP]; - param_t l_m[ILP]; - param_t l_v[ILP]; - for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x) - { - // load - load_store(l_g, g, 0, i_start); - if (decay != 0) - load_store(l_p, p, 0, i_start); - load_store(l_m, m, 0, i_start); - load_store(l_v, v, 0, i_start); - // unpack -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - r_g[ii] = l_g[ii] * (*inv_scale); - if (decay == 0) { - r_p[ii] = MATH_T(0); - } - else { - r_p[ii] = l_p[ii]; - } - r_m[ii] = l_m[ii]; - r_v[ii] = l_v[ii]; - } -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - if (mode == MOMENT_MODE_0) { - MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm; - // L2 on scaled grad - scaled_grad = scaled_grad + decay*r_p[ii]; - r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; - r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad; - MATH_T next_m_unbiased = r_m[ii] / beta1_correction; - MATH_T next_v_unbiased = r_v[ii] / beta2_correction; - MATH_T denom = sqrtf(next_v_unbiased) + epsilon; - r_p[ii] = next_m_unbiased / denom; - } - else { - MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm; - r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; - r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad; - MATH_T next_m_unbiased = r_m[ii] / beta1_correction; - MATH_T next_v_unbiased = r_v[ii] / beta2_correction; - MATH_T denom = sqrtf(next_v_unbiased) + epsilon; - r_p[ii] = (next_m_unbiased/denom) + (decay*r_p[ii]); - } - } -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - l_p[ii] = r_p[ii]; - // Difference from APEX's LAMB kernel. `g` and `p` can be different dtypes. - l_g[ii] = r_p[ii]; - l_m[ii] = r_m[ii]; - l_v[ii] = r_v[ii]; - } - // store - load_store(g, l_g, i_start, 0); - load_store(m, l_m, i_start, 0); - load_store(v, l_v, i_start, 0); - } - } - else - { - // see note in multi_tensor_scale_kernel.cu - for(int i_start = 0; - i_start < n && i_start < chunk_size; - i_start += blockDim.x*ILP) - { - MATH_T r_g[ILP]; - MATH_T r_p[ILP]; - MATH_T r_m[ILP]; - MATH_T r_v[ILP]; -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) - { - r_g[ii] = g[i] * (*inv_scale); - // special ?optimization? for lamb stage 1 - if (decay == 0) { - r_p[ii] = MATH_T(0); - } - else { - r_p[ii] = p[i]; - } - r_m[ii] = m[i]; - r_v[ii] = v[i]; - } else { - r_g[ii] = MATH_T(0); - r_p[ii] = MATH_T(0); - r_m[ii] = MATH_T(0); - r_v[ii] = MATH_T(0); - } - } -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - if (mode == MOMENT_MODE_0) { - MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm; - // L2 on scaled grad - scaled_grad = scaled_grad + decay*r_p[ii]; - r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; - r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad; - MATH_T next_m_unbiased = r_m[ii] / beta1_correction; - MATH_T next_v_unbiased = r_v[ii] / beta2_correction; - MATH_T denom = sqrtf(next_v_unbiased) + epsilon; - r_p[ii] = next_m_unbiased / denom; - } - else { - MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm; - r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; - r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad; - MATH_T next_m_unbiased = r_m[ii] / beta1_correction; - MATH_T next_v_unbiased = r_v[ii] / beta2_correction; - MATH_T denom = sqrtf(next_v_unbiased) + epsilon; - r_p[ii] = (next_m_unbiased/denom) + (decay*r_p[ii]); - } - } -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) - { - g[i] = r_p[ii]; - m[i] = r_m[ii]; - v[i] = r_v[ii]; - } - } - } - } - } -}; - -// Step 2 reads in 'update' value and per-tensor param_norm and update_norm. -// It computes new parameter value. -// N == 2: FP32 params, no master params -// N == 3: FP16 params, FP32 master params. -template -struct LAMBStage2Functor -{ - static_assert((N == 2 && std::is_same::value) || (N == 3 && std::is_same::value), ""); - __device__ __forceinline__ void operator()( - int chunk_size, - volatile int* noop_gmem, - TensorListMetadata& tl, - const float* per_tensor_param_norm, - const float* per_tensor_update_norm, - const float* learning_rate, - const float decay, - bool use_nvlamb) - { - if (*noop_gmem) { - return; - } - - int tensor_loc = tl.block_to_tensor[blockIdx.x]; - int tensor_num = tl.start_tensor_this_launch + tensor_loc; - int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; - - MATH_T ratio = *learning_rate; - // nvlamb: apply adaptive learning rate to all parameters - // otherwise, only apply to those with non-zero weight decay - if (use_nvlamb || (decay != 0.0)) - { - float param_norm = per_tensor_param_norm[tensor_num]; - float update_norm = per_tensor_update_norm[tensor_num]; - ratio = (update_norm != 0.0f && param_norm != 0.0f) ? *learning_rate * (param_norm / update_norm) : *learning_rate; - } - - T* update = (T*)tl.addresses[0][tensor_loc]; - update += chunk_idx*chunk_size; - - param_t* p = (param_t*)tl.addresses[1][tensor_loc]; - p += chunk_idx*chunk_size; - - T* out_p; - if (N == 3) { - out_p = (T*)tl.addresses[2][tensor_loc]; - out_p += chunk_idx*chunk_size; - } - - n -= chunk_idx*chunk_size; - - // to make things simple, we put aligned case in a different code path - bool can_use_aligned_path = n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(p) && is_aligned(update); - if (N == 3) { - can_use_aligned_path = can_use_aligned_path && is_aligned(out_p); - } - if(can_use_aligned_path) - { - param_t r_p[ILP]; - T r_update[ILP]; - T r_out_p[ILP]; - for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x) - { - // load - load_store(r_p, p, 0, i_start); - load_store(r_update, update, 0, i_start); - if (N == 3) { - load_store(r_out_p, out_p, 0, i_start); - } -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - r_p[ii] = static_cast(r_p[ii]) - (ratio * static_cast(r_update[ii])); - if (N == 3) { - r_out_p[ii] = r_p[ii]; - } - } - load_store(p, r_p, i_start, 0); - if (N == 3) { - load_store(out_p, r_out_p, i_start, 0); - } - } - } - else - { - for(int i_start = 0; - i_start < n && i_start < chunk_size; - i_start += blockDim.x*ILP) - { - MATH_T r_p[ILP]; - MATH_T r_update[ILP]; -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) - { - r_p[ii] = p[i]; - r_update[ii] = update[i]; - } - } -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - r_p[ii] = r_p[ii] - (ratio * r_update[ii]); - } -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) - { - p[i] = r_p[ii]; - if (N == 3) { - out_p[i] = r_p[ii]; - } - } - } - } - } - } -}; - - -void multi_tensor_lamb_mp_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - at::Tensor lr, - const float beta1, - const float beta2, - const float epsilon, - at::Tensor step, - const int bias_correction, - const float weight_decay, - const int grad_averaging, - const int mode, - at::Tensor global_grad_norm, - at::Tensor max_grad_norm, - at::optional use_nvlamb_python, - at::Tensor found_inf, - at::Tensor inv_scale) -{ - // n_tensors == 5: FP16 model params & FP32 master params - // n_tensors == 4: FP32 model params & NO FP32 master params - const auto n_tensors = tensor_lists.size(); - assert(n_tensors == 4 || n_tensors == 5); - using namespace at; - - bool use_nvlamb = use_nvlamb_python.has_value() ? use_nvlamb_python.value() : false; - - // note(mkozuki): move bias handling below to functor - // Handle bias correction mode - // float bias_correction1 = 1.0f, bias_correction2 = 1.0f; - // if (bias_correction == 1) { - // bias_correction1 = 1 - std::pow(beta1, step); - // bias_correction2 = 1 - std::pow(beta2, step); - // } - - // Handle grad averaging mode - float beta3 = 1.0f; - if (grad_averaging == 1) beta3 = 1 - beta1; - - std::vector> stage1_tensor_lists(tensor_lists.begin(), tensor_lists.begin() + 4); - std::vector> grad_list(tensor_lists.begin(), tensor_lists.begin()+1); - std::vector> param_list(tensor_lists.begin()+1, tensor_lists.begin()+2); - - // Compute per tensor param norm - auto param_norm_tuple = multi_tensor_l2norm_mp_cuda(chunk_size, noop_flag, param_list, true); - - // We now in-place modify grad to store update before compute its norm - // Generally this is not a issue since people modify grad in step() method all the time - // We can also grab list of empty tensor to avoid this, but I'd like to save space/cpu code - if (n_tensors == 4) { - DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1", - multi_tensor_apply<4>( - BLOCK_SIZE, - chunk_size, - noop_flag, - stage1_tensor_lists, - LAMBStage1Functor(), - beta1, - beta2, - beta3, // 1-beta1 or 1 depends on averaging mode - // bias_correction1, - // bias_correction2, - step.data_ptr(), - bias_correction, - epsilon, - (adamMode_t) mode, - weight_decay, - global_grad_norm.data_ptr(), - max_grad_norm.data_ptr(), - found_inf.data_ptr(), - inv_scale.data_ptr()); ) - } else { - DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1", - multi_tensor_apply<4>( - BLOCK_SIZE, - chunk_size, - noop_flag, - stage1_tensor_lists, - LAMBStage1Functor(), - beta1, - beta2, - beta3, // 1-beta1 or 1 depends on averaging mode - // bias_correction1, - // bias_correction2, - step.data_ptr(), - bias_correction, - epsilon, - (adamMode_t) mode, - weight_decay, - global_grad_norm.data_ptr(), - max_grad_norm.data_ptr(), - found_inf.data_ptr(), - inv_scale.data_ptr()); ) - } - - // Compute update norms - auto update_norm_tuple = multi_tensor_l2norm_mp_cuda(chunk_size, noop_flag, grad_list, true); - - std::vector> grad_param_list(tensor_lists.begin(), tensor_lists.begin()+2); - if (n_tensors == 4) { - DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2", - multi_tensor_apply<2>( - BLOCK_SIZE, - chunk_size, - noop_flag, - grad_param_list, - LAMBStage2Functor(), - std::get<1>(param_norm_tuple).data_ptr(), - std::get<1>(update_norm_tuple).data_ptr(), - lr.data_ptr(), - weight_decay, - use_nvlamb); ) - } else { - grad_param_list.push_back(tensor_lists[4]); - DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2", - multi_tensor_apply<3>( - BLOCK_SIZE, - chunk_size, - noop_flag, - grad_param_list, - LAMBStage2Functor(), - std::get<1>(param_norm_tuple).data_ptr(), - std::get<1>(update_norm_tuple).data_ptr(), - lr.data_ptr(), - weight_decay, - use_nvlamb); ) - } - AT_CUDA_CHECK(cudaGetLastError()); - -} diff --git a/csrc/multi_tensor_lamb_stage_1.cu b/csrc/multi_tensor_lamb_stage_1.cu deleted file mode 100644 index 1d5e398..0000000 --- a/csrc/multi_tensor_lamb_stage_1.cu +++ /dev/null @@ -1,151 +0,0 @@ -#include -#include -#include -#include -// Another possibility: -// #include - -#include - -#include "type_shim.h" -#include "multi_tensor_apply.cuh" - -#define BLOCK_SIZE 1024 -#define ILP 4 - -// Step 1 computes the 'update' value of regular Adam optimizer. -template -struct LAMBStage1Functor -{ - __device__ __forceinline__ void operator()( - int chunk_size, - volatile int* noop_gmem, - TensorListMetadata<5>& tl, - const float* per_tensor_decay, - const float beta1, - const float beta2, - const float beta1_correction, - const float beta2_correction, - const float epsilon, - const float clipped_global_grad_norm) - { - // I'd like this kernel to propagate infs/nans. - // if(*noop_gmem == 1) - // return; - - int tensor_loc = tl.block_to_tensor[blockIdx.x]; - int tensor_num = tl.start_tensor_this_launch + tensor_loc; - int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; - - float decay = per_tensor_decay[tensor_num]; - - GRAD_T* g = (GRAD_T*)tl.addresses[0][tensor_loc]; - g += chunk_idx*chunk_size; - - T* p = (T*)tl.addresses[1][tensor_loc]; - p += chunk_idx*chunk_size; - - T* m = (T*)tl.addresses[2][tensor_loc]; - m += chunk_idx*chunk_size; - - T* v = (T*)tl.addresses[3][tensor_loc]; - v += chunk_idx*chunk_size; - - UPD_T* update = (UPD_T*)tl.addresses[4][tensor_loc]; - update += chunk_idx*chunk_size; - - n -= chunk_idx*chunk_size; - - // see note in multi_tensor_scale_kernel.cu - for(int i_start = 0; - i_start < n && i_start < chunk_size; - i_start += blockDim.x*ILP) - { - GRAD_T r_g[ILP]; - T r_p[ILP]; - T r_m[ILP]; - T r_v[ILP]; -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) - { - r_g[ii] = g[i]; - r_p[ii] = p[i]; - r_m[ii] = m[i]; - r_v[ii] = v[i]; - } else { - r_g[ii] = GRAD_T(0); - r_p[ii] = T(0); - r_m[ii] = T(0); - r_v[ii] = T(0); - } - } -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - T scaled_grad = r_g[ii] / clipped_global_grad_norm; - r_m[ii] = r_m[ii] * beta1 + (1-beta1) * scaled_grad; - r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad; - T next_m_unbiased = r_m[ii] / beta1_correction; - T next_v_unbiased = r_v[ii] / beta2_correction; - T denom = std::sqrt(next_v_unbiased) + epsilon; - r_p[ii] = (next_m_unbiased/denom) + (decay*r_p[ii]); - } -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) - { - update[i] = (UPD_T)r_p[ii]; - m[i] = r_m[ii]; - v[i] = r_v[ii]; - } - } - } - } -}; - -void multi_tensor_lamb_stage1_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - at::Tensor per_tensor_decay, - const int step, - const float beta1, - const float beta2, - const float epsilon, - at::Tensor global_grad_norm, - const float max_global_grad_norm) -{ - using namespace at; - - const float* g_grad_norm = global_grad_norm.DATA_PTR(); - float clipped_global_grad_norm = *(g_grad_norm) > max_global_grad_norm ? *(g_grad_norm) / max_global_grad_norm : 1.0f; - float next_step = float(step+1); - float beta1_correction = 1.0f - std::pow(beta1, next_step); - float beta2_correction = 1.0f - std::pow(beta2, next_step); - DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1", - DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[1][0].scalar_type(), 1, "lamb_stage_1", - DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[4][0].scalar_type(), 2, "lamb_stage_1", - multi_tensor_apply<5>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - LAMBStage1Functor(), - per_tensor_decay.DATA_PTR(), - beta1, - beta2, - beta1_correction, - beta2_correction, - epsilon, - clipped_global_grad_norm); ))) - - AT_CUDA_CHECK(cudaGetLastError()); - - // AT_CUDA_CHECK(cudaDeviceSynchronize()); -} diff --git a/csrc/multi_tensor_lamb_stage_2.cu b/csrc/multi_tensor_lamb_stage_2.cu deleted file mode 100644 index e1999ef..0000000 --- a/csrc/multi_tensor_lamb_stage_2.cu +++ /dev/null @@ -1,125 +0,0 @@ -#include -#include -#include -#include -// Another possibility: -// #include - -#include - -#include "type_shim.h" -#include "multi_tensor_apply.cuh" - -#define BLOCK_SIZE 1024 -#define ILP 4 - -using MATH_T = float; - -// Step 2 reads in 'update' value and per-tensor param_norm and update_norm. -// It computes new parameter value. -template -struct LAMBStage2Functor -{ - __device__ __forceinline__ void operator()( - int chunk_size, - volatile int* noop_gmem, - TensorListMetadata<2>& tl, - const float* per_tensor_param_norm, - const float* per_tensor_update_norm, - const float learning_rate, - const float decay, - bool use_nvlamb) - { - // I'd like this kernel to propagate infs/nans. - // if(*noop_gmem == 1) - // return; - - int tensor_loc = tl.block_to_tensor[blockIdx.x]; - int tensor_num = tl.start_tensor_this_launch + tensor_loc; - int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; - - MATH_T ratio = learning_rate; - // nvlamb: apply adaptive learning rate to all parameters - // otherwise, only apply to those with non-zero weight decay - if (use_nvlamb || (decay != 0.0)) - { - float param_norm = per_tensor_param_norm[tensor_num]; - float update_norm = per_tensor_update_norm[tensor_num]; - ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate; - } - - T* p = (T*)tl.addresses[0][tensor_loc]; - p += chunk_idx*chunk_size; - - UPD_T* update = (UPD_T*)tl.addresses[1][tensor_loc]; - update += chunk_idx*chunk_size; - - n -= chunk_idx*chunk_size; - - for(int i_start = 0; - i_start < n && i_start < chunk_size; - i_start += blockDim.x*ILP) - { - T r_p[ILP]; - UPD_T r_update[ILP]; -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) - { - r_p[ii] = p[i]; - r_update[ii] = update[i]; - } - } -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - r_p[ii] = r_p[ii] - (ratio*(T)r_update[ii]); - } -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) - { - p[i] = r_p[ii]; - } - } - } - } -}; - -void multi_tensor_lamb_stage2_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - at::Tensor per_tensor_param_norm, - at::Tensor per_tensor_update_norm, - const float lr, - const float weight_decay, - at::optional use_nvlamb_python) -{ - bool use_nvlamb = use_nvlamb_python.has_value() ? use_nvlamb_python.value() : false; - - using namespace at; - - DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2", - DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[1][0].scalar_type(), 1, "lamb_stage_2", - multi_tensor_apply<2>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - LAMBStage2Functor(), - per_tensor_param_norm.DATA_PTR(), - per_tensor_update_norm.DATA_PTR(), - lr, - weight_decay, - use_nvlamb); )) - - AT_CUDA_CHECK(cudaGetLastError()); - - // AT_CUDA_CHECK(cudaDeviceSynchronize()); -} diff --git a/csrc/multi_tensor_lars.cu b/csrc/multi_tensor_lars.cu deleted file mode 100644 index bc9bbee..0000000 --- a/csrc/multi_tensor_lars.cu +++ /dev/null @@ -1,354 +0,0 @@ -#include -#include -#include -#include - -#include "type_shim.h" -#include "compat.h" -#include "multi_tensor_apply.cuh" - -#include -#include - -#define BLOCK_SIZE 512 -#define ILP 4 - -/** - * Perform fused SGD on multiple buffers - * N: number of tensors - * tl[0] : gradients - * tl[1] : weights - * tl[2] : momentum buffers - * tl[3] : fp16 weights (if appropriate) - * wd : weight_decay (scalar) - * momentum : momentum (scalar) - * dampening : momentum dampening (scalar) - * lr : learning rate (scalar) - * nesterov : enable nesterov (bool) - * first run : necessary for proper momentum handling & init - * wd_after_momentum : apply weight decay _after_ momentum instead of before - **/ - -template -struct LARSFunctor -{ - __device__ __forceinline__ void operator()( - int chunk_size, - volatile int* noop_gmem, - TensorListMetadata& tl, - float *grad_norms, - float *param_norms, - float lr, - float trust_coefficient, - float epsilon, - float weight_decay, - float momentum, - float dampening, - bool nesterov, - bool first_run, - bool wd_after_momentum, - float scale, - const bool is_skipped) { - - // Early exit if we don't need to do anything - if (*noop_gmem) return; - - int tensor_loc = tl.block_to_tensor[blockIdx.x]; - int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; - - n -= chunk_idx * chunk_size; - //n = min(n, chunk_size); - - T_grad* grad_in = (T_grad*) tl.addresses[0][tensor_loc]; - grad_in += chunk_idx * chunk_size; - - T_weight* weight_in = (T_weight*) tl.addresses[1][tensor_loc]; - weight_in += chunk_idx * chunk_size; - - T_weight* mom_in = (T_weight*)tl.addresses[2][tensor_loc]; - mom_in += chunk_idx*chunk_size; - - at::Half *model_weights_out = nullptr; - if(N == 4) - { - model_weights_out = (at::Half*)tl.addresses[3][tensor_loc]; - model_weights_out += chunk_idx*chunk_size; - } - - float scaled_lr; - if (is_skipped) { - scaled_lr = lr; - } - else { - int tensor_offset = tl.start_tensor_this_launch + tensor_loc; - float p_norm = param_norms[tensor_offset]; - float trust_ratio = 1.0; - float g_norm = grad_norms[tensor_offset]; - if (g_norm > 0.0f && p_norm > 0.0f) { - trust_ratio = trust_coefficient * p_norm / (g_norm + p_norm * weight_decay + epsilon); - } - scaled_lr = lr * trust_ratio; - } - - // Non-divergent exit condition for the __syncthreads - float incoming_grads[ILP]; - float incoming_weights[ILP]; - float incoming_moms[ILP]; - for(int i_start = 0; - i_start < n && i_start < chunk_size; - i_start += blockDim.x*ILP) - { - #pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - incoming_grads[ii] = 0; - incoming_weights[ii] = 0; - incoming_moms[ii] = 0; - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) - { - incoming_grads[ii] = static_cast(grad_in[i]); - incoming_weights[ii] = static_cast(weight_in[i]); - incoming_moms[ii] = static_cast(mom_in[i]); - } - } - - // note for clarification to future michael: - // From a pure memory dependency perspective, there's likely no point unrolling - // the write loop, since writes just fire off once their LDGs arrive. - // Put another way, the STGs are dependent on the LDGs, but not on each other. - // There is still compute ILP benefit from unrolling the loop though. - #pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) - { - // apply weight decay before momentum - incoming_grads[ii] += weight_decay * incoming_weights[ii]; - incoming_moms[ii] = incoming_moms[ii] * momentum - scaled_lr * incoming_grads[ii]; - - // adjust the weight and write out - if (nesterov) { - incoming_weights[ii] += incoming_moms[ii] * momentum - scaled_lr * incoming_grads[ii]; - } else { - incoming_weights[ii] += incoming_moms[ii]; - } - - weight_in[i] = static_cast(incoming_weights[ii]); - - // if necessary, write out an fp16 copy of the weights - if(N == 4) - model_weights_out[i] = static_cast(weight_in[i]); - - // also write out the new momentum - //if(momentum != 0.f) - mom_in[i] = static_cast(incoming_moms[ii]); - } - } - } - } -}; - -void multi_tensor_lars_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - at::Tensor grad_norms, - at::Tensor param_norms, - float lr, - float trust_coefficient, - float epsilon, - float weight_decay, - float momentum, - float dampening, - bool nesterov, - bool first_run, - bool wd_after_momentum, - float scale, - const bool is_skipped) -{ - auto num_tensors = tensor_lists.size(); - auto grad_type = tensor_lists[0][0].scalar_type(); - auto weight_type = tensor_lists[1][0].scalar_type(); - - if(num_tensors == 4) { - for(int i = 0; i < tensor_lists[3].size(); i++) { - TORCH_CHECK(tensor_lists[3][i].scalar_type() == at::ScalarType::Half, - "Additional output tensors should always be fp16."); - } - } - - TORCH_CHECK(noop_flag.device() == tensor_lists[0][0].device(), "expected noop flag to be on the same device as tensors"); - - // We have 3 possibilities to handle here, in terms of - // grad_type, param_type, momentum_type, requires_fp16_copy - // 1. fp16, fp16, fp16, No - // 2. fp32, fp32, fp32, No - // 3. fp16, fp32, fp32, Yes - // 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case - // 5. bfp16, bfp16, bfp16, No - // 6. bfp16, fp32, fp32, Yes - // It's easier to hardcode these possibilities than to use - // switches etc. to handle the cross-product of cases where - // we don't want the majority of them. - - // Case 1. fp16, fp16, fp16, No - if(grad_type == at::ScalarType::Half && - weight_type == at::ScalarType::Half && - num_tensors == 3) - { - multi_tensor_apply<3>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - LARSFunctor<3, at::Half, at::Half>(), - grad_norms.DATA_PTR(), - param_norms.DATA_PTR(), - lr, - trust_coefficient, - epsilon, - weight_decay, - momentum, - dampening, - nesterov, - first_run, - wd_after_momentum, - scale, - is_skipped); - } - // Case 2. fp32, fp32, fp32, No - else if(grad_type == at::ScalarType::Float && - weight_type == at::ScalarType::Float && - num_tensors == 3) - { - multi_tensor_apply<3>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - LARSFunctor<3, float, float>(), - grad_norms.DATA_PTR(), - param_norms.DATA_PTR(), - lr, - trust_coefficient, - epsilon, - weight_decay, - momentum, - dampening, - nesterov, - first_run, - wd_after_momentum, - scale, - is_skipped); - } - // Case 3. fp16, fp32, fp32, Yes - else if(grad_type == at::ScalarType::Half && - weight_type == at::ScalarType::Float && - num_tensors == 4) - { - multi_tensor_apply<4>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - LARSFunctor<4, at::Half, float>(), - grad_norms.DATA_PTR(), - param_norms.DATA_PTR(), - lr, - trust_coefficient, - epsilon, - weight_decay, - momentum, - dampening, - nesterov, - first_run, - wd_after_momentum, - scale, - is_skipped); - } - // Case 4. fp32, fp32, fp32, Yes - else if(grad_type == at::ScalarType::Float && - weight_type == at::ScalarType::Float && - num_tensors == 4) - { - multi_tensor_apply<4>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - LARSFunctor<4, float, float>(), - grad_norms.DATA_PTR(), - param_norms.DATA_PTR(), - lr, - trust_coefficient, - epsilon, - weight_decay, - momentum, - dampening, - nesterov, - first_run, - wd_after_momentum, - scale, - is_skipped); - } - // Case 5. bfp16, bfp16, bfp16, No - else if(grad_type == at::ScalarType::BFloat16 && - weight_type == at::ScalarType::BFloat16 && - num_tensors == 3) - { - multi_tensor_apply<3>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - LARSFunctor<3, at::BFloat16, at::BFloat16>(), - grad_norms.DATA_PTR(), - param_norms.DATA_PTR(), - lr, - trust_coefficient, - epsilon, - weight_decay, - momentum, - dampening, - nesterov, - first_run, - wd_after_momentum, - scale, - is_skipped); - } - // Case 6. bfp16, fp32, fp32, Yes - else if(grad_type == at::ScalarType::BFloat16 && - weight_type == at::ScalarType::Float && - num_tensors == 4) - { - multi_tensor_apply<4>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - LARSFunctor<4, at::BFloat16, float>(), - grad_norms.DATA_PTR(), - param_norms.DATA_PTR(), - lr, - trust_coefficient, - epsilon, - weight_decay, - momentum, - dampening, - nesterov, - first_run, - wd_after_momentum, - scale, - is_skipped); - } - else - { - AT_ERROR("multi_tensor_lars only supports some combinations of gradient & weight types. Given: ", - "gradient: ", grad_type, ", weight: ", weight_type, ", num_lists: ", num_tensors); - } - - AT_CUDA_CHECK(cudaGetLastError()); -} diff --git a/csrc/multi_tensor_novograd.cu b/csrc/multi_tensor_novograd.cu deleted file mode 100644 index 4da815d..0000000 --- a/csrc/multi_tensor_novograd.cu +++ /dev/null @@ -1,188 +0,0 @@ -#include -#include -#include -#include -// Another possibility: -// #include - -#include - -#include "type_shim.h" -#include "multi_tensor_apply.cuh" - -#define BLOCK_SIZE 1024 -#define ILP 4 - -typedef enum{ - MOMENT_MODE_0 =0, // Novograd paper mode, momentum caculation with denom then decay inside - MOMENT_MODE_1 =1 // Decoupled weight decay mode -} momentMode_t; - -void multi_tensor_norm_out_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - at::Tensor out, - const float alpha, - const float beta, - const int norm_type); - -using MATH_T = float; - -template -struct NovoGradFunctor -{ - __device__ __forceinline__ void operator()( - int chunk_size, - volatile int* noop_gmem, - TensorListMetadata<3>& tl, - const float beta1, - const float beta2, - const float beta3, - const float beta1_correction, - const float beta2_correction, - const float epsilon, - const float lr, - momentMode_t m_mode, - const float decay, - const float* per_tensor_grad_norm) - { - // I'd like this kernel to propagate infs/nans. - // if(*noop_gmem == 1) - // return; - - int tensor_loc = tl.block_to_tensor[blockIdx.x]; - int tensor_num = tl.start_tensor_this_launch + tensor_loc; - int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; - - float grad_norm = per_tensor_grad_norm[tensor_num]; - - T* g = (T*)tl.addresses[0][tensor_loc]; - g += chunk_idx*chunk_size; - - T* p = (T*)tl.addresses[1][tensor_loc]; - p += chunk_idx*chunk_size; - - T* m = (T*)tl.addresses[2][tensor_loc]; - m += chunk_idx*chunk_size; - - n -= chunk_idx*chunk_size; - - // see note in multi_tensor_scale_kernel.cu - for(int i_start = 0; - i_start < n && i_start < chunk_size; - i_start += blockDim.x*ILP) - { - MATH_T r_g[ILP]; - MATH_T r_p[ILP]; - MATH_T r_m[ILP]; -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) - { - r_g[ii] = g[i]; - r_p[ii] = p[i]; - r_m[ii] = m[i]; - } else { - r_g[ii] = MATH_T(0); - r_p[ii] = MATH_T(0); - r_m[ii] = MATH_T(0); - } - } -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - if (m_mode == MOMENT_MODE_0) { - MATH_T next_v_unbiased = grad_norm / beta2_correction; - MATH_T denom = next_v_unbiased + epsilon; - r_g[ii] = (r_g[ii] / denom) + (decay * r_p[ii]); - r_m[ii] = beta1 * r_m[ii] + beta3 * r_g[ii]; - MATH_T next_m_unbiased = r_m[ii] / beta1_correction; - r_p[ii] = r_p[ii] - (lr * next_m_unbiased); - } - else { - r_m[ii] = beta1 * r_m[ii] + beta3 * r_g[ii]; - MATH_T next_m_unbiased = r_m[ii] / beta1_correction; - MATH_T next_v_unbiased = grad_norm / beta2_correction; - MATH_T denom = next_v_unbiased + epsilon; - MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]); - r_p[ii] = r_p[ii] - (lr * update); - } - } -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) - { - p[i] = r_p[ii]; - m[i] = r_m[ii]; - } - } - } - } -}; - -void multi_tensor_novograd_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - at::Tensor grad_norms, - const float lr, - const float beta1, - const float beta2, - const float epsilon, - const int step, - const int bias_correction, - const float weight_decay, - const int grad_averaging, - const int moment_mode, - const int norm_type) -{ - using namespace at; - - // Handle bias correction mode - float bias_correction1 = 1.0f, bias_correction2 = 1.0f; - if (bias_correction == 1) { - bias_correction1 = 1 - std::pow(beta1, step); - bias_correction2 = std::sqrt(1 - std::pow(beta2, step)); - } - - // Handle grad averaging mode - float beta3 = 1; - if (grad_averaging == 1) beta3 = 1 - beta1; - - std::vector> grad_list(tensor_lists.begin(), tensor_lists.begin()+1); - - // Compute and update grad norm - // Here use a per tensor norm, and blend new norm(n) and old norm(gn) by - // L-2: gn = sqrt(a * gn^2 + b * n^2) - // L-inf: gn = a * gn + b * n - multi_tensor_norm_out_cuda(chunk_size, noop_flag, grad_list, grad_norms, beta2, (1.0f - beta2), norm_type); - - // Assume single type across p,g,m1,m2 now - DISPATCH_DOUBLE_FLOAT_AND_HALF_AND_BFLOAT16( - tensor_lists[0][0].scalar_type(), 0, "novograd", - multi_tensor_apply<3>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - NovoGradFunctor(), - beta1, - beta2, - beta3, // 1-beta1 or 1 depends on averaging mode - bias_correction1, - bias_correction2, - epsilon, - lr, - (momentMode_t) moment_mode, - weight_decay, - grad_norms.DATA_PTR()); ) - - AT_CUDA_CHECK(cudaGetLastError()); - -} diff --git a/csrc/multi_tensor_scale_kernel.cu b/csrc/multi_tensor_scale_kernel.cu deleted file mode 100644 index 5386f4d..0000000 --- a/csrc/multi_tensor_scale_kernel.cu +++ /dev/null @@ -1,136 +0,0 @@ -#include -#include -#include -#include -// Another possibility: -// #include - -#include -// Stringstream is a big hammer, but I want to rely on operator<< for dtype. -#include - -#include "type_shim.h" -#include "multi_tensor_apply.cuh" - -#define BLOCK_SIZE 1024 -#define ILP 4 - -template -__device__ __forceinline__ bool is_aligned(T* p){ - return ((uint64_t)p) % (ILP*sizeof(T)) == 0; -} - -template -__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){ - typedef typename std::aligned_storage::type LT; - ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset]; -} - -template -struct ScaleFunctor -{ - __device__ __forceinline__ void operator()( - int chunk_size, - volatile int* noop_gmem, - TensorListMetadata<2>& tl, - float scale) - { - // I'd like this kernel to propagate infs/nans. - // if(*noop_gmem == 1) - // return; - - int tensor_loc = tl.block_to_tensor[blockIdx.x]; - int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; - - in_t* in = (in_t*)tl.addresses[0][tensor_loc]; - in += chunk_idx*chunk_size; - - out_t* out = (out_t*)tl.addresses[1][tensor_loc]; - out += chunk_idx*chunk_size; - - n -= chunk_idx*chunk_size; - - bool finite = true; - in_t r_in[ILP]; - out_t r_out[ILP]; - - // to make things simple, we put aligned case in a different code path - if(n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(in) && is_aligned(out)) - { - for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x) - { - // load - load_store(r_in, in, 0 , i_start); -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - r_out[ii] = static_cast(r_in[ii]) * scale; - finite = finite && isfinite(r_in[ii]); - } - // store - load_store(out, r_out, i_start, 0); - } - } - else - { - // Non-divergent exit condition for __syncthreads, not necessary here - for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP) - { -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - r_in[ii] = 0; - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) - r_in[ii] = in[i]; - } - // note for clarification to future michael: - // From a pure memory dependency perspective, there's likely no point unrolling - // the write loop, since writes just fire off once their LDGs arrive. - // Put another way, the STGs are dependent on the LDGs, but not on each other. - // There is still compute ILP benefit from unrolling the loop though. -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - r_out[ii] = static_cast(r_in[ii]) * scale; - finite = finite && isfinite(r_in[ii]); - } -#pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) - out[i] = r_out[ii]; - } - } - } - if(!finite) - *noop_gmem = 1; // Blindly fire off a write. These will race but that's ok. - } -}; - -void multi_tensor_scale_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - float scale) -{ - using namespace at; - // The output (downscaled) type is always float. - // If build times suffer, think about where to put this dispatch, - // and what logic should be moved out of multi_tensor_apply. - - DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_scale_cuda", - DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[1][0].scalar_type(), 1, "multi_tensor_scale_cuda", - multi_tensor_apply<2>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - ScaleFunctor(), - scale); )) - AT_CUDA_CHECK(cudaGetLastError()); - - // AT_CUDA_CHECK(cudaDeviceSynchronize()); -} diff --git a/csrc/multi_tensor_sgd_kernel.cu b/csrc/multi_tensor_sgd_kernel.cu deleted file mode 100644 index 5d1f685..0000000 --- a/csrc/multi_tensor_sgd_kernel.cu +++ /dev/null @@ -1,322 +0,0 @@ -#include -#include -#include -#include -#include "multi_tensor_apply.cuh" -#include "compat.h" - -#include -#include - -#define BLOCK_SIZE 1024 -#define ILP 4 - -/** - * Perform fused SGD on multiple buffers - * N: number of tensors - * tl[0] : gradients - * tl[1] : weights - * tl[2] : momentum buffers - * tl[3] : fp16 weights (if appropriate) - * wd : weight_decay (scalar) - * momentum : momentum (scalar) - * dampening : momentum dampening (scalar) - * lr : learning rate (scalar) - * nesterov : enable nesterov (bool) - * first run : necessary for proper momentum handling & init - * wd_after_momentum : apply weight decay _after_ momentum instead of before - **/ -template -struct SGDFunctor -{ - __device__ __forceinline__ void operator()( - int chunk_size, - volatile int* noop_gmem, - TensorListMetadata& tl, - float wd, - float momentum, - float dampening, - float lr, - bool nesterov, - bool first_run, - bool wd_after_momentum, - float scale) - { - // Early exit if we don't need to do anything - if (*noop_gmem) return; - - int tensor_loc = tl.block_to_tensor[blockIdx.x]; - int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; - - T_grad* grad_in = (T_grad*)tl.addresses[0][tensor_loc]; - grad_in += chunk_idx*chunk_size; - - T_weight* weight_in = (T_weight*)tl.addresses[1][tensor_loc]; - weight_in += chunk_idx*chunk_size; - - T_weight* mom_in = (T_weight*)tl.addresses[2][tensor_loc]; - mom_in += chunk_idx*chunk_size; - - at::Half *model_weights_out = nullptr; - if(N == 4) - { - model_weights_out = (at::Half*)tl.addresses[3][tensor_loc]; - model_weights_out += chunk_idx*chunk_size; - } - - n -= chunk_idx*chunk_size; - - // Non-divergent exit condition for the __syncthreads - float incoming_grads[ILP]; - float incoming_weights[ILP]; - float incoming_moms[ILP]; - for(int i_start = 0; - i_start < n && i_start < chunk_size; - i_start += blockDim.x*ILP) - { - #pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - incoming_grads[ii] = 0; - incoming_weights[ii] = 0; - incoming_moms[ii] = 0; - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) - { - incoming_grads[ii] = static_cast(grad_in[i])*scale; - incoming_weights[ii] = static_cast(weight_in[i]); - incoming_moms[ii] = static_cast(mom_in[i]); - } - } - - // note for clarification to future michael: - // From a pure memory dependency perspective, there's likely no point unrolling - // the write loop, since writes just fire off once their LDGs arrive. - // Put another way, the STGs are dependent on the LDGs, but not on each other. - // There is still compute ILP benefit from unrolling the loop though. - #pragma unroll - for(int ii = 0; ii < ILP; ii++) - { - int i = i_start + threadIdx.x + ii*blockDim.x; - if(i < n && i < chunk_size) - { - // apply weight decay before momentum if necessary - if(wd != 0.f && !wd_after_momentum) - incoming_grads[ii] += wd * incoming_weights[ii]; - - if(momentum != 0.f) - { - if(!first_run) - incoming_moms[ii] = incoming_moms[ii] * momentum + (1.f - dampening) * incoming_grads[ii]; - else // initialize momentums to current incoming grads - incoming_moms[ii] = incoming_grads[ii]; - - if(nesterov) - incoming_grads[ii] += momentum * incoming_moms[ii]; - else - incoming_grads[ii] = incoming_moms[ii]; - } - - // Apply WD after momentum if desired - if(wd != 0.f && wd_after_momentum) - incoming_grads[ii] += wd * incoming_weights[ii]; - - // adjust the weight and write out - weight_in[i] += (-lr * incoming_grads[ii]); - - // if necessary, write out an fp16 copy of the weights - if(N == 4) - model_weights_out[i] = static_cast(weight_in[i]); - - // also write out the new momentum - if(momentum != 0.f) - mom_in[i] = incoming_moms[ii]; - } - } - } - } -}; - -void multi_tensor_sgd_cuda( - int chunk_size, - at::Tensor noop_flag, - std::vector> tensor_lists, - float wd, - float momentum, - float dampening, - float lr, - bool nesterov, - bool first_run, - bool wd_after_momentum, - float scale) -{ - auto num_tensors = tensor_lists.size(); - auto grad_type = tensor_lists[0][0].scalar_type(); - auto weight_type = tensor_lists[1][0].scalar_type(); - - if(num_tensors == 4) - for(int i = 0; i < tensor_lists[3].size(); i++) - TORCH_CHECK(tensor_lists[3][i].scalar_type() == at::ScalarType::Half, - "Additional output tensors should always be fp16."); - - TORCH_CHECK(noop_flag.device() == tensor_lists[0][0].device(), "expected noop flag to be on the same device as tensors"); - - // We have 3 possibilities to handle here, in terms of - // grad_type, param_type, momentum_type, requires_fp16_copy - // 1. fp16, fp16, fp16, No - // 2. fp32, fp32, fp32, No - // 3. fp16, fp32, fp32, Yes - // 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case - // 5. bfp16, bfp16, bfp16, No - // 6. bfp16, fp32, fp32, Yes - // It's easier to hardcode these possibilities than to use - // switches etc. to handle the cross-product of cases where - // we don't want the majority of them. - - // Case 1. fp16, fp16, fp16, No - if(grad_type == at::ScalarType::Half && - weight_type == at::ScalarType::Half && - num_tensors == 3) - { - multi_tensor_apply<3>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - SGDFunctor<3, at::Half, at::Half>(), - wd, - momentum, - dampening, - lr, - nesterov, - first_run, - wd_after_momentum, - scale); - } - // Case 2. fp16, fp32, fp32, No - // else if (grad_type == at::ScalarType::Half && - // weight_type == at::ScalarType::Float && - // num_tensors == 3) { - // multi_tensor_apply<3>( - // BLOCK_SIZE, - // chunk_size, - // noop_flag, - // tensor_lists, - // SGDFunctor<3, at::Half, float>(), - // wd, - // momentum, - // dampening, - // lr, - // nesterov, - // first_run, - // wd_after_momentum); - // } - // Case 2. fp32, fp32, fp32, No - else if(grad_type == at::ScalarType::Float && - weight_type == at::ScalarType::Float && - num_tensors == 3) - { - multi_tensor_apply<3>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - SGDFunctor<3, float, float>(), - wd, - momentum, - dampening, - lr, - nesterov, - first_run, - wd_after_momentum, - scale); - } - // Case 3. fp16, fp32, fp32, Yes - else if(grad_type == at::ScalarType::Half && - weight_type == at::ScalarType::Float && - num_tensors == 4) - { - multi_tensor_apply<4>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - SGDFunctor<4, at::Half, float>(), - wd, - momentum, - dampening, - lr, - nesterov, - first_run, - wd_after_momentum, - scale); - } - // Case 4. fp32, fp32, fp32, Yes - else if(grad_type == at::ScalarType::Float && - weight_type == at::ScalarType::Float && - num_tensors == 4) - { - multi_tensor_apply<4>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - SGDFunctor<4, float, float>(), - wd, - momentum, - dampening, - lr, - nesterov, - first_run, - wd_after_momentum, - scale); - } - // Case 5. bfp16, bfp16, bfp16, No - else if(grad_type == at::ScalarType::BFloat16 && - weight_type == at::ScalarType::BFloat16 && - num_tensors == 3) - { - multi_tensor_apply<3>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - SGDFunctor<3, at::BFloat16, at::BFloat16>(), - wd, - momentum, - dampening, - lr, - nesterov, - first_run, - wd_after_momentum, - scale); - } - // Case 6. bfp16, fp32, fp32, Yes - else if(grad_type == at::ScalarType::BFloat16 && - weight_type == at::ScalarType::Float && - num_tensors == 4) - { - multi_tensor_apply<4>( - BLOCK_SIZE, - chunk_size, - noop_flag, - tensor_lists, - SGDFunctor<4, at::BFloat16, float>(), - wd, - momentum, - dampening, - lr, - nesterov, - first_run, - wd_after_momentum, - scale); - } - else - { - AT_ERROR("multi_tensor_sgd only supports some combinations of gradient & weight types. Given: ", - "gradient: ", grad_type, ", weight: ", weight_type, ", num_lists: ", num_tensors); - } - - AT_CUDA_CHECK(cudaGetLastError()); -} diff --git a/csrc/syncbn.cpp b/csrc/syncbn.cpp deleted file mode 100644 index 578a6e6..0000000 --- a/csrc/syncbn.cpp +++ /dev/null @@ -1,109 +0,0 @@ -#include -#include - -#include - -// returns {mean,biased_var} -// implemented using welford -std::vector welford_mean_var_CUDA(const at::Tensor input); - -// reduces array of mean/var across processes -// returns global {mean,inv_std,biased_var} -// implemented using welford -std::vector welford_parallel_CUDA(const at::Tensor mean_feature_nodes, - const at::Tensor var_biased_feature_nodes, - const at::Tensor numel, - const float eps); - -// elementwise BN operation, returns output -// input/weight/shift should have identical data type; -// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype) -at::Tensor batchnorm_forward_CUDA(const at::Tensor input, - const at::Tensor mean, - const at::Tensor inv_std, - const at::optional weight, - const at::optional shift); - -// backward BN operation, returns {sum_dy, sum_dy_xmu, grad_weight, grad_bias} -// grad_output/input should have identical data type; -// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype) -// implemented using kahan summation -std::vector reduce_bn_CUDA(const at::Tensor grad_output, - const at::Tensor input, - const at::Tensor mean, - const at::Tensor inv_std, - const at::optional weight); - -// elementwise backward BN operation, returns grad_input -// grad_output/input/weight precision could be fp16/fp32; -// mean/inv_std/sum_dy/sum_dy_xmu precision is fp32 -at::Tensor batchnorm_backward_CUDA(const at::Tensor grad_output, - const at::Tensor input, - const at::Tensor mean, - const at::Tensor inv_std, - const at::optional weight, - const at::Tensor sum_dy, - const at::Tensor sum_dy_xmu, - const at::Tensor count); - -// returns {mean, biased_var} -// implemented using welford -// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL -std::vector welford_mean_var_c_last_CUDA(const at::Tensor input); - -// elementwise BN operation, returns output -// input/weight/shift should have identical data type; -// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype) -// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL -at::Tensor batchnorm_forward_c_last_CUDA(const at::Tensor input, - const at::optional z, - const at::Tensor mean, - const at::Tensor inv_std, - const at::optional weight, - const at::optional shift, - const bool fuse_relu); - -// backward BN operation, returns {sum_dy, sum_dy_xmu, grad_weight, grad_bias} -// grad_output/input should have identical data type; -// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype) -// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL -std::vector reduce_bn_c_last_CUDA(const at::Tensor grad_output, - const at::Tensor input, - const at::Tensor mean, - const at::Tensor inv_std, - const at::optional weight); - -// elementwise backward BN operation, returns grad_input -// grad_output/input/weight precision could be fp16/fp32; -// mean/inv_std/sum_dy/sum_dy_xmu precision is fp32 -// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL -at::Tensor batchnorm_backward_c_last_CUDA(const at::Tensor grad_output, - const at::Tensor input, - const at::Tensor mean, - const at::Tensor inv_std, - const at::optional weight, - const at::Tensor sum_dy, - const at::Tensor sum_dy_xmu, - const at::Tensor count); - -at::Tensor relu_backward_c_last_CUDA(const at::Tensor grad_output, - const at::Tensor input, - const at::optional z, - const at::Tensor mean, - const at::Tensor inv_std, - const at::optional weight, - const at::optional shift); - - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("welford_mean_var", &welford_mean_var_CUDA, "welford mean variance"); - m.def("welford_parallel", &welford_parallel_CUDA, "welford parallel reduce mean variance"); - m.def("batchnorm_forward", &batchnorm_forward_CUDA, "batchnorm forward"); - m.def("reduce_bn", &reduce_bn_CUDA, "batchnorm backward reduce grad sum and bias/weight grad"); - m.def("batchnorm_backward", &batchnorm_backward_CUDA, "batchnorm backward dgrad"); - m.def("welford_mean_var_c_last", &welford_mean_var_c_last_CUDA, "welford mean variance nhwc"); - m.def("batchnorm_forward_c_last", &batchnorm_forward_c_last_CUDA, "batchnorm forward nhwc"); - m.def("reduce_bn_c_last", &reduce_bn_c_last_CUDA, "batchnorm backwards reduce grad sum and bias/weight grad nhwc"); - m.def("batchnorm_backward_c_last", &batchnorm_backward_c_last_CUDA, "batchnorm backward dgrad nhwc"); - m.def("relu_bw_c_last", &relu_backward_c_last_CUDA, "relu_bw_c_last"); -} diff --git a/csrc/type_shim.h b/csrc/type_shim.h deleted file mode 100644 index b4df933..0000000 --- a/csrc/type_shim.h +++ /dev/null @@ -1,491 +0,0 @@ -#include -#include "compat.h" - -// Forward/backward compatiblity hack around -// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288 -// pending more future-proof guidance from upstream. -// struct TypeShim -// { -// const at::Type& payload; -// TypeShim(const at::Type& type) : payload(type) {} -// // Enable trivial conversion to a const at::Type& for pre-3aeb78 -// operator const at::Type&(){ return payload; }; -// // Enable dispatch switch statements to take *this directly for post-3aeb78 -// //operator at::ScalarType(){ return payload.; }; -// }; - -#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ - switch(TYPE) \ - { \ - case at::ScalarType::Float: \ - { \ - using scalar_t_##LEVEL = float; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Half: \ - { \ - using scalar_t_##LEVEL = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ - } - - -#define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, LEVEL, NAME, ...) \ - switch(TYPE) \ - { \ - case at::ScalarType::Float: \ - { \ - using scalar_t_##LEVEL = float; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Half: \ - { \ - using scalar_t_##LEVEL = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::BFloat16: \ - { \ - using scalar_t_##LEVEL = at::BFloat16; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ - } - - -#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \ - switch(TYPE) \ - { \ - case at::ScalarType::Float: \ - { \ - using scalar_t_##LEVEL = float; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Half: \ - { \ - using scalar_t_##LEVEL = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Byte: \ - { \ - using scalar_t_##LEVEL = uint8_t; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ - } - - -#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \ - switch(TYPE) \ - { \ - case at::ScalarType::Double: \ - { \ - using scalar_t_##LEVEL = double; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Float: \ - { \ - using scalar_t_##LEVEL = float; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Half: \ - { \ - using scalar_t_##LEVEL = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ - } - - -#define DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(TYPE, LEVEL, NAME, ...) \ - switch(TYPE) \ - { \ - case at::ScalarType::Double: \ - { \ - using scalar_t_##LEVEL = double; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Float: \ - { \ - using scalar_t_##LEVEL = float; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Half: \ - { \ - using scalar_t_##LEVEL = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::BFloat16: \ - { \ - using scalar_t_##LEVEL = at::BFloat16; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ - } - - - #define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \ - switch(TYPE) \ - { \ - case at::ScalarType::Double: \ - { \ - using scalar_t_##LEVEL = double; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Float: \ - { \ - using scalar_t_##LEVEL = float; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ - } - -// TODO: We might have come up with an optimal set of dispatch macros by -// changing the signature to have an integer suffix of number of types -// to dispatch for as defined in upstream (e.g AT_DISPATCH_FLOATING_TYPES_AND2) -// Refactor once all the extension ops are enabled. -#define DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(TYPE, LEVEL, NAME, ...) \ - switch(TYPE) \ - { \ - case at::ScalarType::Float: \ - { \ - using scalar_t_##LEVEL = float; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Half: \ - { \ - using scalar_t_##LEVEL = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::BFloat16: \ - { \ - using scalar_t_##LEVEL = at::BFloat16; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ - } - - -#define DISPATCH_DOUBLE_FLOAT_AND_HALF_AND_BFLOAT16(TYPE, LEVEL, NAME, ...) \ - switch(TYPE) \ - { \ - case at::ScalarType::Double: \ - { \ - using scalar_t_##LEVEL = double; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Float: \ - { \ - using scalar_t_##LEVEL = float; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Half: \ - { \ - using scalar_t_##LEVEL = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::BFloat16: \ - { \ - using scalar_t_##LEVEL = at::BFloat16; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ - } - - #define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \ - switch(TYPE) \ - { \ - case at::ScalarType::Half: \ - { \ - using scalar_t = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::BFloat16: \ - { \ - using scalar_t = at::BFloat16; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ - } - - - #define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ - switch(TYPEIN) \ - { \ - case at::ScalarType::Float: \ - { \ - using scalar_t_in = float; \ - switch(TYPEOUT) \ - { \ - case at::ScalarType::Float: \ - { \ - using scalar_t_out = float; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Half: \ - { \ - using scalar_t_out = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::BFloat16: \ - { \ - using scalar_t_out = at::BFloat16; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ - } \ - break; \ - } \ - case at::ScalarType::Half: \ - { \ - using scalar_t_in = at::Half; \ - using scalar_t_out = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::BFloat16: \ - { \ - using scalar_t_in = at::BFloat16; \ - using scalar_t_out = at::BFloat16; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \ - } - - - #define DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ - switch(TYPEIN) \ - { \ - case at::ScalarType::Double: \ - { \ - using scalar_t_in = double; \ - switch(TYPEOUT) \ - { \ - case at::ScalarType::Double: \ - { \ - using scalar_t_out = double; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Float: \ - { \ - using scalar_t_out = float; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Half: \ - { \ - using scalar_t_out = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::BFloat16: \ - { \ - using scalar_t_out = at::BFloat16; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ - } \ - break; \ - } \ - case at::ScalarType::Float: \ - { \ - using scalar_t_in = float; \ - switch(TYPEOUT) \ - { \ - case at::ScalarType::Float: \ - { \ - using scalar_t_out = float; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::Half: \ - { \ - using scalar_t_out = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::BFloat16: \ - { \ - using scalar_t_out = at::BFloat16; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ - } \ - break; \ - } \ - case at::ScalarType::Half: \ - { \ - using scalar_t_in = at::Half; \ - using scalar_t_out = at::Half; \ - __VA_ARGS__; \ - break; \ - } \ - case at::ScalarType::BFloat16: \ - { \ - using scalar_t_in = at::BFloat16; \ - using scalar_t_out = at::BFloat16; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \ - } - - -template -__device__ __forceinline__ T reduce_block_into_lanes - (T *x, - T val, - int lanes=1, - bool share_result=false) // lanes is intended to be <= 32. -{ - int tid = threadIdx.x + threadIdx.y*blockDim.x; - int blockSize = blockDim.x*blockDim.y; // blockSize is intended to be a multiple of 32. - auto double_warp_size = warpSize * 2; - - if(blockSize >= double_warp_size) - { - x[tid] = val; - __syncthreads(); - } - - #pragma unroll - for(int i = (blockSize >> 1); i >= double_warp_size; i >>= 1) - { - if(tid < i) - x[tid] = x[tid] + x[tid+i]; - __syncthreads(); - } - - T final; - - if(tid < warpSize) - { - if(blockSize >= double_warp_size) - final = x[tid] + x[tid + warpSize]; - else - final = val; - // __SYNCWARP(); - - #pragma unroll - for(int i = warpSize / 2; i >= lanes; i >>= 1) { -#ifdef __HIP_PLATFORM_HCC__ - final = final + __shfl_down(final, i); -#else - final = final + __shfl_down_sync(0xffffffff, final, i); -#endif - } - } - - if(share_result) - { - if(tid < lanes) - x[tid] = final; // EpilogueOp - // Make sure the smem result is visible to all warps. - __syncthreads(); - } - - return final; -} - -template -__device__ __forceinline__ T reduce_block_into_lanes_max_op - (T *x, - T val, - int lanes=1, - bool share_result=false) // lanes is intended to be <= 32. -{ - int tid = threadIdx.x + threadIdx.y*blockDim.x; - int blockSize = blockDim.x*blockDim.y; // blockSize is intended to be a multiple of 32. - auto double_warp_size = warpSize * 2; - - if(blockSize >= double_warp_size) - { - x[tid] = val; - __syncthreads(); - } - - #pragma unroll - for(int i = (blockSize >> 1); i >= double_warp_size; i >>= 1) - { - if(tid < i) - x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid+i])); - __syncthreads(); - } - - T final; - - if(tid < warpSize) - { - if(blockSize >= double_warp_size) - final = fmaxf(fabsf(x[tid]), fabsf(x[tid + warpSize])); - else - final = val; - // __SYNCWARP(); - - #pragma unroll - for(int i = 16; i >= lanes; i >>= 1) { -#ifdef __HIP_PLATFORM_HCC__ - final = fmaxf(fabsf(final), fabsf(__shfl_down(final, i))); -#else - final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i))); -#endif - } - } - - if(share_result) - { - if(tid < lanes) - x[tid] = final; // EpilogueOp - // Make sure the smem result is visible to all warps. - __syncthreads(); - } - - return final; -} diff --git a/csrc/utils.h b/csrc/utils.h deleted file mode 100644 index ef2bc57..0000000 --- a/csrc/utils.h +++ /dev/null @@ -1,27 +0,0 @@ - -#pragma once - -#include -#include - -inline bool parseEnvVarFlag(const char* envVarName) { - char* stringValue = std::getenv(envVarName); - if (stringValue != nullptr) { - int val; - try { - val = std::stoi(stringValue); - } catch (std::exception& e) { - TORCH_CHECK(false, - "Invalid value for environment variable: " + std::string(envVarName)); - } - if (val == 1) { - return true; - } else if (val == 0) { - return false; - } else { - TORCH_CHECK(false, - "Invalid value for environment variable: " + std::string(envVarName)); - } - } - return false; -} \ No newline at end of file diff --git a/csrc/welford.cu b/csrc/welford.cu deleted file mode 100644 index 92dc5c1..0000000 --- a/csrc/welford.cu +++ /dev/null @@ -1,1550 +0,0 @@ -#include -#include -#include -#include - -#include -#include - -#include - -#include "type_shim.h" -#include "compat.h" - -#if defined __HIP_PLATFORM_HCC__ -#define SHFL_DOWN(mask,val,i) __shfl_down(val, i) -#else -#define SHFL_DOWN __shfl_down_sync -#endif - -__device__ __forceinline__ int lastpow2(int n) -{ - int out = 1 << (31 - __clz(n)); - if(n == out) - out >>= 1; - return out; -} - -__host__ __forceinline__ int h_next_pow2(unsigned int n) { - n--; - n |= (n >> 1); - n |= (n >> 2); - n |= (n >> 4); - n |= (n >> 8); - n |= (n >> 16); - return ++n; -} - -__host__ __forceinline__ int h_last_pow2(unsigned int n) { - n |= (n >> 1); - n |= (n >> 2); - n |= (n >> 4); - n |= (n >> 8); - n |= (n >> 16); - return n - (n >> 1); -} - -#ifdef __HIP_PLATFORM_HCC__ -#define WARP_SIZE 64 -#else -#define WARP_SIZE 32 -#endif - -template -__device__ __forceinline__ T warp_reduce_sum(T val) -{ - #pragma unroll - for(int i = WARP_SIZE/2; i > 0; i >>= 1) - val = val + SHFL_DOWN(0xffffffff, val, i); - return val; -} - -template -__device__ __forceinline__ T reduce_block(T *x, T val) -{ - int tid = threadIdx.y*blockDim.x + threadIdx.x; - int blockSize = blockDim.x * blockDim.y; - int lane = tid % WARP_SIZE; - int wid = tid / WARP_SIZE; - - if (blockSize > WARP_SIZE) { - val = warp_reduce_sum(val); - if (lane == 0) - x[wid] = val; - - __syncthreads(); - - val = (tid < blockSize / WARP_SIZE? x[lane] : T(0)); - } - - if(wid==0) val = warp_reduce_sum(val); - - return val; -} - -#define ELEMENTS_PER_ITER 4 // enables concurrency within each thread to hide latency -#define ELEMENTS_PER_THREAD 16 -#define OPTIMAL_TILE_W WARP_SIZE -#define MAX_H_BLOCK 128 -#define MAX_BLOCK_SIZE 512 - -__host__ int div_ru(int x, int y) { - return h_last_pow2(1 + (x-1)/y); -} - -__host__ void flexible_launch_configs( - const int reduction, - const int stride, - dim3 &block, - dim3 &grid, - const bool coop_flag = false) { - int block_x = std::min(h_last_pow2(stride), OPTIMAL_TILE_W); - int block_y = std::min(h_last_pow2(div_ru(reduction , ELEMENTS_PER_THREAD)), - MAX_BLOCK_SIZE / block_x); - if (block_x * block_y != MAX_BLOCK_SIZE) { - block_x = std::min(h_last_pow2(stride), MAX_BLOCK_SIZE / block_y); - } - - int grid_x = div_ru(stride, block_x); - int grid_y = std::min(div_ru(reduction, block_y * ELEMENTS_PER_THREAD), MAX_H_BLOCK); - if (coop_flag) { - // it's not worth having a grid reduction if the reduction dimension is not big enough - grid_y = grid_y < 8 ? 1 : grid_y; - } - - block.x = block_x; - block.y = block_y; - block.z = 1; - grid.x = grid_x; - grid.y = grid_y; - grid.z = 1; -} - -template -__device__ __forceinline__ void welford_merge_element(C& count, - T& mean, - T& m2n, - const C& num_new, - const T& mean_new, - const T& m2n_new) { - T factor = T(1.0) / max(1, (count + num_new)); - T delta0 = mean - mean_new; - mean = (mean_new * num_new + mean * count) * factor; - m2n += m2n_new + delta0 * delta0 * num_new * count * factor; - count += num_new; -} - -template -__device__ __forceinline__ void warp_reduce_mean_m2n(T &mean, T &m2n, int &num) -{ - #pragma unroll - for(int i = WARP_SIZE/2; i > 0; i >>= 1) { - auto num_new = SHFL_DOWN(0xffffffff, num, i); - auto mean_new = SHFL_DOWN(0xffffffff, mean, i); - auto m2n_new = SHFL_DOWN(0xffffffff, m2n, i); - welford_merge_element(num, mean, m2n, num_new, mean_new, m2n_new); - } -} - -template -__device__ void welford_reduce_mean_m2n( - T* __restrict__ x, - int* __restrict__ count, - T &mean, - T &m2n, - int &num, - int block_size, - int thread_id) -{ - int lane = thread_id % WARP_SIZE; - int wid = thread_id / WARP_SIZE; - - if (block_size > WARP_SIZE) { - warp_reduce_mean_m2n(mean, m2n, num); - if (lane == 0) { - x[wid*2] = mean; - x[wid*2+1] = m2n; - count[wid] = num; - } - __syncthreads(); - - if (wid == 0) { - mean = (thread_id < block_size / WARP_SIZE)? x[lane*2] : T(0); - m2n = (thread_id < block_size / WARP_SIZE)? x[lane*2+1] : T(0); - num = (thread_id < block_size / WARP_SIZE)? count[lane] : int(0); - } - } - - if (wid==0) warp_reduce_mean_m2n(mean, m2n, num); - - return; -} - -// return spatial size for NC+ Tensors -__host__ int get_tensor_spatial_size(const at::Tensor& input) -{ - auto space_size = input.size(2); - for (int i = 3; i < input.ndimension(); i++) { - space_size *= input.size(i); - } - return space_size; -} - -// promote accumulation scalar type. promote half to float. -__host__ at::ScalarType promote_scalartype(const at::Tensor& input) -{ - return input.scalar_type() == at::ScalarType::Half ? - at::ScalarType::Float : input.scalar_type(); -} - -// return single element size, optional accumulation type promotion. -__host__ size_t get_element_data_size(const at::Tensor& input, bool accumulation = false) -{ - auto scalar_type = accumulation ? promote_scalartype(input) : input.scalar_type(); - return at::elementSize(scalar_type); -} - -template -__device__ __forceinline__ void welford_merge_block_vertical(C& count, - T& mean, - T& m2n, - C* shmem_count, - T* shmem_mean, - T* shmem_m2n) { - // write to shared memory - auto address_base = threadIdx.x + threadIdx.y * blockDim.x; - shmem_mean[address_base] = mean; - shmem_m2n[address_base] = m2n; - shmem_count[address_base] = count; - -#pragma unroll - for (int offset = blockDim.y/2; offset > 0; offset >>= 1) { - __syncthreads(); - if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) { - auto address = address_base + offset * blockDim.x; - // read shared memory back to register for reduction - auto num_new = shmem_count[address]; - auto mean_new = shmem_mean[address]; - auto m2n_new = shmem_m2n[address]; - - welford_merge_element(count, mean, m2n, num_new, mean_new, m2n_new); - - // last write is not necessary - shmem_mean[address_base] = mean; - shmem_m2n[address_base] = m2n; - shmem_count[address_base] = count; - } - } -} - -template -__device__ __forceinline__ void merge_block_vertical(T& sum_dy, - T& sum_dy_xmu, - T* shmem_sum_dy, - T* shmem_sum_dy_xmu) { - // write to shared memory - auto address_base = threadIdx.x + threadIdx.y * blockDim.x; - shmem_sum_dy[address_base] = sum_dy; - shmem_sum_dy_xmu[address_base] = sum_dy_xmu; - -#pragma unroll - for (int offset = blockDim.y/2; offset > 0; offset >>= 1) { - __syncthreads(); - if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) { - auto address = address_base + offset * blockDim.x; - - sum_dy += shmem_sum_dy[address]; - sum_dy_xmu += shmem_sum_dy_xmu[address]; - - // last write is not necessary - shmem_sum_dy[address_base] = sum_dy; - shmem_sum_dy_xmu[address_base] = sum_dy_xmu; - } - } -} - - -// welford kernel calculating mean/biased_variance/unbiased_variance -template -#ifdef __HIP_PLATFORM_HCC__ -__launch_bounds__(MAX_BLOCK_SIZE) -#endif -__global__ void welford_kernel( - const scalar_t* __restrict__ input, - outscalar_t* __restrict__ out_mean, - outscalar_t* __restrict__ out_var_biased, - const int bs, - const int fs, - const int ss) { - int block_size = blockDim.x * blockDim.y; - int count = 0; - accscalar_t x_mean = accscalar_t(0); - accscalar_t m_2_n = accscalar_t(0); - - int thread_id = threadIdx.y*blockDim.x + threadIdx.x; - - for (int batch_id = threadIdx.y; batch_id < bs; batch_id += blockDim.y) { - int input_base = blockIdx.x*ss + batch_id*ss*fs; - // sequential welford - for (int offset = threadIdx.x; offset < ss ; offset += blockDim.x) { - count++; - auto x_n = static_cast(input[offset+input_base]); - auto d = x_n - x_mean; - x_mean += d / count; - m_2_n += d * (x_n - x_mean); - } - } - - static __shared__ int s_mem[WARP_SIZE]; - static __shared__ accscalar_t s_mem_ac[WARP_SIZE*2]; - - welford_reduce_mean_m2n(s_mem_ac, s_mem, x_mean, m_2_n, count, block_size, thread_id); - - if (thread_id == 0) { - out_mean[blockIdx.x] = static_cast(x_mean); - out_var_biased[blockIdx.x] = static_cast(m_2_n/count); - } -} - -// elementwise BN kernel -template -#ifdef __HIP_PLATFORM_HCC__ -__launch_bounds__(MAX_BLOCK_SIZE) -#endif -__global__ void batchnorm_forward_kernel( - const scalar_t* __restrict__ input, - const accscalar_t* __restrict__ mean, - const accscalar_t* __restrict__ inv_std, - const layerscalar_t* __restrict__ weight, - const layerscalar_t* __restrict__ shift, - scalar_t* __restrict__ out, - const int ss, - const int bs) { - auto m_c = mean[blockIdx.x]; - auto inv_std_c = inv_std[blockIdx.x]; - auto w_c = weight == NULL ? accscalar_t(1.0) : static_cast(weight[blockIdx.x]); - auto s_c = shift == NULL ? accscalar_t(0.0) : static_cast(shift[blockIdx.x]); - - for (int batch_offset = blockIdx.y*blockDim.y + threadIdx.y; batch_offset < bs; batch_offset += gridDim.y*blockDim.y) { - int address_base = blockIdx.x*ss + batch_offset*gridDim.x*ss; - for (int offset = threadIdx.x + blockIdx.z*blockDim.x; offset < ss ; offset+= gridDim.z*blockDim.x) { - out[address_base+offset] = static_cast(w_c * (static_cast(input[address_base+offset]) - m_c ) * inv_std_c + s_c); - } - } -} - -// Backward BN kernel, calculates grad_bias, grad_weight as well as intermediate -// results to calculating grad_input. -// Breaking the grad_input to two step to support sync BN, which requires all -// reduce of the intermediate results across processes. -template -#ifdef __HIP_PLATFORM_HCC__ -__launch_bounds__(MAX_BLOCK_SIZE) -#endif -__global__ void reduce_bn_kernel( - const scalar_t* __restrict__ input, - const scalar_t* __restrict__ grad_output, - const accscalar_t* __restrict__ mean, - const accscalar_t* __restrict__ inv_std, - accscalar_t* __restrict__ sum_dy_o, - accscalar_t* __restrict__ sum_dy_xmu_o, - layerscalar_t* __restrict__ grad_weight, - layerscalar_t* __restrict__ grad_bias, - const int bs, - const int fs, - const int ss) { - static __shared__ int s_mem[WARP_SIZE]; - //int total_item_num = bs * ss; - - int thread_id = threadIdx.y*blockDim.x + threadIdx.x; - - auto r_mean = mean[blockIdx.x]; - auto factor = inv_std[blockIdx.x]; - - // Kahan sum - accscalar_t sum_dy = 0.0; - accscalar_t sum_dy_xmu = 0.0; - accscalar_t sum_dy_c = 0.0; - accscalar_t sum_dy_xmu_c = 0.0; - for (int batch_id = threadIdx.y; batch_id < bs; batch_id += blockDim.y) { - int input_base = blockIdx.x*ss + batch_id*ss*fs; - for (int offset = threadIdx.x; offset < ss ; offset += blockDim.x) { - auto e_grad = static_cast(grad_output[offset+input_base]); - auto e_input = static_cast(input[offset+input_base]); - // calculating sum_dy - auto sum_dy_y = e_grad - sum_dy_c; - auto sum_dy_t = sum_dy + sum_dy_y; - sum_dy_c = (sum_dy_t - sum_dy) - sum_dy_y; - sum_dy = sum_dy_t; - - // calculating sum_dy_xmu - auto sum_dy_xmu_y = e_grad * (e_input - r_mean) - sum_dy_xmu_c; - auto sum_dy_xmu_t = sum_dy_xmu + sum_dy_xmu_y; - sum_dy_xmu_c = (sum_dy_xmu_t - sum_dy_xmu) - sum_dy_xmu_y; - sum_dy_xmu = sum_dy_xmu_t; - } - } - - sum_dy = reduce_block((accscalar_t*)s_mem, sum_dy); - __syncthreads(); - sum_dy_xmu = reduce_block((accscalar_t*)s_mem, sum_dy_xmu); - - if (thread_id == 0) { - if (grad_bias != NULL) { - grad_bias[blockIdx.x] = static_cast(sum_dy); - } - if (grad_weight != NULL) { - grad_weight[blockIdx.x] = static_cast(sum_dy_xmu * factor); - } - //mean_dy[blockIdx.x] = sum_dy / total_item_num; - //mean_dy_xmu[blockIdx.x] = sum_dy_xmu / total_item_num; - sum_dy_o[blockIdx.x] = sum_dy; - sum_dy_xmu_o[blockIdx.x] = sum_dy_xmu; - } -} - -// elementwise backward BN kernel -template -#ifdef __HIP_PLATFORM_HCC__ -__launch_bounds__(MAX_BLOCK_SIZE) -#endif -__global__ void batchnorm_backward_kernel( - const scalar_t* __restrict__ grad_output, - const scalar_t* __restrict__ input, - const accscalar_t* __restrict__ mean, - const accscalar_t* __restrict__ inv_std, - const layerscalar_t* __restrict__ weight, - const accscalar_t* __restrict__ sum_dy, - const accscalar_t* __restrict__ sum_dy_xmu, - const int* __restrict__ numel, - scalar_t* __restrict__ grad_input, - const int64_t world_size, - const int ss, - const int bs) { - int64_t div = 0; - for (int i = 0; i < world_size; i++) { - div += numel[i]; - } - auto m_c = static_cast(mean[blockIdx.x]); - //auto m_dy_c = static_cast(mean_dy[blockIdx.x]); - auto m_dy_c = static_cast(sum_dy[blockIdx.x]) / div; - auto factor_1_c = inv_std[blockIdx.x]; - auto factor_2_c = (weight == NULL ? accscalar_t(1.0) : static_cast(weight[blockIdx.x])) * factor_1_c; - //factor_1_c = factor_1_c * factor_1_c * mean_dy_xmu[blockIdx.x]; - factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[blockIdx.x] / div; - - for (int batch_offset = blockIdx.y*blockDim.y+threadIdx.y; batch_offset < bs; batch_offset += gridDim.y*blockDim.y) { - int address_base = blockIdx.x*ss + batch_offset*gridDim.x*ss; - for (int offset = threadIdx.x + blockIdx.z*blockDim.x; offset < ss ; offset+= gridDim.z*blockDim.x) { - grad_input[address_base+offset] = (static_cast(grad_output[address_base+offset]) - m_dy_c - (static_cast(input[address_base+offset]) - m_c) * factor_1_c) * factor_2_c; - } - } -} - -// welford kernel for c last tensor calculating mean/biased_variance/unbiased_variance -template - -#ifdef __HIP_PLATFORM_HCC__ -__launch_bounds__(MAX_BLOCK_SIZE) -#endif -__global__ void -welford_kernel_c_last( - const scalar_t* __restrict__ input, - outscalar_t* __restrict__ out_mean, - outscalar_t* __restrict__ out_var_biased, - volatile accscalar_t* staging_data, - int* semaphores, - const int reduction_size, - const int stride) { - // hide latency with concurrency - accscalar_t x_mean[PARALLEL_LOADS]; - accscalar_t m_2_n[PARALLEL_LOADS]; - int count[PARALLEL_LOADS]; - -#pragma unroll - for (int i = 0; i < PARALLEL_LOADS; i++) { - x_mean[i] = accscalar_t(0); - m_2_n[i] = accscalar_t(0); - count[i] = accscalar_t(0); - } - // tensor dimension (m,c) - - // loop along m dimension - int inner_loop_stride = blockDim.y * gridDim.y; - - // offset along m dimension - int m_offset = blockIdx.y * blockDim.y + threadIdx.y; - int c_offset = blockIdx.x * blockDim.x + threadIdx.x; - - int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS); - int address_base = m_offset * stride + c_offset; - int address_increment = inner_loop_stride * stride; - - for (int i = 0; i < loop_count; i++) { - accscalar_t x_math[PARALLEL_LOADS]; - accscalar_t x_count_inv[PARALLEL_LOADS]; - accscalar_t is_valid[PARALLEL_LOADS]; - - // load multiple data in -#pragma unroll - for (int j = 0; j < PARALLEL_LOADS; j++) { - if (c_offset < stride && m_offset < reduction_size) { - x_math[j] = input[address_base]; - count[j]++; - x_count_inv[j] = accscalar_t(1) / count[j]; - is_valid[j] = accscalar_t(1); - } else { - x_math[j] = accscalar_t(0); - x_count_inv[j] = accscalar_t(0); - is_valid[j] = accscalar_t(0); - } - m_offset += inner_loop_stride; - address_base += address_increment; - } - - // calculate mean/m2n with welford -#pragma unroll - for (int j = 0; j < PARALLEL_LOADS; j++) { - accscalar_t delta0 = x_math[j] - x_mean[j]; - x_mean[j] += delta0 * x_count_inv[j]; - accscalar_t delta1 = x_math[j] - x_mean[j]; - m_2_n[j] += delta0 * delta1 * is_valid[j]; - } - } - - // thread reduction to accumulate mean/m_2_n/count between PARALLEL_LOADS -#pragma unroll - for (int j = 1; j < PARALLEL_LOADS; j++) { - welford_merge_element(count[0], x_mean[0], m_2_n[0], count[j], x_mean[j], m_2_n[j]); - } - - // release x_mean / m_2_n - auto mean_th = x_mean[0]; - auto m2_th = m_2_n[0]; - auto count_th = count[0]; - - // block-wise reduction with shared memory (since reduction cannot be done within a warp) - static __shared__ accscalar_t shmem_mean[MAX_BLOCK_SIZE]; - static __shared__ accscalar_t shmem_m2n[MAX_BLOCK_SIZE]; - static __shared__ int shmem_count[MAX_BLOCK_SIZE]; - - welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n); - - // grid reduction if needed (coop launch used at the first place) - if (gridDim.y > 1) { - volatile accscalar_t* staging_mean = staging_data; - volatile accscalar_t* staging_m2n = &staging_data[stride*gridDim.y]; - volatile int* staging_count = reinterpret_cast(&staging_m2n[stride*gridDim.y]); - - address_base = c_offset + blockIdx.y * stride; - // write data to staging_data; - if (threadIdx.y == 0 && c_offset < stride) { - staging_mean[address_base] = mean_th; - staging_m2n[address_base] = m2_th; - staging_count[address_base] = count_th; - } - - __threadfence(); - __syncthreads(); // ensuring writes to staging_ is visible to all blocks - - __shared__ bool is_last_block_done; - // mark block done - if (threadIdx.x == 0 && threadIdx.y == 0) { - int old = atomicAdd(&semaphores[blockIdx.x], 1); - is_last_block_done = (old == (gridDim.y-1)); - } - - __syncthreads(); - - // check that all data is now available in global memory - if (is_last_block_done) { - count_th = 0; - mean_th = accscalar_t(0.0); - m2_th = accscalar_t(0.0); - - for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) { - address_base = c_offset + y * stride; - int num_new = c_offset < stride ? staging_count[address_base] : 0; - accscalar_t mean_new = c_offset < stride ? staging_mean[address_base] : accscalar_t(0.0); - accscalar_t m2n_new = c_offset < stride ? staging_m2n[address_base] : accscalar_t(0.0); - - welford_merge_element(count_th, mean_th, m2_th, num_new, mean_new, m2n_new); - } - - welford_merge_block_vertical(count_th, mean_th, m2_th, shmem_count, shmem_mean, shmem_m2n); - if (threadIdx.y == 0 && c_offset < stride) { - out_mean[c_offset] = static_cast(mean_th); - out_var_biased[c_offset] = static_cast(m2_th / count_th); - } - } - } else { - if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) { - out_mean[c_offset] = static_cast(mean_th); - out_var_biased[c_offset] = static_cast(m2_th / count_th); - } - } -} - -// parallel welford kernel to further reduce mean / biased_var -// into mean / unbiased_var / inv_std across multiple processes. -template -#ifdef __HIP_PLATFORM_HCC__ -__launch_bounds__(MAX_BLOCK_SIZE) -#endif -__global__ void welford_kernel_parallel( - const scalar_t* __restrict__ mean, - const scalar_t* __restrict__ var_biased, - const int* __restrict__ numel, - scalar_t* __restrict__ out_mean, - scalar_t* __restrict__ out_var, - scalar_t* __restrict__ inv_std, - const int world_size, - const int feature_size, - const float eps) { - - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < feature_size; i += gridDim.x * blockDim.x) { - // load data; - int address = i; - scalar_t x_mean = 0; - scalar_t m_2_n = 0; - int count = 0; - for (int j = 0; j < world_size; j++) { - welford_merge_element(count, x_mean, m_2_n, numel[j], mean[address], var_biased[address]*numel[j]); - address += feature_size; - } - out_mean[i] = x_mean; - out_var[i] = m_2_n/ (count - 1); - inv_std[i] = scalar_t(1) / sqrt(m_2_n/count + eps); - } -} - -// elementwise BN kernel -template < - typename scalar_t, - typename accscalar_t, - typename layerscalar_t, - int PARALLEL_LOADS> -#ifdef __HIP_PLATFORM_HCC__ -__launch_bounds__(MAX_BLOCK_SIZE) -#endif -__global__ void batchnorm_forward_c_last_kernel( - const scalar_t* __restrict__ input, - const scalar_t* __restrict__ z, - const accscalar_t* __restrict__ mean, - const accscalar_t* __restrict__ inv_std, - const layerscalar_t* __restrict__ weight, - const layerscalar_t* __restrict__ shift, - scalar_t* __restrict__ out, - const int reduction_size, - const int stride, - const bool fuse_relu) { - // tensor dimension (m,c) - // loop along m dimension - int inner_loop_stride = blockDim.y * gridDim.y; - - // offset along m dimension - int m_offset = blockIdx.y * blockDim.y + threadIdx.y; - int c_offset = blockIdx.x * blockDim.x + threadIdx.x; - - auto m_c = mean[c_offset]; - auto inv_std_c = static_cast(inv_std[c_offset]); - auto w_c = weight == NULL ? accscalar_t(1.0) : static_cast(weight[c_offset]); - auto s_c = shift == NULL ? accscalar_t(0.0) : static_cast(shift[c_offset]); - - int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS); - int address_base = m_offset * stride + c_offset; - int address_increment = inner_loop_stride * stride; - - for (int i = 0; i < loop_count; i++) { -#pragma unroll - for (int j = 0; j < PARALLEL_LOADS; j++) { - if (c_offset < stride && m_offset < reduction_size) { - auto tmp = w_c * (static_cast(input[address_base]) - m_c ) * inv_std_c + s_c; - if (z != NULL) { - tmp += z[address_base]; - } - out[address_base] = (fuse_relu && tmp <= accscalar_t(0.0) ? scalar_t(0.0) : static_cast(tmp)); - } - m_offset += inner_loop_stride; - address_base += address_increment; - } - } -} - -// elementwise BN kernel -template < - typename scalar_t, - typename accscalar_t, - typename layerscalar_t, - int PARALLEL_LOADS> -#ifdef __HIP_PLATFORM_HCC__ -__launch_bounds__(MAX_BLOCK_SIZE) -#endif -__global__ void relu_backward_c_last_kernel( - const scalar_t* __restrict__ grad_output, - const scalar_t* __restrict__ input, - const scalar_t* __restrict__ z, - const accscalar_t* __restrict__ mean, - const accscalar_t* __restrict__ inv_std, - const layerscalar_t* __restrict__ weight, - const layerscalar_t* __restrict__ shift, - scalar_t* __restrict__ out, - const int reduction_size, - const int stride) { - // tensor dimension (m,c) - // loop along m dimension - int inner_loop_stride = blockDim.y * gridDim.y; - - // offset along m dimension - int m_offset = blockIdx.y * blockDim.y + threadIdx.y; - int c_offset = blockIdx.x * blockDim.x + threadIdx.x; - - auto m_c = mean[c_offset]; - auto inv_std_c = static_cast(inv_std[c_offset]); - auto w_c = weight == NULL ? accscalar_t(1.0) : static_cast(weight[c_offset]); - auto s_c = shift == NULL ? accscalar_t(0.0) : static_cast(shift[c_offset]); - - int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS); - int address_base = m_offset * stride + c_offset; - int address_increment = inner_loop_stride * stride; - - for (int i = 0; i < loop_count; i++) { -#pragma unroll - for (int j = 0; j < PARALLEL_LOADS; j++) { - if (c_offset < stride && m_offset < reduction_size) { - auto tmp = w_c * (static_cast(input[address_base]) - m_c ) * inv_std_c + s_c; - if (z != NULL) { - tmp += z[address_base]; - } - out[address_base] = (tmp <= accscalar_t(0.0) ? scalar_t(0.0) : grad_output[address_base]); - } - m_offset += inner_loop_stride; - address_base += address_increment; - } - } -} - -// batchnorm backward kernel for c last tensor -template - -#ifdef __HIP_PLATFORM_HCC__ -__launch_bounds__(MAX_BLOCK_SIZE) -#endif -__global__ void reduce_bn_c_last_kernel( - const scalar_t* __restrict__ input, - const scalar_t* __restrict__ grad_output, - const accscalar_t* __restrict__ mean, - const accscalar_t* __restrict__ inv_std, - accscalar_t* __restrict__ sum_dy_o, - accscalar_t* __restrict__ sum_dy_xmu_o, - layerscalar_t* __restrict__ grad_weight, - layerscalar_t* __restrict__ grad_bias, - volatile accscalar_t* staging_data, - int* semaphores, - const int reduction_size, - const int stride) { - - // hide latency with concurrency - accscalar_t sum_dy[PARALLEL_LOADS]; - accscalar_t sum_dy_xmu[PARALLEL_LOADS]; - -#pragma unroll - for (int i = 0; i < PARALLEL_LOADS; i++) { - sum_dy[i] = accscalar_t(0); - sum_dy_xmu[i] = accscalar_t(0); - } - // tensor dimension (m,c) - - // loop along m dimension - int inner_loop_stride = blockDim.y * gridDim.y; - - // offset along m dimension - int m_offset = blockIdx.y * blockDim.y + threadIdx.y; - int c_offset = blockIdx.x * blockDim.x + threadIdx.x; - - int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS); - int address_base = m_offset * stride + c_offset; - int address_increment = inner_loop_stride * stride; - - auto r_mean = mean[c_offset]; - auto factor = inv_std[c_offset]; - - for (int i = 0; i < loop_count; i++) { - accscalar_t x_input[PARALLEL_LOADS]; - accscalar_t x_grad_output[PARALLEL_LOADS]; - - // load multiple data in -#pragma unroll - for (int j = 0; j < PARALLEL_LOADS; j++) { - if (c_offset < stride && m_offset < reduction_size) { - x_input[j] = input[address_base]; - x_grad_output[j] = grad_output[address_base]; - } else { - x_input[j] = accscalar_t(0); - x_grad_output[j] = accscalar_t(0); - } - m_offset += inner_loop_stride; - address_base += address_increment; - } - - // calculate sum_dy / sum_dy_xmu -#pragma unroll - for (int j = 0; j < PARALLEL_LOADS; j++) { - sum_dy[j] += x_grad_output[j]; - sum_dy_xmu[j] += x_grad_output[j] * (x_input[j] - r_mean); - } - } - - // thread reduction to accumulate sum_dy / sum_dy_xmu between PARALLEL_LOADS -#pragma unroll - for (int j = 1; j < PARALLEL_LOADS; j++) { - sum_dy[0] += sum_dy[j]; - sum_dy_xmu[0] += sum_dy_xmu[j]; - } - - // release array of registers - auto sum_dy_th = sum_dy[0]; - auto sum_dy_xmu_th = sum_dy_xmu[0]; - - // block-wise reduction with shared memory (since reduction cannot be done within a warp) - static __shared__ accscalar_t shmem_sum_dy[MAX_BLOCK_SIZE]; - static __shared__ accscalar_t shmem_sum_dy_xmu[MAX_BLOCK_SIZE]; - - merge_block_vertical(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu); - - // grid reduction if needed (coop launch used at the first place) - if (gridDim.y > 1) { - volatile accscalar_t* staging_sum_dy = staging_data; - volatile accscalar_t* staging_sum_dy_xmu = &staging_data[stride*gridDim.y]; - - address_base = c_offset + blockIdx.y * stride; - // write data to staging_data; - if (threadIdx.y == 0 && c_offset < stride) { - staging_sum_dy[address_base] = sum_dy_th; - staging_sum_dy_xmu[address_base] = sum_dy_xmu_th; - } - - __threadfence(); - __syncthreads(); // ensuring writes to staging_ is visible to all blocks - - __shared__ bool is_last_block_done; - // mark block done - if (threadIdx.x == 0 && threadIdx.y == 0) { - int old = atomicAdd(&semaphores[blockIdx.x], 1); - is_last_block_done = (old == (gridDim.y-1)); - } - - __syncthreads(); - - // check that all data is now available in global memory - if (is_last_block_done) { - sum_dy_th = accscalar_t(0.0); - sum_dy_xmu_th = accscalar_t(0.0); - - for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) { - address_base = c_offset + y * stride; - sum_dy_th += (c_offset < stride ? staging_sum_dy[address_base] : accscalar_t(0.0)); - sum_dy_xmu_th += (c_offset < stride ? staging_sum_dy_xmu[address_base] : accscalar_t(0.0)); - } - - merge_block_vertical(sum_dy_th, sum_dy_xmu_th, shmem_sum_dy, shmem_sum_dy_xmu); - if (threadIdx.y == 0 && c_offset < stride) { - if (grad_bias != NULL) { - grad_bias[c_offset] = static_cast(sum_dy_th); - } - if (grad_weight != NULL) { - grad_weight[c_offset] = static_cast(sum_dy_xmu_th * factor); - } - //mean_dy[c_offset] = sum_dy_th / reduction_size; - //mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size; - sum_dy_o[c_offset] = sum_dy_th; - sum_dy_xmu_o[c_offset] = sum_dy_xmu_th; - } - } - } else { - if (blockIdx.y == 0 && threadIdx.y == 0 && c_offset < stride) { - if (grad_bias != NULL) { - grad_bias[c_offset] = static_cast(sum_dy_th); - } - if (grad_weight != NULL) { - grad_weight[c_offset] = static_cast(sum_dy_xmu_th * factor); - } - //mean_dy[c_offset] = sum_dy_th / reduction_size; - //mean_dy_xmu[c_offset] = sum_dy_xmu_th / reduction_size; - sum_dy_o[c_offset] = sum_dy_th; - sum_dy_xmu_o[c_offset] = sum_dy_xmu_th; - } - } -} - -// elementwise BN kernel -template < - typename scalar_t, - typename accscalar_t, - typename layerscalar_t, - int PARALLEL_LOADS> -#ifdef __HIP_PLATFORM_HCC__ -__launch_bounds__(MAX_BLOCK_SIZE) -#endif -__global__ void batchnorm_backward_c_last_kernel( - const scalar_t* __restrict__ grad_output, - const scalar_t* __restrict__ input, - const accscalar_t* __restrict__ mean, - const accscalar_t* __restrict__ inv_std, - const layerscalar_t* __restrict__ weight, - const accscalar_t* __restrict__ sum_dy, - const accscalar_t* __restrict__ sum_dy_xmu, - const int* __restrict__ numel, - scalar_t* __restrict__ grad_input, - const int64_t world_size, - const int reduction_size, - const int stride) { - int64_t div = 0; - for (int i = 0; i < world_size; i++) { - div += numel[i]; - } - // tensor dimension (m,c) - // loop along m dimension - int inner_loop_stride = blockDim.y * gridDim.y; - - // offset along m dimension - int m_offset = blockIdx.y * blockDim.y + threadIdx.y; - int c_offset = blockIdx.x * blockDim.x + threadIdx.x; - - auto m_c = mean[c_offset]; - auto m_dy_c = sum_dy[c_offset] / div; - auto factor_1_c = inv_std[c_offset]; - auto factor_2_c = (weight == NULL? accscalar_t(1.0) : static_cast(weight[c_offset])) * factor_1_c; - factor_1_c = factor_1_c * factor_1_c * sum_dy_xmu[c_offset] / div; - - int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS); - int address_base = m_offset * stride + c_offset; - int address_increment = inner_loop_stride * stride; - - for (int i = 0; i < loop_count; i++) { -#pragma unroll - for (int j = 0; j < PARALLEL_LOADS; j++) { - if (c_offset < stride && m_offset < reduction_size) { - grad_input[address_base] = static_cast( - (static_cast(grad_output[address_base]) - m_dy_c - - (static_cast(input[address_base]) - m_c) * factor_1_c) - * factor_2_c); - } - m_offset += inner_loop_stride; - address_base += address_increment; - } - } -} - -std::vector welford_mean_var_CUDA(const at::Tensor input) { - const auto batch_size = input.size(0); - const auto feature_size = input.size(1); - - auto space_size = get_tensor_spatial_size(input); - auto scalar_type = promote_scalartype(input); - - at::Tensor out_var_biased = at::empty({feature_size}, input.options().dtype(scalar_type)); - at::Tensor out_mean = at::empty({feature_size}, input.options().dtype(scalar_type)); - - int block_y = min(h_last_pow2(batch_size), int(MAX_BLOCK_SIZE / WARP_SIZE)); - int block_x = max(1, min(MAX_BLOCK_SIZE / block_y, h_last_pow2(space_size))); - const dim3 block(block_x, block_y); - const dim3 grid(feature_size); - - auto stream = at::cuda::getCurrentCUDAStream(); - - { - using namespace at; - DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "welford_mean_var_kernel", - using accscalar_t = at::acc_type; - welford_kernel<<>>( - input.DATA_PTR(), - out_mean.DATA_PTR(), - out_var_biased.DATA_PTR(), - batch_size, - feature_size, - space_size); - ); - } - - return {out_mean, out_var_biased}; -} - -at::Tensor batchnorm_forward_CUDA( - const at::Tensor input, - const at::Tensor mean, - const at::Tensor inv_std, - const at::optional weight, - const at::optional shift) { - const auto batch_size = input.size(0); - const auto feature_size = input.size(1); - at::Tensor out = at::empty_like(input); - - auto space_size = get_tensor_spatial_size(input); - - int block_x = max(WARP_SIZE, min(MAX_BLOCK_SIZE, h_last_pow2(space_size)/4)); - int block_y = max(1, min(MAX_BLOCK_SIZE/block_x, h_last_pow2(batch_size)/4)); - const dim3 block(block_x, block_y); - int grid_z = max(1, min(65535, h_last_pow2(space_size)/4/block_x)); - int batch_group_size = max(1, min(65535, h_last_pow2(batch_size)/block_y)); - const dim3 grid(feature_size, batch_group_size, grid_z); - auto stream = at::cuda::getCurrentCUDAStream(); - - if (input.scalar_type() == at::ScalarType::Half - && weight.has_value() && - weight.value().scalar_type() == at::ScalarType::Float) { - using namespace at; - DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward", - using accscalar_t = at::acc_type; - batchnorm_forward_kernel<<>>( - input.DATA_PTR(), - mean.DATA_PTR(), - inv_std.DATA_PTR(), - weight.has_value() ? weight.value().DATA_PTR() : NULL, - shift.has_value() ? shift.value().DATA_PTR() : NULL, - out.DATA_PTR(), - space_size, - batch_size); - ); - } else { - if (weight.has_value()) { - TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(), - "input.scalar_type() is not supported with weight.scalar_type()"); - } - using namespace at; - DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward", - using accscalar_t = at::acc_type; - batchnorm_forward_kernel<<>>( - input.DATA_PTR(), - mean.DATA_PTR(), - inv_std.DATA_PTR(), - weight.has_value() ? weight.value().DATA_PTR() : NULL, - shift.has_value() ? shift.value().DATA_PTR() : NULL, - out.DATA_PTR(), - space_size, - batch_size); - ); - } - return out; -} - -std::vector reduce_bn_CUDA( - const at::Tensor grad_output, - const at::Tensor input, - const at::Tensor mean, - const at::Tensor inv_std, - const at::optional weight) -{ - const auto batch_size = input.size(0); - const auto feature_size = input.size(1); - - auto scalar_type = promote_scalartype(input); - - at::Tensor sum_dy = at::empty({feature_size}, mean.options()); - at::Tensor sum_dy_xmu = at::empty({feature_size}, mean.options()); - - at::Tensor grad_weight; - at::Tensor grad_bias; - if (weight.has_value()) { - grad_weight = at::empty({feature_size}, weight.value().options()); - grad_bias = at::empty({feature_size}, weight.value().options()); - } else { - grad_weight = at::empty({0}, mean.options()); - grad_bias = at::empty({0}, mean.options()); - } - - auto space_size = get_tensor_spatial_size(input); - - int block_y = min(h_last_pow2(batch_size), int(MAX_BLOCK_SIZE/ WARP_SIZE)); - int block_x = max(1, min(MAX_BLOCK_SIZE/ block_y, h_last_pow2(space_size))); - const dim3 block(block_x, block_y); - const dim3 grid(feature_size); - auto stream = at::cuda::getCurrentCUDAStream(); - - if (input.scalar_type() == at::ScalarType::Half - && weight.has_value() && - weight.value().scalar_type() == at::ScalarType::Float) { - using namespace at; - DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward_reduce", - using accscalar_t = at::acc_type; - reduce_bn_kernel<<>>( - input.DATA_PTR(), - grad_output.DATA_PTR(), - mean.DATA_PTR(), - inv_std.DATA_PTR(), - sum_dy.DATA_PTR(), - sum_dy_xmu.DATA_PTR(), - weight.has_value() ? grad_weight.DATA_PTR() : NULL, - weight.has_value() ? grad_bias.DATA_PTR() : NULL, - batch_size, - feature_size, - space_size); - ); - } else { - if (weight.has_value()) { - TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(), - "input.scalar_type() is not supported with weight.scalar_type()"); - } - using namespace at; - DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward_reduce", - using accscalar_t = at::acc_type; - reduce_bn_kernel<<>>( - input.DATA_PTR(), - grad_output.DATA_PTR(), - mean.DATA_PTR(), - inv_std.DATA_PTR(), - sum_dy.DATA_PTR(), - sum_dy_xmu.DATA_PTR(), - weight.has_value() ? grad_weight.DATA_PTR() : NULL, - weight.has_value() ? grad_bias.DATA_PTR() : NULL, - batch_size, - feature_size, - space_size); - ); - } - - return {sum_dy, sum_dy_xmu, grad_weight, grad_bias}; -} - -at::Tensor batchnorm_backward_CUDA( - const at::Tensor grad_output, - const at::Tensor input, - const at::Tensor mean, - const at::Tensor inv_std, - const at::optional weight, - const at::Tensor sum_dy, - const at::Tensor sum_dy_xmu, - const at::Tensor count) { - const auto batch_size = input.size(0); - const auto feature_size = input.size(1); - - at::Tensor grad_input = at::empty_like(input); - - auto space_size = get_tensor_spatial_size(input); - - int block_x = max(WARP_SIZE, min(MAX_BLOCK_SIZE, h_last_pow2(space_size)/4)); - int block_y = max(1, min(MAX_BLOCK_SIZE/block_x, h_last_pow2(batch_size)/4)); - const dim3 block(block_x, block_y); - int grid_z = max(1, min(65535, h_last_pow2(space_size)/4/block_x)); - int batch_group_size = max(1, min(65535, h_last_pow2(batch_size)/block_y)); - const dim3 grid(feature_size, batch_group_size, grid_z); - - auto stream = at::cuda::getCurrentCUDAStream(); - - if (input.scalar_type() == at::ScalarType::Half - && weight.has_value() && - weight.value().scalar_type() == at::ScalarType::Float) { - using namespace at; - DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward", - using accscalar_t = at::acc_type; - batchnorm_backward_kernel<<>>( - grad_output.DATA_PTR(), - input.DATA_PTR(), - mean.DATA_PTR(), - inv_std.DATA_PTR(), - weight.has_value() ? weight.value().DATA_PTR() : NULL, - sum_dy.DATA_PTR(), - sum_dy_xmu.DATA_PTR(), - count.DATA_PTR(), - grad_input.DATA_PTR(), - count.numel(), - space_size, - batch_size); - ); - } else { - if (weight.has_value()) { - TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(), - "input.scalar_type() is not supported with weight.scalar_type()"); - } - using namespace at; - DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward", - using accscalar_t = at::acc_type; - batchnorm_backward_kernel<<>>( - grad_output.DATA_PTR(), - input.DATA_PTR(), - mean.DATA_PTR(), - inv_std.DATA_PTR(), - weight.has_value() ? weight.value().DATA_PTR() : NULL, - sum_dy.DATA_PTR(), - sum_dy_xmu.DATA_PTR(), - count.DATA_PTR(), - grad_input.DATA_PTR(), - count.numel(), - space_size, - batch_size); - ); - } - - return grad_input; -} - -std::vector welford_parallel_CUDA(const at::Tensor mean_feature_nodes, - const at::Tensor var_biased, - const at::Tensor numel, - const float eps) { - const auto world_size = mean_feature_nodes.size(0); - const auto feature_size = mean_feature_nodes.size(1); - - at::Tensor out_var = at::empty({feature_size}, var_biased.options()); - at::Tensor inv_std = at::empty_like(out_var); - at::Tensor out_mean = at::empty_like(out_var); - - at::Tensor mean_feature_nodes_ = mean_feature_nodes.contiguous(); - at::Tensor var_biased_ = var_biased.contiguous(); - at::Tensor numel_ = numel.contiguous(); - - // TODO(jie): tile this for memory coalescing! - const int block = std::min(h_last_pow2(feature_size), MAX_BLOCK_SIZE); - const int grid = std::max(1, feature_size / block); - - auto stream = at::cuda::getCurrentCUDAStream(); - - { - using namespace at; - DISPATCH_FLOAT_AND_HALF(mean_feature_nodes.scalar_type(), 0, "welford_parallel_kernel", - welford_kernel_parallel<<>>( - mean_feature_nodes_.DATA_PTR(), - var_biased_.DATA_PTR(), - numel_.DATA_PTR(), - out_mean.DATA_PTR(), - out_var.DATA_PTR(), - inv_std.DATA_PTR(), - world_size, - feature_size, - eps); - ); - } - - return {out_mean, out_var, inv_std}; -} - -std::vector welford_mean_var_c_last_CUDA(const at::Tensor input) { - const auto stride = input.size(input.ndimension()-1); - const auto reduction_size = input.numel() / stride; - - auto scalar_type = promote_scalartype(input); - auto option = input.options().dtype(scalar_type); - - at::Tensor out_var_biased = at::empty({stride}, option); - at::Tensor out_mean = at::empty({stride}, option); - - dim3 block; - dim3 grid; - flexible_launch_configs(reduction_size, stride, block, grid, true); - - at::Tensor staging_data; - at::Tensor semaphores; - if (grid.y > 1) { - staging_data = at::empty({4*stride*grid.y}, option); - semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt)); - } - - auto stream = at::cuda::getCurrentCUDAStream(); - - { - using namespace at; - DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "welford_mean_var_c_last", - using accscalar_t = at::acc_type; - accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.DATA_PTR() : nullptr; - int* semaphores_ptr = grid.y > 1 ? semaphores.DATA_PTR() : nullptr; - welford_kernel_c_last - <<>>( - input.DATA_PTR(), - out_mean.DATA_PTR(), - out_var_biased.DATA_PTR(), - staging_data_ptr, - semaphores_ptr, - reduction_size, - stride); - ); - } - - return {out_mean, out_var_biased}; -} - -at::Tensor batchnorm_forward_c_last_CUDA( - const at::Tensor input, - const at::optional z, - const at::Tensor mean, - const at::Tensor inv_std, - const at::optional weight, - const at::optional shift, - const bool fuse_relu) { - const auto stride = input.size(input.ndimension()-1); - const auto reduction_size = input.numel() / stride; - - at::Tensor out = at::empty_like(input); - - dim3 block; - dim3 grid; - flexible_launch_configs(reduction_size, stride, block, grid); - - auto stream = at::cuda::getCurrentCUDAStream(); - - if (input.scalar_type() == at::ScalarType::Half - && weight.has_value() && weight.value().scalar_type() == at::ScalarType::Float) { - using namespace at; - DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward", - using accscalar_t = at::acc_type; - batchnorm_forward_c_last_kernel - <<>>( - input.DATA_PTR(), - z.has_value() ? z.value().DATA_PTR() : NULL, - mean.DATA_PTR(), - inv_std.DATA_PTR(), - weight.has_value() ? weight.value().DATA_PTR() : NULL, - shift.has_value() ? shift.value().DATA_PTR(): NULL, - out.DATA_PTR(), - reduction_size, - stride, - fuse_relu); - ); - } else { - if (weight.has_value()) { - TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(), - "input.scalar_type() is not supported with weight.scalar_type()"); - } - using namespace at; - DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward", - using accscalar_t = at::acc_type; - batchnorm_forward_c_last_kernel - <<>>( - input.DATA_PTR(), - z.has_value() ? z.value().DATA_PTR() : NULL, - mean.DATA_PTR(), - inv_std.DATA_PTR(), - weight.has_value() ? weight.value().DATA_PTR() : NULL, - shift.has_value() ? shift.value().DATA_PTR(): NULL, - out.DATA_PTR(), - reduction_size, - stride, - fuse_relu); - ); - } - return out; -} - -std::vector reduce_bn_c_last_CUDA( - const at::Tensor grad_output, - const at::Tensor input, - const at::Tensor mean, - const at::Tensor inv_std, - const at::optional weight) { - const auto stride = input.size(input.ndimension()-1); - const auto reduction_size = input.numel() / stride; - - at::Tensor sumn_dy = at::empty({stride}, mean.options()); - at::Tensor sum_dy_xmu = at::empty({stride}, mean.options()); - - at::Tensor grad_weight; - at::Tensor grad_bias; - if (weight.has_value()) { - grad_weight = at::empty({stride}, weight.value().options()); - grad_bias = at::empty({stride}, weight.value().options()); - } else { - // because I cannot return an uninitialized at::Tensor - grad_weight = at::empty({0}, mean.options()); - grad_bias = at::empty({0}, mean.options()); - } - - dim3 block; - dim3 grid; - flexible_launch_configs(reduction_size, stride, block, grid, true); - - at::Tensor staging_data; - at::Tensor semaphores; - if (grid.y > 1) { - staging_data = at::empty({2*stride*grid.y}, mean.options()); - semaphores = at::zeros({grid.x}, input.options().dtype(at::kInt)); - } - auto stream = at::cuda::getCurrentCUDAStream(); - - if (input.scalar_type() == at::ScalarType::Half - && weight.has_value() - && weight.value().scalar_type() == at::ScalarType::Float) { - using namespace at; - DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward_reduce", - using accscalar_t = at::acc_type; - accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.DATA_PTR() : nullptr; - int* semaphores_ptr = grid.y > 1 ? semaphores.DATA_PTR() : nullptr; - reduce_bn_c_last_kernel - <<>>( - input.DATA_PTR(), - grad_output.DATA_PTR(), - mean.DATA_PTR(), - inv_std.DATA_PTR(), - sumn_dy.DATA_PTR(), - sum_dy_xmu.DATA_PTR(), - weight.has_value() ? grad_weight.DATA_PTR() : NULL, - weight.has_value() ?grad_bias.DATA_PTR() : NULL, - staging_data_ptr, - semaphores_ptr, - reduction_size, - stride); - ); - } else { - if (weight.has_value()) { - TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(), - "input.scalar_type() is not supported with weight.scalar_type()"); - } - using namespace at; - DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_backward_reduce", - using accscalar_t = at::acc_type; - accscalar_t* staging_data_ptr = grid.y > 1 ? staging_data.DATA_PTR() : nullptr; - int* semaphores_ptr = grid.y > 1 ? semaphores.DATA_PTR() : nullptr; - reduce_bn_c_last_kernel - <<>>( - input.DATA_PTR(), - grad_output.DATA_PTR(), - mean.DATA_PTR(), - inv_std.DATA_PTR(), - sumn_dy.DATA_PTR(), - sum_dy_xmu.DATA_PTR(), - weight.has_value() ? grad_weight.DATA_PTR() : NULL, - weight.has_value() ?grad_bias.DATA_PTR() : NULL, - staging_data_ptr, - semaphores_ptr, - reduction_size, - stride); - ); - } - - return {sumn_dy, sum_dy_xmu, grad_weight, grad_bias}; -} - -at::Tensor batchnorm_backward_c_last_CUDA( - const at::Tensor grad_output, - const at::Tensor input, - const at::Tensor mean, - const at::Tensor inv_std, - const at::optional weight, - const at::Tensor sum_dy, - const at::Tensor sum_dy_xmu, - const at::Tensor count) { - const auto stride = input.size(input.ndimension()-1); - const auto reduction_size = input.numel() / stride; - - at::Tensor grad_input = at::empty_like(input); - - dim3 block; - dim3 grid; - flexible_launch_configs(reduction_size, stride, block, grid); - - auto stream = at::cuda::getCurrentCUDAStream(); - - if (input.scalar_type() == at::ScalarType::Half - && weight.has_value() && weight.value().scalar_type() == at::ScalarType::Float) { - using namespace at; - DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward", - using accscalar_t = at::acc_type; - batchnorm_backward_c_last_kernel - <<>>( - grad_output.DATA_PTR(), - input.DATA_PTR(), - mean.DATA_PTR(), - inv_std.DATA_PTR(), - weight.has_value() ? weight.value().DATA_PTR() : NULL, - sum_dy.DATA_PTR(), - sum_dy_xmu.DATA_PTR(), - count.DATA_PTR(), - grad_input.DATA_PTR(), - count.numel(), - reduction_size, - stride); - ); - } else { - if (weight.has_value()) { - TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(), - "input.scalar_type() is not supported with weight.scalar_type()"); - } - using namespace at; - DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward", - using accscalar_t = at::acc_type; - batchnorm_backward_c_last_kernel - <<>>( - grad_output.DATA_PTR(), - input.DATA_PTR(), - mean.DATA_PTR(), - inv_std.DATA_PTR(), - weight.has_value() ? weight.value().DATA_PTR() : NULL, - sum_dy.DATA_PTR(), - sum_dy_xmu.DATA_PTR(), - count.DATA_PTR(), - grad_input.DATA_PTR(), - count.numel(), - reduction_size, - stride); - ); - } - - return grad_input; -} - -at::Tensor relu_backward_c_last_CUDA( - const at::Tensor grad_output, - const at::Tensor input, - const at::optional z, - const at::Tensor mean, - const at::Tensor inv_std, - const at::optional weight, - const at::optional shift) { - - const auto stride = input.size(input.ndimension()-1); - const auto reduction_size = input.numel() / stride; - - at::Tensor out = at::empty_like(input); - - dim3 block; - dim3 grid; - flexible_launch_configs(reduction_size, stride, block, grid); - - auto stream = at::cuda::getCurrentCUDAStream(); - - if (input.scalar_type() == at::ScalarType::Half - && weight.has_value() && weight.value().scalar_type() == at::ScalarType::Float) { - using namespace at; - DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward", - using accscalar_t = at::acc_type; - relu_backward_c_last_kernel - <<>>( - grad_output.DATA_PTR(), - input.DATA_PTR(), - z.has_value() ? z.value().DATA_PTR() : NULL, - mean.DATA_PTR(), - inv_std.DATA_PTR(), - weight.has_value() ? weight.value().DATA_PTR() : NULL, - shift.has_value() ? shift.value().DATA_PTR(): NULL, - out.DATA_PTR(), - reduction_size, - stride); - ); - } else { - if (weight.has_value()) { - TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(), - "input.scalar_type() is not supported with weight.scalar_type()"); - } - using namespace at; - DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward", - using accscalar_t = at::acc_type; - relu_backward_c_last_kernel - <<>>( - grad_output.DATA_PTR(), - input.DATA_PTR(), - z.has_value() ? z.value().DATA_PTR() : NULL, - mean.DATA_PTR(), - inv_std.DATA_PTR(), - weight.has_value() ? weight.value().DATA_PTR() : NULL, - shift.has_value() ? shift.value().DATA_PTR(): NULL, - out.DATA_PTR(), - reduction_size, - stride); - ); - } - return out; -} diff --git a/docs/Makefile b/docs/Makefile deleted file mode 100644 index 86cc249..0000000 --- a/docs/Makefile +++ /dev/null @@ -1,32 +0,0 @@ -# Minimal makefile for Sphinx documentation -# - -# You can set these variables from the command line. -SPHINXOPTS = -SPHINXBUILD = sphinx-build -SPHINXPROJ = NVIDIAAPEX -SOURCEDIR = source -BUILDDIR = build - -# Put it first so that "make" without argument is like "make help". -help: - @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) - -gh-pages: - git checkout gh-pages - rm -rf build - rm -rf source - git checkout master -- . - make html - rm -rf ../_modules ../_sources ../_static - mv -fv build/html/* ../ - rm -rf build - git add -A - git commit -m "Generated gh-pages for `git log master -1 --pretty=short --abbrev-commit`" && git push origin gh-pages ; git checkout master - -.PHONY: help Makefile - -# Catch-all target: route all unknown targets to Sphinx using the new -# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). -%: Makefile - @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/source/_static/css/pytorch_theme.css b/docs/source/_static/css/pytorch_theme.css deleted file mode 100644 index 45e984c..0000000 --- a/docs/source/_static/css/pytorch_theme.css +++ /dev/null @@ -1,118 +0,0 @@ -body { - font-family: "Lato","proxima-nova","Helvetica Neue",Arial,sans-serif; -} - -/* Default header fonts are ugly */ -h1, h2, .rst-content .toctree-wrapper p.caption, h3, h4, h5, h6, legend, p.caption { - font-family: "Lato","proxima-nova","Helvetica Neue",Arial,sans-serif; -} - -/* Use white for docs background */ -.wy-side-nav-search { - background-color: #fff; -} - -.wy-nav-content-wrap, .wy-menu li.current > a { - background-color: #fff; -} - -@media screen and (min-width: 1400px) { - .wy-nav-content-wrap { - background-color: rgba(0, 0, 0, 0.0470588); - } - - .wy-nav-content { - background-color: #fff; - } -} - -/* Fixes for mobile */ -.wy-nav-top { - background-color: #fff; - background-image: url('../img/apex.jpg'); - background-repeat: no-repeat; - background-position: center; - padding: 0; - margin: 0.4045em 0.809em; - color: #333; -} - -.wy-nav-top > a { - display: none; -} - -@media screen and (max-width: 768px) { - .wy-side-nav-search>a img.logo { - height: 60px; - } -} - -/* This is needed to ensure that logo above search scales properly */ -.wy-side-nav-search a { - display: block; -} - -/* This ensures that multiple constructors will remain in separate lines. */ -.rst-content dl:not(.docutils) dt { - display: table; -} - -/* Use our red for literals (it's very similar to the original color) */ -.rst-content tt.literal, .rst-content tt.literal, .rst-content code.literal { - color: #F05732; -} - -.rst-content tt.xref, a .rst-content tt, .rst-content tt.xref, -.rst-content code.xref, a .rst-content tt, a .rst-content code { - color: #404040; -} - -/* Change link colors (except for the menu) */ - -a { - color: #F05732; -} - -a:hover { - color: #F05732; -} - - -a:visited { - color: #D44D2C; -} - -.wy-menu a { - color: #b3b3b3; -} - -.wy-menu a:hover { - color: #b3b3b3; -} - -/* Default footer text is quite big */ -footer { - font-size: 80%; -} - -footer .rst-footer-buttons { - font-size: 125%; /* revert footer settings - 1/80% = 125% */ -} - -footer p { - font-size: 100%; -} - -/* For hidden headers that appear in TOC tree */ -/* see http://stackoverflow.com/a/32363545/3343043 */ -.rst-content .hidden-section { - display: none; -} - -nav .hidden-section { - display: inherit; -} - -.wy-side-nav-search>div.version { - color: #000; -} diff --git a/docs/source/_static/img/nv-pytorch2.png b/docs/source/_static/img/nv-pytorch2.png deleted file mode 100644 index 981268c60032b463b387dd3c9f56ebda9929a266..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6502 zcmX|Gby!s2)89qFMOaw6Vd(}*rMtVNy95-LhFxOuOC#N)Agv%J-LW)CcSuQtNQ2aG zegAl$`<#28GxyASX3pIC%$zyVI$Fx$XH?Gs003B3ML`eU5BzH&Z1lU_xFQGLVcV)J zD*zt9 zvjlfUR|7_~B+euN>?F^^24BjKs^ci=a ziBDP<>6!TFwU5c!%j$O_GNjBgN|jJ3I{-!s!{QLETW$ICGAHP9OXN+nV`UkVOYYb@ zX6u|P`QWx*gJ48dY`ndNYBzGy~qE7Ol;#+Nh-Ihy(g4yWRC3%q@!3Ib! zOD@@7{&(aP02J=M7K6hV>0YZ2-WUZ}p9B`@bbIyZD9C(KDKqy>@1qNxU&Z^{?DWJ| zGS0x#YMTp^^8C~b(>8|+K@4QDn6PWj-S7-}r+go@++9`aUoow@*Ju&a00zKSUzg(k zn-g@d5+ZV15E@T7T=6h$W*PhrN>SBidnqc}PIjk?>z)M>1 z?R&l~8B(<_$`yOZ!~?c?QLRh}5s3BD(rNW=F4|$ym~N9x(`1RH_MNtLzFlUlT12RI zJQfNNBS?S@h79ud&a)rc(p6Ey>l-qqhhpamLtl3;=cIiqTN-*t@~&L?R&Ul0gtz#Q zcvjxd7xaYZS`&@;^b(pa7cok^jvnnbPr)wiy?}=GGmp!b2A8Yb1n_2rYuImz_H6Mv zTmU9u{O7`~SKhYgQu?O2Yc7ZA^7UBG@1Jgmn!{mJIX_hL*M_&<=@woe-AGEV_&vUS z5C_7E^toSgI@Vr!n~m;x2Drgfl$qf~KpO{-90{JovHE&J?^3!BEM=Mb5~T5-j(YH| zf2z0A%*!0E;ND1Nl?JN#wh~HMZ+y;<(M^TJA`!|tfM7N`+Ow1FD=;F2%(64iaAi;1 zN;B;`oO7b5^d?-HX*eRxd9`8HHK;Tcmu#~P4v1f`eB>p;ki(S}k$=Rl z1Hc_WXAnY^Vs?hsU!x2b3@lME z#e%U_8R67Go76BOI&JD=Du)DDv=UYLvND-0O|PtL=j!BA45ubGL7${6~}A2o;{ z1(#IZsejBTZb6_U($uvp>d5!z)pQAYSA-B#E@svAwJKMFr1e8I$4|AP86LmsI{B3Q z9UBlMwAOHSxvVm)gbU2e z4I|3YA^j!UJ>eVWFln7DlSC#@2+T{{J;&x?))gWAzrjYI3;9GZnSS~u*3fYGv8M z>~+oc{Po@mvnT$UYO17`AGNY1b!&QQJsUU9V~g<47I&N^x9*=6LP+{>YSAZkKWZPP z8RopOyRnchs!-0(!uow+*VgH>r&Zi?II>E-?UZ!5a9yLy2Ub6Z_kHPO>vYwR^fN=* zpL-JpX-p8vWwB98maXe)Z54rk1bg(w19TDLL^%MUP2Rc3vpzsK zdfC2Kic=}o@Fv$wMtvz0UrkHL`FCh<46?#678u+{Zvy$?7e!C27jJev2NC-R?@I?i zz&rpL-Heoar0>mXZ2o5VxxW#8JedR2yP^yVwTs3qDe7^7Hc9_}2*C+A!^EqBkpPBQ zZ+BBf`Ro;STGNwz_kOwX6|kDWk&!Wm=7seeVFi>tp9$74ehB$R6$;#nq3*BpOUiPTUaveF$4J?FG&(lfM)R!3H)))DF-^&>qGYARxcfI-k zdlXRzSj+oEtb|8*oKb>vLAe7$_?Am$nVOFqBazQ_&>LIOzug;*%Rc1YQgO0j72+j; zMmxu1BlYKD9F>ZAgdC@NAP|ZfZAP>qnE1oK8vU*53XzwgV3jYmrHQCmv1ANWE*#J& zgkVR<;|f-ppqw!8aDX;CaK-*`%6aalMXvZbu9+Glh(;w%0bvcs90L?D(Ibp%itDM0 zlR+G*nm;3ESRav^zg$rV2AaqLe}C$MrgR5A(v*coHyc49Vt(V8E4Fz0{i@t45xbgn zNgxoIMfcBErlqJQs+a(d_T)UguQ4=u+Wo_0pGTEsUVq8t|IkE{bDsF~7hxE3oU1K) zXvVTpok){JKtY5Sm}dkR2BMOc3IdeI1u(;u1mgb$V55?y&{6jvx?^I<5p&!JMc5cm z^8Wbn9ywkB#p{I?I()czh9UQaMf1cgeqCi)NPXa(9=yQ+uIF`bC#wKiOs3>6n}3ieHqm-(VyN%VX=lY_nufrlX%1h2!BkIOn0bYW$HP8S=Q zzs=;NmXB)+Z&v6cheD@|j5sJ9@haut_2_~?K6|9?ZMKcfYIDyq! zA5pdJS2I?6nAP@&=_`t3$R#V#Nz@7M99^NlAD1Mot1NU-21zq&w_^h;tdq1}h0^T| z4m23)wsXMRi;>eJg55cxh9 zKKtBQ52^nYaBI}h45lJjz&>aY+@m+l%nUwRF;4-mLCN;)A%{cj+&@MC(pelKxS`|% z-E?iNUEkdWKu-gGd(6v}!>7`ot_|l48-EdBLhG$l6wjsAct{TAf=-=-{W$Ad8UE$fq*QUVi$+&i~xrA4w6-sPNJ zLP#>RZe#KbzmkmWJQN1&KZe9F1^xZ8OJnU%4sg#vCEAQe+L_^b-;3N2C>mqd?I z$VBF^4N+Wn5RLn`gzsPAl zyOsuJgk0Zq%x&ur@9A|4I=;d@vZ~`*5iHKwcQw9^NmGHD36Kx#882H%+H9|-ku3yC zcg80VnfdP)U0X0^0zes~eHQLgT#`PLs~ZW0a0DV0#n2c0o)zra+X1A*2v-bky*Q_a z{W&FTT$y;jR_)x?GpBwAVTb!9UQRh`RKLkEj26^=S=@IA%%A=mwsh~l6+Ewd?lk1SI5}4rz9pVl zSK!fU@H=B%3QP0o#k%(vP=>!ng%aU>Gy$NkE4%rLs=F5v z&rcB8BQ~#pn$&iCfsbL!W}@fb*aGi9j&NrtktpH3>>fQk0M${K(1KI)$sYC}l`*A( zOJl8I#UMfW+bnm|A}y?@fw?zBh6@JIvbA|0ANNkSQ@@RcW|z;c+xJ=`e9gr5H)2l^ zBt+p)0M@LbYzLo49gVFrl|k_jrt3Da<=np#nWaI&zIB=$8BMD~oY`tNSG@!{4dN4- zPIFc(?jp6UQ0T8OY2PBSZ{zr>G5y4{*V|BIS(;y$Zzt(oI|g4G1wm%Xx4US$w$-A| zC2gGl`quQ*-rjJ| ztF28BsJegH6IvkTOSVNErHPk7AjyW8tIf3b4Wc^8*uRc{lGO}H>pD**clTI09aXE z$6Dx@D`Op_^h(j#3hS0bdKFkjpy3JojkL2){fx7b^Lm{&(#94O3QnWlGN`@X`?~c+ zi-LwO(NzR8QYY)!X2a`$to@Z6IF}4BrZ;fiPl<^wzX5?rkp{IHIumYYPBjZS|84Il ziQoqB)jLc>iK1?U|MqO84y6xfT>U3!Z<41jEkb8c1vO`-2&jB2L(islsuQKUy87E! z`H4!nnhrvu)mh32;ITLXYd!^dUn~dh^54g@nU{*=v?n92xB(t$oGR<`7w{D=#2N!g z8Cm;I&%^(M8?Dz3D9VLMEGbT%vUf59KqQVq|FXGmf?UIIiZbkaw@0I{gd87M^(K4= z(7|m@?ug*kf3o&PX?(G>oxfx0GcAByUAmiD zJu_izOq^)Z60m+rorc0}=VsF8#u`4Nbs1UU8cjoGwQ%HVaHsqg-#;4v)Smi2#(V!H zlf`l8o#aawF}=${o+oA%-v0O~?NYgXXQTf!s~;I^O9DkFzqQt8rqna;@1xq)SOh*8 zw+!=SS36JBbGqCFauqFm`7NluD<;w;#*saQzx`5SR?1;A4 z_&@m38{LsoaN-lk<)GKL{G^*^Qayeev)X zyZAh-b*nY`b8-cpocWQQ1l!mRfe{vC=`*Vd>BrZ-i^kqHNkbFg%A4B!SBN9Qoz@1@y)^UCw|Z8`A}|1=G&dl`KTS$ z4g!I$-MRgb_q-ZGUcFQ@0Dxmf^v^P|zQfDOh|rwfdJNk=I@b8Gd@JdDos+@%hS*#n zS=0$qg(UE*DLdi?tVv?=5HUwEQhsU5Uctj`Ug)Ge63+a1lDepLLH6ctgS*wO^=x(L z@6n$K5J;0GnVDK(YbH#td10m_nk`FTb5}@DLSxBN(ve$yB@k%LNJs6qF~Jux}IdUZIx+e{q6Ox-V4+W)f65Goeb&BD_eSfbfw=ioc`jK&oY24 z2Gy56^pX5~cj{LwEpo(kPFEtC&=c%IaHrS0yca_9W7!RUF!v)ttzQHG;#`SuMq7Od zYAdZ%&to2LmvtsLn<;RC3slS+-LgWsk1bAOy=hnGHFG~4ijKHPr)9i`pM^q){X9@G zRwJ~skM(Ee8v$=Rw0X1LfA8cMGX*wS#`EeWMKEaQ{Bw@W+}JVE-r@-J`qi~DMNQ=% z&x(N;#v&y0OTIG}Wvd+l6Imk&$x)9EbEn2ZE;g!D$sm8dT!1rCM5CtjOh)d*!=b-Z zD9Rb_GL7Ef*3I%Fag}K~)q8as%!`z{L;|ZfHn^Qy+~aFjs#whrr313oKO3{o*v%79 zeO+;=JYIzbgbV%@@27M(WnahnP9E;|EM`ut{||i?#iQHUZIPh6d7J*xH>tIkUI5G< zctnC3Mg3Fd6st1!$Ho21xpZC+UuoFK6(G8s6cTr&wb^?N68HyRFh#N<)@% zhg~om=Old*LK1e|-IZ+xyJy_)TV^a%(PtmdJOi=tpb0!`89|SAk1WUiq2j*VP9V8qFba?xsO^$OZhG7e4-~V=# zmpfk@!z9^t?SU41%qXokr{7XZW&ofl>ZtW65}pEohU$6=_r9$QCnpg8*lCF_on|O> zT?sk(kG)qp`NU;}22}KXNXIWLy9=aP7}!vl&m4q(v(TfHV{K8UEnD tJlJ8e8gOw?=K$W-{QtX#aNCjx@)CT9tZGP)Ao@lGpsJ{)P$6gi?tiFYTyg*a diff --git a/docs/source/_templates/layout.html b/docs/source/_templates/layout.html deleted file mode 100644 index 63dfed9..0000000 --- a/docs/source/_templates/layout.html +++ /dev/null @@ -1,51 +0,0 @@ -{% extends "!layout.html" %} - {% block sidebartitle %} {{ super() }} - - - {% endblock %} - - {% block footer %} {{ super() }} - - - {% endblock %} diff --git a/docs/source/advanced.rst b/docs/source/advanced.rst deleted file mode 100644 index d6623e6..0000000 --- a/docs/source/advanced.rst +++ /dev/null @@ -1,219 +0,0 @@ -.. role:: hidden - :class: hidden-section - -Advanced Amp Usage -=================================== - -GANs ----- - -GANs are an interesting synthesis of several topics below. A `comprehensive example`_ -is under construction. - -.. _`comprehensive example`: - https://github.com/NVIDIA/apex/tree/master/examples/dcgan - -Gradient clipping ------------------ -Amp calls the params owned directly by the optimizer's ``param_groups`` the "master params." - -These master params may be fully or partially distinct from ``model.parameters()``. -For example, with `opt_level="O2"`_, ``amp.initialize`` casts most model params to FP16, -creates an FP32 master param outside the model for each newly-FP16 model param, -and updates the optimizer's ``param_groups`` to point to these FP32 params. - -The master params owned by the optimizer's ``param_groups`` may also fully coincide with the -model params, which is typically true for ``opt_level``\s ``O0``, ``O1``, and ``O3``. - -In all cases, correct practice is to clip the gradients of the params that are guaranteed to be -owned **by the optimizer's** ``param_groups``, instead of those retrieved via ``model.parameters()``. - -Also, if Amp uses loss scaling, gradients must be clipped after they have been unscaled -(which occurs during exit from the ``amp.scale_loss`` context manager). - -The following pattern should be correct for any ``opt_level``:: - - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - # Gradients are unscaled during context manager exit. - # Now it's safe to clip. Replace - # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) - # with - torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_norm) - # or - torch.nn.utils.clip_grad_value_(amp.master_params(optimizer), max_) - -Note the use of the utility function ``amp.master_params(optimizer)``, -which returns a generator-expression that iterates over the -params in the optimizer's ``param_groups``. - -Also note that ``clip_grad_norm_(amp.master_params(optimizer), max_norm)`` is invoked -*instead of*, not *in addition to*, ``clip_grad_norm_(model.parameters(), max_norm)``. - -.. _`opt_level="O2"`: - https://nvidia.github.io/apex/amp.html#o2-fast-mixed-precision - -Custom/user-defined autograd functions --------------------------------------- - -The old Amp API for `registering user functions`_ is still considered correct. Functions must -be registered before calling ``amp.initialize``. - -.. _`registering user functions`: - https://github.com/NVIDIA/apex/tree/master/apex/amp#annotating-user-functions - -Forcing particular layers/functions to a desired type ------------------------------------------------------ - -I'm still working on a generalizable exposure for this that won't require user-side code divergence -across different ``opt-level``\ s. - -Multiple models/optimizers/losses ---------------------------------- - -Initialization with multiple models/optimizers -********************************************** - -``amp.initialize``'s optimizer argument may be a single optimizer or a list of optimizers, -as long as the output you accept has the same type. -Similarly, the ``model`` argument may be a single model or a list of models, as long as the accepted -output matches. The following calls are all legal:: - - model, optim = amp.initialize(model, optim,...) - model, [optim0, optim1] = amp.initialize(model, [optim0, optim1],...) - [model0, model1], optim = amp.initialize([model0, model1], optim,...) - [model0, model1], [optim0, optim1] = amp.initialize([model0, model1], [optim0, optim1],...) - -Backward passes with multiple optimizers -**************************************** - -Whenever you invoke a backward pass, the ``amp.scale_loss`` context manager must receive -**all the optimizers that own any params for which the current backward pass is creating gradients.** -This is true even if each optimizer owns only some, but not all, of the params that are about to -receive gradients. - -If, for a given backward pass, there's only one optimizer whose params are about to receive gradients, -you may pass that optimizer directly to ``amp.scale_loss``. Otherwise, you must pass the -list of optimizers whose params are about to receive gradients. Example with 3 losses and 2 optimizers:: - - # loss0 accumulates gradients only into params owned by optim0: - with amp.scale_loss(loss0, optim0) as scaled_loss: - scaled_loss.backward() - - # loss1 accumulates gradients only into params owned by optim1: - with amp.scale_loss(loss1, optim1) as scaled_loss: - scaled_loss.backward() - - # loss2 accumulates gradients into some params owned by optim0 - # and some params owned by optim1 - with amp.scale_loss(loss2, [optim0, optim1]) as scaled_loss: - scaled_loss.backward() - -Optionally have Amp use a different loss scaler per-loss -******************************************************** - -By default, Amp maintains a single global loss scaler that will be used for all backward passes -(all invocations of ``with amp.scale_loss(...)``). No additional arguments to ``amp.initialize`` -or ``amp.scale_loss`` are required to use the global loss scaler. The code snippets above with -multiple optimizers/backward passes use the single global loss scaler under the hood, -and they should "just work." - -However, you can optionally tell Amp to maintain a loss scaler per-loss, which gives Amp increased -numerical flexibility. This is accomplished by supplying the ``num_losses`` argument to -``amp.initialize`` (which tells Amp how many backward passes you plan to invoke, and therefore -how many loss scalers Amp should create), then supplying the ``loss_id`` argument to each of your -backward passes (which tells Amp the loss scaler to use for this particular backward pass):: - - model, [optim0, optim1] = amp.initialize(model, [optim0, optim1], ..., num_losses=3) - - with amp.scale_loss(loss0, optim0, loss_id=0) as scaled_loss: - scaled_loss.backward() - - with amp.scale_loss(loss1, optim1, loss_id=1) as scaled_loss: - scaled_loss.backward() - - with amp.scale_loss(loss2, [optim0, optim1], loss_id=2) as scaled_loss: - scaled_loss.backward() - -``num_losses`` and ``loss_id``\ s should be specified purely based on the set of -losses/backward passes. The use of multiple optimizers, or association of single or -multiple optimizers with each backward pass, is unrelated. - -Gradient accumulation across iterations ---------------------------------------- - -The following should "just work," and properly accommodate multiple models/optimizers/losses, as well as -gradient clipping via the `instructions above`_:: - - # If your intent is to simulate a larger batch size using gradient accumulation, - # you can divide the loss by the number of accumulation iterations (so that gradients - # will be averaged over that many iterations): - loss = loss/iters_to_accumulate - - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - - # Every iters_to_accumulate iterations, call step() and reset gradients: - if iter%iters_to_accumulate == 0: - # Gradient clipping if desired: - # torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_norm) - optimizer.step() - optimizer.zero_grad() - -As a minor performance optimization, you can pass ``delay_unscale=True`` -to ``amp.scale_loss`` until you're ready to ``step()``. You should only attempt ``delay_unscale=True`` -if you're sure you know what you're doing, because the interaction with gradient clipping and -multiple models/optimizers/losses can become tricky.:: - - if iter%iters_to_accumulate == 0: - # Every iters_to_accumulate iterations, unscale and step - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - optimizer.step() - optimizer.zero_grad() - else: - # Otherwise, accumulate gradients, don't unscale or step. - with amp.scale_loss(loss, optimizer, delay_unscale=True) as scaled_loss: - scaled_loss.backward() - -.. _`instructions above`: - https://nvidia.github.io/apex/advanced.html#gradient-clipping - -Custom data batch types ------------------------ - -The intention of Amp is that you never need to cast your input data manually, regardless of -``opt_level``. Amp accomplishes this by patching any models' ``forward`` methods to cast -incoming data appropriately for the ``opt_level``. But to cast incoming data, -Amp needs to know how. The patched ``forward`` will recognize and cast floating-point Tensors -(non-floating-point Tensors like IntTensors are not touched) and -Python containers of floating-point Tensors. However, if you wrap your Tensors in a custom class, -the casting logic doesn't know how to drill -through the tough custom shell to access and cast the juicy Tensor meat within. You need to tell -Amp how to cast your custom batch class, by assigning it a ``to`` method that accepts a ``torch.dtype`` -(e.g., ``torch.float16`` or ``torch.float32``) and returns an instance of the custom batch cast to -``dtype``. The patched ``forward`` checks for the presence of your ``to`` method, and will -invoke it with the correct type for the ``opt_level``. - -Example:: - - class CustomData(object): - def __init__(self): - self.tensor = torch.cuda.FloatTensor([1,2,3]) - - def to(self, dtype): - self.tensor = self.tensor.to(dtype) - return self - -.. warning:: - - Amp also forwards numpy ndarrays without casting them. If you send input data as a raw, unwrapped - ndarray, then later use it to create a Tensor within your ``model.forward``, this Tensor's type will - not depend on the ``opt_level``, and may or may not be correct. Users are encouraged to pass - castable data inputs (Tensors, collections of Tensors, or custom classes with a ``to`` method) - wherever possible. - -.. note:: - - Amp does not call ``.cuda()`` on any Tensors for you. Amp assumes that your original script - is already set up to move Tensors from the host to the device as needed. diff --git a/docs/source/amp.rst b/docs/source/amp.rst deleted file mode 100644 index 4bc1405..0000000 --- a/docs/source/amp.rst +++ /dev/null @@ -1,288 +0,0 @@ -.. role:: hidden - :class: hidden-section - -apex.amp -=================================== - -This page documents the updated API for Amp (Automatic Mixed Precision), -a tool to enable Tensor Core-accelerated training in only 3 lines of Python. - -A `runnable, comprehensive Imagenet example`_ demonstrating good practices can be found -on the Github page. - -GANs are a tricky case that many people have requested. A `comprehensive DCGAN example`_ -is under construction. - -If you already implemented Amp based on the instructions below, but it isn't behaving as expected, -please review `Advanced Amp Usage`_ to see if any topics match your use case. If that doesn't help, -`file an issue`_. - -.. _`file an issue`: - https://github.com/NVIDIA/apex/issues - -``opt_level``\ s and Properties -------------------------------- - -Amp allows users to easily experiment with different pure and mixed precision modes. -Commonly-used default modes are chosen by -selecting an "optimization level" or ``opt_level``; each ``opt_level`` establishes a set of -properties that govern Amp's implementation of pure or mixed precision training. -Finer-grained control of how a given ``opt_level`` behaves can be achieved by passing values for -particular properties directly to ``amp.initialize``. These manually specified values -override the defaults established by the ``opt_level``. - -Example:: - - # Declare model and optimizer as usual, with default (FP32) precision - model = torch.nn.Linear(D_in, D_out).cuda() - optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) - - # Allow Amp to perform casts as required by the opt_level - model, optimizer = amp.initialize(model, optimizer, opt_level="O1") - ... - # loss.backward() becomes: - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - ... - -Users **should not** manually cast their model or data to ``.half()``, regardless of what ``opt_level`` -or properties are chosen. Amp intends that users start with an existing default (FP32) script, -add the three lines corresponding to the Amp API, and begin training with mixed precision. -Amp can also be disabled, in which case the original script will behave exactly as it used to. -In this way, there's no risk adhering to the Amp API, and a lot of potential performance benefit. - -.. note:: - Because it's never necessary to manually cast your model (aside from the call ``amp.initialize``) - or input data, a script that adheres to the new API - can switch between different ``opt-level``\ s without having to make any other changes. - -.. _`runnable, comprehensive Imagenet example`: - https://github.com/NVIDIA/apex/tree/master/examples/imagenet - -.. _`comprehensive DCGAN example`: - https://github.com/NVIDIA/apex/tree/master/examples/dcgan - -.. _`Advanced Amp Usage`: - https://nvidia.github.io/apex/advanced.html - -Properties -********** - -Currently, the under-the-hood properties that govern pure or mixed precision training are the following: - -- ``cast_model_type``: Casts your model's parameters and buffers to the desired type. -- ``patch_torch_functions``: Patch all Torch functions and Tensor methods to perform Tensor Core-friendly ops like GEMMs and convolutions in FP16, and any ops that benefit from FP32 precision in FP32. -- ``keep_batchnorm_fp32``: To enhance precision and enable cudnn batchnorm (which improves performance), it's often beneficial to keep batchnorm weights in FP32 even if the rest of the model is FP16. -- ``master_weights``: Maintain FP32 master weights to accompany any FP16 model weights. FP32 master weights are stepped by the optimizer to enhance precision and capture small gradients. -- ``loss_scale``: If ``loss_scale`` is a float value, use this value as the static (fixed) loss scale. If ``loss_scale`` is the string ``"dynamic"``, adaptively adjust the loss scale over time. Dynamic loss scale adjustments are performed by Amp automatically. - -Again, you often don't need to specify these properties by hand. Instead, select an ``opt_level``, -which will set them up for you. After selecting an ``opt_level``, you can optionally pass property -kwargs as manual overrides. - -If you attempt to override a property that does not make sense for the selected ``opt_level``, -Amp will raise an error with an explanation. For example, selecting ``opt_level="O1"`` combined with -the override ``master_weights=True`` does not make sense. ``O1`` inserts casts -around Torch functions rather than model weights. Data, activations, and weights are recast -out-of-place on the fly as they flow through patched functions. Therefore, the model weights themselves -can (and should) remain FP32, and there is no need to maintain separate FP32 master weights. - -``opt_level``\ s -**************** - -Recognized ``opt_level``\ s are ``"O0"``, ``"O1"``, ``"O2"``, and ``"O3"``. - -``O0`` and ``O3`` are not true mixed precision, but they are useful for establishing accuracy and -speed baselines, respectively. - -``O1`` and ``O2`` are different implementations of mixed precision. Try both, and see -what gives the best speedup and accuracy for your model. - -``O0``: FP32 training -^^^^^^^^^^^^^^^^^^^^^^ -Your incoming model should be FP32 already, so this is likely a no-op. -``O0`` can be useful to establish an accuracy baseline. - -| Default properties set by ``O0``: -| ``cast_model_type=torch.float32`` -| ``patch_torch_functions=False`` -| ``keep_batchnorm_fp32=None`` (effectively, "not applicable," everything is FP32) -| ``master_weights=False`` -| ``loss_scale=1.0`` -| -| - -``O1``: Mixed Precision (recommended for typical use) -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Patch all Torch functions and Tensor methods to cast their inputs according to a whitelist-blacklist -model. Whitelist ops (for example, Tensor Core-friendly ops like GEMMs and convolutions) are performed -in FP16. Blacklist ops that benefit from FP32 precision (for example, softmax) -are performed in FP32. ``O1`` also uses dynamic loss scaling, unless overridden. - -| Default properties set by ``O1``: -| ``cast_model_type=None`` (not applicable) -| ``patch_torch_functions=True`` -| ``keep_batchnorm_fp32=None`` (again, not applicable, all model weights remain FP32) -| ``master_weights=None`` (not applicable, model weights remain FP32) -| ``loss_scale="dynamic"`` -| -| - -``O2``: "Almost FP16" Mixed Precision -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -``O2`` casts the model weights to FP16, -patches the model's ``forward`` method to cast input -data to FP16, keeps batchnorms in FP32, maintains FP32 master weights, -updates the optimizer's ``param_groups`` so that the ``optimizer.step()`` -acts directly on the FP32 weights (followed by FP32 master weight->FP16 model weight -copies if necessary), -and implements dynamic loss scaling (unless overridden). -Unlike ``O1``, ``O2`` does not patch Torch functions or Tensor methods. - -| Default properties set by ``O2``: -| ``cast_model_type=torch.float16`` -| ``patch_torch_functions=False`` -| ``keep_batchnorm_fp32=True`` -| ``master_weights=True`` -| ``loss_scale="dynamic"`` -| -| - -``O3``: FP16 training -^^^^^^^^^^^^^^^^^^^^^^ -``O3`` may not achieve the stability of the true mixed precision options ``O1`` and ``O2``. -However, it can be useful to establish a speed baseline for your model, against which -the performance of ``O1`` and ``O2`` can be compared. If your model uses batch normalization, -to establish "speed of light" you can try ``O3`` with the additional property override -``keep_batchnorm_fp32=True`` (which enables cudnn batchnorm, as stated earlier). - -| Default properties set by ``O3``: -| ``cast_model_type=torch.float16`` -| ``patch_torch_functions=False`` -| ``keep_batchnorm_fp32=False`` -| ``master_weights=False`` -| ``loss_scale=1.0`` -| -| - -Unified API ------------ - -.. automodule:: apex.amp -.. currentmodule:: apex.amp - -.. autofunction:: initialize - -.. autofunction:: scale_loss - -.. autofunction:: master_params - -Checkpointing -------------- - -To properly save and load your amp training, we introduce the ``amp.state_dict()``, which contains all ``loss_scaler``\ s and their corresponding unskipped steps, as well as ``amp.load_state_dict()`` to restore these attributes. - -In order to get bitwise accuracy, we recommend the following workflow:: - - # Initialization - opt_level = 'O1' - model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level) - - # Train your model - ... - - # Save checkpoint - checkpoint = { - 'model': model.state_dict(), - 'optimizer': optimizer.state_dict(), - 'amp': amp.state_dict() - } - torch.save(checkpoint, 'amp_checkpoint.pt') - ... - - # Restore - model = ... - optimizer = ... - checkpoint = torch.load('amp_checkpoint.pt') - - model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level) - model.load_state_dict(checkpoint['model']) - optimizer.load_state_dict(checkpoint['optimizer']) - amp.load_state_dict(checkpoint['amp']) - - # Continue training - ... - -Note that we recommend restoring the model using the same ``opt_level``. Also note that we recommend calling the ``load_state_dict`` methods after ``amp.initialize``. - -Advanced use cases ------------------- - -The unified Amp API supports gradient accumulation across iterations, -multiple backward passes per iteration, multiple models/optimizers, -custom/user-defined autograd functions, and custom data batch classes. Gradient clipping and GANs also -require special treatment, but this treatment does not need to change -for different ``opt_level``\ s. Further details can be found here: - -.. toctree:: - :maxdepth: 1 - - advanced - -Transition guide for old API users ----------------------------------- - -We strongly encourage moving to the new Amp API, because it's more versatile, easier to use, and future proof. The original :class:`FP16_Optimizer` and the old "Amp" API are deprecated, and subject to removal at at any time. - -For users of the old "Amp" API -****************************** - -In the new API, ``opt-level O1`` performs the same patching of the Torch namespace as the old thing -called "Amp." -However, the new API allows static or dynamic loss scaling, while the old API only allowed dynamic loss scaling. - -In the new API, the old call to ``amp_handle = amp.init()``, and the returned ``amp_handle``, are no -longer exposed or necessary. The new ``amp.initialize()`` does the duty of ``amp.init()`` (and more). -Therefore, any existing calls to ``amp_handle = amp.init()`` should be deleted. - -The functions formerly exposed through ``amp_handle`` are now free -functions accessible through the ``amp`` module. - -The backward context manager must be changed accordingly:: - - # old API - with amp_handle.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - -> - # new API - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - -For now, the deprecated "Amp" API documentation can still be found on the Github README: https://github.com/NVIDIA/apex/tree/master/apex/amp. The old API calls that `annotate user functions`_ to run -with a particular precision are still honored by the new API. - -.. _`annotate user functions`: - https://github.com/NVIDIA/apex/tree/master/apex/amp#annotating-user-functions - - -For users of the old FP16_Optimizer -*********************************** - -``opt-level O2`` is equivalent to :class:`FP16_Optimizer` with ``dynamic_loss_scale=True``. -Once again, the backward pass must be changed to the unified version:: - - optimizer.backward(loss) - -> - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - -One annoying aspect of FP16_Optimizer was that the user had to manually convert their model to half -(either by calling ``.half()`` on it, or using a function or module wrapper from -``apex.fp16_utils``), and also manually call ``.half()`` on input data. **Neither of these are -necessary in the new API. No matter what --opt-level -you choose, you can and should simply build your model and pass input data in the default FP32 format.** -The new Amp API will perform the right conversions during -``model, optimizer = amp.initialize(model, optimizer, opt_level=....)`` based on the ``--opt-level`` -and any overridden flags. Floating point input data may be FP32 or FP16, but you may as well just -let it be FP16, because the ``model`` returned by ``amp.initialize`` will have its ``forward`` -method patched to cast the input data appropriately. diff --git a/docs/source/conf.py b/docs/source/conf.py deleted file mode 100644 index 4477a28..0000000 --- a/docs/source/conf.py +++ /dev/null @@ -1,248 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# -# PyTorch documentation build configuration file, created by -# sphinx-quickstart on Fri Dec 23 13:31:47 2016. -# -# This file is execfile()d with the current directory set to its -# containing dir. -# -# Note that not all possible configuration values are present in this -# autogenerated file. -# -# All configuration values have a default; values that are commented out -# serve to show the default. - -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. -# -import os -import sys -sys.path.insert(0, os.path.abspath('.')) -# sys.path.insert(0, os.path.abspath('../../apex/parallel/')) -import apex -# import multiproc -import sphinx_rtd_theme - - -# -- General configuration ------------------------------------------------ - -# If your documentation needs a minimal Sphinx version, state it here. -# -# needs_sphinx = '1.0' - -# Add any Sphinx extension module names here, as strings. They can be -# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom -# ones. -extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.autosummary', - 'sphinx.ext.doctest', - 'sphinx.ext.intersphinx', - 'sphinx.ext.todo', - 'sphinx.ext.coverage', - 'sphinx.ext.mathjax', - 'sphinx.ext.napoleon', - 'sphinx.ext.viewcode', - 'sphinx.ext.extlinks', -] - -napoleon_use_ivar = True - -# Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] - -# The suffix(es) of source filenames. -# You can specify multiple suffix as a list of string: -# -# source_suffix = ['.rst', '.md'] -source_suffix = '.rst' - -# The master toctree document. -master_doc = 'index' - -# General information about the project. -project = 'Apex' -copyright = '2018' -author = 'Christian Sarofeen, Natalia Gimelshein, Michael Carilli, Raul Puri' - -# The version info for the project you're documenting, acts as replacement for -# |version| and |release|, also used in various other places throughout the -# built documents. -# -# The short X.Y version. -# TODO: change to [:2] at v1.0 -# version = 'master (' + torch.__version__ + ' )' -version = '0.1' -# The full version, including alpha/beta/rc tags. -# TODO: verify this works as expected -release = '0.1.0' - -# The language for content autogenerated by Sphinx. Refer to documentation -# for a list of supported languages. -# -# This is also used if you do content translation via gettext catalogs. -# Usually you set "language" from the command line for these cases. -language = None - -# List of patterns, relative to source directory, that match files and -# directories to ignore when looking for source files. -# This pattern also affects html_static_path and html_extra_path -exclude_patterns = [] - -# The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' - -# If true, `todo` and `todoList` produce output, else they produce nothing. -todo_include_todos = True - - -# -- Options for HTML output ------------------------------------------------- - -# The theme to use for HTML and HTML Help pages. See the documentation for -# a list of builtin themes. -# -html_theme = 'sphinx_rtd_theme' -html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] - -# Theme options are theme-specific and customize the look and feel of a theme -# further. For a list of options available for each theme, see the -# documentation. -# -html_theme_options = { - 'collapse_navigation': False, - 'display_version': True, - 'logo_only': True, -} - -# html_logo = '_static/img/nv-pytorch2.png' - -# Add any paths that contain custom static files (such as style sheets) here, -# relative to this directory. They are copied after the builtin static files, -# so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] - -# html_style_path = 'css/pytorch_theme.css' -html_context = { - 'css_files': [ - 'https://fonts.googleapis.com/css?family=Lato', - '_static/css/pytorch_theme.css' - ], -} - - -# -- Options for HTMLHelp output --------------------------------------------- - -# Output file base name for HTML help builder. -htmlhelp_basename = 'PyTorchdoc' - - -# -- Options for LaTeX output ------------------------------------------------ - -latex_elements = { - # The paper size ('letterpaper' or 'a4paper'). - # - # 'papersize': 'letterpaper', - - # The font size ('10pt', '11pt' or '12pt'). - # - # 'pointsize': '10pt', - - # Additional stuff for the LaTeX preamble. - # - # 'preamble': '', - - # Latex figure (float) alignment - # - # 'figure_align': 'htbp', -} - -# Grouping the document tree into LaTeX files. List of tuples -# (source start file, target name, title, -# author, documentclass [howto, manual, or own class]). -latex_documents = [ - (master_doc, 'apex.tex', 'Apex Documentation', - 'Torch Contributors', 'manual'), -] - - -# -- Options for manual page output ------------------------------------------ - -# One entry per manual page. List of tuples -# (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, 'Apex', 'Apex Documentation', - [author], 1) -] - - -# -- Options for Texinfo output ---------------------------------------------- - -# Grouping the document tree into Texinfo files. List of tuples -# (source start file, target name, title, author, -# dir menu entry, description, category) -texinfo_documents = [ - (master_doc, 'Apex', 'Apex Documentation', - author, 'Apex', 'One line description of project.', - 'Miscellaneous'), -] - - -# Example configuration for intersphinx: refer to the Python standard library. -intersphinx_mapping = { - 'python': ('https://docs.python.org/', None), - 'numpy': ('http://docs.scipy.org/doc/numpy/', None), -} - -# -- A patch that prevents Sphinx from cross-referencing ivar tags ------- -# See http://stackoverflow.com/a/41184353/3343043 - -from docutils import nodes -from sphinx.util.docfields import TypedField -from sphinx import addnodes - - -def patched_make_field(self, types, domain, items, **kw): - # `kw` catches `env=None` needed for newer sphinx while maintaining - # backwards compatibility when passed along further down! - - # type: (List, unicode, Tuple) -> nodes.field - def handle_item(fieldarg, content): - par = nodes.paragraph() - par += addnodes.literal_strong('', fieldarg) # Patch: this line added - # par.extend(self.make_xrefs(self.rolename, domain, fieldarg, - # addnodes.literal_strong)) - if fieldarg in types: - par += nodes.Text(' (') - # NOTE: using .pop() here to prevent a single type node to be - # inserted twice into the doctree, which leads to - # inconsistencies later when references are resolved - fieldtype = types.pop(fieldarg) - if len(fieldtype) == 1 and isinstance(fieldtype[0], nodes.Text): - typename = u''.join(n.astext() for n in fieldtype) - typename = typename.replace('int', 'python:int') - typename = typename.replace('long', 'python:long') - typename = typename.replace('float', 'python:float') - typename = typename.replace('type', 'python:type') - par.extend(self.make_xrefs(self.typerolename, domain, typename, - addnodes.literal_emphasis, **kw)) - else: - par += fieldtype - par += nodes.Text(')') - par += nodes.Text(' -- ') - par += content - return par - - fieldname = nodes.field_name('', self.label) - if len(items) == 1 and self.can_collapse: - fieldarg, content = items[0] - bodynode = handle_item(fieldarg, content) - else: - bodynode = self.list_type() - for fieldarg, content in items: - bodynode += nodes.list_item('', handle_item(fieldarg, content)) - fieldbody = nodes.field_body('', bodynode) - return nodes.field('', fieldname, fieldbody) - -TypedField.make_field = patched_make_field diff --git a/docs/source/fp16_utils.rst b/docs/source/fp16_utils.rst deleted file mode 100644 index b6b3da5..0000000 --- a/docs/source/fp16_utils.rst +++ /dev/null @@ -1,59 +0,0 @@ -.. role:: hidden - :class: hidden-section - -apex.fp16_utils -=================================== - -This submodule contains utilities designed to streamline the mixed precision training recipe -presented by NVIDIA `on Parallel Forall`_ and in GTC 2018 Sessions -`Training Neural Networks with Mixed Precision: Theory and Practice`_ and -`Training Neural Networks with Mixed Precision: Real Examples`_. -For Pytorch users, Real Examples in particular is recommended. - -Full runnable Python scripts demonstrating ``apex.fp16_utils`` -can be found on the Github page: - -| `Simple FP16_Optimizer demos`_ -| -| `Distributed Mixed Precision Training with imagenet`_ -| -| `Mixed Precision Training with word_language_model`_ -| -| - -.. _`on Parallel Forall`: - https://devblogs.nvidia.com/mixed-precision-training-deep-neural-networks/ -.. _`Training Neural Networks with Mixed Precision: Theory and Practice`: - http://on-demand.gputechconf.com/gtc/2018/video/S8923/ -.. _`Training Neural Networks with Mixed Precision: Real Examples`: - http://on-demand.gputechconf.com/gtc/2018/video/S81012/ -.. _`Simple FP16_Optimizer demos`: - https://github.com/NVIDIA/apex/tree/master/examples/FP16_Optimizer_simple -.. _`Distributed Mixed Precision Training with imagenet`: - https://github.com/NVIDIA/apex/tree/master/examples/imagenet -.. _`Mixed Precision Training with word_language_model`: - https://github.com/NVIDIA/apex/tree/master/examples/word_language_model - -.. automodule:: apex.fp16_utils -.. currentmodule:: apex.fp16_utils - -Automatic management of master params + loss scaling ----------------------------------------------------- - -.. autoclass:: FP16_Optimizer - :members: - -.. autoclass:: LossScaler - :members: - -.. autoclass:: DynamicLossScaler - :members: - -Manual master parameter management ----------------------------------- - -.. autofunction:: prep_param_lists - -.. autofunction:: master_params_to_model_params - -.. autofunction:: model_grads_to_master_grads diff --git a/docs/source/index.rst b/docs/source/index.rst deleted file mode 100644 index c7efc16..0000000 --- a/docs/source/index.rst +++ /dev/null @@ -1,53 +0,0 @@ -.. PyTorch documentation master file, created by - sphinx-quickstart on Fri Dec 23 13:31:47 2016. - You can adapt this file completely to your liking, but it should at least - contain the root `toctree` directive. - -:github_url: https://github.com/nvidia/apex - -Apex (A PyTorch Extension) -=================================== - -This site contains the API documentation for Apex (https://github.com/nvidia/apex), -a Pytorch extension with NVIDIA-maintained utilities to streamline mixed precision and distributed training. Some of the code here will be included in upstream Pytorch eventually. The intention of Apex is to make up-to-date utilities available to users as quickly as possible. - -Installation instructions can be found here: https://github.com/NVIDIA/apex#quick-start. - -Some other useful material, including GTC 2019 and Pytorch DevCon 2019 Slides, can be found here: https://github.com/mcarilli/mixed_precision_references. - -.. toctree:: - :maxdepth: 1 - :caption: AMP: Automatic Mixed Precision - - amp - -.. toctree:: - :maxdepth: 1 - :caption: Distributed Training - - parallel - -.. toctree:: - :maxdepth: 1 - :caption: Fused Optimizers - - optimizers - -.. toctree:: - :maxdepth: 1 - :caption: Fused Layer Norm - - layernorm - -.. .. toctree:: - :maxdepth: 1 - :caption: Deprecated mixed precision API - fp16_util - -.. RNN - -Indices and tables -================== - -* :ref:`genindex` -* :ref:`modindex` diff --git a/docs/source/layernorm.rst b/docs/source/layernorm.rst deleted file mode 100644 index 6eedb4e..0000000 --- a/docs/source/layernorm.rst +++ /dev/null @@ -1,17 +0,0 @@ -.. role:: hidden - :class: hidden-section - -apex.normalization.fused_layer_norm -=================================== - -.. automodule:: apex.normalization -.. currentmodule:: apex.normalization - -.. FusedAdam - ---------- - -.. autoclass:: FusedLayerNorm - :members: - -.. autoclass:: FusedRMSNorm - :members: diff --git a/docs/source/optimizers.rst b/docs/source/optimizers.rst deleted file mode 100644 index 407f077..0000000 --- a/docs/source/optimizers.rst +++ /dev/null @@ -1,23 +0,0 @@ -.. role:: hidden - :class: hidden-section - -apex.optimizers -=================================== - -.. automodule:: apex.optimizers -.. currentmodule:: apex.optimizers - -.. FusedAdam - ---------- - -.. autoclass:: FusedAdam - :members: - -.. autoclass:: FusedLAMB - :members: - -.. autoclass:: FusedNovoGrad - :members: - -.. autoclass:: FusedSGD - :members: diff --git a/docs/source/parallel.rst b/docs/source/parallel.rst deleted file mode 100644 index 73759ee..0000000 --- a/docs/source/parallel.rst +++ /dev/null @@ -1,25 +0,0 @@ -.. role:: hidden - :class: hidden-section - -apex.parallel -=================================== - -.. automodule:: apex.parallel -.. currentmodule:: apex.parallel - -.. DistributedDataParallel - ---------- - -.. autoclass:: DistributedDataParallel - :members: - -.. autoclass:: Reducer - :members: - -.. autoclass:: SyncBatchNorm - :members: - -Utility functions ----------------------------------- - -.. autofunction:: convert_syncbn_model diff --git a/examples/README.md b/examples/README.md deleted file mode 100644 index 6cb9231..0000000 --- a/examples/README.md +++ /dev/null @@ -1,4 +0,0 @@ -This directory contains examples illustrating Apex mixed precision and distributed tools. - -**Note for users of the pre-unification API**: -`deprecated_api` contains examples illustrating the old (pre-unified) APIs. These APIs will be removed soon, and users are strongly encouraged to switch. The separate mixed precision tools called `Amp` and `FP16_Optimizer` in the old API are exposed via different flags/optimization levels in the new API. diff --git a/examples/dcgan/README.md b/examples/dcgan/README.md deleted file mode 100644 index 9fc896c..0000000 --- a/examples/dcgan/README.md +++ /dev/null @@ -1,41 +0,0 @@ -# Mixed Precision DCGAN Training in PyTorch - -`main_amp.py` is based on [https://github.com/pytorch/examples/tree/master/dcgan](https://github.com/pytorch/examples/tree/master/dcgan). -It implements Automatic Mixed Precision (Amp) training of the DCGAN example for different datasets. Command-line flags forwarded to `amp.initialize` are used to easily manipulate and switch between various pure and mixed precision "optimization levels" or `opt_level`s. For a detailed explanation of `opt_level`s, see the [updated API guide](https://nvidia.github.io/apex/amp.html). - -We introduce these changes to the PyTorch DCGAN example as described in the [Multiple models/optimizers/losses](https://nvidia.github.io/apex/advanced.html#multiple-models-optimizers-losses) section of the documentation:: -``` -# Added after models and optimizers construction -[netD, netG], [optimizerD, optimizerG] = amp.initialize( - [netD, netG], [optimizerD, optimizerG], opt_level=opt.opt_level, num_losses=3) -... -# loss.backward() changed to: -with amp.scale_loss(errD_real, optimizerD, loss_id=0) as errD_real_scaled: - errD_real_scaled.backward() -... -with amp.scale_loss(errD_fake, optimizerD, loss_id=1) as errD_fake_scaled: - errD_fake_scaled.backward() -... -with amp.scale_loss(errG, optimizerG, loss_id=2) as errG_scaled: - errG_scaled.backward() -``` - -Note that we use different `loss_scalers` for each computed loss. -Using a separate loss scaler per loss is [optional, not required](https://nvidia.github.io/apex/advanced.html#optionally-have-amp-use-a-different-loss-scaler-per-loss). - -To improve the numerical stability, we swapped `nn.Sigmoid() + nn.BCELoss()` to `nn.BCEWithLogitsLoss()`. - -With the new Amp API **you never need to explicitly convert your model, or the input data, to half().** - -"Pure FP32" training: -``` -$ python main_amp.py --opt_level O0 -``` -Recommended mixed precision training: -``` -$ python main_amp.py --opt_level O1 -``` - -Have a look at the original [DCGAN example](https://github.com/pytorch/examples/tree/master/dcgan) for more information about the used arguments. - -To enable mixed precision training, we introduce the `--opt_level` argument. diff --git a/examples/dcgan/main_amp.py b/examples/dcgan/main_amp.py deleted file mode 100644 index be1a289..0000000 --- a/examples/dcgan/main_amp.py +++ /dev/null @@ -1,274 +0,0 @@ -from __future__ import print_function -import argparse -import os -import random -import torch -import torch.nn as nn -import torch.nn.parallel -import torch.backends.cudnn as cudnn -import torch.optim as optim -import torch.utils.data -import torchvision.datasets as dset -import torchvision.transforms as transforms -import torchvision.utils as vutils - -try: - from apex import amp -except ImportError: - raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") - - -parser = argparse.ArgumentParser() -parser.add_argument('--dataset', default='cifar10', help='cifar10 | lsun | mnist |imagenet | folder | lfw | fake') -parser.add_argument('--dataroot', default='./', help='path to dataset') -parser.add_argument('--workers', type=int, help='number of data loading workers', default=2) -parser.add_argument('--batchSize', type=int, default=64, help='input batch size') -parser.add_argument('--imageSize', type=int, default=64, help='the height / width of the input image to network') -parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector') -parser.add_argument('--ngf', type=int, default=64) -parser.add_argument('--ndf', type=int, default=64) -parser.add_argument('--niter', type=int, default=25, help='number of epochs to train for') -parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002') -parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') -parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use') -parser.add_argument('--netG', default='', help="path to netG (to continue training)") -parser.add_argument('--netD', default='', help="path to netD (to continue training)") -parser.add_argument('--outf', default='.', help='folder to output images and model checkpoints') -parser.add_argument('--manualSeed', type=int, help='manual seed') -parser.add_argument('--classes', default='bedroom', help='comma separated list of classes for the lsun data set') -parser.add_argument('--opt_level', default='O1', help='amp opt_level, default="O1"') - -opt = parser.parse_args() -print(opt) - - -try: - os.makedirs(opt.outf) -except OSError: - pass - -if opt.manualSeed is None: - opt.manualSeed = 2809 -print("Random Seed: ", opt.manualSeed) -random.seed(opt.manualSeed) -torch.manual_seed(opt.manualSeed) - -cudnn.benchmark = True - - -if opt.dataset in ['imagenet', 'folder', 'lfw']: - # folder dataset - dataset = dset.ImageFolder(root=opt.dataroot, - transform=transforms.Compose([ - transforms.Resize(opt.imageSize), - transforms.CenterCrop(opt.imageSize), - transforms.ToTensor(), - transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), - ])) - nc=3 -elif opt.dataset == 'lsun': - classes = [ c + '_train' for c in opt.classes.split(',')] - dataset = dset.LSUN(root=opt.dataroot, classes=classes, - transform=transforms.Compose([ - transforms.Resize(opt.imageSize), - transforms.CenterCrop(opt.imageSize), - transforms.ToTensor(), - transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), - ])) - nc=3 -elif opt.dataset == 'cifar10': - dataset = dset.CIFAR10(root=opt.dataroot, download=True, - transform=transforms.Compose([ - transforms.Resize(opt.imageSize), - transforms.ToTensor(), - transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), - ])) - nc=3 - -elif opt.dataset == 'mnist': - dataset = dset.MNIST(root=opt.dataroot, download=True, - transform=transforms.Compose([ - transforms.Resize(opt.imageSize), - transforms.ToTensor(), - transforms.Normalize((0.5,), (0.5,)), - ])) - nc=1 - -elif opt.dataset == 'fake': - dataset = dset.FakeData(image_size=(3, opt.imageSize, opt.imageSize), - transform=transforms.ToTensor()) - nc=3 - -assert dataset -dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize, - shuffle=True, num_workers=int(opt.workers)) - -device = torch.device("cuda:0") -ngpu = int(opt.ngpu) -nz = int(opt.nz) -ngf = int(opt.ngf) -ndf = int(opt.ndf) - - -# custom weights initialization called on netG and netD -def weights_init(m): - classname = m.__class__.__name__ - if classname.find('Conv') != -1: - m.weight.data.normal_(0.0, 0.02) - elif classname.find('BatchNorm') != -1: - m.weight.data.normal_(1.0, 0.02) - m.bias.data.fill_(0) - - -class Generator(nn.Module): - def __init__(self, ngpu): - super(Generator, self).__init__() - self.ngpu = ngpu - self.main = nn.Sequential( - # input is Z, going into a convolution - nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False), - nn.BatchNorm2d(ngf * 8), - nn.ReLU(True), - # state size. (ngf*8) x 4 x 4 - nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), - nn.BatchNorm2d(ngf * 4), - nn.ReLU(True), - # state size. (ngf*4) x 8 x 8 - nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), - nn.BatchNorm2d(ngf * 2), - nn.ReLU(True), - # state size. (ngf*2) x 16 x 16 - nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), - nn.BatchNorm2d(ngf), - nn.ReLU(True), - # state size. (ngf) x 32 x 32 - nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False), - nn.Tanh() - # state size. (nc) x 64 x 64 - ) - - def forward(self, input): - if input.is_cuda and self.ngpu > 1: - output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) - else: - output = self.main(input) - return output - - -netG = Generator(ngpu).to(device) -netG.apply(weights_init) -if opt.netG != '': - netG.load_state_dict(torch.load(opt.netG)) -print(netG) - - -class Discriminator(nn.Module): - def __init__(self, ngpu): - super(Discriminator, self).__init__() - self.ngpu = ngpu - self.main = nn.Sequential( - # input is (nc) x 64 x 64 - nn.Conv2d(nc, ndf, 4, 2, 1, bias=False), - nn.LeakyReLU(0.2, inplace=True), - # state size. (ndf) x 32 x 32 - nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), - nn.BatchNorm2d(ndf * 2), - nn.LeakyReLU(0.2, inplace=True), - # state size. (ndf*2) x 16 x 16 - nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), - nn.BatchNorm2d(ndf * 4), - nn.LeakyReLU(0.2, inplace=True), - # state size. (ndf*4) x 8 x 8 - nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), - nn.BatchNorm2d(ndf * 8), - nn.LeakyReLU(0.2, inplace=True), - # state size. (ndf*8) x 4 x 4 - nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False), - ) - - def forward(self, input): - if input.is_cuda and self.ngpu > 1: - output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) - else: - output = self.main(input) - - return output.view(-1, 1).squeeze(1) - - -netD = Discriminator(ngpu).to(device) -netD.apply(weights_init) -if opt.netD != '': - netD.load_state_dict(torch.load(opt.netD)) -print(netD) - -criterion = nn.BCEWithLogitsLoss() - -fixed_noise = torch.randn(opt.batchSize, nz, 1, 1, device=device) -real_label = 1 -fake_label = 0 - -# setup optimizer -optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) -optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) - -[netD, netG], [optimizerD, optimizerG] = amp.initialize( - [netD, netG], [optimizerD, optimizerG], opt_level=opt.opt_level, num_losses=3) - -for epoch in range(opt.niter): - for i, data in enumerate(dataloader, 0): - ############################ - # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) - ########################### - # train with real - netD.zero_grad() - real_cpu = data[0].to(device) - batch_size = real_cpu.size(0) - label = torch.full((batch_size,), real_label, device=device) - - output = netD(real_cpu) - errD_real = criterion(output, label) - with amp.scale_loss(errD_real, optimizerD, loss_id=0) as errD_real_scaled: - errD_real_scaled.backward() - D_x = output.mean().item() - - # train with fake - noise = torch.randn(batch_size, nz, 1, 1, device=device) - fake = netG(noise) - label.fill_(fake_label) - output = netD(fake.detach()) - errD_fake = criterion(output, label) - with amp.scale_loss(errD_fake, optimizerD, loss_id=1) as errD_fake_scaled: - errD_fake_scaled.backward() - D_G_z1 = output.mean().item() - errD = errD_real + errD_fake - optimizerD.step() - - ############################ - # (2) Update G network: maximize log(D(G(z))) - ########################### - netG.zero_grad() - label.fill_(real_label) # fake labels are real for generator cost - output = netD(fake) - errG = criterion(output, label) - with amp.scale_loss(errG, optimizerG, loss_id=2) as errG_scaled: - errG_scaled.backward() - D_G_z2 = output.mean().item() - optimizerG.step() - - print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f' - % (epoch, opt.niter, i, len(dataloader), - errD.item(), errG.item(), D_x, D_G_z1, D_G_z2)) - if i % 100 == 0: - vutils.save_image(real_cpu, - '%s/real_samples.png' % opt.outf, - normalize=True) - fake = netG(fixed_noise) - vutils.save_image(fake.detach(), - '%s/amp_fake_samples_epoch_%03d.png' % (opt.outf, epoch), - normalize=True) - - # do checkpointing - torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (opt.outf, epoch)) - torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (opt.outf, epoch)) - - diff --git a/examples/docker/Dockerfile b/examples/docker/Dockerfile deleted file mode 100644 index 88a3bc7..0000000 --- a/examples/docker/Dockerfile +++ /dev/null @@ -1,16 +0,0 @@ -# Base image must at least have pytorch and CUDA installed. -ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:22.02-py3 -FROM $BASE_IMAGE -ARG BASE_IMAGE -RUN echo "Installing Apex on top of ${BASE_IMAGE}" -# make sure we don't overwrite some existing directory called "apex" -WORKDIR /tmp/unique_for_apex -# uninstall Apex if present, twice to make absolutely sure :) -RUN pip uninstall -y apex || : -RUN pip uninstall -y apex || : -# SHA is something the user can touch to force recreation of this Docker layer, -# and therefore force cloning of the latest version of Apex -RUN SHA=ToUcHMe git clone https://github.com/NVIDIA/apex.git -WORKDIR /tmp/unique_for_apex/apex -RUN pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" . -WORKDIR /workspace diff --git a/examples/docker/README.md b/examples/docker/README.md deleted file mode 100644 index 3969af6..0000000 --- a/examples/docker/README.md +++ /dev/null @@ -1,40 +0,0 @@ -## Option 1: Create a new container with Apex - -**Dockerfile** installs the latest Apex on top of an existing image. Run -``` -docker build -t new_image_with_apex . -``` -By default, **Dockerfile** uses NVIDIA's Pytorch container as the base image, -which requires an NVIDIA GPU Cloud (NGC) account. If you don't have an NGC account, you can sign up for free by following the instructions [here](https://docs.nvidia.com/ngc/ngc-getting-started-guide/index.html#generating-api-key). - -Alternatively, you can supply your own base image via the `BASE_IMAGE` build-arg. -`BASE_IMAGE` must have Pytorch and Cuda installed. For example, any -`-devel` image for Pytorch 1.0 and later from the -[official Pytorch Dockerhub](https://hub.docker.com/r/pytorch/pytorch) may be used: -``` -docker build --build-arg BASE_IMAGE=1.3-cuda10.1-cudnn7-devel -t new_image_with_apex . -``` - -If you want to rebuild your image, and force the latest Apex to be cloned and installed, make any small change to the `SHA` variable in **Dockerfile**. - -**Warning:** -Currently, the non-`-devel` images on Pytorch Dockerhub do not contain the Cuda compiler `nvcc`. Therefore, -images whose name does not contain `-devel` are not eligible candidates for `BASE_IMAGE`. - -### Running your Apex container - -Like any Cuda-enabled Pytorch container, a container with Apex should be run via [nvidia-docker](https://github.com/NVIDIA/nvidia-docker), for example: -``` -docker run --runtime=nvidia -it --rm --ipc=host new_image_with_apex -``` - -## Option 2: Install Apex in a running container - -Instead of building a new container, it is also a viable option to `git clone https://github.com/NVIDIA/apex.git` on bare metal, mount the Apex repo into your container at launch by running, for example, -``` -docker run --runtime=nvidia -it --rm --ipc=host -v /bare/metal/apex:/apex/in/container -``` -then go to /apex/in/container within the running container and -``` -pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" . -``` diff --git a/examples/imagenet/README.md b/examples/imagenet/README.md deleted file mode 100644 index 257d4a7..0000000 --- a/examples/imagenet/README.md +++ /dev/null @@ -1,183 +0,0 @@ -# Mixed Precision ImageNet Training in PyTorch - -`main_amp.py` is based on [https://github.com/pytorch/examples/tree/master/imagenet](https://github.com/pytorch/examples/tree/master/imagenet). -It implements Automatic Mixed Precision (Amp) training of popular model architectures, such as ResNet, AlexNet, and VGG, on the ImageNet dataset. Command-line flags forwarded to `amp.initialize` are used to easily manipulate and switch between various pure and mixed precision "optimization levels" or `opt_level`s. For a detailed explanation of `opt_level`s, see the [updated API guide](https://nvidia.github.io/apex/amp.html). - -Three lines enable Amp: -``` -# Added after model and optimizer construction -model, optimizer = amp.initialize(model, optimizer, flags...) -... -# loss.backward() changed to: -with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() -``` - -With the new Amp API **you never need to explicitly convert your model, or the input data, to half().** - -## Requirements - -- Download the ImageNet dataset and move validation images to labeled subfolders - - The following script may be helpful: https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh - -## Training - -To train a model, create softlinks to the Imagenet dataset, then run `main.py` with the desired model architecture, as shown in `Example commands` below. - -The default learning rate schedule is set for ResNet50. `main_amp.py` script rescales the learning rate according to the global batch size (number of distributed processes \* per-process minibatch size). - -## Example commands - -**Note:** batch size `--b 224` assumes your GPUs have >=16GB of onboard memory. You may be able to increase this to 256, but that's cutting it close, so it may out-of-memory for different Pytorch versions. - -**Note:** All of the following use 4 dataloader subprocesses (`--workers 4`) to reduce potential -CPU data loading bottlenecks. - -**Note:** `--opt-level` `O1` and `O2` both use dynamic loss scaling by default unless manually overridden. -`--opt-level` `O0` and `O3` (the "pure" training modes) do not use loss scaling by default. -`O0` and `O3` can be told to use loss scaling via manual overrides, but using loss scaling with `O0` -(pure FP32 training) does not really make sense, and will trigger a warning. - -Softlink training and validation datasets into the current directory: -``` -$ ln -sf /data/imagenet/train-jpeg/ train -$ ln -sf /data/imagenet/val-jpeg/ val -``` - -### Summary - -Amp allows easy experimentation with various pure and mixed precision options. -``` -$ python main_amp.py -a resnet50 --b 128 --workers 4 --opt-level O0 ./ -$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O3 ./ -$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O3 --keep-batchnorm-fp32 True ./ -$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O1 ./ -$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O1 --loss-scale 128.0 ./ -$ python -m torch.distributed.launch --nproc_per_node=2 main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O1 ./ -$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O2 ./ -$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O2 --loss-scale 128.0 ./ -$ python -m torch.distributed.launch --nproc_per_node=2 main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O2 ./ -``` -Options are explained below. Again, the [updated API guide](https://nvidia.github.io/apex/amp.html) provides more detail. - -#### `--opt-level O0` (FP32 training) and `O3` (FP16 training) - -"Pure FP32" training: -``` -$ python main_amp.py -a resnet50 --b 128 --workers 4 --opt-level O0 ./ -``` -"Pure FP16" training: -``` -$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O3 ./ -``` -FP16 training with FP32 batchnorm: -``` -$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O3 --keep-batchnorm-fp32 True ./ -``` -Keeping the batchnorms in FP32 improves stability and allows Pytorch -to use cudnn batchnorms, which significantly increases speed in Resnet50. - -The `O3` options might not converge, because they are not true mixed precision. -However, they can be useful to establish "speed of light" performance for -your model, which provides a baseline for comparison with `O1` and `O2`. -For Resnet50 in particular, `--opt-level O3 --keep-batchnorm-fp32 True` establishes -the "speed of light." (Without `--keep-batchnorm-fp32`, it's slower, because it does -not use cudnn batchnorm.) - -#### `--opt-level O1` (Official Mixed Precision recipe, recommended for typical use) - -`O1` patches Torch functions to cast inputs according to a whitelist-blacklist model. -FP16-friendly (Tensor Core) ops like gemms and convolutions run in FP16, while ops -that benefit from FP32, like batchnorm and softmax, run in FP32. -Also, dynamic loss scaling is used by default. -``` -$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O1 ./ -``` -`O1` overridden to use static loss scaling: -``` -$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O1 --loss-scale 128.0 -``` -Distributed training with 2 processes (1 GPU per process, see **Distributed training** below -for more detail) -``` -$ python -m torch.distributed.launch --nproc_per_node=2 main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O1 ./ -``` -For best performance, set `--nproc_per_node` equal to the total number of GPUs on the node -to use all available resources. - -#### `--opt-level O2` ("Almost FP16" mixed precision. More dangerous than O1.) - -`O2` exists mainly to support some internal use cases. Please prefer `O1`. - -`O2` casts the model to FP16, keeps batchnorms in FP32, -maintains master weights in FP32, and implements -dynamic loss scaling by default. (Unlike --opt-level O1, --opt-level O2 -does not patch Torch functions.) -``` -$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O2 ./ -``` -"Fast mixed precision" overridden to use static loss scaling: -``` -$ python main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O2 --loss-scale 128.0 ./ -``` -Distributed training with 2 processes (1 GPU per process) -``` -$ python -m torch.distributed.launch --nproc_per_node=2 main_amp.py -a resnet50 --b 224 --workers 4 --opt-level O2 ./ -``` - -## Distributed training - -`main_amp.py` optionally uses `apex.parallel.DistributedDataParallel` (DDP) for multiprocess training with one GPU per process. -``` -model = apex.parallel.DistributedDataParallel(model) -``` -is a drop-in replacement for -``` -model = torch.nn.parallel.DistributedDataParallel(model, - device_ids=[arg.local_rank], - output_device=arg.local_rank) -``` -(because Torch DDP permits multiple GPUs per process, with Torch DDP you are required to -manually specify the device to run on and the output device. -With Apex DDP, it uses only the current device by default). - -The choice of DDP wrapper (Torch or Apex) is orthogonal to the use of Amp and other Apex tools. It is safe to use `apex.amp` with either `torch.nn.parallel.DistributedDataParallel` or `apex.parallel.DistributedDataParallel`. In the future, I may add some features that permit optional tighter integration between `Amp` and `apex.parallel.DistributedDataParallel` for marginal performance benefits, but currently, there's no compelling reason to use Apex DDP versus Torch DDP for most models. - -To use DDP with `apex.amp`, the only gotcha is that -``` -model, optimizer = amp.initialize(model, optimizer, flags...) -``` -must precede -``` -model = DDP(model) -``` -If DDP wrapping occurs before `amp.initialize`, `amp.initialize` will raise an error. - -With both Apex DDP and Torch DDP, you must also call `torch.cuda.set_device(args.local_rank)` within -each process prior to initializing your model or any other tensors. -More information can be found in the docs for the -Pytorch multiprocess launcher module [torch.distributed.launch](https://pytorch.org/docs/stable/distributed.html#launch-utility). - -`main_amp.py` is written to interact with -[torch.distributed.launch](https://pytorch.org/docs/master/distributed.html#launch-utility), -which spawns multiprocess jobs using the following syntax: -``` -python -m torch.distributed.launch --nproc_per_node=NUM_GPUS main_amp.py args... -``` -`NUM_GPUS` should be less than or equal to the number of visible GPU devices on the node. The use of `torch.distributed.launch` is unrelated to the choice of DDP wrapper. It is safe to use either apex DDP or torch DDP with `torch.distributed.launch`. - -Optionally, one can run imagenet with synchronized batch normalization across processes by adding -`--sync_bn` to the `args...` - -## Deterministic training (for debugging purposes) - -Running with the `--deterministic` flag should produce bitwise identical outputs run-to-run, -regardless of what other options are used (see [Pytorch docs on reproducibility](https://pytorch.org/docs/stable/notes/randomness.html)). -Since `--deterministic` disables `torch.backends.cudnn.benchmark`, `--deterministic` may -cause a modest performance decrease. - -## Profiling - -If you're curious how the network actually looks on the CPU and GPU timelines (for example, how good is the overall utilization? -Is the prefetcher really overlapping data transfers?) try profiling `main_amp.py`. -[Detailed instructions can be found here](https://gist.github.com/mcarilli/213a4e698e4a0ae2234ddee56f4f3f95). diff --git a/examples/imagenet/main_amp.py b/examples/imagenet/main_amp.py deleted file mode 100644 index c4b0fdf..0000000 --- a/examples/imagenet/main_amp.py +++ /dev/null @@ -1,543 +0,0 @@ -import argparse -import os -import shutil -import time - -import torch -import torch.nn as nn -import torch.nn.parallel -import torch.backends.cudnn as cudnn -import torch.distributed as dist -import torch.optim -import torch.utils.data -import torch.utils.data.distributed -import torchvision.transforms as transforms -import torchvision.datasets as datasets -import torchvision.models as models - -import numpy as np - -try: - from apex.parallel import DistributedDataParallel as DDP - from apex.fp16_utils import * - from apex import amp, optimizers - from apex.multi_tensor_apply import multi_tensor_applier -except ImportError: - raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") - -def fast_collate(batch, memory_format): - - imgs = [img[0] for img in batch] - targets = torch.tensor([target[1] for target in batch], dtype=torch.int64) - w = imgs[0].size[0] - h = imgs[0].size[1] - tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8).contiguous(memory_format=memory_format) - for i, img in enumerate(imgs): - nump_array = np.asarray(img, dtype=np.uint8) - if(nump_array.ndim < 3): - nump_array = np.expand_dims(nump_array, axis=-1) - nump_array = np.rollaxis(nump_array, 2) - tensor[i] += torch.from_numpy(nump_array) - return tensor, targets - - -def parse(): - model_names = sorted(name for name in models.__dict__ - if name.islower() and not name.startswith("__") - and callable(models.__dict__[name])) - - parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') - parser.add_argument('data', metavar='DIR', - help='path to dataset') - parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', - choices=model_names, - help='model architecture: ' + - ' | '.join(model_names) + - ' (default: resnet18)') - parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', - help='number of data loading workers (default: 4)') - parser.add_argument('--epochs', default=90, type=int, metavar='N', - help='number of total epochs to run') - parser.add_argument('--start-epoch', default=0, type=int, metavar='N', - help='manual epoch number (useful on restarts)') - parser.add_argument('-b', '--batch-size', default=256, type=int, - metavar='N', help='mini-batch size per process (default: 256)') - parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, - metavar='LR', help='Initial learning rate. Will be scaled by /256: args.lr = args.lr*float(args.batch_size*args.world_size)/256. A warmup schedule will also be applied over the first 5 epochs.') - parser.add_argument('--momentum', default=0.9, type=float, metavar='M', - help='momentum') - parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, - metavar='W', help='weight decay (default: 1e-4)') - parser.add_argument('--print-freq', '-p', default=10, type=int, - metavar='N', help='print frequency (default: 10)') - parser.add_argument('--resume', default='', type=str, metavar='PATH', - help='path to latest checkpoint (default: none)') - parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', - help='evaluate model on validation set') - parser.add_argument('--pretrained', dest='pretrained', action='store_true', - help='use pre-trained model') - - parser.add_argument('--prof', default=-1, type=int, - help='Only run 10 iterations for profiling.') - parser.add_argument('--deterministic', action='store_true') - - parser.add_argument("--local_rank", default=os.getenv('LOCAL_RANK', 0), type=int) - parser.add_argument('--sync_bn', action='store_true', - help='enabling apex sync BN.') - - parser.add_argument('--opt-level', type=str) - parser.add_argument('--keep-batchnorm-fp32', type=str, default=None) - parser.add_argument('--loss-scale', type=str, default=None) - parser.add_argument('--channels-last', type=bool, default=False) - args = parser.parse_args() - return args - -def main(): - global best_prec1, args - - args = parse() - print("opt_level = {}".format(args.opt_level)) - print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32)) - print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale)) - - print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version())) - - cudnn.benchmark = True - best_prec1 = 0 - if args.deterministic: - cudnn.benchmark = False - cudnn.deterministic = True - torch.manual_seed(args.local_rank) - torch.set_printoptions(precision=10) - - args.distributed = False - if 'WORLD_SIZE' in os.environ: - args.distributed = int(os.environ['WORLD_SIZE']) > 1 - - args.gpu = 0 - args.world_size = 1 - - if args.distributed: - args.gpu = args.local_rank - torch.cuda.set_device(args.gpu) - torch.distributed.init_process_group(backend='nccl', - init_method='env://') - args.world_size = torch.distributed.get_world_size() - - assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled." - - if args.channels_last: - memory_format = torch.channels_last - else: - memory_format = torch.contiguous_format - - # create model - if args.pretrained: - print("=> using pre-trained model '{}'".format(args.arch)) - model = models.__dict__[args.arch](pretrained=True) - else: - print("=> creating model '{}'".format(args.arch)) - model = models.__dict__[args.arch]() - - if args.sync_bn: - import apex - print("using apex synced BN") - model = apex.parallel.convert_syncbn_model(model) - - model = model.cuda().to(memory_format=memory_format) - - # Scale learning rate based on global batch size - args.lr = args.lr*float(args.batch_size*args.world_size)/256. - optimizer = torch.optim.SGD(model.parameters(), args.lr, - momentum=args.momentum, - weight_decay=args.weight_decay) - - # Initialize Amp. Amp accepts either values or strings for the optional override arguments, - # for convenient interoperation with argparse. - model, optimizer = amp.initialize(model, optimizer, - opt_level=args.opt_level, - keep_batchnorm_fp32=args.keep_batchnorm_fp32, - loss_scale=args.loss_scale - ) - - # For distributed training, wrap the model with apex.parallel.DistributedDataParallel. - # This must be done AFTER the call to amp.initialize. If model = DDP(model) is called - # before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter - # the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks. - if args.distributed: - # By default, apex.parallel.DistributedDataParallel overlaps communication with - # computation in the backward pass. - # model = DDP(model) - # delay_allreduce delays all communication to the end of the backward pass. - model = DDP(model, delay_allreduce=True) - - # define loss function (criterion) and optimizer - criterion = nn.CrossEntropyLoss().cuda() - - # Optionally resume from a checkpoint - if args.resume: - # Use a local scope to avoid dangling references - def resume(): - if os.path.isfile(args.resume): - print("=> loading checkpoint '{}'".format(args.resume)) - checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu)) - args.start_epoch = checkpoint['epoch'] - global best_prec1 - best_prec1 = checkpoint['best_prec1'] - model.load_state_dict(checkpoint['state_dict']) - optimizer.load_state_dict(checkpoint['optimizer']) - print("=> loaded checkpoint '{}' (epoch {})" - .format(args.resume, checkpoint['epoch'])) - else: - print("=> no checkpoint found at '{}'".format(args.resume)) - resume() - - # Data loading code - traindir = os.path.join(args.data, 'train') - valdir = os.path.join(args.data, 'val') - - if(args.arch == "inception_v3"): - raise RuntimeError("Currently, inception_v3 is not supported by this example.") - # crop_size = 299 - # val_size = 320 # I chose this value arbitrarily, we can adjust. - else: - crop_size = 224 - val_size = 256 - - train_dataset = datasets.ImageFolder( - traindir, - transforms.Compose([ - transforms.RandomResizedCrop(crop_size), - transforms.RandomHorizontalFlip(), - # transforms.ToTensor(), Too slow - # normalize, - ])) - val_dataset = datasets.ImageFolder(valdir, transforms.Compose([ - transforms.Resize(val_size), - transforms.CenterCrop(crop_size), - ])) - - train_sampler = None - val_sampler = None - if args.distributed: - train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) - val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) - - collate_fn = lambda b: fast_collate(b, memory_format) - - train_loader = torch.utils.data.DataLoader( - train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), - num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=collate_fn) - - val_loader = torch.utils.data.DataLoader( - val_dataset, - batch_size=args.batch_size, shuffle=False, - num_workers=args.workers, pin_memory=True, - sampler=val_sampler, - collate_fn=collate_fn) - - if args.evaluate: - validate(val_loader, model, criterion) - return - - for epoch in range(args.start_epoch, args.epochs): - if args.distributed: - train_sampler.set_epoch(epoch) - - # train for one epoch - train(train_loader, model, criterion, optimizer, epoch) - - # evaluate on validation set - prec1 = validate(val_loader, model, criterion) - - # remember best prec@1 and save checkpoint - if args.local_rank == 0: - is_best = prec1 > best_prec1 - best_prec1 = max(prec1, best_prec1) - save_checkpoint({ - 'epoch': epoch + 1, - 'arch': args.arch, - 'state_dict': model.state_dict(), - 'best_prec1': best_prec1, - 'optimizer' : optimizer.state_dict(), - }, is_best) - -class data_prefetcher(): - def __init__(self, loader): - self.loader = iter(loader) - self.stream = torch.cuda.Stream() - self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1) - self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1) - # With Amp, it isn't necessary to manually convert data to half. - # if args.fp16: - # self.mean = self.mean.half() - # self.std = self.std.half() - self.preload() - - def preload(self): - try: - self.next_input, self.next_target = next(self.loader) - except StopIteration: - self.next_input = None - self.next_target = None - return - # if record_stream() doesn't work, another option is to make sure device inputs are created - # on the main stream. - # self.next_input_gpu = torch.empty_like(self.next_input, device='cuda') - # self.next_target_gpu = torch.empty_like(self.next_target, device='cuda') - # Need to make sure the memory allocated for next_* is not still in use by the main stream - # at the time we start copying to next_*: - # self.stream.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(self.stream): - self.next_input = self.next_input.cuda(non_blocking=True) - self.next_target = self.next_target.cuda(non_blocking=True) - # more code for the alternative if record_stream() doesn't work: - # copy_ will record the use of the pinned source tensor in this side stream. - # self.next_input_gpu.copy_(self.next_input, non_blocking=True) - # self.next_target_gpu.copy_(self.next_target, non_blocking=True) - # self.next_input = self.next_input_gpu - # self.next_target = self.next_target_gpu - - # With Amp, it isn't necessary to manually convert data to half. - # if args.fp16: - # self.next_input = self.next_input.half() - # else: - self.next_input = self.next_input.float() - self.next_input = self.next_input.sub_(self.mean).div_(self.std) - - def next(self): - torch.cuda.current_stream().wait_stream(self.stream) - input = self.next_input - target = self.next_target - if input is not None: - input.record_stream(torch.cuda.current_stream()) - if target is not None: - target.record_stream(torch.cuda.current_stream()) - self.preload() - return input, target - - -def train(train_loader, model, criterion, optimizer, epoch): - batch_time = AverageMeter() - losses = AverageMeter() - top1 = AverageMeter() - top5 = AverageMeter() - - # switch to train mode - model.train() - end = time.time() - - prefetcher = data_prefetcher(train_loader) - input, target = prefetcher.next() - i = 0 - while input is not None: - i += 1 - if args.prof >= 0 and i == args.prof: - print("Profiling begun at iteration {}".format(i)) - torch.cuda.cudart().cudaProfilerStart() - - if args.prof >= 0: torch.cuda.nvtx.range_push("Body of iteration {}".format(i)) - - adjust_learning_rate(optimizer, epoch, i, len(train_loader)) - - # compute output - if args.prof >= 0: torch.cuda.nvtx.range_push("forward") - output = model(input) - if args.prof >= 0: torch.cuda.nvtx.range_pop() - loss = criterion(output, target) - - # compute gradient and do SGD step - optimizer.zero_grad() - - if args.prof >= 0: torch.cuda.nvtx.range_push("backward") - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - if args.prof >= 0: torch.cuda.nvtx.range_pop() - - # for param in model.parameters(): - # print(param.data.double().sum().item(), param.grad.data.double().sum().item()) - - if args.prof >= 0: torch.cuda.nvtx.range_push("optimizer.step()") - optimizer.step() - if args.prof >= 0: torch.cuda.nvtx.range_pop() - - if i%args.print_freq == 0: - # Every print_freq iterations, check the loss, accuracy, and speed. - # For best performance, it doesn't make sense to print these metrics every - # iteration, since they incur an allreduce and some host<->device syncs. - - # Measure accuracy - prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) - - # Average loss and accuracy across processes for logging - if args.distributed: - reduced_loss = reduce_tensor(loss.data) - prec1 = reduce_tensor(prec1) - prec5 = reduce_tensor(prec5) - else: - reduced_loss = loss.data - - # to_python_float incurs a host<->device sync - losses.update(to_python_float(reduced_loss), input.size(0)) - top1.update(to_python_float(prec1), input.size(0)) - top5.update(to_python_float(prec5), input.size(0)) - - torch.cuda.synchronize() - batch_time.update((time.time() - end)/args.print_freq) - end = time.time() - - if args.local_rank == 0: - print('Epoch: [{0}][{1}/{2}]\t' - 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' - 'Speed {3:.3f} ({4:.3f})\t' - 'Loss {loss.val:.10f} ({loss.avg:.4f})\t' - 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' - 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( - epoch, i, len(train_loader), - args.world_size*args.batch_size/batch_time.val, - args.world_size*args.batch_size/batch_time.avg, - batch_time=batch_time, - loss=losses, top1=top1, top5=top5)) - if args.prof >= 0: torch.cuda.nvtx.range_push("prefetcher.next()") - input, target = prefetcher.next() - if args.prof >= 0: torch.cuda.nvtx.range_pop() - - # Pop range "Body of iteration {}".format(i) - if args.prof >= 0: torch.cuda.nvtx.range_pop() - - if args.prof >= 0 and i == args.prof + 10: - print("Profiling ended at iteration {}".format(i)) - torch.cuda.cudart().cudaProfilerStop() - quit() - - -def validate(val_loader, model, criterion): - batch_time = AverageMeter() - losses = AverageMeter() - top1 = AverageMeter() - top5 = AverageMeter() - - # switch to evaluate mode - model.eval() - - end = time.time() - - prefetcher = data_prefetcher(val_loader) - input, target = prefetcher.next() - i = 0 - while input is not None: - i += 1 - - # compute output - with torch.no_grad(): - output = model(input) - loss = criterion(output, target) - - # measure accuracy and record loss - prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) - - if args.distributed: - reduced_loss = reduce_tensor(loss.data) - prec1 = reduce_tensor(prec1) - prec5 = reduce_tensor(prec5) - else: - reduced_loss = loss.data - - losses.update(to_python_float(reduced_loss), input.size(0)) - top1.update(to_python_float(prec1), input.size(0)) - top5.update(to_python_float(prec5), input.size(0)) - - # measure elapsed time - batch_time.update(time.time() - end) - end = time.time() - - # TODO: Change timings to mirror train(). - if args.local_rank == 0 and i % args.print_freq == 0: - print('Test: [{0}/{1}]\t' - 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' - 'Speed {2:.3f} ({3:.3f})\t' - 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' - 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' - 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( - i, len(val_loader), - args.world_size * args.batch_size / batch_time.val, - args.world_size * args.batch_size / batch_time.avg, - batch_time=batch_time, loss=losses, - top1=top1, top5=top5)) - - input, target = prefetcher.next() - - print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' - .format(top1=top1, top5=top5)) - - return top1.avg - - -def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): - torch.save(state, filename) - if is_best: - shutil.copyfile(filename, 'model_best.pth.tar') - - -class AverageMeter(object): - """Computes and stores the average and current value""" - def __init__(self): - self.reset() - - def reset(self): - self.val = 0 - self.avg = 0 - self.sum = 0 - self.count = 0 - - def update(self, val, n=1): - self.val = val - self.sum += val * n - self.count += n - self.avg = self.sum / self.count - - -def adjust_learning_rate(optimizer, epoch, step, len_epoch): - """LR schedule that should yield 76% converged accuracy with batch size 256""" - factor = epoch // 30 - - if epoch >= 80: - factor = factor + 1 - - lr = args.lr*(0.1**factor) - - """Warmup""" - if epoch < 5: - lr = lr*float(1 + step + epoch*len_epoch)/(5.*len_epoch) - - # if(args.local_rank == 0): - # print("epoch = {}, step = {}, lr = {}".format(epoch, step, lr)) - - for param_group in optimizer.param_groups: - param_group['lr'] = lr - - -def accuracy(output, target, topk=(1,)): - """Computes the precision@k for the specified values of k""" - maxk = max(topk) - batch_size = target.size(0) - - _, pred = output.topk(maxk, 1, True, True) - pred = pred.t() - correct = pred.eq(target.view(1, -1).expand_as(pred)) - - res = [] - for k in topk: - correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) - res.append(correct_k.mul_(100.0 / batch_size)) - return res - - -def reduce_tensor(tensor): - rt = tensor.clone() - dist.all_reduce(rt, op=dist.reduce_op.SUM) - rt /= args.world_size - return rt - -if __name__ == '__main__': - main() diff --git a/examples/simple/distributed/README.md b/examples/simple/distributed/README.md deleted file mode 100644 index 0d939cb..0000000 --- a/examples/simple/distributed/README.md +++ /dev/null @@ -1,13 +0,0 @@ -**distributed_data_parallel.py** and **run.sh** show an example using Amp with -[apex.parallel.DistributedDataParallel](https://nvidia.github.io/apex/parallel.html) or -[torch.nn.parallel.DistributedDataParallel](https://pytorch.org/docs/stable/nn.html#distributeddataparallel) -and the Pytorch multiprocess launcher script, -[torch.distributed.launch](https://pytorch.org/docs/master/distributed.html#launch-utility). -The use of `Amp` with DistributedDataParallel does not need to change from ordinary -single-process use. The only gotcha is that wrapping your model with `DistributedDataParallel` must -come after the call to `amp.initialize`. Test via -```bash -bash run.sh -``` - -**This is intended purely as an instructional example, not a performance showcase.** diff --git a/examples/simple/distributed/distributed_data_parallel.py b/examples/simple/distributed/distributed_data_parallel.py deleted file mode 100644 index b364405..0000000 --- a/examples/simple/distributed/distributed_data_parallel.py +++ /dev/null @@ -1,65 +0,0 @@ -import torch -import argparse -import os -from apex import amp -# FOR DISTRIBUTED: (can also use torch.nn.parallel.DistributedDataParallel instead) -from apex.parallel import DistributedDataParallel - -parser = argparse.ArgumentParser() -# FOR DISTRIBUTED: Parse for the local_rank argument, which will be supplied -# automatically by torch.distributed.launch. -parser.add_argument("--local_rank", default=0, type=int) -args = parser.parse_args() - -# FOR DISTRIBUTED: If we are running under torch.distributed.launch, -# the 'WORLD_SIZE' environment variable will also be set automatically. -args.distributed = False -if 'WORLD_SIZE' in os.environ: - args.distributed = int(os.environ['WORLD_SIZE']) > 1 - -if args.distributed: - # FOR DISTRIBUTED: Set the device according to local_rank. - torch.cuda.set_device(args.local_rank) - - # FOR DISTRIBUTED: Initialize the backend. torch.distributed.launch will provide - # environment variables, and requires that you use init_method=`env://`. - torch.distributed.init_process_group(backend='nccl', - init_method='env://') - -torch.backends.cudnn.benchmark = True - -N, D_in, D_out = 64, 1024, 16 - -# Each process receives its own batch of "fake input data" and "fake target data." -# The "training loop" in each process just uses this fake batch over and over. -# https://github.com/NVIDIA/apex/tree/master/examples/imagenet provides a more realistic -# example of distributed data sampling for both training and validation. -x = torch.randn(N, D_in, device='cuda') -y = torch.randn(N, D_out, device='cuda') - -model = torch.nn.Linear(D_in, D_out).cuda() -optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) - -model, optimizer = amp.initialize(model, optimizer, opt_level="O1") - -if args.distributed: - # FOR DISTRIBUTED: After amp.initialize, wrap the model with - # apex.parallel.DistributedDataParallel. - model = DistributedDataParallel(model) - # torch.nn.parallel.DistributedDataParallel is also fine, with some added args: - # model = torch.nn.parallel.DistributedDataParallel(model, - # device_ids=[args.local_rank], - # output_device=args.local_rank) - -loss_fn = torch.nn.MSELoss() - -for t in range(500): - optimizer.zero_grad() - y_pred = model(x) - loss = loss_fn(y_pred, y) - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - optimizer.step() - -if args.local_rank == 0: - print("final loss = ", loss) diff --git a/examples/simple/distributed/run.sh b/examples/simple/distributed/run.sh deleted file mode 100644 index 7a2d85f..0000000 --- a/examples/simple/distributed/run.sh +++ /dev/null @@ -1,2 +0,0 @@ -#!/bin/bash -python -m torch.distributed.launch --nproc_per_node=2 distributed_data_parallel.py diff --git a/get_version.py b/get_version.py deleted file mode 100644 index 1f83110..0000000 --- a/get_version.py +++ /dev/null @@ -1,67 +0,0 @@ -import os -import subprocess -from pathlib import Path - -import torch - -ROOT_DIR = Path(__file__).parent.resolve() - - -def _run_cmd(cmd, shell=False): - try: - return subprocess.check_output(cmd, cwd=ROOT_DIR, stderr=subprocess.DEVNULL, shell=shell).decode("ascii").strip() - except Exception: - return None - - -def _get_version(): - if os.path.exists(ROOT_DIR / "version.txt"): - with open(ROOT_DIR / "version.txt", "r") as f: - version = f.read().strip() - else: - version = '0.1' - if os.getenv("BUILD_VERSION"): - version = os.getenv("BUILD_VERSION") - return version - - -def _make_version_file(version, sha, abi, dtk, torch_version, branch): - sha = "Unknown" if sha is None else sha - torch_version = '.'.join(torch_version.split('.')[:2]) - dcu_version = f"{version}+{sha}.abi{abi}.dtk{dtk}.torch{torch_version}" - version_path = ROOT_DIR / "apex" / "version.py" - with open(version_path, "w") as f: - f.write(f"version = '{version}'\n") - f.write(f"git_hash = '{sha}'\n") - f.write(f"git_branch = '{branch}'\n") - f.write(f"abi = 'abi{abi}'\n") - f.write(f"dtk = '{dtk}'\n") - f.write(f"torch_version = '{torch_version}'\n") - f.write(f"dcu_version = '{dcu_version}'\n") - return dcu_version - - -def _get_pytorch_version(): - if "PYTORCH_VERSION" in os.environ: - return f"{os.environ['PYTORCH_VERSION']}" - return torch.__version__ - -def get_version(ROCM_HOME): - sha = _run_cmd(["git", "rev-parse", "HEAD"]) - sha = sha[:7] - branch = _run_cmd(["git", "rev-parse", "--abbrev-ref", "HEAD"]) - tag = _run_cmd(["git", "describe", "--tags", "--exact-match", "@"]) - print("-- Git branch:", branch) - print("-- Git SHA:", sha) - print("-- Git tag:", tag) - torch_version = _get_pytorch_version() - print("-- PyTorch:", torch_version) - version = _get_version() - print("-- Building version", version) - abi = _run_cmd(["echo '#include ' | gcc -x c++ -E -dM - | fgrep _GLIBCXX_USE_CXX11_ABI | awk '{print $3}'"], shell=True) - print("-- _GLIBCXX_USE_CXX11_ABI:", abi) - dtk = _run_cmd(["cat", os.path.join(ROCM_HOME, '.info/rocm_version')]) - dtk = ''.join(dtk.split('.')[:2]) - print("-- DTK:", dtk) - - return _make_version_file(version, sha, abi, dtk, torch_version, branch) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml deleted file mode 100644 index f29f03d..0000000 --- a/pyproject.toml +++ /dev/null @@ -1,7 +0,0 @@ -[build-system] -requires = [ - "setuptools", - "wheel", -] -build-backend = "setuptools.build_meta" - diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index fd202d9..0000000 --- a/requirements.txt +++ /dev/null @@ -1,6 +0,0 @@ -cxxfilt>=0.2.0 -tqdm>=4.28.1 -numpy>=1.15.3 -PyYAML>=5.1 -pytest>=3.5.1 -packaging>=14.0 diff --git a/requirements_dev.txt b/requirements_dev.txt deleted file mode 100644 index e1086e0..0000000 --- a/requirements_dev.txt +++ /dev/null @@ -1,3 +0,0 @@ --r requirements.txt -flake8>=3.7.9 -Sphinx>=3.0.3 \ No newline at end of file diff --git a/setup.py b/setup.py deleted file mode 100644 index ab0e185..0000000 --- a/setup.py +++ /dev/null @@ -1,709 +0,0 @@ -import torch -from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME, ROCM_HOME -from setuptools import setup, find_packages -import subprocess - -import sys -import warnings -import os - -from get_version import get_version - -dcu_version = get_version(ROCM_HOME) - -# ninja build does not work unless include_dirs are abs path -this_dir = os.path.dirname(os.path.abspath(__file__)) -torch_dir = torch.__path__[0] - -# https://github.com/pytorch/pytorch/pull/71881 -# For the extensions which have rocblas_gemm_flags_fp16_alt_impl we need to make sure if at::BackwardPassGuard exists. -# It helps the extensions be backward compatible with old PyTorch versions. -# The check and ROCM_BACKWARD_PASS_GUARD in nvcc/hipcc args can be retired once the PR is merged into PyTorch upstream. - -context_file = os.path.join(torch_dir, "include", "ATen", "Context.h") -if os.path.exists(context_file): - lines = open(context_file, 'r').readlines() - found_Backward_Pass_Guard = False - found_ROCmBackward_Pass_Guard = False - for line in lines: - if "BackwardPassGuard" in line: - # BackwardPassGuard has been renamed to ROCmBackwardPassGuard - # https://github.com/pytorch/pytorch/pull/71881/commits/4b82f5a67a35406ffb5691c69e6b4c9086316a43 - if "ROCmBackwardPassGuard" in line: - found_ROCmBackward_Pass_Guard = True - else: - found_Backward_Pass_Guard = True - break - -found_aten_atomic_header = False -if os.path.exists(os.path.join(torch_dir, "include", "ATen", "Atomic.cuh")): - found_aten_atomic_header = True - -def get_cuda_bare_metal_version(cuda_dir): - raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) - output = raw_output.split() - release_idx = output.index("release") + 1 - release = output[release_idx].split(".") - bare_metal_major = release[0] - bare_metal_minor = release[1][0] - - return raw_output, bare_metal_major, bare_metal_minor - - -def check_cuda_torch_binary_vs_bare_metal(cuda_dir): - raw_output, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(cuda_dir) - torch_binary_major = torch.version.cuda.split(".")[0] - torch_binary_minor = torch.version.cuda.split(".")[1] - - print("\nCompiling cuda extensions with") - print(raw_output + "from " + cuda_dir + "/bin\n") - - if (bare_metal_major != torch_binary_major) or (bare_metal_minor != torch_binary_minor): - raise RuntimeError( - "Cuda extensions are being compiled with a version of Cuda that does " - "not match the version used to compile Pytorch binaries. " - "Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda) - + "In some cases, a minor-version mismatch will not cause later errors: " - "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. " - "You can try commenting out this check (at your own risk)." - ) - - -def raise_if_cuda_home_none(global_option: str) -> None: - if CUDA_HOME is not None: - return - raise RuntimeError( - f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " - "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " - "only images whose names contain 'devel' will provide nvcc." - ) - - -def append_nvcc_threads(nvcc_extra_args): - _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) - if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2: - return nvcc_extra_args + ["--threads", "4"] - return nvcc_extra_args - - -def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int) -> bool: - cudnn_available = torch.backends.cudnn.is_available() - cudnn_version = torch.backends.cudnn.version() if cudnn_available else None - if not (cudnn_available and (cudnn_version >= required_cudnn_version)): - warnings.warn( - f"Skip `{global_option}` as it requires cuDNN {required_cudnn_version} or later, " - f"but {'cuDNN is not available' if not cudnn_available else cudnn_version}" - ) - return False - return True - - -print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) -TORCH_MAJOR = int(torch.__version__.split('.')[0]) -TORCH_MINOR = int(torch.__version__.split('.')[1]) - -def check_if_rocm_pytorch(): - is_rocm_pytorch = False - if TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 5): - from torch.utils.cpp_extension import ROCM_HOME - is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False - - return is_rocm_pytorch - -IS_ROCM_PYTORCH = check_if_rocm_pytorch() - -if not torch.cuda.is_available() and not IS_ROCM_PYTORCH: - # https://github.com/NVIDIA/apex/issues/486 - # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(), - # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command). - print( - "\nWarning: Torch did not find available GPUs on this system.\n", - "If your intention is to cross-compile, this is not an error.\n" - "By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n" - "Volta (compute capability 7.0), Turing (compute capability 7.5),\n" - "and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n" - "If you wish to cross-compile for a single specific architecture,\n" - 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n', - ) - if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None: - _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME) - if int(bare_metal_major) == 11: - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0" - if int(bare_metal_minor) > 0: - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6" - else: - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5" -elif not torch.cuda.is_available() and IS_ROCM_PYTORCH: - print('\nWarning: Torch did not find available GPUs on this system.\n', - 'If your intention is to cross-compile, this is not an error.\n' - 'By default, Apex will cross-compile for the same gfx targets\n' - 'used by default in ROCm PyTorch\n') - -if TORCH_MAJOR == 0 and TORCH_MINOR < 4: - raise RuntimeError( - "Apex requires Pytorch 0.4 or newer.\nThe latest stable release can be obtained from https://pytorch.org/" - ) - -# cmdclass = {} -ext_modules = [] - -extras = {} - -if "--cpp_ext" in sys.argv or "--cuda_ext" in sys.argv: - if TORCH_MAJOR == 0: - raise RuntimeError("--cpp_ext requires Pytorch 1.0 or later, " - "found torch.__version__ = {}".format(torch.__version__)) -if "--cpp_ext" in sys.argv: - sys.argv.remove("--cpp_ext") - ext_modules.append(CppExtension("apex_C", ["csrc/flatten_unflatten.cpp"])) - -# Set up macros for forward/backward compatibility hack around -# https://github.com/pytorch/pytorch/commit/4404762d7dd955383acee92e6f06b48144a0742e -# and -# https://github.com/NVIDIA/apex/issues/456 -# https://github.com/pytorch/pytorch/commit/eb7b39e02f7d75c26d8a795ea8c7fd911334da7e#diff-4632522f237f1e4e728cb824300403ac -version_ge_1_1 = [] -if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0): - version_ge_1_1 = ["-DVERSION_GE_1_1"] -version_ge_1_3 = [] -if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2): - version_ge_1_3 = ["-DVERSION_GE_1_3"] -version_ge_1_5 = [] -if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4): - version_ge_1_5 = ["-DVERSION_GE_1_5"] -version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5 - -if IS_ROCM_PYTORCH: - version_dependent_macros += ['--gpu-max-threads-per-block=1024'] - -if "--distributed_adam" in sys.argv or "--cuda_ext" in sys.argv: - if "--distributed_adam" in sys.argv: - sys.argv.remove("--distributed_adam") - - if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH: - raise RuntimeError("--distributed_adam was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") - else: - nvcc_args_adam = ['-O3', '--use_fast_math'] + version_dependent_macros - hipcc_args_adam = ['-O3'] + version_dependent_macros - ext_modules.append( - CUDAExtension(name='distributed_adam_cuda', - sources=['apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp', - 'apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu'], - include_dirs=[os.path.join(this_dir, 'csrc'), - os.path.join(this_dir, 'apex/contrib/csrc/optimizers')], - extra_compile_args={'cxx': ['-O3',] + version_dependent_macros, - 'nvcc':nvcc_args_adam if not IS_ROCM_PYTORCH else hipcc_args_adam})) - -if "--distributed_lamb" in sys.argv or "--cuda_ext" in sys.argv: - if "--distributed_lamb" in sys.argv: - sys.argv.remove("--distributed_lamb") - - if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH: - raise RuntimeError("--distributed_lamb was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") - else: - print ("INFO: Building the distributed_lamb extension.") - nvcc_args_distributed_lamb = ['-O3', '--use_fast_math'] + version_dependent_macros - hipcc_args_distributed_lamb = ['-O3'] + version_dependent_macros - ext_modules.append( - CUDAExtension(name='distributed_lamb_cuda', - sources=['apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp', - 'apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu'], - include_dirs=[os.path.join(this_dir, 'csrc')], - extra_compile_args={'cxx': ['-O3',] + version_dependent_macros, - 'nvcc': nvcc_args_distributed_lamb if not IS_ROCM_PYTORCH else hipcc_args_distributed_lamb})) - -if "--cuda_ext" in sys.argv: - if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH: - raise RuntimeError("--cuda_ext was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") - else: - if not IS_ROCM_PYTORCH: - check_cuda_torch_binary_vs_bare_metal(torch.utils.cpp_extension.CUDA_HOME) - - print ("INFO: Building the multi-tensor apply extension.") - nvcc_args_multi_tensor = ['-lineinfo', '-O3', '--use_fast_math'] + version_dependent_macros - hipcc_args_multi_tensor = ['-O3'] + version_dependent_macros - ext_modules.append( - CUDAExtension(name='amp_C', - sources=['csrc/amp_C_frontend.cpp', - 'csrc/multi_tensor_sgd_kernel.cu', - 'csrc/multi_tensor_scale_kernel.cu', - 'csrc/multi_tensor_axpby_kernel.cu', - 'csrc/multi_tensor_l2norm_kernel.cu', - 'csrc/multi_tensor_l2norm_kernel_mp.cu', - 'csrc/multi_tensor_l2norm_scale_kernel.cu', - 'csrc/multi_tensor_lamb_stage_1.cu', - 'csrc/multi_tensor_lamb_stage_2.cu', - 'csrc/multi_tensor_adam.cu', - 'csrc/multi_tensor_adagrad.cu', - 'csrc/multi_tensor_novograd.cu', - 'csrc/multi_tensor_lars.cu', - 'csrc/multi_tensor_lamb.cu', - 'csrc/multi_tensor_lamb_mp.cu'], - include_dirs=[os.path.join(this_dir, 'csrc')], - extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, - 'nvcc': nvcc_args_multi_tensor if not IS_ROCM_PYTORCH else hipcc_args_multi_tensor})) - - print ("INFO: Building syncbn extension.") - ext_modules.append( - CUDAExtension(name='syncbn', - sources=['csrc/syncbn.cpp', - 'csrc/welford.cu'], - include_dirs=[os.path.join(this_dir, 'csrc')], - extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, - 'nvcc':['-O3'] + version_dependent_macros})) - - nvcc_args_layer_norm = ['-maxrregcount=50', '-O3', '--use_fast_math'] + version_dependent_macros - hipcc_args_layer_norm = ['-O3'] + version_dependent_macros - print ("INFO: Building fused layernorm extension.") - ext_modules.append( - CUDAExtension(name='fused_layer_norm_cuda', - sources=['csrc/layer_norm_cuda.cpp', - 'csrc/layer_norm_cuda_kernel.cu'], - include_dirs=[os.path.join(this_dir, 'csrc')], - extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, - 'nvcc': nvcc_args_layer_norm if not IS_ROCM_PYTORCH else hipcc_args_layer_norm})) - - hipcc_args_mlp = ['-O3'] + version_dependent_macros - if found_Backward_Pass_Guard: - hipcc_args_mlp = hipcc_args_mlp + ['-DBACKWARD_PASS_GUARD'] + ['-DBACKWARD_PASS_GUARD_CLASS=BackwardPassGuard'] - if found_ROCmBackward_Pass_Guard: - hipcc_args_mlp = hipcc_args_mlp + ['-DBACKWARD_PASS_GUARD'] + ['-DBACKWARD_PASS_GUARD_CLASS=ROCmBackwardPassGuard'] - - print ("INFO: Building the MLP Extension.") - ext_modules.append( - CUDAExtension(name='mlp_cuda', - sources=['csrc/mlp.cpp', - 'csrc/mlp_cuda.cu'], - include_dirs=[os.path.join(this_dir, 'csrc')], - extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, - 'nvcc':['-O3'] + version_dependent_macros - if not IS_ROCM_PYTORCH else hipcc_args_mlp})) - - ext_modules.append( - CUDAExtension(name='fused_dense_cuda', - sources=['csrc/fused_dense.cpp', - 'csrc/fused_dense_cuda.cu'], - include_dirs=[os.path.join(this_dir, 'csrc')], - extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, - 'nvcc':['-O3'] + version_dependent_macros})) - nvcc_args_transformer = ['-O3', - '-U__CUDA_NO_HALF_OPERATORS__', - '-U__CUDA_NO_HALF_CONVERSIONS__', - '--expt-relaxed-constexpr', - '--expt-extended-lambda'] + version_dependent_macros - hipcc_args_transformer = ['-O3', - '-U__CUDA_NO_HALF_OPERATORS__', - '-U__CUDA_NO_HALF_CONVERSIONS__'] + version_dependent_macros - ext_modules.append( - CUDAExtension(name='scaled_upper_triang_masked_softmax_cuda', - sources=['csrc/megatron/scaled_upper_triang_masked_softmax.cpp', - 'csrc/megatron/scaled_upper_triang_masked_softmax_cuda.cu'], - include_dirs=[os.path.join(this_dir, 'csrc')], - extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, - 'nvcc':nvcc_args_transformer if not IS_ROCM_PYTORCH else hipcc_args_transformer})) - ext_modules.append( - CUDAExtension(name='scaled_masked_softmax_cuda', - sources=['csrc/megatron/scaled_masked_softmax.cpp', - 'csrc/megatron/scaled_masked_softmax_cuda.cu'], - include_dirs=[os.path.join(this_dir, 'csrc'), - os.path.join(this_dir, 'csrc/megatron')], - extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, - 'nvcc':nvcc_args_transformer if not IS_ROCM_PYTORCH else hipcc_args_transformer})) - - -if "--bnp" in sys.argv or "--cuda_ext" in sys.argv: - if "--bnp" in sys.argv: - sys.argv.remove("--bnp") - - if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH: - raise RuntimeError("--bnp was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") - else: - ext_modules.append( - CUDAExtension(name='bnp', - sources=['apex/contrib/csrc/groupbn/batch_norm.cu', - 'apex/contrib/csrc/groupbn/ipc.cu', - 'apex/contrib/csrc/groupbn/interface.cpp', - 'apex/contrib/csrc/groupbn/batch_norm_add_relu.cu'], - include_dirs=[os.path.join(this_dir, 'csrc'), - os.path.join(this_dir, 'apex/contrib/csrc/groupbn')], - extra_compile_args={'cxx': [] + version_dependent_macros, - 'nvcc':['-DCUDA_HAS_FP16=1', - '-D__CUDA_NO_HALF_OPERATORS__', - '-D__CUDA_NO_HALF_CONVERSIONS__', - '-D__CUDA_NO_HALF2_OPERATORS__'] + version_dependent_macros})) - -if "--xentropy" in sys.argv or "--cuda_ext" in sys.argv: - if "--xentropy" in sys.argv: - sys.argv.remove("--xentropy") - - if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH: - raise RuntimeError("--xentropy was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") - else: - print ("INFO: Building the xentropy extension.") - ext_modules.append( - CUDAExtension(name='xentropy_cuda', - sources=['apex/contrib/csrc/xentropy/interface.cpp', - 'apex/contrib/csrc/xentropy/xentropy_kernel.cu'], - include_dirs=[os.path.join(this_dir, 'csrc'), - os.path.join(this_dir, 'apex/contrib/csrc/xentropy')], - extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, - 'nvcc':['-O3'] + version_dependent_macros})) - -if "--focal_loss" in sys.argv or "--cuda_ext" in sys.argv: - if "--focal_loss" in sys.argv: - sys.argv.remove("--focal_loss") - ext_modules.append( - CUDAExtension( - name='focal_loss_cuda', - sources=[ - 'apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp', - 'apex/contrib/csrc/focal_loss/focal_loss_cuda_kernel.cu', - ], - include_dirs=[os.path.join(this_dir, 'csrc')], - extra_compile_args={ - 'cxx': ['-O3'] + version_dependent_macros, - 'nvcc':(['-O3', '--use_fast_math', '--ftz=false'] if not IS_ROCM_PYTORCH else ['-O3']) + version_dependent_macros, - }, - ) - ) - -if "--index_mul_2d" in sys.argv or "--cuda_ext" in sys.argv: - if "--index_mul_2d" in sys.argv: - sys.argv.remove("--index_mul_2d") - - args_index_mul_2d = ['-O3'] - if not IS_ROCM_PYTORCH: - args_index_mul_2d += ['--use_fast_math', '--ftz=false'] - if found_aten_atomic_header: - args_index_mul_2d += ['-DATEN_ATOMIC_HEADER'] - - ext_modules.append( - CUDAExtension( - name='fused_index_mul_2d', - sources=[ - 'apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp', - 'apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda_kernel.cu', - ], - include_dirs=[os.path.join(this_dir, 'csrc')], - extra_compile_args={ - 'cxx': ['-O3'] + version_dependent_macros, - 'nvcc': args_index_mul_2d + version_dependent_macros, - }, - ) - ) - -if "--deprecated_fused_adam" in sys.argv or "--cuda_ext" in sys.argv: - if "--deprecated_fused_adam" in sys.argv: - sys.argv.remove("--deprecated_fused_adam") - - if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH: - raise RuntimeError("--deprecated_fused_adam was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") - else: - print ("INFO: Building deprecated fused adam extension.") - nvcc_args_fused_adam = ['-O3', '--use_fast_math'] + version_dependent_macros - hipcc_args_fused_adam = ['-O3'] + version_dependent_macros - ext_modules.append( - CUDAExtension(name='fused_adam_cuda', - sources=['apex/contrib/csrc/optimizers/fused_adam_cuda.cpp', - 'apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu'], - include_dirs=[os.path.join(this_dir, 'csrc'), - os.path.join(this_dir, 'apex/contrib/csrc/optimizers')], - extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, - 'nvcc' : nvcc_args_fused_adam if not IS_ROCM_PYTORCH else hipcc_args_fused_adam})) - -if "--deprecated_fused_lamb" in sys.argv or "--cuda_ext" in sys.argv: - if "--deprecated_fused_lamb" in sys.argv: - sys.argv.remove("--deprecated_fused_lamb") - - if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH: - raise RuntimeError("--deprecated_fused_lamb was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") - else: - print ("INFO: Building deprecated fused lamb extension.") - nvcc_args_fused_lamb = ['-O3', '--use_fast_math'] + version_dependent_macros - hipcc_args_fused_lamb = ['-O3'] + version_dependent_macros - ext_modules.append( - CUDAExtension(name='fused_lamb_cuda', - sources=['apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp', - 'apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu', - 'csrc/multi_tensor_l2norm_kernel.cu'], - include_dirs=[os.path.join(this_dir, 'csrc')], - extra_compile_args = nvcc_args_fused_lamb if not IS_ROCM_PYTORCH else hipcc_args_fused_lamb)) - -# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h -# See https://github.com/pytorch/pytorch/pull/70650 -generator_flag = [] -torch_dir = torch.__path__[0] -if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): - generator_flag = ["-DOLD_GENERATOR_PATH"] - -if "--fast_layer_norm" in sys.argv: - sys.argv.remove("--fast_layer_norm") - # raise_if_cuda_home_none("--fast_layer_norm") - # Check, if CUDA11 is installed for compute capability 8.0 - # cc_flag = [] - # _, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME) - # if int(bare_metal_major) >= 11: - # cc_flag.append("-gencode") - # cc_flag.append("arch=compute_80,code=sm_80") - - if CUDA_HOME is None and not IS_ROCM_PYTORCH: - raise RuntimeError("--fast_layer_norm was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") - else: - if CUDA_HOME is not None: - # Check, if CUDA11 is installed for compute capability 8.0 - cc_flag = [] - _, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME) - if int(bare_metal_major) >= 11: - cc_flag.append('-gencode') - cc_flag.append('arch=compute_80,code=sm_80') - else: - hipcc_args_fused_lamb = ['-O3'] + version_dependent_macros - ext_modules.append( - CUDAExtension( - name="fast_layer_norm", - sources=[ - "apex/contrib/csrc/layer_norm/ln_api.cpp", - "apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu", - "apex/contrib/csrc/layer_norm/ln_bwd_semi_cuda_kernel.cu", - ], - extra_compile_args={ - "cxx": ["-O3"] + version_dependent_macros + generator_flag, - "nvcc": [ - "-O3", - '-U__HIP_NO_HALF_OPERATORS__', - '-U__HIP_NO_HALF_CONVERSIONS__', - "-I./apex/contrib/csrc/layer_norm/", - ] + version_dependent_macros + generator_flag, - }, - include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/layer_norm")], - ) - ) - -if "--fmha" in sys.argv: - sys.argv.remove("--fmha") - raise_if_cuda_home_none("--fmha") - # Check, if CUDA11 is installed for compute capability 8.0 - cc_flag = [] - _, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME) - if int(bare_metal_major) < 11: - raise RuntimeError("--fmha only supported on SM80") - cc_flag.append("-gencode") - cc_flag.append("arch=compute_80,code=sm_80") - - if CUDA_HOME is None and not IS_ROCM_PYTORCH: - raise RuntimeError("--fmha was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") - else: - # Check, if CUDA11 is installed for compute capability 8.0 - cc_flag = [] - _, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME) - if int(bare_metal_major) < 11: - raise RuntimeError("--fmha only supported on SM80") - - ext_modules.append( - CUDAExtension(name='fmhalib', - sources=[ - 'apex/contrib/csrc/fmha/fmha_api.cpp', - 'apex/contrib/csrc/fmha/src/fmha_noloop_reduce.cu', - 'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu', - 'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu', - 'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu', - 'apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu', - 'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu', - 'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_256_64_kernel.sm80.cu', - 'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu', - 'apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu', - ], - extra_compile_args={'cxx': ['-O3', - ] + version_dependent_macros + generator_flag, - 'nvcc':['-O3', - '-gencode', 'arch=compute_80,code=sm_80', - '-U__CUDA_NO_HALF_OPERATORS__', - '-U__CUDA_NO_HALF_CONVERSIONS__', - '--expt-relaxed-constexpr', - '--expt-extended-lambda', - '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag}, - include_dirs=[os.path.join(this_dir, "apex/contrib/csrc"), os.path.join(this_dir, "apex/contrib/csrc/fmha/src")])) - - -if "--fast_multihead_attn" in sys.argv or "--cuda_ext" in sys.argv: - if "--fast_multihead_attn" in sys.argv: - sys.argv.remove("--fast_multihead_attn") - - if torch.utils.cpp_extension.CUDA_HOME is None and not IS_ROCM_PYTORCH: - raise RuntimeError("--fast_multihead_attn was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") - else: - # Check, if CUDA11 is installed for compute capability 8.0 - cc_flag = [] - if not IS_ROCM_PYTORCH: - _, bare_metal_major, _ = get_cuda_bare_metal_version(torch.utils.cpp_extension.CUDA_HOME) - if int(bare_metal_major) >= 11: - cc_flag.append('-gencode') - cc_flag.append('arch=compute_80,code=sm_80') - cc_flag.append('-gencode') - cc_flag.append('arch=compute_86,code=sm_86') - - #subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/multihead_attn/cutlass"]) - nvcc_args_mha = ['-O3', - '-gencode', - 'arch=compute_70,code=sm_70', - '-Iapex/contrib/csrc/multihead_attn/cutlass', - '-U__CUDA_NO_HALF_OPERATORS__', - '-U__CUDA_NO_HALF_CONVERSIONS__', - '--expt-relaxed-constexpr', - '--expt-extended-lambda', - '--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag - hipcc_args_mha = ['-O3', - '-Iapex/contrib/csrc/multihead_attn/cutlass', - '-I' + os.path.join(ROCM_HOME, 'include/hiprand'), - '-I' + os.path.join(ROCM_HOME, 'include/rocrand'), - '-U__HIP_NO_HALF_OPERATORS__', - '-U__HIP_NO_HALF_CONVERSIONS__'] + version_dependent_macros + generator_flag - if found_Backward_Pass_Guard: - hipcc_args_mha = hipcc_args_mha + ['-DBACKWARD_PASS_GUARD'] + ['-DBACKWARD_PASS_GUARD_CLASS=BackwardPassGuard'] - if found_ROCmBackward_Pass_Guard: - hipcc_args_mha = hipcc_args_mha + ['-DBACKWARD_PASS_GUARD'] + ['-DBACKWARD_PASS_GUARD_CLASS=ROCmBackwardPassGuard'] - - ext_modules.append( - CUDAExtension( - name='fast_multihead_attn', - sources=[ - 'apex/contrib/csrc/multihead_attn/multihead_attn_frontend.cpp', - 'apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu', - "apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu", - "apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu", - "apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu", - "apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu", - "apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu", - "apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu", - "apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu", - ], - include_dirs=[os.path.join(this_dir, 'csrc'), - os.path.join(this_dir, 'apex/contrib/csrc/multihead_attn')], - extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag, - 'nvcc':nvcc_args_mha if not IS_ROCM_PYTORCH else hipcc_args_mha} - ) - ) - -if "--transducer" in sys.argv or "--cuda_ext" in sys.argv: - if "--transducer" in sys.argv: - sys.argv.remove("--transducer") - - if not IS_ROCM_PYTORCH: - raise_if_cuda_home_none("--transducer") - - hipcc_args_mha = ['-O3', - '-I' + os.path.join(ROCM_HOME, 'include/hiprand'), - '-I' + os.path.join(ROCM_HOME, 'include/rocrand'),] + version_dependent_macros + generator_flag - ext_modules.append( - CUDAExtension( - name="transducer_joint_cuda", - sources=[ - "apex/contrib/csrc/transducer/transducer_joint.cpp", - "apex/contrib/csrc/transducer/transducer_joint_kernel.cu", - ], - extra_compile_args={ - "cxx": ["-O3"] + version_dependent_macros + generator_flag, - "nvcc": append_nvcc_threads(["-O3"] + version_dependent_macros + generator_flag) if not IS_ROCM_PYTORCH - else hipcc_args_mha, - }, - include_dirs=[os.path.join(this_dir, "csrc"), os.path.join(this_dir, "apex/contrib/csrc/multihead_attn")], - ) - ) - ext_modules.append( - CUDAExtension( - name="transducer_loss_cuda", - sources=[ - "apex/contrib/csrc/transducer/transducer_loss.cpp", - "apex/contrib/csrc/transducer/transducer_loss_kernel.cu", - ], - include_dirs=[os.path.join(this_dir, "csrc")], - extra_compile_args={ - "cxx": ["-O3"] + version_dependent_macros, - "nvcc": append_nvcc_threads(["-O3"] + version_dependent_macros) if not IS_ROCM_PYTORCH - else ["-O3"] + version_dependent_macros, - }, - ) - ) - -# note (mkozuki): Now `--fast_bottleneck` option (i.e. apex/contrib/bottleneck) depends on `--peer_memory` and `--nccl_p2p`. -if "--fast_bottleneck" in sys.argv: - sys.argv.remove("--fast_bottleneck") - raise_if_cuda_home_none("--fast_bottleneck") - if check_cudnn_version_and_warn("--fast_bottleneck", 8400): - subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/cudnn-frontend/"]) - ext_modules.append( - CUDAExtension( - name="fast_bottleneck", - sources=["apex/contrib/csrc/bottleneck/bottleneck.cpp"], - include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/cudnn-frontend/include")], - extra_compile_args={"cxx": ["-O3"] + version_dependent_macros + generator_flag}, - ) - ) - -if "--peer_memory" in sys.argv or "--cuda_ext" in sys.argv: - if "--peer_memory" in sys.argv: - sys.argv.remove("--peer_memory") - - if not IS_ROCM_PYTORCH: - raise_if_cuda_home_none("--peer_memory") - - ext_modules.append( - CUDAExtension( - name="peer_memory_cuda", - sources=[ - "apex/contrib/csrc/peer_memory/peer_memory_cuda.cu", - "apex/contrib/csrc/peer_memory/peer_memory.cpp", - ], - include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/nccl_p2p")], - extra_compile_args={"cxx": ["-O3"] + version_dependent_macros + generator_flag}, - ) - ) - -if "--nccl_p2p" in sys.argv or "--cuda_ext" in sys.argv: - if "--nccl_p2p" in sys.argv: - sys.argv.remove("--nccl_p2p") - - if not IS_ROCM_PYTORCH: - raise_if_cuda_home_none("--nccl_p2p") - - ext_modules.append( - CUDAExtension( - name="nccl_p2p_cuda", - sources=[ - "apex/contrib/csrc/nccl_p2p/nccl_p2p_cuda.cu", - "apex/contrib/csrc/nccl_p2p/nccl_p2p.cpp", - ], - include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/nccl_p2p")], - extra_compile_args={"cxx": ["-O3"] + version_dependent_macros + generator_flag}, - ) - ) - - -if "--fused_conv_bias_relu" in sys.argv: - sys.argv.remove("--fused_conv_bias_relu") - raise_if_cuda_home_none("--fused_conv_bias_relu") - if check_cudnn_version_and_warn("--fused_conv_bias_relu", 8400): - subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/cudnn-frontend/"]) - ext_modules.append( - CUDAExtension( - name="fused_conv_bias_relu", - sources=["apex/contrib/csrc/conv_bias_relu/conv_bias_relu.cpp"], - include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/cudnn-frontend/include")], - extra_compile_args={"cxx": ["-O3"] + version_dependent_macros + generator_flag}, - ) - ) - -if "--cuda_ext" in sys.argv: - sys.argv.remove("--cuda_ext") - -setup( - name="apex", - version=dcu_version, - packages=find_packages( - exclude=("build", "csrc", "include", "tests", "dist", "docs", "tests", "examples", "apex.egg-info",) - ), - description="PyTorch Extensions written by NVIDIA", - ext_modules=ext_modules, - cmdclass={'build_ext': BuildExtension} if ext_modules else {}, - extras_require=extras, -) diff --git a/tests/L0/run_amp/__init__.py b/tests/L0/run_amp/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/L0/run_amp/test_add_param_group.py b/tests/L0/run_amp/test_add_param_group.py deleted file mode 100644 index 3bdd702..0000000 --- a/tests/L0/run_amp/test_add_param_group.py +++ /dev/null @@ -1,159 +0,0 @@ -import unittest - -import functools as ft -import itertools as it - -from apex import amp -from apex.amp import _amp_state -import torch -from torch import nn -import torch.nn.functional as F -from torch.nn import Parameter - -from utils import common_init, HALF, FLOAT,\ - ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT - -class MyModel(torch.nn.Module): - def __init__(self, unique, dtype=torch.float16): - super(MyModel, self).__init__() - self.weight0 = Parameter(unique + - torch.arange(2, device='cuda', dtype=torch.float32)) - self.weight1 = Parameter(1. + unique + torch.arange(2, device='cuda', dtype=dtype)) - - @staticmethod - def ops(input, weight0, weight1): - return ((input*(weight0.float()))*(weight1.float())).sum() - - def forward(self, input): - return self.ops(input, self.weight0, self.weight1) - - -# Abandon all hope, ye who enter here. - - -class TestAddParamGroup(unittest.TestCase): - def setUp(self): - self.x = torch.ones((2), device='cuda', dtype=torch.float32) - common_init(self) - - def tearDown(self): - pass - - def zero_grad(self, models, optimizer, how_to_zero): - if how_to_zero == "none": - for model in models: - for param in model.parameters(): - param.grad = None - elif how_to_zero == "model": - for model in models: - model.zero_grad() - elif how_to_zero == "optimizer": - optimizer.zero_grad() - - def test_add_param_group(self): - for opt_level in ("O0", "O1", "O2", "O3", "O4", "O5"): - for zero_before_add in (True, False): - for try_accumulation in (True, False): - if opt_level in {"O4", "O5"}: - model0 = MyModel(1, torch.bfloat16) - model1 = MyModel(2, torch.bfloat16) - else: - model0 = MyModel(1) - model1 = MyModel(2) - - optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}], - momentum=0.125) - - optimizer.zero_grad() - loss = model0(self.x) - loss.backward() - optimizer.step() - - if zero_before_add: - optimizer.zero_grad() - optimizer.add_param_group({'params' : model1.parameters(), 'lr' : 0.5}) - if not zero_before_add: - optimizer.zero_grad() - - loss = model0(self.x) + model1(self.x) - loss.backward(retain_graph=try_accumulation) - if try_accumulation: - loss.backward() - optimizer.step() - - # Once more to make sure the new params pick up momemtums properly - optimizer.zero_grad() - loss = model0(self.x) + model1(self.x) - loss.backward(retain_graph=try_accumulation) - if try_accumulation: - loss.backward() - optimizer.step() - - reference_params = [param.data.clone() for param in model0.parameters()] + \ - [param.data.clone() for param in model1.parameters()] - - for how_to_zero in "none", "model", "optimizer": - if opt_level in {"O4", "O5"}: - model0 = MyModel(1, torch.bfloat16) - model1 = MyModel(2, torch.bfloat16) - else: - model0 = MyModel(1) - model1 = MyModel(2) - - optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}], - momentum=0.125) - - _amp_state.allow_incoming_model_not_fp32 = True - [model0, model1], optimizer = amp.initialize([model0, model1], - optimizer, - opt_level=opt_level, - verbosity=0, - cast_model_type=False) - _amp_state.allow_incoming_model_not_fp32 = False - - _amp_state.loss_scalers[0]._loss_scale = 4.0 - - self.zero_grad([model0, model1], optimizer, how_to_zero) - loss = model0(self.x) - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - optimizer.step() - - if zero_before_add: - self.zero_grad([model0, model1], optimizer, how_to_zero) - optimizer.add_param_group({'params' : model1.parameters(), 'lr' : 0.5}) - if not zero_before_add: - self.zero_grad([model0, model1], optimizer, how_to_zero) - - loss = model0(self.x) + model1(self.x) - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward(retain_graph=try_accumulation) - if try_accumulation: - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - optimizer.step() - - # Once more to make sure the new params pick up momentums properly - self.zero_grad([model0, model1], optimizer, how_to_zero) - loss = model0(self.x) + model1(self.x) - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward(retain_graph=try_accumulation) - if try_accumulation: - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - optimizer.step() - - final_params = [param.data.clone() for param in model0.parameters()] + \ - [param.data.clone() for param in model1.parameters()] - - for reference, final in zip(reference_params, final_params): - # TODO: remove the conversion once allclose supports bfloat16 type. - if final.dtype == torch.bfloat16: - final = final.float() - self.assertTrue(torch.allclose(reference.to(final.dtype), final), - "opt_level = {}, how_to_zero = {}, zero_before_add = {}".format( - opt_level, how_to_zero, zero_before_add)) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/L0/run_amp/test_basic_casts.py b/tests/L0/run_amp/test_basic_casts.py deleted file mode 100644 index 75fbb51..0000000 --- a/tests/L0/run_amp/test_basic_casts.py +++ /dev/null @@ -1,258 +0,0 @@ -import unittest - -import functools as ft -import itertools as it - -from apex import amp -import torch -from torch import nn -import torch.nn.functional as F - -from utils import common_init, HALF, FLOAT,\ - ALWAYS_HALF, ALWAYS_BFLOAT16, ALWAYS_FLOAT, MATCH_INPUT - -from apex.testing.common_utils import skipIfRocm - -def run_layer_test(test_case, fns, expected, input_shape, test_backward=True): - for fn, typ in it.product(fns, expected.keys()): - x = torch.randn(input_shape, dtype=typ).requires_grad_() - y = fn(x) - test_case.assertEqual(y.type(), expected[typ]) - if test_backward: - y.float().sum().backward() - test_case.assertEqual(x.grad.type(), MATCH_INPUT[typ]) - -class _TestBasicCasts(unittest.TestCase): - def _test_linear(self, expected): - m = nn.Linear(self.h, self.h) - f = ft.partial(F.linear, weight=m.weight, bias=m.bias) - run_layer_test(self, [m, f], expected, (self.b, self.h)) - - def _test_conv2d(self, expected): - m = nn.Conv2d(self.c, self.c, self.k) - f = ft.partial(F.conv2d, weight=m.weight, bias=m.bias) - run_layer_test(self, [m, f], expected, (self.b, self.c, self.h, self.h)) - - def _test_softmax(self, expected): - m = nn.Softmax(dim=1) - f = ft.partial(F.softmax, dim=1) - run_layer_test(self, [m, f], expected, (self.b, self.h)) - - def _test_group_norm(self, expected): - m = nn.GroupNorm(num_groups=4, num_channels=self.c) - run_layer_test(self, [m], expected, (self.b, self.c, self.h, self.h)) - - def _test_mse_loss(self, expected): - shape = (self.b, self.h) - target = torch.randn(shape) - mod = nn.MSELoss() - m = lambda x: mod(x, target) - f = ft.partial(F.mse_loss, target=target) - run_layer_test(self, [m], expected, shape) - - def _test_relu(self, expected): - run_layer_test(self, [nn.ReLU(), F.relu], expected, (self.b, self.h)) - - def _test_batch_norm(self, expected): - m = nn.BatchNorm2d(num_features=self.c) - f = ft.partial(F.batch_norm, running_mean=m.running_mean, running_var=m.running_var, - weight=m.weight, bias=m.bias, training=True) - run_layer_test(self, [m], expected, (self.b, self.c, self.h, self.h)) - - # Test forward-only for BN inference - m.eval() - f = ft.partial(F.batch_norm, running_mean=m.running_mean, running_var=m.running_var, - weight=m.weight, bias=m.bias, training=False) - run_layer_test(self, [m, f], expected, (self.b, self.c, self.h, self.h), - test_backward=False) - -class TestBasicCastsHalf(_TestBasicCasts): - def setUp(self): - self.handle = amp.init(enabled=True, patch_type=torch.half) - common_init(self) - - def tearDown(self): - self.handle._deactivate() - - def test_linear_is_half(self): - self._test_linear(ALWAYS_HALF) - - def test_conv2d_is_half(self): - self._test_conv2d(ALWAYS_HALF) - - def test_softmax_is_float(self): - self._test_softmax(ALWAYS_FLOAT) - - def test_group_norm_is_float(self): - self._test_group_norm(ALWAYS_FLOAT) - - def test_mse_loss_is_float(self): - self._test_mse_loss(ALWAYS_FLOAT) - - def test_relu_is_match(self): - self._test_relu(MATCH_INPUT) - - def test_batch_norm_is_match(self): - self._test_batch_norm(MATCH_INPUT) - -class TestBasicCastsBFloat16(_TestBasicCasts): - def setUp(self): - self.handle = amp.init(enabled=True, patch_type=torch.bfloat16) - common_init(self) - - def tearDown(self): - self.handle._deactivate() - - @skipIfRocm - def test_linear_is_bfloat16(self): - self._test_linear(ALWAYS_BFLOAT16) - - @skipIfRocm - def test_conv2d_is_bfloat16(self): - self._test_conv2d(ALWAYS_BFLOAT16) - - def test_softmax_is_float(self): - self._test_softmax(ALWAYS_FLOAT) - - def test_group_norm_is_float(self): - self._test_group_norm(ALWAYS_FLOAT) - - def test_mse_loss_is_float(self): - self._test_mse_loss(ALWAYS_FLOAT) - - def test_relu_is_match(self): - self._test_relu(MATCH_INPUT) - - def test_batch_norm_is_match(self): - self._test_batch_norm(MATCH_INPUT) - -class TestBannedMethods(unittest.TestCase): - def setUp(self): - self.handle = amp.init(enabled=True, patch_type=torch.half) - common_init(self) - - def tearDown(self): - self.handle._deactivate() - - def bce_common(self, assertion, dtype=torch.half): - shape = (self.b, self.h) - target = torch.rand(shape) - mod = nn.BCELoss() - m = lambda x: mod(x, target) - f = ft.partial(F.binary_cross_entropy, target=target) - for fn in [m, f]: - x = torch.rand(shape, dtype=dtype) - assertion(fn, x) - - def test_bce_raises_by_default(self): - assertion = lambda fn, x: self.assertRaises(NotImplementedError, fn, x) - self.bce_common(assertion, dtype=torch.half) - - # handle with bfloat16 as patch_type - self.handle._deactivate() - self.handle = amp.init(enabled=True, patch_type=torch.bfloat16) - self.bce_common(assertion, dtype=torch.bfloat16) - - def test_bce_is_float_with_allow_banned(self): - self.handle._deactivate() - self.handle = amp.init(enabled=True, allow_banned=True, patch_type=torch.half) - assertion = lambda fn, x: self.assertEqual(fn(x).type(), FLOAT) - self.bce_common(assertion, dtype=torch.half) - - # handle with bfloat16 as patch_type - self.handle._deactivate() - self.handle = amp.init(enabled=True, allow_banned=True, patch_type=torch.bfloat16) - self.bce_common(assertion, dtype=torch.bfloat16) - -class _TestTensorCasts(unittest.TestCase): - def _test_matmul_method(self, expected): - other = torch.randn(self.h, self.h) - lhs = lambda x: x.matmul(other) - rhs = lambda x: other.matmul(x) - run_layer_test(self, [lhs, rhs], expected, (self.h, self.h)) - - def _test_matmul_op(self, expected): - other = torch.randn(self.h, self.h) - lhs = lambda x: x @ other - rhs = lambda x: other @ x - run_layer_test(self, [lhs, rhs], expected, (self.h, self.h)) - - def _test_pow_method(self, expected): - fn = lambda x: x.pow(2.) - run_layer_test(self, [fn], expected, (self.b, self.h)) - - def _test_pow_op(self, expected): - fn = lambda x: x ** 2. - run_layer_test(self, [fn], expected, (self.b, self.h)) - - def _test_cpu(self, expected): - fn = lambda x: x.cpu() - run_layer_test(self, [fn], expected, (self.b, self.h)) - - def _test_sum(self, expected): - fn = lambda x: x.sum() - run_layer_test(self, [fn], expected, (self.b, self.h)) - - # TODO: maybe more tests on disabled casting? - -class TestTensorCastsHalf(_TestTensorCasts): - def setUp(self): - self.handle = amp.init(enabled=True, patch_type=torch.half) - common_init(self) - - def tearDown(self): - self.handle._deactivate() - - def test_matmul_method_is_half(self): - self._test_matmul_method(ALWAYS_HALF) - - def test_matmul_op_is_half(self): - self._test_matmul_op(ALWAYS_HALF) - - def test_pow_method_is_float(self): - self._test_pow_method(ALWAYS_FLOAT) - - def test_pow_op_is_float(self): - self._test_pow_op(ALWAYS_FLOAT) - - def test_cpu_is_float(self): - always_cpu_float = {torch.float: 'torch.FloatTensor', - torch.half: 'torch.FloatTensor'} - self._test_cpu(always_cpu_float) - - def test_sum_is_float(self): - self._test_sum(ALWAYS_FLOAT) - -class TestTensorCastsBFloat16(_TestTensorCasts): - def setUp(self): - self.handle = amp.init(enabled=True, patch_type=torch.bfloat16) - common_init(self) - - def tearDown(self): - self.handle._deactivate() - - @skipIfRocm - def test_matmul_method_is_bfloat16(self): - self._test_matmul_method(ALWAYS_BFLOAT16) - - @skipIfRocm - def test_matmul_op_is_bfloat16(self): - self._test_matmul_op(ALWAYS_BFLOAT16) - - def test_pow_method_is_float(self): - self._test_pow_method(ALWAYS_FLOAT) - - def test_pow_op_is_float(self): - self._test_pow_op(ALWAYS_FLOAT) - - def test_cpu_is_float(self): - always_cpu_float = {torch.float: 'torch.FloatTensor', - torch.bfloat16: 'torch.FloatTensor'} - self._test_cpu(always_cpu_float) - - def test_sum_is_float(self): - self._test_sum(ALWAYS_FLOAT) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/L0/run_amp/test_cache.py b/tests/L0/run_amp/test_cache.py deleted file mode 100644 index ba26eaa..0000000 --- a/tests/L0/run_amp/test_cache.py +++ /dev/null @@ -1,158 +0,0 @@ -import unittest - -import functools as ft -import itertools as it - -from apex import amp -from apex.amp import _amp_state -import torch -from torch import nn -import torch.nn.functional as F - -from utils import common_init, HALF, FLOAT,\ - ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT - -def get_reference_grad(i, w, ops): - # Creating new tensors ensures, among other things, that the new tensors are not in the cache. - # In fact, they are guaranteed not to use the cache because they are not torch.nn.Parameters. - fp32_i = i.detach().clone().float() - fp32_w = w.detach().clone().float().requires_grad_() - loss = ops(fp32_i, fp32_w) - loss.backward() - return fp32_w.grad - -class WhitelistModule(torch.nn.Module): - def __init__(self, dtype): - super(WhitelistModule, self).__init__() - self.weight = torch.nn.Parameter(torch.arange(8*8, device='cuda', dtype=dtype).view(8,8)) - - @staticmethod - def ops(input, weight): - return (input.mm(weight)).mm(weight).sum() - - def forward(self, input): - return self.ops(input, self.weight) - - -class BlacklistModule(torch.nn.Module): - def __init__(self, dtype): - super(BlacklistModule, self).__init__() - self.weight = torch.nn.Parameter(torch.arange(2*8, device='cuda', dtype=dtype).view(2,8)) - - @staticmethod - def ops(input, weight): - return (input + torch.pow(weight, 2) + torch.pow(weight, 2)).sum() - - def forward(self, input): - return self.ops(input, self.weight) - - -class PromoteModule(torch.nn.Module): - def __init__(self, dtype): - super(PromoteModule, self).__init__() - self.weight = torch.nn.Parameter(torch.arange(2*8, device='cuda', dtype=dtype).view(2,8)) - - @staticmethod - def ops(input, weight): - return ((input*weight)*weight).sum() - - def forward(self, input): - return self.ops(input, self.weight) - -class TestCache(unittest.TestCase): - def setUp(self): - self.x = torch.ones((2, 8), device='cuda', dtype=torch.float32) - common_init(self) - - def tearDown(self): - pass - - def train_eval_train_test(self, module, t, opt_level): - model = module(t).cuda() - optimizer = torch.optim.SGD(model.parameters(), lr=1.0) - - _amp_state.allow_incoming_model_not_fp32 = True - model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level, verbosity=0) - _amp_state.allow_incoming_model_not_fp32 = False - - def training_step(): - for param in model.parameters(): - param.grad = None - - loss = model(self.x).sum() - _amp_state.loss_scalers[0]._loss_scale = 4.0 - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - - self.assertEqual(len([p.grad for p in model.parameters() if p.grad is not None]), 1) - self.assertEqual(model.weight.grad.type(), model.weight.type()) - - reference_grad = get_reference_grad(self.x, model.weight, model.ops) - - # Currently there's no difference in the allclose calls, so no need for branching, - # but I'm keeping this in case we want different tolerances for fp16 and fp32 checks. - if model.weight.grad.type() == "torch.cuda.HalfTensor": - self.assertTrue(torch.allclose(model.weight.grad.float(), reference_grad)) - elif model.weight.grad.type() == "torch.cuda.BFloat16Tensor": - self.assertTrue(torch.allclose(model.weight.grad.float(), reference_grad)) - elif model.weight.grad.type() == "torch.cuda.FloatTensor": - self.assertTrue(torch.allclose(model.weight.grad.float(), reference_grad)) - else: - raise RuntimeError("model.weight.grad.type = {}".format(model.weight.grad.type())) - - model.weight.data -= 1. - - # Simulates first epoch - training_step() - - # Simulates eval - with torch.no_grad(): - loss = model(self.x).sum() - - # Simulates resuming training after eval - training_step() - - _amp_state.handle._deactivate() - - # I could easily have these as a set of for loops in a single test, - # instead of going for granularity. - def test_whitelist_module_fp16_weight(self): - self.train_eval_train_test(WhitelistModule, torch.float16, "O1") - - def test_whitelist_module_fp32_weight(self): - self.train_eval_train_test(WhitelistModule, torch.float32, "O1") - - def test_blacklist_module_fp16_weight(self): - self.train_eval_train_test(BlacklistModule, torch.float16, "O1") - - def test_blacklist_module_fp32_weight(self): - self.train_eval_train_test(BlacklistModule, torch.float32, "O1") - - def test_promote_module_fp16_weight(self): - self.train_eval_train_test(PromoteModule, torch.float16, "O1") - - def test_promote_module_fp32_weight(self): - self.train_eval_train_test(PromoteModule, torch.float32, "O1") - - # opt_level = O4 - def test_whitelist_module_bfp16_weight(self): - self.train_eval_train_test(WhitelistModule, torch.bfloat16, "O4") - - def test_whitelist_module_fp32_weight(self): - self.train_eval_train_test(WhitelistModule, torch.float32, "O4") - - def test_blacklist_module_bfp16_weight(self): - self.train_eval_train_test(BlacklistModule, torch.bfloat16, "O4") - - def test_blacklist_module_fp32_weight(self): - self.train_eval_train_test(BlacklistModule, torch.float32, "O4") - - def test_promote_module_bfp16_weight(self): - self.train_eval_train_test(PromoteModule, torch.bfloat16, "O4") - - def test_promote_module_fp32_weight(self): - self.train_eval_train_test(PromoteModule, torch.float32, "O4") - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/L0/run_amp/test_checkpointing.py b/tests/L0/run_amp/test_checkpointing.py deleted file mode 100644 index f1080a4..0000000 --- a/tests/L0/run_amp/test_checkpointing.py +++ /dev/null @@ -1,273 +0,0 @@ -import unittest - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim - -from apex import amp - -from utils import common_init, FLOAT -from apex.testing.common_utils import skipFlakyTest - -class MyModel(torch.nn.Module): - def __init__(self): - super(MyModel, self).__init__() - self.conv1 = nn.Conv2d(3, 6, 3, 1, 1) - self.bn1 = nn.BatchNorm2d(6) - self.param = nn.Parameter(torch.randn(1)) - - def forward(self, x): - x = x * self.param - x = F.relu(self.conv1(x)) - x = self.bn1(x) - return x - - -class TestCheckpointing(unittest.TestCase): - def setUp(self): - self.initial_lr = 1e-3 - self.test_opt_levels = ("O0", "O1", "O2", "O3", "O4", "O5") - - def seed(self): - torch.manual_seed(2809) - torch.backends.cudnn.benchmark = False - torch.backends.cudnn.deterministic = True - - def check_state_dict_fp32(self, state_dict): - for key in state_dict: - if 'num_batches_tracked' in key: - continue - param = state_dict[key] - self.assertEqual(param.type(), FLOAT, - 'Parameter in state_dict not FLOAT') - - def train_step(self, model, optimizer, data, loss_ids): - optimizer.zero_grad() - - output = model(data) - - # Call backward for num_losses-1 - for idx in loss_ids: - loss = output.mean() - with amp.scale_loss(loss, optimizer, loss_id=idx) as scaled_loss: - scaled_loss.backward(retain_graph=True) - - optimizer.step() - return output - - def compare_models(self, modelA, modelB, test_setup=''): - state_dictA = modelA.state_dict() - state_dictB = modelB.state_dict() - self.assertEqual(len(state_dictA), len(state_dictB), - 'state_dicts have different lengths' + test_setup) - for key in state_dictA: - paramA = state_dictA[key] - paramB = state_dictB[key] - self.assertTrue((paramA==paramB).all(), - msg='Parameters in state_dices not equal.' + - 'key: {}\nparam: {}\nrestored: {}\ndiff: {} for {}'.format( - key, paramA, paramB, paramA - paramB, test_setup)) - - def test_restoring(self): - nb_epochs = 10 - nb_epochs_restore = nb_epochs // 2 - for opt_level in self.test_opt_levels: - for res_opt_level in self.test_opt_levels: - for amp_before_load in [True, False]: - for num_losses in range(1, 3): - test_setup = ('#' * 75 + '\n' + \ - f'opt_level {opt_level}\n' + \ - f'restore_opt_level {res_opt_level}\n' + \ - f'amp_before_load {amp_before_load}\n' + \ - f'num_losses {num_losses}\n') - - self.seed() - - # Create reference model - model = MyModel().to('cuda') - - optimizer = optim.SGD(model.parameters(), - lr=self.initial_lr) - - # Initialize with num_losses*2 for the original model and the restored one - model, optimizer = amp.initialize( - model, optimizer, opt_level=opt_level, - num_losses=num_losses*2, verbosity=0) - - # Compare training behavior for same restore option - # We cannot really generalize it, since a saved model in O0 - # would introduce a skipped step in O1, which will raise an error - if opt_level == res_opt_level: - # train for nb_epochs and restore after nb_epochs_restore - for epoch in range(nb_epochs): - - x = torch.randn(16, 3, 24, 24, device='cuda') - output = self.train_step( - model, optimizer, x, range(num_losses)) - # Initialize model one step before comparing. - # Otherwise the batchnorm layers will be updated - # additionally in restore_model - if epoch == (nb_epochs_restore - 1): - # Load model and optimizer - checkpoint = { - 'model': model.state_dict(), - 'optimizer': optimizer.state_dict(), - 'amp': amp.state_dict() - } - # Check state_dict for FP32 tensors - self.check_state_dict_fp32(checkpoint['model']) - - # Restore model - restore_model = MyModel().to('cuda') - restore_optimizer = optim.SGD( - restore_model.parameters(), - lr=self.initial_lr) - - if amp_before_load: - restore_model, restore_optimizer = amp.initialize( - restore_model, - restore_optimizer, - opt_level=res_opt_level, - num_losses=num_losses*2, - verbosity=0) - - restore_model.load_state_dict(checkpoint['model']) - restore_optimizer.load_state_dict(checkpoint['optimizer']) - # FIXME: We cannot test the amp.state_dict in the same script - # amp.load_state_dict(checkpoint['amp']) - - if not amp_before_load: - restore_model, restore_optimizer = amp.initialize( - restore_model, - restore_optimizer, - opt_level=res_opt_level, - num_losses=num_losses*2, - verbosity=0) - - elif epoch >= nb_epochs_restore: - restore_output = self.train_step( - restore_model, - restore_optimizer, - x, - range(num_losses, num_losses*2)) - self.assertTrue( - torch.allclose(output.float(), restore_output.float()), - 'Output of reference and restored models differ for ' + test_setup) - self.compare_models(model, restore_model, test_setup) - # if opt_level != res_opt_level - else: - # skip tests for different opt_levels - continue - - @skipFlakyTest - def test_loss_scale_decrease(self): - num_losses = 3 - nb_decrease_loss_scales = [0, 1, 2] - for opt_level in self.test_opt_levels: - #print('#' * 75 + f'\n opt_level {opt_level}\n') - # Create new tmp copy for this run - nb_decrease_loss_scales_tmp = list(nb_decrease_loss_scales) - - model = MyModel().to('cuda') - - optimizer = optim.SGD(model.parameters(), - lr=self.initial_lr) - - model, optimizer = amp.initialize( - model, optimizer, opt_level=opt_level, num_losses=num_losses, - verbosity=0) - - if amp._amp_state.opt_properties.loss_scale != 'dynamic': - #print('Static loss scale set. Skipping opt_level.') - continue - - # force to skip some updates to decrease the loss_scale - initial_loss_scales = [] - for idx in range(num_losses): - initial_loss_scales.append( - amp._amp_state.loss_scalers[idx].loss_scale()) - - for _ in range(len(nb_decrease_loss_scales)): - x = torch.randn(16, 3, 24, 24, device='cuda') - for idx in range(num_losses): - while nb_decrease_loss_scales_tmp[idx] > 0: - optimizer.zero_grad() - output = model(x * 2**17) - loss = output.mean() - - with amp.scale_loss(loss, optimizer, loss_id=idx) as scaled_loss: - scaled_loss.backward(retain_graph=True) - optimizer.step() - nb_decrease_loss_scales_tmp[idx] -= 1 - - # Check loss scales afterwards - updated_loss_scales = [] - for idx in range(num_losses): - updated_loss_scales.append( - amp._amp_state.loss_scalers[idx].loss_scale()) - for factor, update_ls, init_ls in zip(nb_decrease_loss_scales, - updated_loss_scales, - initial_loss_scales): - self.assertEqual(update_ls, init_ls / 2**factor) - - # Check state dict - amp_state_dict = amp.state_dict() - for scaler_idx, factor, init_ls in zip(amp_state_dict, - nb_decrease_loss_scales, - initial_loss_scales): - scaler = amp_state_dict[scaler_idx] - self.assertEqual(scaler['loss_scale'], init_ls / 2**factor) - unskipped_target = 0 - self.assertEqual(scaler['unskipped'], unskipped_target) - - def test_state_dict(self): - for opt_level in self.test_opt_levels: - # Skip O3 - if opt_level == 'O3': - continue - - model = MyModel().to('cuda') - torch_ver = torch.__version__.split('a0')[0] - optimizer = None - if torch_ver == '1.10.0': - optimizer = optim.Adam(model.parameters(), lr=1e-3) - else: - optimizer = optim.Adam(model.parameters(), lr=1e-3, capturable=True) - model, optimizer = amp.initialize( - model, optimizer, opt_level=opt_level, verbosity=0) - - # Export state_dict and check for Half - state_dict = model.state_dict() - for key in state_dict: - self.assertFalse('Half' in state_dict[key].type()) - self.assertFalse('BFloat16' in state_dict[key].type()) - - # Check, if model is still trainable - # Create dummy data - data = torch.randn(10, 3, 4, 4, device='cuda') - target = torch.randn(10, 6, 4, 4, device='cuda') - - # Get initnial loss - optimizer.zero_grad() - output = model(data) - loss = F.mse_loss(output, target) - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - optimizer.step() - last_loss = loss.item() - - # train for some epochs - for epoch in range(10): - optimizer.zero_grad() - output = model(data) - loss = F.mse_loss(output, target) - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - optimizer.step() - self.assertTrue(loss.item() < last_loss) - last_loss = loss.item() - -if __name__=='__main__': - unittest.main() - diff --git a/tests/L0/run_amp/test_fused_sgd.py b/tests/L0/run_amp/test_fused_sgd.py deleted file mode 100644 index 5084a60..0000000 --- a/tests/L0/run_amp/test_fused_sgd.py +++ /dev/null @@ -1,793 +0,0 @@ -import unittest - -import functools as ft -import itertools as it - -from apex import amp -from apex.amp import _amp_state -import torch -from torch import nn -import torch.nn.functional as F -from torch.nn import Parameter - -from utils import common_init, HALF, FLOAT,\ - ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT - -try: - import amp_C - disabled = False - from apex.optimizers import FusedSGD as FusedSGD -except ImportError as err: - print("amp_C fused kernels unavailable, disabling TestMultiTensorApply. ImportError was ", err) - disabled = True - - -class MyModel(torch.nn.Module): - def __init__(self, unique): - super(MyModel, self).__init__() - self.weight0 = Parameter(unique + - torch.arange(2, device='cuda', dtype=torch.float32)) - self.weight1 = Parameter(1. + unique + torch.arange(2, device='cuda', dtype=torch.float16)) - - @staticmethod - def ops(input, weight0, weight1): - return ((input*(weight0.float()))*(weight1.float())).sum() - - def forward(self, input): - return self.ops(input, self.weight0, self.weight1) - -# Abandon all hope, ye who enter here. - -# This is hands down the ugliest code I have ever written, but it succeeds in testing -# multiple models/optimizers/losses fairly thoroughly. Many of the different test cases -# require slightly divergent code in a way that seems near-impossible to genericize into a simple -# cross product or nested loops. - -class TestMultipleModelsOptimizersLosses(unittest.TestCase): - def setUp(self): - self.x = torch.ones((2), device='cuda', dtype=torch.float32) - common_init(self) - - def tearDown(self): - pass - - @unittest.skipIf(disabled, "amp_C is unavailable") - def test_2models2losses1optimizer(self): - model0 = MyModel(1) - model1 = MyModel(2) - - optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}, - {'params' : model1.parameters(), 'lr' : 0.5}], - momentum=0.125) - - reference_grads = [] - for i in range(2): - optimizer.zero_grad() - loss0 = model0(self.x) - loss1 = model1(self.x) - loss0.backward() - loss1.backward() - - reference_grads.append([param.grad.data.clone() for param in model0.parameters()] + - [param.grad.data.clone() for param in model1.parameters()]) - - optimizer.step() - - final_params = [param.data.clone() for param in model0.parameters()] + \ - [param.data.clone() for param in model1.parameters()] - - for materialize_master_grads in (False, True): - for opt_level in ("O0", "O1", "O2", "O3"): - for how_to_zero in ("none", "model", "optimizer"): - for use_multiple_loss_scalers in (False, True): - if opt_level == "O1" or opt_level == "O2": - inject_inf_iters = (-1, 0, 1) - else: - inject_inf_iters = (-1,) - - for inject_inf in inject_inf_iters: - if inject_inf >= 0: - inject_inf_locs = ("fp16", "fp32") - which_backwards = (0, 1) - else: - inject_inf_locs = ("fdsa",) - which_backwards = (None,) - - for inject_inf_loc in inject_inf_locs: - for which_backward in which_backwards: - if use_multiple_loss_scalers: - num_losses = 2 - loss_ids = [0, 1] - else: - num_losses = 1 - loss_ids = [0, 0] - - if inject_inf >= 0: - iters = 3 - else: - iters = 2 - - model0 = MyModel(1) - model1 = MyModel(2) - - models = [model0, model1] - - optimizer = FusedSGD([{'params' : model0.parameters(), 'lr' : 0.25}, - {'params' : model1.parameters(), 'lr' : 0.5}], - momentum=0.125, - materialize_master_grads=materialize_master_grads) - - _amp_state.allow_incoming_model_not_fp32 = True - [model0, model1], optimizer = amp.initialize( - [model0, model1], - optimizer, - opt_level=opt_level, - verbosity=0, - cast_model_type=False, - num_losses=num_losses) - _amp_state.allow_incoming_model_not_fp32 = False - - _amp_state.loss_scalers[0]._loss_scale = 4.0 - if use_multiple_loss_scalers: - _amp_state.loss_scalers[1]._loss_scale = 16.0 - - unskipped = 0 - for i in range(iters): - if how_to_zero == "none": - for model in models: - for param in model.parameters(): - param.grad = None - elif how_to_zero == "model": - for model in models: - model.zero_grad() - else: - optimizer.zero_grad() - - loss0 = model0(self.x) - loss1 = model1(self.x) - - with amp.scale_loss(loss0, optimizer, loss_id=loss_ids[0]) as scaled_loss: - scaled_loss.backward() - if i == inject_inf and which_backward == 0: - if inject_inf_loc == "fp32": - model0.weight0.grad[0] = float('inf') - elif inject_inf_loc == "fp16": - model0.weight1.grad[0] = float('inf') - with amp.scale_loss(loss1, optimizer, loss_id=loss_ids[1]) as scaled_loss: - scaled_loss.backward() - if i == inject_inf and which_backward == 1: - if inject_inf_loc == "fp32": - model1.weight0.grad[0] = float('inf') - elif inject_inf_loc == "fp16": - model1.weight1.grad[0] = float('inf') - - if i != inject_inf: - master_params = amp.master_params(optimizer) - for param, reference_grad in zip(master_params, reference_grads[unskipped]): - if opt_level == "O2" and not materialize_master_grads: - continue - else: - self.assertTrue(torch.allclose(param.grad.float(), reference_grad.float()), - "opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} use_multiple_loss_scalers {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, use_multiple_loss_scalers)) - unskipped += 1 - optimizer.step() - - model_params = [p for p in model0.parameters()] + [p for p in model1.parameters()] - for model, master, reference in zip( - model_params, - amp.master_params(optimizer), - final_params): - self.assertTrue(torch.allclose(model, reference)) - self.assertTrue(torch.allclose(model, master.to(model.dtype))) - - if opt_level == "O1": - _amp_state.handle._deactivate() - - @unittest.skipIf(disabled, "amp_C is unavailable") - def test_3models2losses1optimizer(self): - - model0 = MyModel(1) - model1 = MyModel(2) - model2 = MyModel(3) - - optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}, - {'params' : model1.parameters(), 'lr' : 0.5}, - {'params' : model2.parameters(), 'lr' : 0.125}], - momentum=0.125) - - reference_grads = [] - for i in range(2): - optimizer.zero_grad() - loss0 = model0(self.x) + model2(self.x) - loss1 = model1(self.x) + model2(self.x) - loss0.backward() - loss1.backward() - - reference_grads.append([param.grad.data.clone() for param in model0.parameters()] + - [param.grad.data.clone() for param in model1.parameters()] + - [param.grad.data.clone() for param in model2.parameters()]) - - optimizer.step() - - - final_params = [param.data.clone() for param in model0.parameters()] + \ - [param.data.clone() for param in model1.parameters()] + \ - [param.data.clone() for param in model2.parameters()] - - for materialize_master_grads in (False, True): - for opt_level in ("O0", "O1", "O2", "O3"): - for how_to_zero in ("none", "model", "optimizer"): - for use_multiple_loss_scalers in (False, True): - if opt_level == "O1" or opt_level == "O2": - inject_inf_iters = (-1, 0, 1) - else: - inject_inf_iters = (-1,) - - for inject_inf in inject_inf_iters: - if inject_inf >= 0: - inject_inf_locs = ("fp16", "fp32") - which_backwards = (0, 1) - else: - inject_inf_locs = ("fdsa",) - which_backwards = (None,) - - for inject_inf_loc in inject_inf_locs: - for which_backward in which_backwards: - if use_multiple_loss_scalers: - num_losses = 2 - loss_ids = [0, 1] - else: - num_losses = 1 - loss_ids = [0, 0] - - if inject_inf >= 0: - iters = 3 - if which_backward == 0: - which_models = (0, 2) - elif which_backward == 1: - which_models = (1, 2) - else: - iters = 2 - which_models = (None,) - - for which_model in which_models: - model0 = MyModel(1) - model1 = MyModel(2) - model2 = MyModel(3) - - models = [model0, model1, model2] - - optimizer = FusedSGD([{'params' : model0.parameters(), 'lr' : 0.25}, - {'params' : model1.parameters(), 'lr' : 0.5}, - {'params' : model2.parameters(), 'lr' : 0.125}], - momentum=0.125, - materialize_master_grads=materialize_master_grads) - - _amp_state.allow_incoming_model_not_fp32 = True - [model0, model1, model2], optimizer = amp.initialize( - [model0, model1, model2], - optimizer, - opt_level=opt_level, - verbosity=0, - cast_model_type=False, - num_losses=num_losses) - _amp_state.allow_incoming_model_not_fp32 = False - - _amp_state.loss_scalers[0]._loss_scale = 4.0 - if use_multiple_loss_scalers: - _amp_state.loss_scalers[1]._loss_scale = 16.0 - - unskipped = 0 - for i in range(iters): - if how_to_zero == "none": - for model in models: - for param in model.parameters(): - param.grad = None - elif how_to_zero == "model": - for model in models: - model.zero_grad() - else: - optimizer.zero_grad() - - loss0 = model0(self.x) + model2(self.x) - loss1 = model1(self.x) + model2(self.x) - - with amp.scale_loss(loss0, optimizer, loss_id=loss_ids[0]) as scaled_loss: - scaled_loss.backward() - if i == inject_inf and which_backward == 0: - if which_model == 0: - inj_model = model0 - elif which_model == 2: - inj_model = model2 - else: - raise RuntimeError(which_model + " invalid for loss 0") - if inject_inf_loc == "fp32": - inj_model.weight0.grad[0] = float('inf') - elif inject_inf_loc == "fp16": - inj_model.weight1.grad[0] = float('inf') - with amp.scale_loss(loss1, optimizer, loss_id=loss_ids[1]) as scaled_loss: - scaled_loss.backward() - if i == inject_inf and which_backward == 1: - if which_model == 1: - inj_model = model1 - elif which_model == 2: - inj_model = model2 - else: - raise RuntimeError(which_model + " invalid for loss 1 ") - if inject_inf_loc == "fp32": - inj_model.weight0.grad[0] = float('inf') - elif inject_inf_loc == "fp16": - inj_model.weight1.grad[0] = float('inf') - - if i != inject_inf: - master_params = amp.master_params(optimizer) - for param, reference_grad in zip(master_params, reference_grads[unskipped]): - if opt_level == "O2" and not materialize_master_grads: - continue - else: - self.assertTrue(torch.allclose(param.grad.float(), reference_grad.float()), - "opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} which_model {} use_multiple_loss_scalers {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, which_model, use_multiple_loss_scalers)) - unskipped += 1 - - optimizer.step() - - model_params = [p for p in model0.parameters()] + \ - [p for p in model1.parameters()] + \ - [p for p in model2.parameters()] - for model, master, reference in zip( - model_params, - amp.master_params(optimizer), - final_params): - self.assertTrue(torch.allclose(model, reference)) - self.assertTrue(torch.allclose(model, master.to(model.dtype))) - - if opt_level == "O1": - _amp_state.handle._deactivate() - - @unittest.skipIf(disabled, "amp_C is unavailable") - def test_2models2losses2optimizers(self): - model0 = MyModel(1) - model1 = MyModel(2) - - optimizer0 = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}], - momentum=0.125) - optimizer1 = torch.optim.SGD([{'params' : model1.parameters(), 'lr' : 0.5}], - momentum=0.25) - - # Don't do it like this: reference_grads = [[]]*5 - # because then it creates a list of 5 references to the same "[]" and appending - # to any of them effectively makes you append to all of them, which multiplies - # the resulting size of reference_grads by 5x and needless to say makes the test fail. - reference_grads = [[], [], [], [], []] - final_params = [None, None, None, None, None] - for i in range(2): - optimizer0.zero_grad() - optimizer1.zero_grad() - loss0 = model0(self.x) - loss1 = model1(self.x) - loss0.backward() - loss1.backward() - - reference_grads[0].append([param.grad.data.clone() for param in model0.parameters()] + - [param.grad.data.clone() for param in model1.parameters()]) - - optimizer0.step() - optimizer1.step() - - final_params[0] = [param.data.clone() for param in model0.parameters()] + \ - [param.data.clone() for param in model1.parameters()] - - def what_got_skipped(which_iter, which_backward): - if which_iter == 0 and which_backward == 0: - return 1 - if which_iter == 0 and which_backward == 1: - return 2 - if which_iter == 1 and which_backward == 0: - return 3 - if which_iter == 1 and which_backward == 1: - return 4 - return 0 - - for which_iter in (0,1): - for which_backward in (0,1): - model0 = MyModel(1) - model1 = MyModel(2) - - optimizer0 = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}], - momentum=0.125) - optimizer1 = torch.optim.SGD([{'params' : model1.parameters(), 'lr' : 0.5}], - momentum=0.25) - - for i in range(3): - optimizer0.zero_grad() - optimizer1.zero_grad() - loss0 = model0(self.x) - loss1 = model1(self.x) - loss0.backward() - loss1.backward() - - if i != which_iter: - reference_grads[what_got_skipped(which_iter, which_backward)].append( - [param.grad.data.clone() for param in model0.parameters()] + - [param.grad.data.clone() for param in model1.parameters()]) - - if i == which_iter: - if which_backward == 0: - optimizer1.step() - else: - optimizer0.step() - else: - optimizer0.step() - optimizer1.step() - - final_params[what_got_skipped(which_iter, which_backward)] = \ - [param.data.clone() for param in model0.parameters()] + \ - [param.data.clone() for param in model1.parameters()] - - for materialize_master_grads in (False, True): - for opt_level in ("O0", "O1", "O2", "O3"): - for how_to_zero in ("none", "model", "optimizer"): - for use_multiple_loss_scalers in (False, True): - if opt_level == "O1" or opt_level == "O2": - inject_inf_iters = (-1, 0, 1) - else: - inject_inf_iters = (-1,) - - for inject_inf in inject_inf_iters: - if inject_inf >= 0: - inject_inf_locs = ("fp16", "fp32") - which_backwards = (0, 1) - else: - inject_inf_locs = ("fdsa",) - which_backwards = (None,) - - for inject_inf_loc in inject_inf_locs: - for which_backward in which_backwards: - if use_multiple_loss_scalers: - num_losses = 2 - loss_ids = [0, 1] - else: - num_losses = 1 - loss_ids = [0, 0] - - if inject_inf >= 0: - iters = 3 - else: - iters = 2 - - model0 = MyModel(1) - model1 = MyModel(2) - - models = [model0, model1] - - optimizer0 = FusedSGD([{'params' : model0.parameters(), 'lr' : 0.25}], - momentum=0.125, materialize_master_grads=materialize_master_grads) - optimizer1 = FusedSGD([{'params' : model1.parameters(), 'lr' : 0.5}], - momentum=0.25, materialize_master_grads=materialize_master_grads) - - _amp_state.allow_incoming_model_not_fp32 = True - [model0, model1], [optimizer0, optimizer1] = amp.initialize( - [model0, model1], - [optimizer0, optimizer1], - opt_level=opt_level, - verbosity=0, - cast_model_type=False, - num_losses=num_losses) - _amp_state.allow_incoming_model_not_fp32 = False - - _amp_state.loss_scalers[0]._loss_scale = 4.0 - if use_multiple_loss_scalers: - _amp_state.loss_scalers[1]._loss_scale = 16.0 - - unskipped = 0 - for i in range(iters): - if how_to_zero == "none": - for model in models: - for param in model.parameters(): - param.grad = None - elif how_to_zero == "model": - for model in models: - model.zero_grad() - else: - optimizer0.zero_grad() - optimizer1.zero_grad() - - loss0 = model0(self.x) - loss1 = model1(self.x) - - with amp.scale_loss(loss0, optimizer0, loss_id=loss_ids[0]) as scaled_loss: - scaled_loss.backward() - if i == inject_inf and which_backward == 0: - if inject_inf_loc == "fp32": - model0.weight0.grad[0] = float('inf') - elif inject_inf_loc == "fp16": - model0.weight1.grad[0] = float('inf') - with amp.scale_loss(loss1, optimizer1, loss_id=loss_ids[1]) as scaled_loss: - scaled_loss.backward() - if i == inject_inf and which_backward == 1: - if inject_inf_loc == "fp32": - model1.weight0.grad[0] = float('inf') - elif inject_inf_loc == "fp16": - model1.weight1.grad[0] = float('inf') - - # print("opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} use_multiple_loss_scalers {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, use_multiple_loss_scalers)) - - if i != inject_inf: - master_params = list(amp.master_params(optimizer0)) + \ - list(amp.master_params(optimizer1)) - for param, reference_grad in zip(master_params, - reference_grads[what_got_skipped(inject_inf, which_backward)][unskipped]): - if opt_level == "O2" and not materialize_master_grads: - continue - else: - self.assertTrue(torch.allclose(param.grad.float(), reference_grad.float())) - unskipped += 1 - - optimizer0.step() - optimizer1.step() - - model_params = [p for p in model0.parameters()] + [p for p in model1.parameters()] - master_params = [p for p in amp.master_params(optimizer0)] + \ - [p for p in amp.master_params(optimizer1)] - for model, master, reference in zip( - model_params, - master_params, - final_params[what_got_skipped(inject_inf, which_backward)]): - self.assertTrue(torch.allclose(model, reference)) - self.assertTrue(torch.allclose(model, master.to(model.dtype))) - - if opt_level == "O1": - _amp_state.handle._deactivate() - - @unittest.skipIf(disabled, "amp_C is unavailable") - def test_3models2losses2optimizers(self): - model0 = MyModel(1) - model1 = MyModel(2) - model2 = MyModel(3) - - optimizer0 = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}, - {'params' : model1.parameters(), 'lr' : 1.0}], - momentum=0.5) - optimizer1 = torch.optim.SGD([{'params' : model2.parameters(), 'lr' : 0.5}], - momentum=0.25) - - # Again, can't do this: reference_grads = [[]]*9 - reference_grads = [[], [], [], [], [], [], [], [], []] - final_params = [None, None, None, None, None, None, None, None, None] - for i in range(2): - optimizer0.zero_grad() - optimizer1.zero_grad() - loss0 = model0(self.x) + model1(self.x) - loss1 = model2(self.x) + model1(self.x) - loss0.backward() - loss1.backward() - - reference_grads[0].append([param.grad.data.clone() for param in model0.parameters()] + - [param.grad.data.clone() for param in model1.parameters()]) - - optimizer0.step() - optimizer1.step() - - final_params[0] = \ - [param.data.clone() for param in model0.parameters()] + \ - [param.data.clone() for param in model1.parameters()] + \ - [param.data.clone() for param in model2.parameters()] - - def what_got_skipped(which_iter, which_backward, which_model): - if which_iter == 0: - if which_backward == 0: - if which_model == 0: - return 1 - if which_model == 1: - return 2 - if which_backward == 1: - if which_model == 2: - return 3 - if which_model == 1: - return 4 - if which_iter == 1: - if which_backward == 0: - if which_model == 0: - return 5 - if which_model == 1: - return 6 - if which_backward == 1: - if which_model == 2: - return 7 - if which_model == 1: - return 8 - return 0 - - for which_iter in (0,1): - for which_backward in (0,1): - if which_backward == 0: - which_models = (0,1) - if which_backward == 1: - which_models = (2,1) - for which_model in which_models: - - model0 = MyModel(1) - model1 = MyModel(2) - model2 = MyModel(3) - - optimizer0 = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}, - {'params' : model1.parameters(), 'lr' : 1.0}], - momentum=0.5) - optimizer1 = torch.optim.SGD([{'params' : model2.parameters(), 'lr' : 0.5}], - momentum=0.25) - - for i in range(3): - optimizer0.zero_grad() - optimizer1.zero_grad() - loss0 = model0(self.x) + model1(self.x) - loss1 = model2(self.x) + model1(self.x) - loss0.backward() - loss1.backward() - - if i != which_iter: - reference_grads[what_got_skipped(which_iter, - which_backward, which_model)].append( - [param.grad.data.clone() for param in model0.parameters()] + - [param.grad.data.clone() for param in model1.parameters()]) - - if i == which_iter: - if which_backward == 0: - # if which_model == 0: - optimizer1.step() - # if which_model == 1: - # optimizer1.step() - if which_backward == 1: - # if which_model == 2: - # optimizer0.step() - # if which_model == 1: - continue - else: - optimizer0.step() - optimizer1.step() - - final_params[what_got_skipped(which_iter, which_backward, which_model)] = \ - [param.data.clone() for param in model0.parameters()] + \ - [param.data.clone() for param in model1.parameters()] + \ - [param.data.clone() for param in model2.parameters()] - - for materialize_master_grads in (False, True): - for opt_level in ("O0", "O1", "O2", "O3"): - for how_to_zero in ("none", "model", "optimizer"): - for use_multiple_loss_scalers in (False, True): - if opt_level == "O1" or opt_level == "O2": - inject_inf_iters = (-1, 0, 1) - else: - inject_inf_iters = (-1,) - - for inject_inf in inject_inf_iters: - if inject_inf >= 0: - inject_inf_locs = ("fp16", "fp32") - which_backwards = (0, 1) - else: - inject_inf_locs = ("fdsa",) - which_backwards = (None,) - - for inject_inf_loc in inject_inf_locs: - for which_backward in which_backwards: - if use_multiple_loss_scalers: - num_losses = 2 - loss_ids = [0, 1] - else: - num_losses = 1 - loss_ids = [0, 0] - - if inject_inf >= 0: - iters = 3 - if which_backward == 0: - which_models = (0, 1) - elif which_backward == 1: - which_models = (2, 1) - else: - iters = 2 - which_models = (None,) - - for which_model in which_models: - model0 = MyModel(1) - model1 = MyModel(2) - model2 = MyModel(3) - - models = [model0, model1, model2] - - optimizer0 = FusedSGD([{'params' : model0.parameters(), 'lr' : 0.25}, - {'params' : model1.parameters(), 'lr' : 1.0}], - momentum=0.5, materialize_master_grads=materialize_master_grads) - optimizer1 = FusedSGD([{'params' : model2.parameters(), 'lr' : 0.5}], - momentum=0.25, materialize_master_grads=materialize_master_grads) - - _amp_state.allow_incoming_model_not_fp32 = True - [model0, model1, model2], [optimizer0, optimizer1] = amp.initialize( - [model0, model1, model2], - [optimizer0, optimizer1], - opt_level=opt_level, - verbosity=0, - cast_model_type=False, - num_losses=num_losses) - _amp_state.allow_incoming_model_not_fp32 = False - - _amp_state.loss_scalers[0]._loss_scale = 4.0 - if use_multiple_loss_scalers: - _amp_state.loss_scalers[1]._loss_scale = 16.0 - - unskipped = 0 - for i in range(iters): - if how_to_zero == "none": - for model in models: - for param in model.parameters(): - param.grad = None - elif how_to_zero == "model": - for model in models: - model.zero_grad() - else: - optimizer0.zero_grad() - optimizer1.zero_grad() - - loss0 = model0(self.x) + model1(self.x) - loss1 = model2(self.x) + model1(self.x) - - with amp.scale_loss(loss0, optimizer0, loss_id=loss_ids[0]) as scaled_loss: - scaled_loss.backward() - if i == inject_inf and which_backward == 0: - if which_model == 0: - inj_model = model0 - elif which_model == 1: - inj_model = model1 - else: - raise RuntimeError(which_model + " invalid for loss 0") - if inject_inf_loc == "fp32": - inj_model.weight0.grad[0] = float('inf') - elif inject_inf_loc == "fp16": - inj_model.weight1.grad[0] = float('inf') - with amp.scale_loss(loss1, [optimizer0, optimizer1], loss_id=loss_ids[1]) as scaled_loss: - scaled_loss.backward() - if i == inject_inf and which_backward == 1: - if which_model == 2: - inj_model = model2 - elif which_model == 1: - inj_model = model1 - else: - raise RuntimeError(which_model + " invalid for loss 1 ") - if inject_inf_loc == "fp32": - inj_model.weight0.grad[0] = float('inf') - elif inject_inf_loc == "fp16": - inj_model.weight1.grad[0] = float('inf') - - if i != inject_inf: - master_params = list(amp.master_params(optimizer0)) + \ - list(amp.master_params(optimizer1)) - for param, reference_grad in zip(master_params, - reference_grads[what_got_skipped(inject_inf, - which_backward, which_model)][unskipped]): - if opt_level == "O2" and not materialize_master_grads: - continue - else: - self.assertTrue(torch.allclose(param.grad.float(), reference_grad.float())) - unskipped += 1 - - optimizer0.step() - optimizer1.step() - - model_params = [p for p in model0.parameters()] + \ - [p for p in model1.parameters()] + \ - [p for p in model2.parameters()] - master_params = [p for p in amp.master_params(optimizer0)] + \ - [p for p in amp.master_params(optimizer1)] - - # print("opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} use_multiple_loss_scalers {} which_model {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, use_multiple_loss_scalers, which_model)) - - for model, master, reference in zip( - model_params, - master_params, - final_params[what_got_skipped(inject_inf, which_backward, which_model)]): - self.assertTrue(torch.allclose(model, reference)) - self.assertTrue(torch.allclose(model, master.to(model.dtype))) - - if opt_level == "O1": - _amp_state.handle._deactivate() - -if __name__ == '__main__': - unittest.main() diff --git a/tests/L0/run_amp/test_larc.py b/tests/L0/run_amp/test_larc.py deleted file mode 100644 index f4f3e83..0000000 --- a/tests/L0/run_amp/test_larc.py +++ /dev/null @@ -1,53 +0,0 @@ -import unittest - -import torch -from torch import nn -from torch.nn import Parameter - -from apex import amp -from apex.parallel.LARC import LARC -from utils import common_init - - -class MyModel(torch.nn.Module): - def __init__(self, unique): - super(MyModel, self).__init__() - self.weight0 = Parameter( - unique + torch.arange(2, device="cuda", dtype=torch.float32) - ) - - def forward(self, input): - return (input * self.weight0).sum() - - -class TestLARC(unittest.TestCase): - def setUp(self): - self.x = torch.ones((2), device="cuda", dtype=torch.float32) - common_init(self) - - def tearDown(self): - pass - - def test_larc_mixed_precision(self): - for opt_level in ["O0", "O1", "O2", "O3"]: - model = MyModel(1) - - optimizer = LARC( - torch.optim.SGD( - [{"params": model.parameters(), "lr": 0.25}], momentum=0.125 - ) - ) - - model, optimizer = amp.initialize( - model, optimizer, opt_level=opt_level, verbosity=0 - ) - - optimizer.zero_grad() - loss = model(self.x) - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - optimizer.step() - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/L0/run_amp/test_multi_tensor_axpby.py b/tests/L0/run_amp/test_multi_tensor_axpby.py deleted file mode 100644 index a65660a..0000000 --- a/tests/L0/run_amp/test_multi_tensor_axpby.py +++ /dev/null @@ -1,183 +0,0 @@ -import unittest - -import functools as ft -import itertools as it - -from apex import amp -import torch -from torch import nn -import torch.nn.functional as F -from math import floor - -from utils import common_init, HALF, FLOAT,\ - ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT - -try: - import amp_C - from amp_C import multi_tensor_axpby - from apex.multi_tensor_apply import MultiTensorApply - disabled = False -except ImportError as err: - print("amp_C fused kernels unavailable, disabling TestMultiTensorApply. ImportError was ", err) - disabled = True - -TORCH_MAJOR = int(torch.__version__.split('.')[0]) -TORCH_MINOR = int(torch.__version__.split('.')[1]) -try_nhwc = (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4) - - -class TestMultiTensorAxpby(unittest.TestCase): - - def setUp(self): - common_init(self) - - self.a = 2.0 - self.b = 8.0 - self.xval = 4.0 - self.yval = 16.0 - self.overflow_buf = torch.cuda.IntTensor(1).zero_() - self.ref = torch.full((1,), 136.0, device="cuda", dtype=torch.float32) - - def tearDown(self): - pass - - # The tensor creation here is written for convenience, not speed. - def axpby(self, sizea, sizeb, applier, repeat_tensors, - x_type, y_type, out_type, inplace=False, nhwc=False): - self.overflow_buf.zero_() - sizea = sizea if isinstance(sizea, tuple) else (sizea,) - sizeb = sizeb if isinstance(sizeb, tuple) else (sizeb,) - t1 = torch.full(sizea, 1.0, device="cuda", dtype=torch.float32) - t2 = torch.full(sizeb, 1.0, device="cuda", dtype=torch.float32) - - def to_fmt(t, tp): - if nhwc: - return t.clone().to(tp, memory_format=torch.channels_last) - else: - return t.clone().to(tp) - - y_list = [] - for i in range(repeat_tensors): - y_list += [to_fmt(t1, y_type)*self.yval, to_fmt(t2, y_type)*self.yval] - - x_list = [to_fmt(x, x_type)*(self.xval/self.yval) for x in y_list] - - if inplace: - out_list = y_list - else: - out_list = [to_fmt(out, out_type)*3.0 for out in y_list] - - applier(multi_tensor_axpby, self.overflow_buf, [x_list, y_list, out_list], self.a, self.b, -1) - - # TODO: Remove this workaround for bfloat16 after torch.allcose() support bfloat16 - if out_type == torch.bfloat16: - out_list = [out.float() for out in out_list] - self.assertTrue(all([torch.allclose(out, self.ref.to(out.dtype)) for out in out_list]), - msg="{} {} {} {} {} {} {}".format(sizea, sizeb, repeat_tensors, - x_type, y_type, out_type, inplace)) - self.assertTrue(self.overflow_buf.item() == 0, - msg="{} {} {} {} {} {} {}".format(sizea, sizeb, repeat_tensors, - x_type, y_type, out_type, inplace)) - - # def find_inf(self, sizea, sizeb, applier, repeat_tensors, in_type, out_type, t, ind, val, inplace=False): - # self.overflow_buf.zero_() - # a = torch.cuda.FloatTensor(sizea).fill_(self.scale) - # b = torch.cuda.FloatTensor(sizeb).fill_(self.scale) - - # out_list = [] - # for i in range(repeat_tensors): - # out_list += [a.clone().to(out_type), b.clone().to(out_type)] - - # if inplace: - # in_list = out_list - # else: - # in_list = [out.clone().to(in_type) for out in out_list] - - # applier(multi_tensor_scale, self.overflow_buf, [in_list, out_list], 1./self.scale) - - # self.overflow_buf.zero_() - # in_list[t][ind] = val - # applier(multi_tensor_scale, self.overflow_buf, [in_list, out_list], 1./self.scale) - # self.assertTrue(self.overflow_buf.item()) - - @unittest.skipIf(disabled, "amp_C is unavailable") - def test_fuzz(self): - input_size_pairs = ( - (7777*77, 555*555), - (777, 555), - (555, 2048*32+1), - (2048*32+1, 555), - (555, 2048*32), - (2048*32, 555), - (33333, 555), - (555, 33333)) - appliers = ( - MultiTensorApply(2048*32), - MultiTensorApply(333), - MultiTensorApply(33333)) - repeat_tensors = ( - 1, - 55) - - for sizea, sizeb in input_size_pairs: - for applier in appliers: - for repeat in repeat_tensors: - for x_type in (torch.float32, torch.float16, torch.bfloat16): - for y_type in (torch.float32, torch.float16, torch.bfloat16): - for out_type in (torch.float32, torch.float16, torch.bfloat16): - for inplace in (True, False): - if inplace is True and (y_type is not out_type): - continue - else: - self.axpby(sizea, sizeb, applier, repeat, - x_type, y_type, out_type, inplace=inplace) - # self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type, - # 0, 0, float('nan'), inplace=inplace) - # self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type, - # 2*repeat-1, sizeb-1, float('inf'), inplace=inplace) - # self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type, - # 2*(repeat//2), sizea//2, float('inf'), inplace=inplace) - - @unittest.skipIf(disabled, "amp_C is unavailable") - @unittest.skipIf(not try_nhwc, "torch version is 1.4 or earlier, may not support nhwc") - def test_fuzz_nhwc(self): - input_size_pairs = ( - ((7, 77, 7, 77), (5, 55, 5, 55)), - ((1, 1, 777, 1), (1, 1, 555, 1)), - ((5, 47, 5, 55), (1, 1, 1, 2048*32 + 1)), - ((1, 1, 1, 2048*32 + 1), (55, 47, 5, 55)), - ((555, 1, 1, 1), (32, 8, 32, 8)), - ((32, 8, 32, 8), (55, 47, 5, 55)), - ((1, 1, 33333, 1), (55, 47, 55, 5)), - ((55, 47, 55, 5), (1, 1, 33333, 1))) - appliers = ( - MultiTensorApply(2048*32), - MultiTensorApply(333), - MultiTensorApply(33333)) - repeat_tensors = ( - 1, - 55) - - for sizea, sizeb in input_size_pairs: - for applier in appliers: - for repeat in repeat_tensors: - for x_type in (torch.float32, torch.float16): - for y_type in (torch.float32, torch.float16): - for out_type in (torch.float32, torch.float16): - for inplace in (True, False): - if inplace is True and (y_type is not out_type): - continue - else: - self.axpby(sizea, sizeb, applier, repeat, - x_type, y_type, out_type, inplace=inplace, nhwc=True) - # self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type, - # 0, 0, float('nan'), inplace=inplace) - # self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type, - # 2*repeat-1, sizeb-1, float('inf'), inplace=inplace) - # self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type, - # 2*(repeat//2), sizea//2, float('inf'), inplace=inplace) - - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/L0/run_amp/test_multi_tensor_l2norm.py b/tests/L0/run_amp/test_multi_tensor_l2norm.py deleted file mode 100644 index ef09e33..0000000 --- a/tests/L0/run_amp/test_multi_tensor_l2norm.py +++ /dev/null @@ -1,87 +0,0 @@ -import unittest - -import functools as ft -import itertools as it - -from apex import amp -import torch -from torch import nn -import torch.nn.functional as F - -from utils import common_init, HALF, FLOAT,\ - ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT - -try: - import amp_C - from amp_C import multi_tensor_l2norm - from apex.multi_tensor_apply import MultiTensorApply - disabled = False -except ImportError as err: - print("amp_C fused kernels unavailable, disabling TestMultiTensorApply. ImportError was ", err) - disabled = True - - -class TestMultiTensorL2Norm(unittest.TestCase): - - def setUp(self): - common_init(self) - self.val = 4.0 - self.overflow_buf = torch.cuda.IntTensor(1).zero_() - - def tearDown(self): - pass - - # The tensor creation here is written for convenience, not speed. - def l2norm(self, sizea, sizeb, applier, repeat_tensors, in_type, per_tensor): - self.overflow_buf.zero_() - a = torch.cuda.FloatTensor(sizea).fill_(self.val) - b = torch.cuda.FloatTensor(sizeb).fill_(self.val) - - in_list = [] - for i in range(repeat_tensors): - in_list += [a.clone().to(in_type), b.clone().to(in_type)] - - if per_tensor: - norm, norm_per_tensor = applier(multi_tensor_l2norm, self.overflow_buf, [in_list], True) - normab = torch.cat((a.norm().view(1), b.norm().view(1))) - norm_per_tensor = norm_per_tensor.view(-1, 2) - else: - norm, _ = applier(multi_tensor_l2norm, self.overflow_buf, [in_list], True) - - reference = torch.cuda.FloatTensor((sizea + sizeb)*repeat_tensors).fill_(self.val).norm() - - self.assertTrue(torch.allclose(norm, reference)) - if per_tensor: - self.assertTrue(torch.allclose(norm_per_tensor, normab)) - self.assertTrue(self.overflow_buf.item() == 0) - - @unittest.skipIf(disabled, "amp_C is unavailable") - def test_fuzz(self): - input_size_pairs = ( - (7777*77, 555*555), - (777, 555), - (555, 2048*32+1), - (2048*32+1, 555), - (555, 2048*32), - (2048*32, 555), - (33333, 555), - (555, 33333)) - appliers = ( - MultiTensorApply(2048*32), - MultiTensorApply(333), - MultiTensorApply(33333)) - repeat_tensors = ( - 1, - 55) - - for sizea, sizeb in input_size_pairs: - for applier in appliers: - for repeat in repeat_tensors: - for in_type in (torch.float32, torch.float16): - for per_tensor in (False, True): - self.l2norm(sizea, sizeb, applier, repeat, in_type, per_tensor) - - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/L0/run_amp/test_multi_tensor_scale.py b/tests/L0/run_amp/test_multi_tensor_scale.py deleted file mode 100644 index 11a8f5e..0000000 --- a/tests/L0/run_amp/test_multi_tensor_scale.py +++ /dev/null @@ -1,129 +0,0 @@ -import unittest - -import functools as ft -import itertools as it - -from apex import amp -import torch -from torch import nn -import torch.nn.functional as F - -from utils import common_init, HALF, FLOAT,\ - ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT - -try: - import amp_C - from amp_C import multi_tensor_scale - from apex.multi_tensor_apply import MultiTensorApply - disabled = False -except ImportError as err: - print("amp_C fused kernels unavailable, disabling TestMultiTensorApply. ImportError was ", err) - disabled = True - - -class TestMultiTensorScale(unittest.TestCase): - - def setUp(self): - common_init(self) - self.scale = 4.0 - self.overflow_buf = torch.cuda.IntTensor(1).zero_() - self.ref = torch.cuda.FloatTensor([1.0]) - - def tearDown(self): - pass - - # The tensor creation here is written for convenience, not speed. - def downscale(self, sizea, sizeb, applier, repeat_tensors, in_type, out_type, inplace=False): - self.overflow_buf.zero_() - a = torch.cuda.FloatTensor(sizea).fill_(self.scale) - b = torch.cuda.FloatTensor(sizeb).fill_(self.scale) - - out_list = [] - for i in range(repeat_tensors): - out_list += [a.clone().to(out_type), b.clone().to(out_type)] - - if inplace: - in_list = out_list - else: - in_list = [out.clone().to(in_type) for out in out_list] - - applier(multi_tensor_scale, self.overflow_buf, [in_list, out_list], 1./self.scale) - - # TODO: Remove this workaround for bfloat16 after torch.allcose() support bfloat16 - if out_type == torch.bfloat16: - out_list = [out.float() for out in out_list] - self.assertTrue(all([torch.allclose(out, self.ref.to(out.dtype)) for out in out_list])) - self.assertTrue(self.overflow_buf.item() == 0) - - def find_inf(self, sizea, sizeb, applier, repeat_tensors, in_type, out_type, t, ind, val, inplace=False): - self.overflow_buf.zero_() - a = torch.cuda.FloatTensor(sizea).fill_(self.scale) - b = torch.cuda.FloatTensor(sizeb).fill_(self.scale) - - out_list = [] - for i in range(repeat_tensors): - out_list += [a.clone().to(out_type), b.clone().to(out_type)] - - if inplace: - in_list = out_list - else: - in_list = [out.clone().to(in_type) for out in out_list] - - applier(multi_tensor_scale, self.overflow_buf, [in_list, out_list], 1./self.scale) - - self.overflow_buf.zero_() - in_list[t][ind] = val - applier(multi_tensor_scale, self.overflow_buf, [in_list, out_list], 1./self.scale) - self.assertTrue(self.overflow_buf.item()) - - # Currently, the fused kernel gives a hard error if you attempt to downscale - # into fp16 output, which imo is the desired behavior. Maybe someday we - # will learn otherwise. - # @unittest.skipIf(disabled, "amp_C is unavailable") - # def test_fp16_to_fp16(self): - # self.downscale(self.fp16, self.fp16, self.fp16_ref) - # - # @unittest.skipIf(disabled, "amp_C is unavailable") - # def test_fp32_to_fp16(self): - # self.downscale(self.fp32, self.fp16, self.fp16_ref) - - @unittest.skipIf(disabled, "amp_C is unavailable") - def test_fuzz(self): - input_size_pairs = ( - (7777*77, 555*555), - (777, 555), - (555, 2048*32+1), - (2048*32+1, 555), - (555, 2048*32), - (2048*32, 555), - (33333, 555), - (555, 33333)) - appliers = ( - MultiTensorApply(2048*32), - MultiTensorApply(333), - MultiTensorApply(33333)) - repeat_tensors = ( - 1, - 55) - - for sizea, sizeb in input_size_pairs: - for applier in appliers: - for repeat in repeat_tensors: - for in_type in (torch.float32, torch.float16, torch.bfloat16): - for out_type in (torch.float32, torch.float16, torch.bfloat16): - for inplace in (True, False): - if inplace is True and (out_type is not in_type): - continue - else: - self.downscale(sizea, sizeb, applier, repeat, in_type, out_type, inplace=inplace) - self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type, - 0, 0, float('nan'), inplace=inplace) - self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type, - 2*repeat-1, sizeb-1, float('inf'), inplace=inplace) - self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type, - 2*(repeat//2), sizea//2, float('inf'), inplace=inplace) - - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/L0/run_amp/test_multiple_models_optimizers_losses.py b/tests/L0/run_amp/test_multiple_models_optimizers_losses.py deleted file mode 100644 index 068c845..0000000 --- a/tests/L0/run_amp/test_multiple_models_optimizers_losses.py +++ /dev/null @@ -1,762 +0,0 @@ -import unittest - -import functools as ft -import itertools as it - -from apex import amp -from apex.amp import _amp_state -import torch -from torch import nn -import torch.nn.functional as F -from torch.nn import Parameter - -from utils import common_init, HALF, FLOAT,\ - ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT - -class MyModel(torch.nn.Module): - def __init__(self, unique): - super(MyModel, self).__init__() - self.weight0 = Parameter(unique + - torch.arange(2, device='cuda', dtype=torch.float32)) - self.weight1 = Parameter(1. + unique + torch.arange(2, device='cuda', dtype=torch.float16)) - - @staticmethod - def ops(input, weight0, weight1): - return ((input*(weight0.float()))*(weight1.float())).sum() - - def forward(self, input): - return self.ops(input, self.weight0, self.weight1) - -# Abandon all hope, ye who enter here. - -# This is hands down the ugliest code I have ever written, but it succeeds in testing -# multiple models/optimizers/losses fairly thoroughly. Many of the different test cases -# require slightly divergent code in a way that seems near-impossible to genericize into a simple -# cross product or nested loops. - -class TestMultipleModelsOptimizersLosses(unittest.TestCase): - def setUp(self): - self.x = torch.ones((2), device='cuda', dtype=torch.float32) - common_init(self) - - def tearDown(self): - pass - - def test_2models2losses1optimizer(self): - model0 = MyModel(1) - model1 = MyModel(2) - - optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}, - {'params' : model1.parameters(), 'lr' : 0.5}], - momentum=0.125) - - reference_grads = [] - for i in range(2): - optimizer.zero_grad() - loss0 = model0(self.x) - loss1 = model1(self.x) - loss0.backward() - loss1.backward() - - reference_grads.append([param.grad.data.clone() for param in model0.parameters()] + - [param.grad.data.clone() for param in model1.parameters()]) - - optimizer.step() - - final_params = [param.data.clone() for param in model0.parameters()] + \ - [param.data.clone() for param in model1.parameters()] - - for opt_level in ("O0", "O1", "O2", "O3"): - for how_to_zero in ("none", "model", "optimizer"): - for use_multiple_loss_scalers in (True, False): - if opt_level == "O1" or opt_level == "O2": - inject_inf_iters = (-1, 0, 1) - else: - inject_inf_iters = (-1,) - - for inject_inf in inject_inf_iters: - if inject_inf >= 0: - inject_inf_locs = ("fp16", "fp32") - which_backwards = (0, 1) - else: - inject_inf_locs = ("fdsa",) - which_backwards = (None,) - - for inject_inf_loc in inject_inf_locs: - for which_backward in which_backwards: - if use_multiple_loss_scalers: - num_losses = 2 - loss_ids = [0, 1] - else: - num_losses = 1 - loss_ids = [0, 0] - - if inject_inf >= 0: - iters = 3 - else: - iters = 2 - - model0 = MyModel(1) - model1 = MyModel(2) - - models = [model0, model1] - - optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}, - {'params' : model1.parameters(), 'lr' : 0.5}], - momentum=0.125) - - _amp_state.allow_incoming_model_not_fp32 = True - [model0, model1], optimizer = amp.initialize( - [model0, model1], - optimizer, - opt_level=opt_level, - verbosity=0, - cast_model_type=False, - num_losses=num_losses) - _amp_state.allow_incoming_model_not_fp32 = False - - _amp_state.loss_scalers[0]._loss_scale = 4.0 - if use_multiple_loss_scalers: - _amp_state.loss_scalers[1]._loss_scale = 16.0 - - unskipped = 0 - for i in range(iters): - if how_to_zero == "none": - for model in models: - for param in model.parameters(): - param.grad = None - elif how_to_zero == "model": - for model in models: - model.zero_grad() - else: - optimizer.zero_grad() - - loss0 = model0(self.x) - loss1 = model1(self.x) - - with amp.scale_loss(loss0, optimizer, loss_id=loss_ids[0]) as scaled_loss: - scaled_loss.backward() - if i == inject_inf and which_backward == 0: - if inject_inf_loc == "fp32": - model0.weight0.grad[0] = float('inf') - elif inject_inf_loc == "fp16": - model0.weight1.grad[0] = float('inf') - with amp.scale_loss(loss1, optimizer, loss_id=loss_ids[1]) as scaled_loss: - scaled_loss.backward() - if i == inject_inf and which_backward == 1: - if inject_inf_loc == "fp32": - model1.weight0.grad[0] = float('inf') - elif inject_inf_loc == "fp16": - model1.weight1.grad[0] = float('inf') - - if i != inject_inf: - for param, reference_grad in zip(amp.master_params(optimizer), - reference_grads[unskipped]): - self.assertTrue(torch.allclose(param.grad.float(), reference_grad.float())) - unskipped += 1 - optimizer.step() - - model_params = [p for p in model0.parameters()] + [p for p in model1.parameters()] - for model, master, reference in zip( - model_params, - amp.master_params(optimizer), - final_params): - self.assertTrue(torch.allclose(model, reference)) - self.assertTrue(torch.allclose(model, master.to(model.dtype))) - - if opt_level == "O1": - _amp_state.handle._deactivate() - - def test_3models2losses1optimizer(self): - - model0 = MyModel(1) - model1 = MyModel(2) - model2 = MyModel(3) - - optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}, - {'params' : model1.parameters(), 'lr' : 0.5}, - {'params' : model2.parameters(), 'lr' : 0.125}], - momentum=0.125) - - reference_grads = [] - for i in range(2): - optimizer.zero_grad() - loss0 = model0(self.x) + model2(self.x) - loss1 = model1(self.x) + model2(self.x) - loss0.backward() - loss1.backward() - - reference_grads.append([param.grad.data.clone() for param in model0.parameters()] + - [param.grad.data.clone() for param in model1.parameters()] + - [param.grad.data.clone() for param in model2.parameters()]) - - optimizer.step() - - - final_params = [param.data.clone() for param in model0.parameters()] + \ - [param.data.clone() for param in model1.parameters()] + \ - [param.data.clone() for param in model2.parameters()] - - for opt_level in ("O0", "O1", "O2", "O3"): - for how_to_zero in ("none", "model", "optimizer"): - for use_multiple_loss_scalers in (True, False): - if opt_level == "O1" or opt_level == "O2": - inject_inf_iters = (-1, 0, 1) - else: - inject_inf_iters = (-1,) - - for inject_inf in inject_inf_iters: - if inject_inf >= 0: - inject_inf_locs = ("fp16", "fp32") - which_backwards = (0, 1) - else: - inject_inf_locs = ("fdsa",) - which_backwards = (None,) - - for inject_inf_loc in inject_inf_locs: - for which_backward in which_backwards: - if use_multiple_loss_scalers: - num_losses = 2 - loss_ids = [0, 1] - else: - num_losses = 1 - loss_ids = [0, 0] - - if inject_inf >= 0: - iters = 3 - if which_backward == 0: - which_models = (0, 2) - elif which_backward == 1: - which_models = (1, 2) - else: - iters = 2 - which_models = (None,) - - for which_model in which_models: - model0 = MyModel(1) - model1 = MyModel(2) - model2 = MyModel(3) - - models = [model0, model1, model2] - - optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}, - {'params' : model1.parameters(), 'lr' : 0.5}, - {'params' : model2.parameters(), 'lr' : 0.125}], - momentum=0.125) - - _amp_state.allow_incoming_model_not_fp32 = True - [model0, model1, model2], optimizer = amp.initialize( - [model0, model1, model2], - optimizer, - opt_level=opt_level, - verbosity=0, - cast_model_type=False, - num_losses=num_losses) - _amp_state.allow_incoming_model_not_fp32 = False - - _amp_state.loss_scalers[0]._loss_scale = 4.0 - if use_multiple_loss_scalers: - _amp_state.loss_scalers[1]._loss_scale = 16.0 - - unskipped = 0 - for i in range(iters): - if how_to_zero == "none": - for model in models: - for param in model.parameters(): - param.grad = None - elif how_to_zero == "model": - for model in models: - model.zero_grad() - else: - optimizer.zero_grad() - - # print("opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} which_model {} use_multiple_loss_scalers {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, which_model, use_multiple_loss_scalers)) - - loss0 = model0(self.x) + model2(self.x) - loss1 = model1(self.x) + model2(self.x) - - with amp.scale_loss(loss0, optimizer, loss_id=loss_ids[0]) as scaled_loss: - scaled_loss.backward() - if i == inject_inf and which_backward == 0: - if which_model == 0: - inj_model = model0 - elif which_model == 2: - inj_model = model2 - else: - raise RuntimeError(which_model + " invalid for loss 0") - if inject_inf_loc == "fp32": - inj_model.weight0.grad[0] = float('inf') - elif inject_inf_loc == "fp16": - inj_model.weight1.grad[0] = float('inf') - with amp.scale_loss(loss1, optimizer, loss_id=loss_ids[1]) as scaled_loss: - scaled_loss.backward() - if i == inject_inf and which_backward == 1: - if which_model == 1: - inj_model = model1 - elif which_model == 2: - inj_model = model2 - else: - raise RuntimeError(which_model + " invalid for loss 1 ") - if inject_inf_loc == "fp32": - inj_model.weight0.grad[0] = float('inf') - elif inject_inf_loc == "fp16": - inj_model.weight1.grad[0] = float('inf') - - if i != inject_inf: - for param, reference_grad in zip(amp.master_params(optimizer), - reference_grads[unskipped]): - self.assertTrue(torch.allclose(param.grad.float(), reference_grad.float())) - unskipped += 1 - - optimizer.step() - - model_params = [p for p in model0.parameters()] + \ - [p for p in model1.parameters()] + \ - [p for p in model2.parameters()] - for model, master, reference in zip( - model_params, - amp.master_params(optimizer), - final_params): - self.assertTrue(torch.allclose(model, reference)) - self.assertTrue(torch.allclose(model, master.to(model.dtype))) - - if opt_level == "O1": - _amp_state.handle._deactivate() - - def test_2models2losses2optimizers(self): - model0 = MyModel(1) - model1 = MyModel(2) - - optimizer0 = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}], - momentum=0.125) - optimizer1 = torch.optim.SGD([{'params' : model1.parameters(), 'lr' : 0.5}], - momentum=0.25) - - # Don't do it like this: reference_grads = [[]]*5 - # because then it creates a list of 5 references to the same "[]" and appending - # to any of them effectively makes you append to all of them, which multiplies - # the resulting size of reference_grads by 5x and needless to say makes the test fail. - reference_grads = [[], [], [], [], []] - final_params = [None, None, None, None, None] - for i in range(2): - optimizer0.zero_grad() - optimizer1.zero_grad() - loss0 = model0(self.x) - loss1 = model1(self.x) - loss0.backward() - loss1.backward() - - reference_grads[0].append([param.grad.data.clone() for param in model0.parameters()] + - [param.grad.data.clone() for param in model1.parameters()]) - - optimizer0.step() - optimizer1.step() - - final_params[0] = [param.data.clone() for param in model0.parameters()] + \ - [param.data.clone() for param in model1.parameters()] - - def what_got_skipped(which_iter, which_backward): - if which_iter == 0 and which_backward == 0: - return 1 - if which_iter == 0 and which_backward == 1: - return 2 - if which_iter == 1 and which_backward == 0: - return 3 - if which_iter == 1 and which_backward == 1: - return 4 - return 0 - - for which_iter in (0,1): - for which_backward in (0,1): - model0 = MyModel(1) - model1 = MyModel(2) - - optimizer0 = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}], - momentum=0.125) - optimizer1 = torch.optim.SGD([{'params' : model1.parameters(), 'lr' : 0.5}], - momentum=0.25) - - for i in range(3): - optimizer0.zero_grad() - optimizer1.zero_grad() - loss0 = model0(self.x) - loss1 = model1(self.x) - loss0.backward() - loss1.backward() - - if i != which_iter: - reference_grads[what_got_skipped(which_iter, which_backward)].append( - [param.grad.data.clone() for param in model0.parameters()] + - [param.grad.data.clone() for param in model1.parameters()]) - - if i == which_iter: - if which_backward == 0: - optimizer1.step() - else: - optimizer0.step() - else: - optimizer0.step() - optimizer1.step() - - final_params[what_got_skipped(which_iter, which_backward)] = \ - [param.data.clone() for param in model0.parameters()] + \ - [param.data.clone() for param in model1.parameters()] - - for opt_level in ("O0", "O1", "O2", "O3"): - for how_to_zero in ("none", "model", "optimizer"): - for use_multiple_loss_scalers in (True, False): - if opt_level == "O1" or opt_level == "O2": - inject_inf_iters = (-1, 0, 1) - else: - inject_inf_iters = (-1,) - - for inject_inf in inject_inf_iters: - if inject_inf >= 0: - inject_inf_locs = ("fp16", "fp32") - which_backwards = (0, 1) - else: - inject_inf_locs = ("fdsa",) - which_backwards = (None,) - - for inject_inf_loc in inject_inf_locs: - for which_backward in which_backwards: - if use_multiple_loss_scalers: - num_losses = 2 - loss_ids = [0, 1] - else: - num_losses = 1 - loss_ids = [0, 0] - - if inject_inf >= 0: - iters = 3 - else: - iters = 2 - - model0 = MyModel(1) - model1 = MyModel(2) - - models = [model0, model1] - - optimizer0 = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}], - momentum=0.125) - optimizer1 = torch.optim.SGD([{'params' : model1.parameters(), 'lr' : 0.5}], - momentum=0.25) - - _amp_state.allow_incoming_model_not_fp32 = True - [model0, model1], [optimizer0, optimizer1] = amp.initialize( - [model0, model1], - [optimizer0, optimizer1], - opt_level=opt_level, - verbosity=0, - cast_model_type=False, - num_losses=num_losses) - _amp_state.allow_incoming_model_not_fp32 = False - - _amp_state.loss_scalers[0]._loss_scale = 4.0 - if use_multiple_loss_scalers: - _amp_state.loss_scalers[1]._loss_scale = 16.0 - - unskipped = 0 - for i in range(iters): - if how_to_zero == "none": - for model in models: - for param in model.parameters(): - param.grad = None - elif how_to_zero == "model": - for model in models: - model.zero_grad() - else: - optimizer0.zero_grad() - optimizer1.zero_grad() - - loss0 = model0(self.x) - loss1 = model1(self.x) - - with amp.scale_loss(loss0, optimizer0, loss_id=loss_ids[0]) as scaled_loss: - scaled_loss.backward() - if i == inject_inf and which_backward == 0: - if inject_inf_loc == "fp32": - model0.weight0.grad[0] = float('inf') - elif inject_inf_loc == "fp16": - model0.weight1.grad[0] = float('inf') - with amp.scale_loss(loss1, optimizer1, loss_id=loss_ids[1]) as scaled_loss: - scaled_loss.backward() - if i == inject_inf and which_backward == 1: - if inject_inf_loc == "fp32": - model1.weight0.grad[0] = float('inf') - elif inject_inf_loc == "fp16": - model1.weight1.grad[0] = float('inf') - - # print("opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} use_multiple_loss_scalers {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, use_multiple_loss_scalers)) - - if i != inject_inf: - master_params = list(amp.master_params(optimizer0)) + \ - list(amp.master_params(optimizer1)) - for param, reference_grad in zip(master_params, - reference_grads[what_got_skipped(inject_inf, which_backward)][unskipped]): - self.assertTrue(torch.allclose(param.grad.float(), reference_grad.float())) - unskipped += 1 - - optimizer0.step() - optimizer1.step() - - model_params = [p for p in model0.parameters()] + [p for p in model1.parameters()] - master_params = [p for p in amp.master_params(optimizer0)] + \ - [p for p in amp.master_params(optimizer1)] - for model, master, reference in zip( - model_params, - master_params, - final_params[what_got_skipped(inject_inf, which_backward)]): - self.assertTrue(torch.allclose(model, reference)) - self.assertTrue(torch.allclose(model, master.to(model.dtype))) - - if opt_level == "O1": - _amp_state.handle._deactivate() - - def test_3models2losses2optimizers(self): - model0 = MyModel(1) - model1 = MyModel(2) - model2 = MyModel(3) - - optimizer0 = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}, - {'params' : model1.parameters(), 'lr' : 1.0}], - momentum=0.5) - optimizer1 = torch.optim.SGD([{'params' : model2.parameters(), 'lr' : 0.5}], - momentum=0.25) - - # Again, can't do this: reference_grads = [[]]*9 - reference_grads = [[], [], [], [], [], [], [], [], []] - final_params = [None, None, None, None, None, None, None, None, None] - for i in range(2): - optimizer0.zero_grad() - optimizer1.zero_grad() - loss0 = model0(self.x) + model1(self.x) - loss1 = model2(self.x) + model1(self.x) - loss0.backward() - loss1.backward() - - reference_grads[0].append([param.grad.data.clone() for param in model0.parameters()] + - [param.grad.data.clone() for param in model1.parameters()]) - - optimizer0.step() - optimizer1.step() - - final_params[0] = \ - [param.data.clone() for param in model0.parameters()] + \ - [param.data.clone() for param in model1.parameters()] + \ - [param.data.clone() for param in model2.parameters()] - - def what_got_skipped(which_iter, which_backward, which_model): - if which_iter == 0: - if which_backward == 0: - if which_model == 0: - return 1 - if which_model == 1: - return 2 - if which_backward == 1: - if which_model == 2: - return 3 - if which_model == 1: - return 4 - if which_iter == 1: - if which_backward == 0: - if which_model == 0: - return 5 - if which_model == 1: - return 6 - if which_backward == 1: - if which_model == 2: - return 7 - if which_model == 1: - return 8 - return 0 - - for which_iter in (0,1): - for which_backward in (0,1): - if which_backward == 0: - which_models = (0,1) - if which_backward == 1: - which_models = (2,1) - for which_model in which_models: - - model0 = MyModel(1) - model1 = MyModel(2) - model2 = MyModel(3) - - optimizer0 = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}, - {'params' : model1.parameters(), 'lr' : 1.0}], - momentum=0.5) - optimizer1 = torch.optim.SGD([{'params' : model2.parameters(), 'lr' : 0.5}], - momentum=0.25) - - for i in range(3): - optimizer0.zero_grad() - optimizer1.zero_grad() - loss0 = model0(self.x) + model1(self.x) - loss1 = model2(self.x) + model1(self.x) - loss0.backward() - loss1.backward() - - if i != which_iter: - reference_grads[what_got_skipped(which_iter, - which_backward, which_model)].append( - [param.grad.data.clone() for param in model0.parameters()] + - [param.grad.data.clone() for param in model1.parameters()]) - - if i == which_iter: - if which_backward == 0: - # if which_model == 0: - optimizer1.step() - # if which_model == 1: - # optimizer1.step() - if which_backward == 1: - # if which_model == 2: - # optimizer0.step() - # if which_model == 1: - continue - else: - optimizer0.step() - optimizer1.step() - - final_params[what_got_skipped(which_iter, which_backward, which_model)] = \ - [param.data.clone() for param in model0.parameters()] + \ - [param.data.clone() for param in model1.parameters()] + \ - [param.data.clone() for param in model2.parameters()] - - for opt_level in ("O0", "O1", "O2", "O3"): - for how_to_zero in ("none", "model", "optimizer"): - for use_multiple_loss_scalers in (True, False): - if opt_level == "O1" or opt_level == "O2": - inject_inf_iters = (-1, 0, 1) - else: - inject_inf_iters = (-1,) - - for inject_inf in inject_inf_iters: - if inject_inf >= 0: - inject_inf_locs = ("fp16", "fp32") - which_backwards = (0, 1) - else: - inject_inf_locs = ("fdsa",) - which_backwards = (None,) - - for inject_inf_loc in inject_inf_locs: - for which_backward in which_backwards: - if use_multiple_loss_scalers: - num_losses = 2 - loss_ids = [0, 1] - else: - num_losses = 1 - loss_ids = [0, 0] - - if inject_inf >= 0: - iters = 3 - if which_backward == 0: - which_models = (0, 1) - elif which_backward == 1: - which_models = (2, 1) - else: - iters = 2 - which_models = (None,) - - for which_model in which_models: - model0 = MyModel(1) - model1 = MyModel(2) - model2 = MyModel(3) - - models = [model0, model1, model2] - - optimizer0 = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}, - {'params' : model1.parameters(), 'lr' : 1.0}], - momentum=0.5) - optimizer1 = torch.optim.SGD([{'params' : model2.parameters(), 'lr' : 0.5}], - momentum=0.25) - - _amp_state.allow_incoming_model_not_fp32 = True - [model0, model1, model2], [optimizer0, optimizer1] = amp.initialize( - [model0, model1, model2], - [optimizer0, optimizer1], - opt_level=opt_level, - verbosity=0, - cast_model_type=False, - num_losses=num_losses) - _amp_state.allow_incoming_model_not_fp32 = False - - _amp_state.loss_scalers[0]._loss_scale = 4.0 - if use_multiple_loss_scalers: - _amp_state.loss_scalers[1]._loss_scale = 16.0 - - unskipped = 0 - for i in range(iters): - if how_to_zero == "none": - for model in models: - for param in model.parameters(): - param.grad = None - elif how_to_zero == "model": - for model in models: - model.zero_grad() - else: - optimizer0.zero_grad() - optimizer1.zero_grad() - - loss0 = model0(self.x) + model1(self.x) - loss1 = model2(self.x) + model1(self.x) - - with amp.scale_loss(loss0, optimizer0, loss_id=loss_ids[0]) as scaled_loss: - scaled_loss.backward() - if i == inject_inf and which_backward == 0: - if which_model == 0: - inj_model = model0 - elif which_model == 1: - inj_model = model1 - else: - raise RuntimeError(which_model + " invalid for loss 0") - if inject_inf_loc == "fp32": - inj_model.weight0.grad[0] = float('inf') - elif inject_inf_loc == "fp16": - inj_model.weight1.grad[0] = float('inf') - with amp.scale_loss(loss1, [optimizer0, optimizer1], loss_id=loss_ids[1]) as scaled_loss: - scaled_loss.backward() - if i == inject_inf and which_backward == 1: - if which_model == 2: - inj_model = model2 - elif which_model == 1: - inj_model = model1 - else: - raise RuntimeError(which_model + " invalid for loss 1 ") - if inject_inf_loc == "fp32": - inj_model.weight0.grad[0] = float('inf') - elif inject_inf_loc == "fp16": - inj_model.weight1.grad[0] = float('inf') - - if i != inject_inf: - master_params = list(amp.master_params(optimizer0)) + \ - list(amp.master_params(optimizer1)) - for param, reference_grad in zip(master_params, - reference_grads[what_got_skipped(inject_inf, - which_backward, which_model)][unskipped]): - self.assertTrue(torch.allclose(param.grad.float(), reference_grad.float())) - unskipped += 1 - - optimizer0.step() - optimizer1.step() - - model_params = [p for p in model0.parameters()] + \ - [p for p in model1.parameters()] + \ - [p for p in model2.parameters()] - master_params = [p for p in amp.master_params(optimizer0)] + \ - [p for p in amp.master_params(optimizer1)] - - # print("opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} use_multiple_loss_scalers {} which_model {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, use_multiple_loss_scalers, which_model)) - - for model, master, reference in zip( - model_params, - master_params, - final_params[what_got_skipped(inject_inf, which_backward, which_model)]): - self.assertTrue(torch.allclose(model, reference)) - self.assertTrue(torch.allclose(model, master.to(model.dtype))) - - if opt_level == "O1": - _amp_state.handle._deactivate() - -if __name__ == '__main__': - unittest.main() diff --git a/tests/L0/run_amp/test_promotion.py b/tests/L0/run_amp/test_promotion.py deleted file mode 100644 index fcc27e4..0000000 --- a/tests/L0/run_amp/test_promotion.py +++ /dev/null @@ -1,112 +0,0 @@ -import unittest - -import itertools as it - -from apex import amp -import torch -from torch import nn -import torch.nn.functional as F - -from utils import common_init, HALF, FLOAT, DTYPES, DTYPES2, MATCH_INPUT - -class _TestPromotion(unittest.TestCase): - def run_binary_promote_test(self, fns, input_shape, lp_type, x_inplace=False): - if lp_type == torch.half: - dtypes = DTYPES - elif lp_type == torch.bfloat16: - dtypes = DTYPES2 - else: - raise RuntimeError("Creating test class with invalid low_precision type. \ - Supported types are torch.half and torch.bfloat16") - type_pairs = it.product(dtypes, dtypes) - for fn, (xtype, ytype) in it.product(fns, type_pairs): - x = torch.randn(input_shape, dtype=xtype).requires_grad_() - x_leaf = x - if x_inplace: - # We need a non-leaf to call in place on - x = x.clone() - y = torch.randn(input_shape, dtype=ytype) - out = fn(x, y) - if x_inplace: - # In place: always match xtype - self.assertEqual(out.type(), x.type()) - else: - # Out of place: match widest type - if xtype == torch.float or ytype == torch.float: - self.assertEqual(out.type(), FLOAT) - else: - self.assertEqual(out.type(), MATCH_INPUT[lp_type]) - out.float().sum().backward() - self.assertEqual(x_leaf.grad.dtype, xtype) - - def _test_cat_matches_widest(self, lp_type): - shape = self.b - ys = [torch.randn(shape, dtype=lp_type) for _ in range(5)] - x_float = torch.randn(shape) - out = torch.cat(ys + [x_float]) - self.assertEqual(out.type(), FLOAT) - x_lp = torch.randn(shape, dtype=lp_type) - out = torch.cat(ys + [x_lp]) - self.assertEqual(out.type(), MATCH_INPUT[lp_type]) - - def _test_inplace_exp_is_error_for_lp(self, lp_type): - xs = torch.randn(self.b) - xs.exp_() - self.assertEqual(xs.type(), FLOAT) - xs = torch.randn(self.b, dtype=lp_type) - with self.assertRaises(NotImplementedError): - xs.exp_() - -class TestPromotionHalf(_TestPromotion): - def setUp(self): - self.handle = amp.init(enabled=True, patch_type=torch.half) - common_init(self) - - def tearDown(self): - self.handle._deactivate() - - def test_atan2_matches_widest(self): - fns = [lambda x, y : torch.atan2(x, y), - lambda x, y : x.atan2(y)] - self.run_binary_promote_test(fns, (self.b,), torch.half) - - def test_mul_matches_widest(self): - fns = [lambda x, y : torch.mul(x, y), - lambda x, y: x.mul(y)] - self.run_binary_promote_test(fns, (self.b,), torch.half) - - def test_cat_matches_widest(self): - self._test_cat_matches_widest(torch.half) - - def test_inplace_exp_is_error_for_half(self): - self._test_inplace_exp_is_error_for_lp(torch.half) - - def test_inplace_add_matches_self(self): - fn = lambda x, y: x.add_(y) - self.run_binary_promote_test([fn], (self.b,), torch.half, x_inplace=True) - -class TestPromotionBFloat16(_TestPromotion): - def setUp(self): - self.handle = amp.init(enabled=True, patch_type=torch.bfloat16) - common_init(self) - - def tearDown(self): - self.handle._deactivate() - - def test_mul_matches_widest(self): - fns = [lambda x, y : torch.mul(x, y), - lambda x, y: x.mul(y)] - self.run_binary_promote_test(fns, (self.b,), torch.bfloat16) - - def test_cat_matches_widest(self): - self._test_cat_matches_widest(torch.bfloat16) - - def test_inplace_exp_is_error_for_bfloat16(self): - self._test_inplace_exp_is_error_for_lp(torch.bfloat16) - - def test_inplace_add_matches_self(self): - fn = lambda x, y: x.add_(y) - self.run_binary_promote_test([fn], (self.b,), torch.bfloat16, x_inplace=True) - -if __name__ == '__main__': - unittest.main() diff --git a/tests/L0/run_amp/test_rnn.py b/tests/L0/run_amp/test_rnn.py deleted file mode 100644 index 4543450..0000000 --- a/tests/L0/run_amp/test_rnn.py +++ /dev/null @@ -1,121 +0,0 @@ -import unittest - -from apex import amp -import random -import torch -from torch import nn - -from utils import common_init, HALF -from apex.testing.common_utils import skipIfRocm - -class TestRnnCells(unittest.TestCase): - def setUp(self): - self.handle = amp.init(enabled=True) - common_init(self) - - def tearDown(self): - self.handle._deactivate() - - def run_cell_test(self, cell, state_tuple=False): - shape = (self.b, self.h) - for typ in [torch.float, torch.half]: - xs = [torch.randn(shape, dtype=typ).requires_grad_() - for _ in range(self.t)] - hidden_fn = lambda: torch.zeros(shape, dtype=typ) - if state_tuple: - hidden = (hidden_fn(), hidden_fn()) - else: - hidden = hidden_fn() - outputs = [] - for i in range(self.t): - hidden = cell(xs[i], hidden) - if state_tuple: - output = hidden[0] - else: - output = hidden - outputs.append(output) - for y in outputs: - self.assertEqual(y.type(), HALF) - outputs[-1].float().sum().backward() - for i, x in enumerate(xs): - self.assertEqual(x.grad.dtype, x.dtype) - - def test_rnn_cell_is_half(self): - cell = nn.RNNCell(self.h, self.h) - self.run_cell_test(cell) - - def test_gru_cell_is_half(self): - cell = nn.GRUCell(self.h, self.h) - self.run_cell_test(cell) - - def test_lstm_cell_is_half(self): - cell = nn.LSTMCell(self.h, self.h) - self.run_cell_test(cell, state_tuple=True) - -class TestRnns(unittest.TestCase): - def setUp(self): - self.handle = amp.init(enabled=True) - common_init(self) - - def tearDown(self): - self.handle._deactivate() - - def run_rnn_test(self, rnn, layers, bidir, state_tuple=False): - for typ in [torch.float, torch.half]: - x = torch.randn((self.t, self.b, self.h), dtype=typ).requires_grad_() - hidden_fn = lambda: torch.zeros((layers + (layers * bidir), - self.b, self.h), dtype=typ) - if state_tuple: - hidden = (hidden_fn(), hidden_fn()) - else: - hidden = hidden_fn() - output, _ = rnn(x, hidden) - self.assertEqual(output.type(), HALF) - output[-1, :, :].float().sum().backward() - self.assertEqual(x.grad.dtype, x.dtype) - - @skipIfRocm - def test_rnn_is_half(self): - configs = [(1, False), (2, False), (2, True)] - for layers, bidir in configs: - rnn = nn.RNN(input_size=self.h, hidden_size=self.h, num_layers=layers, - nonlinearity='relu', bidirectional=bidir) - self.run_rnn_test(rnn, layers, bidir) - - @skipIfRocm - def test_gru_is_half(self): - configs = [(1, False), (2, False), (2, True)] - for layers, bidir in configs: - rnn = nn.GRU(input_size=self.h, hidden_size=self.h, num_layers=layers, - bidirectional=bidir) - self.run_rnn_test(rnn, layers, bidir) - - @skipIfRocm - def test_lstm_is_half(self): - configs = [(1, False), (2, False), (2, True)] - for layers, bidir in configs: - rnn = nn.LSTM(input_size=self.h, hidden_size=self.h, num_layers=layers, - bidirectional=bidir) - self.run_rnn_test(rnn, layers, bidir, state_tuple=True) - - @skipIfRocm - def test_rnn_packed_sequence(self): - num_layers = 2 - rnn = nn.RNN(input_size=self.h, hidden_size=self.h, num_layers=num_layers) - for typ in [torch.float, torch.half]: - x = torch.randn((self.t, self.b, self.h), dtype=typ).requires_grad_() - lens = sorted([random.randint(self.t // 2, self.t) for _ in range(self.b)], - reverse=True) - # `pack_padded_sequence` breaks if default tensor type is non-CPU - torch.set_default_tensor_type(torch.FloatTensor) - lens = torch.tensor(lens, dtype=torch.int64, device=torch.device('cpu')) - packed_seq = nn.utils.rnn.pack_padded_sequence(x, lens) - torch.set_default_tensor_type(torch.cuda.FloatTensor) - hidden = torch.zeros((num_layers, self.b, self.h), dtype=typ) - output, _ = rnn(packed_seq, hidden) - self.assertEqual(output.data.type(), HALF) - output.data.float().sum().backward() - self.assertEqual(x.grad.dtype, x.dtype) - -if __name__ == '__main__': - unittest.main() diff --git a/tests/L0/run_amp/utils.py b/tests/L0/run_amp/utils.py deleted file mode 100644 index 8e163ee..0000000 --- a/tests/L0/run_amp/utils.py +++ /dev/null @@ -1,27 +0,0 @@ -import torch - -HALF = 'torch.cuda.HalfTensor' -FLOAT = 'torch.cuda.FloatTensor' -BFLOAT16 = 'torch.cuda.BFloat16Tensor' - -DTYPES = [torch.half, torch.float] - -DTYPES2 = [torch.bfloat16, torch.float] - -ALWAYS_HALF = {torch.float: HALF, - torch.half: HALF} -ALWAYS_BFLOAT16 = {torch.bfloat16: BFLOAT16, - torch.float: BFLOAT16} -ALWAYS_FLOAT = {torch.float: FLOAT, - torch.half: FLOAT} -MATCH_INPUT = {torch.float: FLOAT, - torch.half: HALF, - torch.bfloat16: BFLOAT16} - -def common_init(test_case): - test_case.h = 64 - test_case.b = 16 - test_case.c = 16 - test_case.k = 3 - test_case.t = 10 - torch.set_default_tensor_type(torch.cuda.FloatTensor) diff --git a/tests/L0/run_fp16util/__init__.py b/tests/L0/run_fp16util/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/L0/run_fp16util/test_fp16util.py b/tests/L0/run_fp16util/test_fp16util.py deleted file mode 100644 index eecddbc..0000000 --- a/tests/L0/run_fp16util/test_fp16util.py +++ /dev/null @@ -1,75 +0,0 @@ -import unittest - -import torch -import torch.nn as nn - -from apex.fp16_utils import FP16Model - - -class DummyBlock(nn.Module): - def __init__(self): - super(DummyBlock, self).__init__() - - self.conv = nn.Conv2d(10, 10, 2) - self.bn = nn.BatchNorm2d(10, affine=True) - - def forward(self, x): - return self.conv(self.bn(x)) - - -class DummyNet(nn.Module): - def __init__(self): - super(DummyNet, self).__init__() - - self.conv1 = nn.Conv2d(3, 10, 2) - self.bn1 = nn.BatchNorm2d(10, affine=False) - self.db1 = DummyBlock() - self.db2 = DummyBlock() - - def forward(self, x): - out = x - out = self.conv1(out) - out = self.bn1(out) - out = self.db1(out) - out = self.db2(out) - return out - - -class DummyNetWrapper(nn.Module): - def __init__(self): - super(DummyNetWrapper, self).__init__() - - self.bn = nn.BatchNorm2d(3, affine=True) - self.dn = DummyNet() - - def forward(self, x): - return self.dn(self.bn(x)) - - -class TestFP16Model(unittest.TestCase): - def setUp(self): - self.N = 64 - self.C_in = 3 - self.H_in = 16 - self.W_in = 32 - self.in_tensor = torch.randn((self.N, self.C_in, self.H_in, self.W_in)).cuda() - self.orig_model = DummyNetWrapper().cuda() - self.fp16_model = FP16Model(self.orig_model) - - def test_params_and_buffers(self): - exempted_modules = [ - self.fp16_model.network.bn, - self.fp16_model.network.dn.db1.bn, - self.fp16_model.network.dn.db2.bn, - ] - for m in self.fp16_model.modules(): - expected_dtype = torch.float if (m in exempted_modules) else torch.half - for p in m.parameters(recurse=False): - assert p.dtype == expected_dtype - for b in m.buffers(recurse=False): - assert b.dtype in (expected_dtype, torch.int64) - - def test_output_is_half(self): - out_tensor = self.fp16_model(self.in_tensor) - assert out_tensor.dtype == torch.half - diff --git a/tests/L0/run_fused_layer_norm/test_fused_layer_norm.py b/tests/L0/run_fused_layer_norm/test_fused_layer_norm.py deleted file mode 100644 index 1821952..0000000 --- a/tests/L0/run_fused_layer_norm/test_fused_layer_norm.py +++ /dev/null @@ -1,298 +0,0 @@ -import itertools -import unittest - -import torch - -import apex -from apex.testing.common_utils import skipFlakyTest - -class TestFusedLayerNorm(unittest.TestCase): - dtype = torch.float - elementwise_affine = False - normalized_shape = [32, 16] - rtol, atol = None, None - fwd_thresholds = dict(rtol=None, atol=None) - bwd_thresholds = dict(rtol=None, atol=None) - mixed_fused = False - - def setUp(self): - # bias and weight are set to 0 and 1 respectively, so no need to copy parameters from cpu module to the gpu one - if not self.mixed_fused: - self.module_cpu_ = apex.normalization.FusedLayerNorm( - normalized_shape=self.normalized_shape, elementwise_affine=self.elementwise_affine).cpu() - self.module_cuda_ = apex.normalization.FusedLayerNorm( - normalized_shape=self.normalized_shape, elementwise_affine=self.elementwise_affine).to(device="cuda", dtype=self.dtype) - else: - assert self.elementwise_affine - self.module_cpu_ = apex.normalization.MixedFusedLayerNorm( - normalized_shape=self.normalized_shape).cpu() - self.module_cuda_ = apex.normalization.MixedFusedLayerNorm( - normalized_shape=self.normalized_shape).to(device="cuda", dtype=self.dtype) - - - def _check_same_output(self, batch_size, contiguous): - torch.cuda.manual_seed(42) - if contiguous: - input_shape = [batch_size] + self.normalized_shape - input_ = torch.randn(input_shape, device="cpu").requires_grad_(True) - input_cuda_ = input_.to(device="cuda", dtype=self.dtype).detach().requires_grad_(True) - self.assertTrue(input_.is_contiguous()) - self.assertTrue(input_cuda_.is_contiguous()) - else: - input_shape = [batch_size] + self.normalized_shape - input_shape = [batch_size * 3] + [self.normalized_shape[0] * 5, self.normalized_shape[1] * 3] - input_src_ = torch.randn(input_shape, device="cpu") - input_ = input_src_[::3, ::5, ::3].detach().requires_grad_(True) - input_cuda_ = input_src_.to(device="cuda", dtype=self.dtype)[::3, ::5, ::3].detach().requires_grad_(True) - # make sure that tensors are NOT contiguous. - self.assertFalse(input_.is_contiguous()) - self.assertFalse(input_cuda_.is_contiguous()) - out_cpu_ = self.module_cpu_(input_) - gO = torch.rand_like(out_cpu_) - out_cpu_.backward(gO) - out_cuda_ = self.module_cuda_(input_cuda_) - gO = gO.to(device="cuda", dtype=self.dtype) - out_cuda_.backward(gO) - self.assertFalse(out_cpu_.is_cuda) - self.assertTrue(out_cuda_.is_cuda) - # TODO (mkozuki): `torch.testing.assert_allclose` is deprecated. - # Use `torch.testing.assert_close`. - # See https://github.com/pytorch/pytorch/issues/61844 - torch.testing.assert_allclose( - out_cpu_.to(device="cuda", dtype=self.dtype), out_cuda_, **self.fwd_thresholds) - torch.testing.assert_allclose( - input_.grad.to(device="cuda", dtype=self.dtype), input_cuda_.grad, **self.bwd_thresholds) - - def _test_same_output(self, batch_size): - for contiguous in (True, False): - with self.subTest(contiguous=contiguous): - self._check_same_output(batch_size, contiguous) - - def test_layer_norm(self): - self._test_same_output(16) - - def test_large_batch(self): - self._test_same_output(65536) - - -class TestFusedRMSNorm(unittest.TestCase): - dtype = torch.float - elementwise_affine = False - normalized_shape = [32, 16] - rtol, atol = None, None - fwd_thresholds = dict(rtol=None, atol=None) - bwd_thresholds = dict(rtol=None, atol=None) - mixed_fused = False - - def setUp(self): - # bias and weight are set to 0 and 1 respectively, so no need to copy parameters from cpu module to the gpu one - if not self.mixed_fused: - self.module_cpu_ = apex.normalization.FusedRMSNorm( - normalized_shape=self.normalized_shape, elementwise_affine=self.elementwise_affine).cpu() - self.module_cuda_ = apex.normalization.FusedRMSNorm( - normalized_shape=self.normalized_shape, elementwise_affine=self.elementwise_affine).to(device="cuda", dtype=self.dtype) - else: - assert self.elementwise_affine - self.module_cpu_ = apex.normalization.MixedFusedRMSNorm( - normalized_shape=self.normalized_shape).cpu() - self.module_cuda_ = apex.normalization.MixedFusedRMSNorm( - normalized_shape=self.normalized_shape).to(device="cuda", dtype=self.dtype) - - def _check_same_output(self, batch_size, contiguous): - torch.cuda.manual_seed(42) - if contiguous: - input_shape = [batch_size] + self.normalized_shape - input_ = torch.randn(input_shape, device="cpu").requires_grad_(True) - input_cuda_ = input_.to(device="cuda", dtype=self.dtype).detach().requires_grad_(True) - self.assertTrue(input_.is_contiguous()) - self.assertTrue(input_cuda_.is_contiguous()) - else: - input_shape = [batch_size] + self.normalized_shape - input_shape = [batch_size * 3] + [self.normalized_shape[0] * 5, self.normalized_shape[1] * 3] - input_src_ = torch.randn(input_shape, device="cpu") - input_ = input_src_[::3, ::5, ::3].detach().requires_grad_(True) - input_cuda_ = input_src_.to(device="cuda", dtype=self.dtype)[::3, ::5, ::3].detach().requires_grad_(True) - # make sure that tensors are NOT contiguous. - self.assertFalse(input_.is_contiguous()) - self.assertFalse(input_cuda_.is_contiguous()) - out_cpu_ = self.module_cpu_(input_) - gO = torch.rand_like(out_cpu_) - out_cpu_.backward(gO) - out_cuda_ = self.module_cuda_(input_cuda_) - # TODO (mkozuki): `torch.testing.assert_allclose` is deprecated. - # Use `torch.testing.assert_close`. - # See https://github.com/pytorch/pytorch/issues/61844 - torch.testing.assert_allclose( - out_cpu_.to(device="cuda", dtype=self.dtype), out_cuda_.clone().detach(), **self.fwd_thresholds) - gO = gO.to(device="cuda", dtype=self.dtype) - out_cuda_.backward(gO) - self.assertFalse(out_cpu_.is_cuda) - self.assertTrue(out_cuda_.is_cuda) - torch.testing.assert_allclose( - input_.grad.to(device="cuda", dtype=self.dtype), input_cuda_.grad, **self.bwd_thresholds) - if self.elementwise_affine: - torch.testing.assert_allclose(self.module_cpu_.weight.grad.to(device="cuda", dtype=self.dtype), - self.module_cuda_.weight.grad, **self.bwd_thresholds) - - def _test_same_output(self, batch_size): - for contiguous in (True, False): - with self.subTest(contiguous=contiguous): - self._check_same_output(batch_size, contiguous) - - def test_layer_norm(self): - self._test_same_output(16) - - def test_large_batch(self): - self._test_same_output(65536) - - -class TestFusedLayerNormElemWise(TestFusedLayerNorm): - elementwise_affine = True - -class TestMixedFusedLayerNormElemWise(TestFusedLayerNorm): - elementwise_affine = True - mixed_fused = True - -class TestFusedLayerNormElemWiseHalf(TestFusedLayerNormElemWise): - dtype = torch.half - - def test_large_batch(self): - self.skipTest("Skip to save time") - -class TestFusedLayerNormElemWiseBFloat16(TestFusedLayerNormElemWise): - dtype = torch.bfloat16 - # NOTE (mkozuki): [BFloat16 Layer Norm flakiness] - # Use thresholds larger than those used in pytorch, see - # https://github.com/pytorch/pytorch/blob/72274e2a2fd55019ec860e1743dbdc5b0c5a5624/torch/testing/_asserts.py#L26 - fwd_thresholds = dict(rtol=1.6e-2, atol=3e-4) - bwd_thresholds = dict(rtol=1.6e-2, atol=3e-3) - - def test_large_batch(self): - self.skipTest("Skip to save time") - - -class TestFusedRMSNormElemWise(TestFusedRMSNorm): - bwd_thresholds = dict(rtol=2e-3, atol=2e-4) - elementwise_affine = True - -class TestMixedFusedRMSNormElemWise(TestFusedRMSNorm): - bwd_thresholds = dict(rtol=2e-3, atol=2e-4) - elementwise_affine = True - mixed_fused = True - -@skipFlakyTest -class TestFusedRMSNormElemWiseHalf(TestFusedRMSNormElemWise): - dtype = torch.half - bwd_thresholds = dict(rtol=1.6e-2, atol=3e-3) - - def test_large_batch(self): - self.skipTest("Skip to save time") - - -@skipFlakyTest -class TestFusedLayerNormElemWiseBFloat16(TestFusedLayerNormElemWise): - dtype = torch.bfloat16 - # NOTE (mkozuki): [BFloat16 Layer Norm flakiness] - # Use thresholds larger than those used in pytorch, see - # https://github.com/pytorch/pytorch/blob/72274e2a2fd55019ec860e1743dbdc5b0c5a5624/torch/testing/_asserts.py#L26 - fwd_thresholds = dict(rtol=1.6e-2, atol=3e-4) - bwd_thresholds = dict(rtol=1.6e-2, atol=3e-3) - - def test_large_batch(self): - self.skipTest("Skip to save time") - - -def _prep_layers(normalized_shape, elementwise_affine, dtype): - native = torch.nn.LayerNorm( - normalized_shape=normalized_shape, elementwise_affine=elementwise_affine - ).to(device="cuda", dtype=dtype) - fused = apex.normalization.FusedLayerNorm( - normalized_shape=normalized_shape, elementwise_affine=elementwise_affine - ).cuda() - return native, fused - - -def _prep_rms_layers(normalized_shape, elementwise_affine, dtype): - native = apex.normalization.FusedRMSNorm( - normalized_shape=normalized_shape, elementwise_affine=elementwise_affine - ) - fused = apex.normalization.FusedRMSNorm( - normalized_shape=normalized_shape, elementwise_affine=elementwise_affine - ).cuda() - return native, fused - - -def _prep_inputs(batch_size, normalized_shape, dtype): - shape = (batch_size, *normalized_shape) - fused = torch.randn(shape).cuda().requires_grad_(True) - with torch.no_grad(): - native = fused.clone().to(dtype).requires_grad_(True) - return native, fused - - -autocast_dtypes = (torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,) - -class TestAutocastFusedLayerNorm(unittest.TestCase): - bf16_fwd_thresholds = dict(rtol=1.6e-2, atol=3e-4) - bf16_bwd_thresholds = dict(rtol=1.6e-2, atol=3e-3) - - def setUp(self): - self.batch_size = 16 - self.normalized_shape = [32, 16] - - def _run_test(self, dtype, elementwise_affine): - native, fused = _prep_layers(self.normalized_shape, elementwise_affine, dtype) - native_x, fused_x = _prep_inputs(self.batch_size, self.normalized_shape, dtype) - - expected = native(native_x) - with torch.cuda.amp.autocast(dtype=dtype): - actual = fused(fused_x) - tols = {'rtol': None, 'atol': None} if dtype == torch.half else TestAutocastFusedLayerNorm.bf16_fwd_thresholds - torch.testing.assert_allclose(actual, expected, **tols) - - g_native = torch.rand_like(expected) - with torch.no_grad(): - g_fused = g_native.clone() - expected.backward(g_native) - actual.backward(g_fused) - - tols = {'rtol': None, 'atol': None} if dtype == torch.half else TestAutocastFusedLayerNorm.bf16_bwd_thresholds - torch.testing.assert_allclose(native_x.grad, fused_x.grad, **tols) - - def test_autocast(self): - for (dtype, elementwise_affine) in itertools.product(autocast_dtypes, (True, False)): - with self.subTest(f"{dtype}-{elementwise_affine}"): - self._run_test(dtype, elementwise_affine) - -@unittest.skip("Skipped on ROCm5.2 due to the failure of reproducing the issue locally. (Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!) Please refer to https://github.com/ROCmSoftwarePlatform/apex/pull/78") -class TestAutocastFusedRMSNorm(unittest.TestCase): - bf16_fwd_thresholds = dict(rtol=1.6e-2, atol=3e-4) - bf16_bwd_thresholds = dict(rtol=1.6e-2, atol=3e-3) - - def setUp(self): - self.batch_size = 16 - self.normalized_shape = [32, 16] - - def _run_test(self, dtype, elementwise_affine): - native, fused = _prep_rms_layers(self.normalized_shape, elementwise_affine, dtype) - native_x, fused_x = _prep_inputs(self.batch_size, self.normalized_shape, dtype) - - expected = native(native_x.cpu()) - with torch.cuda.amp.autocast(dtype=dtype): - actual = fused(fused_x) - tols = {'rtol': None, 'atol': None} if dtype == torch.half else TestAutocastFusedRMSNorm.bf16_fwd_thresholds - torch.testing.assert_allclose(actual, expected.detach().clone().cuda(), **tols) - - g_native = torch.rand_like(expected) - with torch.no_grad(): - g_fused = g_native.detach().clone().cuda() - expected.backward(g_native) - actual.backward(g_fused) - - tols = {'rtol': None, 'atol': None} if dtype == torch.half else TestAutocastFusedRMSNorm.bf16_bwd_thresholds - torch.testing.assert_allclose(native_x.grad.cuda(), fused_x.grad, **tols) - - def test_autocast(self): - for (dtype, elementwise_affine) in itertools.product(autocast_dtypes, (True, False)): - with self.subTest(f"{dtype}-{elementwise_affine}"): - self._run_test(dtype, elementwise_affine) diff --git a/tests/L0/run_mlp/test_mlp.py b/tests/L0/run_mlp/test_mlp.py deleted file mode 100644 index 615dec9..0000000 --- a/tests/L0/run_mlp/test_mlp.py +++ /dev/null @@ -1,222 +0,0 @@ -"""Tests for c++ MLP""" -import unittest -from time import time -import numpy as np - -import torch -from torch import nn - -from apex.mlp import MLP -from apex.testing.common_utils import skipFlakyTest - -batch_size = 1024 -mlp_sizes = [480, 1024, 1024, 512, 256, 1] -num_iters = 10 - -class TestMLP(unittest.TestCase): - - def test_creation(self): - MLP(mlp_sizes) - - @skipFlakyTest - def test_numeric(self): - mlp = MLP(mlp_sizes).cuda() - - mlp_layers = [] - for i in range(mlp.num_layers): - linear = nn.Linear(mlp_sizes[i], mlp_sizes[i + 1]) - mlp.weights[i].data.copy_(linear.weight) - mlp.biases[i].data.copy_(linear.bias) - mlp_layers.append(linear) - mlp_layers.append(nn.ReLU(inplace=True)) - - ref_mlp = nn.Sequential(*mlp_layers).cuda() - - test_input = torch.empty(batch_size, mlp_sizes[0], device="cuda").uniform_(-1., 1.).requires_grad_() - ref_input = test_input.clone().detach().requires_grad_() - mlp_out = mlp(test_input) - ref_out = ref_mlp(ref_input) - np.testing.assert_allclose( - mlp_out.detach().cpu().numpy(), - ref_out.detach().cpu().numpy(), - atol=1e-7, rtol=1e-5) - - # Use mean value as scalar loss. Multiply 10 to make it big enough not zero out - mlp_out.mean().mul(10.).backward() - ref_out.mean().mul(10.).backward() - np.testing.assert_allclose( - test_input.grad.detach().cpu().numpy(), - ref_input.grad.detach().cpu().numpy(), - atol=0, rtol=1e-5) - np.testing.assert_allclose( - mlp.biases[0].grad.detach().cpu().numpy(), - ref_mlp[0].bias.grad.detach().cpu().numpy(), - atol=1e-7, rtol=1e-5) - - @skipFlakyTest - def test_no_bias(self): - for use_activation in ['none', 'relu', 'sigmoid']: - mlp = MLP(mlp_sizes, bias=False, activation=use_activation).cuda() - - mlp_layers = [] - for i in range(mlp.num_layers): - linear = nn.Linear(mlp_sizes[i], mlp_sizes[i + 1], bias=False) - mlp.weights[i].data.copy_(linear.weight) - mlp_layers.append(linear) - if use_activation == 'relu': - mlp_layers.append(nn.ReLU(inplace=True)) - if use_activation == 'sigmoid': - mlp_layers.append(nn.Sigmoid()) - - ref_mlp = nn.Sequential(*mlp_layers).cuda() - - test_input = torch.empty(batch_size, mlp_sizes[0], device="cuda").uniform_(-1., 1.).requires_grad_() - ref_input = test_input.clone().detach().requires_grad_() - mlp_out = mlp(test_input) - ref_out = ref_mlp(ref_input) - np.testing.assert_allclose( - mlp_out.detach().cpu().numpy(), - ref_out.detach().cpu().numpy(), - atol=1e-7, rtol=1e-5) - - # Use mean value as scalar loss. Multiply 10 to make it big enough not zero out - mlp_out.mean().mul(10.).backward() - ref_out.mean().mul(10.).backward() - np.testing.assert_allclose( - test_input.grad.detach().cpu().numpy(), - ref_input.grad.detach().cpu().numpy(), - atol=0, rtol=100) - np.testing.assert_allclose( - mlp.weights[0].grad.detach().cpu().numpy(), - ref_mlp[0].weight.grad.detach().cpu().numpy(), - atol=1e-7, rtol=100) - - @skipFlakyTest - def test_with_bias(self): - for use_activation in ['none', 'relu', 'sigmoid']: - mlp = MLP(mlp_sizes, bias=True, activation=use_activation).cuda() - - mlp_layers = [] - for i in range(mlp.num_layers): - linear = nn.Linear(mlp_sizes[i], mlp_sizes[i + 1], bias=True) - mlp.weights[i].data.copy_(linear.weight) - mlp.biases[i].data.copy_(linear.bias) - mlp_layers.append(linear) - if use_activation == 'relu': - mlp_layers.append(nn.ReLU(inplace=True)) - if use_activation == 'sigmoid': - mlp_layers.append(nn.Sigmoid()) - - ref_mlp = nn.Sequential(*mlp_layers).cuda() - - test_input = torch.empty(batch_size, mlp_sizes[0], device="cuda").uniform_(-1., 1.).requires_grad_() - ref_input = test_input.clone().detach().requires_grad_() - mlp_out = mlp(test_input) - ref_out = ref_mlp(ref_input) - np.testing.assert_allclose( - mlp_out.detach().cpu().numpy(), - ref_out.detach().cpu().numpy(), - atol=1e-7, rtol=1e-5) - - # Use mean value as scalar loss. Multiply 10 to make it big enough not zero out - mlp_out.mean().mul(10.).backward() - ref_out.mean().mul(10.).backward() - np.testing.assert_allclose( - test_input.grad.detach().cpu().numpy(), - ref_input.grad.detach().cpu().numpy(), - atol=0, rtol=1) - np.testing.assert_allclose( - mlp.weights[0].grad.detach().cpu().numpy(), - ref_mlp[0].weight.grad.detach().cpu().numpy(), - atol=1e-7, rtol=1) - np.testing.assert_allclose( - mlp.biases[0].grad.detach().cpu().numpy(), - ref_mlp[0].bias.grad.detach().cpu().numpy(), - atol=1e-7, rtol=1e-5) - - @skipFlakyTest - def test_no_grad(self): - mlp = MLP(mlp_sizes).cuda() - - mlp_layers = [] - for i in range(mlp.num_layers): - linear = nn.Linear(mlp_sizes[i], mlp_sizes[i + 1]) - mlp.weights[i].data.copy_(linear.weight) - mlp.biases[i].data.copy_(linear.bias) - mlp_layers.append(linear) - mlp_layers.append(nn.ReLU(inplace=True)) - - ref_mlp = nn.Sequential(*mlp_layers).cuda() - - test_input = torch.empty(batch_size, mlp_sizes[0], device="cuda").uniform_(-1., 1.) - ref_input = test_input.clone().detach() - mlp_out = mlp(test_input) - ref_out = ref_mlp(ref_input) - np.testing.assert_allclose( - mlp_out.detach().cpu().numpy(), - ref_out.detach().cpu().numpy(), - atol=1e-7, rtol=1e-5) - - # Use mean value as scalar loss. Multiply 10 to make it big enough not zero out - mlp_out.mean().mul(10.).backward() - ref_out.mean().mul(10.).backward() - np.testing.assert_allclose( - mlp.weights[0].grad.detach().cpu().numpy(), - ref_mlp[0].weight.grad.detach().cpu().numpy(), - atol=1e-7, rtol=1e-5) - - def test_performance_half(self): - mlp = MLP(mlp_sizes).cuda().half() - - mlp_layers = [] - for i in range(mlp.num_layers): - linear = nn.Linear(mlp_sizes[i], mlp_sizes[i + 1]) - mlp.weights[i].data.copy_(linear.weight) - mlp.biases[i].data.copy_(linear.bias) - mlp_layers.append(linear) - mlp_layers.append(nn.ReLU(inplace=True)) - - ref_mlp = nn.Sequential(*mlp_layers).cuda().half() - - test_input = torch.empty( - batch_size, mlp_sizes[0], device="cuda", dtype=torch.half).fill_(10.).requires_grad_() - ref_input = torch.empty( - batch_size, mlp_sizes[0], device="cuda", dtype=torch.half).fill_(10.).requires_grad_() - - # Warm up GPU - for _ in range(100): - ref_out = ref_mlp(ref_input) - ref_loss = ref_out.mean() - ref_mlp.zero_grad() - ref_loss.backward() - mlp_out = mlp(test_input) - test_loss = mlp_out.mean() - mlp.zero_grad() - test_loss.backward() - - #torch.cuda.profiler.start() - torch.cuda.synchronize() - start_time = time() - for _ in range(num_iters): - ref_out = ref_mlp(ref_input) - ref_loss = ref_out.mean() - ref_mlp.zero_grad() - ref_loss.backward() - torch.cuda.synchronize() - stop_time = time() - print(F"\nPytorch MLP time {(stop_time - start_time) * 1000. / num_iters:.4f} ms") - - torch.cuda.synchronize() - start_time = time() - for _ in range(num_iters): - mlp_out = mlp(test_input) - test_loss = mlp_out.mean() - mlp.zero_grad() - test_loss.backward() - torch.cuda.synchronize() - stop_time = time() - print(F"C++ MLP time {(stop_time - start_time) * 1000. / num_iters:.4f} ms") - #torch.cuda.profiler.stop() - -if __name__ == '__main__': - unittest.main() diff --git a/tests/L0/run_optimizers/__init__.py b/tests/L0/run_optimizers/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/L0/run_optimizers/test_fused_novograd.py b/tests/L0/run_optimizers/test_fused_novograd.py deleted file mode 100755 index fa94e71..0000000 --- a/tests/L0/run_optimizers/test_fused_novograd.py +++ /dev/null @@ -1,170 +0,0 @@ -import torch -from torch.optim import Optimizer -import math -import apex -import unittest - -from test_fused_optimizer import TestFusedOptimizer -from itertools import product - -class Novograd(Optimizer): - """ - Implements Novograd algorithm. - - Args: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.95, 0)) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay (L2 penalty) (default: 0) - grad_averaging: gradient averaging - amsgrad (boolean, optional): whether to use the AMSGrad variant of this - algorithm from the paper `On the Convergence of Adam and Beyond`_ - (default: False) - """ - - def __init__(self, params, lr=1e-3, betas=(0.95, 0), eps=1e-8, - weight_decay=0, grad_averaging=False, amsgrad=False): - if not 0.0 <= lr: - raise ValueError("Invalid learning rate: {}".format(lr)) - if not 0.0 <= eps: - raise ValueError("Invalid epsilon value: {}".format(eps)) - if not 0.0 <= betas[0] < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) - if not 0.0 <= betas[1] < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) - defaults = dict(lr=lr, betas=betas, eps=eps, - weight_decay=weight_decay, - grad_averaging=grad_averaging, - amsgrad=amsgrad) - - super(Novograd, self).__init__(params, defaults) - - def __setstate__(self, state): - super(Novograd, self).__setstate__(state) - for group in self.param_groups: - group.setdefault('amsgrad', False) - - def step(self, closure=None): - """Performs a single optimization step. - - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - loss = closure() - - for group in self.param_groups: - for p in group['params']: - if p.grad is None: - continue - grad = p.grad.data - if grad.is_sparse: - raise RuntimeError('Sparse gradients are not supported.') - amsgrad = group['amsgrad'] - - state = self.state[p] - - # State initialization - if len(state) == 0: - state['step'] = 0 - # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p.data) - # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device) - if amsgrad: - # Maintains max of all exp. moving avg. of sq. grad. values - state['max_exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device) - - exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] - if amsgrad: - max_exp_avg_sq = state['max_exp_avg_sq'] - beta1, beta2 = group['betas'] - - state['step'] += 1 - - norm = torch.sum(torch.pow(grad, 2)) - - if exp_avg_sq == 0: - exp_avg_sq.copy_(norm) - else: - exp_avg_sq.mul_(beta2).add_(norm, alpha=1 - beta2) - - if amsgrad: - # Maintains the maximum of all 2nd moment running avg. till now - torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) - # Use the max. for normalizing running avg. of gradient - denom = max_exp_avg_sq.sqrt().add_(group['eps']) - else: - denom = exp_avg_sq.sqrt().add_(group['eps']) - - grad.div_(denom) - if group['weight_decay'] != 0: - grad.add_(p.data, alpha=group['weight_decay']) - if group['grad_averaging']: - grad.mul_(1 - beta1) - exp_avg.mul_(beta1).add_(grad) - - p.data.add_(exp_avg, alpha=-group['lr']) - - return loss - - -class TestFusedNovoGrad(TestFusedOptimizer): - - def __init__(self, *args, **kwargs): - super(TestFusedNovoGrad, self).__init__(*args, **kwargs) - - # The options for NovoGrad and FusedNovoGrad are very specific if they - # are expected to behave the same. - self.options = {'lr':1e-3, 'betas':(0.95, 0), 'eps':1e-8, - 'weight_decay':0, 'grad_averaging':False, 'amsgrad':False} - - self.tst_options = {'lr':1e-3, 'betas':(0.95, 0), 'eps':1e-8, - 'weight_decay':0, 'grad_averaging':False, 'amsgrad':False, - 'bias_correction':False, 'reg_inside_moment':True, - 'norm_type':2, 'init_zero':False, 'set_grad_none':True} - - self.ref_optim = Novograd - self.fused_optim = apex.optimizers.FusedNovoGrad - - def test_float(self): - self.gen_single_type_test(param_type=torch.float) - - def test_half(self): - self.gen_single_type_test(param_type=torch.float16) - - @unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required") - def test_multi_device(self): - devices = ("cuda:1", "cuda:0") - for current_dev, tensor_dev in product(devices, devices): - with torch.cuda.device(current_dev): - torch.cuda.synchronize() - self.gen_single_type_test(param_type=torch.float, device=tensor_dev) - - - def test_multi_params(self): - sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]] - - tensors = [] - for size in sizes: - tensors.append(torch.rand(size, dtype=torch.float, device="cuda")) - ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim( - tensors, self.options, self.tst_options - ) - - for _ in range(self.iters): - self.gen_grad(ref_param, tst_param) - ref_optim.step() - tst_optim.step() - max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param) - self.assertLessEqual(max_abs_diff, self.max_abs_diff) - self.assertLessEqual(max_rel_diff, self.max_rel_diff) - -if __name__ == '__main__': - unittest.main() diff --git a/tests/L0/run_optimizers/test_fused_optimizer.py b/tests/L0/run_optimizers/test_fused_optimizer.py deleted file mode 100644 index 852068a..0000000 --- a/tests/L0/run_optimizers/test_fused_optimizer.py +++ /dev/null @@ -1,310 +0,0 @@ -from itertools import product -import random -import unittest - -import torch - -import apex - -from apex.testing.common_utils import skipIfRocm - - -class TestFusedOptimizer(unittest.TestCase): - def setUp(self, max_abs_diff=1e-3, max_rel_diff=1, iters=7): - self.max_abs_diff = max_abs_diff - self.max_rel_diff = max_rel_diff - self.iters = iters - torch.manual_seed(9876) - - def tearDown(self): - pass - - def gen_param_optim(self, tensors, options, tst_options=None): - - # Adding this to make backward compatible with existing tests. Just in - # case "tst_options" are not provided, it gets a copy of options - # which contains the parameters for the reference optimizer - if tst_options == None: - tst_options = options - - ref_param = [] - tst_param = [] - for tensor in tensors: - ref_param.append(torch.nn.Parameter(tensor.clone())) - tst_param.append(torch.nn.Parameter(tensor.clone())) - - ref_optim = self.ref_optim(ref_param, **options) - tst_optim = self.fused_optim(tst_param, **tst_options) - - return (ref_param, tst_param, ref_optim, tst_optim) - - def gen_grad(self, ref_param, tst_param): - for p_ref, p_tst in zip(ref_param, tst_param): - p_ref.grad = torch.rand_like(p_ref) - p_tst.grad = p_ref.grad - - def gen_mixed_grad(self, ref_param, tst_param, scale=1.0): - half_grads = [] - for p_ref, p_tst in zip(ref_param, tst_param): - half_grads.append(torch.rand_like(p_ref).half()) - p_ref.grad = half_grads[-1].float() / scale - return half_grads - - def get_max_diff(self, ref_param, tst_param): - max_abs_diff = max_rel_diff = 0 - for p_ref, p_tst in zip(ref_param, tst_param): - max_abs_diff_p = (p_ref - p_tst).abs().max().item() - max_rel_diff_p = ((p_ref - p_tst) / p_ref).abs().max().item() - - if max_abs_diff_p > max_abs_diff: max_abs_diff = max_abs_diff_p - if max_rel_diff_p > max_rel_diff: max_rel_diff = max_rel_diff_p - - return max_abs_diff, max_rel_diff - - def gen_single_type_test(self, param_type=torch.float, device='cuda', *, skip_assert: bool = False): - nelem = 278011 - - # Some ref and test optimizers may require different set of options. - # This is a quick workaround to add that functionality while making - # minimum changes in existing code. - # If there is no "tst_options" field provided, safe to initialize - # the test optimizer with the parameters of reference optimizer. - if not hasattr(self, 'tst_options'): - self.tst_options = self.options - - tensor = torch.rand(nelem, dtype=param_type, device=device) - - ref_param, tst_param, ref_optim, tst_optim = \ - self.gen_param_optim([tensor], self.options, self.tst_options) - - for i in range(self.iters): - self.gen_grad(ref_param, tst_param) - ref_optim.step() - tst_optim.step() - if skip_assert: - return - max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param) - self.assertLessEqual(max_abs_diff, self.max_abs_diff) - self.assertLessEqual(max_rel_diff, self.max_rel_diff) - - -class TestFusedAdam(TestFusedOptimizer): - - def setUp(self): - super().setUp() - torch_ver = torch.__version__.split('a0')[0] - if torch_ver == '1.10.0': - self.options = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08, - 'weight_decay': 0, 'amsgrad': False} - else: - self.options = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08, - 'weight_decay': 0, 'amsgrad': False, "capturable": True} - self.tst_options = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08, - 'weight_decay': 0, 'amsgrad': False} - self.ref_optim = torch.optim.Adam - self.fused_optim = apex.optimizers.FusedAdam - - def test_float(self): - self.gen_single_type_test(param_type=torch.float) - - # NOTE(mkozuki): Current threshold values look too small for BFloat16. - # TODO(mkozuki): Refactor `TestFusedOptimizer` - @unittest.skip("NaN issue observed on ROCm as of 12/1/2021. The failing unit test is introduced by a PyTorch commit sometime in between rocm/pytorch:rocm4.3.1_ubuntu18.04_py3.6_pytorch_1.9.0 and 2021/12/01. Please refer to https://github.com/ROCmSoftwarePlatform/apex/issues/63") - def test_half(self): - self.gen_single_type_test(param_type=torch.float16, skip_assert=True) - - @skipIfRocm - def test_bfloat16(self): - self.gen_single_type_test(param_type=torch.bfloat16, skip_assert=True) - - @unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required") - def test_multi_device(self): - devices = ("cuda:0", "cuda:1") - for current_dev, tensor_dev in product(devices, devices): - with torch.cuda.device(current_dev): - self.gen_single_type_test(param_type=torch.float, device=tensor_dev) - - @unittest.skip('Disable until 8/1/2019 adam/adamw upstream picked') - def test_multi_params(self): - sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]] - - tensors = [] - for size in sizes: - tensors.append(torch.rand(size, dtype=torch.float, device='cuda')) - ref_param, tst_param, ref_optim, tst_optim = \ - self.gen_param_optim(tensors, self.options) - - for i in range(self.iters): - self.gen_grad(ref_param, tst_param) - ref_optim.step() - tst_optim.step() - max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param) - self.assertLessEqual(max_abs_diff, self.max_abs_diff) - self.assertLessEqual(max_rel_diff, self.max_rel_diff) - - @unittest.skip('No longer support fuse scaling') - def test_scale(self): - nelem = 278011 - tensor = torch.rand(nelem, dtype=torch.float, device='cuda') - ref_param, tst_param, ref_optim, tst_optim = \ - self.gen_param_optim([tensor], self.options) - - for i in range(self.iters): - scale = random.random() * 1000 - half_grads = self.gen_mixed_grad(ref_param, tst_param, scale) - ref_optim.step() - tst_optim.step(grads=half_grads, scale=scale) - max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param) - - self.assertLessEqual(max_abs_diff, self.max_abs_diff) - self.assertLessEqual(max_rel_diff, self.max_rel_diff) - - @unittest.skip('No longer support output fp16 param') - def test_fp16_output(self): - nelem = 278011 - - tensor = torch.rand(nelem, dtype=torch.float, device='cuda') - ref_param, tst_param, ref_optim, tst_optim = \ - self.gen_param_optim([tensor], self.options) - - fp16_param = torch.nn.Parameter(tensor.clone().half()) - - for i in range(self.iters): - half_grads = self.gen_mixed_grad(ref_param, tst_param) - ref_optim.step() - tst_optim.step(grads=half_grads, output_params=[fp16_param]) - - max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param) - self.assertLessEqual(max_abs_diff, self.max_abs_diff) - self.assertLessEqual(max_rel_diff, self.max_rel_diff) - - max_abs_diff, max_rel_diff = self.get_max_diff(tst_param, \ - [fp16_param.float()]) - self.assertLessEqual(max_abs_diff, self.max_abs_diff) - self.assertLessEqual(max_rel_diff, self.max_rel_diff) - - def test_adam_option(self): - nelem = 1 - torch_ver = torch.__version__.split('a0')[0] - adam_option = None - if torch_ver == '1.10.0': - adam_option = {'lr':0.01, 'betas':(0.6, 0.9), 'eps':3e-06, - 'weight_decay':0, 'amsgrad':False} - else: - adam_option = {'lr':0.01, 'betas':(0.6, 0.9), 'eps':3e-06, - 'weight_decay':0, 'amsgrad':False, 'capturable':True} - - adam_option_tst = {'lr':0.01, 'betas':(0.6, 0.9), 'eps':3e-06, - 'weight_decay':0, 'amsgrad':False} - - tensor = torch.rand(nelem, dtype=torch.float, device='cuda') - ref_param, tst_param, ref_optim, tst_optim = \ - self.gen_param_optim([tensor], adam_option, adam_option_tst) - - for i in range(self.iters): - self.gen_grad(ref_param, tst_param) - ref_optim.step() - tst_optim.step() - max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param) - - self.assertLessEqual(max_abs_diff, self.max_abs_diff) - self.assertLessEqual(max_rel_diff, self.max_rel_diff) - - -class TestFusedAdagrad(TestFusedOptimizer): - def __init__(self, *args, **kwargs): - super(TestFusedAdagrad, self).__init__(*args, **kwargs) - self.options = {"lr": 5e-4, "eps": 1e-08, "weight_decay": 1.0e-5} - self.ref_optim = torch.optim.Adagrad - self.fused_optim = apex.optimizers.FusedAdagrad - - def test_float(self): - self.gen_single_type_test(param_type=torch.float) - - @unittest.skip("PyTorch optimizer is not numerically correct for fp16") - def test_half(self): - self.gen_single_type_test(param_type=torch.float16) - - @unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required") - def test_multi_device(self): - devices = ("cuda:0", "cuda:1") - for current_dev, tensor_dev in product(devices, devices): - with torch.cuda.device(current_dev): - self.gen_single_type_test(param_type=torch.float, device=tensor_dev) - - - def test_multi_params(self): - sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]] - adagrad_option = {"lr": 5e-4, "eps": 1e-08, "weight_decay": 0} - - tensors = [] - for size in sizes: - tensors.append(torch.rand(size, dtype=torch.float, device="cuda")) - ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim( - tensors, adagrad_option - ) - - for _ in range(self.iters): - self.gen_grad(ref_param, tst_param) - ref_optim.step() - tst_optim.step() - max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param) - self.assertLessEqual(max_abs_diff, self.max_abs_diff) - self.assertLessEqual(max_rel_diff, self.max_rel_diff) - - @unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required") - def test_multi_params_different_devices_throws(self): - sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]] - adagrad_option = {"lr": 5e-4, "eps": 1e-08, "weight_decay": 0} - - tensors = [] - for i, size in enumerate(sizes): - tensors.append(torch.rand(size, dtype=torch.float, device="cuda:"+str(i % 2))) - ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim( - tensors, adagrad_option - ) - self.gen_grad(ref_param, tst_param) - with self.assertRaisesRegex(RuntimeError, "not on the same device"): - tst_optim.step() - - def test_adagrad_option(self): - nelem = 1 - adagrad_option = {"lr": 0.01, "eps": 3e-06, "weight_decay": 0} - - tensor = torch.rand(nelem, dtype=torch.float, device="cuda") - ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim( - [tensor], adagrad_option - ) - - for _ in range(self.iters): - self.gen_grad(ref_param, tst_param) - ref_optim.step() - tst_optim.step() - max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param) - - self.assertLessEqual(max_abs_diff, self.max_abs_diff) - self.assertLessEqual(max_rel_diff, self.max_rel_diff) - - -class TestFusedSGD(TestFusedOptimizer): - def __init__(self, *args, **kwargs): - super(TestFusedSGD, self).__init__(*args, **kwargs) - self.options = {"lr": .25, "momentum": .125} - self.ref_optim = torch.optim.SGD - self.fused_optim = apex.optimizers.FusedSGD - - def test_float(self): - self.gen_single_type_test(param_type=torch.float) - - def test_half(self): - self.gen_single_type_test(param_type=torch.float16) - - @unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required") - def test_multi_device(self): - devices = ("cuda:0", "cuda:1") - for current_dev, tensor_dev in product(devices, devices): - with torch.cuda.device(current_dev): - self.gen_single_type_test(param_type=torch.float, device=tensor_dev) - -if __name__ == '__main__': - unittest.main() diff --git a/tests/L0/run_optimizers/test_fused_optimizer_channels_last.py b/tests/L0/run_optimizers/test_fused_optimizer_channels_last.py deleted file mode 100644 index 7db329b..0000000 --- a/tests/L0/run_optimizers/test_fused_optimizer_channels_last.py +++ /dev/null @@ -1,112 +0,0 @@ -from itertools import product -import random -import unittest - -import torch - -import apex - -# NHWC -class TestFusedOptimizerChannelsLast(unittest.TestCase): - def setUp(self, max_abs_diff=1e-3, max_rel_diff=1, iters=7): - self.max_abs_diff = max_abs_diff - self.max_rel_diff = max_rel_diff - self.iters = iters - torch.manual_seed(9876) - - def tearDown(self): - pass - - def gen_param_optim(self, tensors, options, device, tst_options=None): - - # Adding this to make backward compatible with existing tests. Just in - # case "tst_options" are not provided, it gets a copy of options - # which contains the parameters for the reference optimizer - if tst_options == None: - tst_options = options - - ref_param = [] - tst_param = [] - for tensor in tensors: - input = tensor.clone().contiguous(memory_format=torch.channels_last).to(device) # channels_last - ref_input = tensor.clone().contiguous().to(device) - - self.assertTrue(input.is_contiguous(memory_format=torch.channels_last)) - self.assertTrue(ref_input.is_contiguous(memory_format=torch.contiguous_format)) - - tst_param.append(torch.nn.Parameter(input)) - ref_param.append(torch.nn.Parameter(ref_input)) - - ref_optim = self.ref_optim(ref_param, **options) - tst_optim = self.fused_optim(tst_param, **tst_options) - return (ref_param, tst_param, ref_optim, tst_optim) - - def gen_grad(self, ref_param, tst_param): - for p_ref, p_tst in zip(ref_param, tst_param): - p_ref.grad = torch.rand_like(p_ref) - p_tst.grad = p_ref.grad.clone() #### p_tst is =torch.channels_last but p_tst.grad is torch.contiguous_format - - self.assertTrue(p_tst.grad.is_contiguous(memory_format=torch.contiguous_format)) - self.assertTrue(p_ref.grad.is_contiguous(memory_format=torch.contiguous_format)) - - - def get_max_diff(self, ref_param, tst_param): - max_abs_diff = max_rel_diff = 0 - for p_ref, p_tst in zip(ref_param, tst_param): - self.assertTrue(p_ref.is_contiguous(memory_format=torch.contiguous_format)) - self.assertTrue(p_tst.is_contiguous(memory_format=torch.channels_last)) - max_abs_diff_p = (p_ref - p_tst).abs().max().item() - max_rel_diff_p = ((p_ref - p_tst) / p_ref).abs().max().item() - - if max_abs_diff_p > max_abs_diff: max_abs_diff = max_abs_diff_p - if max_rel_diff_p > max_rel_diff: max_rel_diff = max_rel_diff_p - - return max_abs_diff, max_rel_diff - - def gen_single_type_test(self, param_type=torch.float, device='cuda', *, skip_assert: bool = False): - # nelem = 278011 - - # Some ref and test optimizers may require different set of options. - # This is a quick workaround to add that functionality while making - # minimum changes in existing code. - # If there is no "tst_options" field provided, safe to initialize - # the test optimizer with the parameters of reference optimizer. - if not hasattr(self, 'tst_options'): - self.tst_options = self.options - - tensor = torch.rand([3,4,2,3], dtype=param_type, device=device) - ref_param, tst_param, ref_optim, tst_optim = \ - self.gen_param_optim([tensor], self.options, device, self.tst_options) - - for i in range(self.iters): - self.gen_grad(ref_param, tst_param) - ref_optim.step() - tst_optim.step() - if skip_assert: - return - max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param) - self.assertLessEqual(max_abs_diff, self.max_abs_diff) - self.assertLessEqual(max_rel_diff, self.max_rel_diff) - -class TestFusedSGDChannelLast(TestFusedOptimizerChannelsLast): - def __init__(self, *args, **kwargs): - super(TestFusedSGDChannelLast, self).__init__(*args, **kwargs) - self.options = {"lr": .25, "momentum": .125} - self.ref_optim = torch.optim.SGD - self.fused_optim = apex.optimizers.FusedSGD - - def test_float(self): - self.gen_single_type_test(param_type=torch.float) - - def test_half(self): - self.gen_single_type_test(param_type=torch.float16) - - @unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required") - def test_multi_device(self): - devices = ("cuda:0", "cuda:1") - for current_dev, tensor_dev in product(devices, devices): - with torch.cuda.device(current_dev): - self.gen_single_type_test(param_type=torch.float, device=tensor_dev) - -if __name__ == '__main__': - unittest.main() diff --git a/tests/L0/run_optimizers/test_lamb.py b/tests/L0/run_optimizers/test_lamb.py deleted file mode 100644 index c6ef9aa..0000000 --- a/tests/L0/run_optimizers/test_lamb.py +++ /dev/null @@ -1,337 +0,0 @@ -import unittest -import os - -import torch -from torch.optim import Optimizer -import apex -from apex.multi_tensor_apply import multi_tensor_applier -from itertools import product - -class RefLAMB(Optimizer): - r"""Implements Lamb algorithm. - - It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. - - Arguments: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-6) - weight_decay (float, optional): weight decay (L2 penalty) (default: 0.01) - - .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes: - https://arxiv.org/abs/1904.00962 - """ - - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01): - if not 0.0 <= lr: - raise ValueError("Invalid learning rate: {}".format(lr)) - if not 0.0 <= eps: - raise ValueError("Invalid epsilon value: {}".format(eps)) - if not 0.0 <= betas[0] < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) - if not 0.0 <= betas[1] < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) - defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) - super(RefLAMB, self).__init__(params, defaults) - if multi_tensor_applier.available: - import amp_C - self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm - # Skip buffer - self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=self.param_groups[0]["params"][0].device) - self.multi_tensor_lamb = amp_C.multi_tensor_lamb - else: - raise RuntimeError('apex.optimizers.FusedLAMB requires cuda extensions') - - def step(self, closure=None): - """Performs a single optimization step. - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - loss = closure() - - # create separate grad lists for fp32 and fp16 params - g_all_32, g_all_16 = [], [] - for group in self.param_groups: - for p in group['params']: - if p.grad is None: - continue - if p.dtype == torch.float32: - g_all_32.append(p.grad.data) - elif p.dtype == torch.float16: - g_all_16.append(p.grad.data) - else: - raise RuntimeError('FusedLAMB only support fp16 and fp32.') - - device = self.param_groups[0]["params"][0].device - g_norm_32, g_norm_16 = torch.zeros(1, device=device), torch.zeros(1, device=device) - # compute grad norm for two lists - if len(g_all_32) > 0: - g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm, - self._dummy_overflow_buf, - [g_all_32], False)[0] - if len(g_all_16) > 0: - g_norm_16 = multi_tensor_applier(self.multi_tensor_l2norm, - self._dummy_overflow_buf, - [g_all_16], False)[0] - - # blend two grad norms to get global grad norm - global_grad_norm = multi_tensor_applier(self.multi_tensor_l2norm, - self._dummy_overflow_buf, - [[g_norm_32, g_norm_16]], - False)[0] - - max_grad_norm = 1.0 - clipped_ratio = max_grad_norm / max(global_grad_norm, max_grad_norm) - - for group in self.param_groups: - for p in group['params']: - if p.grad is None: - continue - p.grad.data *= clipped_ratio - grad = p.grad.data - if grad.is_sparse: - raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.') - - state = self.state[p] - - # State initialization - if len(state) == 0: - state['step'] = 0 - # Exponential moving average of gradient values - state['m'] = torch.zeros_like(p.data) - # Exponential moving average of squared gradient values - state['v'] = torch.zeros_like(p.data) - - m_t, v_t = state['m'], state['v'] - beta1, beta2 = group['betas'] - - state['step'] += 1 - - # m_t = beta1 * m + (1 - beta1) * g_t - m_t.mul_(beta1).add_(grad, alpha=1-beta1) - # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) - v_t.mul_(beta2).addcmul_(grad, grad, value=1-beta2) - - # Debiasing - m_t_hat = m_t / (1.0 - beta1 ** state['step']) - v_t_hat = v_t / (1.0 - beta2 ** state['step']) - - update = m_t_hat / v_t_hat.sqrt().add(group['eps']) - - if group['weight_decay'] != 0: - update.add_(p.data, alpha=group['weight_decay']) - - trust_ratio = 1.0 - w_norm = p.data.pow(2).sum().sqrt() - g_norm = update.pow(2).sum().sqrt() - if w_norm > 0 and g_norm > 0: - trust_ratio = w_norm / g_norm - - state['w_norm'] = w_norm - state['g_norm'] = g_norm - state['trust_ratio'] = trust_ratio - - step_size = group['lr'] - - p.data.add_(update, alpha=-step_size*trust_ratio) - - return loss - -class TestLamb(unittest.TestCase): - def setUp(self, max_abs_diff=1e-3, max_rel_diff=1, iters=7): - self.max_abs_diff = max_abs_diff - self.max_rel_diff = max_rel_diff - self.iters = iters - torch.cuda.manual_seed(9876) - - - def tearDown(self): - pass - - def gen_param_optim(self, tensors, lamb_option): - ref_param = [] - tst_param = [] - for tensor in tensors: - ref_param.append(torch.nn.Parameter(tensor.clone())) - tst_param.append(torch.nn.Parameter(tensor.clone())) - - ref_optim = self.ref_optim(ref_param, **lamb_option) - tst_optim = self.tst_optim(tst_param, use_nvlamb=True, **lamb_option) - - return (ref_param, tst_param, ref_optim, tst_optim) - - def gen_grad(self, ref_param, tst_param): - for p_ref, p_tst in zip(ref_param, tst_param): - p_ref.grad = torch.rand_like(p_ref) - p_tst.grad = p_ref.grad - - def gen_mixed_grad(self, ref_param, tst_param, scale=1.0): - half_grads = [] - for p_ref, _ in zip(ref_param, tst_param): - half_grads.append(torch.rand_like(p_ref).half()) - p_ref.grad = half_grads[-1].float() / scale - return half_grads - - def get_max_diff(self, ref_param, tst_param): - max_abs_diff = max_rel_diff = 0 - for p_ref, p_tst in zip(ref_param, tst_param): - max_abs_diff_p = (p_ref - p_tst).abs().max().item() - max_rel_diff_p = ((p_ref - p_tst) / p_ref).abs().max().item() - - if max_abs_diff_p > max_abs_diff: max_abs_diff = max_abs_diff_p - if max_rel_diff_p > max_rel_diff: max_rel_diff = max_rel_diff_p - - return max_abs_diff, max_rel_diff - - def gen_single_type_test(self, param_type=torch.float, device="cuda"): - nelem = 278011 - tensor = torch.rand(nelem, dtype=param_type, device=device) - weight_decay = [0, 0.01] - - for wd in weight_decay: - lamb_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08, 'weight_decay':wd} - ref_param, tst_param, ref_optim, tst_optim = \ - self.gen_param_optim([tensor], lamb_option) - - for i in range(self.iters): - self.gen_grad(ref_param, tst_param) - ref_optim.step() - torch.cuda.synchronize() - tst_optim.step() - torch.cuda.synchronize() - max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param) - - self.assertLessEqual(max_abs_diff, self.max_abs_diff) - self.assertLessEqual(max_rel_diff, self.max_rel_diff) - -class TestFusedLAMB(TestLamb): - def __init__(self, *args, **kwargs): - super(TestLamb, self).__init__(*args, **kwargs) - self.ref_optim = RefLAMB - self.tst_optim = apex.optimizers.FusedLAMB - - - def test_float(self): - self.gen_single_type_test(param_type=torch.float) - - @unittest.skip("PyTorch optimizer is not numerically correct for fp16") - def test_half(self): - self.gen_single_type_test(param_type=torch.float16) - - @unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required") - def test_multi_device(self): - devices = ("cuda:0", "cuda:1") - for current_dev, tensor_dev in product(devices, devices): - with torch.cuda.device(current_dev): - self.gen_single_type_test(param_type=torch.float, device=tensor_dev) - - def test_multi_params(self): - sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]] - weight_decay = [0, 0.01] - - for wd in weight_decay: - lamb_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08, 'weight_decay':wd} - tensors = [] - for size in sizes: - tensors.append(torch.rand(size, dtype=torch.float, device='cuda')) - ref_param, tst_param, ref_optim, tst_optim = \ - self.gen_param_optim(tensors, lamb_option) - - for i in range(self.iters): - self.gen_grad(ref_param, tst_param) - ref_optim.step() - tst_optim.step() - max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param) - self.assertLessEqual(max_abs_diff, self.max_abs_diff) - self.assertLessEqual(max_rel_diff, self.max_rel_diff) - - def test_lamb_option(self): - nelem = 1 - tensor = torch.rand(nelem, dtype=torch.float, device='cuda') - weight_decay = [0, 0.01] - - for wd in weight_decay: - lamb_option = {'lr':0.01, 'betas':(0.6, 0.9), 'eps':3e-06, 'weight_decay':wd} - ref_param, tst_param, ref_optim, tst_optim = \ - self.gen_param_optim([tensor], lamb_option) - - for i in range(self.iters): - self.gen_grad(ref_param, tst_param) - ref_optim.step() - tst_optim.step() - max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param) - - self.assertLessEqual(max_abs_diff, self.max_abs_diff) - self.assertLessEqual(max_rel_diff, self.max_rel_diff) - -class TestFusedMixedPrecisionLamb(TestLamb): - def __init__(self, *args, **kwargs): - super(TestLamb, self).__init__(*args, **kwargs) - self.ref_optim = RefLAMB - self.tst_optim = apex.optimizers.FusedMixedPrecisionLamb - - - def test_float(self): - self.gen_single_type_test(param_type=torch.float) - - @unittest.skip("PyTorch optimizer is not numerically correct for fp16") - def test_half(self): - self.gen_single_type_test(param_type=torch.float16) - - @unittest.skip("Skipped the test since it failed the accuracy test on the PyTorch as of 8/1/2022. Please refer to https://github.com/ROCmSoftwarePlatform/apex/issues/83") - @unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required") - def test_multi_device(self): - devices = ("cuda:0", "cuda:1") - for current_dev, tensor_dev in product(devices, devices): - with torch.cuda.device(current_dev): - self.gen_single_type_test(param_type=torch.float, device=tensor_dev) - - def test_multi_params(self): - sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]] - weight_decay = [0, 0.01] - - for wd in weight_decay: - lamb_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08, 'weight_decay':wd} - tensors = [] - for size in sizes: - tensors.append(torch.rand(size, dtype=torch.float, device='cuda')) - ref_param, tst_param, ref_optim, tst_optim = \ - self.gen_param_optim(tensors, lamb_option) - - for i in range(self.iters): - self.gen_grad(ref_param, tst_param) - ref_optim.step() - tst_optim.step() - max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param) - self.assertLessEqual(max_abs_diff, self.max_abs_diff) - self.assertLessEqual(max_rel_diff, self.max_rel_diff) - - def test_lamb_option(self): - nelem = 1 - tensor = torch.rand(nelem, dtype=torch.float, device='cuda') - weight_decay = [0, 0.01] - - for wd in weight_decay: - lamb_option = {'lr':0.01, 'betas':(0.6, 0.9), 'eps':3e-06, 'weight_decay':wd} - ref_param, tst_param, ref_optim, tst_optim = \ - self.gen_param_optim([tensor], lamb_option) - - for i in range(self.iters): - self.gen_grad(ref_param, tst_param) - ref_optim.step() - tst_optim.step() - max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param) - - self.assertLessEqual(max_abs_diff, self.max_abs_diff) - self.assertLessEqual(max_rel_diff, self.max_rel_diff) - -if __name__ == '__main__': - script_path = os.path.dirname(os.path.realpath(__file__)) - unittest.main() diff --git a/tests/L0/run_rocm.sh b/tests/L0/run_rocm.sh deleted file mode 100755 index 32405e7..0000000 --- a/tests/L0/run_rocm.sh +++ /dev/null @@ -1,2 +0,0 @@ -#!/bin/bash -APEX_TEST_WITH_ROCM=1 APEX_SKIP_FLAKY_TEST=1 python run_test.py diff --git a/tests/L0/run_test.py b/tests/L0/run_test.py deleted file mode 100644 index e87a1e8..0000000 --- a/tests/L0/run_test.py +++ /dev/null @@ -1,72 +0,0 @@ -"""L0 Tests Runner. - -How to run this script? - -1. Run all the tests: `python /path/to/apex/tests/L0/run_test.py` -2. Run one of the tests (e.g. fused layer norm): - `python /path/to/apex/tests/L0/run_test.py --include run_fused_layer_norm` -3. Run two or more of the tests (e.g. optimizers and fused layer norm): - `python /path/to/apex/tests/L0/run_test.py --include run_optimizers run_fused_layer_norm` -""" -import argparse -import os -import unittest -import sys - -from apex.testing.common_utils import TEST_WITH_ROCM -from apex.testing.common_utils import SKIP_FLAKY_TEST - -TEST_ROOT = os.path.dirname(os.path.abspath(__file__)) -TEST_DIRS = [ - "run_amp", - "run_fp16util", - "run_optimizers", - "run_fused_layer_norm", - "run_mlp", - "run_transformer", # not fully supported on ROCm -] -DEFAULT_TEST_DIRS = [ - "run_amp", - "run_fp16util", - "run_optimizers", - "run_fused_layer_norm", - "run_mlp", -] - - -def parse_args(): - parser = argparse.ArgumentParser( - description="L0 test runner", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - parser.add_argument( - "--include", - nargs="+", - choices=TEST_DIRS, - default=DEFAULT_TEST_DIRS, - help="select a set of tests to run (defaults to ALL tests).", - ) - args, _ = parser.parse_known_args() - return args - - -def main(args): - runner = unittest.TextTestRunner(verbosity=2) - errcode = 0 - for test_dir in args.include: - test_dir = os.path.join(TEST_ROOT, test_dir) - print(test_dir) - suite = unittest.TestLoader().discover(test_dir) - - print("\nExecuting tests from " + test_dir) - result = runner.run(suite) - if not result.wasSuccessful(): - errcode = 1 - - sys.exit(errcode) - - -if __name__ == '__main__': - args = parse_args() - main(args) - diff --git a/tests/L0/run_transformer/__init__.py b/tests/L0/run_transformer/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/L0/run_transformer/gpt_scaling_test.py b/tests/L0/run_transformer/gpt_scaling_test.py deleted file mode 100644 index eb70e25..0000000 --- a/tests/L0/run_transformer/gpt_scaling_test.py +++ /dev/null @@ -1,116 +0,0 @@ -import subprocess -import os - -from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE - - -def run_gpt(cmd): - args = list(cmd.split(" ")) - p = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - outs, errs = p.communicate() - outs = list(str((outs).decode("utf-8")).splitlines()) - success = False - runtime = 0 - num_params = 0 - for out in outs: - out = str(out) - if "Average Iteration Time:" in str(out): - slicey = out[out.find(":") + 2 :] - try: - runtime = float(slicey) - except: - print(slicey) - quit() - if "Number of Parameters:" in str(out): - slicey = out[out.find(":") + 2 :] - try: - num_params = int(slicey) - except: - print(slicey) - quit() - if str(out) == str(TEST_SUCCESS_MESSAGE): - success = True - return runtime, round(float(int(num_params)) / 10.0 ** 9, 3), success, errs - - -def plot(runtimes): - import matplotlib.pyplot as plt - - for distributed_setting in runtimes.keys(): - plt.scatter( - runtimes[distributed_setting].keys(), - runtimes[distributed_setting].values(), - label=distributed_setting, - ) - plt.legend() - plt.xlabel("Parameters (Billions)") - plt.ylabel("Training Iteration time (s)") - plt.title(str("GPT Scaling w/ Offloading")) - plt.savefig("offload_gpt_scaling.png") - plt.close() - if not os.path.exists("/my_workspace/"): - os.system("mkdir /my_workspace/") - os.system("cp *.png /my_workspace/") - - -def main(): - runtimes = {} - nlist = ( - list(range(2000, 10000, 2000)) - + list(range(10000, 50000, 5000)) - + list(range(50000, 100000, 10000)) - ) - print("N-List:", nlist) - for data_parr, tens_parr, pipe_parr in [(8, 1, 1), (4, 2, 1), (2, 1, 4), (1, 2, 4)]: - for offload in [True, False]: - dist_setting = ( - "ddp=" - + str(data_parr) - + ", tensor_parr=" - + str(tens_parr) - + ", pipe_parr=" - + str(pipe_parr) - + ", offload=" - + str(offload) - ) - runtimes[dist_setting] = {} - print("Beginning Testing for", dist_setting) - for n in nlist: - cmd = "python3 -m torch.distributed.launch --nproc_per_node=8 run_gpt_minimal_test.py" - cmd += ( - " --micro-batch-size 1 --num-layers " - + str(n) - + " --hidden-size 128 --num-attention-heads 16" - ) - cmd += ( - " --max-position-embeddings 128 --seq-length 128 --tensor-model-parallel-size " - + str(tens_parr) - ) - cmd += ( - " --pipeline-model-parallel-size " - + str(pipe_parr) - + (" --cpu-offload" if offload else "") - ) - print(cmd) - runtime, bill_params, success, errs = run_gpt(cmd) - if success: - runtimes[dist_setting][bill_params] = runtime - print( - str(runtime) + "s per training iter for", - str(bill_params) + "B parameter GPT-2", - ) - if n >= 10000: - plot(runtimes) - else: - print("GPT-2 w/", n, "layers failed using", dist_setting) - print("Moving on to the next distributed setting...") - print("#" * (25)) - print() - plot(runtimes) - break - print(runtimes) - plot(runtimes) - - -if __name__ == "__main__": - main() diff --git a/tests/L0/run_transformer/run_bert_minimal_test.py b/tests/L0/run_transformer/run_bert_minimal_test.py deleted file mode 100644 index 639c31e..0000000 --- a/tests/L0/run_transformer/run_bert_minimal_test.py +++ /dev/null @@ -1,260 +0,0 @@ -import random -import torch -try: - import torch_ucc -except ImportError: - HAS_TORCH_UCC = False -else: - HAS_TORCH_UCC = True - print("Use UCC as backend of Pipeline Parallel ProcessGroups") - -from apex.transformer.enums import ModelType -from apex.transformer import tensor_parallel -from apex.transformer import parallel_state -from apex.transformer.log_util import set_logging_level -from apex.transformer.tensor_parallel import vocab_parallel_cross_entropy -from apex.transformer.pipeline_parallel.utils import setup_microbatch_calculator -from apex.transformer.pipeline_parallel.utils import unwrap_model -from apex.transformer.pipeline_parallel.utils import ( - average_losses_across_data_parallel_group, -) -from apex.transformer.pipeline_parallel.schedules import get_forward_backward_func -from apex.transformer.pipeline_parallel.schedules.common import build_model -from apex.transformer.pipeline_parallel.schedules.common import ( - _get_params_for_weight_decay_optimization, -) - -from apex.transformer.testing.standalone_bert import bert_model_provider -from apex.transformer.testing import global_vars -from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE -from apex.transformer.testing.commons import initialize_distributed -from apex.transformer.testing.commons import print_separator - -import warnings - - -class DebugWarning(Warning): - pass - - -set_logging_level("WARNING") -mode = None -MANUAL_SEED = 42 -inds = None -masks = None -data_idx = 0 -MASK_PROB = 0.1 -EASY_MODE = False -EASY_MODE_SIZ = 32 -ONCE = False - - -def download_fancy_data(): - # import requests - # response = requests.get('https://internet.com/book.txt') - # text = ' '.join(response.text.split()) - text = """ - An original sentence not subject to any license restrictions, copyright, or royalty payments. Nothing to see here. Commercial or non-commercial use. Research or non-research purposes. The quick brown fox jumps over the lazy dog. Lorem ipsum. - """ - text = text * 1024 - encoded = text.encode("ascii", "replace") - ints = [int(encoded[i]) for i in range(len(encoded))] - return torch.tensor(ints) - - -# build a batch given sequence_len and batch size -def generate_fancy_data_labels(sequence_len, batch_size): - global data_idx - global inds - global masks - global MANUAL_SEED - temps = [] - for i in range(batch_size): - if inds is None or data_idx >= len(inds): - # hack as use of RNG will fall out of sync due to pipelines being different - torch.manual_seed(MANUAL_SEED) - inds = torch.randperm(effective_length, device="cuda") - masks = ( - torch.rand( - len(inds) // batch_size + 1, batch_size, sequence_len, device="cuda" - ) - >= MASK_PROB - ).long() - MANUAL_SEED += 1 - print("new epoch", len(inds)) - data_idx = 0 - print("my start", inds[0:5]) - print("masks_checksum:", torch.sum(masks)) - if EASY_MODE: - data_idx_ = data_idx % EASY_MODE_SIZ - else: - data_idx_ = data_idx - offset = inds[data_idx_] # * SEQUENCE_LEN - data_idx += 1 - - curr = fancy_data[offset : offset + sequence_len].clone().detach() - temps.append(curr) - temp = torch.stack(temps, dim=0).cuda() - mask = masks[data_idx // batch_size] - mask_not = torch.logical_not(mask).long() - data = mask * temp + mask_not * 124 - label = temp - if parallel_state.get_tensor_model_parallel_rank() == 0: - data_dict = {"text": data, "label": label, "mask_not": mask_not} - else: - data_dict = None - keys = ["text", "label", "mask_not"] - dtype = torch.int64 - broadcasted_data = tensor_parallel.broadcast_data(keys, data_dict, torch.long) - return ( - broadcasted_data["text"].long(), - broadcasted_data["label"].long(), - broadcasted_data["mask_not"], - ) - - -easy_data = None - - -def fwd_step_func(batch, model): - data, label, loss_mask = batch - y = model(data, torch.ones_like(data), lm_labels=label) - - def loss_func(output_tensor): - global ONCE - output_tensor, _ = output_tensor - lm_loss_ = output_tensor.float() - lm_loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() - averaged_loss = average_losses_across_data_parallel_group([lm_loss]) - if data_idx >= 1536: - assert averaged_loss < 4.8 - if not ONCE: - print("LOSS OK") - ONCE = True - return lm_loss, {"avg": averaged_loss} - - return y, loss_func - - -def train( - model, optim, virtual_pipeline_model_parallel_size, pipeline_model_parallel_size, async_comm -): - sequence_len = global_vars.get_args().seq_length - micro_batch_size = global_vars.get_args().micro_batch_size - hidden_size = global_vars.get_args().hidden_size - forward_backward_func = get_forward_backward_func( - virtual_pipeline_model_parallel_size, pipeline_model_parallel_size - ) - tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size) - for _ in range(16): - batch = generate_fancy_data_labels(sequence_len, batch_size) - optim.zero_grad() - forward_backward_func( - fwd_step_func, - batch, - model, - forward_only=False, - tensor_shape=tensor_shape, - async_comm=async_comm, - sequence_parallel_enabled=global_vars.get_args().sequence_parallel, - ) - # All-reduce layernorm parameters across model parallel nodes - # when sequence parallelism is used - if parallel_state.get_tensor_model_parallel_world_size() > 1 and global_vars.get_args().sequence_parallel: - for model_module in model: - unwrapped_model = unwrap_model(model_module) - for param in unwrapped_model.parameters(): - if getattr(param, 'sequence_parallel_enabled', False): - grad = param.grad - torch.distributed.all_reduce(grad, group=parallel_state.get_tensor_model_parallel_group()) - - optim.step() - - -if __name__ == "__main__": - global fancy_data - global effective_length - - global_vars.set_global_variables() - - fancy_data = download_fancy_data() - effective_length = fancy_data.size(0) // global_vars.get_args().seq_length - effective_length = fancy_data.size(0) - global_vars.get_args().seq_length - - initialize_distributed("nccl") - world_size = torch.distributed.get_world_size() - failure = None - init = True - try: - virtual_pipeline_model_parallel_sizes = (None, 2,) - if HAS_TORCH_UCC: - # Deliberately skipping test with interleaved schedule for BERT model. - # It deadlocks on hybrid UCC/NCCL backend. - virtual_pipeline_model_parallel_sizes = (None,) - for virtual_pipeline_model_parallel_size in virtual_pipeline_model_parallel_sizes: - args = global_vars.get_args() - async_comm = not args.sequence_parallel and virtual_pipeline_model_parallel_size is None - data_idx = 0 - ONCE = False - if init: - init = False - args = global_vars.get_args() - args.padded_vocab_size = 128 # needed in standalone gpt - args.model_type = ModelType.encoder_or_decoder - batch_size = args.global_batch_size - micro_batch_size = args.micro_batch_size - setup_microbatch_calculator( - args.rank, - args.rampup_batch_size, - args.global_batch_size, - args.micro_batch_size, - args.data_parallel_size, - ) - else: - parallel_state.destroy_model_parallel() - parallel_state.initialize_model_parallel( - args.tensor_model_parallel_size, - args.pipeline_model_parallel_size, - virtual_pipeline_model_parallel_size, - default_backend="nccl", - p2p_backend="ucc" if HAS_TORCH_UCC else "nccl", - ) - pipeline_model_parallel_size = ( - parallel_state.get_pipeline_model_parallel_world_size() - ) - - tensor_parallel.random.model_parallel_cuda_manual_seed(0) - model = build_model( - bert_model_provider, - wrap_with_ddp=parallel_state.get_data_parallel_world_size() > 1, - virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size, - cpu_offload=args.cpu_offload, - ) - assert isinstance(model, list) - assert len(model) == ( - 1 - if virtual_pipeline_model_parallel_size is None - else virtual_pipeline_model_parallel_size - ) - _param_groups = _get_params_for_weight_decay_optimization(model) - optim = torch.optim.Adam(_param_groups) - print(effective_length) - print(fancy_data.size(0)) - train( - model, - optim, - virtual_pipeline_model_parallel_size, - args.pipeline_model_parallel_size, - async_comm, - ) - except Exception as e: - failure = str(e) - finally: - parallel_state.destroy_model_parallel() - if failure is not None: - warnings.warn( - f"Minimal BERT Pipeline Parallel Failed with: {failure}", DebugWarning - ) - print(f"Minimal BERT Pipeline Parallel Failed with: {failure}") - torch.distributed.barrier() - print(TEST_SUCCESS_MESSAGE) diff --git a/tests/L0/run_transformer/run_dynamic_batchsize_test.py b/tests/L0/run_transformer/run_dynamic_batchsize_test.py deleted file mode 100644 index b2c020a..0000000 --- a/tests/L0/run_transformer/run_dynamic_batchsize_test.py +++ /dev/null @@ -1,202 +0,0 @@ -from typing import Tuple, List - -import torch - -from apex.transformer import parallel_state -from apex.transformer.pipeline_parallel.utils import get_num_microbatches -from apex.transformer.pipeline_parallel.schedules.common import ( - _get_params_for_weight_decay_optimization, -) -from apex.transformer.pipeline_parallel.schedules.common import build_model -from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving import ( - _forward_backward_pipelining_with_interleaving, -) -from apex.transformer.pipeline_parallel.utils import setup_microbatch_calculator -from apex.transformer.pipeline_parallel.utils import _reconfigure_microbatch_calculator -from apex.transformer.pipeline_parallel.utils import update_num_microbatches -from apex.transformer.testing import global_vars -from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE -from apex.transformer.testing.commons import initialize_distributed -from apex.transformer.testing.commons import print_separator -from apex.transformer.testing.commons import fwd_step_func -from apex.transformer.log_util import get_transformer_logger, set_logging_level -from apex.transformer.testing.commons import model_provider_func -from apex.transformer._data import MegatronPretrainingRandomSampler -from apex.transformer._data import MegatronPretrainingSampler - - -# note(mkozuki): To see warmup, steady, cooldown iterations, uncomment the line below -# set_logging_level("INFO") -_logger = get_transformer_logger("pipeline_parallel_test") -# note(mkozuki): To see if local batch size increases, uncomment the line below -# _logger.setLevel("INFO") -global_vars.set_global_variables( - args_defaults={"global_batch_size": 512, "rampup_batch_size": [64, 64, 1000],}, - ignore_unknown_args=True, -) - - -RAMPUP_BATCH_SIZE = [] -NUM_ITERATIONS = 20 -NUM_SAMPLES = 16384 // 2 -batch_size, micro_batch_size = None, None -HIDDEN_SIZE = 16 - - -def Dataset(num_samples: int) -> List[Tuple[torch.Tensor, torch.Tensor]]: - return [ - ( - torch.randn(HIDDEN_SIZE, HIDDEN_SIZE), - torch.randn(HIDDEN_SIZE // 2, HIDDEN_SIZE // 2), - ) - for _ in range(num_samples) - ] - - -# Run forward & backward with dynamic batch size. -def run_interleaved_with_dynamic_batch_size( - pipeline_model_parallel_size: int, forward_only: bool, BatchSamplerCls, -) -> None: - args = global_vars.get_args() - _reconfigure_microbatch_calculator( - args.rank, - args.rampup_batch_size, - args.global_batch_size, - args.micro_batch_size, - 1, # args.data_parallel_size, - ) - virtual_pipeline_model_parallel_size = 2 - # NOTE (mkozuki): `virtual_pipeline_model_parallel_size` is a requisite for the interleaving scheduling - # In megatron, `args.virtual_pipeline_model_parallel_size` is computed in megatron/arguments.py and - # used ubiquitously but this test uses custom model so it's safe to abuse. - parallel_state.initialize_model_parallel( - 1, pipeline_model_parallel_size, virtual_pipeline_model_parallel_size - ) - pipeline_model_parallel_size = ( - parallel_state.get_pipeline_model_parallel_world_size() - ) - - print_separator( - f"BatchSamplerCls: {BatchSamplerCls.__name__}, forward_only: {forward_only}" - ) - - model = build_model( - model_provider_func, - wrap_with_ddp=True, - virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size, - hidden_size=HIDDEN_SIZE, - ) - assert isinstance(model, list) - assert len(model) == virtual_pipeline_model_parallel_size - optimizer = torch.optim.Adam(_get_params_for_weight_decay_optimization(model)) - - initial_local_minibatch_size = get_num_microbatches() * micro_batch_size - dataset = Dataset(NUM_SAMPLES) - data_loader = torch.utils.data.DataLoader( - dataset, - batch_sampler=BatchSamplerCls( - NUM_SAMPLES, - 0, - initial_local_minibatch_size, - parallel_state.get_data_parallel_rank(), - parallel_state.get_data_parallel_world_size(), - ), - ) - data_iter = iter(data_loader) - - def get_num_samples(batch): - if isinstance(batch, torch.Tensor): - return len(batch) - assert isinstance(batch, (list, tuple)) - return [get_num_samples(b) for b in batch] - - tensor_shape = [micro_batch_size, HIDDEN_SIZE, HIDDEN_SIZE] - consumed_samples = 0 - for i in range(NUM_ITERATIONS): - update_num_microbatches(consumed_samples, consistency_check=False) - local_batch_size = get_num_microbatches() * micro_batch_size - data_iter._index_sampler.local_minibatch_size = local_batch_size - local_mini_batch = next(data_iter) - - _logger.info( - f"iter: {i} / {NUM_ITERATIONS} " - f"local batchsize: {get_num_samples(local_mini_batch)} " - f"consumed_samples: {consumed_samples} / {NUM_SAMPLES}" - ) - _forward_backward_pipelining_with_interleaving( - fwd_step_func, - local_mini_batch, - model, - forward_only=forward_only, - tensor_shape=tensor_shape, - ) - - consumed_samples += ( - parallel_state.get_data_parallel_world_size() - * get_num_microbatches() - * micro_batch_size - ) - - if not forward_only: - for m in model: - for p in m.parameters(): - if p.grad is None: - raise RuntimeError("grad not found") - else: - optimizer.zero_grad(set_to_none=True) - - torch.distributed.barrier() - if torch.distributed.get_rank() == 0: - print(TEST_SUCCESS_MESSAGE) - - -if __name__ == "__main__": - torch.backends.cuda.matmul.allow_tf32 = False - torch.backends.cudnn.allow_tf32 = False - n_tests = 0 - failures = [] - - initialize_distributed() - world_size = torch.distributed.get_world_size() - args = global_vars.get_args() - batch_size = args.global_batch_size - micro_batch_size = args.micro_batch_size - setup_microbatch_calculator( - args.rank, - args.rampup_batch_size, - args.global_batch_size, - args.micro_batch_size, - 1, # args.data_parallel_size, - ) - for BatchSamplerCls in ( - MegatronPretrainingSampler, - MegatronPretrainingRandomSampler, - ): - for forward_only in (False, True): - n_tests += 1 - pipeline_model_parallel_size = world_size - try: - run_interleaved_with_dynamic_batch_size( - pipeline_model_parallel_size, forward_only, BatchSamplerCls, - ) - except Exception as e: - msg = ( - f"\tforward_only: {forward_only}\n" - f"pipeline rank: {parallel_state.get_pipeline_model_parallel_rank()}, " - f"virtual pipeline rank: {parallel_state.get_virtual_pipeline_model_parallel_rank()}\n" - f"{str(e)}" - ) - raise RuntimeError(msg) - finally: - parallel_state.destroy_model_parallel() - print_separator("TEST RESULT") - if failures: - torch.distributed.barrier() - if torch.distributed.get_rank() == 0: - print("\n".join(failures)) - msg = f"{len(failures)} / {n_tests} cases failed" - raise RuntimeError(msg) - else: - torch.distributed.barrier() - if torch.distributed.get_rank() == 0: - print("### PASS!") diff --git a/tests/L0/run_transformer/run_gpt_minimal_test.py b/tests/L0/run_transformer/run_gpt_minimal_test.py deleted file mode 100644 index b3674c8..0000000 --- a/tests/L0/run_transformer/run_gpt_minimal_test.py +++ /dev/null @@ -1,223 +0,0 @@ -from functools import partial -from typing import List -import time - -import torch -try: - import torch_ucc -except ImportError: - HAS_TORCH_UCC = False -else: - HAS_TORCH_UCC = True - print("Use UCC as backend of Pipeline Parallel ProcessGroups") - -from apex.transformer import parallel_state -from apex.transformer.enums import ModelType -from apex.transformer.tensor_parallel import model_parallel_cuda_manual_seed -from apex.transformer.pipeline_parallel.utils import setup_microbatch_calculator -from apex.transformer.pipeline_parallel.utils import unwrap_model -from apex.transformer.pipeline_parallel.utils import ( - average_losses_across_data_parallel_group, -) -from apex.transformer.pipeline_parallel.utils import get_ltor_masks_and_position_ids -from apex.transformer.pipeline_parallel.schedules.common import build_model -from apex.transformer.pipeline_parallel.schedules.common import ( - _get_params_for_weight_decay_optimization, -) -from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import ( - forward_backward_pipelining_without_interleaving, -) -from apex.transformer.testing.standalone_gpt import gpt_model_provider -from apex.transformer.testing import global_vars -from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE -from apex.transformer.testing.commons import initialize_distributed - -MANUAL_SEED = 42 -inds = None -data_idx = 0 -N_VOCAB = 128 - - -def download_fancy_data(): - # import requests - # response = requests.get('https://internet.com/book.txt') - # text = ' '.join(response.text.split()) - text = """ - An original sentence not subject to any license restrictions, copyright, or royalty payments. Nothing to see here. Commercial or non-commercial use. Research or non-research purposes. The quick brown fox jumps over the lazy dog. Lorem ipsum. - """ - text = text * 1024 - encoded = text.encode("ascii", "replace") - ints = [int(encoded[i]) for i in range(len(encoded))] - return torch.tensor(ints) - - -# build a batch given sequence_len and batch size -def generate_fancy_data_labels(sequence_len, batch_size): - global data_idx - global inds - global MANUAL_SEED - temps = list() - for i in range(batch_size): - if inds is None or data_idx >= len(inds): - # hack as use of RNG will fall out of sync due to pipelines being different - model_parallel_cuda_manual_seed(MANUAL_SEED) - inds = torch.randperm(effective_length, device="cuda") - MANUAL_SEED += 1 - data_idx = 0 - data_idx_ = data_idx - offset = inds[data_idx_] - data_idx += 1 - curr = fancy_data[offset : offset + sequence_len + 1].clone().detach() - temps.append(curr) - temp = torch.stack(temps, dim=0).cuda() - return temp - - -easy_data = None - - -def get_batch(int_tensors: List[torch.Tensor]): - data = int_tensors[0] - # Unpack. - tokens_ = data.long() - labels = tokens_[:, 1:].contiguous() - tokens = tokens_[:, :-1].contiguous() - # Get the masks and position ids. - attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( - tokens, - N_VOCAB, # tokenizer.eod, - False, # args.reset_position_ids, - False, # args.reset_attention_mask, - False, # args.eod_mask_loss, - ) - return tokens, labels, loss_mask, attention_mask, position_ids - - -# Ref: https://github.com/NVIDIA/Megatron-LM/blob/b31e1296354e979722627a6c4dedafe19b51fa97/pretrain_gpt.py#L75 -def loss_func(loss_mask, output_tensor): - losses = output_tensor.float() - loss_mask = loss_mask.view(-1).float() - loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() - - # Reduce loss for logging. - averaged_loss = average_losses_across_data_parallel_group([loss]) - - return loss, {"lm loss": averaged_loss[0]} - - -# Ref: https://github.com/NVIDIA/Megatron-LM/blob/b31e1296354e979722627a6c4dedafe19b51fa97/pretrain_gpt.py#L86 -def fwd_step_func(batch, model): - """Forward step.""" - tokens, labels, loss_mask, attention_mask, position_ids = get_batch(batch) - output_tensor = model(tokens, position_ids, attention_mask, labels=labels) - return output_tensor, partial(loss_func, loss_mask) - - -def train(model, optim, pipeline_model_parallel_size, async_comm): - sequence_len = global_vars.get_args().seq_length - micro_batch_size = global_vars.get_args().micro_batch_size - hidden_size = global_vars.get_args().hidden_size - fwd_bwd_func = forward_backward_pipelining_without_interleaving - - tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size) - runtime = 0 - # training loop - for i in range(3): - since = time.time() - if torch.distributed.get_rank() == 0: - print("begin iter", i) - batch = [ - generate_fancy_data_labels(args.seq_length, args.global_batch_size) - for _ in range(pipeline_model_parallel_size) - ] - if torch.distributed.get_rank() == 0: - print("finished making batch...") - optim.zero_grad() - fwd_bwd_func( - fwd_step_func, - batch, - model, - forward_only=False, - tensor_shape=tensor_shape, - async_comm=async_comm, - sequence_parallel_enabled=args.sequence_parallel, - ) - if torch.distributed.get_rank() == 0: - print("finished forward step") - # All-reduce layernorm parameters across model parallel nodes - # when sequence parallelism is used - if parallel_state.get_tensor_model_parallel_world_size() > 1 and global_vars.get_args().sequence_parallel: - for model_module in model: - unwrapped_model = unwrap_model(model_module) - for param in unwrapped_model.parameters(): - if getattr(param, 'sequence_parallel_enabled', False): - grad = param.grad - torch.distributed.all_reduce(grad, group=parallel_state.get_tensor_model_parallel_group()) - optim.step() - if torch.distributed.get_rank() == 0: - print("finished iter", i) - runtime += time.time() - since - return runtime / 3.0 - - -if __name__ == "__main__": - init = True - global_vars.set_global_variables() - for async_comm in (False,) if global_vars.get_args().sequence_parallel else (False, True): - global fancy_data - global effective_length - - if init: - init = False - - fancy_data = download_fancy_data() - args = global_vars.get_args() - args.model_type = ModelType.encoder_or_decoder - effective_length = fancy_data.size(0) // args.seq_length - effective_length = fancy_data.size(0) - args.seq_length - - initialize_distributed("nccl") - world_size = torch.distributed.get_world_size() - - failure = None - args.padded_vocab_size = 128 - batch_size = args.global_batch_size - micro_batch_size = args.micro_batch_size - setup_microbatch_calculator( - args.rank, - args.rampup_batch_size, - args.global_batch_size, - args.micro_batch_size, - args.data_parallel_size, # args.data_parallel_size, - ) - world_size = torch.distributed.get_world_size() - - print(args.tensor_model_parallel_size, "MODEL PARALLEL SIZE") - - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=args.tensor_model_parallel_size, - pipeline_model_parallel_size_=args.pipeline_model_parallel_size, - default_backend="nccl", - p2p_backend="ucc" if HAS_TORCH_UCC else "nccl", - ) - - pipeline_model_parallel_size = ( - parallel_state.get_pipeline_model_parallel_world_size() - ) - model_parallel_cuda_manual_seed(0) - model = build_model( - gpt_model_provider, - wrap_with_ddp=parallel_state.get_data_parallel_world_size() > 1, - virtual_pipeline_model_parallel_size=None, - cpu_offload=args.cpu_offload, - ) - assert isinstance(model, list), model - _param_groups = _get_params_for_weight_decay_optimization(model) - optim = torch.optim.Adam(_param_groups) - runtime = train(model, optim, args.pipeline_model_parallel_size, async_comm) - - parallel_state.destroy_model_parallel() - torch.distributed.barrier() - if torch.distributed.get_rank() == 0: - print(TEST_SUCCESS_MESSAGE) - print("Average Iteration Time:", runtime) diff --git a/tests/L0/run_transformer/test_batch_sampler.py b/tests/L0/run_transformer/test_batch_sampler.py deleted file mode 100644 index 52175d5..0000000 --- a/tests/L0/run_transformer/test_batch_sampler.py +++ /dev/null @@ -1,142 +0,0 @@ -from itertools import product - -import torch -from torch.testing._internal import common_utils -from torch.utils.data import Dataset -from torch.utils.data import RandomSampler -from torch.utils.data import BatchSampler -from torch.utils.data import DataLoader - -from apex.transformer.pipeline_parallel.utils import _split_batch_into_microbatch as split_batch_into_microbatch - - -class MyIterableDataset(Dataset): - def __init__(self, start, end): - super().__init__() - assert end > start, "this example code only works with end >= start" - self.start = start - self.end = end - self.samples = list(range(self.start, self.end)) - - def __iter__(self): - return iter(range(self.start, self.end)) - - def __getitem__(self, index): - return self.samples[index] - - -class MegatronPretrainingRandomSampler: - - def __init__(self, total_samples, consumed_samples, micro_batch_size, - data_parallel_rank, data_parallel_size): - # Keep a copy of input params for later use. - self.total_samples = total_samples - self.consumed_samples = consumed_samples - self.micro_batch_size = micro_batch_size - self.data_parallel_rank = data_parallel_rank - self.data_parallel_size = data_parallel_size - self.micro_batch_times_data_parallel_size = \ - self.micro_batch_size * data_parallel_size - self.last_batch_size = \ - self.total_samples % self.micro_batch_times_data_parallel_size - - # Sanity checks. - assert self.total_samples > 0, \ - 'no sample to consume: {}'.format(self.total_samples) - assert self.micro_batch_size > 0 - assert data_parallel_size > 0 - assert self.data_parallel_rank < data_parallel_size, \ - 'data_parallel_rank should be smaller than data size: {}, ' \ - '{}'.format(self.data_parallel_rank, data_parallel_size) - - def __len__(self): - return self.total_samples - - def __iter__(self): - active_total_samples = self.total_samples - self.last_batch_size - self.epoch = self.consumed_samples // active_total_samples - current_epoch_samples = self.consumed_samples % active_total_samples - assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0 - - # data sharding and random sampling - bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) * self.micro_batch_size - bucket_offset = current_epoch_samples // self.data_parallel_size - start_idx = self.data_parallel_rank * bucket_size - - g = torch.Generator() - g.manual_seed(self.epoch) - random_idx = torch.randperm(bucket_size, generator=g).tolist() - idx_range = [start_idx + x for x in random_idx[bucket_offset:]] - - batch = [] - # Last batch if not complete will be dropped. - for idx in idx_range: - batch.append(idx) - if len(batch) == self.micro_batch_size: - self.consumed_samples += self.micro_batch_times_data_parallel_size - yield batch - batch = [] - - -# Samples 8 tensors in total. -# First sample 4 tensors twice, then sample 2 tensors fourth. -class TestBatchSamplerBehavior(common_utils.TestCase): - def test_batch_sampler_behavior(self): - dataset = MyIterableDataset(0, 100) - - for num_workers in (1, 2, 4): - with self.subTest(f"{num_workers}"): - torch.manual_seed(42) - loader = DataLoader(dataset, batch_sampler=MegatronPretrainingRandomSampler(100, 0, 4, 0, 1), num_workers=num_workers) - samples = [] - for i, batch in enumerate(loader): - samples.append(batch) - if i == 2 - 1: - break - - torch.manual_seed(42) - loader = DataLoader(dataset, batch_sampler=MegatronPretrainingRandomSampler(100, 0, 2, 0, 1), num_workers=num_workers) - samples2 = [] - for i, batch in enumerate(loader): - samples2.append(batch) - if i == 4 - 1: - break - self.assertEqual(torch.cat(samples), torch.cat(samples2)) - - def test_split_batch(self): - - class MyIterableDataset(Dataset): - def __init__(self, start, end): - super().__init__() - assert end > start, "this example code only works with end >= start" - self.start = start - self.end = end - self.samples = list(range(self.start, self.end)) - - def __len__(self): - return self.end - self.start - - def __iter__(self): - return iter(range(self.start, self.end)) - - def __getitem__(self, index): - return (torch.tensor([index, index]), torch.tensor([index // 2, index // 2])) - - dataset = MyIterableDataset(0, 100) - torch.manual_seed(42) - global_batch_size = 16 - loader = DataLoader(dataset, batch_sampler=MegatronPretrainingRandomSampler(100, 0, global_batch_size, 0, 1), num_workers=2) - batch = next(iter(loader)) - - for _micro_batch_size in (1, 2, 4, 8): - microbatches = list(split_batch_into_microbatch( - batch, - _micro_batch_size=_micro_batch_size, - _global_batch_size=global_batch_size, - )) - self.assertEqual(len(microbatches), global_batch_size // _micro_batch_size) - self.assertEqual(len(microbatches[0][0]), _micro_batch_size) - - -if __name__ == "__main__": - common_utils.run_tests() diff --git a/tests/L0/run_transformer/test_cross_entropy.py b/tests/L0/run_transformer/test_cross_entropy.py deleted file mode 100644 index 1f51628..0000000 --- a/tests/L0/run_transformer/test_cross_entropy.py +++ /dev/null @@ -1,94 +0,0 @@ -import logging -from typing import Tuple - -import torch -import torch.nn.functional as F -from torch.testing._internal import common_utils - -logging.getLogger("torch").setLevel(logging.WARNING) - -from apex.transformer import parallel_state -from apex.transformer import tensor_parallel -from apex.transformer.tensor_parallel import cross_entropy -from apex.transformer.testing.commons import set_random_seed, IdentityLayer -from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase -from apex.transformer.testing.distributed_test_base import UccDistributedTestBase - -logging.getLogger("apex").setLevel(logging.WARNING) - - -def torch_cross_entropy( - batch_size: int, seq_length: int, vocab_size: int, logits_scale: float, seed: int, -) -> Tuple[torch.Tensor, torch.Tensor]: - set_random_seed(seed) - identity = IdentityLayer( - (batch_size, seq_length, vocab_size), scale=logits_scale - ).cuda() - logits = identity() - target = torch.cuda.LongTensor(size=(batch_size, seq_length)).random_(0, vocab_size) - loss = ( - F.cross_entropy( - logits.view(-1, logits.size()[-1]), target.view(-1), reduction="none" - ) - .view_as(target) - .mean() - ) - loss.backward() - return loss, identity.weight.grad - - -def tensor_sharded_cross_entropy( - batch_size, seq_length, vocab_size, logits_scale, seed -): - set_random_seed(seed) - identity = IdentityLayer( - (batch_size, seq_length, vocab_size), scale=logits_scale - ).cuda() - logits = identity() - logits_parallel = tensor_parallel.scatter_to_tensor_model_parallel_region(logits) - target = torch.cuda.LongTensor(size=(batch_size, seq_length)).random_(0, vocab_size) - logits_parallel_ = logits_parallel.clone().detach() - loss = cross_entropy.vocab_parallel_cross_entropy(logits_parallel, target).mean() - loss.backward() - # check for mutation - assert torch.equal(logits_parallel_, logits_parallel) - return loss, identity.weight.grad - - -class VocabParallelCrossEntropyTestBase: - def test_cross_entropy(self): - batch_size, sequence_length, vocab_size_per_partition = 13, 17, 11 - logits_scale = 1000.0 - seed = 1234 - for tensor_model_parallel_world_size in range(1, self.world_size + 1): - if self.world_size % tensor_model_parallel_world_size: - continue - with self.subTest( - tensor_model_parallel_world_size=tensor_model_parallel_world_size - ): - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=tensor_model_parallel_world_size, - ) - vocab_size = vocab_size_per_partition * tensor_model_parallel_world_size - loss_torch, grad_torch = torch_cross_entropy( - batch_size, sequence_length, vocab_size, logits_scale, seed - ) - ( - loss_tensor_parallel, - grad_tensor_parallel, - ) = tensor_sharded_cross_entropy( - batch_size, sequence_length, vocab_size, logits_scale, seed - ) - - self.assertEqual(loss_torch, loss_tensor_parallel) - self.assertEqual(grad_torch, grad_tensor_parallel) - - parallel_state.destroy_model_parallel() - - -class NcclVocabParallelCrossEntropyTest(VocabParallelCrossEntropyTestBase, NcclDistributedTestBase): pass -class UccVocabParallelCrossEntropyTest(VocabParallelCrossEntropyTestBase, UccDistributedTestBase): pass - - -if __name__ == "__main__": - common_utils.run_tests() diff --git a/tests/L0/run_transformer/test_data.py b/tests/L0/run_transformer/test_data.py deleted file mode 100644 index 38dc752..0000000 --- a/tests/L0/run_transformer/test_data.py +++ /dev/null @@ -1,64 +0,0 @@ -import logging - -import torch.testing -from torch.testing._internal import common_utils - -logging.getLogger("torch").setLevel(logging.WARNING) - -from apex.transformer import parallel_state -from apex.transformer.tensor_parallel import data as data_utils -from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase -from apex.transformer.testing.distributed_test_base import UccDistributedTestBase - -logging.getLogger("torch").setLevel(logging.WARNING) - - -class BroadcastDataTestBase: - def test_broadcast_data(self): - tensor_model_parallel_world_size: int = self.world_size // ( - 1 + self.world_size > 1 - ) - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=tensor_model_parallel_world_size - ) - - target_key_size = { - "key1": [7, 11], - "key2": [8, 2, 1], - "key3": [13], - "key4": [5, 1, 2], - "key5": [5, 12], - } - keys = [k for k in target_key_size] - - data = {} - data_t = {} - with torch.no_grad(): - for key in target_key_size: - data[key] = torch.randint(0, 1000, size=target_key_size[key]) - data_t[key] = data[key].clone() - # "key_x" is supposed to be ignored. - data["key_x"] = torch.rand(5) - data_t["key_x"] = data["key_x"].clone() - if parallel_state.get_tensor_model_parallel_rank() != 0: - data = None - - data_utils._check_data_types(keys, data_t, torch.int64) - key_size, _, _ = data_utils._build_key_size_numel_dictionaries(keys, data) - - for key in keys: - self.assertEqual(target_key_size[key], key_size[key]) - - broadcasted_data = data_utils.broadcast_data(keys, data, torch.int64) - for key in keys: - self.assertEqual(broadcasted_data[key], data_t[key].cuda()) - - parallel_state.destroy_model_parallel() - - -class NcclBroadcastDataTest(BroadcastDataTestBase, NcclDistributedTestBase): pass -class UccBroadcastDataTest(BroadcastDataTestBase, UccDistributedTestBase): pass - - -if __name__ == "__main__": - common_utils.run_tests() diff --git a/tests/L0/run_transformer/test_fused_softmax.py b/tests/L0/run_transformer/test_fused_softmax.py deleted file mode 100644 index 278df69..0000000 --- a/tests/L0/run_transformer/test_fused_softmax.py +++ /dev/null @@ -1,220 +0,0 @@ -"""Test for fused softmax functions. - -Ref: https://github.com/NVIDIA/Megatron-LM/blob/40becfc96c4144985458ac0e0fae45dbb111fbd2/megatron/fused_kernels/tests/test_fused_kernels.py -""" # NOQA -import itertools - -import torch -from torch.testing._internal import common_utils - -from apex.transformer import AttnMaskType -from apex.transformer.functional import FusedScaleMaskSoftmax - - -def attention_mask_func(attention_scores, attention_mask): - return attention_scores.masked_fill(attention_mask, -10000.0) - - -autocast_dtypes = ( - (torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,) -) - - -class TestFusedScaleMaskSoftmax(common_utils.TestCase): - def _setup_fused_softmax( - self, - input_in_fp16, - input_in_bf16, - scale=None, - softmax_in_fp32=False, - attn_mask_type=AttnMaskType.padding, - ): - fused_fn = FusedScaleMaskSoftmax( - input_in_fp16=input_in_fp16, - input_in_bf16=input_in_bf16, - mask_func=attention_mask_func, - scale=scale, - softmax_in_fp32=softmax_in_fp32, - attn_mask_type=attn_mask_type, - scaled_masked_softmax_fusion=True, - ) - torch_fn = FusedScaleMaskSoftmax( - input_in_fp16=input_in_fp16, - input_in_bf16=input_in_bf16, - mask_func=attention_mask_func, - scale=scale, - softmax_in_fp32=softmax_in_fp32, - attn_mask_type=attn_mask_type, - scaled_masked_softmax_fusion=False, - ) - return fused_fn, torch_fn - - def test_fused_scale_mask_softmax(self): - """ - attention_scores.shape = [4, 12, 24, 24] - mask.shape = [4, 1, 24, 24] - """ - for (dtype, scale, softmax_in_fp32, shape) in itertools.product( - (torch.half, torch.bfloat16), (None, 2.0), (False, True), ((4, 12, 24, 24), (32, 12, 4, 214)) - ): - with self.subTest(f"{dtype}-{scale}-{softmax_in_fp32}"): - input_in_fp16 = dtype == torch.half - input_in_bf16 = dtype == torch.bfloat16 - if not (scale is None or softmax_in_fp32): - with self.assertRaises(RuntimeError): - self._setup_fused_softmax( - input_in_fp16, - input_in_bf16, - scale, - softmax_in_fp32, - AttnMaskType.padding, - ) - return - fused_fn, torch_fn = self._setup_fused_softmax( - input_in_fp16, - input_in_bf16, - scale, - softmax_in_fp32, - AttnMaskType.padding, - ) - - attention_scores_0 = ( - torch.randn(shape) - .to(device="cuda", dtype=dtype) - .requires_grad_(True) - ) - with torch.no_grad(): - attention_scores_1 = attention_scores_0.clone().requires_grad_(True) - mask_shape = (shape[0],) + (1,) + shape[2:] - mask = torch.randint(0, 2, mask_shape, device="cuda").bool() - expected = fused_fn(attention_scores_0, mask) - actual = torch_fn(attention_scores_1, mask) - self.assertEqual(actual, expected) - - g0 = torch.rand_like(actual) - with torch.no_grad(): - g1 = g0.clone() - expected.backward(g0) - actual.backward(g1) - - def test_autocast_fused_scale_mask_softmax(self): - for dtype in autocast_dtypes: - with self.subTest(f"{dtype}"): - input_in_fp16 = dtype == torch.half - input_in_bf16 = dtype == torch.bfloat16 - fused_fn, torch_fn = self._setup_fused_softmax( - input_in_fp16, input_in_bf16, attn_mask_type=AttnMaskType.padding - ) - - attention_scores_0 = ( - torch.randn((4, 12, 24, 24)).cuda().requires_grad_(True) - ) - with torch.no_grad(): - attention_scores_1 = ( - attention_scores_0.clone().to(dtype).requires_grad_(True) - ) - mask = torch.randint(0, 2, (4, 1, 24, 24)).bool().cuda() - - expected = torch_fn(attention_scores_1, mask) - with torch.cuda.amp.autocast(dtype=dtype): - actual = fused_fn(attention_scores_0, mask) - self.assertEqual(actual.dtype, dtype) - self.assertEqual(actual, expected) - - g0 = torch.rand_like(actual) - with torch.no_grad(): - g1 = g0.clone() - expected.backward(g0) - actual.backward(g1) - - def test_fused_upper_triangle_mask_softmax(self): - """ - attn_weights.shape: [4, 12, 24, 24] - total_mask.shape: [4, 1, 24, 24] - - total_mask[0, 0], a 24x24 matrix is like a lower triangular matrix, but - upper elements are True and lower elements and diagonal are False. - """ - for (dtype, scale, softmax_in_fp32) in itertools.product( - (torch.half, torch.bfloat16), (None, 2.0), (False, True), - ): - with self.subTest(f"{dtype}-{scale}-{softmax_in_fp32}"): - input_in_fp16 = dtype == torch.half - input_in_bf16 = dtype == torch.bfloat16 - if not (scale is None or softmax_in_fp32): - with self.assertRaises(RuntimeError): - self._setup_fused_softmax( - input_in_fp16, - input_in_bf16, - scale, - softmax_in_fp32, - AttnMaskType.causal, - ) - return - fused_fn, torch_fn = self._setup_fused_softmax( - input_in_fp16, - input_in_bf16, - scale, - softmax_in_fp32, - AttnMaskType.causal, - ) - - attn_weights_0 = ( - torch.randn((4, 12, 24, 24)) - .to(device="cuda", dtype=dtype) - .requires_grad_(True) - ) - with torch.no_grad(): - attn_weights_1 = attn_weights_0.clone().requires_grad_(True) - total_mask = ( - ~(torch.tril(torch.randn((24, 24), device="cuda")).bool()) - .unsqueeze(0) - .unsqueeze(0) - ) - total_mask = total_mask.repeat((4, 1, 1, 1)) - expected = fused_fn(attn_weights_0, total_mask) - actual = torch_fn(attn_weights_1, total_mask) - self.assertEqual(actual, expected) - - g0 = torch.randn_like(actual) - with torch.no_grad(): - g1 = g0.clone() - actual.backward(g0) - expected.backward(g1) - - def test_autocast_fused_upper_triangle_mask_softmax(self): - for dtype in autocast_dtypes: - with self.subTest(f"{dtype}"): - input_in_fp16 = dtype == torch.half - input_in_bf16 = dtype == torch.bfloat16 - fused_fn, torch_fn = self._setup_fused_softmax( - input_in_fp16, input_in_bf16, attn_mask_type=AttnMaskType.causal - ) - - attn_weights_0 = ( - torch.randn((4, 12, 24, 24)).cuda().requires_grad_(True) - ) - with torch.no_grad(): - attn_weights_1 = ( - attn_weights_0.clone().to(dtype).requires_grad_(True) - ) - total_mask = ( - ~(torch.tril(torch.randn((24, 24), device="cuda")).bool()) - .unsqueeze(0) - .unsqueeze(0) - ) - - with torch.cuda.amp.autocast(dtype=dtype): - actual = fused_fn(attn_weights_0, total_mask) - self.assertEqual(actual.dtype, dtype) - expected = torch_fn(attn_weights_1, total_mask) - self.assertEqual(actual, expected) - - g0 = torch.randn_like(actual) - with torch.no_grad(): - g1 = g0.clone() - actual.backward(g0) - expected.backward(g1) - -if __name__ == "__main__": - common_utils.run_tests() diff --git a/tests/L0/run_transformer/test_layers.py b/tests/L0/run_transformer/test_layers.py deleted file mode 100644 index b3b2eb2..0000000 --- a/tests/L0/run_transformer/test_layers.py +++ /dev/null @@ -1,558 +0,0 @@ -import logging -import unittest -import typing - -import torch -import torch.nn as nn -from torch.testing._internal import common_utils - -from apex.transformer import parallel_state -from apex.transformer.tensor_parallel import layers -from apex.transformer.testing.commons import set_random_seed -from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase -from apex.transformer.testing.distributed_test_base import UccDistributedTestBase - - -logging.getLogger("torch").setLevel(logging.WARNING) -logging.getLogger("apex").setLevel(logging.WARNING) - - -# N.B.(mkozuki): Disable TF32 matrix multiply. -# Matrices used in this test are so small that TF32 matmul -# can be less precise so that `self.assertEqual` raises. -torch.backends.cuda.matmul.allow_tf32 = False - - -class TensorParallelLayerTestBase: - - BATCH_SIZE: int = 8 - SEQUENCE_LENGTH: int = 128 - VOCAB_SIZE: int = 1024 - HIDDEN_SIZE: int = 256 - INPUT_SIZE_COEFF: int = 256 - OUTPUT_SIZE_COEFF: int = 256 - SEED: int = 123456 - - @property - def tensor_shape(self) -> typing.Sequence[int]: - return [self.SEQUENCE_LENGTH, self.BATCH_SIZE, self.HIDDEN_SIZE] - - @torch.no_grad() - @unittest.skipIf(torch.cuda.device_count() < 2, "Requires >=2 GPUs") - def test_all_gather_parity(self) -> None: - if self.DISTRIBUTED_BACKEND == "ucc": - self.skipTest("torch_ucc does NOT support `torch.distributed._all_gather_base` as of 2022/06/15") - from torch.distributed.distributed_c10d import all_gather, _all_gather_base # NOQA - - for tensor_model_parallel_world_size in range(1, self.world_size + 1): - if self.world_size % tensor_model_parallel_world_size: - continue - with self.subTest(tensor_model_parallel_world_size=tensor_model_parallel_world_size): - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=tensor_model_parallel_world_size, - ) - tensor_model_parallel_rank = parallel_state.get_tensor_model_parallel_rank() - cur_tensor_model_device = torch.device(f"cuda:{tensor_model_parallel_rank}") - with torch.no_grad(): - tensor = tensor_model_parallel_rank * torch.ones( - self.tensor_shape, dtype=torch.float32, device=cur_tensor_model_device) - numel = tensor.numel() - numel_gathered = tensor_model_parallel_world_size * numel - gathered = torch.empty( - torch.Size((numel_gathered,)), - device=cur_tensor_model_device, - dtype=torch.float32, - requires_grad=False, - ) - chunks = [ - gathered[i * numel : (i + 1) * numel] - for i in range(tensor_model_parallel_world_size) - ] - all_gather(chunks, tensor, group=parallel_state.get_tensor_model_parallel_group()) - - gathered_for_base = torch.empty( - torch.Size((numel_gathered,)), - device=cur_tensor_model_device, - dtype=torch.float32, - requires_grad=False, - ) - _all_gather_base( - gathered_for_base, - tensor, - group=parallel_state.get_tensor_model_parallel_group(), - ) - - self.assertEqual(gathered, gathered_for_base) - parallel_state.destroy_model_parallel() - - @torch.no_grad() - @unittest.skipIf(torch.cuda.device_count() < 2, "Requires >=2 GPUs") - def test_reduce_scatter_parity(self) -> None: - if self.DISTRIBUTED_BACKEND == "ucc": - self.skipTest("torch_ucc does NOT support `torch.distributed._reduce_scatter_base` as of 2022/06/15") - from torch.distributed.distributed_c10d import reduce_scatter, _reduce_scatter_base # NOQA - - for tensor_model_parallel_world_size in range(2, self.world_size + 1): - if self.world_size % tensor_model_parallel_world_size: - continue - with self.subTest(tensor_model_parallel_world_size=tensor_model_parallel_world_size): - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=tensor_model_parallel_world_size, - ) - tensor_model_parallel_rank = parallel_state.get_tensor_model_parallel_rank() - cur_tensor_model_device = torch.device(f"cuda:{tensor_model_parallel_rank}") - with torch.no_grad(): - input = torch.cat([ - i * torch.ones(self.tensor_shape, dtype=torch.float32, device=cur_tensor_model_device) - for i in range(tensor_model_parallel_world_size) - ]) - input_list = [t.clone() for t in input.chunk(tensor_model_parallel_world_size)] - output = torch.empty( - self.tensor_shape, - device=cur_tensor_model_device, - dtype=torch.float32, - requires_grad=False, - ) - reduce_scatter( - output, input_list, - group=parallel_state.get_tensor_model_parallel_group(), - ) - - output_for_base = torch.empty( - self.tensor_shape, - device=cur_tensor_model_device, - dtype=torch.float32, - requires_grad=False, - ) - _reduce_scatter_base( - output_for_base, - input, - group=parallel_state.get_tensor_model_parallel_group(), - ) - - self.assertEqual(output, output_for_base) - self.assertEqual(input, torch.cat(input_list)) - parallel_state.destroy_model_parallel() - - def test_parallel_embedding(self) -> None: - for tensor_model_parallel_world_size in range(1, self.world_size + 1): - if self.world_size % tensor_model_parallel_world_size: - continue - with self.subTest( - tensor_model_parallel_world_size=tensor_model_parallel_world_size - ): - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=tensor_model_parallel_world_size, - ) - set_random_seed(self.SEED + 1) - input_tensor = torch.randint( - 0, - self.VOCAB_SIZE, - ( - self.BATCH_SIZE, - self.SEQUENCE_LENGTH, - ), - device="cuda", - ) - loss_weight = torch.randn( - ( - self.BATCH_SIZE, - self.SEQUENCE_LENGTH, - self.HIDDEN_SIZE, - ), - device="cuda", - ) - - set_random_seed(self.SEED) - embedding_torch = nn.Embedding( - self.VOCAB_SIZE, - self.HIDDEN_SIZE, - ).cuda() - output_torch = embedding_torch(input_tensor) - loss_torch = torch.mul(output_torch, loss_weight).sum() - loss_torch.backward() - - # N.B.(mkozuki): With affine weight initialization on GPU, - # it's super difficult to keep the consistency with nn.Embedding. - # Thus, turning on `use_cpu_initialization`. - set_random_seed(self.SEED) - embedding_vocab_parallel = layers.VocabParallelEmbedding( - self.VOCAB_SIZE, - self.HIDDEN_SIZE, - init_method=nn.init.normal_, - use_cpu_initialization=True, - ).cuda() - output_vocab_parallel = embedding_vocab_parallel(input_tensor) - loss_vocab_parallel = torch.mul( - output_vocab_parallel, loss_weight - ).sum() - loss_vocab_parallel.backward() - - self.assertEqual(output_torch, output_vocab_parallel) - self.assertEqual(loss_torch, loss_vocab_parallel) - - splitted_weight_torch = torch.split( - embedding_torch.weight.grad, - self.VOCAB_SIZE - // tensor_model_parallel_world_size, - 0, - )[parallel_state.get_tensor_model_parallel_rank()] - self.assertEqual( - splitted_weight_torch, embedding_vocab_parallel.weight.grad - ) - - parallel_state.destroy_model_parallel() - - def _affine_weight_init_test_impl( - self, init_device: str, is_column_parallel: bool - ) -> None: - dim = int(not is_column_parallel) - for tensor_model_parallel_world_size in range(1, self.world_size + 1): - if self.world_size % tensor_model_parallel_world_size: - continue - with self.subTest( - tensor_model_parallel_world_size=tensor_model_parallel_world_size - ): - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=tensor_model_parallel_world_size - ) - input_size: int = self.INPUT_SIZE_COEFF * tensor_model_parallel_world_size - output_size: int = self.OUTPUT_SIZE_COEFF * tensor_model_parallel_world_size - - weight_shape = ( - (self.OUTPUT_SIZE_COEFF, input_size) - if is_column_parallel - else (output_size, self.INPUT_SIZE_COEFF) - ) - weight = torch.empty(weight_shape) - set_random_seed(self.SEED) - - sharding_dim_size = ( - self.OUTPUT_SIZE_COEFF - if is_column_parallel - else self.INPUT_SIZE_COEFF - ) - - if init_device == "cpu": - layers._initialize_affine_weight_cpu( - weight, - output_size, - input_size, - sharding_dim_size, - dim, - nn.init.normal_, - params_dtype=torch.float32, - ) - else: - layers._initialize_affine_weight_gpu( - weight, torch.nn.init.normal_, dim - ) - # Target - set_random_seed(self.SEED) - if init_device == "cpu": - main_weight = torch.empty(output_size, input_size) - nn.init.normal_(main_weight) - curr_weight = torch.split(main_weight, sharding_dim_size, dim=dim)[ - parallel_state.get_tensor_model_parallel_rank() - ] - else: - curr_weight = torch.empty(*weight_shape) - nn.init.normal_(curr_weight) - self.assertEqual(curr_weight, weight) - parallel_state.destroy_model_parallel() - - def test_affine_weight_init_column_parallel_cpu(self) -> None: - self._affine_weight_init_test_impl(init_device="cpu", is_column_parallel=True) - - def test_affine_weight_init_column_parallel_gpu(self) -> None: - self._affine_weight_init_test_impl(init_device="gpu", is_column_parallel=True) - - def test_affine_weight_init_row_parallel_cpu(self) -> None: - self._affine_weight_init_test_impl(init_device="cpu", is_column_parallel=False) - - def test_affine_weight_init_row_parallel_gpu(self) -> None: - self._affine_weight_init_test_impl(init_device="gpu", is_column_parallel=False) - - def test_row_parallel_linear(self) -> None: - self._row_parallel_linear_test_impl(False, False, False) - - def test_row_parallel_linear_gradient_accumulation_fusion(self) -> None: - self._row_parallel_linear_test_impl(True, False, False) - - def test_row_parallel_linear_gradient_accumulation_fusion_in_fp16(self) -> None: - self._row_parallel_linear_test_impl(True, True, False) - - @unittest.skipIf(torch.cuda.device_count() < 2, "Sequence Parallel requires >=2 GPUs") - def test_row_parallel_linear_sequence_parallel(self) -> None: - self._row_parallel_linear_test_impl(False, False, True) - - # TODO(mkozuki): Merge this with `_column_parallel_linear_test_impl` - # Note that `input_is_parallel` is unique to `RowParallelLinear` which could make the merge complicated. - def _row_parallel_linear_test_impl( - self, - gradient_accumulation_fusion: bool, - accumulation_in_fp16: bool, - sequence_parallel_enabled: bool, - ) -> None: - tensor_shape = ( - self.SEQUENCE_LENGTH, - self.BATCH_SIZE, - self.HIDDEN_SIZE, - ) - for tensor_model_parallel_world_size in range( - 1 + int(sequence_parallel_enabled), self.world_size + 1 - ): - if self.world_size % tensor_model_parallel_world_size: - continue - with self.subTest( - tensor_model_parallel_world_size=tensor_model_parallel_world_size, - ): - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=tensor_model_parallel_world_size, - ) - set_random_seed(self.SEED) - - linear = layers.RowParallelLinear( - self.HIDDEN_SIZE, - self.HIDDEN_SIZE, - keep_master_weight_for_test=True, - params_dtype=torch.float32, - use_cpu_initialization=True, - gradient_accumulation_fusion=gradient_accumulation_fusion, - accumulation_in_fp16=accumulation_in_fp16, - sequence_parallel_enabled=sequence_parallel_enabled, - # n.b.(mkozuki): RowParallelLinear is constructed with `input_is_parallel=True` - # by default, e.g. https://github.com/NVIDIA/NeMo/blob/782b4e1652aaa43c8be390d9\ - # db0dc89544afa080/nemo/collections/nlp/modules/common/megatron/transformer.py#L204 - input_is_parallel=True, - ).cuda() - if accumulation_in_fp16: - linear = linear.half() - # Simulate the situation where fusion of weight grad calculation and gradient accumulation is enabled. - if gradient_accumulation_fusion: - with torch.no_grad(): - linear.weight.main_grad = torch.zeros_like(linear.weight) - - with torch.no_grad(): - orig_input_tensor = torch.randn(tensor_shape, requires_grad=True, device="cuda") - orig_loss_weight = torch.randn(tensor_shape, device="cuda") - input_tensor = orig_input_tensor.chunk( - chunks=tensor_model_parallel_world_size, - dim=2, - )[parallel_state.get_tensor_model_parallel_rank()].contiguous() - if sequence_parallel_enabled: - loss_weight = orig_loss_weight.chunk( - chunks=tensor_model_parallel_world_size, - dim=0, - )[parallel_state.get_tensor_model_parallel_rank()] - else: - loss_weight = orig_loss_weight - if accumulation_in_fp16: - orig_input_tensor = orig_input_tensor.half() - input_tensor = input_tensor.half() - loss_weight = loss_weight.half() - input_tensor.requires_grad_() - output, _ = linear(input_tensor) - loss = torch.mul(output, loss_weight).sum() - loss.backward() - self.assertIsNotNone(input_tensor.grad) - - ref_linear = nn.Linear( - in_features=self.HIDDEN_SIZE, - out_features=self.HIDDEN_SIZE, - bias=False, - device="cuda", - ) - with torch.no_grad(): - dldy = orig_loss_weight.clone() - x = orig_input_tensor.clone() - ref_linear.weight.copy_(linear.master_weight) - if accumulation_in_fp16: - ref_linear = ref_linear.half() - x.requires_grad_() - expected_output = ref_linear(x) - expected_loss = torch.mul(expected_output, dldy).sum() - expected_loss.backward() - - if not accumulation_in_fp16: - if sequence_parallel_enabled: - self.assertEqual( - x=output, - y=expected_output.chunk( - chunks=tensor_model_parallel_world_size, - dim=0, - )[parallel_state.get_tensor_model_parallel_rank()], - ) - else: - self.assertEqual( - x=output, - y=expected_output, - ) - - grad_attr_name = "main_grad" if gradient_accumulation_fusion else "grad" - # NOTE(mkozuki): Numerical errors seems to be enlarged by tensor model parallel. - if tensor_model_parallel_world_size == 1: - self.assertEqual( - x=getattr(linear.weight, grad_attr_name), - y=ref_linear.weight.grad.chunk( - chunks=tensor_model_parallel_world_size, - dim=0, - )[parallel_state.get_tensor_model_parallel_rank()], - ) - - parallel_state.destroy_model_parallel() - - def test_column_parallel_linear(self): - self._column_parallel_linear_test_impl(False, False, False, False) - - def test_column_parallel_linear_async(self): - self._column_parallel_linear_test_impl(True, False, False, False) - - def test_column_parallel_linear_gradient_accumulation_fusion(self): - self._column_parallel_linear_test_impl(False, True, False, False) - - def test_column_parallel_linear_gradient_accumulation_fusion_in_fp16(self): - self._column_parallel_linear_test_impl(False, True, True, False) - - def test_column_parallel_linear_sequence_parallel(self): - if self.DISTRIBUTED_BACKEND == "ucc": - self.skipTest("Backward's reduce_scatter fails. as of 2022/06/15") - self._column_parallel_linear_test_impl(False, False, False, True) - - @unittest.skipIf(torch.cuda.device_count() < 2, "Sequence Parallel requires >= 2 GPUs") - def test_column_parallel_linear_exception(self): - with self.assertRaisesRegex( - RuntimeError, - "`async_tensor_model_parallel_allreduce` and `sequence_parallel_enabled` cannot be enabled at the same time.", - ): - self._column_parallel_linear_test_impl(True, False, False, True) - - def _column_parallel_linear_test_impl( - self, - async_tensor_model_parallel_allreduce: bool, - gradient_accumulation_fusion: bool, - accumulation_in_fp16: bool, - sequence_parallel_enabled: bool, - ): - for tensor_model_parallel_world_size in range(1, self.world_size + 1): - if async_tensor_model_parallel_allreduce and sequence_parallel_enabled: - if tensor_model_parallel_world_size == 1: - continue - with self.subTest(tensor_model_parallel_world_size=tensor_model_parallel_world_size): - if self.world_size % tensor_model_parallel_world_size: - continue - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=tensor_model_parallel_world_size, - ) - - input_tensor_shape = self.tensor_shape - expected_output_shape = self.tensor_shape - # When sequence parallel, `gather_output` is disabled, i.e., - # output of matmul isn't gathered in dimension of feature/hidden (last dim). - if sequence_parallel_enabled: - expected_output_shape[-1] //= tensor_model_parallel_world_size - - # tensor's shape is [sequence length, batch size, hidden size] - set_random_seed(self.SEED) - linear = layers.ColumnParallelLinear( - self.HIDDEN_SIZE, - self.HIDDEN_SIZE, - bias=False, - keep_master_weight_for_test=True, - params_dtype=torch.float32, - use_cpu_initialization=True, - gather_output=not sequence_parallel_enabled, - no_async_tensor_model_parallel_allreduce=not async_tensor_model_parallel_allreduce, - gradient_accumulation_fusion=gradient_accumulation_fusion, - accumulation_in_fp16=accumulation_in_fp16, - sequence_parallel_enabled=sequence_parallel_enabled, - ).cuda() - if accumulation_in_fp16: - linear = linear.half() - - # Simulate the situation where fusion of weight grad calculation and gradient accumulation happens. - if gradient_accumulation_fusion: - with torch.no_grad(): - linear.weight.main_grad = torch.zeros_like(linear.weight) - - orig_input_tensor = torch.randn(input_tensor_shape, device="cuda", requires_grad=True) - if accumulation_in_fp16: - orig_input_tensor = orig_input_tensor.half() - if sequence_parallel_enabled: - input_tensor = list( - orig_input_tensor.chunk(tensor_model_parallel_world_size, dim=0) - )[parallel_state.get_tensor_model_parallel_rank()] - else: - input_tensor = orig_input_tensor - output, _ = linear(input_tensor) - # The order of dimension is expected to be (sequence, batch, hidden) - self.assertEqual(output.shape, expected_output_shape) - - orig_loss_weight = torch.randn(input_tensor_shape, device="cuda") - if accumulation_in_fp16: - orig_loss_weight = orig_loss_weight.half() - if sequence_parallel_enabled: - loss_weight = orig_loss_weight.chunk( - tensor_model_parallel_world_size, dim=2, - )[parallel_state.get_tensor_model_parallel_rank()] - else: - loss_weight = orig_loss_weight - loss = torch.mul(output, loss_weight).sum() - loss.backward() - - with torch.no_grad(): - dldy = orig_loss_weight.clone() - x = orig_input_tensor.clone() - ref_linear = nn.Linear( - in_features=self.HIDDEN_SIZE, - out_features=self.HIDDEN_SIZE, - bias=False, - device="cuda", - ) - if accumulation_in_fp16: - ref_linear = ref_linear.half() - # NOTE(mkozuki): `master_weight` is available because `keep_master_weight_for_test` is set. - ref_linear.weight.copy_(linear.master_weight) - x.requires_grad_() - expected_output = ref_linear(x) - if sequence_parallel_enabled: - chunk = expected_output.chunk( - tensor_model_parallel_world_size, - dim=2, - )[parallel_state.get_tensor_model_parallel_rank()] - self.assertEqual( - x=output, - y=chunk, - ) - else: - self.assertEqual( - x=output, - y=expected_output, - ) - - expected_loss = torch.mul(expected_output, dldy).sum() - expected_loss.backward() - grad_attr_name = "main_grad" if gradient_accumulation_fusion else "grad" - # NOTE(mkozuki): Numerical errors seems to be enlarged by tensor model parallel. - if tensor_model_parallel_world_size == 1: - self.assertEqual( - x=getattr(linear.weight, grad_attr_name), - y=ref_linear.weight.grad.chunk( - chunks=tensor_model_parallel_world_size, - dim=0, - )[parallel_state.get_tensor_model_parallel_rank()], - ) - - parallel_state.destroy_model_parallel() - - -class NcclTensorParallelLayerTest(TensorParallelLayerTestBase, NcclDistributedTestBase): - pass - - -class UccTensorParallelLayerTest(TensorParallelLayerTestBase, UccDistributedTestBase): - pass - - -if __name__ == "__main__": - common_utils.run_tests() diff --git a/tests/L0/run_transformer/test_mapping.py b/tests/L0/run_transformer/test_mapping.py deleted file mode 100644 index 9ebda6d..0000000 --- a/tests/L0/run_transformer/test_mapping.py +++ /dev/null @@ -1,89 +0,0 @@ -import logging - -import torch -from torch.testing._internal import common_utils - -from apex.transformer import parallel_state -from apex.transformer.tensor_parallel import mappings -from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase -from apex.transformer.testing.distributed_test_base import UccDistributedTestBase - - -logging.getLogger("torch").setLevel(logging.WARNING) -logging.getLogger("apex").setLevel(logging.WARNING) - - -class MappingTestBase: - def test_reduce(self): - for tensor_model_paralell_world_size in range(1, self.world_size + 1): - if self.world_size % tensor_model_paralell_world_size > 0: - continue - with self.subTest( - tensor_model_paralell_world_size=tensor_model_paralell_world_size - ): - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=tensor_model_paralell_world_size - ) - t = torch.full((10, 10, 10, 10), 50, device=f"cuda:{self.rank}") - expected = torch.full( - (10, 10, 10, 10), - 50 * tensor_model_paralell_world_size, - device=f"cuda:{self.rank}", - ) - self.assertTrue(torch.equal(mappings._reduce(t), expected)) - parallel_state.destroy_model_parallel() - - def test_split(self): - for tensor_model_paralell_world_size in range(1, self.world_size + 1): - if self.world_size % tensor_model_paralell_world_size > 0: - continue - with self.subTest( - tensor_model_paralell_world_size=tensor_model_paralell_world_size - ): - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=tensor_model_paralell_world_size - ) - - tensors = [ - torch.randn(10, 1) - for rank in range(tensor_model_paralell_world_size) - ] - x = torch.cat(tensors, 1) - out = mappings._split_along_last_dim(x) - self.assertTrue( - torch.equal( - out, tensors[parallel_state.get_tensor_model_parallel_rank()] - ) - ) - parallel_state.destroy_model_parallel() - - def test_gather(self): - for tensor_model_paralell_world_size in range(1, self.world_size + 1): - if self.world_size % tensor_model_paralell_world_size > 0: - continue - with self.subTest( - tensor_model_paralell_world_size=tensor_model_paralell_world_size - ): - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=tensor_model_paralell_world_size - ) - device = f"cuda:{self.rank}" - gathered = mappings._gather_along_last_dim( - torch.tensor( - [parallel_state.get_tensor_model_parallel_rank()], device=device - ) - ) - expected = torch.tensor( - [rank for rank in range(tensor_model_paralell_world_size)], - device=device, - ) - self.assertTrue(torch.equal(gathered, expected)) - parallel_state.destroy_model_parallel() - - -class NcclMappingTest(MappingTestBase, NcclDistributedTestBase): pass -class UccMappingTest(MappingTestBase, UccDistributedTestBase): pass - - -if __name__ == "__main__": - common_utils.run_tests() diff --git a/tests/L0/run_transformer/test_microbatches.py b/tests/L0/run_transformer/test_microbatches.py deleted file mode 100644 index 0d4b509..0000000 --- a/tests/L0/run_transformer/test_microbatches.py +++ /dev/null @@ -1,85 +0,0 @@ -import logging -from typing import List, Optional - -from torch.testing._internal import common_utils - -logging.getLogger("torch").setLevel(logging.WARNING) - -from apex.transformer import parallel_state -from apex.transformer.pipeline_parallel.utils import ( - _reconfigure_microbatch_calculator, - get_micro_batch_size, - get_num_microbatches, - get_current_global_batch_size, - update_num_microbatches, -) -from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase -from apex.transformer.testing.distributed_test_base import UccDistributedTestBase - -logging.getLogger("apex").setLevel(logging.WARNING) - - -class MicrobatchCalculatorTestBase: - - GLOBAL_BATCH_SIZE: int = 1024 - MICRO_BATCH_SIZE: int = 1 - - def _test(self, rampup_batch_size: Optional[List[int]]) -> None: - for data_parallel_size in range(1, self.world_size + 1): - - expected_global_batch_size = self.GLOBAL_BATCH_SIZE - expected_micro_batch_size = self.MICRO_BATCH_SIZE - if rampup_batch_size: - expected_global_batch_size = rampup_batch_size[0] - num_consumed_samples = 0 - step_of_global_batch_size = rampup_batch_size[1] - threshold = rampup_batch_size[2] - - if data_parallel_size > 1 and data_parallel_size % 2 != 0: - continue - if self.world_size % data_parallel_size != 0: - continue - with self.subTest(data_parallel_size=data_parallel_size): - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=self.world_size // data_parallel_size, - pipeline_model_parallel_size_=1, - ) - self.assertEqual(data_parallel_size, parallel_state.get_data_parallel_world_size()) - - _reconfigure_microbatch_calculator( - self.rank, - rampup_batch_size, - self.GLOBAL_BATCH_SIZE, - self.MICRO_BATCH_SIZE, - data_parallel_size, - ) - - self.assertEqual(get_micro_batch_size(), expected_micro_batch_size) - self.assertEqual(get_num_microbatches(), expected_global_batch_size / expected_micro_batch_size / data_parallel_size) - current_global_batch_size = get_current_global_batch_size() - self.assertEqual(current_global_batch_size, expected_global_batch_size) - - # Make sure `global_batch_size` equals to the final global batch size after - # certain number of updates. - if rampup_batch_size: - update_num_microbatches(current_global_batch_size) - for i in range(100): - current_global_batch_size = get_current_global_batch_size() - update_num_microbatches(current_global_batch_size) - current_global_batch_size = get_current_global_batch_size() - self.assertEqual(get_current_global_batch_size(), self.GLOBAL_BATCH_SIZE) - parallel_state.destroy_model_parallel() - - def test_constant_microbatch_calculator(self): - self._test(rampup_batch_size=None) - - def test_dynamic_microbatch_calculator(self): - self._test(rampup_batch_size=[256, 128, 500]) - - -class NcclMicrobatchCalculatorTest(MicrobatchCalculatorTestBase, NcclDistributedTestBase): pass -class UccMicrobatchCalculatorTest(MicrobatchCalculatorTestBase, UccDistributedTestBase): pass - - -if __name__ == "__main__": - common_utils.run_tests() diff --git a/tests/L0/run_transformer/test_p2p_comm.py b/tests/L0/run_transformer/test_p2p_comm.py deleted file mode 100644 index c93b19a..0000000 --- a/tests/L0/run_transformer/test_p2p_comm.py +++ /dev/null @@ -1,122 +0,0 @@ -import logging -import unittest - -import torch -from torch.testing._internal import common_utils - -logging.getLogger("torch").setLevel(logging.WARNING) - -from apex.transformer import parallel_state -from apex.transformer.pipeline_parallel import p2p_communication -from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase -from apex.transformer.testing.distributed_test_base import UccDistributedTestBase - -logging.getLogger("apex").setLevel(logging.DEBUG) - - -# [P2P Ops Involved in Pipeline Model Parallel forward/backward] -# **forward_backward_pipelining_without_interleaving** -# - send_forward / recv_forward -# - send_backward / recv_backward -# - send_forward_recv_backward -# - send_backward_recv_forward -# **forward_backward_pipelining_with_interleaving** -# - send_backward_recv_backward -# - recv_backward -# - recv_forward -# - send_forward_backward_recv_forward_backward -# - send_forward_recv_forward -class P2PCommTestBase: - - numel = 4 - shape = (2, 2) - dtype = torch.float32 - - @property - def world_size(self): - return min(2, torch.cuda.device_count()) - - def _init_model_parallel(self): - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=1, - pipeline_model_parallel_size_=self.world_size, - virtual_pipeline_model_parallel_size_=None, - ) - - def create_tensor(self, value: int = None): - return torch.tensor( - [value] * self.numel).view(self.shape).to(device="cuda", dtype=self.dtype) - - # Brief: Simulate warm-up. - # Brief: test `recv_forward` & `send_forward`. - def test_no_interleaving_warmup(self): - self.assertEqual(self.world_size, 2) - self._init_model_parallel() - input_tensor = None - if parallel_state.is_pipeline_first_stage(): - tensor = self.create_tensor(self.rank) - print(tensor) - p2p_communication.send_forward(output_tensor=tensor, tensor_shape=self.shape, dtype=self.dtype) - else: - input_tensor = p2p_communication.recv_forward(tensor_shape=self.shape, dtype=self.dtype) - - if parallel_state.is_pipeline_first_stage(): - self.assertIsNone(input_tensor) - else: - expected_input_tensor = self.create_tensor(self.rank - 1) - self.assertEqual(input_tensor, expected_input_tensor) - - # Brief: test `send_forward`, `send_forward_recv_forward`, and `recv_forward`. - def test_send_forward_recv_forward(self): - self._init_model_parallel() - prev_tensor = None - tensor = self.create_tensor(self.rank) - if parallel_state.is_pipeline_first_stage(): - p2p_communication.send_forward(output_tensor=tensor, tensor_shape=self.shape, dtype=self.dtype) - elif parallel_state.is_pipeline_last_stage(): - prev_tensor = p2p_communication.recv_forward(tensor_shape=self.shape, dtype=self.dtype) - else: - prev_tensor = p2p_communication.send_forward_recv_forward( - output_tensor=tensor, - recv_prev=True, - tensor_shape=self.shape, - dtype=self.dtype, - ) - - if parallel_state.is_pipeline_first_stage(): - self.assertIsNone(prev_tensor) - else: - expected_prev_tensor = self.create_tensor(self.rank - 1) - self.assertEqual(prev_tensor, expected_prev_tensor) - - # Brief: test `send_backward`, `send_backward_recv_backward`, and `recv_backward`. - def test_send_backward_recv_backward(self): - self._init_model_parallel() - tensor = self.create_tensor(self.rank) - - next_tensor = None - if parallel_state.is_pipeline_first_stage(): - next_tensor = p2p_communication.recv_backward(tensor_shape=self.shape, dtype=self.dtype) - elif parallel_state.is_pipeline_last_stage(): - p2p_communication.send_backward(input_tensor_grad=tensor, tensor_shape=self.shape, dtype=self.dtype) - else: - next_tensor = p2p_communication.send_backward_recv_backward( - input_tensor_grad=tensor, - recv_next=True, - tensor_shape=self.shape, - dtype=self.dtype, - ) - - if parallel_state.is_pipeline_last_stage(): - self.assertIsNone(next_tensor) - else: - expected_next_tensor = self.create_tensor(self.rank + 1) - self.assertEqual(next_tensor, expected_next_tensor) - - -# n.b.(mkozuki): Intentionally skip NCCL backend tests as I trust pytorch/pytorch repo. -class UccP2PCommTest(P2PCommTestBase, UccDistributedTestBase): pass - - -if __name__ == "__main__": - common_utils.run_tests() diff --git a/tests/L0/run_transformer/test_parallel_state.py b/tests/L0/run_transformer/test_parallel_state.py deleted file mode 100644 index 0314895..0000000 --- a/tests/L0/run_transformer/test_parallel_state.py +++ /dev/null @@ -1,185 +0,0 @@ -import logging -import os - -from torch.testing._internal import common_utils - -logging.getLogger("torch").setLevel(logging.WARNING) - -from apex.transformer import parallel_state -from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase -from apex.transformer.testing.distributed_test_base import UccDistributedTestBase - -logging.getLogger("apex").setLevel(logging.WARNING) - - -os.environ["BACKEND"] = "NCCL" -DATA_PARALLEL_WORLD_SIZE: int = 1 - - -def calc_expected_tensor_model_paralell_rank( - rank: int, tensor_model_parallel_world_size: int, -) -> int: - return rank % tensor_model_parallel_world_size - - -class ParallelStateTestBase: - def test_initialize_model_parallel(self) -> None: - - self.assertFalse(parallel_state.model_parallel_is_initialized()) - - for tensor_model_parallel_world_size in range(1, self.world_size + 1): - with self.subTest( - tensor_model_parallel_world_size=tensor_model_parallel_world_size - ): - if self.world_size % tensor_model_parallel_world_size: - continue - - pipeline_model_parallel_world_size = ( - self.world_size // tensor_model_parallel_world_size - ) - - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=tensor_model_parallel_world_size, - pipeline_model_parallel_size_=pipeline_model_parallel_world_size, - ) - self.assertEqual( - tensor_model_parallel_world_size, - parallel_state.get_tensor_model_parallel_world_size(), - ) - expected_tensor_model_parallel_rank = calc_expected_tensor_model_paralell_rank( - self.rank, tensor_model_parallel_world_size - ) - self.assertEqual( - expected_tensor_model_parallel_rank, - parallel_state.get_tensor_model_parallel_rank(), - ) - - expected_tensor_model_parallel_src_rank = ( - self.rank // tensor_model_parallel_world_size - ) * tensor_model_parallel_world_size - self.assertEqual( - expected_tensor_model_parallel_src_rank, - parallel_state.get_tensor_model_parallel_src_rank(), - ) - - parallel_state.destroy_model_parallel() - self.assertFalse(parallel_state.model_parallel_is_initialized()) - - def test_initialize_model_parallel_with_virtual_and_split(self) -> None: - if self.world_size < 4: - self.skipTest("requires >= 4 GPUs") - self.assertFalse(parallel_state.model_parallel_is_initialized()) - - tensor_model_parallel_world_size = 1 + int(self.world_size > 4) - pipeline_model_parallel_world_size = ( - self.world_size // tensor_model_parallel_world_size - ) - virtual_pipeline_model_parallel_world_size = 2 - pipeline_model_parallel_split_rank = pipeline_model_parallel_world_size // 2 - - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=tensor_model_parallel_world_size, - pipeline_model_parallel_size_=pipeline_model_parallel_world_size, - virtual_pipeline_model_parallel_size_=virtual_pipeline_model_parallel_world_size, - pipeline_model_parallel_split_rank_=pipeline_model_parallel_split_rank, - ) - self.assertEqual( - calc_expected_tensor_model_paralell_rank( - self.rank, tensor_model_parallel_world_size - ), - parallel_state.get_tensor_model_parallel_rank(), - ) - self.assertEqual( - pipeline_model_parallel_world_size, - parallel_state.get_pipeline_model_parallel_world_size(), - ) - self.assertEqual( - virtual_pipeline_model_parallel_world_size, - parallel_state.get_virtual_pipeline_model_parallel_world_size(), - ) - - expected_pipeline_rank = ( - self.rank - (self.rank % tensor_model_parallel_world_size) - ) % pipeline_model_parallel_world_size - self.assertEqual( - expected_pipeline_rank, parallel_state.get_pipeline_model_parallel_rank(), - ) - # virtual pipeline model parallel rank is lazily set, i.e., right after the call of - # `initialize_model_parallel`, it's set to 0. - self.assertEqual( - 0, parallel_state.get_virtual_pipeline_model_parallel_rank(), - ) - self.assertEqual( - pipeline_model_parallel_split_rank, - parallel_state.get_pipeline_model_parallel_split_rank(), - ) - - fake_split_rank = 77 - parallel_state.set_pipeline_model_parallel_split_rank(fake_split_rank) - self.assertEqual( - fake_split_rank, parallel_state.get_pipeline_model_parallel_split_rank() - ) - - # relative position embedding groups check - self.assertEqual( - expected_pipeline_rank < pipeline_model_parallel_split_rank, - parallel_state.is_rank_in_encoder_relative_position_embedding_group(), - ) - self.assertEqual( - expected_pipeline_rank >= pipeline_model_parallel_split_rank, - parallel_state.is_rank_in_decoder_relative_position_embedding_group(), - ) - - parallel_state.destroy_model_parallel() - - def test_initialize_model_parallel_decoder_only(self) -> None: - """Initialize model parallelism for decoder-only Transformers like GPT-3""" - - self.assertFalse(parallel_state.model_parallel_is_initialized()) - - for tensor_model_parallel_world_size in range(1, self.world_size + 1): - with self.subTest( - tensor_model_parallel_world_size=tensor_model_parallel_world_size - ): - if self.world_size % tensor_model_parallel_world_size: - continue - - pipeline_model_parallel_world_size = ( - self.world_size // tensor_model_parallel_world_size - ) - - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=tensor_model_parallel_world_size, - pipeline_model_parallel_size_=pipeline_model_parallel_world_size, - pipeline_model_parallel_split_rank_=0, - ) - self.assertEqual( - tensor_model_parallel_world_size, - parallel_state.get_tensor_model_parallel_world_size(), - ) - expected_tensor_model_parallel_rank = calc_expected_tensor_model_paralell_rank( - self.rank, tensor_model_parallel_world_size - ) - self.assertEqual( - expected_tensor_model_parallel_rank, - parallel_state.get_tensor_model_parallel_rank(), - ) - - expected_tensor_model_parallel_src_rank = ( - self.rank // tensor_model_parallel_world_size - ) * tensor_model_parallel_world_size - self.assertEqual( - expected_tensor_model_parallel_src_rank, - parallel_state.get_tensor_model_parallel_src_rank(), - ) - - parallel_state.destroy_model_parallel() - self.assertFalse(parallel_state.model_parallel_is_initialized()) - - -class NcclParallelStateTest(ParallelStateTestBase, NcclDistributedTestBase): pass -class UccParallelStateTest(ParallelStateTestBase, UccDistributedTestBase): pass - - -if __name__ == "__main__": - common_utils.run_tests() diff --git a/tests/L0/run_transformer/test_pipeline_parallel_fwd_bwd.py b/tests/L0/run_transformer/test_pipeline_parallel_fwd_bwd.py deleted file mode 100644 index a409c40..0000000 --- a/tests/L0/run_transformer/test_pipeline_parallel_fwd_bwd.py +++ /dev/null @@ -1,447 +0,0 @@ -import logging -import itertools -import re -from typing import Optional, Tuple, List -import unittest - -import torch -from torch.testing._internal import common_utils -from torch.testing._internal import common_cuda - -from apex._autocast_utils import _get_autocast_dtypes -from apex.transformer import parallel_state -from apex.transformer.enums import ModelType -from apex.transformer.pipeline_parallel import utils as pp_utils -from apex.transformer.pipeline_parallel.schedules.common import ( - FwdStepFunc, - build_model, - _get_params_for_weight_decay_optimization, -) -from apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining import ( - forward_backward_no_pipelining, -) -from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving import ( - _forward_backward_pipelining_with_interleaving, -) -from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import ( - forward_backward_pipelining_without_interleaving, -) -from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase -from apex.transformer.testing.distributed_test_base import UccDistributedTestBase -from apex.transformer.testing.distributed_test_base import HAS_TORCH_UCC -from apex.transformer.testing.distributed_test_base import HAS_TORCH_UCC_COMPAT_NVIDIA_DRIVER -from apex.transformer.testing import commons as testing_utils - - -logging.getLogger("torch").setLevel(logging.WARNING) -logging.getLogger("apex").setLevel(logging.WARNING) - -weight_coeff = 1024 - - -def get_init_weights_func(offset: int = 0): - @torch.no_grad() - def init_weights(m): - rank = parallel_state.get_pipeline_model_parallel_rank() - if isinstance(m, torch.nn.Linear): - m.weight.fill_((rank + offset + 1.0) / weight_coeff) - m.bias.fill_(1.0) - return init_weights - - -def get_dtype_for_comparison(): - if(torch.cuda.get_device_capability() >= (8, 0)): - return torch.float64 - return torch.float32 - - -def get_target_loss_and_model(global_batch_shape: tuple, hidden_size: int, total_layers: int) -> Tuple[torch.Tensor, List[torch.Tensor]]: - model = [] - dtype = get_dtype_for_comparison() - data = torch.ones(global_batch_shape, dtype=dtype) - for i in range(total_layers): - w = torch.ones((hidden_size, hidden_size), dtype=dtype) * (i + 1.0) / weight_coeff - b = torch.ones(hidden_size, dtype=dtype) - - w.requires_grad_() - b.requires_grad_() - - # don't need to care about transpose semantics as all values are the same - data = torch.matmul(w, data) + b - model.append([w, b]) - - loss = data.sum() / global_batch_shape[0] - loss.backward() - - return loss, model - - -def _get_default_world_sizes_model_parallel_world_size(pipeline_model_parallel_world_size: Optional[int] = None - ) -> Tuple[int, int, int]: - # TODO: revisit if we can fold this into the class for skip logic / avoid duplication - # of world size computation - world_size = torch.cuda.device_count() - tensor_model_parallel_world_size = 1 - data_parallel_size = 1 + (world_size >= 8 and world_size % 2 == 0) - - if pipeline_model_parallel_world_size is None: - pipeline_model_parallel_world_size = world_size // (tensor_model_parallel_world_size * data_parallel_size) - else: - data_parallel_size = world_size // (tensor_model_parallel_world_size * pipeline_model_parallel_world_size) - - return tensor_model_parallel_world_size, data_parallel_size, pipeline_model_parallel_world_size - - -class PipelineParallelForwardBackwardTestBase: - - GLOBAL_BATCH_SIZE = 16 - MICRO_BATCH_SIZE = 2 - HIDDEN_SIZE = 32 - - deallocate_options = (True, False) - # If :obj:`None`, (torch.float32, torch.float16, torch.bfloat16) are dtype options on Ampere. - # You can limit the options by overriding the following `dtypes`. - dtypes = None - - def _forward_backward_test_impl( - self, - forward_only: bool, - fwd_bwd_func: FwdStepFunc, - pipeline_model_parallel_world_size: Optional[int], - virtual_pipeline_model_parallel_size: Optional[int], - async_comm: bool = False, - *, - default_backend: Optional[str] = None, - p2p_backend: Optional[str] = None, - ) -> None: - if fwd_bwd_func == _forward_backward_pipelining_with_interleaving: - self.assertIsNotNone(virtual_pipeline_model_parallel_size) - self.assertGreater(virtual_pipeline_model_parallel_size, 1) - dtype_options = self.dtypes or [torch.float32, torch.double] + _get_autocast_dtypes() - - for dtype, deallocate_pipeline_outputs in itertools.product( - dtype_options, self.deallocate_options, - ): - grad_scaler = ( - torch.cuda.amp.GradScaler(init_scale=4.0) - if dtype == torch.half - else None - ) - - (tensor_model_parallel_world_size, - data_parallel_size, - pipeline_model_parallel_world_size) = _get_default_world_sizes_model_parallel_world_size(pipeline_model_parallel_world_size) - - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=tensor_model_parallel_world_size, - pipeline_model_parallel_size_=pipeline_model_parallel_world_size, - virtual_pipeline_model_parallel_size_=virtual_pipeline_model_parallel_size, - default_backend=default_backend, - p2p_backend=p2p_backend, - ) - pp_utils._reconfigure_microbatch_calculator( - rank=parallel_state.get_tensor_model_parallel_rank(), - rampup_batch_size=None, - global_batch_size=self.GLOBAL_BATCH_SIZE, - micro_batch_size=self.MICRO_BATCH_SIZE, - data_parallel_size=parallel_state.get_data_parallel_world_size(), - ) - - global_batch_shape = ( - self.GLOBAL_BATCH_SIZE - // parallel_state.get_data_parallel_world_size(), - self.HIDDEN_SIZE, - self.HIDDEN_SIZE, - ) - - batch = None - if parallel_state.is_pipeline_first_stage(): - batch = (torch.ones(global_batch_shape, dtype=dtype).cuda(), ) - - model = build_model( - testing_utils.model_provider_func, - # Use DDP only when it's better to have - wrap_with_ddp=data_parallel_size > 1, - virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size, - hidden_size=self.HIDDEN_SIZE, - ) - - - offset = pipeline_model_parallel_world_size if virtual_pipeline_model_parallel_size is not None else 0 - for idx, model_module in enumerate(model): - model_module = model_module.to(dtype) - model_module.apply(get_init_weights_func(idx*offset)) - - _param_groups = _get_params_for_weight_decay_optimization(model) - optimizer = torch.optim.Adam(_param_groups, lr=1e-3) - - pp_utils.update_num_microbatches(0) - - loss = fwd_bwd_func( - testing_utils.fwd_step_func, - batch, - model, - forward_only=forward_only, - # `tensor_shape` is the shape of micro batch. - tensor_shape=( - self.MICRO_BATCH_SIZE, - self.HIDDEN_SIZE, - self.HIDDEN_SIZE, - ), - dtype=dtype, - async_comm=async_comm, - grad_scaler=grad_scaler, - deallocate_pipeline_output=deallocate_pipeline_outputs, - ) - - if dtype == get_dtype_for_comparison(): - torch.cuda.synchronize() - hidden_size = self.HIDDEN_SIZE - microbatch_size = self.MICRO_BATCH_SIZE - total_layers = pipeline_model_parallel_world_size - if virtual_pipeline_model_parallel_size is not None: - total_layers *= virtual_pipeline_model_parallel_size - target_loss, target_model = get_target_loss_and_model(global_batch_shape, hidden_size, total_layers) - - for loss_item in loss: - x = loss_item['avg'] - self.assertEqual(x.item() / microbatch_size, target_loss.item()) - - if not forward_only: - for vm_id, model_module in enumerate(model): - params = list(model_module.parameters()) - rank = params[0].get_device() - offset = pipeline_model_parallel_world_size - param_id = rank // data_parallel_size + vm_id * offset - target_params = target_model[param_id] - - self.assertEqual(params[0].cpu(), target_params[0]) - self.assertEqual(params[1].cpu(), target_params[1]) - self.assertEqual(params[0].grad.cpu() / microbatch_size, target_params[0].grad) - self.assertEqual(params[1].grad.cpu() / microbatch_size, target_params[1].grad) - - if not forward_only: - for m in model: - for p in m.parameters(): - self.assertIsNotNone(p.grad) - optimizer.step() - optimizer.zero_grad(set_to_none=True) - - parallel_state.destroy_model_parallel() - - def test_learning_no_pipelining(self): - self._forward_backward_test_impl(False, forward_backward_no_pipelining, 1, None) - - def test_inference_no_pipelining(self): - self._forward_backward_test_impl(True, forward_backward_no_pipelining, 1, None) - - def test_learning_pipelining_without_interleaving(self): - self._forward_backward_test_impl( - False, forward_backward_pipelining_without_interleaving, None, None - ) - - def test_inference_pipelining_without_interleaving(self): - self._forward_backward_test_impl( - True, forward_backward_pipelining_without_interleaving, None, None - ) - - def test_learning_async_pipelining_without_interleaving(self): - self._forward_backward_test_impl( - False, forward_backward_pipelining_without_interleaving, None, None, async_comm=True - ) - - def test_inference_async_pipelining_without_interleaving(self): - self._forward_backward_test_impl( - True, forward_backward_pipelining_without_interleaving, None, None, async_comm=True - ) - - @unittest.skipUnless(_get_default_world_sizes_model_parallel_world_size()[-1] > 2, "Interleaved schedule requires pipeline_model_parallel_world_size > 2") - def test_learning_pipelining_with_interleaving(self): - self._forward_backward_test_impl( - False, _forward_backward_pipelining_with_interleaving, None, virtual_pipeline_model_parallel_size=2 - ) - - @unittest.skipUnless(_get_default_world_sizes_model_parallel_world_size()[-1] > 2, "Interleaved schedule requires pipeline_model_parallel_world_size > 2") - def test_inference_pipelining_with_interleaving(self): - self._forward_backward_test_impl( - True, _forward_backward_pipelining_with_interleaving, None, virtual_pipeline_model_parallel_size=2 - ) - - @unittest.skipUnless(_get_default_world_sizes_model_parallel_world_size()[-1] > 2, "Interleaved schedule requires pipeline_model_parallel_world_size > 2") - def test_learning_async_pipelining_with_interleaving(self): - self._forward_backward_test_impl( - False, _forward_backward_pipelining_with_interleaving, None, virtual_pipeline_model_parallel_size=2, async_comm=True - ) - - @unittest.skipUnless(_get_default_world_sizes_model_parallel_world_size()[-1] > 2, "Interleaved schedule requires pipeline_model_parallel_world_size > 2") - def test_inference_async_pipelining_with_interleaving(self): - self._forward_backward_test_impl( - True, _forward_backward_pipelining_with_interleaving, None, virtual_pipeline_model_parallel_size=2, async_comm=True - ) - - -class NcclPipelineParallelForwardBackwardTest(NcclDistributedTestBase, PipelineParallelForwardBackwardTestBase): - - @property - def world_size(self) -> int: - return min(torch.cuda.device_count(), 8) - - def _run_hybrid_distributed_backend(self, forward_only: bool) -> None: - self._forward_backward_test_impl( - forward_only, forward_backward_pipelining_without_interleaving, None, None, - default_backend="nccl", p2p_backend="ucc", - ) - - @unittest.skipUnless(HAS_TORCH_UCC_COMPAT_NVIDIA_DRIVER, "Needs driver >= 470.42.01") - def _test_hybrid_backends(self, forward_only: bool) -> None: - if HAS_TORCH_UCC: - self._run_hybrid_distributed_backend(forward_only) - else: - with self.assertRaisesRegex( - ImportError, - re.escape("UCC backend requires [torch_ucc](https://github.com/facebookresearch/torch_ucc) but not found"), - ): - self._run_hybrid_distributed_backend(forward_only) - - def test_learning_pipelining_without_interleaving_ucc_for_p2p(self): - self._test_hybrid_backends(False) - - def test_inference_pipelining_without_interleaving_ucc_for_p2p(self): - self._test_hybrid_backends(True) - - -# n.b.(mkozuki): pipeline parallel w/o interleaving with UCX_TLS=tcp,sm fails. -class UccPipelineParallelForwardBackwardTest(UccDistributedTestBase, PipelineParallelForwardBackwardTestBase): - - @property - def world_size(self) -> int: - return min(torch.cuda.device_count(), 8) - - deallocate_options = (False,) - dtypes = (torch.float32,) - - -# Sanity checking the functionality of `forward_backward_pipelining_without_interleaving` with -# `model_type=ModelType.encoder_and_decoder` which is used for pipeline training of transformer -# models such as T5. -@unittest.skipIf(torch.cuda.device_count() < 4, "Requires >= 4 GPUs") -class NcclPipelineParallelWithToyParallelMLP(NcclDistributedTestBase): - - GLOBAL_BATCH_SIZE = 16 - MICRO_BATCH_SIZE = 2 - HIDDEN_SIZE = 64 - # TODO(mkozuki): Change `DECODER_SEQUENCE_LENGTH` to a value different from `ENCODER_SEQUENCE_LENGTH`. - # To test forward_backward_pipelining_without_interleaving with `model_type=ModelType.encoder_and_decoder`, - # `decoder_seq_length` is necessary and ideally should be different from `encoder_sequence_length` - # but my laziness let me use the same value. - # Note that you may have to either update `MyModel` def or define another `MyModel`. - # to support different `DECODER_SEQUENCE_LENGTH`. - ENCODER_SEQUENCE_LENGTH = 32 - DECODER_SEQUENCE_LENGTH = 32 - - @property - def world_size(self) -> int: - return min(torch.cuda.device_count(), 8) - - # TODO(mkozuki): Add cases of async_comm=True - # TODO(mkozuki): Add loss check. - # TODO(mkozuki): Call `build_model` with `model_type`. - # TODO(mkozuki): Set `tensor_model_parallel>1` for encoder_and_decoder as well if there's enough GPUs - # in order to let `sequence_parallel_enabled` have an effect on tensor shape logic. - def _forward_backward_test_impl( - self, - *, - forward_only: bool, - sequence_parallel_enabled: bool, - model_type: ModelType, - dtype: torch.dtype = torch.float32, - ) -> None: - # N.B.(mkozuki): It might be better to set `tensor_model_parallel_size` to >1 - # if `self.world_size > 5`. Otherwise, `pipeline_model_parallel_split_rank` - # can be 1, which can be too far real usecase. - tensor_model_parallel_size = 1 + int(self.world_size >= 4) - pipeline_model_parallel_world_size = self.world_size // tensor_model_parallel_size - if model_type == ModelType.encoder_and_decoder: - pipeline_model_parallel_split_rank = pipeline_model_parallel_world_size // 2 - else: - pipeline_model_parallel_split_rank = None - - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=tensor_model_parallel_size, - pipeline_model_parallel_size_=pipeline_model_parallel_world_size, - virtual_pipeline_model_parallel_size_=None, - pipeline_model_parallel_split_rank_=pipeline_model_parallel_split_rank, - ) - testing_utils.set_random_seed(567) - pp_utils._reconfigure_microbatch_calculator( - rank=parallel_state.get_tensor_model_parallel_rank(), - rampup_batch_size=None, - global_batch_size=self.GLOBAL_BATCH_SIZE, - micro_batch_size=self.MICRO_BATCH_SIZE, - data_parallel_size=parallel_state.get_data_parallel_world_size(), - ) - model = build_model( - testing_utils.mlp_provider_func, - wrap_with_ddp=False, - virtual_pipeline_model_parallel_size=None, - hidden_size=self.HIDDEN_SIZE, - sequence_parallel_enabled=sequence_parallel_enabled, - ) - model = [m.to(dtype=dtype) for m in model] - - if parallel_state.is_pipeline_first_stage(): - batch: Tuple[torch.Tensor] = ( - torch.ones( - (self.GLOBAL_BATCH_SIZE, self.ENCODER_SEQUENCE_LENGTH, self.HIDDEN_SIZE), - dtype=dtype, - device="cuda", - ), - ) - else: - batch = None - - forward_backward_pipelining_without_interleaving( - forward_step_func=testing_utils.ToyParallelMLPFwdBwdStepFunc( - sequence_parallel_enabled=sequence_parallel_enabled, - ), - batch=batch, - model=model, - forward_only=forward_only, - tensor_shape=( - self.ENCODER_SEQUENCE_LENGTH, - self.MICRO_BATCH_SIZE, - self.HIDDEN_SIZE, - ), - model_type=model_type, - decoder_sequence_length=self.DECODER_SEQUENCE_LENGTH, - async_comm=False, - grad_scaler=None, - deallocate_pipeline_outputs=False, - dtype=dtype, - sequence_parallel_enabled=sequence_parallel_enabled, - ) - - def test_pipelining_without_interleaving_encoder_and_decoder(self) -> None: - self._forward_backward_test_impl(forward_only=False, sequence_parallel_enabled=False, model_type=ModelType.encoder_and_decoder) - - def test_pipelining_without_interleaving_inferenc_encoder_and_decoder(self) -> None: - self._forward_backward_test_impl(forward_only=True, sequence_parallel_enabled=False, model_type=ModelType.encoder_and_decoder) - - def test_pipelining_without_interleaving_sequence_paralle_encoder_and_decoder(self) -> None: - self._forward_backward_test_impl(forward_only=False, sequence_parallel_enabled=True, model_type=ModelType.encoder_and_decoder) - - def test_pipelining_without_interleaving_inference_sequence_paralle_encoder_and_decoder(self) -> None: - self._forward_backward_test_impl(forward_only=True, sequence_parallel_enabled=True, model_type=ModelType.encoder_and_decoder) - - def test_pipelining_without_interleaving_encoder_or_decoder(self) -> None: - self._forward_backward_test_impl(forward_only=False, sequence_parallel_enabled=False, model_type=ModelType.encoder_or_decoder) - - def test_pipelining_without_interleaving_sequence_parallel_encoder_or_decoder(self) -> None: - self._forward_backward_test_impl(forward_only=False, sequence_parallel_enabled=True, model_type=ModelType.encoder_or_decoder) - - def test_pipelining_without_interleaving_sequence_parallel_encoder_or_decoder_half(self) -> None: - self._forward_backward_test_impl(forward_only=False, sequence_parallel_enabled=True, model_type=ModelType.encoder_or_decoder, dtype=torch.half) - - -if __name__ == "__main__": - common_utils.run_tests() diff --git a/tests/L0/run_transformer/test_random.py b/tests/L0/run_transformer/test_random.py deleted file mode 100644 index 6060f9e..0000000 --- a/tests/L0/run_transformer/test_random.py +++ /dev/null @@ -1,120 +0,0 @@ -import logging - -import torch -from torch.testing._internal import common_utils - -logging.getLogger("torch").setLevel(logging.WARNING) - -from apex.transformer import parallel_state -from apex.transformer import tensor_parallel -from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase -from apex.transformer.testing.distributed_test_base import UccDistributedTestBase - -logging.getLogger("apex").setLevel(logging.WARNING) - - -class TransformerRandomTestBase: - def test_set_cuda_rng_state(self): - for tensor_model_parallel_world_size in range(1, self.world_size + 1): - if self.world_size % tensor_model_parallel_world_size: - continue - with self.subTest( - tensor_model_parallel_world_size=tensor_model_parallel_world_size - ): - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=tensor_model_parallel_world_size - ) - - size, seed = 123, 1234 - torch.cuda.manual_seed(seed) - tensor = torch.cuda.FloatTensor(size) - - rng_state = torch.cuda.get_rng_state() - rng_state_clone = rng_state.clone() - - for _ in range(5): - torch.randn(size, out=tensor) - result_1 = tensor.clone() - - self.assertEqual(rng_state.sub(rng_state_clone).max(), 0) - self.assertGreater( - torch.cuda.get_rng_state().sub(rng_state_clone).max(), 0 - ) - - new_rng_state = torch.cuda.get_rng_state() - self.assertGreater(new_rng_state.sub(rng_state).max(), 0) - - tensor_parallel.random._set_cuda_rng_state(rng_state) - for _ in range(5): - torch.randn(size, out=tensor) - tensor_parallel.random._set_cuda_rng_state(rng_state) - for _ in range(5): - torch.randn(size, out=tensor) - result_2 = tensor.clone() - - self.assertEqual(result_2, result_1) - - self.assertEqual(rng_state.sub(rng_state_clone).max(), 0) - - parallel_state.destroy_model_parallel() - - def test_cuda_rng_tracker(self): - for tensor_model_parallel_world_size in range(1, self.world_size + 1): - if self.world_size % tensor_model_parallel_world_size: - continue - with self.subTest( - tensor_model_parallel_world_size=tensor_model_parallel_world_size - ): - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=tensor_model_parallel_world_size - ) - - seed_1, seed_2, size = 1234, 4321, [12, 21] - tensor = torch.cuda.FloatTensor(size) - - torch.cuda.manual_seed(seed_1) - torch.randn(size, out=tensor) - target_11 = tensor.clone() - torch.randn(size, out=tensor) - target_12 = tensor.clone() - - torch.cuda.manual_seed(seed_2) - torch.randn(size, out=tensor) - targt_21 = tensor.clone() - torch.randn(size, out=tensor) - target_22 = tensor.clone() - - torch.cuda.manual_seed(seed_1) - tensor_parallel.random.get_cuda_rng_tracker().add("test", seed_2) - - torch.randn(size, out=tensor) - result_11 = tensor.clone() - - with tensor_parallel.random.get_cuda_rng_tracker().fork("test"): - torch.randn(size, out=tensor) - result_21 = tensor.clone() - - torch.randn(size, out=tensor) - result_12 = tensor.clone() - - with tensor_parallel.random.get_cuda_rng_tracker().fork("test"): - torch.randn(size, out=tensor) - result_22 = tensor.clone() - - self.assertEqual(target_11, result_11) - self.assertEqual(target_12, result_12) - self.assertEqual(targt_21, result_21) - self.assertEqual(target_22, result_22) - self.assertNotEqual(result_11, result_21) - self.assertNotEqual(result_21, result_22) - - tensor_parallel.random.get_cuda_rng_tracker().reset() - parallel_state.destroy_model_parallel() - - -class NcclTransformerRandomTest(TransformerRandomTestBase, NcclDistributedTestBase): pass -class UccTransformerRandomTest(TransformerRandomTestBase, UccDistributedTestBase): pass - - -if __name__ == "__main__": - common_utils.run_tests() diff --git a/tests/L0/run_transformer/test_transformer_module.py b/tests/L0/run_transformer/test_transformer_module.py deleted file mode 100644 index 77ce67e..0000000 --- a/tests/L0/run_transformer/test_transformer_module.py +++ /dev/null @@ -1,107 +0,0 @@ -from typing import Tuple -import os -import subprocess -import sys -import unittest - - -SEVERALGPU_TEST = [ - "bert_minimal_test", - "gpt_minimal_test", - "dynamic_batchsize_test", -] - - -def get_multigpu_launch_option(min_gpu): - should_skip = False - import torch - - num_devices = torch.cuda.device_count() - if num_devices < min_gpu: - should_skip = True - distributed_run_options = f"-m torch.distributed.run --nproc_per_node={num_devices}" - return should_skip, distributed_run_options - - -def get_launch_option(test_filename) -> Tuple[bool, str]: - should_skip = False - for severalgpu_test in SEVERALGPU_TEST: - if severalgpu_test in test_filename: - return get_multigpu_launch_option(3) - return should_skip, "" - - -def run_transformer_tests(): - python_executable_path = sys.executable - directory = os.path.dirname(__file__) - files = [ - os.path.join(directory, f) - for f in os.listdir(directory) - if f.startswith("run_") and os.path.isfile(os.path.join(directory, f)) - ] - print("#######################################################") - print(f"# Python executable path: {python_executable_path}") - print(f"# {len(files)} tests: {files}") - print("#######################################################") - errors = [] - for i, test_file in enumerate(files, 1): - is_denied = False - should_skip, launch_option = get_launch_option(test_file) - if should_skip: - print( - f"### {i} / {len(files)}: {test_file} skipped. Requires multiple GPUs." - ) - continue - test_run_cmd = ( - f"{python_executable_path} {launch_option} {test_file} " - "--micro-batch-size 2 --num-layers 16 --hidden-size 256 --num-attention-heads 8 --max-position-embeddings " - "512 --seq-length 512 --global-batch-size 128" - ) - if "bert" in test_file or "gpt" in test_file: - import torch - - num_devices = torch.cuda.device_count() - if "bert" in test_file: - # "bert" uses the interleaving. - tensor_model_parallel_size = 2 if num_devices % 2 == 0 and num_devices > 4 else 1 - if "gpt" in test_file: - # "gpt" uses the non-interleaving. - tensor_model_parallel_size = 2 if num_devices % 2 == 0 and num_devices >= 4 else 1 - pipeline_model_parallel_size = num_devices // tensor_model_parallel_size - test_run_cmd += f" --pipeline-model-parallel-size {pipeline_model_parallel_size} --tensor-model-parallel-size {tensor_model_parallel_size}" - - if "bert" in test_file: - test_run_cmd += f" --bert-no-binary-head" - else: - test_run_cmd += f" --use-cpu-initialization" - print(f"### {i} / {len(files)}: cmd: {test_run_cmd}") - try: - output = ( - subprocess.check_output(test_run_cmd, shell=True) - .decode(sys.stdout.encoding) - .strip() - ) - except Exception as e: - errors.append((test_file, str(e))) - else: - if ">> passed the test :-)" not in output: - errors.append((test_file, output)) - else: - if not errors: - print("### PASSED") - else: - print("### FAILED") - short_msg = f"{len(errors)} out of {len(files)} tests failed" - print(short_msg) - for (filename, log) in errors: - print(f"File: {filename}\nLog: {log}") - raise RuntimeError(short_msg) - - -class TestTransformer(unittest.TestCase): - def test_transformer(self): - run_transformer_tests() - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/L0/run_transformer/test_transformer_utils.py b/tests/L0/run_transformer/test_transformer_utils.py deleted file mode 100644 index d5d1608..0000000 --- a/tests/L0/run_transformer/test_transformer_utils.py +++ /dev/null @@ -1,40 +0,0 @@ -import logging - -import torch -from torch.testing._internal import common_utils - -logging.getLogger("torch").setLevel(logging.WARNING) - -from apex.transformer import parallel_state -from apex.transformer.tensor_parallel import utils -from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase - -logging.getLogger("apex").setLevel(logging.WARNING) - - -class TransformerUtilsTest(NcclDistributedTestBase): - def test_split_tensor_along_last_dim(self): - for tensor_model_paralell_world_size in range(1, self.world_size + 1): - if self.world_size % tensor_model_paralell_world_size > 0: - continue - with self.subTest( - tensor_model_paralell_world_size=tensor_model_paralell_world_size - ): - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=tensor_model_paralell_world_size - ) - - device = "cpu" - input_tensor = torch.randn((100, 100, 100), device=device) - splits = utils.split_tensor_along_last_dim(input_tensor, 10) - last_dim_shapes = torch.tensor( - [int(split.size()[-1]) for split in splits] - ) - - self.assertTrue(torch.equal(last_dim_shapes, torch.full((10,), 10),)) - - parallel_state.destroy_model_parallel() - - -if __name__ == "__main__": - common_utils.run_tests() diff --git a/tests/L1/common/compare.py b/tests/L1/common/compare.py deleted file mode 100644 index 74374d4..0000000 --- a/tests/L1/common/compare.py +++ /dev/null @@ -1,64 +0,0 @@ -import argparse -import torch - -parser = argparse.ArgumentParser(description='Compare') -parser.add_argument('--opt-level', type=str) -parser.add_argument('--keep-batchnorm-fp32', type=str, default=None) -parser.add_argument('--loss-scale', type=str, default=None) -parser.add_argument('--fused-adam', action='store_true') -parser.add_argument('--use_baseline', action='store_true') -args = parser.parse_args() - -base_file = str(args.opt_level) + "_" +\ - str(args.loss_scale) + "_" +\ - str(args.keep_batchnorm_fp32) + "_" +\ - str(args.fused_adam) - -file_e = "True_" + base_file -file_p = "False_" + base_file -if args.use_baseline: - file_b = "baselines/True_" + base_file - -dict_e = torch.load(file_e) -dict_p = torch.load(file_p) -if args.use_baseline: - dict_b = torch.load(file_b) - -torch.set_printoptions(precision=10) - -print(file_e) -print(file_p) -if args.use_baseline: - print(file_b) - -# ugly duplication here... -if not args.use_baseline: - for n, (i_e, i_p) in enumerate(zip(dict_e["Iteration"], dict_p["Iteration"])): - assert i_e == i_p, "i_e = {}, i_p = {}".format(i_e, i_p) - - loss_e = dict_e["Loss"][n] - loss_p = dict_p["Loss"][n] - assert loss_e == loss_p, "Iteration {}, loss_e = {}, loss_p = {}".format(i_e, loss_e, loss_p) - print("{:4} {:15.10f} {:15.10f} {:15.10f} {:15.10f}".format( - i_e, - loss_e, - loss_p, - dict_e["Speed"][n], - dict_p["Speed"][n])) -else: - for n, (i_e, i_p) in enumerate(zip(dict_e["Iteration"], dict_p["Iteration"])): - assert i_e == i_p, "i_e = {}, i_p = {}".format(i_e, i_p) - - loss_e = dict_e["Loss"][n] - loss_p = dict_p["Loss"][n] - loss_b = dict_b["Loss"][n] - assert loss_e == loss_p, "Iteration {}, loss_e = {}, loss_p = {}".format(i_e, loss_e, loss_p) - assert loss_e == loss_b, "Iteration {}, loss_e = {}, loss_b = {}".format(i_e, loss_e, loss_b) - print("{:4} {:15.10f} {:15.10f} {:15.10f} {:15.10f} {:15.10f} {:15.10f}".format( - i_e, - loss_b, - loss_e, - loss_p, - dict_b["Speed"][n], - dict_e["Speed"][n], - dict_p["Speed"][n])) diff --git a/tests/L1/common/main_amp.py b/tests/L1/common/main_amp.py deleted file mode 100644 index 106a0f6..0000000 --- a/tests/L1/common/main_amp.py +++ /dev/null @@ -1,526 +0,0 @@ -import argparse -import os -import shutil -import time - -import torch -import torch.nn as nn -import torch.nn.parallel -import torch.backends.cudnn as cudnn -import torch.distributed as dist -import torch.optim -import torch.utils.data -import torch.utils.data.distributed -import torchvision.transforms as transforms -import torchvision.datasets as datasets -import torchvision.models as models - -import numpy as np - -try: - from apex.parallel import DistributedDataParallel as DDP - from apex.fp16_utils import * - from apex import amp, optimizers - from apex.multi_tensor_apply import multi_tensor_applier -except ImportError: - raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") - -model_names = sorted(name for name in models.__dict__ - if name.islower() and not name.startswith("__") - and callable(models.__dict__[name])) - -parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') -parser.add_argument('data', metavar='DIR', - help='path to dataset') -parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', - choices=model_names, - help='model architecture: ' + - ' | '.join(model_names) + - ' (default: resnet18)') -parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', - help='number of data loading workers (default: 4)') -parser.add_argument('--epochs', default=90, type=int, metavar='N', - help='number of total epochs to run') -parser.add_argument('--start-epoch', default=0, type=int, metavar='N', - help='manual epoch number (useful on restarts)') -parser.add_argument('-b', '--batch-size', default=256, type=int, - metavar='N', help='mini-batch size per process (default: 256)') -parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, - metavar='LR', help='Initial learning rate. Will be scaled by /256: args.lr = args.lr*float(args.batch_size*args.world_size)/256. A warmup schedule will also be applied over the first 5 epochs.') -parser.add_argument('--momentum', default=0.9, type=float, metavar='M', - help='momentum') -parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, - metavar='W', help='weight decay (default: 1e-4)') -parser.add_argument('--print-freq', '-p', default=10, type=int, - metavar='N', help='print frequency (default: 10)') -parser.add_argument('--resume', default='', type=str, metavar='PATH', - help='path to latest checkpoint (default: none)') -parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', - help='evaluate model on validation set') -parser.add_argument('--pretrained', dest='pretrained', action='store_true', - help='use pre-trained model') - -parser.add_argument('--prof', dest='prof', action='store_true', - help='Only run 10 iterations for profiling.') -parser.add_argument('--deterministic', action='store_true') - -parser.add_argument("--local_rank", default=0, type=int) -parser.add_argument('--sync_bn', action='store_true', - help='enabling apex sync BN.') - -parser.add_argument('--has-ext', action='store_true') -parser.add_argument('--opt-level', type=str) -parser.add_argument('--keep-batchnorm-fp32', type=str, default=None) -parser.add_argument('--loss-scale', type=str, default=None) -parser.add_argument('--fused-adam', action='store_true') - -parser.add_argument('--prints-to-process', type=int, default=10) - -cudnn.benchmark = True - -def fast_collate(batch): - imgs = [img[0] for img in batch] - targets = torch.tensor([target[1] for target in batch], dtype=torch.int64) - w = imgs[0].size[0] - h = imgs[0].size[1] - tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8 ) - for i, img in enumerate(imgs): - nump_array = np.asarray(img, dtype=np.uint8) - if(nump_array.ndim < 3): - nump_array = np.expand_dims(nump_array, axis=-1) - nump_array = np.rollaxis(nump_array, 2) - - tensor[i] += torch.from_numpy(nump_array) - - return tensor, targets - -best_prec1 = 0 -args = parser.parse_args() - -# Let multi_tensor_applier be the canary in the coalmine -# that verifies if the backend is what we think it is -assert multi_tensor_applier.available == args.has_ext - -print("opt_level = {}".format(args.opt_level)) -print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32)) -print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale)) - - -print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version())) - -if args.deterministic: - cudnn.benchmark = False - cudnn.deterministic = True - torch.manual_seed(args.local_rank) - torch.set_printoptions(precision=10) - -def main(): - global best_prec1, args - - args.distributed = False - if 'WORLD_SIZE' in os.environ: - args.distributed = int(os.environ['WORLD_SIZE']) > 1 - - args.gpu = 0 - args.world_size = 1 - - if args.distributed: - args.gpu = args.local_rank % torch.cuda.device_count() - torch.cuda.set_device(args.gpu) - torch.distributed.init_process_group(backend='nccl', - init_method='env://') - args.world_size = torch.distributed.get_world_size() - - assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled." - - # create model - if args.pretrained: - print("=> using pre-trained model '{}'".format(args.arch)) - model = models.__dict__[args.arch](pretrained=True) - else: - print("=> creating model '{}'".format(args.arch)) - model = models.__dict__[args.arch]() - - if args.sync_bn: - import apex - print("using apex synced BN") - model = apex.parallel.convert_syncbn_model(model) - - model = model.cuda() - - # Scale learning rate based on global batch size - args.lr = args.lr*float(args.batch_size*args.world_size)/256. - if args.fused_adam: - optimizer = optimizers.FusedAdam(model.parameters()) - else: - optimizer = torch.optim.SGD(model.parameters(), args.lr, - momentum=args.momentum, - weight_decay=args.weight_decay) - - model, optimizer = amp.initialize( - model, optimizer, - # enabled=False, - opt_level=args.opt_level, - keep_batchnorm_fp32=args.keep_batchnorm_fp32, - loss_scale=args.loss_scale - ) - - if args.distributed: - # By default, apex.parallel.DistributedDataParallel overlaps communication with - # computation in the backward pass. - # model = DDP(model) - # delay_allreduce delays all communication to the end of the backward pass. - model = DDP(model, delay_allreduce=True) - - # define loss function (criterion) and optimizer - criterion = nn.CrossEntropyLoss().cuda() - - # Optionally resume from a checkpoint - if args.resume: - # Use a local scope to avoid dangling references - def resume(): - if os.path.isfile(args.resume): - print("=> loading checkpoint '{}'".format(args.resume)) - checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu)) - args.start_epoch = checkpoint['epoch'] - best_prec1 = checkpoint['best_prec1'] - model.load_state_dict(checkpoint['state_dict']) - optimizer.load_state_dict(checkpoint['optimizer']) - print("=> loaded checkpoint '{}' (epoch {})" - .format(args.resume, checkpoint['epoch'])) - else: - print("=> no checkpoint found at '{}'".format(args.resume)) - resume() - - # Data loading code - traindir = os.path.join(args.data, 'train') - valdir = os.path.join(args.data, 'val') - - if(args.arch == "inception_v3"): - crop_size = 299 - val_size = 320 # I chose this value arbitrarily, we can adjust. - else: - crop_size = 224 - val_size = 256 - - train_dataset = datasets.ImageFolder( - traindir, - transforms.Compose([ - transforms.RandomResizedCrop(crop_size), - transforms.RandomHorizontalFlip(), - # transforms.ToTensor(), Too slow - # normalize, - ])) - val_dataset = datasets.ImageFolder(valdir, transforms.Compose([ - transforms.Resize(val_size), - transforms.CenterCrop(crop_size), - ])) - - train_sampler = None - val_sampler = None - if args.distributed: - train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) - val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) - - train_loader = torch.utils.data.DataLoader( - train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), - num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=fast_collate) - - val_loader = torch.utils.data.DataLoader( - val_dataset, - batch_size=args.batch_size, shuffle=False, - num_workers=args.workers, pin_memory=True, - sampler=val_sampler, - collate_fn=fast_collate) - - if args.evaluate: - validate(val_loader, model, criterion) - return - - for epoch in range(args.start_epoch, args.epochs): - if args.distributed: - train_sampler.set_epoch(epoch) - - # train for one epoch - train(train_loader, model, criterion, optimizer, epoch) - if args.prof: - break - # evaluate on validation set - prec1 = validate(val_loader, model, criterion) - - # remember best prec@1 and save checkpoint - if args.local_rank == 0: - is_best = prec1 > best_prec1 - best_prec1 = max(prec1, best_prec1) - save_checkpoint({ - 'epoch': epoch + 1, - 'arch': args.arch, - 'state_dict': model.state_dict(), - 'best_prec1': best_prec1, - 'optimizer' : optimizer.state_dict(), - }, is_best) - -class data_prefetcher(): - def __init__(self, loader): - self.loader = iter(loader) - self.stream = torch.cuda.Stream() - self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1) - self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1) - # With Amp, it isn't necessary to manually convert data to half. - # if args.fp16: - # self.mean = self.mean.half() - # self.std = self.std.half() - self.preload() - - def preload(self): - try: - self.next_input, self.next_target = next(self.loader) - except StopIteration: - self.next_input = None - self.next_target = None - return - with torch.cuda.stream(self.stream): - self.next_input = self.next_input.cuda(non_blocking=True) - self.next_target = self.next_target.cuda(non_blocking=True) - # With Amp, it isn't necessary to manually convert data to half. - # if args.fp16: - # self.next_input = self.next_input.half() - # else: - self.next_input = self.next_input.float() - self.next_input = self.next_input.sub_(self.mean).div_(self.std) - - def next(self): - torch.cuda.current_stream().wait_stream(self.stream) - input = self.next_input - target = self.next_target - self.preload() - return input, target - - -def train(train_loader, model, criterion, optimizer, epoch): - batch_time = AverageMeter() - data_time = AverageMeter() - losses = AverageMeter() - top1 = AverageMeter() - top5 = AverageMeter() - - # switch to train mode - model.train() - end = time.time() - - run_info_dict = {"Iteration" : [], - "Loss" : [], - "Speed" : []} - - prefetcher = data_prefetcher(train_loader) - input, target = prefetcher.next() - i = -1 - while input is not None: - i += 1 - - # No learning rate warmup for this test, to expose bitwise inaccuracies more quickly - # adjust_learning_rate(optimizer, epoch, i, len(train_loader)) - - if args.prof: - if i > 10: - break - # measure data loading time - data_time.update(time.time() - end) - - # compute output - output = model(input) - loss = criterion(output, target) - - # measure accuracy and record loss - prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) - - if args.distributed: - reduced_loss = reduce_tensor(loss.data) - prec1 = reduce_tensor(prec1) - prec5 = reduce_tensor(prec5) - else: - reduced_loss = loss.data - - losses.update(to_python_float(reduced_loss), input.size(0)) - top1.update(to_python_float(prec1), input.size(0)) - top5.update(to_python_float(prec5), input.size(0)) - - # compute gradient and do SGD step - optimizer.zero_grad() - - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - - # for param in model.parameters(): - # print(param.data.double().sum().item(), param.grad.data.double().sum().item()) - - # torch.cuda.synchronize() - torch.cuda.nvtx.range_push("step") - optimizer.step() - torch.cuda.nvtx.range_pop() - - torch.cuda.synchronize() - # measure elapsed time - batch_time.update(time.time() - end) - - end = time.time() - - # If you decide to refactor this test, like examples/imagenet, to sample the loss every - # print_freq iterations, make sure to move this prefetching below the accuracy calculation. - input, target = prefetcher.next() - - if i % args.print_freq == 0 and i > 1: - if args.local_rank == 0: - print('Epoch: [{0}][{1}/{2}]\t' - 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' - 'Speed {3:.3f} ({4:.3f})\t' - 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' - 'Loss {loss.val:.10f} ({loss.avg:.4f})\t' - 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' - 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( - epoch, i, len(train_loader), - args.world_size * args.batch_size / batch_time.val, - args.world_size * args.batch_size / batch_time.avg, - batch_time=batch_time, - data_time=data_time, loss=losses, top1=top1, top5=top5)) - run_info_dict["Iteration"].append(i) - run_info_dict["Loss"].append(losses.val) - run_info_dict["Speed"].append(args.world_size * args.batch_size / batch_time.val) - if len(run_info_dict["Loss"]) == args.prints_to_process: - if args.local_rank == 0: - torch.save(run_info_dict, - str(args.has_ext) + "_" + str(args.opt_level) + "_" + - str(args.loss_scale) + "_" + str(args.keep_batchnorm_fp32) + "_" + - str(args.fused_adam)) - quit() - - -def validate(val_loader, model, criterion): - batch_time = AverageMeter() - losses = AverageMeter() - top1 = AverageMeter() - top5 = AverageMeter() - - # switch to evaluate mode - model.eval() - - end = time.time() - - prefetcher = data_prefetcher(val_loader) - input, target = prefetcher.next() - i = -1 - while input is not None: - i += 1 - - # compute output - with torch.no_grad(): - output = model(input) - loss = criterion(output, target) - - # measure accuracy and record loss - prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) - - if args.distributed: - reduced_loss = reduce_tensor(loss.data) - prec1 = reduce_tensor(prec1) - prec5 = reduce_tensor(prec5) - else: - reduced_loss = loss.data - - losses.update(to_python_float(reduced_loss), input.size(0)) - top1.update(to_python_float(prec1), input.size(0)) - top5.update(to_python_float(prec5), input.size(0)) - - # measure elapsed time - batch_time.update(time.time() - end) - end = time.time() - - if args.local_rank == 0 and i % args.print_freq == 0: - print('Test: [{0}/{1}]\t' - 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' - 'Speed {2:.3f} ({3:.3f})\t' - 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' - 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' - 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( - i, len(val_loader), - args.world_size * args.batch_size / batch_time.val, - args.world_size * args.batch_size / batch_time.avg, - batch_time=batch_time, loss=losses, - top1=top1, top5=top5)) - - input, target = prefetcher.next() - - print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' - .format(top1=top1, top5=top5)) - - return top1.avg - - -def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): - torch.save(state, filename) - if is_best: - shutil.copyfile(filename, 'model_best.pth.tar') - - -class AverageMeter(object): - """Computes and stores the average and current value""" - def __init__(self): - self.reset() - - def reset(self): - self.val = 0 - self.avg = 0 - self.sum = 0 - self.count = 0 - - def update(self, val, n=1): - self.val = val - self.sum += val * n - self.count += n - self.avg = self.sum / self.count - - -def adjust_learning_rate(optimizer, epoch, step, len_epoch): - """LR schedule that should yield 76% converged accuracy with batch size 256""" - factor = epoch // 30 - - if epoch >= 80: - factor = factor + 1 - - lr = args.lr*(0.1**factor) - - """Warmup""" - if epoch < 5: - lr = lr*float(1 + step + epoch*len_epoch)/(5.*len_epoch) - - # if(args.local_rank == 0): - # print("epoch = {}, step = {}, lr = {}".format(epoch, step, lr)) - - for param_group in optimizer.param_groups: - param_group['lr'] = lr - - -def accuracy(output, target, topk=(1,)): - """Computes the precision@k for the specified values of k""" - maxk = max(topk) - batch_size = target.size(0) - - _, pred = output.topk(maxk, 1, True, True) - pred = pred.t() - correct = pred.eq(target.view(1, -1).expand_as(pred)) - - res = [] - for k in topk: - correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) - res.append(correct_k.mul_(100.0 / batch_size)) - return res - - -def reduce_tensor(tensor): - rt = tensor.clone() - dist.all_reduce(rt, op=dist.reduce_op.SUM) - rt /= args.world_size - return rt - -if __name__ == '__main__': - main() diff --git a/tests/L1/common/run_test.sh b/tests/L1/common/run_test.sh deleted file mode 100644 index f4ae06c..0000000 --- a/tests/L1/common/run_test.sh +++ /dev/null @@ -1,144 +0,0 @@ -#!/bin/bash - -print_banner() { - printf "\n\n\n\e[30m\e[42m$1\e[0m\n\n\n\n" -} - -print_banner "Distributed status: $1" - -echo $2 -DATADIR=$2 - -if [ -n "$3" ] -then - USE_BASELINE="" -else - USE_BASELINE="--use_baseline" -fi - -if [ "$1" == "single_gpu" ] -then - BASE_CMD="python main_amp.py -a resnet50 --b 128 --workers 4 --deterministic --prints-to-process 5" -fi - -if [ "$1" == "distributed" ] -then - BASE_CMD="python -m torch.distributed.launch --nproc_per_node=2 main_amp.py -a resnet50 --b 128 --workers 4 --deterministic --prints-to-process 5" -fi - -ADAM_ARGS="--opt-level O2 --keep-batchnorm-fp32 False --fused-adam" - -keep_batchnorms=( -"" -"--keep-batchnorm-fp32 True" -"--keep-batchnorm-fp32 False" -) - -loss_scales=( -"" -"--loss-scale 1.0" -"--loss-scale 128.0" -"--loss-scale dynamic" -) - -opt_levels=( -"O0" -"O1" -"O2" -"O3" -) - -rm True* -rm False* - -set -e - -print_banner "Installing Apex with --cuda_ext and --cpp_ext" - -pushd ../../.. -pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" . -popd - -for opt_level in "${opt_levels[@]}" -do - for loss_scale in "${loss_scales[@]}" - do - for keep_batchnorm in "${keep_batchnorms[@]}" - do - if [ "$opt_level" == "O1" ] && [ -n "${keep_batchnorm}" ] - then - print_banner "Skipping ${opt_level} ${loss_scale} ${keep_batchnorm}" - continue - fi - print_banner "${BASE_CMD} --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm} --has-ext $DATADIR" - set -x - ${BASE_CMD} --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm} --has-ext $DATADIR - set +x - done - done -done - -# Handle FusedAdam separately due to limited support. -# FusedAdam will not be tested for bitwise accuracy against the Python implementation. -# The L0 tests already do so. These tests are here to ensure that it actually runs, -# and get an idea of performance. -for loss_scale in "${loss_scales[@]}" -do - print_banner "${BASE_CMD} ${ADAM_ARGS} ${loss_scale} --has-ext $DATADIR" - set -x - ${BASE_CMD} ${ADAM_ARGS} ${loss_scale} --has-ext $DATADIR - set +x -done - -print_banner "Reinstalling apex without extensions" - -pushd ../../.. -pip install -v --no-cache-dir . -popd - -for opt_level in "${opt_levels[@]}" -do - for loss_scale in "${loss_scales[@]}" - do - for keep_batchnorm in "${keep_batchnorms[@]}" - do - if [ "$opt_level" == "O1" ] && [ -n "${keep_batchnorm}" ] - then - print_banner "Skipping ${opt_level} ${loss_scale} ${keep_batchnorm}" - continue - fi - print_banner "${BASE_CMD} --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm} $DATADIR" - set -x - ${BASE_CMD} --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm} $DATADIR - set +x - done - done -done - -print_banner "Checking for bitwise accuracy between Python-only and cpp/cuda extension installs" - -for opt_level in "${opt_levels[@]}" -do - for loss_scale in "${loss_scales[@]}" - do - for keep_batchnorm in "${keep_batchnorms[@]}" - do - echo "" - if [ "$opt_level" == "O1" ] && [ -n "${keep_batchnorm}" ] - then - echo "Skipping ${opt_level} ${loss_scale} ${keep_batchnorm}" - continue - fi - echo "${BASE_CMD} --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm} [--has-ext] $DATADIR" - set -x - python compare.py --opt-level ${opt_level} ${loss_scale} ${keep_batchnorm} --use_baseline - set +x - done - done -done - -print_banner "Reinstalling Apex with --cuda_ext and --cpp_ext" - -pushd ../../.. -pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" . -popd diff --git a/tests/L1/cross_product/run.sh b/tests/L1/cross_product/run.sh deleted file mode 100644 index 7ccf9ec..0000000 --- a/tests/L1/cross_product/run.sh +++ /dev/null @@ -1,6 +0,0 @@ -#!/bin/bash - -# DATADIR="/home/mcarilli/Desktop/pt18data/apex_stale/examples/imagenet/bare_metal_train_val/" -# DATADIR="/opt/home/apex/examples/imagenet/" -cp ../common/* . -bash run_test.sh single_gpu $1 diff --git a/tests/L1/cross_product_distributed/run.sh b/tests/L1/cross_product_distributed/run.sh deleted file mode 100644 index 917ec11..0000000 --- a/tests/L1/cross_product_distributed/run.sh +++ /dev/null @@ -1,4 +0,0 @@ -#!/bin/bash - -cp ../common/* . -bash run_test.sh distributed $1 diff --git a/tests/L1/transformer/pipeline_parallel_fwd_bwd_ucc_async.py b/tests/L1/transformer/pipeline_parallel_fwd_bwd_ucc_async.py deleted file mode 100644 index 786c0ed..0000000 --- a/tests/L1/transformer/pipeline_parallel_fwd_bwd_ucc_async.py +++ /dev/null @@ -1,219 +0,0 @@ -import os -import logging -import itertools -from typing import Optional, Tuple, List -import unittest - -import torch -from torch.testing._internal import common_utils -from torch.testing._internal import common_cuda -from torch.testing._internal import common_distributed - -from apex._autocast_utils import _get_autocast_dtypes -from apex.transformer import parallel_state -from apex.transformer.pipeline_parallel import utils as pp_utils -from apex.transformer.pipeline_parallel.schedules.common import ( - FwdStepFunc, - build_model, - _get_params_for_weight_decay_optimization, -) -from apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining import ( - forward_backward_no_pipelining, -) -from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving import ( - _forward_backward_pipelining_with_interleaving, -) -from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import ( - forward_backward_pipelining_without_interleaving, -) -from apex.transformer.testing.distributed_test_base import UccDistributedTestBase -from apex.transformer.testing import commons as testing_utils - - -logging.getLogger("torch").setLevel(logging.WARNING) -logging.getLogger("apex").setLevel(logging.WARNING) - - -def _get_default_world_sizes_model_parallel_world_size(pipeline_model_parallel_world_size: Optional[int] = None - ) -> Tuple[int, int, int]: - # TODO: revisit if we can fold this into the class for skip logic / avoid duplication - # of world size computation - world_size = torch.cuda.device_count() - tensor_model_parallel_world_size = 1 - data_parallel_size = 1 + (world_size >= 8 and world_size % 2 == 0) - - if pipeline_model_parallel_world_size is None: - pipeline_model_parallel_world_size = world_size // (tensor_model_parallel_world_size * data_parallel_size) - else: - data_parallel_size = world_size // (tensor_model_parallel_world_size * pipeline_model_parallel_world_size) - - return tensor_model_parallel_world_size, data_parallel_size, pipeline_model_parallel_world_size - - -class UccPipelineParallelForwardBackwardProf(UccDistributedTestBase): - - # The purpose of this class is to test and confirm asynchronous communication via profiling. - # Having that in mind, it is safe to skip all the numerical checks. - # For unit testing with numerical checks please refer to `tests/L0/run_transformer/test_pipeline_parallel_fwd_bwd.py`. - - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.GLOBAL_BATCH_SIZE = 1024 - self.MICRO_BATCH_SIZE = 64 - self.HIDDEN_SIZE = 256 - self.NUM_FWD_BWD_ITERATIONS = 4 - self.deallocate_options = (False,) - self.dtypes = (torch.float32,) - - @property - def world_size(self) -> int: - return min(torch.cuda.device_count(), 8) - - def _forward_backward_test_impl( - self, - forward_only: bool, - fwd_bwd_func: FwdStepFunc, - pipeline_model_parallel_world_size: Optional[int], - virtual_pipeline_model_parallel_size: Optional[int], - async_comm: bool = False, - *, - default_backend: Optional[str] = None, - p2p_backend: Optional[str] = None, - ) -> None: - if fwd_bwd_func == _forward_backward_pipelining_with_interleaving: - self.assertIsNotNone(virtual_pipeline_model_parallel_size) - self.assertGreater(virtual_pipeline_model_parallel_size, 1) - dtype_options = self.dtypes or [torch.float32, torch.double] + _get_autocast_dtypes() - - for dtype, deallocate_pipeline_outputs in itertools.product( - dtype_options, self.deallocate_options, - ): - grad_scaler = ( - torch.cuda.amp.GradScaler(init_scale=4.0) - if dtype == torch.half - else None - ) - - (tensor_model_parallel_world_size, - data_parallel_size, - pipeline_model_parallel_world_size) = _get_default_world_sizes_model_parallel_world_size(pipeline_model_parallel_world_size) - - parallel_state.initialize_model_parallel( - tensor_model_parallel_size_=tensor_model_parallel_world_size, - pipeline_model_parallel_size_=pipeline_model_parallel_world_size, - virtual_pipeline_model_parallel_size_=virtual_pipeline_model_parallel_size, - default_backend=default_backend, - p2p_backend=p2p_backend, - ) - pp_utils._reconfigure_microbatch_calculator( - rank=parallel_state.get_tensor_model_parallel_rank(), - rampup_batch_size=None, - global_batch_size=self.GLOBAL_BATCH_SIZE, - micro_batch_size=self.MICRO_BATCH_SIZE, - data_parallel_size=parallel_state.get_data_parallel_world_size(), - ) - - global_batch_shape = ( - self.GLOBAL_BATCH_SIZE - // parallel_state.get_data_parallel_world_size(), - self.HIDDEN_SIZE, - self.HIDDEN_SIZE, - ) - - batch = None - if parallel_state.is_pipeline_first_stage(): - batch = (torch.ones(global_batch_shape, dtype=dtype).cuda(), ) - - model = build_model( - testing_utils.model_provider_func, - # Use DDP only when it's better to have - wrap_with_ddp=data_parallel_size > 1, - virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size, - hidden_size=self.HIDDEN_SIZE, - ) - - - offset = pipeline_model_parallel_world_size if virtual_pipeline_model_parallel_size is not None else 0 - for idx, model_module in enumerate(model): - model_module = model_module.to(dtype) - - _param_groups = _get_params_for_weight_decay_optimization(model) - optimizer = torch.optim.Adam(_param_groups, lr=1e-3) - - pp_utils.update_num_microbatches(0) - - for _ in range(self.NUM_FWD_BWD_ITERATIONS): - loss = fwd_bwd_func( - testing_utils.fwd_step_func, - batch, - model, - forward_only=forward_only, - # `tensor_shape` is the shape of micro batch. - tensor_shape=( - self.MICRO_BATCH_SIZE, - self.HIDDEN_SIZE, - self.HIDDEN_SIZE, - ), - dtype=dtype, - async_comm=async_comm, - grad_scaler=grad_scaler, - deallocate_pipeline_output=deallocate_pipeline_outputs, - ) - - parallel_state.destroy_model_parallel() - - def test_learning_no_pipelining(self): - self._forward_backward_test_impl(False, forward_backward_no_pipelining, 1, None) - - def test_inference_no_pipelining(self): - self._forward_backward_test_impl(True, forward_backward_no_pipelining, 1, None) - - def test_learning_pipelining_without_interleaving(self): - self._forward_backward_test_impl( - False, forward_backward_pipelining_without_interleaving, None, None - ) - - def test_inference_pipelining_without_interleaving(self): - self._forward_backward_test_impl( - True, forward_backward_pipelining_without_interleaving, None, None - ) - - def test_learning_async_pipelining_without_interleaving(self): - self._forward_backward_test_impl( - False, forward_backward_pipelining_without_interleaving, None, None, async_comm=True - ) - - def test_inference_async_pipelining_without_interleaving(self): - self._forward_backward_test_impl( - True, forward_backward_pipelining_without_interleaving, None, None, async_comm=True - ) - - @unittest.skipUnless(_get_default_world_sizes_model_parallel_world_size()[-1] > 2, "Interleaved schedule requires pipeline_model_parallel_world_size > 2") - def test_learning_pipelining_with_interleaving(self): - self._forward_backward_test_impl( - False, _forward_backward_pipelining_with_interleaving, None, virtual_pipeline_model_parallel_size=2 - ) - - @unittest.skipUnless(_get_default_world_sizes_model_parallel_world_size()[-1] > 2, "Interleaved schedule requires pipeline_model_parallel_world_size > 2") - def test_inference_pipelining_with_interleaving(self): - self._forward_backward_test_impl( - True, _forward_backward_pipelining_with_interleaving, None, virtual_pipeline_model_parallel_size=2 - ) - - @unittest.skipUnless(_get_default_world_sizes_model_parallel_world_size()[-1] > 2, "Interleaved schedule requires pipeline_model_parallel_world_size > 2") - def test_learning_async_pipelining_with_interleaving(self): - self._forward_backward_test_impl( - False, _forward_backward_pipelining_with_interleaving, None, virtual_pipeline_model_parallel_size=2, async_comm=True - ) - - @unittest.skipUnless(_get_default_world_sizes_model_parallel_world_size()[-1] > 2, "Interleaved schedule requires pipeline_model_parallel_world_size > 2") - def test_inference_async_pipelining_with_interleaving(self): - self._forward_backward_test_impl( - True, _forward_backward_pipelining_with_interleaving, None, virtual_pipeline_model_parallel_size=2, async_comm=True - ) - - -if __name__ == "__main__": - os.environ["UCC_TLS"] = "ucp,cuda" - common_distributed.TIMEOUT_DEFAULT = 500 - common_utils.run_tests() diff --git a/tests/distributed/DDP/ddp_race_condition_test.py b/tests/distributed/DDP/ddp_race_condition_test.py deleted file mode 100644 index 761a335..0000000 --- a/tests/distributed/DDP/ddp_race_condition_test.py +++ /dev/null @@ -1,69 +0,0 @@ -import torch -import torch.distributed as dist -from torch.nn import Parameter -from torch.nn import Module -from apex.parallel import DistributedDataParallel as DDP -import argparse -import os - - -parser = argparse.ArgumentParser(description='allreduce hook example') -parser.add_argument("--local_rank", default=0, type=int) -args = parser.parse_args() - -args.distributed = False -if 'WORLD_SIZE' in os.environ: - args.distributed = int(os.environ['WORLD_SIZE']) > 1 - -if args.distributed: - args.gpu = args.local_rank % torch.cuda.device_count() - torch.cuda.set_device(args.gpu) - torch.distributed.init_process_group(backend='nccl', - init_method='env://') - args.world_size = torch.distributed.get_world_size() - -torch.set_printoptions(precision=10) -torch.manual_seed(args.local_rank) - -class Model(Module): - def __init__(self): - super(Model, self).__init__() - self.a = Parameter(torch.cuda.FloatTensor(4096*4096).fill_(1.0)) - self.b = Parameter(torch.cuda.FloatTensor(4096*4096).fill_(2.0)) - def forward(self, input): - return (input*self.a)*self.b - -model = Model() -# model = DDP(model, message_size=1, gradient_predivide_factor=8.0) -# model = DDP(model, delay_allreduce=True) -# model = DDP(model, message_size=1, allreduce_trigger_params=[model.b]) -model = DDP(model, message_size=1, allreduce_trigger_params=[model.b], num_allreduce_streams=3) - -x = torch.cuda.FloatTensor(4096*4096) - -passed = True -torch.cuda.cudart().cudaProfilerStart() -for i in range(10): - x.fill_(i + args.local_rank) # fill x with new values every iteration for sanity - model.zero_grad() - out = model(x) - loss = out.sum() - # torch.cuda.nvtx.range_push("backward") - loss.backward() - # torch.cuda.nvtx.range_pop() - - # torch.cuda.nvtx.range_push("synchronize() + info") - # torch.cuda.synchronize() - print("i = {}".format(i)) - def info(name, param, val): - expected = val*4096*4096*(2.*i+1)/2. - actual = param.grad.data.sum().item() - print(name+": grad.data_ptr() = {}, expected sum {}, got {}".format( - param.grad.data_ptr(), expected, actual)) - return (expected == actual) - if not info("model.a", model.module.a, 2.): passed = False - if not info("model.b", model.module.b, 1.): passed = False - # torch.cuda.nvtx.range_pop() -torch.cuda.cudart().cudaProfilerStop() - -print("passed = ", passed) diff --git a/tests/distributed/DDP/run_race_test.sh b/tests/distributed/DDP/run_race_test.sh deleted file mode 100644 index 2c2bd26..0000000 --- a/tests/distributed/DDP/run_race_test.sh +++ /dev/null @@ -1,3 +0,0 @@ -#!/bin/bash - -CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 ddp_race_condition_test.py diff --git a/tests/distributed/amp_master_params/amp_master_params.py b/tests/distributed/amp_master_params/amp_master_params.py deleted file mode 100644 index 4b3a804..0000000 --- a/tests/distributed/amp_master_params/amp_master_params.py +++ /dev/null @@ -1,71 +0,0 @@ -import torch -import argparse -import os -from apex import amp -# FOR DISTRIBUTED: (can also use torch.nn.parallel.DistributedDataParallel instead) -from apex.parallel import DistributedDataParallel - -parser = argparse.ArgumentParser() -# FOR DISTRIBUTED: Parse for the local_rank argument, which will be supplied -# automatically by torch.distributed.launch. -parser.add_argument("--local_rank", default=0, type=int) -parser.add_argument("--opt_level", default="O2", type=str) -args = parser.parse_args() - -# FOR DISTRIBUTED: If we are running under torch.distributed.launch, -# the 'WORLD_SIZE' environment variable will also be set automatically. -args.distributed = False -if 'WORLD_SIZE' in os.environ: - args.distributed = int(os.environ['WORLD_SIZE']) > 1 - -if args.distributed: - # FOR DISTRIBUTED: Set the device according to local_rank. - torch.cuda.set_device(args.local_rank) - - # FOR DISTRIBUTED: Initialize the backend. torch.distributed.launch will provide - # environment variables, and requires that you use init_method=`env://`. - torch.distributed.init_process_group(backend='nccl', - init_method='env://') - - torch.manual_seed(torch.distributed.get_rank()) - -torch.backends.cudnn.benchmark = True - -N, D_in, D_out = 64, 1024, 16 - -# Each process receives its own batch of "fake input data" and "fake target data." -# The "training loop" in each process just uses this fake batch over and over. -# https://github.com/NVIDIA/apex/tree/master/examples/imagenet provides a more realistic -# example of distributed data sampling for both training and validation. -x = torch.randn(N, D_in, device='cuda') -y = torch.randn(N, D_out, device='cuda') - -model = torch.nn.Linear(D_in, D_out).cuda() -optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) - -model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level) - -if args.distributed: - # FOR DISTRIBUTED: After amp.initialize, wrap the model with - # apex.parallel.DistributedDataParallel. - model = DistributedDataParallel(model) - # torch.nn.parallel.DistributedDataParallel is also fine, with some added args: - # model = torch.nn.parallel.DistributedDataParallel(model, - # device_ids=[args.local_rank], - # output_device=args.local_rank) - -loss_fn = torch.nn.MSELoss() - -for t in range(500): - optimizer.zero_grad() - y_pred = model(x) - loss = loss_fn(y_pred, y) - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - optimizer.step() - -if args.local_rank == 0: - print("final loss = ", loss) - -torch.save(list(model.parameters()), "rank{}model.pth".format(torch.distributed.get_rank())) -torch.save(list(amp.master_params(optimizer)), "rank{}master.pth".format(torch.distributed.get_rank())) diff --git a/tests/distributed/amp_master_params/compare.py b/tests/distributed/amp_master_params/compare.py deleted file mode 100644 index b804775..0000000 --- a/tests/distributed/amp_master_params/compare.py +++ /dev/null @@ -1,31 +0,0 @@ -import torch - -model_params_rank0 = torch.load("rank0model.pth", - map_location = lambda storage, loc: storage.cuda(0)) -model_params_rank1 = torch.load("rank1model.pth", - map_location = lambda storage, loc: storage.cuda(0)) -master_params_rank0 = torch.load("rank0master.pth", - map_location = lambda storage, loc: storage.cuda(0)) -master_params_rank1 = torch.load("rank1master.pth", - map_location = lambda storage, loc: storage.cuda(0)) - -for model_rank0, model_rank1, master_rank0, master_rank1 in zip( - model_params_rank0, - model_params_rank1, - master_params_rank0, - master_params_rank1): - # converting model params to float is a hack since allclose doesn't support bfloat16 yet. - model_rank0 = model_rank0.float() - model_rank1 = model_rank1.float() - assert torch.allclose(model_rank0, model_rank1), "Model param mismatch" - assert torch.allclose(master_rank0, master_rank1), "Master param mismatch" - # Some debugging/investigation assistance code: - # maxval, maxind = torch.max(((torch.abs(model_rank0).float())/torch.abs(master_rank0)).view(-1), 0) - # offending_val_half = model_rank0.view(-1)[maxind.item()] - # offending_val_float = master_rank0.view(-1)[maxind.item()] - # print(maxval.item(), maxind.item(), offending_val_half.item(), offending_val_float.item(), - # offending_val_float.half().item()) - # rtol needs to be > 2^-11 because of denormals... - assert torch.allclose(model_rank0, master_rank0, rtol=.005), "Model-master mismatch" - -print("OK: Model and master params match across ranks.") diff --git a/tests/distributed/amp_master_params/run.sh b/tests/distributed/amp_master_params/run.sh deleted file mode 100644 index 8599dbb..0000000 --- a/tests/distributed/amp_master_params/run.sh +++ /dev/null @@ -1,4 +0,0 @@ -#!/bin/bash -python -m torch.distributed.launch --nproc_per_node=2 amp_master_params.py - -python compare.py diff --git a/tests/distributed/run_rocm_distributed.sh b/tests/distributed/run_rocm_distributed.sh deleted file mode 100644 index 89cb4e1..0000000 --- a/tests/distributed/run_rocm_distributed.sh +++ /dev/null @@ -1,46 +0,0 @@ -#!/bin/bash -set -e - -# To run the test on 2 gpus -export WORLD_SIZE=2 - -# Test with opt_level="O2" -echo "running opt_level O2" -python -m torch.distributed.launch --nproc_per_node=2 amp_master_params/amp_master_params.py --opt_level "O2" -python amp_master_params/compare.py - -# delete the model files -echo -e "O2 test completed. Deleting model files\n" -rm rank0model.pth -rm rank1model.pth -rm rank0master.pth -rm rank1master.pth - - -# Test with opt_level="O5" -#echo "running opt_level O5" -#python -m torch.distributed.launch --nproc_per_node=2 amp_master_params/amp_master_params.py --opt_level "O5" -#python amp_master_params/compare.py - -## delete the model files -#echo "O5 test completed. Deleting model files" -#rm rank0model.pth -#rm rank1model.pth -#rm rank0master.pth -#rm rank1master.pth - -## Run the Sync BN Tests. -echo "Running syncbn tests" -python -m torch.distributed.launch --nproc_per_node=2 synced_batchnorm/two_gpu_unit_test.py -python -m torch.distributed.launch --nproc_per_node=2 synced_batchnorm/two_gpu_unit_test.py --fp16 -python -m torch.distributed.launch --nproc_per_node=2 synced_batchnorm/two_gpu_test_different_batch_size.py --apex -echo "Running syncbn python only tests" -python synced_batchnorm/python_single_gpu_unit_test.py -echo "Running syncbn batchnorm1d tests" -python synced_batchnorm/test_batchnorm1d.py -#beware, you need a system with at least 4 gpus to test group_size= error + error * np.abs(b)).nonzero() - print("dif : ", z[index]) - print("inp1 : ", a[index]) - print("inp2 : ", b[index]) - return close - -feature_size = 10 -space_size = 16 -batch_size = 5 - - -error = 1e-5 - -np.random.seed(1) -dtype = np.float32 -inp = (np.random.randn(batch_size, feature_size, space_size, space_size)).astype(dtype) -grad = (np.random.randn(batch_size, feature_size, space_size, space_size)).astype(dtype) -weight = (np.random.randn(feature_size)).astype(dtype) -bias = (np.random.randn(feature_size)).astype(dtype) - -type_tensor = torch.cuda.FloatTensor -ref_tensor = torch.cuda.DoubleTensor - -inp_t = type_tensor(inp) -weight_t = type_tensor(weight) -bias_t = type_tensor(bias) - -inp_r = ref_tensor(inp.transpose(1, 0, 2, 3).reshape(feature_size, -1)) -inp2_r = ref_tensor(inp) -weight_r = ref_tensor(weight).view(-1, 1, 1) -bias_r = ref_tensor(bias).view(-1, 1, 1) - -grad_output_t = type_tensor(grad) - -m = inp_r.mean(1) -b_v = inp_r.var(1, unbiased=False) -unb_v = inp_r.var(1, unbiased=True) - -eps = 1e-5 - -bn = torch.nn.BatchNorm2d(feature_size).cuda() -bn.momentum = 1.0 -bn.weight.data = weight_t.clone() -bn.bias.data = bias_t.clone() -inp_bn = inp_t.clone().requires_grad_() -grad_bn = grad_output_t.clone().detach() -out_bn = bn(inp_bn) -out_bn.backward(grad_bn) - -from apex.parallel.sync_batchnorm import SyncBatchNorm - -sbn = SyncBatchNorm(feature_size).cuda() -sbn.momentum = 1.0 -sbn.weight.data = weight_t.clone() -sbn.bias.data = bias_t.clone() -inp_sbn = inp_t.clone().requires_grad_() -grad_sbn = grad_output_t.clone().detach() -out_sbn = sbn(inp_sbn) -out_sbn.backward(grad_sbn) - -sbn_result = True -sbn_result_c_last = True -bn_result = True - -out_r = weight_r * (inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1,1,1) + eps) + bias_r - -compare("comparing bn output: ", out_bn, out_r, error) - -grad_output_t = type_tensor(grad) - -grad_output_r = ref_tensor(grad.transpose(1, 0, 2, 3).reshape(feature_size, -1)) -grad_output2_r = ref_tensor(grad) - -grad_bias_r = grad_output_r.sum(1) -grad_weight_r = ((inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1,1,1) + eps) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).sum(1) - -mean_dy_r = grad_output_r.mean(1) -mean_dy_xmu_r = ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).mean(1) - -grad_input_r = (grad_output2_r - mean_dy_r.view(-1, 1, 1) - (inp2_r - m.view(-1, 1, 1)) / (b_v.view(-1,1,1) + eps) * mean_dy_xmu_r.view(-1, 1, 1) ) * torch.rsqrt(b_v.view(-1,1,1) + eps) * weight_r.view(-1,1,1) - -compare("comparing bn input grad: ", inp_bn.grad, grad_input_r, error) -sbn_result = compare("comparing sbn input grad: ", inp_sbn.grad, grad_input_r, error) and sbn_result - -compare("comparing bn/sbn output: ", out_bn, out_sbn, error) -sbn_result = compare("comparing running_mean: ", bn.running_mean.data, sbn.running_mean.data, error) and sbn_result -sbn_result = compare("comparing running_variance: ", bn.running_var.data, sbn.running_var.data, error) and sbn_result -compare("comparing grad_input: ", inp_bn.grad, inp_sbn.grad, error) -compare("comparing grad_bias: ", bn.bias.grad, sbn.bias.grad, error) -compare("comparing grad_bias bn to ref: ", bn.bias.grad, grad_bias_r, error) -sbn_result = compare("comparing grad_bias sbn to ref: ", sbn.bias.grad, grad_bias_r, error) and sbn_result -compare("comparing grad_weight: ", bn.weight.grad, sbn.weight.grad, error) -compare("comparing grad_weight bn to ref: ", bn.weight.grad, grad_weight_r, error) -sbn_result = compare("comparing grad_weight sbn to ref: ", sbn.weight.grad, grad_weight_r, error) and sbn_result - -if sbn_result: - print("====SBN single gpu passed tests") -else: - print("*SBN single gpu failed*") - -assert sbn_result diff --git a/tests/distributed/synced_batchnorm/single_gpu_unit_test.py b/tests/distributed/synced_batchnorm/single_gpu_unit_test.py deleted file mode 100644 index 446b6b0..0000000 --- a/tests/distributed/synced_batchnorm/single_gpu_unit_test.py +++ /dev/null @@ -1,162 +0,0 @@ -import torch -import numpy as np -import apex -if True: - print("using setup tools") - import syncbn -else: - print("using jit") - from torch.utils.cpp_extension import load - syncbn = load(name='syncbn', sources=['../../csrc/syncbn.cpp', '../../csrc/welford.cu']) - -def compare(desc, inp1, inp2, error): - a = inp1.clone().detach().cpu().numpy() - b = inp2.clone().detach().cpu().numpy() - close = np.allclose(a,b, error, error) - if not close: - print(desc, close) - z = a - b - index = (np.abs(z) >= error + error * np.abs(b)).nonzero() - print("dif : ", z[index]) - print("inp1 : ", a[index]) - print("inp2 : ", b[index]) - return close - -feature_size = 10 -space_size = 16 -batch_size = 5 - - -error = 1e-5 - -np.random.seed(1) -dtype = np.float32 -inp = (np.random.randn(batch_size, feature_size, space_size, space_size)).astype(dtype) -grad = (np.random.randn(batch_size, feature_size, space_size, space_size)).astype(dtype) -weight = (np.random.randn(feature_size)).astype(dtype) -bias = (np.random.randn(feature_size)).astype(dtype) -count = torch.cuda.IntTensor([batch_size*space_size**2]) - -type_tensor = torch.cuda.FloatTensor -ref_tensor = torch.cuda.DoubleTensor - -inp_t = type_tensor(inp) -weight_t = type_tensor(weight) -bias_t = type_tensor(bias) - -inp_r = ref_tensor(inp.transpose(1, 0, 2, 3).reshape(feature_size, -1)) -inp2_r = ref_tensor(inp) -weight_r = ref_tensor(weight).view(-1, 1, 1) -bias_r = ref_tensor(bias).view(-1, 1, 1) - -grad_output_t = type_tensor(grad) - -m = inp_r.mean(1) -b_v = inp_r.var(1, unbiased=False) -unb_v = inp_r.var(1, unbiased=True) - -eps = 1e-5 - -#mean, var, var_biased = syncbn.welford_mean_var(inp_t) -mean, var_biased = syncbn.welford_mean_var(inp_t) -inv_std = 1.0 / torch.sqrt(var_biased + eps) - -bn = torch.nn.BatchNorm2d(feature_size).cuda() -bn.momentum = 1.0 -bn.weight.data = weight_t.clone() -bn.bias.data = bias_t.clone() -inp_bn = inp_t.clone().requires_grad_() -grad_bn = grad_output_t.clone().detach() -out_bn = bn(inp_bn) -out_bn.backward(grad_bn) - -sbn = apex.parallel.SyncBatchNorm(feature_size).cuda() -sbn.momentum = 1.0 -sbn.weight.data = weight_t.clone() -sbn.bias.data = bias_t.clone() -inp_sbn = inp_t.clone().requires_grad_() -grad_sbn = grad_output_t.clone().detach() -out_sbn = sbn(inp_sbn) -out_sbn.backward(grad_sbn) - -sbn_c_last = apex.parallel.SyncBatchNorm(feature_size, channel_last=True).cuda() -sbn_c_last.momentum = 1.0 -sbn_c_last.weight.data = weight_t.clone() -sbn_c_last.bias.data = bias_t.clone() -inp_sbn_c_last = inp_t.clone().transpose(-1, 1).contiguous().requires_grad_() -grad_sbn_c_last = grad_output_t.clone().transpose(-1, 1).contiguous().detach() -out_sbn_c_last = sbn_c_last(inp_sbn_c_last) -out_sbn_c_last.backward(grad_sbn_c_last) - -sbn_result = True -sbn_result_c_last = True -bn_result = True - -sbn_result = compare("comparing mean: ", mean, m, error) and sbn_result -#sbn_result = compare("comparing variance: ", var, unb_v, error) and sbn_result -sbn_result = compare("comparing biased variance: ", var_biased, b_v, error) and sbn_result - - -out = syncbn.batchnorm_forward(inp_t, mean, inv_std, weight_t, bias_t) -out_r = weight_r * (inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1,1,1) + eps) + bias_r - -sbn_result = compare("comparing output: ", out, out_r, error) and sbn_result -compare("comparing bn output: ", out_bn, out_r, error) - -grad_output_t = type_tensor(grad) - -grad_output_r = ref_tensor(grad.transpose(1, 0, 2, 3).reshape(feature_size, -1)) -grad_output2_r = ref_tensor(grad) - -grad_bias_r = grad_output_r.sum(1) -grad_weight_r = ((inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1,1,1) + eps) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).sum(1) - -sum_dy_r = grad_output_r.sum(1) -mean_dy_r = grad_output_r.mean(1) -sum_dy_xmu_r = ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).sum(1) -mean_dy_xmu_r = ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).mean(1) - -grad_input_r = (grad_output2_r - mean_dy_r.view(-1, 1, 1) - (inp2_r - m.view(-1, 1, 1)) / (b_v.view(-1,1,1) + eps) * mean_dy_xmu_r.view(-1, 1, 1) ) * torch.rsqrt(b_v.view(-1,1,1) + eps) * weight_r.view(-1,1,1) - -sum_dy, sum_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output_t, inp_t, mean, inv_std, weight_t) -grad_input = syncbn.batchnorm_backward(grad_output_t, inp_t, mean, inv_std, weight_t, sum_dy, sum_dy_xmu, count) -sbn_result = compare("comparing bias grad: ", grad_bias, grad_bias_r, error) and sbn_result -sbn_result = compare("comparing weight grad: ", grad_weight, grad_weight_r, error) and sbn_result -sbn_result = compare("comparing sum_dy grad: ", sum_dy, sum_dy_r, error) and sbn_result -sbn_result = compare("comparing sum_dy_xmu grad: ", sum_dy_xmu, sum_dy_xmu_r, error) and sbn_result -sbn_result = compare("comparing input grad: ", grad_input, grad_input_r, error) and sbn_result -compare("comparing bn input grad: ", inp_bn.grad, grad_input_r, error) -sbn_result = compare("comparing sbn input grad: ", inp_sbn.grad, grad_input_r, error) and sbn_result - -compare("comparing bn/sbn output: ", out_bn, out_sbn, error) -sbn_result = compare("comparing running_mean: ", bn.running_mean.data, sbn.running_mean.data, error) and sbn_result -sbn_result = compare("comparing running_variance: ", bn.running_var.data, sbn.running_var.data, error) and sbn_result -compare("comparing grad_input: ", inp_bn.grad, inp_sbn.grad, error) -compare("comparing grad_bias: ", bn.bias.grad, sbn.bias.grad, error) -compare("comparing grad_bias bn to ref: ", bn.bias.grad, grad_bias_r, error) -sbn_result = compare("comparing grad_bias sbn to ref: ", sbn.bias.grad, grad_bias_r, error) and sbn_result -compare("comparing grad_weight: ", bn.weight.grad, sbn.weight.grad, error) -compare("comparing grad_weight bn to ref: ", bn.weight.grad, grad_weight_r, error) -sbn_result = compare("comparing grad_weight sbn to ref: ", sbn.weight.grad, grad_weight_r, error) and sbn_result - -compare("comparing channel last bn/sbn output: ", out_bn, out_sbn_c_last.transpose(-1, 1).contiguous(), error) -sbn_result_c_last = compare("comparing channel last running_mean: ", bn.running_mean.data, sbn_c_last.running_mean.data, error) and sbn_result_c_last -sbn_result_c_last = compare("comparing channel last running_variance: ", bn.running_var.data, sbn_c_last.running_var.data, error) and sbn_result_c_last -compare("comparing channel last grad_input: ", inp_bn.grad, inp_sbn_c_last.grad.transpose(-1, 1).contiguous(), error) -compare("comparing channel last grad_bias: ", bn.bias.grad, sbn_c_last.bias.grad, error) -sbn_result_c_last = compare("comparing channel last grad_bias sbn to ref: ", sbn_c_last.bias.grad, grad_bias_r, error) and sbn_result_c_last -compare("comparing channel last grad_weight: ", bn.weight.grad, sbn_c_last.weight.grad, error) -sbn_result_c_last = compare("comparing channel last grad_weight sbn to ref: ", sbn_c_last.weight.grad, grad_weight_r, error) and sbn_result_c_last - -if sbn_result: - print("====SBN single gpu passed tests") -else: - print("*SBN single gpu failed*") - -if sbn_result_c_last: - print("====SBN channel last single gpu passed tests") -else: - print("*SBN channel last single gpu failed*") - -assert sbn_result -assert sbn_result_c_last diff --git a/tests/distributed/synced_batchnorm/test_batchnorm1d.py b/tests/distributed/synced_batchnorm/test_batchnorm1d.py deleted file mode 100644 index f35ac47..0000000 --- a/tests/distributed/synced_batchnorm/test_batchnorm1d.py +++ /dev/null @@ -1,18 +0,0 @@ -import torch -import apex - -model = apex.parallel.SyncBatchNorm(4).cuda() -model.weight.data.uniform_() -model.bias.data.uniform_() -data = torch.rand((8,4)).cuda() - -model_ref = torch.nn.BatchNorm1d(4).cuda() -model_ref.load_state_dict(model.state_dict()) -data_ref = data.clone() - -output = model(data) -output_ref = model_ref(data_ref) - -assert(output.allclose(output_ref)) -assert(model.running_mean.allclose(model_ref.running_mean)) -assert(model.running_var.allclose(model_ref.running_var)) diff --git a/tests/distributed/synced_batchnorm/test_groups.py b/tests/distributed/synced_batchnorm/test_groups.py deleted file mode 100644 index 674f8e6..0000000 --- a/tests/distributed/synced_batchnorm/test_groups.py +++ /dev/null @@ -1,189 +0,0 @@ -import torch -import numpy as np -import apex -import syncbn -import os -import argparse -import torch.optim as optim - -def compare(desc, inp1, inp2, error): - a = inp1.clone().detach().cpu().numpy() - b = inp2.clone().detach().cpu().numpy() - close = np.allclose(a,b, error, error) - if not close: - print(desc, close) - z = a - b - index = (np.abs(z) >= error + error * np.abs(b)).nonzero() - print("dif : ", z[index]) - print("inp1 : ", a[index]) - print("inp2 : ", b[index]) - return close - -feature_size = 10 -space_size = 40 -batch_size = 32 - - -from apex.parallel import DistributedDataParallel as DDP -parser = argparse.ArgumentParser() -parser.add_argument("--local_rank", default=0, type=int) -parser.add_argument("--fp16", action='store_true', default=False) -parser.add_argument("--fp64", action='store_true', default=False) -parser.add_argument("--group_size", default=0, type=int) -args = parser.parse_args() - -try: - args.world_size = int(os.environ['WORLD_SIZE']) -except: - print("This is a multi-gpu test. To run it please use 'python -m torch.distributed.launch --nproc_per_node= test_groups.py '") - exit(1) - -torch.cuda.set_device(args.local_rank) -torch.distributed.init_process_group(backend='nccl', init_method='env://') - -start = (args.local_rank%args.group_size) * batch_size//args.group_size -finish = (args.local_rank%args.group_size + 1) * batch_size//args.group_size - -error = 1e-5 -dtype = np.float32 -if args.fp16: - error = 1e-3 - dtype = np.float16 -elif args.fp64: - error = 1e-8 - dtype = np.float64 - - -np.random.seed(18 + args.local_rank//args.group_size) - -inp = np.random.randn(batch_size, feature_size, space_size, space_size).astype(dtype) -grad = np.random.randn(batch_size, feature_size, space_size, space_size).astype(dtype) -weight = np.random.randn(feature_size).astype(dtype) -bias = np.random.randn(feature_size).astype(dtype) -#count = torch.cuda.IntTensor([batch_size*space_size**2]) -count = [ space_size**2 * ( (i+1) * batch_size // args.world_size - i * batch_size // args.world_size ) for i in range(0, args.world_size)] -count = torch.cuda.IntTensor(count) - -print("--- count : " , count) - -type_tensor = torch.cuda.FloatTensor -if args.fp16: - type_tensor = torch.cuda.HalfTensor -if args.fp64: - type_tensor = torch.cuda.DoubleTensor - -ref_tensor = torch.cuda.DoubleTensor - -inp_t = type_tensor(inp) -weight_t = type_tensor(weight) -bias_t = type_tensor(bias) - -inp_r = ref_tensor(inp.transpose(1, 0, 2, 3).reshape(feature_size, -1)) -inp2_r = ref_tensor(inp) -weight_r = ref_tensor(weight).view(-1, 1, 1) -bias_r = ref_tensor(bias).view(-1, 1, 1) - -grad_output_t = type_tensor(grad) - -m = inp_r.mean(1) -b_v = inp_r.var(1, unbiased=False) -unb_v = inp_r.var(1, unbiased=True) - -eps = 1e-5 - -mean, var_biased = syncbn.welford_mean_var(inp_t) -inv_std = 1.0 / torch.sqrt(var_biased + eps) - -bn = torch.nn.BatchNorm2d(feature_size).cuda() -bn.momentum = 1.0 -bn.weight.data = weight_t.clone() -bn.bias.data = bias_t.clone() -if args.fp16: - bn.half() -if args.fp64: - bn.double() -bn = DDP(bn) -inp_bn = inp_t.clone().requires_grad_() -grad_bn = grad_output_t.clone().detach() -out_bn = bn(inp_bn) -out_bn.backward(grad_bn) -# compensating the averaging over processes done by DDP -# in order to produce mathematically equivalent result -# https://github.com/NVIDIA/apex/issues/134#issuecomment-458307368 -for param in bn.parameters(): - param.grad = param.grad / args.group_size -bn_opt = optim.SGD(bn.parameters(), lr=1.0) - -sbn = apex.parallel.SyncBatchNorm(feature_size, process_group=apex.parallel.create_syncbn_process_group(args.group_size)).cuda() -sbn.momentum = 1.0 -sbn.weight.data = weight_t.clone() -sbn.bias.data = bias_t.clone() -if args.fp16: - sbn.half() -if args.fp64: - sbn.double() -sbn = DDP(sbn) -sbn_opt = optim.SGD(sbn.parameters(), lr=1.0) -inp_sbn = inp_t.clone().requires_grad_() -grad_sbn = grad_output_t.clone().detach() -out_sbn = sbn(inp_sbn[start:finish]) -out_sbn.backward(grad_sbn[start:finish]) - -sbn_result = True -bn_result = True - -if args.local_rank == 0: - sbn_result = compare("comparing mean: ", mean, m, error) and sbn_result - sbn_result = compare("comparing biased variance: ", var_biased, b_v, error) and sbn_result - -out = syncbn.batchnorm_forward(inp_t, mean, inv_std, weight_t, bias_t) -out_r = weight_r * (inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1,1,1) + eps) + bias_r - -if args.local_rank == 0: - sbn_result = compare("comparing output: ", out, out_r, error) and sbn_result - compare("comparing bn output: ", out_bn, out_r, error) - -grad_output_t = type_tensor(grad) - -grad_output_r = ref_tensor(grad.transpose(1, 0, 2, 3).reshape(feature_size, -1)) -grad_output2_r = ref_tensor(grad) - -grad_bias_r = grad_output_r.sum(1) -grad_weight_r = ((inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1,1,1) + eps) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).sum(1) - -mean_dy_r = grad_output_r.mean(1) -mean_dy_xmu_r = ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).mean(1) - -grad_input_r = (grad_output2_r - mean_dy_r.view(-1, 1, 1) - (inp2_r - m.view(-1, 1, 1)) / (b_v.view(-1,1,1) + eps) * mean_dy_xmu_r.view(-1, 1, 1) ) * torch.rsqrt(b_v.view(-1,1,1) + eps) * weight_r.view(-1,1,1) - -mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output_t, inp_t, mean, inv_std, weight_t) -grad_input = syncbn.batchnorm_backward(grad_output_t, inp_t, mean, inv_std, weight_t, mean_dy, mean_dy_xmu, count) - -if args.local_rank == 0: - sbn_result = compare("comparing bias grad: ", grad_bias, grad_bias_r, error) and sbn_result - sbn_result = compare("comparing weight grad: ", grad_weight, grad_weight_r, error) and sbn_result - sbn_result = compare("comparing mean_dy grad: ", mean_dy, mean_dy_r, error) and sbn_result - sbn_result = compare("comparing mean_dy_xmu grad: ", mean_dy_xmu, mean_dy_xmu_r, error) and sbn_result - sbn_result = compare("comparing input grad: ", grad_input, grad_input_r, error) and sbn_result - compare("comparing bn input grad: ", inp_bn.grad, grad_input_r, error) - -if args.local_rank == 0: - sbn_result = compare("comparing running_mean: ", bn.module.running_mean.data, sbn.module.running_mean.data, error) and sbn_result - sbn_result = compare("comparing running_variance: ", bn.module.running_var.data, sbn.module.running_var.data, error) and sbn_result - -# execute by both -compare("comparing layers output: ", out_bn[start:finish], out_sbn, error) and sbn_result -compare("comparing layers grad_input: ", inp_bn.grad[start:finish], inp_sbn.grad[start:finish], error) and sbn_result - -bn_opt.step() -sbn_opt.step() - -if args.local_rank == 0: - compare("comparing bn vs sbn bias: ", bn.module.bias, sbn.module.bias, error) - compare("comparing bn vs sbn weight: ", bn.module.weight, sbn.module.weight, error) - - -if sbn_result: - print("====SBN group test passed") -else: - print("*SBN group test failed*") diff --git a/tests/distributed/synced_batchnorm/two_gpu_test_different_batch_size.py b/tests/distributed/synced_batchnorm/two_gpu_test_different_batch_size.py deleted file mode 100755 index a9e8cb6..0000000 --- a/tests/distributed/synced_batchnorm/two_gpu_test_different_batch_size.py +++ /dev/null @@ -1,158 +0,0 @@ -import torch -import torch.nn as nn -from torch.nn.parallel import DistributedDataParallel as DDP -from apex.parallel import SyncBatchNorm as ApexSyncBatchNorm - -import argparse -import os -import numpy as np - -var_batch = 16 - -def compare(desc, inp1, inp2, error= 1e-5): - a = inp1.clone().detach().cpu().numpy() - b = inp2.clone().detach().cpu().numpy() - close = np.allclose(a,b, error, error) - if not close: - print(desc, close) - z = a - b - index = (np.abs(z) >= error + error * np.abs(b)).nonzero() - print("dif : ", z[index]) - print("inp1 : ", a[index]) - print("inp2 : ", b[index]) - return close - -parser = argparse.ArgumentParser() -parser.add_argument('--local_rank', type=int, default=0) -parser.add_argument('--apex', action='store_true') -args = parser.parse_args() - - -torch.manual_seed(2809) -# Setup DDP -torch.cuda.set_device(args.local_rank) -device = torch.device('cuda:{}'.format(args.local_rank)) - -torch.distributed.init_process_group( - 'nccl', - init_method='env://', - rank=args.local_rank, -) - -# Setup model -if args.apex: - model = nn.Sequential( - nn.Conv2d(3, 6, 3, 1, 1), - ApexSyncBatchNorm(6) - ) -else: - model = nn.Sequential( - nn.Conv2d(3, 6, 3, 1, 1), - nn.SyncBatchNorm(6) - ) - -# Setup reference model -model_reference = nn.Sequential( - nn.Conv2d(3, 6, 3, 1, 1), - nn.BatchNorm2d(6) -) - -with torch.no_grad(): - model_reference[0].weight.copy_(model[0].weight) - model_reference[0].bias.copy_(model[0].bias) -model_reference.to(device) - -model = model.to(device) -model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank) - -global_batch_size = var_batch + 8 -# Create random data -if args.local_rank == 0: - data = torch.randn(var_batch, 3, 8, 8, device=device, dtype=torch.float) * 50.0 - grad = torch.randint(0, 10, (var_batch, 6, 8, 8), device=device, dtype=torch.float) / 10.0 -else: - data = torch.randn(8, 3, 8, 8, device=device) - grad = torch.randint(0, 10, (8, 6, 8, 8), device=device, dtype=torch.float) / 10.0 - -data.requires_grad_() -data.retain_grad = True - -weighted_gradient = True - -# DDP forward/backward -output = model(data) - -if weighted_gradient: - output.backward(grad * 2 / global_batch_size) -else: - output.backward(grad / output.size(0)) - -d_list = [torch.randn(8, 3, 8, 8, device=device) for i in range(int(os.environ['WORLD_SIZE']))] -y_list = [torch.randn(8, 6, 8, 8, device=device) for i in range(int(os.environ['WORLD_SIZE']))] -dgrad_list = [torch.randn(8, 3, 8, 8, device=device) for i in range(int(os.environ['WORLD_SIZE']))] -grad_list = [torch.randn(8, 6, 8, 8, device=device) for i in range(int(os.environ['WORLD_SIZE']))] -if args.local_rank == 0: - # placeholder, these random data will later be discarded. - torch.distributed.all_gather(d_list, torch.randn(8, 3, 8, 8, device=device)) - torch.distributed.all_gather(y_list, torch.randn(8, 6, 8, 8, device=device)) - torch.distributed.all_gather(dgrad_list, torch.randn(8, 3, 8, 8, device=device)) - torch.distributed.all_gather(grad_list, torch.randn(8, 6, 8, 8, device=device)) -else: - torch.distributed.all_gather(d_list, data) - torch.distributed.all_gather(y_list, output) - torch.distributed.all_gather(dgrad_list, data.grad) - torch.distributed.all_gather(grad_list, grad) - -torch.distributed.barrier() - -if args.local_rank == 0: - ref_tensor = d_list[1:] - ref_tensor.insert(0, data) - assert(ref_tensor[0].equal(data)) - ref_tensor = torch.cat(ref_tensor, 0) - ref_tensor = ref_tensor.detach() - ref_tensor.requires_grad_() - ref_tensor.retain_grad() - - # Reference forward/backward - output_reference = model_reference(ref_tensor) - grad_tensor = grad_list[1:] - grad_tensor.insert(0, grad) - assert(grad_tensor[0].equal(grad)) - grad_tensor = torch.cat(grad_tensor, 0) - if weighted_gradient: - output_reference.backward(grad_tensor / output_reference.size(0)) - else: - output_reference.backward(grad_tensor / output_reference.size(0)) - - dgrad_tensor = dgrad_list[1:] - dgrad_tensor.insert(0, data.grad) - dgrad_tensor = torch.cat(dgrad_tensor, 0) - # check output - output_tensor = y_list[1:] - output_tensor.insert(0, output) - output_tensor = torch.cat(output_tensor, 0) - passed = True - passed = passed and compare("check output", - output_tensor, - output_reference) - # check stats - passed = passed and compare("check running mean failed", - model_reference[1].running_mean, - model.module[1].running_mean) - passed = passed and compare("check running var failed", - model_reference[1].running_var, - model.module[1].running_var) - passed = passed and compare("bn wgrad check failed!", - model_reference[1].weight.grad, - model.module[1].weight.grad, 1e-6) - passed = passed and compare("conv wgrad check failed!", - model_reference[0].weight.grad, - model.module[0].weight.grad) - # can't really compare dgrad directly, as we need to scale it to account for - # DDP - # passed = passed and compare("dgrad check failed!", ref_tensor.grad, dgrad_tensor) - if passed: - print("====SBN two gpu with different batches test passed") - else: - assert("*failed two gpu with different batches tests*") diff --git a/tests/distributed/synced_batchnorm/two_gpu_unit_test.py b/tests/distributed/synced_batchnorm/two_gpu_unit_test.py deleted file mode 100644 index 505ae8f..0000000 --- a/tests/distributed/synced_batchnorm/two_gpu_unit_test.py +++ /dev/null @@ -1,182 +0,0 @@ -import torch -import numpy as np -import apex -import syncbn -import os -import argparse -import torch.optim as optim - -def compare(desc, inp1, inp2, error): - a = inp1.clone().detach().cpu().numpy() - b = inp2.clone().detach().cpu().numpy() - close = np.allclose(a,b, error, error) - if not close: - print(desc, close) - z = a - b - index = (np.abs(z) >= error + error * np.abs(b)).nonzero() - print("dif : ", z[index]) - print("inp1 : ", a[index]) - print("inp2 : ", b[index]) - return close - -feature_size = 10 -space_size = 40 -batch_size = 32 - - -from apex.parallel import DistributedDataParallel as DDP -parser = argparse.ArgumentParser() -parser.add_argument("--local_rank", default=0, type=int) -parser.add_argument("--fp16", action='store_true', default=False) -parser.add_argument("--fp64", action='store_true', default=False) -args = parser.parse_args() -args.world_size = int(os.environ['WORLD_SIZE']) -torch.cuda.set_device(args.local_rank) -torch.distributed.init_process_group(backend='nccl', init_method='env://') -start = args.local_rank * batch_size//args.world_size -finish = (args.local_rank + 1) * batch_size//args.world_size - -error = 1e-5 -dtype = np.float32 -if args.fp16: - error = 1e-3 - dtype = np.float16 -elif args.fp64: - error = 1e-8 - dtype = np.float64 - -np.random.seed(18) -inp = np.random.randn(batch_size, feature_size, space_size, space_size).astype(dtype) -grad = np.random.randn(batch_size, feature_size, space_size, space_size).astype(dtype) -weight = np.random.randn(feature_size).astype(dtype) -bias = np.random.randn(feature_size).astype(dtype) - - -type_tensor = torch.cuda.FloatTensor -if args.fp16: - type_tensor = torch.cuda.HalfTensor -if args.fp64: - type_tensor = torch.cuda.DoubleTensor - -ref_tensor = torch.cuda.DoubleTensor - -inp_t = type_tensor(inp) -weight_t = type_tensor(weight) -bias_t = type_tensor(bias) - -inp_r = ref_tensor(inp.transpose(1, 0, 2, 3).reshape(feature_size, -1)) -inp2_r = ref_tensor(inp) -weight_r = ref_tensor(weight).view(-1, 1, 1) -bias_r = ref_tensor(bias).view(-1, 1, 1) - -grad_output_t = type_tensor(grad) - -m = inp_r.mean(1) -b_v = inp_r.var(1, unbiased=False) -unb_v = inp_r.var(1, unbiased=True) - -eps = 1e-5 - -mean, var_biased = syncbn.welford_mean_var(inp_t) -inv_std = 1.0 / torch.sqrt(var_biased + eps) - -bn = torch.nn.BatchNorm2d(feature_size).cuda() -bn.momentum = 1.0 -bn.weight.data = weight_t.clone() -bn.bias.data = bias_t.clone() -if args.fp16: - bn.half() -if args.fp64: - bn.double() -inp_bn = inp_t.clone().requires_grad_() -grad_bn = grad_output_t.clone().detach() -out_bn = bn(inp_bn) -out_bn.backward(grad_bn) -# compensating the averaging over processes done by DDP -# in order to produce mathematically equivalent result -# https://github.com/NVIDIA/apex/issues/134#issuecomment-458307368 -for param in bn.parameters(): - param.grad = param.grad / args.world_size -bn_opt = optim.SGD(bn.parameters(), lr=1.0) - -sbn = apex.parallel.SyncBatchNorm(feature_size).cuda() -sbn.momentum = 1.0 -sbn.weight.data = weight_t.clone() -sbn.bias.data = bias_t.clone() -if args.fp16: - sbn.half() -if args.fp64: - sbn.double() -sbn = DDP(sbn) -sbn_opt = optim.SGD(sbn.parameters(), lr=1.0) -inp_sbn = inp_t.clone().requires_grad_() -grad_sbn = grad_output_t.clone().detach() -out_sbn = sbn(inp_sbn[start:finish]) -out_sbn.backward(grad_sbn[start:finish]) - -count = [ space_size**2 * ( (i+1) * batch_size // args.world_size - i * batch_size // args.world_size ) for i in range(0, args.world_size)] -count = torch.cuda.IntTensor(count) - -print("--- count : " , count) - -sbn_result = True -bn_result = True - -if args.local_rank == 0: - sbn_result = compare("comparing mean: ", mean, m, error) and sbn_result - sbn_result = compare("comparing biased variance: ", var_biased, b_v, error) and sbn_result - -out = syncbn.batchnorm_forward(inp_t, mean, inv_std, weight_t, bias_t) -out_r = weight_r * (inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1,1,1) + eps) + bias_r - -if args.local_rank == 0: - sbn_result = compare("comparing output: ", out, out_r, error) and sbn_result - compare("comparing bn output: ", out_bn, out_r, error) - -grad_output_t = type_tensor(grad) - -grad_output_r = ref_tensor(grad.transpose(1, 0, 2, 3).reshape(feature_size, -1)) -grad_output2_r = ref_tensor(grad) - -grad_bias_r = grad_output_r.sum(1) -grad_weight_r = ((inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1,1,1) + eps) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).sum(1) - -sum_dy_r = grad_output_r.sum(1) -mean_dy_r = grad_output_r.mean(1) -mean_dy_xmu_r = ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).mean(1) -sum_dy_xmu_r = ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).sum(1) - -grad_input_r = (grad_output2_r - mean_dy_r.view(-1, 1, 1) - (inp2_r - m.view(-1, 1, 1)) / (b_v.view(-1,1,1) + eps) * mean_dy_xmu_r.view(-1, 1, 1) ) * torch.rsqrt(b_v.view(-1,1,1) + eps) * weight_r.view(-1,1,1) - -sum_dy, sum_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output_t, inp_t, mean, inv_std, weight_t) -grad_input = syncbn.batchnorm_backward(grad_output_t, inp_t, mean, inv_std, weight_t, sum_dy, sum_dy_xmu, count) -if args.local_rank == 0: - sbn_result = compare("comparing bias grad: ", grad_bias, grad_bias_r, error) and sbn_result - sbn_result = compare("comparing weight grad: ", grad_weight, grad_weight_r, error) and sbn_result - sbn_result = compare("comparing sum_dy grad: ", sum_dy, sum_dy_r, error) and sbn_result - sbn_result = compare("comparing sum_dy_xmu grad: ", sum_dy_xmu, sum_dy_xmu_r, error) and sbn_result - sbn_result = compare("comparing input grad: ", grad_input, grad_input_r, error) and sbn_result - compare("comparing bn input grad: ", inp_bn.grad, grad_input_r, error) - -if args.local_rank == 0: - sbn_result = compare("comparing running_mean: ", bn.running_mean.data, sbn.module.running_mean.data, error) and sbn_result - sbn_result = compare("comparing running_variance: ", bn.running_var.data, sbn.module.running_var.data, error) and sbn_result - -# execute by both -compare("comparing layers output: ", out_bn[start:finish], out_sbn, error) and sbn_result -compare("comparing layers grad_input: ", inp_bn.grad[start:finish], inp_sbn.grad[start:finish], error) and sbn_result - -bn_opt.step() -sbn_opt.step() - -if args.local_rank == 0: - compare("comparing bn vs sbn bias: ", bn.bias, sbn.module.bias, error) - compare("comparing bn vs sbn weight: ", bn.weight, sbn.module.weight, error) - - -if sbn_result: - print("====SBN two gpu passed tests") -else: - print("*SBN two gpu failed*") - -assert sbn_result diff --git a/tests/distributed/synced_batchnorm/unit_test.sh b/tests/distributed/synced_batchnorm/unit_test.sh deleted file mode 100755 index 4cb4515..0000000 --- a/tests/distributed/synced_batchnorm/unit_test.sh +++ /dev/null @@ -1,8 +0,0 @@ -python python_single_gpu_unit_test.py || exit 1 -python single_gpu_unit_test.py || exit 1 -python test_batchnorm1d.py || exit 1 -python -m torch.distributed.launch --nproc_per_node=2 two_gpu_unit_test.py || exit 1 -python -m torch.distributed.launch --nproc_per_node=2 two_gpu_unit_test.py --fp16 || exit 1 -python -m torch.distributed.launch --nproc_per_node=2 two_gpu_test_different_batch_size.py --apex || exit 1 -#beware, you need a system with at least 4 gpus to test group_size