From 7bc1dae095efb910345602235394dfd74f391e92 Mon Sep 17 00:00:00 2001 From: Mick Date: Thu, 6 Nov 2025 04:28:52 +0800 Subject: [PATCH] WIP: initial multimodal-gen support (#12484) Co-authored-by: yhyang201 Co-authored-by: yizhang2077 <1109276519@qq.com> Co-authored-by: Xinyuan Tong Co-authored-by: ispobock Co-authored-by: JiLi Co-authored-by: CHEN Xi <78632976+RubiaCx@users.noreply.github.com> Co-authored-by: laixin Co-authored-by: SolitaryThinker Co-authored-by: jzhang38 Co-authored-by: BrianChen1129 Co-authored-by: Kevin Lin <42618777+kevin314@users.noreply.github.com> Co-authored-by: Edenzzzz Co-authored-by: rlsu9 Co-authored-by: Jinzhe Pan <48981407+eigensystem@users.noreply.github.com> Co-authored-by: foreverpiano Co-authored-by: RandNMR73 Co-authored-by: PorridgeSwim Co-authored-by: Jiali Chen <90408393+gary-chenjl@users.noreply.github.com> --- .pre-commit-config.yaml | 11 +- docker/Dockerfile.diffusion | 104 + python/pyproject.toml | 22 + python/sglang/cli/__init__.py | 0 python/sglang/cli/generate.py | 21 + python/sglang/cli/main.py | 178 + python/sglang/cli/serve.py | 42 + python/sglang/launch_server.py | 24 +- python/sglang/multimodal_gen/README.md | 83 + python/sglang/multimodal_gen/__init__.py | 7 + .../sglang/multimodal_gen/configs/__init__.py | 3 + .../backend/vmoba/wan_1.3B_77_448_832.json | 16 + .../backend/vmoba/wan_1.3B_77_480_832.json | 16 + .../sglang/multimodal_gen/configs/configs.py | 258 + .../configs/fasthunyuan_t2v.json | 48 + .../multimodal_gen/configs/models/__init__.py | 8 + .../multimodal_gen/configs/models/base.py | 105 + .../configs/models/dits/__init__.py | 7 + .../configs/models/dits/base.py | 69 + .../configs/models/dits/flux.py | 36 + .../configs/models/dits/hunyuanvideo.py | 185 + .../configs/models/dits/qwenimage.py | 36 + .../configs/models/dits/stepvideo.py | 64 + .../configs/models/dits/wanvideo.py | 103 + .../configs/models/encoders/__init__.py | 25 + .../configs/models/encoders/base.py | 85 + .../configs/models/encoders/clip.py | 95 + .../configs/models/encoders/llama.py | 69 + .../configs/models/encoders/qwen_image.py | 67 + .../configs/models/encoders/t5.py | 86 + .../configs/models/vaes/__init__.py | 11 + .../configs/models/vaes/base.py | 158 + .../configs/models/vaes/flux.py | 50 + .../configs/models/vaes/hunyuanvae.py | 41 + .../configs/models/vaes/qwenimage.py | 61 + .../configs/models/vaes/stepvideovae.py | 31 + .../configs/models/vaes/wanvae.py | 88 + .../configs/pipelines/__init__.py | 37 + .../multimodal_gen/configs/pipelines/base.py | 485 ++ .../multimodal_gen/configs/pipelines/flux.py | 174 + .../configs/pipelines/hunyuan.py | 109 + .../configs/pipelines/qwen_image.py | 299 + .../configs/pipelines/registry.py | 168 + .../configs/pipelines/stepvideo.py | 36 + .../multimodal_gen/configs/pipelines/wan.py | 190 + .../multimodal_gen/configs/sample/__init__.py | 5 + .../multimodal_gen/configs/sample/base.py | 494 ++ .../multimodal_gen/configs/sample/flux.py | 18 + .../multimodal_gen/configs/sample/hunyuan.py | 37 + .../configs/sample/qwenimage.py | 18 + .../multimodal_gen/configs/sample/registry.py | 122 + .../configs/sample/stepvideo.py | 22 + .../multimodal_gen/configs/sample/teacache.py | 43 + .../multimodal_gen/configs/sample/wan.py | 217 + python/sglang/multimodal_gen/configs/utils.py | 61 + .../configs/wan_1.3B_t2v_pipeline.json | 41 + .../configs/wan_14B_i2v_480p_pipeline.json | 49 + .../csrc/attn/vmoba_attn/README.md | 31 + .../csrc/attn/vmoba_attn/setup.py | 26 + .../attn/vmoba_attn/tests/test_vmoba_attn.py | 137 + .../csrc/attn/vmoba_attn/vmoba/__init__.py | 2 + .../csrc/attn/vmoba_attn/vmoba/vmoba.py | 1086 +++ python/sglang/multimodal_gen/docs/cli.md | 274 + python/sglang/multimodal_gen/docs/install.md | 52 + .../multimodal_gen/docs/support_matrix.md | 46 + python/sglang/multimodal_gen/envs.py | 326 + .../runtime/architectures/basic/__init__.py | 8 + .../architectures/basic/flux/__init__.py | 1 + .../runtime/architectures/basic/flux/flux.py | 126 + .../architectures/basic/hunyuan/__init__.py | 1 + .../basic/hunyuan/hunyuan_pipeline.py | 93 + .../basic/qwen_image/__init__.py | 1 + .../basic/qwen_image/qwen_image.py | 196 + .../architectures/basic/stepvideo/__init__.py | 1 + .../basic/stepvideo/stepvideo_pipeline.py | 182 + .../architectures/basic/wan/__init__.py | 1 + .../basic/wan/wan_causal_dmd_pipeline.py | 78 + .../basic/wan/wan_dmd_pipeline.py | 98 + .../basic/wan/wan_i2v_dmd_pipeline.py | 113 + .../basic/wan/wan_i2v_pipeline.py | 118 + .../architectures/basic/wan/wan_pipeline.py | 98 + .../architectures/preprocess/__init__.py | 1 + .../preprocess/preprocess_pipeline_base.py | 433 + .../preprocess/preprocess_pipeline_i2v.py | 247 + .../preprocess_pipeline_ode_trajectory.py | 355 + .../preprocess/preprocess_pipeline_t2v.py | 26 + .../preprocess/preprocess_pipeline_text.py | 200 + .../preprocess/preprocess_stages.py | 134 + .../architectures/preprocess/v1_preprocess.py | 147 + .../preprocess/v1_preprocessing_new.py | 26 + .../architectures/preprocess/wan/__init__.py | 1 + .../wan/wan_preprocess_pipelines.py | 118 + .../runtime/distributed/__init__.py | 55 + .../runtime/distributed/communication_op.py | 55 + .../device_communicators/__init__.py | 1 + .../base_device_communicator.py | 297 + .../device_communicators/cpu_communicator.py | 161 + .../device_communicators/cuda_communicator.py | 79 + .../device_communicators/pynccl.py | 258 + .../device_communicators/pynccl_wrapper.py | 450 + .../runtime/distributed/group_coordinator.py | 1226 +++ .../runtime/distributed/parallel_state.py | 1144 +++ .../runtime/distributed/utils.py | 195 + .../runtime/entrypoints/__init__.py | 1 + .../runtime/entrypoints/cli/__init__.py | 1 + .../runtime/entrypoints/cli/cli_types.py | 28 + .../runtime/entrypoints/cli/generate.py | 103 + .../runtime/entrypoints/cli/main.py | 44 + .../runtime/entrypoints/cli/serve.py | 69 + .../runtime/entrypoints/cli/utils.py | 74 + .../entrypoints/diffusion_generator.py | 429 + .../runtime/entrypoints/http_server.py | 58 + .../runtime/entrypoints/openai/image_api.py | 255 + .../runtime/entrypoints/openai/protocol.py | 65 + .../runtime/entrypoints/openai/stores.py | 46 + .../runtime/entrypoints/openai/utils.py | 77 + .../runtime/entrypoints/openai/video_api.py | 269 + .../runtime/entrypoints/utils.py | 139 + .../multimodal_gen/runtime/launch_server.py | 142 + .../multimodal_gen/runtime/layers/__init__.py | 1 + .../runtime/layers/activation.py | 129 + .../layers/attention/STA_configuration.py | 414 + .../runtime/layers/attention/__init__.py | 28 + .../layers/attention/backends/__init__.py | 1 + .../layers/attention/backends/aiter.py | 101 + .../attention/backends/attention_backend.py | 180 + .../layers/attention/backends/flash_attn.py | 132 + .../layers/attention/backends/flash_attn_2.py | 78 + .../layers/attention/backends/sage_attn.py | 70 + .../layers/attention/backends/sage_attn3.py | 78 + .../runtime/layers/attention/backends/sdpa.py | 77 + .../attention/backends/sliding_tile_attn.py | 313 + .../attention/backends/video_sparse_attn.py | 331 + .../layers/attention/backends/vmoba.py | 258 + .../runtime/layers/attention/layer.py | 399 + .../runtime/layers/attention/selector.py | 197 + .../runtime/layers/custom_op.py | 110 + .../runtime/layers/layernorm.py | 429 + .../multimodal_gen/runtime/layers/linear.py | 1057 +++ .../runtime/layers/lora/linear.py | 426 + .../multimodal_gen/runtime/layers/mlp.py | 46 + .../runtime/layers/quantization/__init__.py | 71 + .../layers/quantization/base_config.py | 152 + .../runtime/layers/rotary_embedding.py | 886 ++ .../runtime/layers/triton_ops.py | 948 +++ .../multimodal_gen/runtime/layers/usp.py | 255 + .../multimodal_gen/runtime/layers/utils.py | 24 + .../runtime/layers/visual_embedding.py | 186 + .../layers/vocab_parallel_embedding.py | 480 ++ .../multimodal_gen/runtime/loader/__init__.py | 1 + .../runtime/loader/component_loader.py | 670 ++ .../runtime/loader/fsdp_load.py | 314 + .../multimodal_gen/runtime/loader/utils.py | 103 + .../runtime/loader/weight_utils.py | 238 + .../runtime/managers/forward_context.py | 120 + .../runtime/managers/gpu_worker.py | 171 + .../runtime/managers/scheduler.py | 177 + .../runtime/managers/schedulerbase.py | 103 + .../multimodal_gen/runtime/models/__init__.py | 1 + .../runtime/models/dits/base.py | 134 + .../runtime/models/dits/causal_wanvideo.py | 851 ++ .../runtime/models/dits/flux.py | 559 ++ .../runtime/models/dits/hunyuanvideo.py | 961 +++ .../runtime/models/dits/qwen_image.py | 651 ++ .../runtime/models/dits/stepvideo.py | 729 ++ .../runtime/models/dits/wanvideo.py | 945 +++ .../runtime/models/encoders/base.py | 71 + .../runtime/models/encoders/bert.py | 46 + .../runtime/models/encoders/clip.py | 700 ++ .../runtime/models/encoders/llama.py | 459 ++ .../runtime/models/encoders/qwen2_5vl.py | 1181 +++ .../runtime/models/encoders/stepllm.py | 614 ++ .../runtime/models/encoders/t5.py | 716 ++ .../runtime/models/encoders/vision.py | 96 + .../runtime/models/parameter.py | 423 + .../multimodal_gen/runtime/models/registry.py | 366 + .../runtime/models/schedulers/base.py | 37 + .../scheduling_flow_match_euler_discrete.py | 698 ++ .../scheduling_flow_unipc_multistep.py | 853 ++ .../scheduling_self_forcing_flow_match.py | 172 + .../schedulers/scheduling_unipc_multistep.py | 1207 +++ .../multimodal_gen/runtime/models/utils.py | 194 + .../runtime/models/vaes/autoencoder.py | 585 ++ .../models/vaes/autoencoder_kl_qwenimage.py | 1183 +++ .../runtime/models/vaes/common.py | 647 ++ .../runtime/models/vaes/hunyuanvae.py | 852 ++ .../runtime/models/vaes/stepvideovae.py | 1184 +++ .../runtime/models/vaes/wanvae.py | 1343 +++ .../runtime/models/vision_utils.py | 301 + .../runtime/pipelines/README.md | 18 + .../runtime/pipelines/__init__.py | 93 + .../pipelines/composed_pipeline_base.py | 354 + .../pipelines/executors/parallel_executor.py | 92 + .../pipelines/executors/pipeline_executor.py | 71 + .../pipelines/executors/sync_executor.py | 39 + .../runtime/pipelines/lora_pipeline.py | 227 + .../runtime/pipelines/pipeline_batch_info.py | 271 + .../runtime/pipelines/pipeline_registry.py | 239 + .../runtime/pipelines/stages/__init__.py | 59 + .../runtime/pipelines/stages/base.py | 254 + .../pipelines/stages/causal_denoising.py | 506 ++ .../runtime/pipelines/stages/conditioning.py | 105 + .../runtime/pipelines/stages/decoding.py | 232 + .../runtime/pipelines/stages/denoising.py | 1217 +++ .../runtime/pipelines/stages/denoising_dmd.py | 283 + .../runtime/pipelines/stages/encoding.py | 104 + .../pipelines/stages/image_encoding.py | 447 + .../pipelines/stages/input_validation.py | 211 + .../pipelines/stages/latent_preparation.py | 155 + .../pipelines/stages/stepvideo_encoding.py | 97 + .../runtime/pipelines/stages/text_encoding.py | 326 + .../pipelines/stages/timestep_preparation.py | 163 + .../runtime/pipelines/stages/validators.py | 522 ++ .../runtime/platforms/__init__.py | 172 + .../multimodal_gen/runtime/platforms/cpu.py | 61 + .../multimodal_gen/runtime/platforms/cuda.py | 430 + .../runtime/platforms/interface.py | 252 + .../multimodal_gen/runtime/platforms/mps.py | 88 + .../multimodal_gen/runtime/platforms/rocm.py | 138 + .../runtime/scheduler_client.py | 149 + .../multimodal_gen/runtime/server_args.py | 1025 +++ .../runtime/sync_scheduler_client.py | 92 + .../multimodal_gen/runtime/utils/common.py | 291 + .../runtime/utils/distributed.py | 231 + .../runtime/utils/hf_diffusers_utils.py | 384 + .../runtime/utils/logging_utils.py | 401 + .../runtime/utils/performance_logger.py | 76 + .../runtime/workflow/__init__.py | 1 + .../runtime/workflow/preprocess/__init__.py | 1 + .../runtime/workflow/preprocess/components.py | 341 + .../preprocess/preprocess_workflow.py | 143 + .../preprocess/preprocess_workflow_i2v.py | 70 + .../preprocess/preprocess_workflow_t2v.py | 70 + .../runtime/workflow/workflow_base.py | 188 + python/sglang/multimodal_gen/test/__init__.py | 1 + .../test/cli/test_generate_common.py | 105 + .../test/cli/test_generate_t2i_perf.py | 70 + .../test/cli/test_generate_t2v_perf.py | 68 + .../test/cli/test_generate_ti2v_perf.py | 62 + .../multimodal_gen/test/cli/test_serve.py | 287 + .../test/test_files/launch_flux.json | 11 + .../test/test_files/launch_wan.json | 11 + .../multimodal_gen/test/test_files/rabbit.jpg | Bin 0 -> 268656 bytes .../multimodal_gen/test/test_offline_api.py | 75 + .../sglang/multimodal_gen/test/test_utils.py | 260 + python/sglang/multimodal_gen/test/utils.py | 162 + .../multimodal_gen/third_party/__init__.py | 1 + .../multimodal_gen/third_party/pynvml.py | 7227 +++++++++++++++++ python/sglang/multimodal_gen/utils.py | 777 ++ 249 files changed, 63750 insertions(+), 11 deletions(-) create mode 100644 docker/Dockerfile.diffusion create mode 100644 python/sglang/cli/__init__.py create mode 100644 python/sglang/cli/generate.py create mode 100644 python/sglang/cli/main.py create mode 100644 python/sglang/cli/serve.py create mode 100644 python/sglang/multimodal_gen/README.md create mode 100644 python/sglang/multimodal_gen/__init__.py create mode 100644 python/sglang/multimodal_gen/configs/__init__.py create mode 100644 python/sglang/multimodal_gen/configs/backend/vmoba/wan_1.3B_77_448_832.json create mode 100644 python/sglang/multimodal_gen/configs/backend/vmoba/wan_1.3B_77_480_832.json create mode 100644 python/sglang/multimodal_gen/configs/configs.py create mode 100644 python/sglang/multimodal_gen/configs/fasthunyuan_t2v.json create mode 100644 python/sglang/multimodal_gen/configs/models/__init__.py create mode 100644 python/sglang/multimodal_gen/configs/models/base.py create mode 100644 python/sglang/multimodal_gen/configs/models/dits/__init__.py create mode 100644 python/sglang/multimodal_gen/configs/models/dits/base.py create mode 100644 python/sglang/multimodal_gen/configs/models/dits/flux.py create mode 100644 python/sglang/multimodal_gen/configs/models/dits/hunyuanvideo.py create mode 100644 python/sglang/multimodal_gen/configs/models/dits/qwenimage.py create mode 100644 python/sglang/multimodal_gen/configs/models/dits/stepvideo.py create mode 100644 python/sglang/multimodal_gen/configs/models/dits/wanvideo.py create mode 100644 python/sglang/multimodal_gen/configs/models/encoders/__init__.py create mode 100644 python/sglang/multimodal_gen/configs/models/encoders/base.py create mode 100644 python/sglang/multimodal_gen/configs/models/encoders/clip.py create mode 100644 python/sglang/multimodal_gen/configs/models/encoders/llama.py create mode 100644 python/sglang/multimodal_gen/configs/models/encoders/qwen_image.py create mode 100644 python/sglang/multimodal_gen/configs/models/encoders/t5.py create mode 100644 python/sglang/multimodal_gen/configs/models/vaes/__init__.py create mode 100644 python/sglang/multimodal_gen/configs/models/vaes/base.py create mode 100644 python/sglang/multimodal_gen/configs/models/vaes/flux.py create mode 100644 python/sglang/multimodal_gen/configs/models/vaes/hunyuanvae.py create mode 100644 python/sglang/multimodal_gen/configs/models/vaes/qwenimage.py create mode 100644 python/sglang/multimodal_gen/configs/models/vaes/stepvideovae.py create mode 100644 python/sglang/multimodal_gen/configs/models/vaes/wanvae.py create mode 100644 python/sglang/multimodal_gen/configs/pipelines/__init__.py create mode 100644 python/sglang/multimodal_gen/configs/pipelines/base.py create mode 100644 python/sglang/multimodal_gen/configs/pipelines/flux.py create mode 100644 python/sglang/multimodal_gen/configs/pipelines/hunyuan.py create mode 100644 python/sglang/multimodal_gen/configs/pipelines/qwen_image.py create mode 100644 python/sglang/multimodal_gen/configs/pipelines/registry.py create mode 100644 python/sglang/multimodal_gen/configs/pipelines/stepvideo.py create mode 100644 python/sglang/multimodal_gen/configs/pipelines/wan.py create mode 100644 python/sglang/multimodal_gen/configs/sample/__init__.py create mode 100644 python/sglang/multimodal_gen/configs/sample/base.py create mode 100644 python/sglang/multimodal_gen/configs/sample/flux.py create mode 100644 python/sglang/multimodal_gen/configs/sample/hunyuan.py create mode 100644 python/sglang/multimodal_gen/configs/sample/qwenimage.py create mode 100644 python/sglang/multimodal_gen/configs/sample/registry.py create mode 100644 python/sglang/multimodal_gen/configs/sample/stepvideo.py create mode 100644 python/sglang/multimodal_gen/configs/sample/teacache.py create mode 100644 python/sglang/multimodal_gen/configs/sample/wan.py create mode 100644 python/sglang/multimodal_gen/configs/utils.py create mode 100644 python/sglang/multimodal_gen/configs/wan_1.3B_t2v_pipeline.json create mode 100644 python/sglang/multimodal_gen/configs/wan_14B_i2v_480p_pipeline.json create mode 100644 python/sglang/multimodal_gen/csrc/attn/vmoba_attn/README.md create mode 100644 python/sglang/multimodal_gen/csrc/attn/vmoba_attn/setup.py create mode 100644 python/sglang/multimodal_gen/csrc/attn/vmoba_attn/tests/test_vmoba_attn.py create mode 100644 python/sglang/multimodal_gen/csrc/attn/vmoba_attn/vmoba/__init__.py create mode 100644 python/sglang/multimodal_gen/csrc/attn/vmoba_attn/vmoba/vmoba.py create mode 100644 python/sglang/multimodal_gen/docs/cli.md create mode 100644 python/sglang/multimodal_gen/docs/install.md create mode 100644 python/sglang/multimodal_gen/docs/support_matrix.md create mode 100644 python/sglang/multimodal_gen/envs.py create mode 100644 python/sglang/multimodal_gen/runtime/architectures/basic/__init__.py create mode 100644 python/sglang/multimodal_gen/runtime/architectures/basic/flux/__init__.py create mode 100644 python/sglang/multimodal_gen/runtime/architectures/basic/flux/flux.py create mode 100644 python/sglang/multimodal_gen/runtime/architectures/basic/hunyuan/__init__.py create mode 100644 python/sglang/multimodal_gen/runtime/architectures/basic/hunyuan/hunyuan_pipeline.py create mode 100644 python/sglang/multimodal_gen/runtime/architectures/basic/qwen_image/__init__.py create mode 100644 python/sglang/multimodal_gen/runtime/architectures/basic/qwen_image/qwen_image.py create mode 100644 python/sglang/multimodal_gen/runtime/architectures/basic/stepvideo/__init__.py create mode 100644 python/sglang/multimodal_gen/runtime/architectures/basic/stepvideo/stepvideo_pipeline.py create mode 100644 python/sglang/multimodal_gen/runtime/architectures/basic/wan/__init__.py create mode 100644 python/sglang/multimodal_gen/runtime/architectures/basic/wan/wan_causal_dmd_pipeline.py create mode 100644 python/sglang/multimodal_gen/runtime/architectures/basic/wan/wan_dmd_pipeline.py create mode 100644 python/sglang/multimodal_gen/runtime/architectures/basic/wan/wan_i2v_dmd_pipeline.py create mode 100644 python/sglang/multimodal_gen/runtime/architectures/basic/wan/wan_i2v_pipeline.py create mode 100644 python/sglang/multimodal_gen/runtime/architectures/basic/wan/wan_pipeline.py create mode 100644 python/sglang/multimodal_gen/runtime/architectures/preprocess/__init__.py create mode 100644 python/sglang/multimodal_gen/runtime/architectures/preprocess/preprocess_pipeline_base.py create mode 100644 python/sglang/multimodal_gen/runtime/architectures/preprocess/preprocess_pipeline_i2v.py create mode 100644 python/sglang/multimodal_gen/runtime/architectures/preprocess/preprocess_pipeline_ode_trajectory.py create mode 100644 python/sglang/multimodal_gen/runtime/architectures/preprocess/preprocess_pipeline_t2v.py create mode 100644 python/sglang/multimodal_gen/runtime/architectures/preprocess/preprocess_pipeline_text.py create mode 100644 python/sglang/multimodal_gen/runtime/architectures/preprocess/preprocess_stages.py create mode 100644 python/sglang/multimodal_gen/runtime/architectures/preprocess/v1_preprocess.py create mode 100644 python/sglang/multimodal_gen/runtime/architectures/preprocess/v1_preprocessing_new.py create mode 100644 python/sglang/multimodal_gen/runtime/architectures/preprocess/wan/__init__.py create mode 100644 python/sglang/multimodal_gen/runtime/architectures/preprocess/wan/wan_preprocess_pipelines.py create mode 100644 python/sglang/multimodal_gen/runtime/distributed/__init__.py create mode 100644 python/sglang/multimodal_gen/runtime/distributed/communication_op.py create mode 100644 python/sglang/multimodal_gen/runtime/distributed/device_communicators/__init__.py create mode 100644 python/sglang/multimodal_gen/runtime/distributed/device_communicators/base_device_communicator.py create mode 100644 python/sglang/multimodal_gen/runtime/distributed/device_communicators/cpu_communicator.py create mode 100644 python/sglang/multimodal_gen/runtime/distributed/device_communicators/cuda_communicator.py create mode 100644 python/sglang/multimodal_gen/runtime/distributed/device_communicators/pynccl.py create mode 100644 python/sglang/multimodal_gen/runtime/distributed/device_communicators/pynccl_wrapper.py create mode 100644 python/sglang/multimodal_gen/runtime/distributed/group_coordinator.py create mode 100644 python/sglang/multimodal_gen/runtime/distributed/parallel_state.py create mode 100644 python/sglang/multimodal_gen/runtime/distributed/utils.py create mode 100644 python/sglang/multimodal_gen/runtime/entrypoints/__init__.py create mode 100644 python/sglang/multimodal_gen/runtime/entrypoints/cli/__init__.py create mode 100644 python/sglang/multimodal_gen/runtime/entrypoints/cli/cli_types.py create mode 100644 python/sglang/multimodal_gen/runtime/entrypoints/cli/generate.py create mode 100644 python/sglang/multimodal_gen/runtime/entrypoints/cli/main.py create mode 100644 python/sglang/multimodal_gen/runtime/entrypoints/cli/serve.py create mode 100644 python/sglang/multimodal_gen/runtime/entrypoints/cli/utils.py create mode 100644 python/sglang/multimodal_gen/runtime/entrypoints/diffusion_generator.py create mode 100644 python/sglang/multimodal_gen/runtime/entrypoints/http_server.py create mode 100644 python/sglang/multimodal_gen/runtime/entrypoints/openai/image_api.py create mode 100644 python/sglang/multimodal_gen/runtime/entrypoints/openai/protocol.py create mode 100644 python/sglang/multimodal_gen/runtime/entrypoints/openai/stores.py create mode 100644 python/sglang/multimodal_gen/runtime/entrypoints/openai/utils.py create mode 100644 python/sglang/multimodal_gen/runtime/entrypoints/openai/video_api.py create mode 100644 python/sglang/multimodal_gen/runtime/entrypoints/utils.py create mode 100644 python/sglang/multimodal_gen/runtime/launch_server.py create mode 100644 python/sglang/multimodal_gen/runtime/layers/__init__.py create mode 100644 python/sglang/multimodal_gen/runtime/layers/activation.py create mode 100644 python/sglang/multimodal_gen/runtime/layers/attention/STA_configuration.py create mode 100644 python/sglang/multimodal_gen/runtime/layers/attention/__init__.py create mode 100644 python/sglang/multimodal_gen/runtime/layers/attention/backends/__init__.py create mode 100644 python/sglang/multimodal_gen/runtime/layers/attention/backends/aiter.py create mode 100644 python/sglang/multimodal_gen/runtime/layers/attention/backends/attention_backend.py create mode 100644 python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn.py create mode 100644 python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn_2.py create mode 100644 python/sglang/multimodal_gen/runtime/layers/attention/backends/sage_attn.py create mode 100644 python/sglang/multimodal_gen/runtime/layers/attention/backends/sage_attn3.py create mode 100644 python/sglang/multimodal_gen/runtime/layers/attention/backends/sdpa.py create mode 100644 python/sglang/multimodal_gen/runtime/layers/attention/backends/sliding_tile_attn.py create mode 100644 python/sglang/multimodal_gen/runtime/layers/attention/backends/video_sparse_attn.py create mode 100644 python/sglang/multimodal_gen/runtime/layers/attention/backends/vmoba.py create mode 100644 python/sglang/multimodal_gen/runtime/layers/attention/layer.py create mode 100644 python/sglang/multimodal_gen/runtime/layers/attention/selector.py create mode 100644 python/sglang/multimodal_gen/runtime/layers/custom_op.py create mode 100644 python/sglang/multimodal_gen/runtime/layers/layernorm.py create mode 100644 python/sglang/multimodal_gen/runtime/layers/linear.py create mode 100644 python/sglang/multimodal_gen/runtime/layers/lora/linear.py create mode 100644 python/sglang/multimodal_gen/runtime/layers/mlp.py create mode 100644 python/sglang/multimodal_gen/runtime/layers/quantization/__init__.py create mode 100644 python/sglang/multimodal_gen/runtime/layers/quantization/base_config.py create mode 100644 python/sglang/multimodal_gen/runtime/layers/rotary_embedding.py create mode 100644 python/sglang/multimodal_gen/runtime/layers/triton_ops.py create mode 100644 python/sglang/multimodal_gen/runtime/layers/usp.py create mode 100644 python/sglang/multimodal_gen/runtime/layers/utils.py create mode 100644 python/sglang/multimodal_gen/runtime/layers/visual_embedding.py create mode 100644 python/sglang/multimodal_gen/runtime/layers/vocab_parallel_embedding.py create mode 100644 python/sglang/multimodal_gen/runtime/loader/__init__.py create mode 100644 python/sglang/multimodal_gen/runtime/loader/component_loader.py create mode 100644 python/sglang/multimodal_gen/runtime/loader/fsdp_load.py create mode 100644 python/sglang/multimodal_gen/runtime/loader/utils.py create mode 100644 python/sglang/multimodal_gen/runtime/loader/weight_utils.py create mode 100644 python/sglang/multimodal_gen/runtime/managers/forward_context.py create mode 100644 python/sglang/multimodal_gen/runtime/managers/gpu_worker.py create mode 100644 python/sglang/multimodal_gen/runtime/managers/scheduler.py create mode 100644 python/sglang/multimodal_gen/runtime/managers/schedulerbase.py create mode 100644 python/sglang/multimodal_gen/runtime/models/__init__.py create mode 100644 python/sglang/multimodal_gen/runtime/models/dits/base.py create mode 100644 python/sglang/multimodal_gen/runtime/models/dits/causal_wanvideo.py create mode 100644 python/sglang/multimodal_gen/runtime/models/dits/flux.py create mode 100644 python/sglang/multimodal_gen/runtime/models/dits/hunyuanvideo.py create mode 100644 python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py create mode 100644 python/sglang/multimodal_gen/runtime/models/dits/stepvideo.py create mode 100644 python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py create mode 100644 python/sglang/multimodal_gen/runtime/models/encoders/base.py create mode 100644 python/sglang/multimodal_gen/runtime/models/encoders/bert.py create mode 100644 python/sglang/multimodal_gen/runtime/models/encoders/clip.py create mode 100644 python/sglang/multimodal_gen/runtime/models/encoders/llama.py create mode 100644 python/sglang/multimodal_gen/runtime/models/encoders/qwen2_5vl.py create mode 100644 python/sglang/multimodal_gen/runtime/models/encoders/stepllm.py create mode 100644 python/sglang/multimodal_gen/runtime/models/encoders/t5.py create mode 100644 python/sglang/multimodal_gen/runtime/models/encoders/vision.py create mode 100644 python/sglang/multimodal_gen/runtime/models/parameter.py create mode 100644 python/sglang/multimodal_gen/runtime/models/registry.py create mode 100644 python/sglang/multimodal_gen/runtime/models/schedulers/base.py create mode 100644 python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_flow_match_euler_discrete.py create mode 100644 python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_flow_unipc_multistep.py create mode 100644 python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_self_forcing_flow_match.py create mode 100644 python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_unipc_multistep.py create mode 100644 python/sglang/multimodal_gen/runtime/models/utils.py create mode 100644 python/sglang/multimodal_gen/runtime/models/vaes/autoencoder.py create mode 100644 python/sglang/multimodal_gen/runtime/models/vaes/autoencoder_kl_qwenimage.py create mode 100644 python/sglang/multimodal_gen/runtime/models/vaes/common.py create mode 100644 python/sglang/multimodal_gen/runtime/models/vaes/hunyuanvae.py create mode 100644 python/sglang/multimodal_gen/runtime/models/vaes/stepvideovae.py create mode 100644 python/sglang/multimodal_gen/runtime/models/vaes/wanvae.py create mode 100644 python/sglang/multimodal_gen/runtime/models/vision_utils.py create mode 100644 python/sglang/multimodal_gen/runtime/pipelines/README.md create mode 100644 python/sglang/multimodal_gen/runtime/pipelines/__init__.py create mode 100644 python/sglang/multimodal_gen/runtime/pipelines/composed_pipeline_base.py create mode 100644 python/sglang/multimodal_gen/runtime/pipelines/executors/parallel_executor.py create mode 100644 python/sglang/multimodal_gen/runtime/pipelines/executors/pipeline_executor.py create mode 100644 python/sglang/multimodal_gen/runtime/pipelines/executors/sync_executor.py create mode 100644 python/sglang/multimodal_gen/runtime/pipelines/lora_pipeline.py create mode 100644 python/sglang/multimodal_gen/runtime/pipelines/pipeline_batch_info.py create mode 100644 python/sglang/multimodal_gen/runtime/pipelines/pipeline_registry.py create mode 100644 python/sglang/multimodal_gen/runtime/pipelines/stages/__init__.py create mode 100644 python/sglang/multimodal_gen/runtime/pipelines/stages/base.py create mode 100644 python/sglang/multimodal_gen/runtime/pipelines/stages/causal_denoising.py create mode 100644 python/sglang/multimodal_gen/runtime/pipelines/stages/conditioning.py create mode 100644 python/sglang/multimodal_gen/runtime/pipelines/stages/decoding.py create mode 100644 python/sglang/multimodal_gen/runtime/pipelines/stages/denoising.py create mode 100644 python/sglang/multimodal_gen/runtime/pipelines/stages/denoising_dmd.py create mode 100644 python/sglang/multimodal_gen/runtime/pipelines/stages/encoding.py create mode 100644 python/sglang/multimodal_gen/runtime/pipelines/stages/image_encoding.py create mode 100644 python/sglang/multimodal_gen/runtime/pipelines/stages/input_validation.py create mode 100644 python/sglang/multimodal_gen/runtime/pipelines/stages/latent_preparation.py create mode 100644 python/sglang/multimodal_gen/runtime/pipelines/stages/stepvideo_encoding.py create mode 100644 python/sglang/multimodal_gen/runtime/pipelines/stages/text_encoding.py create mode 100644 python/sglang/multimodal_gen/runtime/pipelines/stages/timestep_preparation.py create mode 100644 python/sglang/multimodal_gen/runtime/pipelines/stages/validators.py create mode 100644 python/sglang/multimodal_gen/runtime/platforms/__init__.py create mode 100644 python/sglang/multimodal_gen/runtime/platforms/cpu.py create mode 100644 python/sglang/multimodal_gen/runtime/platforms/cuda.py create mode 100644 python/sglang/multimodal_gen/runtime/platforms/interface.py create mode 100644 python/sglang/multimodal_gen/runtime/platforms/mps.py create mode 100644 python/sglang/multimodal_gen/runtime/platforms/rocm.py create mode 100644 python/sglang/multimodal_gen/runtime/scheduler_client.py create mode 100644 python/sglang/multimodal_gen/runtime/server_args.py create mode 100644 python/sglang/multimodal_gen/runtime/sync_scheduler_client.py create mode 100644 python/sglang/multimodal_gen/runtime/utils/common.py create mode 100644 python/sglang/multimodal_gen/runtime/utils/distributed.py create mode 100644 python/sglang/multimodal_gen/runtime/utils/hf_diffusers_utils.py create mode 100644 python/sglang/multimodal_gen/runtime/utils/logging_utils.py create mode 100644 python/sglang/multimodal_gen/runtime/utils/performance_logger.py create mode 100644 python/sglang/multimodal_gen/runtime/workflow/__init__.py create mode 100644 python/sglang/multimodal_gen/runtime/workflow/preprocess/__init__.py create mode 100644 python/sglang/multimodal_gen/runtime/workflow/preprocess/components.py create mode 100644 python/sglang/multimodal_gen/runtime/workflow/preprocess/preprocess_workflow.py create mode 100644 python/sglang/multimodal_gen/runtime/workflow/preprocess/preprocess_workflow_i2v.py create mode 100644 python/sglang/multimodal_gen/runtime/workflow/preprocess/preprocess_workflow_t2v.py create mode 100644 python/sglang/multimodal_gen/runtime/workflow/workflow_base.py create mode 100644 python/sglang/multimodal_gen/test/__init__.py create mode 100644 python/sglang/multimodal_gen/test/cli/test_generate_common.py create mode 100644 python/sglang/multimodal_gen/test/cli/test_generate_t2i_perf.py create mode 100644 python/sglang/multimodal_gen/test/cli/test_generate_t2v_perf.py create mode 100644 python/sglang/multimodal_gen/test/cli/test_generate_ti2v_perf.py create mode 100644 python/sglang/multimodal_gen/test/cli/test_serve.py create mode 100644 python/sglang/multimodal_gen/test/test_files/launch_flux.json create mode 100644 python/sglang/multimodal_gen/test/test_files/launch_wan.json create mode 100644 python/sglang/multimodal_gen/test/test_files/rabbit.jpg create mode 100644 python/sglang/multimodal_gen/test/test_offline_api.py create mode 100644 python/sglang/multimodal_gen/test/test_utils.py create mode 100644 python/sglang/multimodal_gen/test/utils.py create mode 100644 python/sglang/multimodal_gen/third_party/__init__.py create mode 100644 python/sglang/multimodal_gen/third_party/pynvml.py create mode 100644 python/sglang/multimodal_gen/utils.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 21afe122c..53b7dbf11 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,4 +1,5 @@ default_stages: [pre-commit, pre-push, manual] +exclude: ^python/sglang/multimodal_gen/csrc repos: - repo: https://github.com/pre-commit/pre-commit-hooks @@ -31,7 +32,15 @@ repos: - --select=F401,F821 - --fix files: ^(benchmark/|docs/|examples/|python/sglang/|sgl-router/py_*) - exclude: __init__\.py$|\.ipynb$|^python/sglang/srt/grpc/.*_pb2\.py$|^python/sglang/srt/grpc/.*_pb2_grpc\.py$|^python/sglang/srt/grpc/.*_pb2\.pyi$|^python/sglang/srt/grpc/.*_pb2_grpc\.pyi$ + exclude: | + (?x)^( + .*/__init__\.py$| + .*\.ipynb$| + python/sglang/srt/grpc/.*_pb2\.py$| + python/sglang/srt/grpc/.*_pb2_grpc\.py$| + python/sglang/srt/grpc/.*_pb2\.pyi$| + python/sglang/srt/grpc/.*_pb2_grpc\.pyi$| + )$ - repo: https://github.com/psf/black rev: 24.10.0 hooks: diff --git a/docker/Dockerfile.diffusion b/docker/Dockerfile.diffusion new file mode 100644 index 000000000..9eec1aa63 --- /dev/null +++ b/docker/Dockerfile.diffusion @@ -0,0 +1,104 @@ +FROM nvidia/cuda:12.8.0-cudnn-devel-ubuntu22.04 + +ENV DEBIAN_FRONTEND=noninteractive + +SHELL ["/bin/bash", "-c"] + +WORKDIR /sgl-workspace/sglang + +RUN apt-get update && apt-get install -y --no-install-recommends \ + wget \ + git \ + ca-certificates \ + openssh-server \ + zsh \ + vim \ + curl \ + gcc-11 \ + g++-11 \ + clang-11 \ + libnuma1 libnuma-dev \ + && rm -rf /var/lib/apt/lists/* + +# Install oh-my-zsh and plugins +RUN sh -c "$(curl -fsSL https://raw.githubusercontent.com/ohmyzsh/ohmyzsh/master/tools/install.sh)" "" --unattended \ + && git clone https://github.com/zsh-users/zsh-autosuggestions ${ZSH_CUSTOM:-~/.oh-my-zsh/custom}/plugins/zsh-autosuggestions \ + && git clone https://github.com/zsh-users/zsh-syntax-highlighting.git ${ZSH_CUSTOM:-~/.oh-my-zsh/custom}/plugins/zsh-syntax-highlighting + + +# Set up C++20 compilers for ThunderKittens +RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-11 100 --slave /usr/bin/g++ g++ /usr/bin/g++-11 + +# Set CUDA environment variables +ENV CUDA_HOME=/usr/local/cuda-12.8 +ENV PATH=${CUDA_HOME}/bin:${PATH} +ENV LD_LIBRARY_PATH=${CUDA_HOME}/lib64:$LD_LIBRARY_PATH + +# Install uv and source its environment +RUN curl -LsSf https://astral.sh/uv/install.sh | sh && \ + echo 'source $HOME/.local/bin/env' >> /root/.zshrc + +# Copy just the pyproject.toml first to leverage Docker cache +COPY python/pyproject.toml python/ + +# Create a dummy README to satisfy the installation +RUN mkdir -p python && echo "# Placeholder" > python/README.md + +# Create and activate virtual environment with specific Python version and seed +RUN source $HOME/.local/bin/env && \ + uv venv --python 3.12 --seed /opt/venv && \ + source /opt/venv/bin/activate && \ + uv pip install nvitop && \ + uv pip install --no-cache-dir --upgrade pip && \ + uv pip install --no-cache-dir --prerelease=allow./python[diffusion] + +COPY . . + +# Install dependencies using uv and set up shell configuration +RUN source $HOME/.local/bin/env && \ + source /opt/venv/bin/activate && \ + git config --unset-all http.https://github.com/.extraheader || true && \ + echo 'source /opt/venv/bin/activate' >> /root/.zshrc && \ + echo 'if [ -n "$ZSH_VERSION" ] && [ -f ~/.zshrc ]; then . ~/.zshrc; elif [ -f ~/.bashrc ]; then . ~/.bashrc; fi' > /root/.profile + +# Set PATH to include venv bin +ENV PATH=/opt/venv/bin:$PATH + +# Configure zsh +COPY --chown=root:root <<-"EOF" /root/.zshrc +export ZSH="/root/.oh-my-zsh" + +source $HOME/.local/bin/env +source /opt/venv/bin/activate + +## Theme +ZSH_THEME="robbyrussell" + +## Plugins +plugins=( + git + z + zsh-autosuggestions + zsh-syntax-highlighting +) + +source $ZSH/oh-my-zsh.sh + +## Aliases +alias ll='ls -alF' +alias la='ls -A' +alias l='ls -CF' +alias vi='vim' + +## Enhanced history +HISTSIZE=10000 +SAVEHIST=10000 +setopt HIST_IGNORE_ALL_DUPS +setopt HIST_FIND_NO_DUPS +setopt INC_APPEND_HISTORY +EOF + + +EXPOSE 22 + +CMD ["/bin/zsh"] diff --git a/python/pyproject.toml b/python/pyproject.toml index f7cfc482c..a30431241 100755 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -79,6 +79,25 @@ dependencies = [ [project.optional-dependencies] checkpoint-engine = ["checkpoint-engine==0.1.2"] +diffusion = [ + "diffusers==0.35.2", + "yunchang==0.6.3.post1", + "opencv-python==4.10.0.84", + "imageio==2.36.0", + "imageio-ffmpeg==0.5.1", + "PyYAML==6.0.1", + "moviepy>=2.0.0", + "cloudpickle", + "remote-pdb", + "torchcodec==0.5.0", + "st_attn ==0.0.7", + "vsa==0.0.4", +] + +[tool.uv.extra-build-dependencies] +st-attn = ["torch", "setuptools"] +vsa = ["torch", "setuptools"] + test = [ "accelerate", "expecttest", @@ -102,6 +121,9 @@ tracing = [ "Homepage" = "https://github.com/sgl-project/sglang" "Bug Tracker" = "https://github.com/sgl-project/sglang/issues" +[project.scripts] +sglang = "sglang.cli.main:main" + [tool.setuptools.package-data] "sglang" = [ "srt/layers/moe/fused_moe_triton/configs/*/*.json", diff --git a/python/sglang/cli/__init__.py b/python/sglang/cli/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/sglang/cli/generate.py b/python/sglang/cli/generate.py new file mode 100644 index 000000000..a354846f6 --- /dev/null +++ b/python/sglang/cli/generate.py @@ -0,0 +1,21 @@ +import argparse + +from sglang.cli.main import get_is_diffusion_model, get_model_path +from sglang.multimodal_gen.runtime.entrypoints.cli.generate import ( + add_multimodal_gen_generate_args, + generate_cmd, +) + + +def generate(args, extra_argv): + model_path = get_model_path(extra_argv) + is_diffusion_model = get_is_diffusion_model(model_path) + if is_diffusion_model: + parser = argparse.ArgumentParser(description="SGLang Multimodal Generation") + add_multimodal_gen_generate_args(parser) + parsed_args = parser.parse_args(extra_argv) + generate_cmd(parsed_args) + else: + raise Exception( + f"Generate subcommand is not yet supported for model: {model_path}" + ) diff --git a/python/sglang/cli/main.py b/python/sglang/cli/main.py new file mode 100644 index 000000000..59235897d --- /dev/null +++ b/python/sglang/cli/main.py @@ -0,0 +1,178 @@ +import argparse +import hashlib +import json +import logging +import os +import tempfile +from typing import Optional + +import filelock +from huggingface_hub import hf_hub_download + +from sglang.cli.generate import generate +from sglang.cli.serve import serve + +logger = logging.getLogger(__name__) + +temp_dir = tempfile.gettempdir() + + +def _get_lock(model_name_or_path: str, cache_dir: Optional[str] = None): + lock_dir = cache_dir or temp_dir + os.makedirs(os.path.dirname(lock_dir), exist_ok=True) + model_name = model_name_or_path.replace("/", "-") + hash_name = hashlib.sha256(model_name.encode()).hexdigest() + lock_file_name = hash_name + model_name + ".lock" + lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name), mode=0o666) + return lock + + +# Copied and adapted from hf_diffusers_utils.py +def _maybe_download_model( + model_name_or_path: str, local_dir: str | None = None, download: bool = True +) -> str: + """ + Resolve a model path. If it's a local directory, return it. + If it's a Hugging Face Hub ID, download only the config file + (`model_index.json` or `config.json`) and return its directory. + + Args: + model_name_or_path: Local path or Hugging Face Hub model ID + local_dir: Local directory to save the downloaded file (if any) + download: Whether to download from Hugging Face Hub when needed + + Returns: + Local directory path that contains the downloaded config file, or the original local directory. + """ + + if os.path.exists(model_name_or_path): + logger.info("Model already exists locally") + return model_name_or_path + + if not download: + return model_name_or_path + + with _get_lock(model_name_or_path): + # Try `model_index.json` first (diffusers models) + try: + logger.info( + "Downloading model_index.json from HF Hub for %s...", + model_name_or_path, + ) + file_path = hf_hub_download( + repo_id=model_name_or_path, + filename="model_index.json", + local_dir=local_dir, + ) + logger.info("Downloaded to %s", file_path) + return os.path.dirname(file_path) + except Exception as e_index: + logger.debug("model_index.json not found or failed: %s", e_index) + + # Fallback to `config.json` + try: + logger.info( + "Downloading config.json from HF Hub for %s...", model_name_or_path + ) + file_path = hf_hub_download( + repo_id=model_name_or_path, + filename="config.json", + local_dir=local_dir, + ) + logger.info("Downloaded to %s", file_path) + return os.path.dirname(file_path) + except Exception as e_config: + raise ValueError( + ( + "Could not find model locally at %s and failed to download " + "model_index.json/config.json from HF Hub: %s" + ) + % (model_name_or_path, e_config) + ) from e_config + + +# Copied and adapted from hf_diffusers_utils.py +def is_diffusers_model_path(model_path: str) -> True: + """ + Verify if the model directory contains a valid diffusers configuration. + + Args: + model_path: Path to the model directory + + Returns: + The loaded model configuration as a dictionary if the model is a diffusers model + None if the model is not a diffusers model + """ + + # Prefer model_index.json which indicates a diffusers pipeline + config_path = os.path.join(model_path, "model_index.json") + if not os.path.exists(config_path): + return False + + # Load the config + with open(config_path) as f: + config = json.load(f) + + # Verify diffusers version exists + if "_diffusers_version" not in config: + return False + return True + + +def get_is_diffusion_model(model_path: str): + model_path = _maybe_download_model(model_path) + is_diffusion_model = is_diffusers_model_path(model_path) + if is_diffusion_model: + logger.info("Diffusion model detected") + return is_diffusion_model + + +def get_model_path(extra_argv): + # Find the model_path argument + model_path = None + for i, arg in enumerate(extra_argv): + if arg == "--model-path": + if i + 1 < len(extra_argv): + model_path = extra_argv[i + 1] + break + elif arg.startswith("--model-path="): + model_path = arg.split("=", 1)[1] + break + + if model_path is None: + # Fallback for --help or other cases where model-path is not provided + if any(h in extra_argv for h in ["-h", "--help"]): + raise Exception( + "Usage: sglang serve --model-path [additional-arguments]\n\n" + "This command can launch either a standard language model server or a diffusion model server.\n" + "The server type is determined by the model path.\n" + "For specific arguments, please provide a model_path." + ) + else: + raise Exception( + "Error: --model-path is required. " + "Please provide the path to the model." + ) + return model_path + + +def main(): + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers(dest="subcommand", required=True) + + serve_parser = subparsers.add_parser( + "serve", + help="Launch the SGLang server.", + add_help=False, # Defer help to the specific parser + ) + serve_parser.set_defaults(func=serve) + + generate_parser = subparsers.add_parser( + "generate", + help="Run inference on a multimodal model.", + add_help=False, # Defer help to the specific parser + ) + generate_parser.set_defaults(func=generate) + + args, extra_argv = parser.parse_known_args() + args.func(args, extra_argv) diff --git a/python/sglang/cli/serve.py b/python/sglang/cli/serve.py new file mode 100644 index 000000000..5a10e56c1 --- /dev/null +++ b/python/sglang/cli/serve.py @@ -0,0 +1,42 @@ +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import logging +import os + +from sglang.cli.main import get_is_diffusion_model, get_model_path +from sglang.srt.utils import kill_process_tree + +logger = logging.getLogger(__name__) + + +def serve(args, extra_argv): + model_path = get_model_path(extra_argv) + try: + is_diffusion_model = get_is_diffusion_model(model_path) + if is_diffusion_model: + # Logic for Diffusion Models + from sglang.multimodal_gen.runtime.entrypoints.cli.serve import ( + add_multimodal_gen_serve_args, + execute_serve_cmd, + ) + + parser = argparse.ArgumentParser( + description="SGLang Diffusion Model Serving" + ) + add_multimodal_gen_serve_args(parser) + parsed_args, remaining_argv = parser.parse_known_args(extra_argv) + + execute_serve_cmd(parsed_args, remaining_argv) + else: + # Logic for Standard Language Models + from sglang.launch_server import run_server + from sglang.srt.server_args import prepare_server_args + + # Add a dummy argument for the program name, expected by prepare_server_args + # as it typically processes sys.argv + server_args = prepare_server_args(extra_argv) + + run_server(server_args) + finally: + kill_process_tree(os.getpid(), include_parent=False) diff --git a/python/sglang/launch_server.py b/python/sglang/launch_server.py index 452c4b9a7..9e3e82a78 100644 --- a/python/sglang/launch_server.py +++ b/python/sglang/launch_server.py @@ -7,19 +7,23 @@ import sys from sglang.srt.server_args import prepare_server_args from sglang.srt.utils import kill_process_tree + +def run_server(server_args): + """Run the server based on server_args.grpc_mode.""" + if server_args.grpc_mode: + from sglang.srt.entrypoints.grpc_server import serve_grpc + + asyncio.run(serve_grpc(server_args)) + else: + from sglang.srt.entrypoints.http_server import launch_server + + launch_server(server_args) + + if __name__ == "__main__": server_args = prepare_server_args(sys.argv[1:]) try: - if server_args.grpc_mode: - # Handle gRPC server - from sglang.srt.entrypoints.grpc_server import serve_grpc - - asyncio.run(serve_grpc(server_args)) - else: - # Handle HTTP server - from sglang.srt.entrypoints.http_server import launch_server - - launch_server(server_args) + run_server(server_args) finally: kill_process_tree(os.getpid(), include_parent=False) diff --git a/python/sglang/multimodal_gen/README.md b/python/sglang/multimodal_gen/README.md new file mode 100644 index 000000000..d45337491 --- /dev/null +++ b/python/sglang/multimodal_gen/README.md @@ -0,0 +1,83 @@ +
+ +
+ +**sgl-diffusion is an inference framework for accelerated image/video generation.** + +sgl-diffusion features an end-to-end unified pipeline for accelerating diffusion models. It is designed to be modular and extensible, allowing users to easily add new optimizations and techniques. + +## Key Features + +sgl-diffusion has the following features: + +- State-of-the-art performance optimizations for inference + - [Video Sparse Attention](https://arxiv.org/pdf/2505.13389) + - [Sliding Tile Attention](https://arxiv.org/pdf/2502.04507) + - [TeaCache](https://arxiv.org/pdf/2411.19108) + - [Sage Attention](https://arxiv.org/abs/2410.02367) + - USP + - CFG Parallel +- Diverse hardware and OS support + - Supported hardware: H100, H200, A100, B200, 4090 + - Supported OS: Linux, Windows, MacOS + +## Getting Started + +```bash +uv pip install sglang[.diffusion] --prerelease=allow +``` + +For more information, check the [docs](https://github.com/sgl-project/sglang/tree/main/python/sglang/multimodal_gen/docs/install.md). + + +## Inference + +Here's a minimal example to generate a video using the default settings: + +```python +from sglang.multimodal_gen import DiffGenerator + +def main(): + # Create a diff generator from a pre-trained model + generator = DiffGenerator.from_pretrained( + model_path="Wan-AI/Wan2.1-T2V-1.3B-Diffusers", + num_gpus=1, # Adjust based on your hardware + ) + + # Provide a prompt for your video + prompt = "A curious raccoon peers through a vibrant field of yellow sunflowers, its eyes wide with interest." + + # Generate the video + video = generator.generate( + prompt, + return_frames=True, # Also return frames from this call (defaults to False) + output_path="my_videos/", # Controls where videos are saved + save_output=True + ) + +if __name__ == '__main__': + main() +``` + +Or, more simply, with the CLI: + +```bash +sglang generate --model-path Wan-AI/Wan2.1-T2V-1.3B-Diffusers \ + --text-encoder-cpu-offload --pin-cpu-memory \ + --prompt "A curious raccoon" \ + --save-output +``` + +For more information, check the [docs](https://github.com/sgl-project/sglang/tree/main/python/sglang/multimodal_gen/docs/cli.md). + +## Contributing + +All contributions are welcome. + +## Acknowledgement + +We learnt and reused code from the following projects: + +- [FastVideo](https://github.com/hao-ai-lab/FastVideo.git). The major components of this repo are based on a fork of FastVide on Sept. 24, 2025. +- [xDiT](https://github.com/xdit-project/xDiT). We used the parallelism library from it. +- [diffusers](https://github.com/huggingface/diffusers) We used the pipeline design from it. diff --git a/python/sglang/multimodal_gen/__init__.py b/python/sglang/multimodal_gen/__init__.py new file mode 100644 index 000000000..11f4ecd4e --- /dev/null +++ b/python/sglang/multimodal_gen/__init__.py @@ -0,0 +1,7 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +from sglang.multimodal_gen.configs.pipelines import PipelineConfig +from sglang.multimodal_gen.configs.sample import SamplingParams +from sglang.multimodal_gen.runtime.entrypoints.diffusion_generator import DiffGenerator + +__all__ = ["DiffGenerator", "PipelineConfig", "SamplingParams"] diff --git a/python/sglang/multimodal_gen/configs/__init__.py b/python/sglang/multimodal_gen/configs/__init__.py new file mode 100644 index 000000000..dfff5f2c4 --- /dev/null +++ b/python/sglang/multimodal_gen/configs/__init__.py @@ -0,0 +1,3 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# Configs for pipelines, and pipeline modules (in models folder) diff --git a/python/sglang/multimodal_gen/configs/backend/vmoba/wan_1.3B_77_448_832.json b/python/sglang/multimodal_gen/configs/backend/vmoba/wan_1.3B_77_448_832.json new file mode 100644 index 000000000..1e55b5f2e --- /dev/null +++ b/python/sglang/multimodal_gen/configs/backend/vmoba/wan_1.3B_77_448_832.json @@ -0,0 +1,16 @@ +{ + "temporal_chunk_size": 2, + "temporal_topk": 2, + "spatial_chunk_size": [4, 13], + "spatial_topk": 6, + "st_chunk_size": [4, 4, 13], + "st_topk": 18, + "moba_select_mode": "topk", + "moba_threshold": 0.25, + "moba_threshold_type": "query_head", + "first_full_layer": 0, + "first_full_step": 12, + "temporal_layer": 1, + "spatial_layer": 1, + "st_layer": 1 +} diff --git a/python/sglang/multimodal_gen/configs/backend/vmoba/wan_1.3B_77_480_832.json b/python/sglang/multimodal_gen/configs/backend/vmoba/wan_1.3B_77_480_832.json new file mode 100644 index 000000000..ddf66f48e --- /dev/null +++ b/python/sglang/multimodal_gen/configs/backend/vmoba/wan_1.3B_77_480_832.json @@ -0,0 +1,16 @@ +{ + "temporal_chunk_size": 2, + "temporal_topk": 3, + "spatial_chunk_size": [3, 4], + "spatial_topk": 20, + "st_chunk_size": [4, 6, 4], + "st_topk": 15, + "moba_select_mode": "threshold", + "moba_threshold": 0.25, + "moba_threshold_type": "query_head", + "first_full_layer": 0, + "first_full_step": 12, + "temporal_layer": 1, + "spatial_layer": 1, + "st_layer": 1 +} diff --git a/python/sglang/multimodal_gen/configs/configs.py b/python/sglang/multimodal_gen/configs/configs.py new file mode 100644 index 000000000..f74cd6d9e --- /dev/null +++ b/python/sglang/multimodal_gen/configs/configs.py @@ -0,0 +1,258 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +import dataclasses +from enum import Enum +from typing import Any, Optional + +from sglang.multimodal_gen.configs.utils import update_config_from_args +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import FlexibleArgumentParser, StoreBoolean + +logger = init_logger(__name__) + + +class DatasetType(str, Enum): + """ + Enumeration for different dataset types. + """ + + HF = "hf" + MERGED = "merged" + + @classmethod + def from_string(cls, value: str) -> "DatasetType": + """Convert string to DatasetType enum.""" + try: + return cls(value.lower()) + except ValueError: + raise ValueError( + f"Invalid dataset type: {value}. Must be one of: {', '.join([m.value for m in cls])}" + ) from None + + @classmethod + def choices(cls) -> list[str]: + """Get all available choices as strings for argparse.""" + return [dataset_type.value for dataset_type in cls] + + +class VideoLoaderType(str, Enum): + """ + Enumeration for different video loaders. + """ + + TORCHCODEC = "torchcodec" + TORCHVISION = "torchvision" + + @classmethod + def from_string(cls, value: str) -> "VideoLoaderType": + """Convert string to VideoLoader enum.""" + try: + return cls(value.lower()) + except ValueError: + raise ValueError( + f"Invalid video loader: {value}. Must be one of: {', '.join([m.value for m in cls])}" + ) from None + + @classmethod + def choices(cls) -> list[str]: + """Get all available choices as strings for argparse.""" + return [video_loader.value for video_loader in cls] + + +@dataclasses.dataclass +class PreprocessConfig: + """Configuration for preprocessing operations.""" + + # Model and dataset configuration + model_path: str = "" + dataset_path: str = "" + dataset_type: DatasetType = DatasetType.HF + dataset_output_dir: str = "./output" + + # Dataloader configuration + dataloader_num_workers: int = 1 + preprocess_video_batch_size: int = 2 + + # Saver configuration + samples_per_file: int = 64 + flush_frequency: int = 256 + + # Video processing parameters + video_loader_type: VideoLoaderType = VideoLoaderType.TORCHCODEC + max_height: int = 480 + max_width: int = 848 + num_frames: int = 163 + video_length_tolerance_range: float = 2.0 + train_fps: int = 30 + speed_factor: float = 1.0 + drop_short_ratio: float = 1.0 + do_temporal_sample: bool = False + + # Model configuration + training_cfg_rate: float = 0.0 + + # framework configuration + seed: int = 42 + + @staticmethod + def add_cli_args( + parser: FlexibleArgumentParser, prefix: str = "preprocess" + ) -> FlexibleArgumentParser: + """Add preprocessing configuration arguments to the parser.""" + prefix_with_dot = f"{prefix}." if (prefix.strip() != "") else "" + + preprocess_args = parser.add_argument_group("Preprocessing Arguments") + # Model & Dataset + preprocess_args.add_argument( + f"--{prefix_with_dot}model-path", + type=str, + default=PreprocessConfig.model_path, + help="Path to the model for preprocessing", + ) + preprocess_args.add_argument( + f"--{prefix_with_dot}dataset-path", + type=str, + default=PreprocessConfig.dataset_path, + help="Path to the dataset directory for preprocessing", + ) + preprocess_args.add_argument( + f"--{prefix_with_dot}dataset-type", + type=str, + choices=DatasetType.choices(), + default=PreprocessConfig.dataset_type.value, + help="Type of the dataset", + ) + preprocess_args.add_argument( + f"--{prefix_with_dot}dataset-output-dir", + type=str, + default=PreprocessConfig.dataset_output_dir, + help="The output directory where the dataset will be written.", + ) + + # Dataloader + preprocess_args.add_argument( + f"--{prefix_with_dot}dataloader-num-workers", + type=int, + default=PreprocessConfig.dataloader_num_workers, + help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", + ) + preprocess_args.add_argument( + f"--{prefix_with_dot}preprocess-video-batch-size", + type=int, + default=PreprocessConfig.preprocess_video_batch_size, + help="Batch size (per device) for the training dataloader.", + ) + + # Saver + preprocess_args.add_argument( + f"--{prefix_with_dot}samples-per-file", + type=int, + default=PreprocessConfig.samples_per_file, + help="Number of samples per output file", + ) + preprocess_args.add_argument( + f"--{prefix_with_dot}flush-frequency", + type=int, + default=PreprocessConfig.flush_frequency, + help="How often to save to parquet files", + ) + + # Video processing parameters + preprocess_args.add_argument( + f"--{prefix_with_dot}video-loader-type", + type=str, + choices=VideoLoaderType.choices(), + default=PreprocessConfig.video_loader_type.value, + help="Type of the video loader", + ) + preprocess_args.add_argument( + f"--{prefix_with_dot}max-height", + type=int, + default=PreprocessConfig.max_height, + help="Maximum height for video processing", + ) + preprocess_args.add_argument( + f"--{prefix_with_dot}max-width", + type=int, + default=PreprocessConfig.max_width, + help="Maximum width for video processing", + ) + preprocess_args.add_argument( + f"--{prefix_with_dot}num-frames", + type=int, + default=PreprocessConfig.num_frames, + help="Number of frames to process", + ) + preprocess_args.add_argument( + f"--{prefix_with_dot}video-length-tolerance-range", + type=float, + default=PreprocessConfig.video_length_tolerance_range, + help="Video length tolerance range", + ) + preprocess_args.add_argument( + f"--{prefix_with_dot}train-fps", + type=int, + default=PreprocessConfig.train_fps, + help="Training FPS", + ) + preprocess_args.add_argument( + f"--{prefix_with_dot}speed-factor", + type=float, + default=PreprocessConfig.speed_factor, + help="Speed factor for video processing", + ) + preprocess_args.add_argument( + f"--{prefix_with_dot}drop-short-ratio", + type=float, + default=PreprocessConfig.drop_short_ratio, + help="Ratio for dropping short videos", + ) + preprocess_args.add_argument( + f"--{prefix_with_dot}do-temporal-sample", + action=StoreBoolean, + default=PreprocessConfig.do_temporal_sample, + help="Whether to do temporal sampling", + ) + + # Model Training configuration + preprocess_args.add_argument( + f"--{prefix_with_dot}training-cfg-rate", + type=float, + default=PreprocessConfig.training_cfg_rate, + help="Training CFG rate", + ) + preprocess_args.add_argument( + f"--{prefix_with_dot}seed", + type=int, + default=PreprocessConfig.seed, + help="Seed for random number generator", + ) + + return parser + + @classmethod + def from_kwargs(cls, kwargs: dict[str, Any]) -> Optional["PreprocessConfig"]: + """Create PreprocessConfig from keyword arguments.""" + if "dataset_type" in kwargs and isinstance(kwargs["dataset_type"], str): + kwargs["dataset_type"] = DatasetType.from_string(kwargs["dataset_type"]) + if "video_loader_type" in kwargs and isinstance( + kwargs["video_loader_type"], str + ): + kwargs["video_loader_type"] = VideoLoaderType.from_string( + kwargs["video_loader_type"] + ) + + preprocess_config = cls() + if not update_config_from_args( + preprocess_config, kwargs, prefix="preprocess", pop_args=True + ): + return None + return preprocess_config + + def check_preprocess_config(self) -> None: + if self.dataset_path == "": + raise ValueError("dataset_path must be set for preprocess mode") + if self.samples_per_file <= 0: + raise ValueError("samples_per_file must be greater than 0") + if self.flush_frequency <= 0: + raise ValueError("flush_frequency must be greater than 0") diff --git a/python/sglang/multimodal_gen/configs/fasthunyuan_t2v.json b/python/sglang/multimodal_gen/configs/fasthunyuan_t2v.json new file mode 100644 index 000000000..ac570a6b2 --- /dev/null +++ b/python/sglang/multimodal_gen/configs/fasthunyuan_t2v.json @@ -0,0 +1,48 @@ +{ + "embedded_cfg_scale": 6, + "flow_shift": 17, + "dit_cpu_offload": false, + "disable_autocast": false, + "precision": "bf16", + "vae_precision": "fp32", + "vae_tiling": true, + "vae_sp": true, + "vae_config": { + "load_encoder": false, + "load_decoder": true, + "tile_sample_min_height": 256, + "tile_sample_min_width": 256, + "tile_sample_min_num_frames": 16, + "tile_sample_stride_height": 192, + "tile_sample_stride_width": 192, + "tile_sample_stride_num_frames": 12, + "blend_num_frames": 4, + "use_tiling": true, + "use_temporal_tiling": true, + "use_parallel_tiling": true + }, + "dit_config": { + "prefix": "Hunyuan", + "quant_config": null + }, + "text_encoder_precisions": [ + "fp16", + "fp16" + ], + "text_encoder_configs": [ + { + "prefix": "llama", + "quant_config": null, + "lora_config": null + }, + { + "prefix": "clip", + "quant_config": null, + "lora_config": null, + "num_hidden_layers_override": null, + "require_post_norm": null + } + ], + "mask_strategy_file_path": null, + "enable_torch_compile": false +} diff --git a/python/sglang/multimodal_gen/configs/models/__init__.py b/python/sglang/multimodal_gen/configs/models/__init__.py new file mode 100644 index 000000000..62c0aadfd --- /dev/null +++ b/python/sglang/multimodal_gen/configs/models/__init__.py @@ -0,0 +1,8 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +from sglang.multimodal_gen.configs.models.base import ModelConfig +from sglang.multimodal_gen.configs.models.dits.base import DiTConfig +from sglang.multimodal_gen.configs.models.encoders.base import EncoderConfig +from sglang.multimodal_gen.configs.models.vaes.base import VAEConfig + +__all__ = ["ModelConfig", "VAEConfig", "DiTConfig", "EncoderConfig"] diff --git a/python/sglang/multimodal_gen/configs/models/base.py b/python/sglang/multimodal_gen/configs/models/base.py new file mode 100644 index 000000000..2820a4585 --- /dev/null +++ b/python/sglang/multimodal_gen/configs/models/base.py @@ -0,0 +1,105 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field, fields +from typing import Any, Dict + +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +# 1. ArchConfig contains all fields from diffuser's/transformer's config.json (i.e. all fields related to the architecture of the model) +# 2. ArchConfig should be inherited & overridden by each model arch_config +# 3. Any field in ArchConfig is fixed upon initialization, and should be hidden away from users +@dataclass +class ArchConfig: + stacked_params_mapping: list[tuple[str, str, str]] = field( + default_factory=list + ) # mapping from huggingface weight names to custom names + extra_attrs: Dict[str, Any] = field(default_factory=dict) + + def __getattr__(self, name: str): + d = object.__getattribute__(self, "__dict__") + extras = d.get("extra_attrs") + if extras is not None and name in extras: + return extras[name] + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{name}'" + ) + + def __setattr__(self, key, value): + if key in type(self).__dataclass_fields__: + object.__setattr__(self, key, value) + else: + d = object.__getattribute__(self, "__dict__") + extras = d.get("extra_attrs") + if extras is None: + extras = {} + d["extra_attrs"] = extras + extras[key] = value + + +@dataclass +class ModelConfig: + # Every model config parameter can be categorized into either ArchConfig or everything else + # Diffuser/Transformer parameters + arch_config: ArchConfig = field(default_factory=ArchConfig) + + # sgl-diffusion-specific parameters here + # i.e. STA, quantization, teacache + + def __getattr__(self, name): + # Only called if 'name' is not found in ModelConfig directly + if hasattr(self.arch_config, name): + return getattr(self.arch_config, name) + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'" + ) + + def __getstate__(self): + # Return a dictionary of attributes to pickle + # Convert to dict and exclude any problematic attributes + state = self.__dict__.copy() + return state + + def __setstate__(self, state): + # Restore instance attributes from the unpickled state + self.__dict__.update(state) + + # This should be used only when loading from transformers/diffusers + def update_model_arch(self, source_model_dict: dict[str, Any]) -> None: + """ + Update arch_config with source_model_dict + """ + arch_config = self.arch_config + valid_fields = {f.name for f in fields(arch_config)} + + for key, value in source_model_dict.items(): + setattr(arch_config, key, value) + # else: + # raise AttributeError( + # f"{type(arch_config).__name__} has no field '{key}'" + # ) + + if hasattr(arch_config, "__post_init__"): + arch_config.__post_init__() + + def update_model_config(self, source_model_dict: dict[str, Any]) -> None: + assert ( + "arch_config" not in source_model_dict + ), "Source model config shouldn't contain arch_config." + + valid_fields = {f.name for f in fields(self)} + + for key, value in source_model_dict.items(): + if key in valid_fields: + setattr(self, key, value) + else: + logger.warning( + "%s does not contain field '%s'!", type(self).__name__, key + ) + raise AttributeError(f"Invalid field: {key}") + + if hasattr(self, "__post_init__"): + self.__post_init__() diff --git a/python/sglang/multimodal_gen/configs/models/dits/__init__.py b/python/sglang/multimodal_gen/configs/models/dits/__init__.py new file mode 100644 index 000000000..67e6d97b4 --- /dev/null +++ b/python/sglang/multimodal_gen/configs/models/dits/__init__.py @@ -0,0 +1,7 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +from sglang.multimodal_gen.configs.models.dits.hunyuanvideo import HunyuanVideoConfig +from sglang.multimodal_gen.configs.models.dits.stepvideo import StepVideoConfig +from sglang.multimodal_gen.configs.models.dits.wanvideo import WanVideoConfig + +__all__ = ["HunyuanVideoConfig", "WanVideoConfig", "StepVideoConfig"] diff --git a/python/sglang/multimodal_gen/configs/models/dits/base.py b/python/sglang/multimodal_gen/configs/models/dits/base.py new file mode 100644 index 000000000..128238b36 --- /dev/null +++ b/python/sglang/multimodal_gen/configs/models/dits/base.py @@ -0,0 +1,69 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field +from typing import Any + +from sglang.multimodal_gen.configs.models.base import ArchConfig, ModelConfig +from sglang.multimodal_gen.runtime.layers.quantization import QuantizationConfig +from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum + + +@dataclass +class DiTArchConfig(ArchConfig): + _fsdp_shard_conditions: list = field(default_factory=list) + _compile_conditions: list = field(default_factory=list) + param_names_mapping: dict = field(default_factory=dict) + reverse_param_names_mapping: dict = field(default_factory=dict) + lora_param_names_mapping: dict = field(default_factory=dict) + _supported_attention_backends: set[AttentionBackendEnum] = field( + default_factory=lambda: { + AttentionBackendEnum.SLIDING_TILE_ATTN, + AttentionBackendEnum.SAGE_ATTN, + AttentionBackendEnum.FA3, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.VIDEO_SPARSE_ATTN, + AttentionBackendEnum.VMOBA_ATTN, + AttentionBackendEnum.SAGE_ATTN_THREE, + } + ) + + hidden_size: int = 0 + num_attention_heads: int = 0 + num_channels_latents: int = 0 + exclude_lora_layers: list[str] = field(default_factory=list) + boundary_ratio: float | None = None + + def __post_init__(self) -> None: + if not self._compile_conditions: + self._compile_conditions = self._fsdp_shard_conditions.copy() + + +@dataclass +class DiTConfig(ModelConfig): + arch_config: DiTArchConfig = field(default_factory=DiTArchConfig) + + # sgl-diffusionDiT-specific parameters + prefix: str = "" + quant_config: QuantizationConfig | None = None + + @staticmethod + def add_cli_args(parser: Any, prefix: str = "dit-config") -> Any: + """Add CLI arguments for DiTConfig fields""" + parser.add_argument( + f"--{prefix}.prefix", + type=str, + dest=f"{prefix.replace('-', '_')}.prefix", + default=DiTConfig.prefix, + help="Prefix for the DiT model", + ) + + parser.add_argument( + f"--{prefix}.quant-config", + type=str, + dest=f"{prefix.replace('-', '_')}.quant_config", + default=None, + help="Quantization configuration for the DiT model", + ) + + return parser diff --git a/python/sglang/multimodal_gen/configs/models/dits/flux.py b/python/sglang/multimodal_gen/configs/models/dits/flux.py new file mode 100644 index 000000000..285acecc0 --- /dev/null +++ b/python/sglang/multimodal_gen/configs/models/dits/flux.py @@ -0,0 +1,36 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field +from typing import Tuple + +from sglang.multimodal_gen.configs.models.dits.base import DiTArchConfig, DiTConfig + + +@dataclass +class FluxArchConfig(DiTArchConfig): + patch_size: int = 1 + in_channels: int = 64 + out_channels: int | None = None + num_layers: int = 19 + num_single_layers: int = 38 + attention_head_dim: int = 128 + num_attention_heads: int = 24 + joint_attention_dim: int = 4096 + pooled_projection_dim: int = 768 + guidance_embeds: bool = False + axes_dims_rope: Tuple[int, int, int] = (16, 56, 56) + + def __post_init__(self): + super().__post_init__() + self.out_channels = self.out_channels or self.in_channels + self.hidden_size = self.num_attention_heads * self.attention_head_dim + self.num_channels_latents = self.out_channels + + +@dataclass +class FluxConfig(DiTConfig): + + arch_config: DiTArchConfig = field(default_factory=FluxArchConfig) + + prefix: str = "Flux" diff --git a/python/sglang/multimodal_gen/configs/models/dits/hunyuanvideo.py b/python/sglang/multimodal_gen/configs/models/dits/hunyuanvideo.py new file mode 100644 index 000000000..23a6c715b --- /dev/null +++ b/python/sglang/multimodal_gen/configs/models/dits/hunyuanvideo.py @@ -0,0 +1,185 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field + +import torch + +from sglang.multimodal_gen.configs.models.dits.base import DiTArchConfig, DiTConfig + + +def is_double_block(n: str, m) -> bool: + return "double" in n and str.isdigit(n.split(".")[-1]) + + +def is_single_block(n: str, m) -> bool: + return "single" in n and str.isdigit(n.split(".")[-1]) + + +def is_refiner_block(n: str, m) -> bool: + return "refiner" in n and str.isdigit(n.split(".")[-1]) + + +def is_txt_in(n: str, m) -> bool: + return n.split(".")[-1] == "txt_in" + + +@dataclass +class HunyuanVideoArchConfig(DiTArchConfig): + _fsdp_shard_conditions: list = field( + default_factory=lambda: [is_double_block, is_single_block, is_refiner_block] + ) + + _compile_conditions: list = field( + default_factory=lambda: [is_double_block, is_single_block, is_txt_in] + ) + + param_names_mapping: dict = field( + default_factory=lambda: { + # 1. context_embedder.time_text_embed submodules (specific rules, applied first): + r"^context_embedder\.time_text_embed\.timestep_embedder\.linear_1\.(.*)$": r"txt_in.t_embedder.mlp.fc_in.\1", + r"^context_embedder\.time_text_embed\.timestep_embedder\.linear_2\.(.*)$": r"txt_in.t_embedder.mlp.fc_out.\1", + r"^context_embedder\.proj_in\.(.*)$": r"txt_in.input_embedder.\1", + r"^context_embedder\.time_text_embed\.text_embedder\.linear_1\.(.*)$": r"txt_in.c_embedder.fc_in.\1", + r"^context_embedder\.time_text_embed\.text_embedder\.linear_2\.(.*)$": r"txt_in.c_embedder.fc_out.\1", + r"^context_embedder\.token_refiner\.refiner_blocks\.(\d+)\.norm1\.(.*)$": r"txt_in.refiner_blocks.\1.norm1.\2", + r"^context_embedder\.token_refiner\.refiner_blocks\.(\d+)\.norm2\.(.*)$": r"txt_in.refiner_blocks.\1.norm2.\2", + r"^context_embedder\.token_refiner\.refiner_blocks\.(\d+)\.attn\.to_q\.(.*)$": ( + r"txt_in.refiner_blocks.\1.self_attn_qkv.\2", + 0, + 3, + ), + r"^context_embedder\.token_refiner\.refiner_blocks\.(\d+)\.attn\.to_k\.(.*)$": ( + r"txt_in.refiner_blocks.\1.self_attn_qkv.\2", + 1, + 3, + ), + r"^context_embedder\.token_refiner\.refiner_blocks\.(\d+)\.attn\.to_v\.(.*)$": ( + r"txt_in.refiner_blocks.\1.self_attn_qkv.\2", + 2, + 3, + ), + r"^context_embedder\.token_refiner\.refiner_blocks\.(\d+)\.attn\.to_out\.0\.(.*)$": r"txt_in.refiner_blocks.\1.self_attn_proj.\2", + r"^context_embedder\.token_refiner\.refiner_blocks\.(\d+)\.ff\.net\.0(?:\.proj)?\.(.*)$": r"txt_in.refiner_blocks.\1.mlp.fc_in.\2", + r"^context_embedder\.token_refiner\.refiner_blocks\.(\d+)\.ff\.net\.2(?:\.proj)?\.(.*)$": r"txt_in.refiner_blocks.\1.mlp.fc_out.\2", + r"^context_embedder\.token_refiner\.refiner_blocks\.(\d+)\.norm_out\.linear\.(.*)$": r"txt_in.refiner_blocks.\1.adaLN_modulation.linear.\2", + # 3. x_embedder mapping: + r"^x_embedder\.proj\.(.*)$": r"img_in.proj.\1", + # 4. Top-level time_text_embed mappings: + r"^time_text_embed\.timestep_embedder\.linear_1\.(.*)$": r"time_in.mlp.fc_in.\1", + r"^time_text_embed\.timestep_embedder\.linear_2\.(.*)$": r"time_in.mlp.fc_out.\1", + r"^time_text_embed\.guidance_embedder\.linear_1\.(.*)$": r"guidance_in.mlp.fc_in.\1", + r"^time_text_embed\.guidance_embedder\.linear_2\.(.*)$": r"guidance_in.mlp.fc_out.\1", + r"^time_text_embed\.text_embedder\.linear_1\.(.*)$": r"vector_in.fc_in.\1", + r"^time_text_embed\.text_embedder\.linear_2\.(.*)$": r"vector_in.fc_out.\1", + # 5. transformer_blocks mapping: + r"^transformer_blocks\.(\d+)\.norm1\.linear\.(.*)$": r"double_blocks.\1.img_mod.linear.\2", + r"^transformer_blocks\.(\d+)\.norm1_context\.linear\.(.*)$": r"double_blocks.\1.txt_mod.linear.\2", + r"^transformer_blocks\.(\d+)\.attn\.norm_q\.(.*)$": r"double_blocks.\1.img_attn_q_norm.\2", + r"^transformer_blocks\.(\d+)\.attn\.norm_k\.(.*)$": r"double_blocks.\1.img_attn_k_norm.\2", + r"^transformer_blocks\.(\d+)\.attn\.to_q\.(.*)$": ( + r"double_blocks.\1.img_attn_qkv.\2", + 0, + 3, + ), + r"^transformer_blocks\.(\d+)\.attn\.to_k\.(.*)$": ( + r"double_blocks.\1.img_attn_qkv.\2", + 1, + 3, + ), + r"^transformer_blocks\.(\d+)\.attn\.to_v\.(.*)$": ( + r"double_blocks.\1.img_attn_qkv.\2", + 2, + 3, + ), + r"^transformer_blocks\.(\d+)\.attn\.add_q_proj\.(.*)$": ( + r"double_blocks.\1.txt_attn_qkv.\2", + 0, + 3, + ), + r"^transformer_blocks\.(\d+)\.attn\.add_k_proj\.(.*)$": ( + r"double_blocks.\1.txt_attn_qkv.\2", + 1, + 3, + ), + r"^transformer_blocks\.(\d+)\.attn\.add_v_proj\.(.*)$": ( + r"double_blocks.\1.txt_attn_qkv.\2", + 2, + 3, + ), + r"^transformer_blocks\.(\d+)\.attn\.to_out\.0\.(.*)$": r"double_blocks.\1.img_attn_proj.\2", + # Corrected: merge attn.to_add_out into the main projection. + r"^transformer_blocks\.(\d+)\.attn\.to_add_out\.(.*)$": r"double_blocks.\1.txt_attn_proj.\2", + r"^transformer_blocks\.(\d+)\.attn\.norm_added_q\.(.*)$": r"double_blocks.\1.txt_attn_q_norm.\2", + r"^transformer_blocks\.(\d+)\.attn\.norm_added_k\.(.*)$": r"double_blocks.\1.txt_attn_k_norm.\2", + r"^transformer_blocks\.(\d+)\.ff\.net\.0(?:\.proj)?\.(.*)$": r"double_blocks.\1.img_mlp.fc_in.\2", + r"^transformer_blocks\.(\d+)\.ff\.net\.2(?:\.proj)?\.(.*)$": r"double_blocks.\1.img_mlp.fc_out.\2", + r"^transformer_blocks\.(\d+)\.ff_context\.net\.0(?:\.proj)?\.(.*)$": r"double_blocks.\1.txt_mlp.fc_in.\2", + r"^transformer_blocks\.(\d+)\.ff_context\.net\.2(?:\.proj)?\.(.*)$": r"double_blocks.\1.txt_mlp.fc_out.\2", + # 6. single_transformer_blocks mapping: + r"^single_transformer_blocks\.(\d+)\.attn\.norm_q\.(.*)$": r"single_blocks.\1.q_norm.\2", + r"^single_transformer_blocks\.(\d+)\.attn\.norm_k\.(.*)$": r"single_blocks.\1.k_norm.\2", + r"^single_transformer_blocks\.(\d+)\.attn\.to_q\.(.*)$": ( + r"single_blocks.\1.linear1.\2", + 0, + 4, + ), + r"^single_transformer_blocks\.(\d+)\.attn\.to_k\.(.*)$": ( + r"single_blocks.\1.linear1.\2", + 1, + 4, + ), + r"^single_transformer_blocks\.(\d+)\.attn\.to_v\.(.*)$": ( + r"single_blocks.\1.linear1.\2", + 2, + 4, + ), + r"^single_transformer_blocks\.(\d+)\.proj_mlp\.(.*)$": ( + r"single_blocks.\1.linear1.\2", + 3, + 4, + ), + # Corrected: map proj_out to modulation.linear rather than a separate proj_out branch. + r"^single_transformer_blocks\.(\d+)\.proj_out\.(.*)$": r"single_blocks.\1.linear2.\2", + r"^single_transformer_blocks\.(\d+)\.norm\.linear\.(.*)$": r"single_blocks.\1.modulation.linear.\2", + # 7. Final layers mapping: + r"^norm_out\.linear\.(.*)$": r"final_layer.adaLN_modulation.linear.\1", + r"^proj_out\.(.*)$": r"final_layer.linear.\1", + } + ) + + # Reverse mapping for saving checkpoints: custom -> hf + reverse_param_names_mapping: dict = field(default_factory=lambda: {}) + + patch_size: int = 2 + patch_size_t: int = 1 + in_channels: int = 16 + out_channels: int = 16 + num_attention_heads: int = 24 + attention_head_dim: int = 128 + mlp_ratio: float = 4.0 + num_layers: int = 20 + num_single_layers: int = 40 + num_refiner_layers: int = 2 + rope_axes_dim: tuple[int, int, int] = (16, 56, 56) + guidance_embeds: bool = False + dtype: torch.dtype | None = None + text_embed_dim: int = 4096 + pooled_projection_dim: int = 768 + rope_theta: int = 256 + qk_norm: str = "rms_norm" + exclude_lora_layers: list[str] = field( + default_factory=lambda: ["img_in", "txt_in", "time_in", "vector_in"] + ) + + def __post_init__(self): + super().__post_init__() + self.hidden_size: int = self.attention_head_dim * self.num_attention_heads + self.num_channels_latents: int = self.in_channels + + +@dataclass +class HunyuanVideoConfig(DiTConfig): + arch_config: DiTArchConfig = field(default_factory=HunyuanVideoArchConfig) + + prefix: str = "Hunyuan" diff --git a/python/sglang/multimodal_gen/configs/models/dits/qwenimage.py b/python/sglang/multimodal_gen/configs/models/dits/qwenimage.py new file mode 100644 index 000000000..4cf46a089 --- /dev/null +++ b/python/sglang/multimodal_gen/configs/models/dits/qwenimage.py @@ -0,0 +1,36 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field +from typing import Tuple + +from sglang.multimodal_gen.configs.models.dits.base import DiTArchConfig, DiTConfig + + +@dataclass +class QwenImageArchConfig(DiTArchConfig): + patch_size: int = 1 + in_channels: int = 64 + out_channels: int | None = None + num_layers: int = 19 + num_single_layers: int = 38 + attention_head_dim: int = 128 + num_attention_heads: int = 24 + joint_attention_dim: int = 4096 + pooled_projection_dim: int = 768 + guidance_embeds: bool = False + axes_dims_rope: Tuple[int, int, int] = (16, 56, 56) + + def __post_init__(self): + super().__post_init__() + self.out_channels = self.out_channels or self.in_channels + self.hidden_size = self.num_attention_heads * self.attention_head_dim + self.num_channels_latents = self.out_channels + + +@dataclass +class QwenImageDitConfig(DiTConfig): + + arch_config: DiTArchConfig = field(default_factory=QwenImageArchConfig) + + prefix: str = "qwenimage" diff --git a/python/sglang/multimodal_gen/configs/models/dits/stepvideo.py b/python/sglang/multimodal_gen/configs/models/dits/stepvideo.py new file mode 100644 index 000000000..1d7fe21a6 --- /dev/null +++ b/python/sglang/multimodal_gen/configs/models/dits/stepvideo.py @@ -0,0 +1,64 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field + +from sglang.multimodal_gen.configs.models.dits.base import DiTArchConfig, DiTConfig + + +def is_transformer_blocks(n, m): + return "transformer_blocks" in n and n.split(".")[-1].isdigit() + + +@dataclass +class StepVideoArchConfig(DiTArchConfig): + _fsdp_shard_conditions: list = field( + default_factory=lambda: [is_transformer_blocks] + ) + + param_names_mapping: dict = field( + default_factory=lambda: { + # transformer block + r"^transformer_blocks\.(\d+)\.norm1\.(weight|bias)$": r"transformer_blocks.\1.norm1.norm.\2", + r"^transformer_blocks\.(\d+)\.norm2\.(weight|bias)$": r"transformer_blocks.\1.norm2.norm.\2", + r"^transformer_blocks\.(\d+)\.ff\.net\.0\.proj\.weight$": r"transformer_blocks.\1.ff.fc_in.weight", + r"^transformer_blocks\.(\d+)\.ff\.net\.2\.weight$": r"transformer_blocks.\1.ff.fc_out.weight", + # adanorm block + r"^adaln_single\.emb\.timestep_embedder\.linear_1\.(weight|bias)$": r"adaln_single.emb.mlp.fc_in.\1", + r"^adaln_single\.emb\.timestep_embedder\.linear_2\.(weight|bias)$": r"adaln_single.emb.mlp.fc_out.\1", + # caption projection + r"^caption_projection\.linear_1\.(weight|bias)$": r"caption_projection.fc_in.\1", + r"^caption_projection\.linear_2\.(weight|bias)$": r"caption_projection.fc_out.\1", + } + ) + + num_attention_heads: int = 48 + attention_head_dim: int = 128 + in_channels: int = 64 + out_channels: int | None = 64 + num_layers: int = 48 + dropout: float = 0.0 + patch_size: int = 1 + norm_type: str = "ada_norm_single" + norm_elementwise_affine: bool = False + norm_eps: float = 1e-6 + caption_channels: int | list[int] | tuple[int, ...] | None = field( + default_factory=lambda: [6144, 1024] + ) + attention_type: str | None = "torch" + use_additional_conditions: bool | None = False + exclude_lora_layers: list[str] = field(default_factory=lambda: []) + + def __post_init__(self): + self.hidden_size = self.num_attention_heads * self.attention_head_dim + self.out_channels = ( + self.in_channels if self.out_channels is None else self.out_channels + ) + self.num_channels_latents = self.out_channels + + +@dataclass +class StepVideoConfig(DiTConfig): + arch_config: DiTArchConfig = field(default_factory=StepVideoArchConfig) + + prefix: str = "StepVideo" diff --git a/python/sglang/multimodal_gen/configs/models/dits/wanvideo.py b/python/sglang/multimodal_gen/configs/models/dits/wanvideo.py new file mode 100644 index 000000000..68e6801d7 --- /dev/null +++ b/python/sglang/multimodal_gen/configs/models/dits/wanvideo.py @@ -0,0 +1,103 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field + +from sglang.multimodal_gen.configs.models.dits.base import DiTArchConfig, DiTConfig + + +def is_blocks(n: str, m) -> bool: + return "blocks" in n and str.isdigit(n.split(".")[-1]) + + +@dataclass +class WanVideoArchConfig(DiTArchConfig): + _fsdp_shard_conditions: list = field(default_factory=lambda: [is_blocks]) + + param_names_mapping: dict = field( + default_factory=lambda: { + r"^patch_embedding\.(.*)$": r"patch_embedding.proj.\1", + r"^condition_embedder\.text_embedder\.linear_1\.(.*)$": r"condition_embedder.text_embedder.fc_in.\1", + r"^condition_embedder\.text_embedder\.linear_2\.(.*)$": r"condition_embedder.text_embedder.fc_out.\1", + r"^condition_embedder\.time_embedder\.linear_1\.(.*)$": r"condition_embedder.time_embedder.mlp.fc_in.\1", + r"^condition_embedder\.time_embedder\.linear_2\.(.*)$": r"condition_embedder.time_embedder.mlp.fc_out.\1", + r"^condition_embedder\.time_proj\.(.*)$": r"condition_embedder.time_modulation.linear.\1", + r"^condition_embedder\.image_embedder\.ff\.net\.0\.proj\.(.*)$": r"condition_embedder.image_embedder.ff.fc_in.\1", + r"^condition_embedder\.image_embedder\.ff\.net\.2\.(.*)$": r"condition_embedder.image_embedder.ff.fc_out.\1", + r"^blocks\.(\d+)\.attn1\.to_q\.(.*)$": r"blocks.\1.to_q.\2", + r"^blocks\.(\d+)\.attn1\.to_k\.(.*)$": r"blocks.\1.to_k.\2", + r"^blocks\.(\d+)\.attn1\.to_v\.(.*)$": r"blocks.\1.to_v.\2", + r"^blocks\.(\d+)\.attn1\.to_out\.0\.(.*)$": r"blocks.\1.to_out.\2", + r"^blocks\.(\d+)\.attn1\.norm_q\.(.*)$": r"blocks.\1.norm_q.\2", + r"^blocks\.(\d+)\.attn1\.norm_k\.(.*)$": r"blocks.\1.norm_k.\2", + r"^blocks\.(\d+)\.attn2\.to_out\.0\.(.*)$": r"blocks.\1.attn2.to_out.\2", + r"^blocks\.(\d+)\.ffn\.net\.0\.proj\.(.*)$": r"blocks.\1.ffn.fc_in.\2", + r"^blocks\.(\d+)\.ffn\.net\.2\.(.*)$": r"blocks.\1.ffn.fc_out.\2", + r"^blocks\.(\d+)\.norm2\.(.*)$": r"blocks.\1.self_attn_residual_norm.norm.\2", + } + ) + + # Reverse mapping for saving checkpoints: custom -> hf + reverse_param_names_mapping: dict = field(default_factory=lambda: {}) + + # Some LoRA adapters use the original official layer names instead of hf layer names, + # so apply this before the param_names_mapping + lora_param_names_mapping: dict = field( + default_factory=lambda: { + r"^blocks\.(\d+)\.self_attn\.q\.(.*)$": r"blocks.\1.attn1.to_q.\2", + r"^blocks\.(\d+)\.self_attn\.k\.(.*)$": r"blocks.\1.attn1.to_k.\2", + r"^blocks\.(\d+)\.self_attn\.v\.(.*)$": r"blocks.\1.attn1.to_v.\2", + r"^blocks\.(\d+)\.self_attn\.o\.(.*)$": r"blocks.\1.attn1.to_out.0.\2", + r"^blocks\.(\d+)\.cross_attn\.q\.(.*)$": r"blocks.\1.attn2.to_q.\2", + r"^blocks\.(\d+)\.cross_attn\.k\.(.*)$": r"blocks.\1.attn2.to_k.\2", + r"^blocks\.(\d+)\.cross_attn\.v\.(.*)$": r"blocks.\1.attn2.to_v.\2", + r"^blocks\.(\d+)\.cross_attn\.o\.(.*)$": r"blocks.\1.attn2.to_out.0.\2", + r"^blocks\.(\d+)\.ffn\.0\.(.*)$": r"blocks.\1.ffn.fc_in.\2", + r"^blocks\.(\d+)\.ffn\.2\.(.*)$": r"blocks.\1.ffn.fc_out.\2", + } + ) + + patch_size: tuple[int, int, int] = (1, 2, 2) + text_len = 512 + num_attention_heads: int = 40 + attention_head_dim: int = 128 + in_channels: int = 16 + out_channels: int = 16 + text_dim: int = 4096 + freq_dim: int = 256 + ffn_dim: int = 13824 + num_layers: int = 40 + cross_attn_norm: bool = True + qk_norm: str = "rms_norm_across_heads" + eps: float = 1e-6 + image_dim: int | None = None + added_kv_proj_dim: int | None = None + rope_max_seq_len: int = 1024 + pos_embed_seq_len: int | None = None + exclude_lora_layers: list[str] = field(default_factory=lambda: ["embedder"]) + + # Wan MoE + boundary_ratio: float | None = None + + # Causal Wan + local_attn_size: int = ( + -1 + ) # Window size for temporal local attention (-1 indicates global attention) + sink_size: int = ( + 0 # Size of the attention sink, we keep the first `sink_size` frames unchanged when rolling the KV cache + ) + num_frames_per_block: int = 3 + sliding_window_num_frames: int = 21 + + def __post_init__(self): + super().__post_init__() + self.out_channels = self.out_channels or self.in_channels + self.hidden_size = self.num_attention_heads * self.attention_head_dim + self.num_channels_latents = self.out_channels + + +@dataclass +class WanVideoConfig(DiTConfig): + arch_config: DiTArchConfig = field(default_factory=WanVideoArchConfig) + + prefix: str = "Wan" diff --git a/python/sglang/multimodal_gen/configs/models/encoders/__init__.py b/python/sglang/multimodal_gen/configs/models/encoders/__init__.py new file mode 100644 index 000000000..70851bfa5 --- /dev/null +++ b/python/sglang/multimodal_gen/configs/models/encoders/__init__.py @@ -0,0 +1,25 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +from sglang.multimodal_gen.configs.models.encoders.base import ( + BaseEncoderOutput, + EncoderConfig, + ImageEncoderConfig, + TextEncoderConfig, +) +from sglang.multimodal_gen.configs.models.encoders.clip import ( + CLIPTextConfig, + CLIPVisionConfig, +) +from sglang.multimodal_gen.configs.models.encoders.llama import LlamaConfig +from sglang.multimodal_gen.configs.models.encoders.t5 import T5Config + +__all__ = [ + "EncoderConfig", + "TextEncoderConfig", + "ImageEncoderConfig", + "BaseEncoderOutput", + "CLIPTextConfig", + "CLIPVisionConfig", + "LlamaConfig", + "T5Config", +] diff --git a/python/sglang/multimodal_gen/configs/models/encoders/base.py b/python/sglang/multimodal_gen/configs/models/encoders/base.py new file mode 100644 index 000000000..1ae63fe92 --- /dev/null +++ b/python/sglang/multimodal_gen/configs/models/encoders/base.py @@ -0,0 +1,85 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field +from typing import Any + +import torch + +from sglang.multimodal_gen.configs.models.base import ArchConfig, ModelConfig +from sglang.multimodal_gen.runtime.layers.quantization import QuantizationConfig +from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum + + +@dataclass +class EncoderArchConfig(ArchConfig): + architectures: list[str] = field(default_factory=lambda: []) + _supported_attention_backends: set[AttentionBackendEnum] = field( + default_factory=lambda: { + AttentionBackendEnum.FA3, + AttentionBackendEnum.TORCH_SDPA, + } + ) + output_hidden_states: bool = False + use_return_dict: bool = True + + +@dataclass +class TextEncoderArchConfig(EncoderArchConfig): + vocab_size: int = 0 + hidden_size: int = 0 + num_hidden_layers: int = 0 + num_attention_heads: int = 0 + pad_token_id: int = 0 + eos_token_id: int = 0 + text_len: int = 0 + hidden_state_skip_layer: int = 0 + decoder_start_token_id: int = 0 + output_past: bool = True + scalable_attention: bool = True + tie_word_embeddings: bool = False + stacked_params_mapping: list[tuple[str, str, str]] = field( + default_factory=list + ) # mapping from huggingface weight names to custom names + tokenizer_kwargs: dict[str, Any] = field(default_factory=dict) + _fsdp_shard_conditions: list = field(default_factory=lambda: []) + + def __post_init__(self) -> None: + self.tokenizer_kwargs = { + "truncation": True, + "max_length": self.text_len, + "return_tensors": "pt", + } + + +@dataclass +class ImageEncoderArchConfig(EncoderArchConfig): + pass + + +@dataclass +class BaseEncoderOutput: + last_hidden_state: torch.FloatTensor | None = None + pooler_output: torch.FloatTensor | None = None + hidden_states: tuple[torch.FloatTensor, ...] | None = None + attentions: tuple[torch.FloatTensor, ...] | None = None + attention_mask: torch.Tensor | None = None + + +@dataclass +class EncoderConfig(ModelConfig): + arch_config: ArchConfig = field(default_factory=EncoderArchConfig) + + prefix: str = "" + quant_config: QuantizationConfig | None = None + lora_config: Any | None = None + + +@dataclass +class TextEncoderConfig(EncoderConfig): + arch_config: ArchConfig = field(default_factory=TextEncoderArchConfig) + + +@dataclass +class ImageEncoderConfig(EncoderConfig): + arch_config: ArchConfig = field(default_factory=ImageEncoderArchConfig) diff --git a/python/sglang/multimodal_gen/configs/models/encoders/clip.py b/python/sglang/multimodal_gen/configs/models/encoders/clip.py new file mode 100644 index 000000000..6b36fc88b --- /dev/null +++ b/python/sglang/multimodal_gen/configs/models/encoders/clip.py @@ -0,0 +1,95 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field + +from sglang.multimodal_gen.configs.models.encoders.base import ( + ImageEncoderArchConfig, + ImageEncoderConfig, + TextEncoderArchConfig, + TextEncoderConfig, +) + + +def _is_transformer_layer(n: str, m) -> bool: + return "layers" in n and str.isdigit(n.split(".")[-1]) + + +def _is_embeddings(n: str, m) -> bool: + return n.endswith("embeddings") + + +@dataclass +class CLIPTextArchConfig(TextEncoderArchConfig): + vocab_size: int = 49408 + hidden_size: int = 512 + intermediate_size: int = 2048 + projection_dim: int = 512 + num_hidden_layers: int = 12 + num_attention_heads: int = 8 + max_position_embeddings: int = 77 + hidden_act: str = "quick_gelu" + layer_norm_eps: float = 1e-5 + dropout: float = 0.0 + attention_dropout: float = 0.0 + initializer_range: float = 0.02 + initializer_factor: float = 1.0 + pad_token_id: int = 1 + bos_token_id: int = 49406 + eos_token_id: int = 49407 + text_len: int = 77 + stacked_params_mapping: list[tuple[str, str, str]] = field( + default_factory=lambda: [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + ) + _fsdp_shard_conditions: list = field( + default_factory=lambda: [_is_transformer_layer, _is_embeddings] + ) + + +@dataclass +class CLIPVisionArchConfig(ImageEncoderArchConfig): + hidden_size: int = 768 + intermediate_size: int = 3072 + projection_dim: int = 512 + num_hidden_layers: int = 12 + num_attention_heads: int = 12 + num_channels: int = 3 + image_size: int = 224 + patch_size: int = 32 + hidden_act: str = "quick_gelu" + layer_norm_eps: float = 1e-5 + dropout: float = 0.0 + attention_dropout: float = 0.0 + initializer_range: float = 0.02 + initializer_factor: float = 1.0 + stacked_params_mapping: list[tuple[str, str, str]] = field( + default_factory=lambda: [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + ) + + +@dataclass +class CLIPTextConfig(TextEncoderConfig): + arch_config: TextEncoderArchConfig = field(default_factory=CLIPTextArchConfig) + + num_hidden_layers_override: int | None = None + require_post_norm: bool | None = None + prefix: str = "clip" + + +@dataclass +class CLIPVisionConfig(ImageEncoderConfig): + arch_config: ImageEncoderArchConfig = field(default_factory=CLIPVisionArchConfig) + + num_hidden_layers_override: int | None = None + require_post_norm: bool | None = None + prefix: str = "clip" diff --git a/python/sglang/multimodal_gen/configs/models/encoders/llama.py b/python/sglang/multimodal_gen/configs/models/encoders/llama.py new file mode 100644 index 000000000..41d98cab2 --- /dev/null +++ b/python/sglang/multimodal_gen/configs/models/encoders/llama.py @@ -0,0 +1,69 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field + +from sglang.multimodal_gen.configs.models.encoders.base import ( + TextEncoderArchConfig, + TextEncoderConfig, +) + + +def _is_transformer_layer(n: str, m) -> bool: + return "layers" in n and str.isdigit(n.split(".")[-1]) + + +def _is_embeddings(n: str, m) -> bool: + return n.endswith("embed_tokens") + + +def _is_final_norm(n: str, m) -> bool: + return n.endswith("norm") + + +@dataclass +class LlamaArchConfig(TextEncoderArchConfig): + vocab_size: int = 32000 + hidden_size: int = 4096 + intermediate_size: int = 11008 + num_hidden_layers: int = 32 + num_attention_heads: int = 32 + num_key_value_heads: int | None = None + hidden_act: str = "silu" + max_position_embeddings: int = 2048 + initializer_range: float = 0.02 + rms_norm_eps: float = 1e-6 + use_cache: bool = True + pad_token_id: int = 0 + bos_token_id: int = 1 + eos_token_id: int = 2 + pretraining_tp: int = 1 + tie_word_embeddings: bool = False + rope_theta: float = 10000.0 + rope_scaling: float | None = None + attention_bias: bool = False + attention_dropout: float = 0.0 + mlp_bias: bool = False + head_dim: int | None = None + hidden_state_skip_layer: int = 2 + text_len: int = 256 + stacked_params_mapping: list[tuple[str, str, str]] = field( + default_factory=lambda: [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), # type: ignore + (".gate_up_proj", ".up_proj", 1), # type: ignore + ] + ) + _fsdp_shard_conditions: list = field( + default_factory=lambda: [_is_transformer_layer, _is_embeddings, _is_final_norm] + ) + + +@dataclass +class LlamaConfig(TextEncoderConfig): + arch_config: TextEncoderArchConfig = field(default_factory=LlamaArchConfig) + + prefix: str = "llama" diff --git a/python/sglang/multimodal_gen/configs/models/encoders/qwen_image.py b/python/sglang/multimodal_gen/configs/models/encoders/qwen_image.py new file mode 100644 index 000000000..0a5f245f4 --- /dev/null +++ b/python/sglang/multimodal_gen/configs/models/encoders/qwen_image.py @@ -0,0 +1,67 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field + +from sglang.multimodal_gen.configs.models.encoders.base import ( + TextEncoderArchConfig, + TextEncoderConfig, +) + + +def _is_transformer_layer(n: str, m) -> bool: + return "layers" in n and str.isdigit(n.split(".")[-1]) + + +def _is_embeddings(n: str, m) -> bool: + return n.endswith("embed_tokens") + + +def _is_final_norm(n: str, m) -> bool: + return n.endswith("norm") + + +@dataclass +class QwenImageArchConfig(TextEncoderArchConfig): + vocab_size: int = 32000 + hidden_size: int = 4096 + intermediate_size: int = 11008 + num_hidden_layers: int = 32 + num_attention_heads: int = 32 + num_key_value_heads: int | None = None + hidden_act: str = "silu" + max_position_embeddings: int = 2048 + initializer_range: float = 0.02 + rms_norm_eps: float = 1e-6 + use_cache: bool = True + pad_token_id: int = -1 + eos_token_id: int = 2 + pretraining_tp: int = 1 + tie_word_embeddings: bool = False + rope_theta: float = 10000.0 + rope_scaling: float | None = None + attention_bias: bool = False + attention_dropout: float = 0.0 + mlp_bias: bool = False + head_dim: int | None = None + hidden_state_skip_layer: int = 2 + text_len: int = 256 + stacked_params_mapping: list[tuple[str, str, str]] = field( + default_factory=lambda: [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), # type: ignore + (".gate_up_proj", ".up_proj", 1), # type: ignore + ] + ) + _fsdp_shard_conditions: list = field( + default_factory=lambda: [_is_transformer_layer, _is_embeddings, _is_final_norm] + ) + + +@dataclass +class Qwen2_5VLConfig(TextEncoderConfig): + arch_config: TextEncoderArchConfig = field(default_factory=QwenImageArchConfig) + # prefix: str = "qwen_image" diff --git a/python/sglang/multimodal_gen/configs/models/encoders/t5.py b/python/sglang/multimodal_gen/configs/models/encoders/t5.py new file mode 100644 index 000000000..3fd9b2f1a --- /dev/null +++ b/python/sglang/multimodal_gen/configs/models/encoders/t5.py @@ -0,0 +1,86 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field + +from sglang.multimodal_gen.configs.models.encoders.base import ( + TextEncoderArchConfig, + TextEncoderConfig, +) + + +def _is_transformer_layer(n: str, m) -> bool: + return "block" in n and str.isdigit(n.split(".")[-1]) + + +def _is_embeddings(n: str, m) -> bool: + return n.endswith("shared") + + +def _is_final_layernorm(n: str, m) -> bool: + return n.endswith("final_layer_norm") + + +@dataclass +class T5ArchConfig(TextEncoderArchConfig): + vocab_size: int = 32128 + d_model: int = 512 + d_kv: int = 64 + d_ff: int = 2048 + num_layers: int = 6 + num_decoder_layers: int | None = None + num_heads: int = 8 + relative_attention_num_buckets: int = 32 + relative_attention_max_distance: int = 128 + dropout_rate: float = 0.1 + layer_norm_epsilon: float = 1e-6 + initializer_factor: float = 1.0 + feed_forward_proj: str = "relu" + dense_act_fn: str = "" + is_gated_act: bool = False + is_encoder_decoder: bool = True + use_cache: bool = True + pad_token_id: int = 0 + eos_token_id: int = 1 + classifier_dropout: float = 0.0 + text_len: int = 512 + stacked_params_mapping: list[tuple[str, str, str]] = field( + default_factory=lambda: [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q", "q"), + (".qkv_proj", ".k", "k"), + (".qkv_proj", ".v", "v"), + ] + ) + _fsdp_shard_conditions: list = field( + default_factory=lambda: [ + _is_transformer_layer, + _is_embeddings, + _is_final_layernorm, + ] + ) + + # Referenced from https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/configuration_t5.py + def __post_init__(self): + super().__post_init__() + act_info = self.feed_forward_proj.split("-") + self.dense_act_fn: str = act_info[-1] + self.is_gated_act: bool = act_info[0] == "gated" + if self.feed_forward_proj == "gated-gelu": + self.dense_act_fn = "gelu_new" + + self.tokenizer_kwargs = { + "padding": "max_length", + "truncation": True, + "max_length": self.text_len, + "add_special_tokens": True, + "return_attention_mask": True, + "return_tensors": "pt", + } + + +@dataclass +class T5Config(TextEncoderConfig): + arch_config: TextEncoderArchConfig = field(default_factory=T5ArchConfig) + + prefix: str = "t5" diff --git a/python/sglang/multimodal_gen/configs/models/vaes/__init__.py b/python/sglang/multimodal_gen/configs/models/vaes/__init__.py new file mode 100644 index 000000000..e9b478618 --- /dev/null +++ b/python/sglang/multimodal_gen/configs/models/vaes/__init__.py @@ -0,0 +1,11 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +from sglang.multimodal_gen.configs.models.vaes.hunyuanvae import HunyuanVAEConfig +from sglang.multimodal_gen.configs.models.vaes.stepvideovae import StepVideoVAEConfig +from sglang.multimodal_gen.configs.models.vaes.wanvae import WanVAEConfig + +__all__ = [ + "HunyuanVAEConfig", + "WanVAEConfig", + "StepVideoVAEConfig", +] diff --git a/python/sglang/multimodal_gen/configs/models/vaes/base.py b/python/sglang/multimodal_gen/configs/models/vaes/base.py new file mode 100644 index 000000000..3e31760e2 --- /dev/null +++ b/python/sglang/multimodal_gen/configs/models/vaes/base.py @@ -0,0 +1,158 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +import argparse +import dataclasses +from dataclasses import dataclass, field +from typing import Any + +import torch + +from sglang.multimodal_gen.configs.models.base import ArchConfig, ModelConfig +from sglang.multimodal_gen.runtime.models.vision_utils import get_default_height_width +from sglang.multimodal_gen.utils import StoreBoolean + + +@dataclass +class VAEArchConfig(ArchConfig): + scaling_factor: float | torch.Tensor = 0 + + temporal_compression_ratio: int = 4 + # or vae_scale_factor? + spatial_compression_ratio: int = 8 + + +@dataclass +class VAEConfig(ModelConfig): + arch_config: VAEArchConfig = field(default_factory=VAEArchConfig) + + # sgl-diffusionVAE-specific parameters + load_encoder: bool = True + load_decoder: bool = True + + tile_sample_min_height: int = 256 + tile_sample_min_width: int = 256 + tile_sample_min_num_frames: int = 16 + tile_sample_stride_height: int = 192 + tile_sample_stride_width: int = 192 + tile_sample_stride_num_frames: int = 12 + blend_num_frames: int = 0 + + use_tiling: bool = True + use_temporal_tiling: bool = True + use_parallel_tiling: bool = True + use_temporal_scaling_frames: bool = True + + def __post_init__(self): + self.blend_num_frames = ( + self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames + ) + + def post_init(self): + pass + + # returns width, height + def calculate_dimensions( + self, image, vae_scale_factor, width, height + ) -> tuple[int, int]: + height, width = get_default_height_width(image, vae_scale_factor, height, width) + return width, height + + @staticmethod + def add_cli_args(parser: Any, prefix: str = "vae-config") -> Any: + """Add CLI arguments for VAEConfig fields""" + parser.add_argument( + f"--{prefix}.load-encoder", + action=StoreBoolean, + dest=f"{prefix.replace('-', '_')}.load_encoder", + default=VAEConfig.load_encoder, + help="Whether to load the VAE encoder", + ) + parser.add_argument( + f"--{prefix}.load-decoder", + action=StoreBoolean, + dest=f"{prefix.replace('-', '_')}.load_decoder", + default=VAEConfig.load_decoder, + help="Whether to load the VAE decoder", + ) + parser.add_argument( + f"--{prefix}.tile-sample-min-height", + type=int, + dest=f"{prefix.replace('-', '_')}.tile_sample_min_height", + default=VAEConfig.tile_sample_min_height, + help="Minimum height for VAE tile sampling", + ) + parser.add_argument( + f"--{prefix}.tile-sample-min-width", + type=int, + dest=f"{prefix.replace('-', '_')}.tile_sample_min_width", + default=VAEConfig.tile_sample_min_width, + help="Minimum width for VAE tile sampling", + ) + parser.add_argument( + f"--{prefix}.tile-sample-min-num-frames", + type=int, + dest=f"{prefix.replace('-', '_')}.tile_sample_min_num_frames", + default=VAEConfig.tile_sample_min_num_frames, + help="Minimum number of frames for VAE tile sampling", + ) + parser.add_argument( + f"--{prefix}.tile-sample-stride-height", + type=int, + dest=f"{prefix.replace('-', '_')}.tile_sample_stride_height", + default=VAEConfig.tile_sample_stride_height, + help="Stride height for VAE tile sampling", + ) + parser.add_argument( + f"--{prefix}.tile-sample-stride-width", + type=int, + dest=f"{prefix.replace('-', '_')}.tile_sample_stride_width", + default=VAEConfig.tile_sample_stride_width, + help="Stride width for VAE tile sampling", + ) + parser.add_argument( + f"--{prefix}.tile-sample-stride-num-frames", + type=int, + dest=f"{prefix.replace('-', '_')}.tile_sample_stride_num_frames", + default=VAEConfig.tile_sample_stride_num_frames, + help="Stride number of frames for VAE tile sampling", + ) + parser.add_argument( + f"--{prefix}.blend-num-frames", + type=int, + dest=f"{prefix.replace('-', '_')}.blend_num_frames", + default=VAEConfig.blend_num_frames, + help="Number of frames to blend for VAE tile sampling", + ) + parser.add_argument( + f"--{prefix}.use-tiling", + action=StoreBoolean, + dest=f"{prefix.replace('-', '_')}.use_tiling", + default=VAEConfig.use_tiling, + help="Whether to use tiling for VAE", + ) + parser.add_argument( + f"--{prefix}.use-temporal-tiling", + action=StoreBoolean, + dest=f"{prefix.replace('-', '_')}.use_temporal_tiling", + default=VAEConfig.use_temporal_tiling, + help="Whether to use temporal tiling for VAE", + ) + parser.add_argument( + f"--{prefix}.use-parallel-tiling", + action=StoreBoolean, + dest=f"{prefix.replace('-', '_')}.use_parallel_tiling", + default=VAEConfig.use_parallel_tiling, + help="Whether to use parallel tiling for VAE", + ) + + return parser + + @classmethod + def from_cli_args(cls, args: argparse.Namespace) -> "VAEConfig": + kwargs = {} + for attr in dataclasses.fields(cls): + value = getattr(args, attr.name, None) + if value is not None: + kwargs[attr.name] = value + return cls(**kwargs) diff --git a/python/sglang/multimodal_gen/configs/models/vaes/flux.py b/python/sglang/multimodal_gen/configs/models/vaes/flux.py new file mode 100644 index 000000000..0b56149d9 --- /dev/null +++ b/python/sglang/multimodal_gen/configs/models/vaes/flux.py @@ -0,0 +1,50 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field + +from sglang.multimodal_gen.configs.models.vaes.base import VAEArchConfig, VAEConfig + + +@dataclass +class FluxVAEArchConfig(VAEArchConfig): + spatial_compression_ratio: int = 1 + + base_dim: int = 96 + decoder_base_dim: int | None = None + z_dim: int = 16 + dim_mult: tuple[int, ...] = (1, 2, 4, 4) + num_res_blocks: int = 2 + attn_scales: tuple[float, ...] = () + temperal_downsample: tuple[bool, ...] = (False, True, True) + dropout: float = 0.0 + + is_residual: bool = False + in_channels: int = 3 + out_channels: int = 3 + patch_size: int | None = None + scale_factor_temporal: int = 4 + scale_factor_spatial: int = 8 + clip_output: bool = True + + +@dataclass +class FluxVAEConfig(VAEConfig): + arch_config: FluxVAEArchConfig = field(default_factory=FluxVAEArchConfig) + + use_feature_cache: bool = True + + use_tiling: bool = False + use_temporal_tiling: bool = False + use_parallel_tiling: bool = False + + def __post_init__(self): + self.blend_num_frames = ( + self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames + ) * 2 + + def post_init(self): + self.arch_config.vae_scale_factor = 2 ** ( + len(self.arch_config.block_out_channels) - 1 + ) + self.arch_config.spatial_compression_ratio = self.arch_config.vae_scale_factor diff --git a/python/sglang/multimodal_gen/configs/models/vaes/hunyuanvae.py b/python/sglang/multimodal_gen/configs/models/vaes/hunyuanvae.py new file mode 100644 index 000000000..601b72d57 --- /dev/null +++ b/python/sglang/multimodal_gen/configs/models/vaes/hunyuanvae.py @@ -0,0 +1,41 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field + +from sglang.multimodal_gen.configs.models.vaes.base import VAEArchConfig, VAEConfig + + +@dataclass +class HunyuanVAEArchConfig(VAEArchConfig): + in_channels: int = 3 + out_channels: int = 3 + latent_channels: int = 16 + down_block_types: tuple[str, ...] = ( + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + ) + up_block_types: tuple[str, ...] = ( + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + ) + block_out_channels: tuple[int, ...] = (128, 256, 512, 512) + layers_per_block: int = 2 + act_fn: str = "silu" + norm_num_groups: int = 32 + scaling_factor: float = 0.476986 + spatial_compression_ratio: int = 8 + temporal_compression_ratio: int = 4 + mid_block_add_attention: bool = True + + def __post_init__(self): + self.spatial_compression_ratio: int = 2 ** (len(self.block_out_channels) - 1) + + +@dataclass +class HunyuanVAEConfig(VAEConfig): + arch_config: VAEArchConfig = field(default_factory=HunyuanVAEArchConfig) diff --git a/python/sglang/multimodal_gen/configs/models/vaes/qwenimage.py b/python/sglang/multimodal_gen/configs/models/vaes/qwenimage.py new file mode 100644 index 000000000..26375f351 --- /dev/null +++ b/python/sglang/multimodal_gen/configs/models/vaes/qwenimage.py @@ -0,0 +1,61 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field + +from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit import calculate_dimensions + +from sglang.multimodal_gen.configs.models.vaes.base import VAEArchConfig, VAEConfig + + +@dataclass +class QwenImageVAEArchConfig(VAEArchConfig): + spatial_compression_ratio: int = 1 + + base_dim: int = 96 + decoder_base_dim: int | None = None + z_dim: int = 16 + dim_mult: tuple[int, ...] = (1, 2, 4, 4) + num_res_blocks: int = 2 + attn_scales: tuple[float, ...] = () + temperal_downsample: tuple[bool, ...] = (False, True, True) + dropout: float = 0.0 + + is_residual: bool = False + in_channels: int = 3 + out_channels: int = 3 + patch_size: int | None = None + scale_factor_temporal: int = 4 + scale_factor_spatial: int = 8 + clip_output: bool = True + + def __post_init__(self): + self.vae_scale_factor = 2 ** len(self.temperal_downsample) + + +@dataclass +class QwenImageVAEConfig(VAEConfig): + arch_config: QwenImageVAEArchConfig = field(default_factory=QwenImageVAEArchConfig) + + use_feature_cache: bool = True + + use_tiling: bool = False + use_temporal_tiling: bool = False + use_parallel_tiling: bool = False + + def calculate_dimensions(self, image, vae_scale_factor, width, height): + width = image.size[0] + height = image.size[1] + width, height, _ = calculate_dimensions(1024 * 1024, width / height) + return width, height + + def __post_init__(self): + self.blend_num_frames = ( + self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames + ) * 2 + + def post_init(self): + self.arch_config.vae_scale_factor = 2 ** ( + len(self.arch_config.temperal_downsample) + ) + self.arch_config.spatial_compression_ratio = self.arch_config.vae_scale_factor diff --git a/python/sglang/multimodal_gen/configs/models/vaes/stepvideovae.py b/python/sglang/multimodal_gen/configs/models/vaes/stepvideovae.py new file mode 100644 index 000000000..6794e9792 --- /dev/null +++ b/python/sglang/multimodal_gen/configs/models/vaes/stepvideovae.py @@ -0,0 +1,31 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field + +from sglang.multimodal_gen.configs.models.vaes.base import VAEArchConfig, VAEConfig + + +@dataclass +class StepVideoVAEArchConfig(VAEArchConfig): + in_channels: int = 3 + out_channels: int = 3 + z_channels: int = 64 + num_res_blocks: int = 2 + version: int = 2 + frame_len: int = 17 + world_size: int = 1 + + spatial_compression_ratio: int = 16 + temporal_compression_ratio: int = 8 + + scaling_factor: float = 1.0 + + +@dataclass +class StepVideoVAEConfig(VAEConfig): + arch_config: VAEArchConfig = field(default_factory=StepVideoVAEArchConfig) + use_tiling: bool = False + use_temporal_tiling: bool = False + use_parallel_tiling: bool = False + use_temporal_scaling_frames: bool = False diff --git a/python/sglang/multimodal_gen/configs/models/vaes/wanvae.py b/python/sglang/multimodal_gen/configs/models/vaes/wanvae.py new file mode 100644 index 000000000..a1bd77ebf --- /dev/null +++ b/python/sglang/multimodal_gen/configs/models/vaes/wanvae.py @@ -0,0 +1,88 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field + +import torch + +from sglang.multimodal_gen.configs.models.vaes.base import VAEArchConfig, VAEConfig + + +@dataclass +class WanVAEArchConfig(VAEArchConfig): + base_dim: int = 96 + decoder_base_dim: int | None = None + z_dim: int = 16 + dim_mult: tuple[int, ...] = (1, 2, 4, 4) + num_res_blocks: int = 2 + attn_scales: tuple[float, ...] = () + temperal_downsample: tuple[bool, ...] = (False, True, True) + dropout: float = 0.0 + latents_mean: tuple[float, ...] = ( + -0.7571, + -0.7089, + -0.9113, + 0.1075, + -0.1745, + 0.9653, + -0.1517, + 1.5508, + 0.4134, + -0.0715, + 0.5517, + -0.3632, + -0.1922, + -0.9497, + 0.2503, + -0.2921, + ) + latents_std: tuple[float, ...] = ( + 2.8184, + 1.4541, + 2.3275, + 2.6558, + 1.2196, + 1.7708, + 2.6052, + 2.0743, + 3.2687, + 2.1526, + 2.8652, + 1.5579, + 1.6382, + 1.1253, + 2.8251, + 1.9160, + ) + is_residual: bool = False + in_channels: int = 3 + out_channels: int = 3 + patch_size: int | None = None + scale_factor_temporal: int = 4 + scale_factor_spatial: int = 8 + clip_output: bool = True + + def __post_init__(self): + self.scaling_factor: torch.tensor = 1.0 / torch.tensor(self.latents_std).view( + 1, self.z_dim, 1, 1, 1 + ) + self.shift_factor: torch.tensor = torch.tensor(self.latents_mean).view( + 1, self.z_dim, 1, 1, 1 + ) + self.temporal_compression_ratio = self.scale_factor_temporal + self.spatial_compression_ratio = self.scale_factor_spatial + + +@dataclass +class WanVAEConfig(VAEConfig): + arch_config: WanVAEArchConfig = field(default_factory=WanVAEArchConfig) + use_feature_cache: bool = True + + use_tiling: bool = False + use_temporal_tiling: bool = False + use_parallel_tiling: bool = False + + def __post_init__(self): + self.blend_num_frames = ( + self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames + ) * 2 diff --git a/python/sglang/multimodal_gen/configs/pipelines/__init__.py b/python/sglang/multimodal_gen/configs/pipelines/__init__.py new file mode 100644 index 000000000..5db869f31 --- /dev/null +++ b/python/sglang/multimodal_gen/configs/pipelines/__init__.py @@ -0,0 +1,37 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +from sglang.multimodal_gen.configs.pipelines.base import ( + PipelineConfig, + SlidingTileAttnConfig, +) +from sglang.multimodal_gen.configs.pipelines.flux import FluxPipelineConfig +from sglang.multimodal_gen.configs.pipelines.hunyuan import ( + FastHunyuanConfig, + HunyuanConfig, +) +from sglang.multimodal_gen.configs.pipelines.registry import ( + get_pipeline_config_cls_from_name, +) +from sglang.multimodal_gen.configs.pipelines.stepvideo import StepVideoT2VConfig +from sglang.multimodal_gen.configs.pipelines.wan import ( + SelfForcingWanT2V480PConfig, + WanI2V480PConfig, + WanI2V720PConfig, + WanT2V480PConfig, + WanT2V720PConfig, +) + +__all__ = [ + "HunyuanConfig", + "FastHunyuanConfig", + "FluxPipelineConfig", + "PipelineConfig", + "SlidingTileAttnConfig", + "WanT2V480PConfig", + "WanI2V480PConfig", + "WanT2V720PConfig", + "WanI2V720PConfig", + "StepVideoT2VConfig", + "SelfForcingWanT2V480PConfig", + "get_pipeline_config_cls_from_name", +] diff --git a/python/sglang/multimodal_gen/configs/pipelines/base.py b/python/sglang/multimodal_gen/configs/pipelines/base.py new file mode 100644 index 000000000..9451f59d3 --- /dev/null +++ b/python/sglang/multimodal_gen/configs/pipelines/base.py @@ -0,0 +1,485 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +import json +from collections.abc import Callable +from dataclasses import asdict, dataclass, field, fields +from enum import Enum +from typing import Any, cast + +import torch +from diffusers.image_processor import VaeImageProcessor + +from sglang.multimodal_gen.configs.models import ( + DiTConfig, + EncoderConfig, + ModelConfig, + VAEConfig, +) +from sglang.multimodal_gen.configs.models.encoders import BaseEncoderOutput +from sglang.multimodal_gen.configs.utils import update_config_from_args +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import ( + FlexibleArgumentParser, + StoreBoolean, + shallow_asdict, +) + +logger = init_logger(__name__) + + +class STA_Mode(str, Enum): + """STA (Sliding Tile Attention) modes.""" + + STA_INFERENCE = "STA_inference" + STA_SEARCHING = "STA_searching" + STA_TUNING = "STA_tuning" + STA_TUNING_CFG = "STA_tuning_cfg" + NONE = None + + +def preprocess_text(prompt: str) -> str: + return prompt + + +def postprocess_text(output: BaseEncoderOutput, _text_inputs) -> torch.tensor: + raise NotImplementedError + + +# config for a single pipeline +@dataclass +class PipelineConfig: + """Base configuration for all pipeline architectures.""" + + model_path: str = "" + pipeline_config_path: str | None = None + + is_image_gen: bool = False + + # generation parameters + # controls the timestep embedding generation + should_use_guidance: bool = True + embedded_cfg_scale: float = 6.0 + flow_shift: float | None = None + disable_autocast: bool = False + + # Model configuration + dit_config: DiTConfig = field(default_factory=DiTConfig) + dit_precision: str = "bf16" + + # VAE configuration + vae_config: VAEConfig = field(default_factory=VAEConfig) + vae_precision: str = "fp32" + vae_tiling: bool = True + vae_sp: bool = True + + # Image encoder configuration + image_encoder_config: EncoderConfig = field(default_factory=EncoderConfig) + image_encoder_precision: str = "fp32" + + # Text encoder configuration + DEFAULT_TEXT_ENCODER_PRECISIONS = ("fp32",) + text_encoder_configs: tuple[EncoderConfig, ...] = field( + default_factory=lambda: (EncoderConfig(),) + ) + # See PRECISION_TO_TYPE for detailed mapping + text_encoder_precisions: tuple[str, ...] = field(default_factory=lambda: ("fp32",)) + text_encoder_extra_args: list[dict] = field(default_factory=lambda: [{}]) + + # image encoding + image_encoder_extra_args: dict = field(default_factory=lambda: {}) + + def postprocess_image(self, image): + return image.last_hidden_state + + preprocess_text_funcs: tuple[Callable[[str], str], ...] = field( + default_factory=lambda: (preprocess_text,) + ) + postprocess_text_funcs: tuple[Callable[[BaseEncoderOutput], torch.tensor], ...] = ( + field(default_factory=lambda: (postprocess_text,)) + ) + + # StepVideo specific parameters + pos_magic: str | None = None + neg_magic: str | None = None + timesteps_scale: bool | None = None + + # STA (Sliding Tile Attention) parameters + mask_strategy_file_path: str | None = None + STA_mode: STA_Mode = STA_Mode.STA_INFERENCE + skip_time_steps: int = 15 + + # DMD parameters + dmd_denoising_steps: list[int] | None = field(default=None) + + # Wan2.2 TI2V parameters + ti2v_task: bool = False + i2v_task: bool = False + ti2i_task: bool = False + boundary_ratio: float | None = None + + # Compilation + # enable_torch_compile: bool = False + + def slice_noise_pred(self, noise, latents): + return noise + + def set_width_and_height(self, width, height, image): + """ + image: input image + """ + return width, height + + # called in ImageEncodingStage, preprocess the image + def preprocess_image(self, image, image_processor: VaeImageProcessor): + return image + + def prepare_latent_shape(self, batch, batch_size, num_frames): + height = batch.height // self.vae_config.arch_config.spatial_compression_ratio + width = batch.width // self.vae_config.arch_config.spatial_compression_ratio + + # Calculate latent shape + shape = ( + batch_size, + self.dit_config.num_channels_latents, + num_frames, + height, + width, + ) + + return shape + + # called after latents are prepared + def pack_latents(self, latents, batch_size, batch): + return latents + + def get_pos_prompt_embeds(self, batch): + return batch.prompt_embeds + + def get_neg_prompt_embeds(self, batch): + return batch.negative_prompt_embeds + + def post_denoising_loop(self, latents, batch): + return latents + + def prepare_pos_cond_kwargs(self, batch, device, rotary_emb, dtype): + return {} + + def prepare_neg_cond_kwargs(self, batch, device, rotary_emb, dtype): + return {} + + @staticmethod + def add_cli_args( + parser: FlexibleArgumentParser, prefix: str = "" + ) -> FlexibleArgumentParser: + prefix_with_dot = f"{prefix}." if (prefix.strip() != "") else "" + + # model_path will be conflicting with the model_path in ServerArgs, + # so we add it separately if prefix is not empty + if prefix_with_dot != "": + parser.add_argument( + f"--{prefix_with_dot}model-path", + type=str, + dest=f"{prefix_with_dot.replace('-', '_')}model_path", + default=PipelineConfig.model_path, + help="Path to the pretrained model", + ) + + parser.add_argument( + f"--{prefix_with_dot}pipeline-config-path", + type=str, + dest=f"{prefix_with_dot.replace('-', '_')}pipeline_config_path", + default=PipelineConfig.pipeline_config_path, + help="Path to the pipeline config", + ) + parser.add_argument( + f"--{prefix_with_dot}embedded-cfg-scale", + type=float, + dest=f"{prefix_with_dot.replace('-', '_')}embedded_cfg_scale", + default=PipelineConfig.embedded_cfg_scale, + help="Embedded CFG scale", + ) + parser.add_argument( + f"--{prefix_with_dot}flow-shift", + type=float, + dest=f"{prefix_with_dot.replace('-', '_')}flow_shift", + default=PipelineConfig.flow_shift, + help="Flow shift parameter", + ) + + # DiT configuration + parser.add_argument( + f"--{prefix_with_dot}dit-precision", + type=str, + dest=f"{prefix_with_dot.replace('-', '_')}dit_precision", + default=PipelineConfig.dit_precision, + choices=["fp32", "fp16", "bf16"], + help="Precision for the DiT model", + ) + + # VAE configuration + parser.add_argument( + f"--{prefix_with_dot}vae-precision", + type=str, + dest=f"{prefix_with_dot.replace('-', '_')}vae_precision", + default=PipelineConfig.vae_precision, + choices=["fp32", "fp16", "bf16"], + help="Precision for VAE", + ) + parser.add_argument( + f"--{prefix_with_dot}vae-tiling", + action=StoreBoolean, + dest=f"{prefix_with_dot.replace('-', '_')}vae_tiling", + default=PipelineConfig.vae_tiling, + help="Enable VAE tiling", + ) + parser.add_argument( + f"--{prefix_with_dot}vae-sp", + action=StoreBoolean, + dest=f"{prefix_with_dot.replace('-', '_')}vae_sp", + help="Enable VAE spatial parallelism", + ) + + # Text encoder configuration + parser.add_argument( + f"--{prefix_with_dot}text-encoder-precisions", + nargs="+", + type=str, + dest=f"{prefix_with_dot.replace('-', '_')}text_encoder_precisions", + default=PipelineConfig.DEFAULT_TEXT_ENCODER_PRECISIONS, + choices=["fp32", "fp16", "bf16"], + help="Precision for each text encoder", + ) + + # Image encoder configuration + parser.add_argument( + f"--{prefix_with_dot}image-encoder-precision", + type=str, + dest=f"{prefix_with_dot.replace('-', '_')}image_encoder_precision", + default=PipelineConfig.image_encoder_precision, + choices=["fp32", "fp16", "bf16"], + help="Precision for image encoder", + ) + parser.add_argument( + f"--{prefix_with_dot}pos_magic", + type=str, + dest=f"{prefix_with_dot.replace('-', '_')}pos_magic", + default=PipelineConfig.pos_magic, + help="Positive magic prompt for sampling, used in stepvideo", + ) + parser.add_argument( + f"--{prefix_with_dot}neg_magic", + type=str, + dest=f"{prefix_with_dot.replace('-', '_')}neg_magic", + default=PipelineConfig.neg_magic, + help="Negative magic prompt for sampling, used in stepvideo", + ) + parser.add_argument( + f"--{prefix_with_dot}timesteps_scale", + type=bool, + dest=f"{prefix_with_dot.replace('-', '_')}timesteps_scale", + default=PipelineConfig.timesteps_scale, + help="Bool for applying scheduler scale in set_timesteps, used in stepvideo", + ) + + # DMD parameters + parser.add_argument( + f"--{prefix_with_dot}dmd-denoising-steps", + type=parse_int_list, + default=PipelineConfig.dmd_denoising_steps, + help="Comma-separated list of denoising steps (e.g., '1000,757,522')", + ) + + # Add VAE configuration arguments + from sglang.multimodal_gen.configs.models.vaes.base import VAEConfig + + VAEConfig.add_cli_args(parser, prefix=f"{prefix_with_dot}vae-config") + + # Add DiT configuration arguments + from sglang.multimodal_gen.configs.models.dits.base import DiTConfig + + DiTConfig.add_cli_args(parser, prefix=f"{prefix_with_dot}dit-config") + + return parser + + def update_config_from_dict(self, args: dict[str, Any], prefix: str = "") -> None: + prefix_with_dot = f"{prefix}." if (prefix.strip() != "") else "" + update_config_from_args(self, args, prefix, pop_args=True) + update_config_from_args( + self.vae_config, args, f"{prefix_with_dot}vae_config", pop_args=True + ) + update_config_from_args( + self.dit_config, args, f"{prefix_with_dot}dit_config", pop_args=True + ) + + @classmethod + def from_pretrained(cls, model_path: str) -> "PipelineConfig": + """ + use the pipeline class setting from model_path to match the pipeline config + """ + from sglang.multimodal_gen.configs.pipelines.registry import ( + get_pipeline_config_cls_from_name, + ) + + pipeline_config_cls = get_pipeline_config_cls_from_name(model_path) + + return cast(PipelineConfig, pipeline_config_cls(model_path=model_path)) + + @classmethod + def from_kwargs( + cls, kwargs: dict[str, Any], config_cli_prefix: str = "" + ) -> "PipelineConfig": + """ + Load PipelineConfig from kwargs Dictionary. + kwargs: dictionary of kwargs + config_cli_prefix: prefix of CLI arguments for this PipelineConfig instance + """ + from sglang.multimodal_gen.configs.pipelines.registry import ( + get_pipeline_config_cls_from_name, + ) + + prefix_with_dot = ( + f"{config_cli_prefix}." if (config_cli_prefix.strip() != "") else "" + ) + model_path: str | None = kwargs.get( + prefix_with_dot + "model_path", None + ) or kwargs.get("model_path") + pipeline_config_or_path: str | PipelineConfig | dict[str, Any] | None = ( + kwargs.get(prefix_with_dot + "pipeline_config", None) + or kwargs.get("pipeline_config") + ) + if model_path is None: + raise ValueError("model_path is required in kwargs") + + # 1. Get the pipeline config class from the registry + pipeline_config_cls = get_pipeline_config_cls_from_name(model_path) + + # 2. Instantiate PipelineConfig + if pipeline_config_cls is None: + logger.warning( + "Couldn't find pipeline config for %s. Using the default pipeline config.", + model_path, + ) + pipeline_config = cls() + else: + pipeline_config = pipeline_config_cls() + + # 3. Load PipelineConfig from a json file or a PipelineConfig object if provided + if isinstance(pipeline_config_or_path, str): + pipeline_config.load_from_json(pipeline_config_or_path) + kwargs[prefix_with_dot + "pipeline_config_path"] = pipeline_config_or_path + elif isinstance(pipeline_config_or_path, PipelineConfig): + pipeline_config = pipeline_config_or_path + elif isinstance(pipeline_config_or_path, dict): + pipeline_config.update_pipeline_config(pipeline_config_or_path) + + # 4. Update PipelineConfig from CLI arguments if provided + kwargs[prefix_with_dot + "model_path"] = model_path + pipeline_config.update_config_from_dict(kwargs, config_cli_prefix) + return pipeline_config + + def check_pipeline_config(self) -> None: + if self.vae_sp and not self.vae_tiling: + raise ValueError( + "Currently enabling vae_sp requires enabling vae_tiling, please set --vae-tiling to True." + ) + + if len(self.text_encoder_configs) != len(self.text_encoder_precisions): + raise ValueError( + f"Length of text encoder configs ({len(self.text_encoder_configs)}) must be equal to length of text encoder precisions ({len(self.text_encoder_precisions)})" + ) + + if len(self.text_encoder_configs) != len(self.preprocess_text_funcs): + raise ValueError( + f"Length of text encoder configs ({len(self.text_encoder_configs)}) must be equal to length of text preprocessing functions ({len(self.preprocess_text_funcs)})" + ) + + if len(self.preprocess_text_funcs) != len(self.postprocess_text_funcs): + raise ValueError( + f"Length of text postprocess functions ({len(self.postprocess_text_funcs)}) must be equal to length of text preprocessing functions ({len(self.preprocess_text_funcs)})" + ) + + def dump_to_json(self, file_path: str): + output_dict = shallow_asdict(self) + del_keys = [] + for key, value in output_dict.items(): + if isinstance(value, ModelConfig): + model_dict = asdict(value) + # Model Arch Config should be hidden away from the users + model_dict.pop("arch_config") + output_dict[key] = model_dict + elif isinstance(value, tuple) and all( + isinstance(v, ModelConfig) for v in value + ): + model_dicts = [] + for v in value: + model_dict = asdict(v) + # Model Arch Config should be hidden away from the users + model_dict.pop("arch_config") + model_dicts.append(model_dict) + output_dict[key] = model_dicts + elif isinstance(value, tuple) and all(callable(f) for f in value): + # Skip dumping functions + del_keys.append(key) + + for key in del_keys: + output_dict.pop(key, None) + + with open(file_path, "w") as f: + json.dump(output_dict, f, indent=2) + + def load_from_json(self, file_path: str): + with open(file_path) as f: + input_pipeline_dict = json.load(f) + self.update_pipeline_config(input_pipeline_dict) + + def update_pipeline_config(self, source_pipeline_dict: dict[str, Any]) -> None: + for f in fields(self): + key = f.name + if key in source_pipeline_dict: + current_value = getattr(self, key) + new_value = source_pipeline_dict[key] + + # If it's a nested ModelConfig, update it recursively + if isinstance(current_value, ModelConfig): + current_value.update_model_config(new_value) + elif isinstance(current_value, tuple) and all( + isinstance(v, ModelConfig) for v in current_value + ): + assert len(current_value) == len( + new_value + ), "Users shouldn't delete or add text encoder config objects in your json" + for target_config, source_config in zip( + current_value, new_value, strict=True + ): + target_config.update_model_config(source_config) + else: + setattr(self, key, new_value) + + if hasattr(self, "__post_init__"): + self.__post_init__() + + +@dataclass +class SlidingTileAttnConfig(PipelineConfig): + """Configuration for sliding tile attention.""" + + # Override any BaseConfig defaults as needed + # Add sliding tile specific parameters + window_size: int = 16 + stride: int = 8 + + # You can provide custom defaults for inherited fields + height: int = 576 + width: int = 1024 + + # Additional configuration specific to sliding tile attention + pad_to_square: bool = False + use_overlap_optimization: bool = True + + +def parse_int_list(value: str) -> list[int]: + """Parse a comma-separated string of integers into a list.""" + if not value: + return [] + return [int(x.strip()) for x in value.split(",")] diff --git a/python/sglang/multimodal_gen/configs/pipelines/flux.py b/python/sglang/multimodal_gen/configs/pipelines/flux.py new file mode 100644 index 000000000..a5348ec25 --- /dev/null +++ b/python/sglang/multimodal_gen/configs/pipelines/flux.py @@ -0,0 +1,174 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +from dataclasses import dataclass, field +from typing import Callable + +import torch + +from sglang.multimodal_gen.configs.models import DiTConfig, EncoderConfig, VAEConfig +from sglang.multimodal_gen.configs.models.dits.flux import FluxConfig +from sglang.multimodal_gen.configs.models.encoders import ( + BaseEncoderOutput, + CLIPTextConfig, + T5Config, +) +from sglang.multimodal_gen.configs.models.vaes.flux import FluxVAEConfig +from sglang.multimodal_gen.configs.pipelines.base import PipelineConfig, preprocess_text +from sglang.multimodal_gen.configs.pipelines.hunyuan import ( + clip_postprocess_text, + clip_preprocess_text, +) + + +def t5_postprocess_text(outputs: BaseEncoderOutput, _text_inputs) -> torch.Tensor: + return outputs.last_hidden_state + + +@dataclass +class FluxPipelineConfig(PipelineConfig): + # FIXME: duplicate with SamplingParams.guidance_scale? + embedded_cfg_scale: float = 3.5 + + is_image_gen: bool = True + + vae_tiling: bool = False + + vae_sp: bool = False + + dit_config: DiTConfig = field(default_factory=FluxConfig) + # VAE + vae_config: VAEConfig = field(default_factory=FluxVAEConfig) + + # Text encoding stage + text_encoder_configs: tuple[EncoderConfig, ...] = field( + default_factory=lambda: (CLIPTextConfig(), T5Config()) + ) + + text_encoder_precisions: tuple[str, ...] = field( + default_factory=lambda: ("bf16", "bf16") + ) + + preprocess_text_funcs: tuple[Callable[[str], str], ...] = field( + default_factory=lambda: (clip_preprocess_text, preprocess_text), + ) + + postprocess_text_funcs: tuple[Callable[[str], str], ...] = field( + default_factory=lambda: (clip_postprocess_text, t5_postprocess_text) + ) + + text_encoder_extra_args: list[dict] = field( + default_factory=lambda: [ + dict( + max_length=77, + padding="max_length", + truncation=True, + return_overflowing_tokens=False, + return_length=False, + ), + None, + ] + ) + + def prepare_latent_shape(self, batch, batch_size, num_frames): + height = 2 * ( + batch.height // (self.vae_config.arch_config.vae_scale_factor * 2) + ) + width = 2 * (batch.width // (self.vae_config.arch_config.vae_scale_factor * 2)) + num_channels_latents = self.dit_config.arch_config.in_channels // 4 + shape = (batch_size, num_channels_latents, height, width) + return shape + + def pack_latents(self, latents, batch_size, batch): + height = 2 * ( + batch.height // (self.vae_config.arch_config.vae_scale_factor * 2) + ) + width = 2 * (batch.width // (self.vae_config.arch_config.vae_scale_factor * 2)) + num_channels_latents = self.dit_config.arch_config.in_channels // 4 + # pack latents + latents = latents.view( + batch_size, num_channels_latents, height // 2, 2, width // 2, 2 + ) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape( + batch_size, (height // 2) * (width // 2), num_channels_latents * 4 + ) + return latents + + def get_pos_prompt_embeds(self, batch): + return batch.prompt_embeds[1] + + def get_neg_prompt_embeds(self, batch): + return batch.negative_prompt_embeds[1] + + def _prepare_latent_image_ids(self, original_height, original_width, device): + vae_scale_factor = self.vae_config.arch_config.vae_scale_factor + height = int(original_height) // (vae_scale_factor * 2) + width = int(original_width) // (vae_scale_factor * 2) + latent_image_ids = torch.zeros(height, width, 3, device=device) + latent_image_ids[..., 1] = ( + latent_image_ids[..., 1] + torch.arange(height, device=device)[:, None] + ) + latent_image_ids[..., 2] = ( + latent_image_ids[..., 2] + torch.arange(width, device=device)[None, :] + ) + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = ( + latent_image_ids.shape + ) + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids + + def get_freqs_cis(self, prompt_embeds, width, height, device, rotary_emb): + txt_ids = torch.zeros(prompt_embeds.shape[1], 3, device=device) + img_ids = self._prepare_latent_image_ids( + original_height=height, + original_width=width, + device=device, + ) + ids = torch.cat([txt_ids, img_ids], dim=0).to(device=device) + # NOTE(mick): prepare it here, to avoid unnecessary computations + freqs_cis = rotary_emb.forward(ids) + return freqs_cis + + def post_denoising_loop(self, latents, batch): + # unpack latents for flux + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + batch_size = latents.shape[0] + channels = latents.shape[-1] + vae_scale_factor = self.vae_config.arch_config.vae_scale_factor + height = 2 * (int(batch.height) // (vae_scale_factor * 2)) + width = 2 * (int(batch.width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + return latents + + def prepare_pos_cond_kwargs(self, batch, device, rotary_emb, dtype): + return { + "freqs_cis": self.get_freqs_cis( + batch.prompt_embeds[1], batch.width, batch.height, device, rotary_emb + ), + "pooled_projections": ( + batch.pooled_embeds[0] if batch.pooled_embeds else None + ), + } + + def prepare_neg_cond_kwargs(self, batch, device, rotary_emb, dtype): + return { + "freqs_cis": self.get_freqs_cis( + batch.negative_prompt_embeds[1], + batch.width, + batch.height, + device, + rotary_emb, + ), + "pooled_projections": ( + batch.neg_pooled_embeds[0] if batch.neg_pooled_embeds else None + ), + } diff --git a/python/sglang/multimodal_gen/configs/pipelines/hunyuan.py b/python/sglang/multimodal_gen/configs/pipelines/hunyuan.py new file mode 100644 index 000000000..73ede7d07 --- /dev/null +++ b/python/sglang/multimodal_gen/configs/pipelines/hunyuan.py @@ -0,0 +1,109 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import TypedDict + +import torch + +from sglang.multimodal_gen.configs.models import DiTConfig, EncoderConfig, VAEConfig +from sglang.multimodal_gen.configs.models.dits import HunyuanVideoConfig +from sglang.multimodal_gen.configs.models.encoders import ( + BaseEncoderOutput, + CLIPTextConfig, + LlamaConfig, +) +from sglang.multimodal_gen.configs.models.vaes import HunyuanVAEConfig +from sglang.multimodal_gen.configs.pipelines.base import PipelineConfig + +PROMPT_TEMPLATE_ENCODE_VIDEO = ( + "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " + "1. The main content and theme of the video." + "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." + "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." + "4. background environment, light, style and atmosphere." + "5. camera angles, movements, and transitions used in the video:<|eot_id|>" + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" +) + + +class PromptTemplate(TypedDict): + template: str + crop_start: int + + +prompt_template_video: PromptTemplate = { + "template": PROMPT_TEMPLATE_ENCODE_VIDEO, + "crop_start": 95, +} + + +def llama_preprocess_text(prompt: str) -> str: + return prompt_template_video["template"].format(prompt) + + +def llama_postprocess_text(outputs: BaseEncoderOutput, _text_inputs) -> torch.tensor: + hidden_state_skip_layer = 2 + assert outputs.hidden_states is not None + hidden_states: tuple[torch.Tensor, ...] = outputs.hidden_states + last_hidden_state: torch.tensor = hidden_states[-(hidden_state_skip_layer + 1)] + crop_start = prompt_template_video.get("crop_start", -1) + last_hidden_state = last_hidden_state[:, crop_start:] + return last_hidden_state + + +def clip_preprocess_text(prompt: str) -> str: + return prompt + + +def clip_postprocess_text(outputs: BaseEncoderOutput, _text_inputs) -> torch.tensor: + pooler_output: torch.tensor = outputs.pooler_output + return pooler_output + + +@dataclass +class HunyuanConfig(PipelineConfig): + """Base configuration for HunYuan pipeline architecture.""" + + # HunyuanConfig-specific parameters with defaults + # DiT + dit_config: DiTConfig = field(default_factory=HunyuanVideoConfig) + # VAE + vae_config: VAEConfig = field(default_factory=HunyuanVAEConfig) + # Denoising stage + embedded_cfg_scale: int = 6 + flow_shift: int = 7 + + # Text encoding stage + text_encoder_configs: tuple[EncoderConfig, ...] = field( + default_factory=lambda: (LlamaConfig(), CLIPTextConfig()) + ) + preprocess_text_funcs: tuple[Callable[[str], str], ...] = field( + default_factory=lambda: (llama_preprocess_text, clip_preprocess_text) + ) + postprocess_text_funcs: tuple[Callable[[BaseEncoderOutput], torch.tensor], ...] = ( + field(default_factory=lambda: (llama_postprocess_text, clip_postprocess_text)) + ) + + # Precision for each component + dit_precision: str = "bf16" + vae_precision: str = "fp16" + text_encoder_precisions: tuple[str, ...] = field( + default_factory=lambda: ("fp16", "fp16") + ) + + def __post_init__(self): + self.vae_config.load_encoder = False + self.vae_config.load_decoder = True + + +@dataclass +class FastHunyuanConfig(HunyuanConfig): + """Configuration specifically optimized for FastHunyuan weights.""" + + # Override HunyuanConfig defaults + flow_shift: int = 17 + + # No need to re-specify guidance_scale or embedded_cfg_scale as they + # already have the desired values from HunyuanConfig diff --git a/python/sglang/multimodal_gen/configs/pipelines/qwen_image.py b/python/sglang/multimodal_gen/configs/pipelines/qwen_image.py new file mode 100644 index 000000000..e4b702c84 --- /dev/null +++ b/python/sglang/multimodal_gen/configs/pipelines/qwen_image.py @@ -0,0 +1,299 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +from dataclasses import dataclass, field +from typing import Callable + +import torch +from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit import calculate_dimensions + +from sglang.multimodal_gen.configs.models import DiTConfig, EncoderConfig, VAEConfig +from sglang.multimodal_gen.configs.models.dits.qwenimage import QwenImageDitConfig +from sglang.multimodal_gen.configs.models.encoders.qwen_image import Qwen2_5VLConfig +from sglang.multimodal_gen.configs.models.vaes.qwenimage import QwenImageVAEConfig +from sglang.multimodal_gen.configs.pipelines.base import PipelineConfig + + +def _extract_masked_hidden(hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + + return split_result + + +def qwen_image_preprocess_text(prompt): + prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + + template = prompt_template_encode + txt = template.format(prompt) + return txt + + +def qwen_image_postprocess_text(outputs, _text_inputs, drop_idx=34): + # squeeze the batch dim + hidden_states = outputs.hidden_states[-1] + split_hidden_states = _extract_masked_hidden( + hidden_states, _text_inputs.attention_mask + ) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [ + torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) + for u in split_hidden_states + ] + ) + return prompt_embeds + + +# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents +def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view( + batch_size, num_channels_latents, height // 2, 2, width // 2, 2 + ) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape( + batch_size, (height // 2) * (width // 2), num_channels_latents * 4 + ) + + return latents + + +@dataclass +class QwenImagePipelineConfig(PipelineConfig): + should_use_guidance: bool = False + + is_image_gen: bool = True + + vae_tiling: bool = False + + vae_sp: bool = False + + dit_config: DiTConfig = field(default_factory=QwenImageDitConfig) + # VAE + vae_config: VAEConfig = field(default_factory=QwenImageVAEConfig) + + # Text encoding stage + text_encoder_configs: tuple[EncoderConfig, ...] = field( + default_factory=lambda: (Qwen2_5VLConfig(),) + ) + + text_encoder_precisions: tuple[str, ...] = field(default_factory=lambda: ("bf16",)) + + preprocess_text_funcs: tuple[Callable[[str], str], ...] = field( + default_factory=lambda: (qwen_image_preprocess_text,) + ) + + postprocess_text_funcs: tuple[Callable[[str], str], ...] = field( + default_factory=lambda: (qwen_image_postprocess_text,) + ) + text_encoder_extra_args: list[dict] = field( + default_factory=lambda: [ + dict( + padding=True, + truncation=True, + ), + None, + ] + ) + + def get_vae_scale_factor(self): + return self.vae_config.arch_config.vae_scale_factor + + def prepare_latent_shape(self, batch, batch_size, num_frames): + height = 2 * ( + batch.height // (self.vae_config.arch_config.vae_scale_factor * 2) + ) + width = 2 * (batch.width // (self.vae_config.arch_config.vae_scale_factor * 2)) + num_channels_latents = self.dit_config.arch_config.in_channels // 4 + shape = (batch_size, num_channels_latents, height, width) + return shape + + def pack_latents(self, latents, batch_size, batch): + height = 2 * ( + batch.height // (self.vae_config.arch_config.vae_scale_factor * 2) + ) + width = 2 * (batch.width // (self.vae_config.arch_config.vae_scale_factor * 2)) + num_channels_latents = self.dit_config.arch_config.in_channels // 4 + # pack latents + # _pack_latents(latents, batch_size, num_channels_latents, height, width) + latents = latents.view( + batch_size, num_channels_latents, height // 2, 2, width // 2, 2 + ) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape( + batch_size, (height // 2) * (width // 2), num_channels_latents * 4 + ) + return latents + + @staticmethod + def get_freqs_cis(img_shapes, txt_seq_lens, rotary_emb, device, dtype): + img_freqs, txt_freqs = rotary_emb(img_shapes, txt_seq_lens, device=device) + + img_cos, img_sin = ( + img_freqs.real.to(dtype=dtype), + img_freqs.imag.to(dtype=dtype), + ) + txt_cos, txt_sin = ( + txt_freqs.real.to(dtype=dtype), + txt_freqs.imag.to(dtype=dtype), + ) + return (img_cos, img_sin), (txt_cos, txt_sin) + + def prepare_pos_cond_kwargs(self, batch, device, rotary_emb, dtype): + batch_size = batch.latents.shape[0] + vae_scale_factor = self.vae_config.arch_config.vae_scale_factor + + img_shapes = [ + [ + ( + 1, + batch.height // vae_scale_factor // 2, + batch.width // vae_scale_factor // 2, + ) + ] + ] * batch_size + txt_seq_lens = [batch.prompt_embeds[0].shape[1]] + return { + "img_shapes": img_shapes, + "txt_seq_lens": txt_seq_lens, + "freqs_cis": QwenImagePipelineConfig.get_freqs_cis( + img_shapes, txt_seq_lens, rotary_emb, device, dtype + ), + } + + def prepare_neg_cond_kwargs(self, batch, device, rotary_emb, dtype): + batch_size = batch.latents.shape[0] + vae_scale_factor = self.vae_config.arch_config.vae_scale_factor + + img_shapes = [ + [ + ( + 1, + batch.height // vae_scale_factor // 2, + batch.width // vae_scale_factor // 2, + ) + ] + ] * batch_size + + txt_seq_lens = [batch.negative_prompt_embeds[0].shape[1]] + return { + "img_shapes": img_shapes, + "txt_seq_lens": txt_seq_lens, + "freqs_cis": QwenImagePipelineConfig.get_freqs_cis( + img_shapes, txt_seq_lens, rotary_emb, device, dtype + ), + } + + def post_denoising_loop(self, latents, batch): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + batch_size = latents.shape[0] + channels = latents.shape[-1] + vae_scale_factor = self.vae_config.arch_config.vae_scale_factor + height = 2 * (int(batch.height) // (vae_scale_factor * 2)) + width = 2 * (int(batch.width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) + return latents + + +class QwenImageEditPipelineConfig(QwenImagePipelineConfig): + ti2i_task = True + + def prepare_pos_cond_kwargs(self, batch, device, rotary_emb, dtype): + # TODO: lots of duplications here + batch_size = batch.latents.shape[0] + height = batch.height + width = batch.width + image = batch.pil_image + image_size = image[0].size if isinstance(image, list) else image.size + calculated_width, calculated_height, _ = calculate_dimensions( + 1024 * 1024, image_size[0] / image_size[1] + ) + vae_scale_factor = self.get_vae_scale_factor() + img_shapes = [ + [ + (1, height // vae_scale_factor // 2, width // vae_scale_factor // 2), + ( + 1, + calculated_height // vae_scale_factor // 2, + calculated_width // vae_scale_factor // 2, + ), + ] + ] * batch_size + txt_seq_lens = [batch.prompt_embeds[0].shape[1]] + return { + "img_shapes": img_shapes, + "txt_seq_lens": txt_seq_lens, + "freqs_cis": QwenImagePipelineConfig.get_freqs_cis( + img_shapes, txt_seq_lens, rotary_emb, device, dtype + ), + } + + def prepare_neg_cond_kwargs(self, batch, device, rotary_emb, dtype): + batch_size = batch.latents.shape[0] + height = batch.height + width = batch.width + image = batch.pil_image + image_size = image[0].size if isinstance(image, list) else image.size + calculated_width, calculated_height, _ = calculate_dimensions( + 1024 * 1024, image_size[0] / image_size[1] + ) + vae_scale_factor = self.get_vae_scale_factor() + img_shapes = [ + [ + (1, height // vae_scale_factor // 2, width // vae_scale_factor // 2), + ( + 1, + calculated_height // vae_scale_factor // 2, + calculated_width // vae_scale_factor // 2, + ), + ] + ] * batch_size + + txt_seq_lens = [batch.negative_prompt_embeds[0].shape[1]] + return { + "img_shapes": img_shapes, + "txt_seq_lens": txt_seq_lens, + "freqs_cis": QwenImagePipelineConfig.get_freqs_cis( + img_shapes, txt_seq_lens, rotary_emb, device, dtype + ), + } + + def prepare_latent_shape(self, batch, batch_size, num_frames): + vae_scale_factor = self.vae_config.arch_config.vae_scale_factor + height = 2 * (batch.height // (vae_scale_factor * 2)) + + width = 2 * (batch.width // (vae_scale_factor * 2)) + num_channels_latents = self.dit_config.arch_config.in_channels // 4 + shape = (batch_size, 1, num_channels_latents, height, width) + return shape + + def preprocess_image(self, image, image_processor): + image_size = image[0].size if isinstance(image, list) else image.size + calculated_width, calculated_height, _ = calculate_dimensions( + 1024 * 1024, image_size[0] / image_size[1] + ) + image = image_processor.resize(image, calculated_height, calculated_width) + return image + + def set_width_and_height(self, width, height, image): + image_size = image[0].size if isinstance(image, list) else image.size + calculated_width, calculated_height, _ = calculate_dimensions( + 1024 * 1024, image_size[0] / image_size[1] + ) + height = height or calculated_height + width = width or calculated_width + + multiple_of = self.get_vae_scale_factor() * 2 + width = width // multiple_of * multiple_of + height = height // multiple_of * multiple_of + return width, height + + def slice_noise_pred(self, noise, latents): + noise = noise[:, : latents.size(1)] + return noise diff --git a/python/sglang/multimodal_gen/configs/pipelines/registry.py b/python/sglang/multimodal_gen/configs/pipelines/registry.py new file mode 100644 index 000000000..b9c223399 --- /dev/null +++ b/python/sglang/multimodal_gen/configs/pipelines/registry.py @@ -0,0 +1,168 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +"""Registry for pipeline weight-specific configurations.""" + +import os +from collections.abc import Callable + +from sglang.multimodal_gen.configs.pipelines.base import PipelineConfig +from sglang.multimodal_gen.configs.pipelines.flux import FluxPipelineConfig +from sglang.multimodal_gen.configs.pipelines.hunyuan import ( + FastHunyuanConfig, + HunyuanConfig, +) +from sglang.multimodal_gen.configs.pipelines.qwen_image import ( + QwenImageEditPipelineConfig, + QwenImagePipelineConfig, +) +from sglang.multimodal_gen.configs.pipelines.stepvideo import StepVideoT2VConfig + +# isort: off +from sglang.multimodal_gen.configs.pipelines.wan import ( + FastWan2_1_T2V_480P_Config, + FastWan2_2_TI2V_5B_Config, + Wan2_2_I2V_A14B_Config, + Wan2_2_T2V_A14B_Config, + Wan2_2_TI2V_5B_Config, + WanI2V480PConfig, + WanI2V720PConfig, + WanT2V480PConfig, + WanT2V720PConfig, + SelfForcingWanT2V480PConfig, +) +from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import ( + verify_model_config_and_directory, + maybe_download_model_index, +) + +# isort: on +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + +# Registry maps specific model weights to their config classes +PIPE_NAME_TO_CONFIG: dict[str, type[PipelineConfig]] = { + "FastVideo/FastHunyuan-diffusers": FastHunyuanConfig, + "hunyuanvideo-community/HunyuanVideo": HunyuanConfig, + "Wan-AI/Wan2.1-T2V-1.3B-Diffusers": WanT2V480PConfig, + "weizhou03/Wan2.1-Fun-1.3B-InP-Diffusers": WanI2V480PConfig, + "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers": WanI2V480PConfig, + "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers": WanI2V720PConfig, + "Wan-AI/Wan2.1-T2V-14B-Diffusers": WanT2V720PConfig, + "FastVideo/FastWan2.1-T2V-1.3B-Diffusers": FastWan2_1_T2V_480P_Config, + "FastVideo/FastWan2.1-T2V-14B-480P-Diffusers": FastWan2_1_T2V_480P_Config, + "FastVideo/FastWan2.2-TI2V-5B-Diffusers": FastWan2_2_TI2V_5B_Config, + "FastVideo/stepvideo-t2v-diffusers": StepVideoT2VConfig, + "FastVideo/Wan2.1-VSA-T2V-14B-720P-Diffusers": WanT2V720PConfig, + "wlsaidhi/SFWan2.1-T2V-1.3B-Diffusers": SelfForcingWanT2V480PConfig, + "Wan-AI/Wan2.2-TI2V-5B-Diffusers": Wan2_2_TI2V_5B_Config, + "Wan-AI/Wan2.2-T2V-A14B-Diffusers": Wan2_2_T2V_A14B_Config, + "Wan-AI/Wan2.2-I2V-A14B-Diffusers": Wan2_2_I2V_A14B_Config, + # Add other specific weight variants + "black-forest-labs/FLUX.1-dev": FluxPipelineConfig, + "Qwen/Qwen-Image": QwenImagePipelineConfig, + "Qwen/Qwen-Image-Edit": QwenImageEditPipelineConfig, +} + +# For determining pipeline type from model ID +PIPELINE_DETECTOR: dict[str, Callable[[str], bool]] = { + "hunyuan": lambda id: "hunyuan" in id.lower(), + "wanpipeline": lambda id: "wanpipeline" in id.lower(), + "wanimagetovideo": lambda id: "wanimagetovideo" in id.lower(), + "wandmdpipeline": lambda id: "wandmdpipeline" in id.lower(), + "wancausaldmdpipeline": lambda id: "wancausaldmdpipeline" in id.lower(), + "stepvideo": lambda id: "stepvideo" in id.lower(), + "qwenimage": lambda id: "qwen-image" in id.lower() and "edit" not in id.lower(), + "qwenimageedit": lambda id: "qwen-image-edit" in id.lower(), + # Add other pipeline architecture detectors +} + +# Fallback configs when exact match isn't found but architecture is detected +PIPELINE_FALLBACK_CONFIG: dict[str, type[PipelineConfig]] = { + "hunyuan": HunyuanConfig, # Base Hunyuan config as fallback for any Hunyuan variant + "wanpipeline": WanT2V480PConfig, # Base Wan config as fallback for any Wan variant + "wanimagetovideo": WanI2V480PConfig, + "wandmdpipeline": FastWan2_1_T2V_480P_Config, + "wancausaldmdpipeline": SelfForcingWanT2V480PConfig, + "stepvideo": StepVideoT2VConfig, + "qwenimage": QwenImagePipelineConfig, + "qwenimageedit": QwenImageEditPipelineConfig, + # Other fallbacks by architecture +} + + +def get_pipeline_config_cls_from_name( + pipeline_name_or_path: str, +) -> type[PipelineConfig]: + """Get the appropriate configuration class for a given pipeline name or path. + + This function implements a multi-step lookup process to find the most suitable + configuration class for a given pipeline. It follows this order: + 1. Exact match in the PIPE_NAME_TO_CONFIG + 2. Partial match in the PIPE_NAME_TO_CONFIG + 3. Fallback to class name in the model_index.json + 4. else raise an error + + Args: + pipeline_name_or_path (str): The name or path of the pipeline. This can be: + - A registered model ID (e.g., "FastVideo/FastHunyuan-diffusers") + - A local path to a model directory + - A model ID that will be downloaded + + Returns: + Type[PipelineConfig]: The configuration class that best matches the pipeline. + This will be one of: + - A specific weight configuration class if an exact match is found + - A fallback configuration class based on the pipeline architecture + - The base PipelineConfig class if no matches are found + + Note: + - For local paths, the function will verify the model configuration + - For remote models, it will attempt to download the model index + - Warning messages are logged when falling back to less specific configurations + """ + + pipeline_config_cls: type[PipelineConfig] | None = None + + # First try exact match for specific weights + if pipeline_name_or_path in PIPE_NAME_TO_CONFIG: + pipeline_config_cls = PIPE_NAME_TO_CONFIG[pipeline_name_or_path] + + if pipeline_config_cls is None: + # Try partial matches (for local paths that might include the weight ID) + for registered_id, config_class in PIPE_NAME_TO_CONFIG.items(): + if registered_id in pipeline_name_or_path: + pipeline_config_cls = config_class + break + + # If no match, try to use the fallback config + if pipeline_config_cls is None: + if os.path.exists(pipeline_name_or_path): + config = verify_model_config_and_directory(pipeline_name_or_path) + else: + config = maybe_download_model_index(pipeline_name_or_path) + logger.warning( + "Trying to use the config from the model_index.json. sgl-diffusion may not correctly identify the optimal config for this model in this situation." + ) + + pipeline_name = config["_class_name"] + # Try to determine pipeline architecture for fallback + for pipeline_type, detector in PIPELINE_DETECTOR.items(): + if detector(pipeline_name.lower()): + pipeline_config_cls = PIPELINE_FALLBACK_CONFIG.get(pipeline_type) + break + + if pipeline_config_cls is not None: + logger.warning( + "No match found for pipeline %s, using fallback config %s.", + pipeline_name_or_path, + pipeline_config_cls, + ) + + if pipeline_config_cls is None: + raise ValueError( + f"No match found for pipeline {pipeline_name_or_path}, please check the pipeline name or path." + ) + + return pipeline_config_cls diff --git a/python/sglang/multimodal_gen/configs/pipelines/stepvideo.py b/python/sglang/multimodal_gen/configs/pipelines/stepvideo.py new file mode 100644 index 000000000..586e7542b --- /dev/null +++ b/python/sglang/multimodal_gen/configs/pipelines/stepvideo.py @@ -0,0 +1,36 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field + +from sglang.multimodal_gen.configs.models import DiTConfig, VAEConfig +from sglang.multimodal_gen.configs.models.dits import StepVideoConfig +from sglang.multimodal_gen.configs.models.vaes import StepVideoVAEConfig +from sglang.multimodal_gen.configs.pipelines.base import PipelineConfig + + +@dataclass +class StepVideoT2VConfig(PipelineConfig): + """Base configuration for StepVideo pipeline architecture.""" + + # WanConfig-specific parameters with defaults + # DiT + dit_config: DiTConfig = field(default_factory=StepVideoConfig) + # VAE + vae_config: VAEConfig = field(default_factory=StepVideoVAEConfig) + vae_tiling: bool = False + vae_sp: bool = False + + # Denoising stage + flow_shift: int = 13 + timesteps_scale: bool = False + pos_magic: str = ( + "超高清、HDR 视频、环境光、杜比全景声、画面稳定、流畅动作、逼真的细节、专业级构图、超现实主义、自然、生动、超细节、清晰。" + ) + neg_magic: str = ( + "画面暗、低分辨率、不良手、文本、缺少手指、多余的手指、裁剪、低质量、颗粒状、签名、水印、用户名、模糊。" + ) + + # Precision for each component + precision: str = "bf16" + vae_precision: str = "bf16" diff --git a/python/sglang/multimodal_gen/configs/pipelines/wan.py b/python/sglang/multimodal_gen/configs/pipelines/wan.py new file mode 100644 index 000000000..af6a697c2 --- /dev/null +++ b/python/sglang/multimodal_gen/configs/pipelines/wan.py @@ -0,0 +1,190 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from collections.abc import Callable +from dataclasses import dataclass, field + +import torch + +from sglang.multimodal_gen.configs.models import DiTConfig, EncoderConfig, VAEConfig +from sglang.multimodal_gen.configs.models.dits import WanVideoConfig +from sglang.multimodal_gen.configs.models.encoders import ( + BaseEncoderOutput, + CLIPVisionConfig, + T5Config, +) +from sglang.multimodal_gen.configs.models.vaes import WanVAEConfig +from sglang.multimodal_gen.configs.pipelines.base import PipelineConfig + + +def t5_postprocess_text(outputs: BaseEncoderOutput, _text_inputs) -> torch.Tensor: + mask: torch.Tensor = outputs.attention_mask + hidden_state: torch.Tensor = outputs.last_hidden_state + seq_lens = mask.gt(0).sum(dim=1).long() + assert torch.isnan(hidden_state).sum() == 0 + prompt_embeds = [u[:v] for u, v in zip(hidden_state, seq_lens, strict=True)] + prompt_embeds_tensor: torch.Tensor = torch.stack( + [ + torch.cat([u, u.new_zeros(512 - u.size(0), u.size(1))]) + for u in prompt_embeds + ], + dim=0, + ) + return prompt_embeds_tensor + + +@dataclass +class WanT2V480PConfig(PipelineConfig): + """Base configuration for Wan T2V 1.3B pipeline architecture.""" + + # WanConfig-specific parameters with defaults + # DiT + dit_config: DiTConfig = field(default_factory=WanVideoConfig) + + # VAE + vae_config: VAEConfig = field(default_factory=WanVAEConfig) + vae_tiling: bool = False + vae_sp: bool = False + + # Denoising stage + flow_shift: float | None = 3.0 + + # Text encoding stage + text_encoder_configs: tuple[EncoderConfig, ...] = field( + default_factory=lambda: (T5Config(),) + ) + postprocess_text_funcs: tuple[Callable[[BaseEncoderOutput], torch.Tensor], ...] = ( + field(default_factory=lambda: (t5_postprocess_text,)) + ) + + # Precision for each component + precision: str = "bf16" + vae_precision: str = "fp32" + text_encoder_precisions: tuple[str, ...] = field(default_factory=lambda: ("fp32",)) + + # WanConfig-specific added parameters + + def __post_init__(self): + self.vae_config.load_encoder = False + self.vae_config.load_decoder = True + + +@dataclass +class WanT2V720PConfig(WanT2V480PConfig): + """Base configuration for Wan T2V 14B 720P pipeline architecture.""" + + # WanConfig-specific parameters with defaults + + # Denoising stage + flow_shift: float | None = 5.0 + + +@dataclass +class WanI2V480PConfig(WanT2V480PConfig): + """Base configuration for Wan I2V 14B 480P pipeline architecture.""" + + # WanConfig-specific parameters with defaults + i2v_task: bool = True + # Precision for each component + image_encoder_config: EncoderConfig = field(default_factory=CLIPVisionConfig) + image_encoder_precision: str = "fp32" + + image_encoder_extra_args: dict = field( + default_factory=lambda: dict( + output_hidden_states=True, + ) + ) + + def postprocess_image(self, image): + return image.hidden_states[-2] + + def __post_init__(self) -> None: + self.vae_config.load_encoder = True + self.vae_config.load_decoder = True + + +@dataclass +class WanI2V720PConfig(WanI2V480PConfig): + """Base configuration for Wan I2V 14B 720P pipeline architecture.""" + + # WanConfig-specific parameters with defaults + + # Denoising stage + flow_shift: float | None = 5.0 + + +@dataclass +class FastWan2_1_T2V_480P_Config(WanT2V480PConfig): + """Base configuration for FastWan T2V 1.3B 480P pipeline architecture with DMD""" + + # WanConfig-specific parameters with defaults + + # Denoising stage + flow_shift: float | None = 8.0 + dmd_denoising_steps: list[int] | None = field( + default_factory=lambda: [1000, 757, 522] + ) + + +@dataclass +class Wan2_2_TI2V_5B_Config(WanT2V480PConfig): + flow_shift: float | None = 5.0 + ti2v_task: bool = True + expand_timesteps: bool = True + # ti2v, 5B + vae_stride = (4, 16, 16) + + def prepare_latent_shape(self, batch, batch_size, num_frames): + F = num_frames + z_dim = self.vae_config.arch_config.z_dim + vae_stride = self.vae_stride + oh = batch.height + ow = batch.width + shape = (z_dim, F, oh // vae_stride[1], ow // vae_stride[2]) + + return shape + + def __post_init__(self) -> None: + self.vae_config.load_encoder = True + self.vae_config.load_decoder = True + self.dit_config.expand_timesteps = self.expand_timesteps + + +@dataclass +class FastWan2_2_TI2V_5B_Config(Wan2_2_TI2V_5B_Config): + flow_shift: float | None = 5.0 + dmd_denoising_steps: list[int] | None = field( + default_factory=lambda: [1000, 757, 522] + ) + + +@dataclass +class Wan2_2_T2V_A14B_Config(WanT2V480PConfig): + flow_shift: float | None = 12.0 + boundary_ratio: float | None = 0.875 + + def __post_init__(self) -> None: + self.dit_config.boundary_ratio = self.boundary_ratio + + +@dataclass +class Wan2_2_I2V_A14B_Config(WanI2V480PConfig): + flow_shift: float | None = 5.0 + boundary_ratio: float | None = 0.900 + + def __post_init__(self) -> None: + super().__post_init__() + self.dit_config.boundary_ratio = self.boundary_ratio + + +# ============================================= +# ============= Causal Self-Forcing ============= +# ============================================= +@dataclass +class SelfForcingWanT2V480PConfig(WanT2V480PConfig): + is_causal: bool = True + flow_shift: float | None = 5.0 + dmd_denoising_steps: list[int] | None = field( + default_factory=lambda: [1000, 750, 500, 250] + ) + warp_denoising_step: bool = True diff --git a/python/sglang/multimodal_gen/configs/sample/__init__.py b/python/sglang/multimodal_gen/configs/sample/__init__.py new file mode 100644 index 000000000..13bf24ce5 --- /dev/null +++ b/python/sglang/multimodal_gen/configs/sample/__init__.py @@ -0,0 +1,5 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +from sglang.multimodal_gen.configs.sample.base import SamplingParams + +__all__ = ["SamplingParams"] diff --git a/python/sglang/multimodal_gen/configs/sample/base.py b/python/sglang/multimodal_gen/configs/sample/base.py new file mode 100644 index 000000000..27da2f9f5 --- /dev/null +++ b/python/sglang/multimodal_gen/configs/sample/base.py @@ -0,0 +1,494 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +import argparse +import dataclasses +import hashlib +import json +import os.path +import re +import time +import unicodedata +import uuid +from copy import deepcopy +from dataclasses import dataclass +from enum import Enum, auto +from typing import Any + +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import align_to + +logger = init_logger(__name__) + + +def _json_safe(obj: Any): + """ + Recursively convert objects to JSON-serializable forms. + - Enums -> their name + - Sets/Tuples -> lists + - Dicts/Lists -> recursively processed + """ + if isinstance(obj, Enum): + return obj.name + if isinstance(obj, dict): + return {k: _json_safe(v) for k, v in obj.items()} + if isinstance(obj, (list, tuple, set)): + return [_json_safe(v) for v in obj] + return obj + + +def generate_request_id() -> str: + return str(uuid.uuid4()) + + +def _sanitize_filename(name: str, replacement: str = "_", max_length: int = 150) -> str: + """Create a filesystem- and ffmpeg-friendly filename. + + - Normalize to ASCII (drop accents and unsupported chars) + - Replace spaces with underscores + - Replace any char not in [A-Za-z0-9_.-] with replacement + - Collapse multiple underscores + - Trim leading/trailing dots/underscores and limit length + """ + normalized = unicodedata.normalize("NFKD", name) + ascii_name = normalized.encode("ascii", "ignore").decode("ascii") + ascii_name = ascii_name.replace(" ", "_") + ascii_name = re.sub(r"[^A-Za-z0-9._-]", replacement, ascii_name) + ascii_name = re.sub(r"_+", "_", ascii_name).strip("._") + if not ascii_name: + ascii_name = "output" + if max_length and len(ascii_name) > max_length: + ascii_name = ascii_name[:max_length] + return ascii_name + + +class DataType(Enum): + IMAGE = auto() + VIDEO = auto() + + def get_default_extension(self) -> str: + if self == DataType.IMAGE: + return "jpg" + else: + return "mp4" + + +@dataclass +class SamplingParams: + """ + Sampling parameters for generation. + """ + + data_type: DataType = DataType.VIDEO + + request_id: str | None = None + + # All fields below are copied from ForwardBatch + + # Image inputs + image_path: str | None = None + + # Text inputs + prompt: str | list[str] | None = None + negative_prompt: str = ( + "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" + ) + prompt_path: str | None = None + output_path: str = "outputs/" + output_file_name: str | None = None + + # Batch info + num_outputs_per_prompt: int = 1 + seed: int = 1024 + + # Original dimensions (before VAE scaling) + num_frames: int = 125 + num_frames_round_down: bool = ( + False # Whether to round down num_frames if it's not divisible by num_gpus + ) + height: int | None = None + width: int | None = None + # NOTE: this is temporary, we need a way to know if width or height is not provided, or do the image resize earlier + height_not_provided: bool = False + width_not_provided: bool = False + fps: int = 24 + + # Denoising parameters + num_inference_steps: int = 50 + guidance_scale: float = 1.0 + guidance_rescale: float = 0.0 + boundary_ratio: float | None = None + + # TeaCache parameters + enable_teacache: bool = False + + # Profiling + profile: bool = False + num_profiled_timesteps: int = 2 + + # Debugging + debug: bool = False + + # Misc + save_output: bool = True + return_frames: bool = False + return_trajectory_latents: bool = False # returns all latents for each timestep + return_trajectory_decoded: bool = False # returns decoded latents for each timestep + + def set_output_file_ext(self): + # add extension if needed + if not any( + self.output_file_name.endswith(ext) + for ext in [".mp4", ".jpg", ".png", ".webp"] + ): + self.output_file_name = ( + f"{self.output_file_name}.{self.data_type.get_default_extension()}" + ) + + def set_output_file_name(self): + # settle output_file_name + if ( + self.output_file_name is None + and self.prompt + and isinstance(self.prompt, str) + ): + # generate a random filename + # get a hash of current params + params_dict = dataclasses.asdict(self) + # Avoid recursion + params_dict["output_file_name"] = "" + + # Convert to a stable JSON string + params_str = json.dumps(_json_safe(params_dict), sort_keys=True) + # Create a hash + hasher = hashlib.sha256() + hasher.update(params_str.encode("utf-8")) + param_hash = hasher.hexdigest()[:8] + + timestamp = time.strftime("%Y%m%d-%H%M%S") + base = f"{self.prompt[:100]}_{timestamp}_{param_hash}" + self.output_file_name = base + + if self.output_file_name is None: + timestamp = time.strftime("%Y%m%d-%H%M%S") + self.output_file_name = f"output_{timestamp}" + + self.output_file_name = _sanitize_filename(self.output_file_name) + + # Ensure a proper extension is present + self.set_output_file_ext() + + def __post_init__(self) -> None: + assert self.num_frames >= 1 + self.data_type = DataType.VIDEO if self.num_frames > 1 else DataType.IMAGE + + if self.width is None: + self.width_not_provided = True + self.width = 1280 + if self.height is None: + self.height_not_provided = True + self.height = 720 + + def check_sampling_param(self): + if self.prompt_path and not self.prompt_path.endswith(".txt"): + raise ValueError("prompt_path must be a txt file") + + def update(self, source_dict: dict[str, Any]) -> None: + for key, value in source_dict.items(): + if hasattr(self, key): + setattr(self, key, value) + else: + logger.exception("%s has no attribute %s", type(self).__name__, key) + + self.__post_init__() + + @classmethod + def from_pretrained(cls, model_path: str, **kwargs) -> "SamplingParams": + from sglang.multimodal_gen.configs.sample.registry import ( + get_sampling_param_cls_for_name, + ) + + sampling_cls = get_sampling_param_cls_for_name(model_path) + logger.debug(f"Using pretrained SamplingParam: {sampling_cls}") + if sampling_cls is not None: + sampling_params: SamplingParams = sampling_cls(**kwargs) + else: + logger.warning( + "Couldn't find an optimal sampling param for %s. Using the default sampling param.", + model_path, + ) + sampling_params = cls(**kwargs) + return sampling_params + + def from_user_sampling_params(self, user_params): + sampling_params = deepcopy(self) + sampling_params._merge_with_user_params(user_params) + return sampling_params + + @staticmethod + def add_cli_args(parser: Any) -> Any: + """Add CLI arguments for SamplingParam fields""" + parser.add_argument("--data-type", type=str, nargs="+", default=DataType.VIDEO) + parser.add_argument( + "--num-frames-round-down", + action="store_true", + default=SamplingParams.num_frames_round_down, + ) + parser.add_argument( + "--enable-teacache", + action="store_true", + default=SamplingParams.enable_teacache, + ) + parser.add_argument( + "--profile", + action="store_true", + default=SamplingParams.profile, + help="Enable torch profiler for denoising stage", + ) + parser.add_argument( + "--debug", + action="store_true", + default=SamplingParams.debug, + help="", + ) + parser.add_argument( + "--num-profiled-timesteps", + type=int, + default=SamplingParams.num_profiled_timesteps, + help="Number of timesteps to profile after warmup", + ) + parser.add_argument( + "--prompt", + type=str, + default=SamplingParams.prompt, + help="Text prompt for generation", + ) + parser.add_argument( + "--negative-prompt", + type=str, + default=SamplingParams.negative_prompt, + help="Negative text prompt for generation", + ) + parser.add_argument( + "--prompt-path", + type=str, + default=SamplingParams.prompt_path, + help="Path to a text file containing the prompt", + ) + parser.add_argument( + "--output-path", + type=str, + default=SamplingParams.output_path, + help="Path to save the generated image/video", + ) + parser.add_argument( + "--output-file-name", + type=str, + default=SamplingParams.output_file_name, + help="Name of the output file", + ) + parser.add_argument( + "--num-outputs-per-prompt", + type=int, + default=SamplingParams.num_outputs_per_prompt, + help="Number of outputs to generate per prompt", + ) + parser.add_argument( + "--seed", + type=int, + default=SamplingParams.seed, + help="Random seed for generation", + ) + parser.add_argument( + "--num-frames", + type=int, + default=SamplingParams.num_frames, + help="Number of frames to generate", + ) + parser.add_argument( + "--height", + type=int, + default=SamplingParams.height, + help="Height of generated output", + ) + parser.add_argument( + "--width", + type=int, + default=SamplingParams.width, + help="Width of generated output", + ) + parser.add_argument( + "--fps", + type=int, + default=SamplingParams.fps, + help="Frames per second for saved output", + ) + parser.add_argument( + "--num-inference-steps", + type=int, + default=SamplingParams.num_inference_steps, + help="Number of denoising steps", + ) + parser.add_argument( + "--guidance-scale", + type=float, + default=SamplingParams.guidance_scale, + help="Classifier-free guidance scale", + ) + parser.add_argument( + "--guidance-rescale", + type=float, + default=SamplingParams.guidance_rescale, + help="Guidance rescale factor", + ) + parser.add_argument( + "--boundary-ratio", + type=float, + default=SamplingParams.boundary_ratio, + help="Boundary timestep ratio", + ) + parser.add_argument( + "--save-output", + action="store_true", + default=SamplingParams.save_output, + help="Whether to save the output to disk", + ) + parser.add_argument( + "--no-save-output", + action="store_false", + dest="save_output", + help="Don't save the output to disk", + ) + parser.add_argument( + "--return-frames", + action="store_true", + default=SamplingParams.return_frames, + help="Whether to return the raw frames", + ) + parser.add_argument( + "--image-path", + type=str, + default=SamplingParams.image_path, + help="Path to input image for image-to-video generation", + ) + parser.add_argument( + "--moba-config-path", + type=str, + default=None, + help="Path to a JSON file containing V-MoBA specific configurations.", + ) + parser.add_argument( + "--return-trajectory-latents", + action="store_true", + default=SamplingParams.return_trajectory_latents, + help="Whether to return the trajectory", + ) + parser.add_argument( + "--return-trajectory-decoded", + action="store_true", + default=SamplingParams.return_trajectory_decoded, + help="Whether to return the decoded trajectory", + ) + return parser + + @classmethod + def from_cli_args(cls, args: argparse.Namespace): + attrs = [attr.name for attr in dataclasses.fields(cls)] + args.height_not_provided = False + args.width_not_provided = False + return cls(**{attr: getattr(args, attr) for attr in attrs}) + + def output_file_path(self): + return os.path.join(self.output_path, self.output_file_name) + + def _merge_with_user_params(self, user_params): + """ + Merges parameters from a user-provided SamplingParams object. + + This method updates the current object with values from `user_params`, + but skips any fields that are explicitly defined in the current object's + subclass. This is to preserve model-specific optimal parameters. + It also skips fields that the user has not changed from the default + in `user_params`. + """ + if user_params is None: + return + + # Get fields defined directly in the subclass (not inherited) + subclass_defined_fields = set(type(self).__annotations__.keys()) + + # Compare against current instance to avoid constructing a default instance + default_params = SamplingParams() + + for field in dataclasses.fields(user_params): + field_name = field.name + user_value = getattr(user_params, field_name) + default_value = getattr(default_params, field_name) + + # A field is considered user-modified if its value is different from + # the default, with an exception for `output_file_name` which is + # auto-generated with a random component. + is_user_modified = ( + user_value != default_value + if field_name != "output_file_name" + else user_params.output_file_path is not None + ) + if is_user_modified and field_name not in subclass_defined_fields: + if hasattr(self, field_name): + setattr(self, field_name, user_value) + + self.__post_init__() + + @property + def n_tokens(self) -> int: + # Calculate latent sizes + if self.height and self.width: + latents_size = [ + (self.num_frames - 1) // 4 + 1, + self.height // 8, + self.width // 8, + ] + n_tokens = latents_size[0] * latents_size[1] * latents_size[2] + else: + n_tokens = -1 + return n_tokens + + def output_file_path(self): + return os.path.join(self.output_path, self.output_file_name) + + def log(self, server_args: ServerArgs): + # TODO: in some cases (e.g., TI2I), height and weight might be undecided at this moment + if self.height: + target_height = align_to(self.height, 16) + else: + target_height = -1 + if self.width: + target_width = align_to(self.width, 16) + else: + target_width = -1 + + # Log sampling parameters + debug_str = f"""Sampling params: + height: {target_height} + width: {target_width} + num_frames: {self.num_frames} + prompt: {self.prompt} + neg_prompt: {self.negative_prompt} + seed: {self.seed} + infer_steps: {self.num_inference_steps} + num_outputs_per_prompt: {self.num_outputs_per_prompt} + guidance_scale: {self.guidance_scale} + embedded_guidance_scale: {server_args.pipeline_config.embedded_cfg_scale} + n_tokens: {self.n_tokens} + flow_shift: {server_args.pipeline_config.flow_shift} + image_path: {self.image_path} + save_output: {self.save_output} + output_file_path: {self.output_file_path()} + """ # type: ignore[attr-defined] + logger.info(debug_str) + + +@dataclass +class CacheParams: + cache_type: str = "none" diff --git a/python/sglang/multimodal_gen/configs/sample/flux.py b/python/sglang/multimodal_gen/configs/sample/flux.py new file mode 100644 index 000000000..4c96467fb --- /dev/null +++ b/python/sglang/multimodal_gen/configs/sample/flux.py @@ -0,0 +1,18 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass + +from sglang.multimodal_gen.configs.sample.base import SamplingParams + + +@dataclass +class FluxSamplingParams(SamplingParams): + # Video parameters + # height: int = 1024 + # width: int = 1024 + num_frames: int = 1 + # Denoising stage + guidance_scale: float = 1.0 + negative_prompt: str = None + num_inference_steps: int = 50 diff --git a/python/sglang/multimodal_gen/configs/sample/hunyuan.py b/python/sglang/multimodal_gen/configs/sample/hunyuan.py new file mode 100644 index 000000000..266d665e2 --- /dev/null +++ b/python/sglang/multimodal_gen/configs/sample/hunyuan.py @@ -0,0 +1,37 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field + +from sglang.multimodal_gen.configs.sample.base import SamplingParams +from sglang.multimodal_gen.configs.sample.teacache import TeaCacheParams + + +@dataclass +class HunyuanSamplingParams(SamplingParams): + num_inference_steps: int = 50 + + num_frames: int = 125 + height: int = 720 + width: int = 1280 + fps: int = 24 + + guidance_scale: float = 1.0 + + teacache_params: TeaCacheParams = field( + default_factory=lambda: TeaCacheParams( + teacache_thresh=0.15, + coefficients=[ + 7.33226126e02, + -4.01131952e02, + 6.75869174e01, + -3.14987800e00, + 9.61237896e-02, + ], + ) + ) + + +@dataclass +class FastHunyuanSamplingParam(HunyuanSamplingParams): + num_inference_steps: int = 6 diff --git a/python/sglang/multimodal_gen/configs/sample/qwenimage.py b/python/sglang/multimodal_gen/configs/sample/qwenimage.py new file mode 100644 index 000000000..282b66d8f --- /dev/null +++ b/python/sglang/multimodal_gen/configs/sample/qwenimage.py @@ -0,0 +1,18 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass + +from sglang.multimodal_gen.configs.sample.base import SamplingParams + + +@dataclass +class QwenImageSamplingParams(SamplingParams): + # Video parameters + # height: int = 1024 + # width: int = 1024 + negative_prompt: str = " " + num_frames: int = 1 + # Denoising stage + guidance_scale: float = 4.0 + num_inference_steps: int = 50 diff --git a/python/sglang/multimodal_gen/configs/sample/registry.py b/python/sglang/multimodal_gen/configs/sample/registry.py new file mode 100644 index 000000000..297901fc2 --- /dev/null +++ b/python/sglang/multimodal_gen/configs/sample/registry.py @@ -0,0 +1,122 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +import os +from collections.abc import Callable +from typing import Any + +from sglang.multimodal_gen.configs.sample.flux import FluxSamplingParams +from sglang.multimodal_gen.configs.sample.hunyuan import ( + FastHunyuanSamplingParam, + HunyuanSamplingParams, +) +from sglang.multimodal_gen.configs.sample.qwenimage import QwenImageSamplingParams +from sglang.multimodal_gen.configs.sample.stepvideo import StepVideoT2VSamplingParams + +# isort: off +from sglang.multimodal_gen.configs.sample.wan import ( + FastWanT2V480PConfig, + Wan2_1_Fun_1_3B_InP_SamplingParams, + Wan2_2_I2V_A14B_SamplingParam, + Wan2_2_T2V_A14B_SamplingParam, + Wan2_2_TI2V_5B_SamplingParam, + WanI2V_14B_480P_SamplingParam, + WanI2V_14B_720P_SamplingParam, + WanT2V_1_3B_SamplingParams, + WanT2V_14B_SamplingParams, + SelfForcingWanT2V480PConfig, +) +from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import ( + maybe_download_model_index, + verify_model_config_and_directory, +) + +# isort: on +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) +# Registry maps specific model weights to their config classes +SAMPLING_PARAM_REGISTRY: dict[str, Any] = { + "FastVideo/FastHunyuan-diffusers": FastHunyuanSamplingParam, + "hunyuanvideo-community/HunyuanVideo": HunyuanSamplingParams, + "FastVideo/stepvideo-t2v-diffusers": StepVideoT2VSamplingParams, + # Wan2.1 + "Wan-AI/Wan2.1-T2V-1.3B-Diffusers": WanT2V_1_3B_SamplingParams, + "Wan-AI/Wan2.1-T2V-14B-Diffusers": WanT2V_14B_SamplingParams, + "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers": WanI2V_14B_480P_SamplingParam, + "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers": WanI2V_14B_720P_SamplingParam, + "weizhou03/Wan2.1-Fun-1.3B-InP-Diffusers": Wan2_1_Fun_1_3B_InP_SamplingParams, + # Wan2.2 + "Wan-AI/Wan2.2-TI2V-5B-Diffusers": Wan2_2_TI2V_5B_SamplingParam, + "FastVideo/FastWan2.2-TI2V-5B-FullAttn-Diffusers": Wan2_2_TI2V_5B_SamplingParam, + "Wan-AI/Wan2.2-T2V-A14B-Diffusers": Wan2_2_T2V_A14B_SamplingParam, + "Wan-AI/Wan2.2-I2V-A14B-Diffusers": Wan2_2_I2V_A14B_SamplingParam, + # FastWan2.1 + "FastVideo/FastWan2.1-T2V-1.3B-Diffusers": FastWanT2V480PConfig, + # FastWan2.2 + "FastVideo/FastWan2.2-TI2V-5B-Diffusers": Wan2_2_TI2V_5B_SamplingParam, + # Causal Self-Forcing Wan2.1 + "wlsaidhi/SFWan2.1-T2V-1.3B-Diffusers": SelfForcingWanT2V480PConfig, + # Add other specific weight variants + "black-forest-labs/FLUX.1-dev": FluxSamplingParams, + "Qwen/Qwen-Image": QwenImageSamplingParams, + "Qwen/Qwen-Image-Edit": QwenImageSamplingParams, +} + +# For determining pipeline type from model ID +SAMPLING_PARAM_DETECTOR: dict[str, Callable[[str], bool]] = { + "hunyuan": lambda id: "hunyuan" in id.lower(), + "wanpipeline": lambda id: "wanpipeline" in id.lower(), + "wanimagetovideo": lambda id: "wanimagetovideo" in id.lower(), + "stepvideo": lambda id: "stepvideo" in id.lower(), + # Add other pipeline architecture detectors + "flux": lambda id: "flux" in id.lower(), +} + +# Fallback configs when exact match isn't found but architecture is detected +SAMPLING_FALLBACK_PARAM: dict[str, Any] = { + "hunyuan": HunyuanSamplingParams, # Base Hunyuan config as fallback for any Hunyuan variant + "wanpipeline": WanT2V_1_3B_SamplingParams, # Base Wan config as fallback for any Wan variant + "wanimagetovideo": WanI2V_14B_480P_SamplingParam, + "stepvideo": StepVideoT2VSamplingParams, + # Other fallbacks by architecture + "flux": FluxSamplingParams, +} + + +def get_sampling_param_cls_for_name(pipeline_name_or_path: str) -> Any | None: + """Get the appropriate sampling param for specific pretrained weights.""" + + if os.path.exists(pipeline_name_or_path): + config = verify_model_config_and_directory(pipeline_name_or_path) + logger.warning( + "sgl-diffusion may not correctly identify the optimal sampling param for this model, as the local directory may have been renamed." + ) + else: + config = maybe_download_model_index(pipeline_name_or_path) + + pipeline_name = config["_class_name"] + + # First try exact match for specific weights + if pipeline_name_or_path in SAMPLING_PARAM_REGISTRY: + return SAMPLING_PARAM_REGISTRY[pipeline_name_or_path] + + # Try partial matches (for local paths that might include the weight ID) + for registered_id, config_class in SAMPLING_PARAM_REGISTRY.items(): + if registered_id in pipeline_name_or_path: + return config_class + + # If no match, try to use the fallback config + fallback_config = None + # Try to determine pipeline architecture for fallback + for pipeline_type, detector in SAMPLING_PARAM_DETECTOR.items(): + if detector(pipeline_name.lower()): + fallback_config = SAMPLING_FALLBACK_PARAM.get(pipeline_type) + break + + logger.warning( + "No match found for pipeline %s, using fallback sampling param %s.", + pipeline_name_or_path, + fallback_config, + ) + return fallback_config diff --git a/python/sglang/multimodal_gen/configs/sample/stepvideo.py b/python/sglang/multimodal_gen/configs/sample/stepvideo.py new file mode 100644 index 000000000..3f58ab3fe --- /dev/null +++ b/python/sglang/multimodal_gen/configs/sample/stepvideo.py @@ -0,0 +1,22 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass + +from sglang.multimodal_gen.configs.sample.base import SamplingParams + + +@dataclass +class StepVideoT2VSamplingParams(SamplingParams): + # Video parameters + height: int = 720 + width: int = 1280 + num_frames: int = 81 + + # Denoising stage + guidance_scale: float = 9.0 + num_inference_steps: int = 50 + + # neg magic and pos magic + # pos_magic: str = "超高清、HDR 视频、环境光、杜比全景声、画面稳定、流畅动作、逼真的细节、专业级构图、超现实主义、自然、生动、超细节、清晰。" + # neg_magic: str = "画面暗、低分辨率、不良手、文本、缺少手指、多余的手指、裁剪、低质量、颗粒状、签名、水印、用户名、模糊。" diff --git a/python/sglang/multimodal_gen/configs/sample/teacache.py b/python/sglang/multimodal_gen/configs/sample/teacache.py new file mode 100644 index 000000000..bec0cf884 --- /dev/null +++ b/python/sglang/multimodal_gen/configs/sample/teacache.py @@ -0,0 +1,43 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field + +from sglang.multimodal_gen.configs.sample.base import CacheParams + + +@dataclass +class TeaCacheParams(CacheParams): + cache_type: str = "teacache" + teacache_thresh: float = 0.0 + coefficients: list[float] = field(default_factory=list) + + +@dataclass +class WanTeaCacheParams(CacheParams): + # Unfortunately, TeaCache is very different for Wan than other models + cache_type: str = "teacache" + teacache_thresh: float = 0.0 + use_ret_steps: bool = True + ret_steps_coeffs: list[float] = field(default_factory=list) + non_ret_steps_coeffs: list[float] = field(default_factory=list) + + @property + def coefficients(self) -> list[float]: + if self.use_ret_steps: + return self.ret_steps_coeffs + else: + return self.non_ret_steps_coeffs + + @property + def ret_steps(self) -> int: + if self.use_ret_steps: + return 5 * 2 + else: + return 1 * 2 + + def get_cutoff_steps(self, num_inference_steps: int) -> int: + if self.use_ret_steps: + return num_inference_steps * 2 + else: + return num_inference_steps * 2 - 2 diff --git a/python/sglang/multimodal_gen/configs/sample/wan.py b/python/sglang/multimodal_gen/configs/sample/wan.py new file mode 100644 index 000000000..da2d2a58a --- /dev/null +++ b/python/sglang/multimodal_gen/configs/sample/wan.py @@ -0,0 +1,217 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field + +from sglang.multimodal_gen.configs.sample.base import SamplingParams +from sglang.multimodal_gen.configs.sample.teacache import WanTeaCacheParams + + +@dataclass +class WanT2V_1_3B_SamplingParams(SamplingParams): + # Video parameters + height: int = 480 + width: int = 832 + num_frames: int = 81 + fps: int = 16 + + # Denoising stage + guidance_scale: float = 3.0 + negative_prompt: str = ( + "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" + ) + num_inference_steps: int = 50 + + teacache_params: WanTeaCacheParams = field( + default_factory=lambda: WanTeaCacheParams( + teacache_thresh=0.08, + ret_steps_coeffs=[ + -5.21862437e04, + 9.23041404e03, + -5.28275948e02, + 1.36987616e01, + -4.99875664e-02, + ], + non_ret_steps_coeffs=[ + 2.39676752e03, + -1.31110545e03, + 2.01331979e02, + -8.29855975e00, + 1.37887774e-01, + ], + ) + ) + + +@dataclass +class WanT2V_14B_SamplingParams(SamplingParams): + # Video parameters + height: int = 720 + width: int = 1280 + num_frames: int = 81 + fps: int = 16 + + # Denoising stage + guidance_scale: float = 5.0 + negative_prompt: str = ( + "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" + ) + num_inference_steps: int = 50 + + teacache_params: WanTeaCacheParams = field( + default_factory=lambda: WanTeaCacheParams( + teacache_thresh=0.20, + use_ret_steps=False, + ret_steps_coeffs=[ + -3.03318725e05, + 4.90537029e04, + -2.65530556e03, + 5.87365115e01, + -3.15583525e-01, + ], + non_ret_steps_coeffs=[ + -5784.54975374, + 5449.50911966, + -1811.16591783, + 256.27178429, + -13.02252404, + ], + ) + ) + + +@dataclass +class WanI2V_14B_480P_SamplingParam(WanT2V_1_3B_SamplingParams): + # Denoising stage + guidance_scale: float = 5.0 + num_inference_steps: int = 50 + # num_inference_steps: int = 40 + + teacache_params: WanTeaCacheParams = field( + default_factory=lambda: WanTeaCacheParams( + teacache_thresh=0.26, + ret_steps_coeffs=[ + -3.03318725e05, + 4.90537029e04, + -2.65530556e03, + 5.87365115e01, + -3.15583525e-01, + ], + non_ret_steps_coeffs=[ + -5784.54975374, + 5449.50911966, + -1811.16591783, + 256.27178429, + -13.02252404, + ], + ) + ) + + +@dataclass +class WanI2V_14B_720P_SamplingParam(WanT2V_14B_SamplingParams): + # Denoising stage + guidance_scale: float = 5.0 + num_inference_steps: int = 50 + # num_inference_steps: int = 40 + + teacache_params: WanTeaCacheParams = field( + default_factory=lambda: WanTeaCacheParams( + teacache_thresh=0.3, + ret_steps_coeffs=[ + -3.03318725e05, + 4.90537029e04, + -2.65530556e03, + 5.87365115e01, + -3.15583525e-01, + ], + non_ret_steps_coeffs=[ + -5784.54975374, + 5449.50911966, + -1811.16591783, + 256.27178429, + -13.02252404, + ], + ) + ) + + +@dataclass +class FastWanT2V480PConfig(WanT2V_1_3B_SamplingParams): + # DMD parameters + # dmd_denoising_steps: list[int] | None = field(default_factory=lambda: [1000, 757, 522]) + num_inference_steps: int = 3 + num_frames: int = 61 + height: int = 448 + width: int = 832 + fps: int = 16 + + +# ============================================= +# ============= Wan2.1 Fun Models ============= +# ============================================= +@dataclass +class Wan2_1_Fun_1_3B_InP_SamplingParams(SamplingParams): + """Sampling parameters for Wan2.1 Fun 1.3B InP model.""" + + height: int = 480 + width: int = 832 + num_frames: int = 81 + fps: int = 16 + negative_prompt: str | None = ( + "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" + ) + guidance_scale: float = 6.0 + num_inference_steps: int = 50 + + +# ============================================= +# ============= Wan2.2 TI2V Models ============= +# ============================================= +@dataclass +class Wan2_2_Base_SamplingParams(SamplingParams): + """Sampling parameters for Wan2.2 TI2V 5B model.""" + + negative_prompt: str | None = ( + "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" + ) + + +@dataclass +class Wan2_2_TI2V_5B_SamplingParam(Wan2_2_Base_SamplingParams): + """Sampling parameters for Wan2.2 TI2V 5B model.""" + + height: int = 704 + width: int = 1280 + num_frames: int = 121 + fps: int = 24 + guidance_scale: float = 5.0 + num_inference_steps: int = 50 + + +@dataclass +class Wan2_2_T2V_A14B_SamplingParam(Wan2_2_Base_SamplingParams): + guidance_scale: float = 4.0 # high_noise + guidance_scale_2: float = 3.0 # low_noise + num_inference_steps: int = 40 + fps: int = 16 + # NOTE(will): default boundary timestep is tracked by PipelineConfig, but + # can be overridden during sampling + + +@dataclass +class Wan2_2_I2V_A14B_SamplingParam(Wan2_2_Base_SamplingParams): + guidance_scale: float = 3.5 # high_noise + guidance_scale_2: float = 3.5 # low_noise + num_inference_steps: int = 40 + fps: int = 16 + # NOTE(will): default boundary timestep is tracked by PipelineConfig, but + # can be overridden during sampling + + +# ============================================= +# ============= Causal Self-Forcing ============= +# ============================================= +@dataclass +class SelfForcingWanT2V480PConfig(WanT2V_1_3B_SamplingParams): + pass diff --git a/python/sglang/multimodal_gen/configs/utils.py b/python/sglang/multimodal_gen/configs/utils.py new file mode 100644 index 000000000..d2cc69adb --- /dev/null +++ b/python/sglang/multimodal_gen/configs/utils.py @@ -0,0 +1,61 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +import argparse +from typing import Any + + +def update_config_from_args( + config: Any, args_dict: dict[str, Any], prefix: str = "", pop_args: bool = False +) -> bool: + """ + Update configuration object from arguments dictionary. + + Args: + config: The configuration object to update + args_dict: Dictionary containing arguments + prefix: Prefix for the configuration parameters in the args_dict. + If None, assumes direct attribute mapping without prefix. + """ + # Handle top-level attributes (no prefix) + args_not_to_remove = [ + "model_path", + ] + args_to_remove = [] + if prefix.strip() == "": + for key, value in args_dict.items(): + if hasattr(config, key) and value is not None: + if key == "text_encoder_precisions" and isinstance(value, list): + setattr(config, key, tuple(value)) + else: + setattr(config, key, value) + if pop_args: + args_to_remove.append(key) + else: + # Handle nested attributes with prefix + prefix_with_dot = f"{prefix}." + for key, value in args_dict.items(): + if key.startswith(prefix_with_dot) and value is not None: + attr_name = key[len(prefix_with_dot) :] + if hasattr(config, attr_name): + setattr(config, attr_name, value) + if pop_args: + args_to_remove.append(key) + + if pop_args: + for key in args_to_remove: + if key not in args_not_to_remove: + args_dict.pop(key) + + return len(args_to_remove) > 0 + + +def clean_cli_args(args: argparse.Namespace) -> dict[str, Any]: + """ + Clean the arguments by removing the ones that not explicitly provided by the user. + """ + provided_args = {} + for k, v in vars(args).items(): + if v is not None and hasattr(args, "_provided") and k in args._provided: + provided_args[k] = v + + return provided_args diff --git a/python/sglang/multimodal_gen/configs/wan_1.3B_t2v_pipeline.json b/python/sglang/multimodal_gen/configs/wan_1.3B_t2v_pipeline.json new file mode 100644 index 000000000..724c9cebd --- /dev/null +++ b/python/sglang/multimodal_gen/configs/wan_1.3B_t2v_pipeline.json @@ -0,0 +1,41 @@ +{ + "embedded_cfg_scale": 6.0, + "flow_shift": 3, + "dit_cpu_offload": true, + "disable_autocast": false, + "precision": "bf16", + "vae_precision": "fp32", + "vae_tiling": false, + "vae_sp": false, + "vae_config": { + "load_encoder": false, + "load_decoder": true, + "tile_sample_min_height": 256, + "tile_sample_min_width": 256, + "tile_sample_min_num_frames": 16, + "tile_sample_stride_height": 192, + "tile_sample_stride_width": 192, + "tile_sample_stride_num_frames": 12, + "blend_num_frames": 8, + "use_tiling": false, + "use_temporal_tiling": false, + "use_parallel_tiling": false, + "use_feature_cache": true + }, + "dit_config": { + "prefix": "Wan", + "quant_config": null + }, + "text_encoder_precisions": [ + "fp32" + ], + "text_encoder_configs": [ + { + "prefix": "t5", + "quant_config": null, + "lora_config": null + } + ], + "mask_strategy_file_path": null, + "enable_torch_compile": false +} diff --git a/python/sglang/multimodal_gen/configs/wan_14B_i2v_480p_pipeline.json b/python/sglang/multimodal_gen/configs/wan_14B_i2v_480p_pipeline.json new file mode 100644 index 000000000..3bb7b3e2a --- /dev/null +++ b/python/sglang/multimodal_gen/configs/wan_14B_i2v_480p_pipeline.json @@ -0,0 +1,49 @@ +{ + "embedded_cfg_scale": 6.0, + "flow_shift": 3, + "dit_cpu_offload": true, + "disable_autocast": false, + "precision": "bf16", + "vae_precision": "fp32", + "vae_tiling": false, + "vae_sp": false, + "vae_config": { + "load_encoder": true, + "load_decoder": true, + "tile_sample_min_height": 256, + "tile_sample_min_width": 256, + "tile_sample_min_num_frames": 16, + "tile_sample_stride_height": 192, + "tile_sample_stride_width": 192, + "tile_sample_stride_num_frames": 12, + "blend_num_frames": 8, + "use_tiling": false, + "use_temporal_tiling": false, + "use_parallel_tiling": false, + "use_feature_cache": true + }, + "dit_config": { + "prefix": "Wan", + "quant_config": null + }, + "text_encoder_precisions": [ + "fp32" + ], + "text_encoder_configs": [ + { + "prefix": "t5", + "quant_config": null, + "lora_config": null + } + ], + "mask_strategy_file_path": null, + "enable_torch_compile": false, + "image_encoder_config": { + "prefix": "clip", + "quant_config": null, + "lora_config": null, + "num_hidden_layers_override": null, + "require_post_norm": null + }, + "image_encoder_precision": "fp32" +} diff --git a/python/sglang/multimodal_gen/csrc/attn/vmoba_attn/README.md b/python/sglang/multimodal_gen/csrc/attn/vmoba_attn/README.md new file mode 100644 index 000000000..3fc24e366 --- /dev/null +++ b/python/sglang/multimodal_gen/csrc/attn/vmoba_attn/README.md @@ -0,0 +1,31 @@ +# Attention Kernel Used in sgl-diffusion + +## VMoBA: Mixture-of-Block Attention for Video Diffusion Models (VMoBA) + +### Installation +Please ensure that you have installed FlashAttention version **2.7.1 or higher**, as some interfaces have changed in recent releases. + +### Usage + +You can use `moba_attn_varlen` in the following ways: + +**Install from source:** +```bash +python setup.py install +``` + +**Import after installation:** +```python +from vmoba import moba_attn_varlen +``` + +**Or import directly from the project root:** +```python +from csrc.attn.vmoba_attn.vmoba import moba_attn_varlen +``` + +### Verify if you have successfully installed + +```bash +python csrc/attn/vmoba_attn/vmoba/vmoba.py +``` diff --git a/python/sglang/multimodal_gen/csrc/attn/vmoba_attn/setup.py b/python/sglang/multimodal_gen/csrc/attn/vmoba_attn/setup.py new file mode 100644 index 000000000..3a1bdb67f --- /dev/null +++ b/python/sglang/multimodal_gen/csrc/attn/vmoba_attn/setup.py @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: Apache-2.0 + +from setuptools import find_packages, setup + +PACKAGE_NAME = "vmoba" +VERSION = "0.0.0" +AUTHOR = "JianzongWu" +DESCRIPTION = "VMoBA: Mixture-of-Block Attention for Video Diffusion Models" +URL = "https://github.com/KwaiVGI/VMoBA" + +setup( + name=PACKAGE_NAME, + version=VERSION, + author=AUTHOR, + description=DESCRIPTION, + url=URL, + packages=find_packages(), + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + ], + python_requires=">=3.12", + install_requires=[ + "flash-attn >= 2.7.1", + ], +) diff --git a/python/sglang/multimodal_gen/csrc/attn/vmoba_attn/tests/test_vmoba_attn.py b/python/sglang/multimodal_gen/csrc/attn/vmoba_attn/tests/test_vmoba_attn.py new file mode 100644 index 000000000..f4304bda4 --- /dev/null +++ b/python/sglang/multimodal_gen/csrc/attn/vmoba_attn/tests/test_vmoba_attn.py @@ -0,0 +1,137 @@ +# SPDX-License-Identifier: Apache-2.0 + +import random + +import pytest +import torch +from csrc.attn.vmoba_attn.vmoba import moba_attn_varlen + + +def generate_test_data( + batch_size, total_seqlen, num_heads, head_dim, dtype, device="cuda" +): + """ + Generates random data for testing the variable-length attention function. + """ + torch.manual_seed(42) + random.seed(42) + torch.cuda.manual_seed_all(42) + + # Generate sequence lengths for each item in the batch + if batch_size > 1: + # Ensure sequence lengths are reasonably distributed + avg_seqlen = total_seqlen // batch_size + seqlens = [ + random.randint(avg_seqlen // 2, avg_seqlen + avg_seqlen // 2) + for _ in range(batch_size - 1) + ] + remaining_len = total_seqlen - sum(seqlens) + if remaining_len > 0: + seqlens.append(remaining_len) + else: # Adjust if sum exceeds total_seqlen + seqlens.append(avg_seqlen) + current_sum = sum(seqlens) + seqlens[-1] -= current_sum - total_seqlen + # Ensure all lengths are positive + seqlens = [max(1, s) for s in seqlens] + # Final adjustment to match total_seqlen + seqlens[-1] += total_seqlen - sum(seqlens) + + else: + seqlens = [total_seqlen] + + cu_seqlens = torch.tensor( + [0] + list(torch.cumsum(torch.tensor(seqlens), 0)), + device=device, + dtype=torch.int32, + ) + max_seqlen = max(seqlens) if seqlens else 0 + + q = torch.randn( + (total_seqlen, num_heads, head_dim), + dtype=dtype, + device=device, + requires_grad=False, + ) + k = torch.randn( + (total_seqlen, num_heads, head_dim), + dtype=dtype, + device=device, + requires_grad=False, + ) + v = torch.randn( + (total_seqlen, num_heads, head_dim), + dtype=dtype, + device=device, + requires_grad=False, + ) + + return q, k, v, cu_seqlens, max_seqlen + + +@pytest.mark.parametrize("batch_size", [1, 2]) +@pytest.mark.parametrize("total_seqlen", [512, 1024]) +@pytest.mark.parametrize("num_heads", [8]) +@pytest.mark.parametrize("head_dim", [64]) +@pytest.mark.parametrize("moba_chunk_size", [64]) +@pytest.mark.parametrize("moba_topk", [2, 4]) +@pytest.mark.parametrize("select_mode", ["topk", "threshold"]) +@pytest.mark.parametrize("threshold_type", ["query_head", "head_global", "overall"]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +def test_moba_attn_varlen_forward( + batch_size, + total_seqlen, + num_heads, + head_dim, + moba_chunk_size, + moba_topk, + select_mode, + threshold_type, + dtype, +): + """ + Tests the forward pass of moba_attn_varlen for basic correctness. + It checks output shape, dtype, and for the presence of NaNs/Infs. + """ + if dtype == torch.float32: + pytest.skip("float32 is not supported in flash attention") + + q, k, v, cu_seqlens, max_seqlen = generate_test_data( + batch_size, total_seqlen, num_heads, head_dim, dtype + ) + + # Ensure chunk size is not larger than the smallest sequence length + min_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).min().item() + if moba_chunk_size > min_seqlen: + pytest.skip( + "moba_chunk_size is larger than the minimum sequence length in the batch" + ) + + try: + output = moba_attn_varlen( + q=q, + k=k, + v=v, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + moba_chunk_size=moba_chunk_size, + moba_topk=moba_topk, + select_mode=select_mode, + threshold_type=threshold_type, + simsum_threshold=0.5, # A reasonable default for threshold mode + ) + except Exception as e: + pytest.fail(f"moba_attn_varlen forward pass failed with exception: {e}") + + # 1. Check output shape + assert ( + output.shape == q.shape + ), f"Expected output shape {q.shape}, but got {output.shape}" + + # 2. Check output dtype + assert ( + output.dtype == q.dtype + ), f"Expected output dtype {q.dtype}, but got {output.dtype}" + + # 3. Check for NaNs or Infs in the output + assert torch.all(torch.isfinite(output)), "Output contains NaN or Inf values" diff --git a/python/sglang/multimodal_gen/csrc/attn/vmoba_attn/vmoba/__init__.py b/python/sglang/multimodal_gen/csrc/attn/vmoba_attn/vmoba/__init__.py new file mode 100644 index 000000000..8119387c3 --- /dev/null +++ b/python/sglang/multimodal_gen/csrc/attn/vmoba_attn/vmoba/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +from .vmoba import moba_attn_varlen, process_moba_input, process_moba_output diff --git a/python/sglang/multimodal_gen/csrc/attn/vmoba_attn/vmoba/vmoba.py b/python/sglang/multimodal_gen/csrc/attn/vmoba_attn/vmoba/vmoba.py new file mode 100644 index 000000000..8a29360a9 --- /dev/null +++ b/python/sglang/multimodal_gen/csrc/attn/vmoba_attn/vmoba/vmoba.py @@ -0,0 +1,1086 @@ +# SPDX-License-Identifier: Apache-2.0 +# Adapt from https://github.com/KwaiVGI/VMoBA/blob/main/src/vmoba.py + +import random +import time +from typing import Tuple + +import torch + +try: + from flash_attn import ( # Use the new flash attention function + flash_attn_varlen_func, + ) + from flash_attn.flash_attn_interface import ( + _flash_attn_varlen_backward, + _flash_attn_varlen_forward, + ) +except ImportError: + + def _unsupported(*args, **kwargs): + raise ImportError( + "flash-attn is not installed. Please install it, e.g., `pip install flash-attn`." + ) + + _flash_attn_varlen_forward = _unsupported + _flash_attn_varlen_backward = _unsupported + flash_attn_varlen_func = _unsupported + +from functools import lru_cache + +from einops import rearrange + + +@lru_cache(maxsize=16) +def calc_chunks(cu_seqlen, moba_chunk_size): + """ + Calculate chunk boundaries. + + For vision tasks we include all chunks (even the last one which might be shorter) + so that every chunk can be selected. + """ + batch_sizes = cu_seqlen[1:] - cu_seqlen[:-1] + batch_num_chunk = (batch_sizes + (moba_chunk_size - 1)) // moba_chunk_size + cu_num_chunk = torch.ones( + batch_num_chunk.numel() + 1, + device=cu_seqlen.device, + dtype=batch_num_chunk.dtype, + ) + cu_num_chunk[1:] = batch_num_chunk.cumsum(dim=0) + num_chunk = cu_num_chunk[-1] + chunk_sizes = torch.full( + (num_chunk + 1,), moba_chunk_size, dtype=torch.int32, device=cu_seqlen.device + ) + chunk_sizes[0] = 0 + batch_last_chunk_size = batch_sizes - (batch_num_chunk - 1) * moba_chunk_size + chunk_sizes[cu_num_chunk[1:]] = batch_last_chunk_size + cu_chunk = chunk_sizes.cumsum(dim=-1, dtype=torch.int32) + chunk_to_batch = torch.zeros( + (num_chunk,), dtype=torch.int32, device=cu_seqlen.device + ) + chunk_to_batch[cu_num_chunk[1:-1]] = 1 + chunk_to_batch = chunk_to_batch.cumsum(dim=0, dtype=torch.int32) + + # Do not filter out any chunk + filtered_chunk_indices = torch.arange( + num_chunk, device=cu_seqlen.device, dtype=torch.int32 + ) + num_filtered_chunk = num_chunk + + return cu_chunk, filtered_chunk_indices, num_filtered_chunk, chunk_to_batch + + +# --- Threshold Selection Helper Functions --- + + +def _select_threshold_query_head( + gate: torch.Tensor, + valid_gate_mask: torch.Tensor, + gate_self_chunk_mask: torch.Tensor, + simsum_threshold: float, +) -> torch.Tensor: + """ + Selects chunks for each pair based on threshold. + Normalization and sorting happen along the chunk dimension (dim=0). + """ + C, H, S = gate.shape + eps = 1e-6 + + # LSE‐style normalization per (across chunks) + gate_masked = torch.where(valid_gate_mask, gate, -torch.inf) # Use -inf for max + gate_min_val = torch.where(valid_gate_mask, gate, torch.inf) # Use +inf for min + + row_min = gate_min_val.amin(dim=0) # (H, S) + row_max = gate_masked.amax(dim=0) # (H, S) + denom = row_max - row_min + denom = torch.where( + denom <= eps, torch.ones_like(denom), denom + ) # avoid divide‑by‑zero + + gate_norm = (gate - row_min.unsqueeze(0)) / denom.unsqueeze(0) + gate_norm = torch.where(valid_gate_mask, gate_norm, 0.0) # (C, H, S) + + # 1) pull out the self‐chunk’s normalized weight for each + self_norm = (gate_norm * gate_self_chunk_mask).sum(dim=0) # (H, S) + + # 2) compute how much more normalized weight we need beyond self + total_norm_sum = gate_norm.sum(dim=0) # (H, S) + remain_ratio = simsum_threshold - self_norm / (total_norm_sum + eps) # (H, S) + remain_ratio = torch.clamp( + remain_ratio, min=0.0 + ) # if already ≥ thresh, no extra needed + + # 3) zero out the self‐chunk in a copy, so we only sort “others” + others_norm = gate_norm.clone() + others_norm[gate_self_chunk_mask] = 0.0 + + # 4) sort the other chunks by descending norm, per + sorted_norm, sorted_idx = torch.sort( + others_norm, descending=True, dim=0 + ) # (C, H, S) + + # 5) cumulative‑sum the sorted norms per + cumsum_others = sorted_norm.cumsum(dim=0) # (C, H, S) + + # 6) for each , find the smallest k where cumsum_ratio ≥ remain_ratio + ratio = cumsum_others / (total_norm_sum.unsqueeze(0) + eps) # (C, H, S) + cond = ratio >= remain_ratio.unsqueeze(0) # (C, H, S) boolean mask + any_cond = cond.any(dim=0) # (H, S) + # Find the index of the first True value along dim 0. If none, use C-1. + cutoff = torch.where( + any_cond, + cond.float().argmax(dim=0), + torch.full_like(any_cond, fill_value=C - 1), + ) # (H, S) + + # 7) build a mask in sorted order up to that cutoff + idx_range = torch.arange(C, device=gate.device).view(-1, 1, 1) # (C, 1, 1) + sorted_mask = idx_range <= cutoff.unsqueeze(0) # (C, H, S) + + # 8) scatter it back to original chunk order + others_mask = torch.zeros_like(gate, dtype=torch.bool) + others_mask.scatter_(0, sorted_idx, sorted_mask) + + # 9) finally, include every self‐chunk plus all selected others + final_gate_mask = valid_gate_mask & (others_mask | gate_self_chunk_mask) + + return final_gate_mask + + +def _select_threshold_block( + gate: torch.Tensor, + valid_gate_mask: torch.Tensor, + gate_self_chunk_mask: torch.Tensor, + simsum_threshold: float, +) -> torch.Tensor: + """ + Selects pairs for each block based on threshold. + Normalization and sorting happen across the head and sequence dimensions (dim=1, 2). + """ + C, H, S = gate.shape + HS = H * S + eps = 1e-6 + + # LSE‐style normalization per block (across heads and queries) + gate_masked = torch.where(valid_gate_mask, gate, -torch.inf) # Use -inf for max + gate_min_val = torch.where(valid_gate_mask, gate, torch.inf) # Use +inf for min + + block_max = gate_masked.amax(dim=(1, 2), keepdim=True) # (C, 1, 1) + block_min = gate_min_val.amin(dim=(1, 2), keepdim=True) # (C, 1, 1) + block_denom = block_max - block_min + block_denom = torch.where( + block_denom <= eps, torch.ones_like(block_denom), block_denom + ) # (C, 1, 1) + + gate_norm = (gate - block_min) / block_denom # (C, H, S) + gate_norm = torch.where(valid_gate_mask, gate_norm, 0.0) # (C, H, S) + + # 1) identify normalized weights of entries that *are* self-chunks (from query perspective) + self_norm_entries = gate_norm * gate_self_chunk_mask # (C, H, S) + # Sum these weights *per block* + self_norm_sum_per_block = self_norm_entries.sum(dim=(1, 2)) # (C,) + + # 2) compute how much more normalized weight each block needs beyond its self-chunk contributions + total_norm_sum_per_block = gate_norm.sum(dim=(1, 2)) # (C,) + remain_ratio = simsum_threshold - self_norm_sum_per_block / ( + total_norm_sum_per_block + eps + ) # (C,) + remain_ratio = torch.clamp(remain_ratio, min=0.0) # (C,) + + # 3) zero out the self‐chunk entries in a copy, so we only sort “others” + others_norm = gate_norm.clone() + others_norm[gate_self_chunk_mask] = 0.0 # Zero out self entries + + # 4) sort the other pairs by descending norm, per block + others_flat = others_norm.contiguous().view(C, HS) # (C, H*S) + sorted_others_flat, sorted_indices_flat = torch.sort( + others_flat, dim=1, descending=True + ) # (C, H*S) + + # 5) cumulative‑sum the sorted norms per block + cumsum_others_flat = sorted_others_flat.cumsum(dim=1) # (C, H*S) + + # 6) for each block, find the smallest k where cumsum_ratio ≥ remain_ratio + ratio_flat = cumsum_others_flat / ( + total_norm_sum_per_block.unsqueeze(1) + eps + ) # (C, H*S) + cond_flat = ratio_flat >= remain_ratio.unsqueeze(1) # (C, H*S) boolean mask + any_cond = cond_flat.any(dim=1) # (C,) + # Find the index of the first True value along dim 1. If none, use HS-1. + cutoff_flat = torch.where( + any_cond, + cond_flat.float().argmax(dim=1), + torch.full_like(any_cond, fill_value=HS - 1), + ) # (C,) + + # 7) build a mask in sorted order up to that cutoff per block + idx_range_flat = torch.arange(HS, device=gate.device).unsqueeze(0) # (1, H*S) + sorted_mask_flat = idx_range_flat <= cutoff_flat.unsqueeze(1) # (C, H*S) + + # 8) scatter it back to original order per block + others_mask_flat = torch.zeros_like(others_flat, dtype=torch.bool) # (C, H*S) + others_mask_flat.scatter_(1, sorted_indices_flat, sorted_mask_flat) + others_mask = others_mask_flat.view(C, H, S) # (C, H, S) + + # 9) finally, include every self‐chunk entry plus all selected others + final_gate_mask = valid_gate_mask & (others_mask | gate_self_chunk_mask) + + return final_gate_mask + + +def _select_threshold_overall( + gate: torch.Tensor, + valid_gate_mask: torch.Tensor, + gate_self_chunk_mask: torch.Tensor, + simsum_threshold: float, +) -> torch.Tensor: + """ + Selects triplets globally based on threshold. + Normalization and sorting happen across all valid entries. + """ + C, H, S = gate.shape + CHS = C * H * S + eps = 1e-6 + + # LSE‐style normalization globally across all valid entries + gate_masked = torch.where(valid_gate_mask, gate, -torch.inf) # Use -inf for max + gate_min_val = torch.where(valid_gate_mask, gate, torch.inf) # Use +inf for min + + overall_max = gate_masked.max() # scalar + overall_min = gate_min_val.min() # scalar + overall_denom = overall_max - overall_min + overall_denom = torch.where( + overall_denom <= eps, + torch.tensor(1.0, device=gate.device, dtype=gate.dtype), + overall_denom, + ) + + gate_norm = (gate - overall_min) / overall_denom # (C, H, S) + gate_norm = torch.where(valid_gate_mask, gate_norm, 0.0) # (C, H, S) + + # 1) identify normalized weights of entries that *are* self-chunks + self_norm_entries = gate_norm * gate_self_chunk_mask # (C, H, S) + # Sum these weights globally + self_norm_sum_overall = self_norm_entries.sum() # scalar + + # 2) compute how much more normalized weight is needed globally beyond self-chunk contributions + total_norm_sum_overall = gate_norm.sum() # scalar + remain_ratio = simsum_threshold - self_norm_sum_overall / ( + total_norm_sum_overall + eps + ) # scalar + remain_ratio = torch.clamp(remain_ratio, min=0.0) # scalar + + # 3) zero out the self‐chunk entries in a copy, so we only sort “others” + others_norm = gate_norm.clone() + others_norm[gate_self_chunk_mask] = 0.0 # Zero out self entries + + # 4) sort all other entries by descending norm, globally + others_flat = others_norm.flatten() # (C*H*S,) + valid_others_mask_flat = ( + valid_gate_mask.flatten() & ~gate_self_chunk_mask.flatten() + ) # Mask for valid, non-self entries + + # Only sort the valid 'other' entries + valid_others_indices = torch.where(valid_others_mask_flat)[0] + valid_others_values = others_flat[valid_others_indices] + + sorted_others_values, sort_perm = torch.sort( + valid_others_values, descending=True + ) # (N_valid_others,) + sorted_original_indices = valid_others_indices[ + sort_perm + ] # Original indices in C*H*S space, sorted by value + + # 5) cumulative‑sum the sorted valid 'other' norms globally + cumsum_others_values = sorted_others_values.cumsum(dim=0) # (N_valid_others,) + + # 6) find the smallest k where cumsum_ratio ≥ remain_ratio globally + ratio_values = cumsum_others_values / ( + total_norm_sum_overall + eps + ) # (N_valid_others,) + cond_values = ratio_values >= remain_ratio # (N_valid_others,) boolean mask + any_cond = cond_values.any() # scalar + + # Find the index of the first True value in the *sorted* list. If none, use all valid others. + cutoff_idx_in_sorted = torch.where( + any_cond, + cond_values.float().argmax(dim=0), + torch.tensor( + len(sorted_others_values) - 1, device=gate.device, dtype=torch.long + ), + ) + + # 7) build a mask selecting the top-k others based on the cutoff + # Select the original indices corresponding to the top entries in the sorted list + selected_other_indices = sorted_original_indices[: cutoff_idx_in_sorted + 1] + + # 8) create the mask in the original flat shape + others_mask_flat = torch.zeros_like(others_flat, dtype=torch.bool) # (C*H*S,) + if selected_other_indices.numel() > 0: # Check if any 'other' indices were selected + others_mask_flat[selected_other_indices] = True + others_mask = others_mask_flat.view(C, H, S) # (C, H, S) + + # 9) finally, include every self‐chunk entry plus all selected others + final_gate_mask = valid_gate_mask & (others_mask | gate_self_chunk_mask) + + return final_gate_mask + + +def _select_threshold_head_global( + gate: torch.Tensor, + valid_gate_mask: torch.Tensor, + gate_self_chunk_mask: torch.Tensor, + simsum_threshold: float, +) -> torch.Tensor: + """ + Selects globally for each head based on threshold. + """ + C, H, S = gate.shape + eps = 1e-6 + + # 1) LSE‐style normalization per head (across chunks and sequence dims) + gate_masked = torch.where(valid_gate_mask, gate, -torch.inf) + gate_min_val = torch.where(valid_gate_mask, gate, torch.inf) + + max_per_head = gate_masked.amax(dim=(0, 2), keepdim=True) # (1, H, 1) + min_per_head = gate_min_val.amin(dim=(0, 2), keepdim=True) # (1, H, 1) + denom = max_per_head - min_per_head + denom = torch.where(denom <= eps, torch.ones_like(denom), denom) + + gate_norm = (gate - min_per_head) / denom + gate_norm = torch.where(valid_gate_mask, gate_norm, 0.0) # (C, H, S) + + # 2) sum normalized self‐chunk contributions per head + self_norm_sum = (gate_norm * gate_self_chunk_mask).sum(dim=(0, 2)) # (H,) + + # 3) total normalized sum per head + total_norm_sum = gate_norm.sum(dim=(0, 2)) # (H,) + + # 4) how much more normalized weight needed per head + remain_ratio = simsum_threshold - self_norm_sum / (total_norm_sum + eps) # (H,) + remain_ratio = torch.clamp(remain_ratio, min=0.0) + + # 5) zero out self‐chunk entries to focus on "others" + others_norm = gate_norm.clone() + others_norm[gate_self_chunk_mask] = 0.0 # (C, H, S) + + # 6) flatten chunk and sequence dims, per head + CS = C * S + others_flat = others_norm.permute(1, 0, 2).reshape(H, CS) # (H, C*S) + valid_flat = ( + (valid_gate_mask & ~gate_self_chunk_mask).permute(1, 0, 2).reshape(H, CS) + ) # (H, C*S) + + # 7) vectorized selection of “others” per head + masked_flat = torch.where(valid_flat, others_flat, torch.zeros_like(others_flat)) + sorted_vals, sorted_idx = torch.sort( + masked_flat, dim=1, descending=True + ) # (H, C*S) + + cumsum_vals = sorted_vals.cumsum(dim=1) # (H, C*S) + ratio_vals = cumsum_vals / (total_norm_sum.unsqueeze(1) + eps) # (H, C*S) + cond = ratio_vals >= remain_ratio.unsqueeze(1) # (H, C*S) + + has_cutoff = cond.any(dim=1) # (H,) + default = torch.full((H,), CS - 1, device=gate.device, dtype=torch.long) + cutoff = torch.where(has_cutoff, cond.float().argmax(dim=1), default) # (H,) + + idx_range = torch.arange(CS, device=gate.device).unsqueeze(0) # (1, C*S) + sorted_mask = idx_range <= cutoff.unsqueeze(1) # (H, C*S) + + selected_flat = torch.zeros_like(valid_flat) # (H, C*S) + selected_flat.scatter_(1, sorted_idx, sorted_mask) # (H, C*S) + + # 8) reshape selection mask back to (C, H, S) + others_mask = selected_flat.reshape(H, C, S).permute(1, 0, 2) # (C, H, S) + + # 9) include self‐chunks plus selected others, and obey valid mask + final_gate_mask = valid_gate_mask & (gate_self_chunk_mask | others_mask) + + return final_gate_mask + + +class MixedAttention(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + self_attn_cu_seqlen, + moba_q, + moba_kv, + moba_cu_seqlen_q, + moba_cu_seqlen_kv, + max_seqlen, + moba_chunk_size, + moba_q_sh_indices, + ): + ctx.max_seqlen = max_seqlen + ctx.moba_chunk_size = moba_chunk_size + ctx.softmax_scale = softmax_scale = q.shape[-1] ** (-0.5) + + # Non-causal self-attention branch + # return out, softmax_lse, S_dmask, rng_state + self_attn_out_sh, self_attn_lse_hs, _, _ = _flash_attn_varlen_forward( + q=q, + k=k, + v=v, + cu_seqlens_q=self_attn_cu_seqlen, + cu_seqlens_k=self_attn_cu_seqlen, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + softmax_scale=softmax_scale, + causal=False, + dropout_p=0.0, + ) + # MOBA attention branch (non-causal) + moba_attn_out, moba_attn_lse_hs, _, _ = _flash_attn_varlen_forward( + q=moba_q, + k=moba_kv[:, 0], + v=moba_kv[:, 1], + cu_seqlens_q=moba_cu_seqlen_q, + cu_seqlens_k=moba_cu_seqlen_kv, + max_seqlen_q=max_seqlen, + max_seqlen_k=moba_chunk_size, + softmax_scale=softmax_scale, + causal=False, + dropout_p=0.0, + ) + + self_attn_lse_sh = self_attn_lse_hs.t().contiguous() + moba_attn_lse = moba_attn_lse_hs.t().contiguous() + + output = torch.zeros( + (q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32 + ) + output_2d = output.view(-1, q.shape[2]) + + max_lse_1d = self_attn_lse_sh.view(-1) + max_lse_1d = max_lse_1d.index_reduce( + 0, moba_q_sh_indices, moba_attn_lse.view(-1), "amax" + ) + self_attn_lse_sh = self_attn_lse_sh - max_lse_1d.view_as(self_attn_lse_sh) + moba_attn_lse = ( + moba_attn_lse.view(-1) + .sub(max_lse_1d.index_select(0, moba_q_sh_indices)) + .reshape_as(moba_attn_lse) + ) + + mixed_attn_se_sh = self_attn_lse_sh.exp() + moba_attn_se = moba_attn_lse.exp() + + mixed_attn_se_sh.view(-1).index_add_( + 0, moba_q_sh_indices, moba_attn_se.view(-1) + ) + mixed_attn_lse_sh = mixed_attn_se_sh.log() + + # Combine self-attention output + factor = (self_attn_lse_sh - mixed_attn_lse_sh).exp() # [S, H] + self_attn_out_sh = self_attn_out_sh * factor.unsqueeze(-1) + output_2d += self_attn_out_sh.reshape_as(output_2d) + + # Combine MOBA attention output + mixed_attn_lse = ( + mixed_attn_lse_sh.view(-1) + .index_select(0, moba_q_sh_indices) + .view_as(moba_attn_lse) + ) + factor = (moba_attn_lse - mixed_attn_lse).exp() # [S, H] + moba_attn_out = moba_attn_out * factor.unsqueeze(-1) + raw_attn_out = moba_attn_out.view(-1, moba_attn_out.shape[-1]) + output_2d.index_add_(0, moba_q_sh_indices, raw_attn_out) + output = output.to(q.dtype) + mixed_attn_lse_sh = mixed_attn_lse_sh + max_lse_1d.view_as(mixed_attn_se_sh) + ctx.save_for_backward( + output, + mixed_attn_lse_sh, + q, + k, + v, + self_attn_cu_seqlen, + moba_q, + moba_kv, + moba_cu_seqlen_q, + moba_cu_seqlen_kv, + moba_q_sh_indices, + ) + + return output + + @staticmethod + def backward(ctx, d_output): + + max_seqlen = ctx.max_seqlen + moba_chunk_size = ctx.moba_chunk_size + softmax_scale = ctx.softmax_scale + + ( + output, + mixed_attn_vlse_sh, + q, + k, + v, + self_attn_cu_seqlen, + moba_q, + moba_kv, + moba_cu_seqlen_q, + moba_cu_seqlen_kv, + moba_q_sh_indices, + ) = ctx.saved_tensors + + d_output = d_output.contiguous() + + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + _ = _flash_attn_varlen_backward( + dout=d_output, + q=q, + k=k, + v=v, + out=output, + softmax_lse=mixed_attn_vlse_sh.t().contiguous(), + dq=dq, + dk=dk, + dv=dv, + cu_seqlens_q=self_attn_cu_seqlen, + cu_seqlens_k=self_attn_cu_seqlen, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + softmax_scale=softmax_scale, + causal=False, + dropout_p=0.0, + softcap=0.0, + alibi_slopes=None, + deterministic=True, + window_size_left=-1, + window_size_right=-1, + ) + + headdim = q.shape[-1] + d_moba_output = ( + d_output.view(-1, headdim).index_select(0, moba_q_sh_indices).unsqueeze(1) + ) + moba_output = ( + output.view(-1, headdim).index_select(0, moba_q_sh_indices).unsqueeze(1) + ) + + mixed_attn_vlse = ( + mixed_attn_vlse_sh.view(-1).index_select(0, moba_q_sh_indices).view(1, -1) + ) + + dmq = torch.empty_like(moba_q) + dmkv = torch.empty_like(moba_kv) + _ = _flash_attn_varlen_backward( + dout=d_moba_output, + q=moba_q, + k=moba_kv[:, 0], + v=moba_kv[:, 1], + out=moba_output, + softmax_lse=mixed_attn_vlse, + dq=dmq, + dk=dmkv[:, 0], + dv=dmkv[:, 1], + cu_seqlens_q=moba_cu_seqlen_q, + cu_seqlens_k=moba_cu_seqlen_kv, + max_seqlen_q=max_seqlen, + max_seqlen_k=moba_chunk_size, + softmax_scale=softmax_scale, + causal=False, + dropout_p=0.0, + softcap=0.0, + alibi_slopes=None, + deterministic=True, + window_size_left=-1, + window_size_right=-1, + ) + + return dq, dk, dv, None, dmq, dmkv, None, None, None, None, None + + +def moba_attn_varlen( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + moba_chunk_size: int, + moba_topk: int, + select_mode: str = "threshold", # "topk" or "threshold" + simsum_threshold: float = 0.25, + threshold_type: str = "query_head", +) -> torch.Tensor: + """ + Accelerated MOBA attention for vision tasks with proper LSE normalization. + + This version: + - Splits KV into chunks. + - For each query head, selects the top-k relevant KV chunks (including the self chunk) + by amplifying the diagonal (self-chunk) logits. + - Aggregates the attention outputs from the selected chunks using a log-sum-exp + reduction so that attending to each query over the selected chunks is equivalent + to the original algorithm. + """ + # Stack keys and values. + kv = torch.stack((k, v), dim=1) + seqlen, num_head, head_dim = q.shape + + # Compute chunk boundaries. + cu_chunk, filtered_chunk_indices, num_filtered_chunk, chunk_to_batch = calc_chunks( + cu_seqlens, moba_chunk_size + ) + + self_attn_cu_seqlen = cu_chunk + + # Update top-k selection to include the self chunk. + moba_topk = min(moba_topk, num_filtered_chunk) + + # --- Build filtered KV from chunks --- + chunk_starts = cu_chunk[filtered_chunk_indices] # [num_filtered_chunk] + chunk_ends = cu_chunk[filtered_chunk_indices + 1] # [num_filtered_chunk] + chunk_lengths = chunk_ends - chunk_starts # [num_filtered_chunk] + max_chunk_len = int(chunk_lengths.max().item()) + + range_tensor = torch.arange( + max_chunk_len, device=kv.device, dtype=chunk_starts.dtype + ).unsqueeze(0) + indices = chunk_starts.unsqueeze(1) + range_tensor + indices = torch.clamp(indices, max=kv.shape[0] - 1) + valid_mask = range_tensor < chunk_lengths.unsqueeze(1) + gathered = kv[indices.view(-1)].view( + num_filtered_chunk, max_chunk_len, *kv.shape[1:] + ) + gathered = gathered * valid_mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).type_as( + gathered + ) + + # Compute key_gate_weight over valid tokens. + key_values = gathered[ + :, :, 0 + ].float() # [num_filtered_chunk, max_chunk_len, num_head, head_dim] + valid_mask_exp = valid_mask.unsqueeze(-1).unsqueeze(-1) + key_sum = (key_values * valid_mask_exp).sum(dim=1) + divisor = valid_mask.sum(dim=1).unsqueeze(-1).unsqueeze(-1) + key_gate_weight = key_sum / divisor # [num_filtered_chunk, num_head, head_dim] + + # Compute gate logits between key_gate_weight and queries. + q_float = q.float() + # gate = torch.einsum("nhd,shd->nhs", key_gate_weight, q_float) # [num_filtered_chunk, num_head, seqlen] + gate = torch.bmm( + key_gate_weight.permute(1, 0, 2), q_float.permute(1, 0, 2).transpose(1, 2) + ).permute(1, 0, 2) + + # Amplify the diagonal (self chunk) contributions. + gate_seq_idx = ( + torch.arange(seqlen, device=q.device, dtype=torch.int32) + .unsqueeze(0) + .expand(num_filtered_chunk, seqlen) + ) + chunk_start = cu_chunk[filtered_chunk_indices] # [num_filtered_chunk] + chunk_end = cu_chunk[filtered_chunk_indices + 1] # [num_filtered_chunk] + gate_self_chunk_mask = ( + ( + (gate_seq_idx >= chunk_start.unsqueeze(1)) + & (gate_seq_idx < chunk_end.unsqueeze(1)) + ) + .unsqueeze(1) + .expand(-1, num_head, -1) + ) + amplification_factor = 1e9 # Example factor; adjust as needed. + origin_gate = gate.clone() + gate = gate.clone() + if select_mode == "topk": + gate[gate_self_chunk_mask] += amplification_factor + + # Exclude positions that are outside the valid batch boundaries. + batch_starts = cu_seqlens[chunk_to_batch[filtered_chunk_indices]] + batch_ends = cu_seqlens[chunk_to_batch[filtered_chunk_indices] + 1] + gate_batch_start_mask = gate_seq_idx < batch_starts.unsqueeze(1) + gate_batch_end_mask = gate_seq_idx >= batch_ends.unsqueeze(1) + gate_inf_mask = gate_batch_start_mask | gate_batch_end_mask + gate.masked_fill_(gate_inf_mask.unsqueeze(1), -float("inf")) + + if select_mode == "topk": + # We amplify self‐chunk in gate already, so self entries will rank highest. + valid_gate_mask = gate != -float("inf") + if threshold_type == "query_head": + # === per‐ top-k across chunks (original behavior) === + # gate: (C, H, S) + _, gate_topk_idx = torch.topk( + gate, k=moba_topk, dim=0, largest=True, sorted=False + ) + gate_idx_mask = torch.zeros_like(gate, dtype=torch.bool) + gate_idx_mask.scatter_(0, gate_topk_idx, True) + gate_mask = valid_gate_mask & gate_idx_mask + elif threshold_type == "overall": + # === global top-k across all (chunk, head, seq) entries === + C, H, S = gate.shape + flat_gate = gate.flatten() + flat_mask = valid_gate_mask.flatten() + flat_gate_masked = torch.where(flat_mask, flat_gate, -float("inf")) + # pick topk global entries + vals, idx = torch.topk( + flat_gate_masked, k=moba_topk * H * S, largest=True, sorted=False + ) + others_mask_flat = torch.zeros_like(flat_mask, dtype=torch.bool) + others_mask_flat[idx] = True + gate_mask = (valid_gate_mask.flatten() & others_mask_flat).view(gate.shape) + elif threshold_type == "head_global": + # per-head top-k across all chunks and sequence positions + C, H, S = gate.shape + CS = C * S + flat_gate = gate.permute(1, 0, 2).reshape(H, CS) + flat_valid = valid_gate_mask.permute(1, 0, 2).reshape(H, CS) + flat_gate_masked = torch.where( + flat_valid, flat_gate, torch.full_like(flat_gate, -float("inf")) + ) + # pick top-k indices per head + _, topk_idx = torch.topk( + flat_gate_masked, k=moba_topk * S, dim=1, largest=True, sorted=False + ) + gate_idx_flat = torch.zeros_like(flat_valid, dtype=torch.bool) + gate_idx_flat.scatter_(1, topk_idx, True) + gate_mask = gate_idx_flat.reshape(H, C, S).permute(1, 0, 2) + else: + raise ValueError( + f"Invalid threshold_type for topk: {threshold_type}. " + "Choose 'query_head', 'block', or 'overall'." + ) + elif select_mode == "threshold": + # Delegate to the specific thresholding function + valid_gate_mask = gate != -float("inf") # (num_chunk, num_head, seqlen) + if threshold_type == "query_head": + gate_mask = _select_threshold_query_head( + gate, valid_gate_mask, gate_self_chunk_mask, simsum_threshold + ) + elif threshold_type == "block": + gate_mask = _select_threshold_block( + gate, valid_gate_mask, gate_self_chunk_mask, simsum_threshold + ) + elif threshold_type == "overall": + gate_mask = _select_threshold_overall( + gate, valid_gate_mask, gate_self_chunk_mask, simsum_threshold + ) + elif threshold_type == "head_global": + gate_mask = _select_threshold_head_global( + gate, valid_gate_mask, gate_self_chunk_mask, simsum_threshold + ) + else: + raise ValueError( + f"Invalid threshold_type: {threshold_type}. Choose 'query_head', 'block', or 'overall'." + ) + else: + raise ValueError( + f"Invalid select_mode: {select_mode}. Choose 'topk' or 'threshold'." + ) + + # eliminate self_chunk in MoBA branch + gate_mask = gate_mask & ~gate_self_chunk_mask + # if gate_mask is all false, perform flash_attn instead + if gate_mask.sum() == 0: + return flash_attn_varlen_func( + q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, causal=False + ) + + # Determine which query positions are selected. + # nonzero_indices has shape [N, 3] where each row is [chunk_index, head_index, seq_index]. + moba_q_indices = gate_mask.reshape(gate_mask.shape[0], -1).nonzero(as_tuple=True)[ + -1 + ] # [(h s k)] + moba_q_sh_indices = (moba_q_indices % seqlen) * num_head + ( + moba_q_indices // seqlen + ) + moba_q = ( + rearrange(q, "s h d -> (h s) d").index_select(0, moba_q_indices).unsqueeze(1) + ) + + # Build cumulative sequence lengths for the selected queries. + moba_seqlen_q = gate_mask.sum(dim=-1).flatten() + q_zero_mask = moba_seqlen_q == 0 + valid_expert_mask = ~q_zero_mask + if q_zero_mask.sum() > 0: + moba_seqlen_q = moba_seqlen_q[valid_expert_mask] + moba_cu_seqlen_q = torch.cat( + ( + torch.tensor([0], device=q.device, dtype=moba_seqlen_q.dtype), + moba_seqlen_q.cumsum(dim=0), + ), + dim=0, + ).to(torch.int32) + + # Rearrange gathered KV for the MOBA branch. + experts_tensor = rearrange(gathered, "nc cl two h d -> (nc h) cl two d") + valid_expert_lengths = ( + chunk_lengths.unsqueeze(1) + .expand(num_filtered_chunk, num_head) + .reshape(-1) + .to(torch.int32) + ) + if q_zero_mask.sum() > 0: + experts_tensor = experts_tensor[valid_expert_mask] + valid_expert_lengths = valid_expert_lengths[valid_expert_mask] + + seq_range = torch.arange( + experts_tensor.shape[1], device=experts_tensor.device + ).unsqueeze(0) + mask = seq_range < valid_expert_lengths.unsqueeze(1) + moba_kv = experts_tensor[mask] # Shape: ((nc h cl_valid) two d) + moba_kv = moba_kv.unsqueeze(2) # Shape: ((nc h cl_valid) two 1 d) + + moba_cu_seqlen_kv = torch.cat( + [ + torch.zeros(1, device=experts_tensor.device, dtype=torch.int32), + valid_expert_lengths.cumsum(dim=0), + ], + dim=0, + ).to(torch.int32) + + assert ( + moba_cu_seqlen_kv.shape == moba_cu_seqlen_q.shape + ), f"Mismatch between moba_cu_seqlen_kv.shape and moba_cu_seqlen_q.shape: {moba_cu_seqlen_kv.shape} vs {moba_cu_seqlen_q.shape}" + + return MixedAttention.apply( + q, + k, + v, + self_attn_cu_seqlen, + moba_q, + moba_kv, + moba_cu_seqlen_q, + moba_cu_seqlen_kv, + max_seqlen, + moba_chunk_size, + moba_q_sh_indices, + ) + + +def process_moba_input( + x, + patch_resolution, + chunk_size, +): + """ + Process inputs for the attention function. + + Args: + x (torch.Tensor): Input tensor with shape [batch_size, num_patches, num_heads, head_dim]. + patch_resolution (tuple): Tuple containing the patch resolution (t, h, w). + chunk_size (int): Size of the chunk. (maybe tuple or int, according to chunk type) + + Returns: + torch.Tensor: Processed input tensor. + """ + if isinstance(chunk_size, float) or isinstance(chunk_size, int): + moba_chunk_size = int(chunk_size * patch_resolution[1] * patch_resolution[2]) + else: + assert isinstance( + chunk_size, (Tuple, list) + ), f"chunk_size should be a tuple, list, or int, now it is: {type(chunk_size)}" + if len(chunk_size) == 2: + assert ( + patch_resolution[1] % chunk_size[0] == 0 + and patch_resolution[2] % chunk_size[1] == 0 + ), f"spatial patch_resolution {patch_resolution[1:]} should be divisible by 2d chunk_size {chunk_size}" + nch, ncw = ( + patch_resolution[1] // chunk_size[0], + patch_resolution[2] // chunk_size[1], + ) + x = rearrange( + x, + "b (t nch ch ncw cw) n d -> b (nch ncw t ch cw) n d", + t=patch_resolution[0], + nch=nch, + ncw=ncw, + ch=chunk_size[0], + cw=chunk_size[1], + ) + moba_chunk_size = patch_resolution[0] * chunk_size[0] * chunk_size[1] + elif len(chunk_size) == 3: + assert ( + patch_resolution[0] % chunk_size[0] == 0 + and patch_resolution[1] % chunk_size[1] == 0 + and patch_resolution[2] % chunk_size[2] == 0 + ), f"patch_resolution {patch_resolution} should be divisible by 3d chunk_size {chunk_size}" + nct, nch, ncw = ( + patch_resolution[0] // chunk_size[0], + patch_resolution[1] // chunk_size[1], + patch_resolution[2] // chunk_size[2], + ) + x = rearrange( + x, + "b (nct ct nch ch ncw cw) n d -> b (nct nch ncw ct ch cw) n d", + nct=nct, + nch=nch, + ncw=ncw, + ct=chunk_size[0], + ch=chunk_size[1], + cw=chunk_size[2], + ) + moba_chunk_size = chunk_size[0] * chunk_size[1] * chunk_size[2] + else: + raise ValueError( + f"chunk_size should be a int, or a tuple of length 2 or 3, now it is: {len(chunk_size)}" + ) + + return x, moba_chunk_size + + +def process_moba_output( + x, + patch_resolution, + chunk_size, +): + if isinstance(chunk_size, float) or isinstance(chunk_size, int): + pass + elif len(chunk_size) == 2: + x = rearrange( + x, + "b (nch ncw t ch cw) n d -> b (t nch ch ncw cw) n d", + nch=patch_resolution[1] // chunk_size[0], + ncw=patch_resolution[2] // chunk_size[1], + t=patch_resolution[0], + ch=chunk_size[0], + cw=chunk_size[1], + ) + elif len(chunk_size) == 3: + x = rearrange( + x, + "b (nct nch ncw ct ch cw) n d -> b (nct ct nch ch ncw cw) n d", + nct=patch_resolution[0] // chunk_size[0], + nch=patch_resolution[1] // chunk_size[1], + ncw=patch_resolution[2] // chunk_size[2], + ct=chunk_size[0], + ch=chunk_size[1], + cw=chunk_size[2], + ) + + return x + + +# TEST +def generate_data(batch_size, seqlen, num_head, head_dim, dtype): + random.seed(0) + torch.manual_seed(0) + torch.cuda.manual_seed(0) + device = torch.cuda.current_device() + + q = torch.randn((batch_size, seqlen, num_head, head_dim), requires_grad=True).to( + dtype=dtype, device="cuda" + ) + k = torch.randn((batch_size, seqlen, num_head, head_dim), requires_grad=True).to( + dtype=dtype, device="cuda" + ) + v = torch.randn((batch_size, seqlen, num_head, head_dim), requires_grad=True).to( + dtype=dtype, device="cuda" + ) + print(f"q.shape: {q.shape}, k.shape: {k.shape}, v.shape: {v.shape}") + cu_seqlens = torch.arange( + 0, q.shape[0] * q.shape[1] + 1, q.shape[1], dtype=torch.int32, device="cuda" + ) + max_seqlen = q.shape[1] + q = rearrange(q, "b s ... -> (b s) ...") + k = rearrange(k, "b s ... -> (b s) ...") + v = rearrange(v, "b s ... -> (b s) ...") + + return q, k, v, cu_seqlens, max_seqlen + + +def test_attn_varlen_moba_speed( + batch, + head, + seqlen, + head_dim, + moba_chunk_size, + moba_topk, + dtype=torch.bfloat16, + select_mode="threshold", + simsum_threshold=0.25, + threshold_type="query_head", +): + """Speed test comparing flash_attn vs moba_attention""" + # Get data + q, k, v, cu_seqlen, max_seqlen = generate_data(batch, seqlen, head, head_dim, dtype) + print( + f"batch:{batch} head:{head} seqlen:{seqlen} chunk:{moba_chunk_size} topk:{moba_topk} select_mode: {select_mode} simsum_threshold:{simsum_threshold}" + ) + vo_grad = torch.randn_like(q) + + # Warmup + warmup_iters = 3 + perf_test_iters = 10 + + # Warmup + for _ in range(warmup_iters): + o = flash_attn_varlen_func( + q, k, v, cu_seqlen, cu_seqlen, max_seqlen, max_seqlen, causal=False + ) + torch.autograd.backward(o, vo_grad) + + torch.cuda.synchronize() + start_flash = time.perf_counter() + for _ in range(perf_test_iters): + o = flash_attn_varlen_func( + q, k, v, cu_seqlen, cu_seqlen, max_seqlen, max_seqlen, causal=False + ) + torch.autograd.backward(o, vo_grad) + + torch.cuda.synchronize() + time_flash = (time.perf_counter() - start_flash) / perf_test_iters * 1000 + + # Warmup + for _ in range(warmup_iters): + om = moba_attn_varlen( + q, + k, + v, + cu_seqlen, + max_seqlen, + moba_chunk_size=moba_chunk_size, + moba_topk=moba_topk, + select_mode=select_mode, + simsum_threshold=simsum_threshold, + threshold_type=threshold_type, + ) + torch.autograd.backward(om, vo_grad) + + torch.cuda.synchronize() + start_moba = time.perf_counter() + for _ in range(perf_test_iters): + om = moba_attn_varlen( + q, + k, + v, + cu_seqlen, + max_seqlen, + moba_chunk_size=moba_chunk_size, + moba_topk=moba_topk, + select_mode=select_mode, + simsum_threshold=simsum_threshold, + threshold_type=threshold_type, + ) + torch.autograd.backward(om, vo_grad) + + torch.cuda.synchronize() + time_moba = (time.perf_counter() - start_moba) / perf_test_iters * 1000 + + print(f"Flash: {time_flash:.2f}ms, MoBA: {time_moba:.2f}ms") + print(f"Speedup: {time_flash / time_moba:.2f}x") + + +if __name__ == "__main__": + """ + CUDA_VISIBLE_DEVICES=1 \ + python -u csrc/attn/vmoba_attn/vmoba/vmoba.py + """ + test_attn_varlen_moba_speed( + batch=1, + head=12, + seqlen=32760, + head_dim=128, + moba_chunk_size=32760 // 3 // 6 // 4, + moba_topk=3, + select_mode="threshold", + simsum_threshold=0.3, + threshold_type="query_head", + ) diff --git a/python/sglang/multimodal_gen/docs/cli.md b/python/sglang/multimodal_gen/docs/cli.md new file mode 100644 index 000000000..2a37e7050 --- /dev/null +++ b/python/sglang/multimodal_gen/docs/cli.md @@ -0,0 +1,274 @@ +# sgl-diffusion CLI Inference + +The sgl-diffusion CLI provides a quick way to access the sgl-diffusion inference pipeline for image and video generation. + +## Prerequisites + +- A working sgl-diffusion installation and the `sgl-diffusion` CLI available in `$PATH`. +- Python 3.10+ if you plan to use the OpenAI Python SDK. + + +## Supported Arguments + +### Server Arguments + +- `--model-path {MODEL_PATH}`: Path to the model or model ID +- `--num-gpus {NUM_GPUS}`: Number of GPUs to use +- `--tp-size {TP_SIZE}`: Tensor parallelism size (only for the encoder; should not be larger than 1 if text encoder offload is enabled, as layer-wise offload plus prefetch is faster) +- `--sp-size {SP_SIZE}`: Sequence parallelism size (typically should match the number of GPUs) +- `--ulysses-degree {ULYSSES_DEGREE}`: The degree of DeepSpeed-Ulysses-style SP in USP +- `--ring-degree {RING_DEGREE}`: The degree of ring attention-style SP in USP + + +### Sampling Parameters + +- `--prompt {PROMPT}`: Text description for the video you want to generate +- `--num-inference-steps {STEPS}`: Number of denoising steps +- `--negative-prompt {PROMPT}`: Negative prompt to guide generation away from certain concepts +- `--seed {SEED}`: Random seed for reproducible generation + + +#### Image/Video Configuration + +- `--height {HEIGHT}`: Height of the generated output +- `--width {WIDTH}`: Width of the generated output +- `--num-frames {NUM_FRAMES}`: Number of frames to generate +- `--fps {FPS}`: Frames per second for the saved output, if this is a video-generation task + + +#### Output Options + +- `--output-path {PATH}`: Directory to save the generated video +- `--save-output`: Whether to save the image/video to disk +- `--return-frames`: Whether to return the raw frames + +### Using Configuration Files + +Instead of specifying all parameters on the command line, you can use a configuration file: + +```bash +sglang generate --config {CONFIG_FILE_PATH} +``` + +The configuration file should be in JSON or YAML format with the same parameter names as the CLI options. Command-line arguments take precedence over settings in the configuration file, allowing you to override specific values while keeping the rest from the configuration file. + +Example configuration file (config.json): + +```json +{ + "model_path": "FastVideo/FastHunyuan-diffusers", + "prompt": "A beautiful woman in a red dress walking down a street", + "output_path": "outputs/", + "num_gpus": 2, + "sp_size": 2, + "tp_size": 1, + "num_frames": 45, + "height": 720, + "width": 1280, + "num_inference_steps": 6, + "seed": 1024, + "fps": 24, + "precision": "bf16", + "vae_precision": "fp16", + "vae_tiling": true, + "vae_sp": true, + "vae_config": { + "load_encoder": false, + "load_decoder": true, + "tile_sample_min_height": 256, + "tile_sample_min_width": 256 + }, + "text_encoder_precisions": [ + "fp16", + "fp16" + ], + "mask_strategy_file_path": null, + "enable_torch_compile": false +} +``` + +Or using YAML format (config.yaml): + +```yaml +model_path: "FastVideo/FastHunyuan-diffusers" +prompt: "A beautiful woman in a red dress walking down a street" +output_path: "outputs/" +num_gpus: 2 +sp_size: 2 +tp_size: 1 +num_frames: 45 +height: 720 +width: 1280 +num_inference_steps: 6 +seed: 1024 +fps: 24 +precision: "bf16" +vae_precision: "fp16" +vae_tiling: true +vae_sp: true +vae_config: + load_encoder: false + load_decoder: true + tile_sample_min_height: 256 + tile_sample_min_width: 256 +text_encoder_precisions: + - "fp16" + - "fp16" +mask_strategy_file_path: null +enable_torch_compile: false +``` + + +To see all the options, you can use the `--help` flag: + +```bash +sglang generate --help +``` + +## Serve + +Launch the sgl-diffusion HTTP server and interact with it using the OpenAI SDK and curl. The server implements an OpenAI-compatible subset for Videos under the `/v1/videos` namespace. + +### Start the server + +Use the following command to launch the server: + +```bash +SERVER_ARGS=( + --model-path Wan-AI/Wan2.1-T2V-1.3B-Diffusers + --text-encoder-cpu-offload + --pin-cpu-memory + --num-gpus 4 + --ulysses-degree=2 + --ring-degree=2 +) + +sglang serve $SERVER_ARGS +``` + +- **--model-path**: Which model to load. The example uses `Wan-AI/Wan2.1-T2V-1.3B-Diffusers`. +- **--port**: HTTP port to listen on (the default here is `30010`). + +Wait until the port is listening. In CI, the tests probe `127.0.0.1:30010` before sending requests. + +### OpenAI Python SDK usage + +Initialize the client with a dummy API key and point `base_url` to your local server: + +```python +from openai import OpenAI + +client = OpenAI(api_key="sk-proj-1234567890", base_url="http://localhost:30010/v1") +``` + +- **Create a video** + +```python +video = client.videos.create(prompt="A calico cat playing a piano on stage", size="1280x720") +print(video.id, video.status) +``` + +Response example fields include `id`, `status` (e.g., `queued` → `completed`), `size`, and `seconds`. + +- **List videos** + +```python +videos = client.videos.list() +for item in videos.data: + print(item.id, item.status) +``` + +- **Poll for completion and download content** + +```python +import time + +video = client.videos.create(prompt="A calico cat playing a piano on stage", size="1280x720") +video_id = video.id + +# Simple polling loop +while True: + page = client.videos.list() + item = next((v for v in page.data if v.id == video_id), None) + if item and item.status == "completed": + break + time.sleep(5) + +# Download binary content (MP4) +resp = client.videos.download_content(video_id=video_id) +content = resp.read() # bytes +with open("output.mp4", "wb") as f: + f.write(content) +``` + +### curl examples + +- **Create a video** + +```bash +curl -sS -X POST "http://localhost:30010/v1/videos" \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer sk-proj-1234567890" \ + -d '{ + "prompt": "A calico cat playing a piano on stage", + "size": "1280x720" + }' +``` + +- **List videos** + +```bash +curl -sS -X GET "http://localhost:30010/v1/videos" \ + -H "Authorization: Bearer sk-proj-1234567890" +``` + +- **Download video content** + +```bash +curl -sS -L "http://localhost:30010/v1/videos//content" \ + -H "Authorization: Bearer sk-proj-1234567890" \ + -o output.mp4 +``` + +### API surface implemented here + +The server exposes these endpoints (OpenAPI tag `videos`): + +- `POST /v1/videos` — Create a generation job and return a queued `video` object. +- `GET /v1/videos` — List jobs. +- `GET /v1/videos/{video_id}/content` — Download binary content when ready (e.g., MP4). + +### Reference + +- OpenAI Videos API reference: `https://platform.openai.com/docs/api-reference/videos` + +## Generate + +Run a one-off generation task without launching a persistent server. + +To use it, pass both server arguments and sampling parameters in one command, after the `generate` subcommand, for example: + +```bash +SERVER_ARGS=( + --model-path Wan-AI/Wan2.2-T2V-A14B-Diffusers + --text-encoder-cpu-offload + --pin-cpu-memory + --num-gpus 4 + --ulysses-degree=2 + --ring-degree=2 +) + +SAMPLING_ARGS=( + --prompt "A curious raccoon" + --save-output + --output-path outputs + --output-file-name "A curious raccoon.mp4" +) + +sglang generate $SERVER_ARGS $SAMPLING_ARGS +``` + +Once the generation task has finished, the server will shut down automatically. + +> [!NOTE] +> The HTTP server-related arguments are ignored in this subcommand. diff --git a/python/sglang/multimodal_gen/docs/install.md b/python/sglang/multimodal_gen/docs/install.md new file mode 100644 index 000000000..8b8321533 --- /dev/null +++ b/python/sglang/multimodal_gen/docs/install.md @@ -0,0 +1,52 @@ +# Install sgl-diffusion + +You can install sgl-diffusion using one of the methods below. + +This page primarily applies to common NVIDIA GPU platforms. + +## Method 1: With pip or uv + +It is recommended to use uv for a faster installation: + +```bash +pip install --upgrade pip +pip install uv +uv pip install sglang[.diffusion] --prerelease=allow +``` + +## Method 2: From source + +```bash +# Use the latest release branch +git clone https://github.com/sgl-project/sglang.git +cd sglang + +# Install the Python packages +pip install --upgrade pip +pip install -e "python/.[diffusion]" + +# With uv +uv pip install --prerelease=allow -e "python/.[diffusion]" +``` + +**Quick fixes for common problems:** + +- If you want to develop sgl-diffusion, it is recommended to use Docker. The Docker image is `lmsysorg/sgl-diffusion:latest`. + +## Method 3: Using Docker + +The Docker images are available on Docker Hub at [lmsysorg/sgl-diffusion](), built from the [Dockerfile](https://github.com/sgl-project/sgl-diffusion/tree/main/docker). +Replace `` below with your HuggingFace Hub [token](https://huggingface.co/docs/hub/en/security-tokens). + +```bash +docker run --gpus all \ + --shm-size 32g \ + -p 30000:30000 \ + -v ~/.cache/huggingface:/root/.cache/huggingface \ + --env "HF_TOKEN=" \ + --ipc=host \ + lmsysorg/sglang:diffusion \ + sglang generate --model-path black-forest-labs/FLUX.1-dev \ + --prompt "A logo With Bold Large text: SGL Diffusion" \ + --save-output +``` diff --git a/python/sglang/multimodal_gen/docs/support_matrix.md b/python/sglang/multimodal_gen/docs/support_matrix.md new file mode 100644 index 000000000..bbbacb83d --- /dev/null +++ b/python/sglang/multimodal_gen/docs/support_matrix.md @@ -0,0 +1,46 @@ +# Compatibility Matrix + +The table below shows every supported model and the optimizations supported for them. + +The symbols used have the following meanings: + +- ✅ = Full compatibility +- ❌ = No compatibility +- ⭕ = Does not apply to this model + +## Models x Optimization + +The `HuggingFace Model ID` can be passed directly to `from_pretrained()` methods, and sgl-diffusion will use the optimal +default parameters when initializing and generating videos. + +### Video Generation Models + +| Model Name | Hugging Face Model ID | Resolutions | TeaCache | Sliding Tile Attn | Sage Attn | Video Sparse Attention (VSA) | +|:-----------------------------|:--------------------------------------------------|:---------------------------------------------|:--------:|:-----------------:|:---------:|:----------------------------:| +| FastWan2.1 T2V 1.3B | `FastVideo/FastWan2.1-T2V-1.3B-Diffusers` | 480p | ⭕ | ⭕ | ⭕ | ✅ | +| FastWan2.2 TI2V 5B Full Attn | `FastVideo/FastWan2.2-TI2V-5B-FullAttn-Diffusers` | 720p | ⭕ | ⭕ | ⭕ | ✅ | +| Wan2.2 TI2V 5B | `Wan-AI/Wan2.2-TI2V-5B-Diffusers` | 720p | ⭕ | ⭕ | ✅ | ⭕ | +| Wan2.2 T2V A14B | `Wan-AI/Wan2.2-T2V-A14B-Diffusers` | 480p
720p | ❌ | ❌ | ✅ | ⭕ | +| Wan2.2 I2V A14B | `Wan-AI/Wan2.2-I2V-A14B-Diffusers` | 480p
720p | ❌ | ❌ | ✅ | ⭕ | +| HunyuanVideo | `hunyuanvideo-community/HunyuanVideo` | 720×1280
544×960 | ❌ | ✅ | ✅ | ⭕ | +| FastHunyuan | `FastVideo/FastHunyuan-diffusers` | 720×1280
544×960 | ❌ | ✅ | ✅ | ⭕ | +| Wan2.1 T2V 1.3B | `Wan-AI/Wan2.1-T2V-1.3B-Diffusers` | 480p | ✅ | ✅ | ✅ | ⭕ | +| Wan2.1 T2V 14B | `Wan-AI/Wan2.1-T2V-14B-Diffusers` | 480p, 720p | ✅ | ✅ | ✅ | ⭕ | +| Wan2.1 I2V 480P | `Wan-AI/Wan2.1-I2V-14B-480P-Diffusers` | 480p | ✅ | ✅ | ✅ | ⭕ | +| Wan2.1 I2V 720P | `Wan-AI/Wan2.1-I2V-14B-720P-Diffusers` | 720p | ✅ | ✅ | ✅ | ⭕ | + +**Note**: Wan2.2 TI2V 5B has some quality issues when performing I2V generation. We are working on fixing this issue. + +### Image Generation Models + +| Model Name | HuggingFace Model ID | Resolutions | TeaCache | Sage Attn | +|:----------------|:-------------------------------|:---------------|:--------:|:---------:| +| FLUX.1-dev | `black-forest-labs/FLUX.1-dev` | Any resolution | ❌ | ❌ | +| Qwen Image | `Qwen/Qwen-Image` | Any resolution | ❌ | ❌ | +| Qwen Image Edit | `Qwen/Qwen-Image-Edit` | Any resolution | ❌ | ❌ | + +## Special requirements + +### Sliding Tile Attention + +- Currently, only Hopper GPUs (H100s) are supported. diff --git a/python/sglang/multimodal_gen/envs.py b/python/sglang/multimodal_gen/envs.py new file mode 100644 index 000000000..387ecde2b --- /dev/null +++ b/python/sglang/multimodal_gen/envs.py @@ -0,0 +1,326 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo +import importlib.util + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/envs.py +import logging +import os +from collections.abc import Callable +from typing import TYPE_CHECKING, Any + +import diffusers +import torch +from packaging import version + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + SGL_DIFFUSION_RINGBUFFER_WARNING_INTERVAL: int = 60 + SGL_DIFFUSION_NCCL_SO_PATH: str | None = None + LD_LIBRARY_PATH: str | None = None + LOCAL_RANK: int = 0 + CUDA_VISIBLE_DEVICES: str | None = None + SGL_DIFFUSION_CACHE_ROOT: str = os.path.expanduser("~/.cache/sgl_diffusion") + SGL_DIFFUSION_CONFIG_ROOT: str = os.path.expanduser("~/.config/sgl_diffusion") + SGL_DIFFUSION_CONFIGURE_LOGGING: int = 1 + SGL_DIFFUSION_LOGGING_LEVEL: str = "INFO" + SGL_DIFFUSION_LOGGING_PREFIX: str = "" + SGL_DIFFUSION_LOGGING_CONFIG_PATH: str | None = None + SGL_DIFFUSION_TRACE_FUNCTION: int = 0 + SGL_DIFFUSION_WORKER_MULTIPROC_METHOD: str = "fork" + SGL_DIFFUSION_TARGET_DEVICE: str = "cuda" + MAX_JOBS: str | None = None + NVCC_THREADS: str | None = None + CMAKE_BUILD_TYPE: str | None = None + VERBOSE: bool = False + SGL_DIFFUSION_SERVER_DEV_MODE: bool = False + SGL_DIFFUSION_STAGE_LOGGING: bool = False + + +def _is_hip(): + has_rocm = torch.version.hip is not None + return has_rocm + + +def _is_cuda(): + has_cuda = torch.version.cuda is not None + return has_cuda + + +def _is_musa(): + try: + if hasattr(torch, "musa") and torch.musa.is_available(): + return True + except ModuleNotFoundError: + return False + + +def _is_mps(): + return torch.backends.mps.is_available() + + +class PackagesEnvChecker: + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super(PackagesEnvChecker, cls).__new__(cls) + cls._instance.initialize() + return cls._instance + + def initialize(self): + self.packages_info = { + "has_aiter": self.check_aiter(), + "diffusers_version": self.check_diffusers_version(), + } + + def check_aiter(self): + """ + Checks whether ROCm AITER library is installed + """ + try: + + logger.info("Using AITER as the attention library") + return True + except: + if _is_hip(): + logger.warning( + f'Using AMD GPUs, but library "aiter" is not installed, ' + "defaulting to other attention mechanisms" + ) + return False + + def check_flash_attn(self): + if not torch.cuda.is_available(): + return False + if _is_musa(): + logger.info( + "Flash Attention library is not supported on MUSA for the moment." + ) + return False + try: + return True + except ImportError: + logger.warning( + f'Flash Attention library "flash_attn" not found, ' + f"using pytorch attention implementation" + ) + return False + + def check_long_ctx_attn(self): + if not torch.cuda.is_available(): + return False + try: + return importlib.util.find_spec("yunchang") is not None + except ImportError: + logger.warning( + f'Ring Flash Attention library "yunchang" not found, ' + f"using pytorch attention implementation" + ) + return False + + def check_diffusers_version(self): + if version.parse( + version.parse(diffusers.__version__).base_version + ) < version.parse("0.30.0"): + raise RuntimeError( + f"Diffusers version: {version.parse(version.parse(diffusers.__version__).base_version)} is not supported," + f"please upgrade to version > 0.30.0" + ) + return version.parse(version.parse(diffusers.__version__).base_version) + + def get_packages_info(self): + return self.packages_info + + +PACKAGES_CHECKER = PackagesEnvChecker() + + +def get_default_cache_root() -> str: + return os.getenv( + "XDG_CACHE_HOME", + os.path.join(os.path.expanduser("~"), ".cache"), + ) + + +def get_default_config_root() -> str: + return os.getenv( + "XDG_CONFIG_HOME", + os.path.join(os.path.expanduser("~"), ".config"), + ) + + +def maybe_convert_int(value: str | None) -> int | None: + if value is None: + return None + return int(value) + + +# The begin-* and end* here are used by the documentation generator +# to extract the used env vars. + +# begin-env-vars-definition + +environment_variables: dict[str, Callable[[], Any]] = { + # ================== Installation Time Env Vars ================== + # Target device of sgl-diffusion, supporting [cuda (by default), + # rocm, neuron, cpu, openvino] + "SGL_DIFFUSION_TARGET_DEVICE": lambda: os.getenv( + "SGL_DIFFUSION_TARGET_DEVICE", "cuda" + ), + # Maximum number of compilation jobs to run in parallel. + # By default this is the number of CPUs + "MAX_JOBS": lambda: os.getenv("MAX_JOBS", None), + # Number of threads to use for nvcc + # By default this is 1. + # If set, `MAX_JOBS` will be reduced to avoid oversubscribing the CPU. + "NVCC_THREADS": lambda: os.getenv("NVCC_THREADS", None), + # If set, sgl_diffusion will use precompiled binaries (*.so) + "SGL_DIFFUSION_USE_PRECOMPILED": lambda: bool( + os.environ.get("SGL_DIFFUSION_USE_PRECOMPILED") + ) + or bool(os.environ.get("SGL_DIFFUSION_PRECOMPILED_WHEEL_LOCATION")), + # CMake build type + # If not set, defaults to "Debug" or "RelWithDebInfo" + # Available options: "Debug", "Release", "RelWithDebInfo" + "CMAKE_BUILD_TYPE": lambda: os.getenv("CMAKE_BUILD_TYPE"), + # If set, sgl_diffusion will print verbose logs during installation + "VERBOSE": lambda: bool(int(os.getenv("VERBOSE", "0"))), + # Root directory for FASTVIDEO configuration files + # Defaults to `~/.config/sgl_diffusion` unless `XDG_CONFIG_HOME` is set + # Note that this not only affects how sgl_diffusion finds its configuration files + # during runtime, but also affects how sgl_diffusion installs its configuration + # files during **installation**. + "SGL_DIFFUSION_CONFIG_ROOT": lambda: os.path.expanduser( + os.getenv( + "SGL_DIFFUSION_CONFIG_ROOT", + os.path.join(get_default_config_root(), "sgl_diffusion"), + ) + ), + # ================== Runtime Env Vars ================== + # Root directory for FASTVIDEO cache files + # Defaults to `~/.cache/sgl_diffusion` unless `XDG_CACHE_HOME` is set + "SGL_DIFFUSION_CACHE_ROOT": lambda: os.path.expanduser( + os.getenv( + "SGL_DIFFUSION_CACHE_ROOT", + os.path.join(get_default_cache_root(), "sgl_diffusion"), + ) + ), + # Interval in seconds to log a warning message when the ring buffer is full + "SGL_DIFFUSION_RINGBUFFER_WARNING_INTERVAL": lambda: int( + os.environ.get("SGL_DIFFUSION_RINGBUFFER_WARNING_INTERVAL", "60") + ), + # Path to the NCCL library file. It is needed because nccl>=2.19 brought + # by PyTorch contains a bug: https://github.com/NVIDIA/nccl/issues/1234 + "SGL_DIFFUSION_NCCL_SO_PATH": lambda: os.environ.get( + "SGL_DIFFUSION_NCCL_SO_PATH", None + ), + # when `SGL_DIFFUSION_NCCL_SO_PATH` is not set, sgl_diffusion will try to find the nccl + # library file in the locations specified by `LD_LIBRARY_PATH` + "LD_LIBRARY_PATH": lambda: os.environ.get("LD_LIBRARY_PATH", None), + # Internal flag to enable Dynamo fullgraph capture + "SGL_DIFFUSION_TEST_DYNAMO_FULLGRAPH_CAPTURE": lambda: bool( + os.environ.get("SGL_DIFFUSION_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0" + ), + # local rank of the process in the distributed setting, used to determine + # the GPU device id + "LOCAL_RANK": lambda: int(os.environ.get("LOCAL_RANK", "0")), + # used to control the visible devices in the distributed setting + "CUDA_VISIBLE_DEVICES": lambda: os.environ.get("CUDA_VISIBLE_DEVICES", None), + # timeout for each iteration in the engine + "SGL_DIFFUSION_ENGINE_ITERATION_TIMEOUT_S": lambda: int( + os.environ.get("SGL_DIFFUSION_ENGINE_ITERATION_TIMEOUT_S", "60") + ), + # Logging configuration + # If set to 0, sgl_diffusion will not configure logging + # If set to 1, sgl_diffusion will configure logging using the default configuration + # or the configuration file specified by SGL_DIFFUSION_LOGGING_CONFIG_PATH + "SGL_DIFFUSION_CONFIGURE_LOGGING": lambda: int( + os.getenv("SGL_DIFFUSION_CONFIGURE_LOGGING", "1") + ), + "SGL_DIFFUSION_LOGGING_CONFIG_PATH": lambda: os.getenv( + "SGL_DIFFUSION_LOGGING_CONFIG_PATH" + ), + # this is used for configuring the default logging level + "SGL_DIFFUSION_LOGGING_LEVEL": lambda: os.getenv( + "SGL_DIFFUSION_LOGGING_LEVEL", "INFO" + ), + # if set, SGL_DIFFUSION_LOGGING_PREFIX will be prepended to all log messages + "SGL_DIFFUSION_LOGGING_PREFIX": lambda: os.getenv( + "SGL_DIFFUSION_LOGGING_PREFIX", "" + ), + # Trace function calls + # If set to 1, sgl_diffusion will trace function calls + # Useful for debugging + "SGL_DIFFUSION_TRACE_FUNCTION": lambda: int( + os.getenv("SGL_DIFFUSION_TRACE_FUNCTION", "0") + ), + # Path to the attention configuration file. Only used for sliding tile + # attention for now. + "SGL_DIFFUSION_ATTENTION_CONFIG": lambda: ( + None + if os.getenv("SGL_DIFFUSION_ATTENTION_CONFIG", None) is None + else os.path.expanduser(os.getenv("SGL_DIFFUSION_ATTENTION_CONFIG", ".")) + ), + # Use dedicated multiprocess context for workers. + # Both spawn and fork work + "SGL_DIFFUSION_WORKER_MULTIPROC_METHOD": lambda: os.getenv( + "SGL_DIFFUSION_WORKER_MULTIPROC_METHOD", "fork" + ), + # Enables torch profiler if set. Path to the directory where torch profiler + # traces are saved. Note that it must be an absolute path. + "SGL_DIFFUSION_TORCH_PROFILER_DIR": lambda: ( + None + if os.getenv("SGL_DIFFUSION_TORCH_PROFILER_DIR", None) is None + else os.path.expanduser(os.getenv("SGL_DIFFUSION_TORCH_PROFILER_DIR", ".")) + ), + # If set, sgl_diffusion will run in development mode, which will enable + # some additional endpoints for developing and debugging, + # e.g. `/reset_prefix_cache` + "SGL_DIFFUSION_SERVER_DEV_MODE": lambda: bool( + int(os.getenv("SGL_DIFFUSION_SERVER_DEV_MODE", "0")) + ), + # If set, sgl_diffusion will enable stage logging, which will print the time + # taken for each stage + "SGL_DIFFUSION_STAGE_LOGGING": lambda: bool( + int(os.getenv("SGL_DIFFUSION_STAGE_LOGGING", "0")) + ), +} + + +# end-env-vars-definition + + +def __getattr__(name: str): + # lazy evaluation of environment variables + if name in environment_variables: + return environment_variables[name]() + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +def __dir__(): + return list(environment_variables.keys()) + + +def get_torch_distributed_backend() -> str: + if torch.cuda.is_available(): + return "nccl" + elif _is_musa(): + return "mccl" + elif _is_mps(): + return "gloo" + else: + raise NotImplementedError( + "No Accelerators(AMD/NV/MTT GPU, AMD MI instinct accelerators) available" + ) + + +def get_device(local_rank: int) -> torch.device: + if torch.cuda.is_available(): + return torch.device("cuda", local_rank) + elif _is_musa(): + return torch.device("musa", local_rank) + elif _is_mps(): + return torch.device("mps") + else: + return torch.device("cpu") diff --git a/python/sglang/multimodal_gen/runtime/architectures/basic/__init__.py b/python/sglang/multimodal_gen/runtime/architectures/basic/__init__.py new file mode 100644 index 000000000..9f0e4bc62 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/architectures/basic/__init__.py @@ -0,0 +1,8 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +Basic inference pipelines for sglang.multimodal_gen. + +This package contains basic pipelines for video and image generation. +""" diff --git a/python/sglang/multimodal_gen/runtime/architectures/basic/flux/__init__.py b/python/sglang/multimodal_gen/runtime/architectures/basic/flux/__init__.py new file mode 100644 index 000000000..af2eb7d10 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/architectures/basic/flux/__init__.py @@ -0,0 +1 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo diff --git a/python/sglang/multimodal_gen/runtime/architectures/basic/flux/flux.py b/python/sglang/multimodal_gen/runtime/architectures/basic/flux/flux.py new file mode 100644 index 000000000..d88e554db --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/architectures/basic/flux/flux.py @@ -0,0 +1,126 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +Hunyuan video diffusion pipeline implementation. + +This module contains an implementation of the Hunyuan video diffusion pipeline +using the modular pipeline architecture. +""" + +from sglang.multimodal_gen.runtime.pipelines import ComposedPipelineBase, Req +from sglang.multimodal_gen.runtime.pipelines.stages import ( + ConditioningStage, + DecodingStage, + DenoisingStage, + InputValidationStage, + LatentPreparationStage, + TextEncodingStage, + TimestepPreparationStage, +) +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +# TODO(will): move PRECISION_TO_TYPE to better place + +logger = init_logger(__name__) + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +def prepare_mu(batch: Req, server_args: ServerArgs): + height = batch.height + width = batch.width + vae_scale_factor = ( + server_args.pipeline_config.vae_config.arch_config.vae_scale_factor + ) + image_seq_len = (int(height) // vae_scale_factor) * (int(width) // vae_scale_factor) + + mu = calculate_shift( + image_seq_len, + # hard code, since scheduler_config is not in PipelineConfig now + 256, + 4096, + 0.5, + 1.15, + ) + return "mu", mu + + +class FluxPipeline(ComposedPipelineBase): + pipeline_name = "FluxPipeline" + + _required_config_modules = [ + "text_encoder", + "text_encoder_2", + "tokenizer", + "tokenizer_2", + "vae", + "transformer", + "scheduler", + ] + + def create_pipeline_stages(self, server_args: ServerArgs): + """Set up pipeline stages with proper dependency injection.""" + + self.add_stage( + stage_name="input_validation_stage", stage=InputValidationStage() + ) + + self.add_stage( + stage_name="prompt_encoding_stage_primary", + stage=TextEncodingStage( + text_encoders=[ + self.get_module("text_encoder"), + self.get_module("text_encoder_2"), + ], + tokenizers=[ + self.get_module("tokenizer"), + self.get_module("tokenizer_2"), + ], + ), + ) + + self.add_stage(stage_name="conditioning_stage", stage=ConditioningStage()) + + self.add_stage( + stage_name="timestep_preparation_stage", + stage=TimestepPreparationStage( + scheduler=self.get_module("scheduler"), + prepare_extra_set_timesteps_kwargs=[prepare_mu], + ), + ) + + self.add_stage( + stage_name="latent_preparation_stage", + stage=LatentPreparationStage( + scheduler=self.get_module("scheduler"), + transformer=self.get_module("transformer"), + ), + ) + + self.add_stage( + stage_name="denoising_stage", + stage=DenoisingStage( + transformer=self.get_module("transformer"), + scheduler=self.get_module("scheduler"), + ), + ) + + self.add_stage( + stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae")) + ) + + +EntryClass = FluxPipeline diff --git a/python/sglang/multimodal_gen/runtime/architectures/basic/hunyuan/__init__.py b/python/sglang/multimodal_gen/runtime/architectures/basic/hunyuan/__init__.py new file mode 100644 index 000000000..af2eb7d10 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/architectures/basic/hunyuan/__init__.py @@ -0,0 +1 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo diff --git a/python/sglang/multimodal_gen/runtime/architectures/basic/hunyuan/hunyuan_pipeline.py b/python/sglang/multimodal_gen/runtime/architectures/basic/hunyuan/hunyuan_pipeline.py new file mode 100644 index 000000000..ffc2c6eec --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/architectures/basic/hunyuan/hunyuan_pipeline.py @@ -0,0 +1,93 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +Hunyuan video diffusion pipeline implementation. + +This module contains an implementation of the Hunyuan video diffusion pipeline +using the modular pipeline architecture. +""" + + +from sglang.multimodal_gen.runtime.pipelines import ComposedPipelineBase +from sglang.multimodal_gen.runtime.pipelines.stages import ( + ConditioningStage, + DecodingStage, + DenoisingStage, + InputValidationStage, + LatentPreparationStage, + TextEncodingStage, + TimestepPreparationStage, +) +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +# TODO(will): move PRECISION_TO_TYPE to better place + +logger = init_logger(__name__) + + +class HunyuanVideoPipeline(ComposedPipelineBase): + + pipeline_name = "HunyuanVideoPipeline" + + _required_config_modules = [ + "text_encoder", + "text_encoder_2", + "tokenizer", + "tokenizer_2", + "vae", + "transformer", + "scheduler", + ] + + def create_pipeline_stages(self, server_args: ServerArgs): + """Set up pipeline stages with proper dependency injection.""" + + self.add_stage( + stage_name="input_validation_stage", stage=InputValidationStage() + ) + + self.add_stage( + stage_name="prompt_encoding_stage_primary", + stage=TextEncodingStage( + text_encoders=[ + self.get_module("text_encoder"), + self.get_module("text_encoder_2"), + ], + tokenizers=[ + self.get_module("tokenizer"), + self.get_module("tokenizer_2"), + ], + ), + ) + + self.add_stage(stage_name="conditioning_stage", stage=ConditioningStage()) + + self.add_stage( + stage_name="timestep_preparation_stage", + stage=TimestepPreparationStage(scheduler=self.get_module("scheduler")), + ) + + self.add_stage( + stage_name="latent_preparation_stage", + stage=LatentPreparationStage( + scheduler=self.get_module("scheduler"), + transformer=self.get_module("transformer"), + ), + ) + + self.add_stage( + stage_name="denoising_stage", + stage=DenoisingStage( + transformer=self.get_module("transformer"), + scheduler=self.get_module("scheduler"), + ), + ) + + self.add_stage( + stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae")) + ) + + +EntryClass = HunyuanVideoPipeline diff --git a/python/sglang/multimodal_gen/runtime/architectures/basic/qwen_image/__init__.py b/python/sglang/multimodal_gen/runtime/architectures/basic/qwen_image/__init__.py new file mode 100644 index 000000000..af2eb7d10 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/architectures/basic/qwen_image/__init__.py @@ -0,0 +1 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo diff --git a/python/sglang/multimodal_gen/runtime/architectures/basic/qwen_image/qwen_image.py b/python/sglang/multimodal_gen/runtime/architectures/basic/qwen_image/qwen_image.py new file mode 100644 index 000000000..649a7f74d --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/architectures/basic/qwen_image/qwen_image.py @@ -0,0 +1,196 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +Hunyuan video diffusion pipeline implementation. + +This module contains an implementation of the Hunyuan video diffusion pipeline +using the modular pipeline architecture. +""" +from diffusers.image_processor import VaeImageProcessor + +from sglang.multimodal_gen.runtime.pipelines import ComposedPipelineBase, Req +from sglang.multimodal_gen.runtime.pipelines.stages import ( + ConditioningStage, + DecodingStage, + DenoisingStage, + ImageEncodingStage, + ImageVAEEncodingStage, + InputValidationStage, + LatentPreparationStage, + TextEncodingStage, + TimestepPreparationStage, +) +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +# TODO(will): move PRECISION_TO_TYPE to better place + +logger = init_logger(__name__) + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +def prepare_mu(batch: Req, server_args: ServerArgs): + height = batch.height + width = batch.width + vae_scale_factor = server_args.pipeline_config.vae_config.vae_scale_factor + image_seq_len = (int(height) // vae_scale_factor) * (int(width) // vae_scale_factor) + + mu = calculate_shift( + image_seq_len, + # hard code, since scheduler_config is not in PipelineConfig now + 256, + 4096, + 0.5, + 1.15, + ) + return "mu", mu + + +class QwenImagePipeline(ComposedPipelineBase): + pipeline_name = "QwenImagePipeline" + + _required_config_modules = [ + "text_encoder", + "tokenizer", + "vae", + "transformer", + "scheduler", + ] + + def create_pipeline_stages(self, server_args: ServerArgs): + """Set up pipeline stages with proper dependency injection.""" + + self.add_stage( + stage_name="input_validation_stage", stage=InputValidationStage() + ) + + self.add_stage( + stage_name="prompt_encoding_stage_primary", + stage=TextEncodingStage( + text_encoders=[ + self.get_module("text_encoder"), + ], + tokenizers=[ + self.get_module("tokenizer"), + ], + ), + ) + + self.add_stage(stage_name="conditioning_stage", stage=ConditioningStage()) + + self.add_stage( + stage_name="timestep_preparation_stage", + stage=TimestepPreparationStage( + scheduler=self.get_module("scheduler"), + prepare_extra_set_timesteps_kwargs=[prepare_mu], + ), + ) + + self.add_stage( + stage_name="latent_preparation_stage", + stage=LatentPreparationStage( + scheduler=self.get_module("scheduler"), + transformer=self.get_module("transformer"), + ), + ) + + self.add_stage( + stage_name="denoising_stage", + stage=DenoisingStage( + transformer=self.get_module("transformer"), + scheduler=self.get_module("scheduler"), + ), + ) + + self.add_stage( + stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae")) + ) + + +class QwenImageEditPipeline(ComposedPipelineBase): + pipeline_name = "QwenImageEditPipeline" + + _required_config_modules = [ + "processor", + "scheduler", + "text_encoder", + "tokenizer", + "transformer", + "vae", + ] + + def create_pipeline_stages(self, server_args: ServerArgs): + """Set up pipeline stages with proper dependency injection.""" + + self.add_stage( + stage_name="input_validation_stage", stage=InputValidationStage() + ) + + self.add_stage( + stage_name="prompt_encoding_stage_primary", + stage=ImageEncodingStage( + image_processor=self.get_module("processor"), + text_encoder=self.get_module("text_encoder"), + vae_image_processor=VaeImageProcessor( + vae_scale_factor=server_args.pipeline_config.vae_config.arch_config.vae_scale_factor + * 2 + ), + ), + ) + + self.add_stage( + stage_name="image_encoding_stage_primary", + stage=ImageVAEEncodingStage( + vae_image_processor=VaeImageProcessor( + vae_scale_factor=server_args.pipeline_config.vae_config.arch_config.vae_scale_factor + * 2 + ), + vae=self.get_module("vae"), + ), + ) + + self.add_stage( + stage_name="timestep_preparation_stage", + stage=TimestepPreparationStage( + scheduler=self.get_module("scheduler"), + prepare_extra_set_timesteps_kwargs=[prepare_mu], + ), + ) + + self.add_stage( + stage_name="latent_preparation_stage", + stage=LatentPreparationStage( + scheduler=self.get_module("scheduler"), + transformer=self.get_module("transformer"), + ), + ) + + self.add_stage(stage_name="conditioning_stage", stage=ConditioningStage()) + + self.add_stage( + stage_name="denoising_stage", + stage=DenoisingStage( + transformer=self.get_module("transformer"), + scheduler=self.get_module("scheduler"), + ), + ) + + self.add_stage( + stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae")) + ) + + +EntryClass = [QwenImagePipeline, QwenImageEditPipeline] diff --git a/python/sglang/multimodal_gen/runtime/architectures/basic/stepvideo/__init__.py b/python/sglang/multimodal_gen/runtime/architectures/basic/stepvideo/__init__.py new file mode 100644 index 000000000..af2eb7d10 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/architectures/basic/stepvideo/__init__.py @@ -0,0 +1 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo diff --git a/python/sglang/multimodal_gen/runtime/architectures/basic/stepvideo/stepvideo_pipeline.py b/python/sglang/multimodal_gen/runtime/architectures/basic/stepvideo/stepvideo_pipeline.py new file mode 100644 index 000000000..bb0e4ef9f --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/architectures/basic/stepvideo/stepvideo_pipeline.py @@ -0,0 +1,182 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# type: ignore +# SPDX-License-Identifier: Apache-2.0 +""" +Hunyuan video diffusion pipeline implementation. + +This module contains an implementation of the Hunyuan video diffusion pipeline +using the modular pipeline architecture. +""" + +import os +from typing import Any + +import torch +from huggingface_hub import hf_hub_download + +from sglang.multimodal_gen.runtime.distributed import get_local_torch_device +from sglang.multimodal_gen.runtime.loader.component_loader import ( + PipelineComponentLoader, +) +from sglang.multimodal_gen.runtime.models.encoders.bert import ( + HunyuanClip, # type: ignore +) +from sglang.multimodal_gen.runtime.models.encoders.stepllm import STEP1TextEncoder +from sglang.multimodal_gen.runtime.pipelines.composed_pipeline_base import ( + ComposedPipelineBase, +) +from sglang.multimodal_gen.runtime.pipelines.lora_pipeline import LoRAPipeline +from sglang.multimodal_gen.runtime.pipelines.stages import ( + DecodingStage, + DenoisingStage, + InputValidationStage, + LatentPreparationStage, + StepvideoPromptEncodingStage, + TimestepPreparationStage, +) +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class StepVideoPipeline(LoRAPipeline, ComposedPipelineBase): + pipeline_name = "StepVideoPipeline" + + _required_config_modules = ["transformer", "scheduler", "vae"] + + def create_pipeline_stages(self, server_args: ServerArgs): + """Set up pipeline stages with proper dependency injection.""" + + self.add_stage( + stage_name="input_validation_stage", stage=InputValidationStage() + ) + + self.add_stage( + stage_name="prompt_encoding_stage", + stage=StepvideoPromptEncodingStage( + stepllm=self.get_module("text_encoder"), + clip=self.get_module("text_encoder_2"), + ), + ) + + self.add_stage( + stage_name="timestep_preparation_stage", + stage=TimestepPreparationStage(scheduler=self.get_module("scheduler")), + ) + + self.add_stage( + stage_name="latent_preparation_stage", + stage=LatentPreparationStage( + scheduler=self.get_module("scheduler"), + transformer=self.get_module("transformer"), + ), + ) + + self.add_stage( + stage_name="denoising_stage", + stage=DenoisingStage( + transformer=self.get_module("transformer"), + scheduler=self.get_module("scheduler"), + ), + ) + + self.add_stage( + stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae")) + ) + + def build_llm(self, model_dir, device) -> torch.nn.Module: + text_encoder = ( + STEP1TextEncoder(model_dir, max_length=320).to(torch.bfloat16).eval() + ) + return text_encoder + + def build_clip(self, model_dir, device) -> HunyuanClip: + clip = HunyuanClip(model_dir, max_length=77).eval() + return clip + + def initialize_pipeline(self, server_args: ServerArgs): + """ + Initialize the pipeline. + """ + target_device = get_local_torch_device() + llm_dir = os.path.join(self.model_path, "step_llm") + clip_dir = os.path.join(self.model_path, "hunyuan_clip") + text_enc = self.build_llm(llm_dir, target_device) + clip_enc = self.build_clip(clip_dir, target_device) + self.add_module("text_encoder", text_enc) + self.add_module("text_encoder_2", clip_enc) + lib_path = ( + os.path.join( + server_args.model_path, + "lib/liboptimus_ths-torch2.5-cu124.cpython-310-x86_64-linux-gnu.so", + ) + if os.path.isdir(server_args.model_path) # local checkout + else hf_hub_download( + repo_id=server_args.model_path, + filename="lib/liboptimus_ths-torch2.5-cu124.cpython-310-x86_64-linux-gnu.so", + ) + ) + torch.ops.load_library(lib_path) + + def load_modules( + self, + server_args: ServerArgs, + loaded_modules: dict[str, torch.nn.Module] | None = None, + ) -> dict[str, Any]: + """ + Load the modules from the config. + """ + model_index = self._load_config() + logger.info("Loading pipeline modules from config: %s", model_index) + + # remove keys that are not pipeline modules + model_index.pop("_class_name") + model_index.pop("_diffusers_version") + + # some sanity checks + assert ( + len(model_index) > 1 + ), "model_index.json must contain at least one pipeline module" + + required_modules = ["transformer", "scheduler", "vae"] + for module_name in required_modules: + if module_name not in model_index: + raise ValueError( + f"model_index.json must contain a {module_name} module" + ) + logger.info("Diffusers config passed sanity checks") + + # all the component models used by the pipeline + modules = {} + for module_name, ( + transformers_or_diffusers, + architecture, + ) in model_index.items(): + component_model_path = os.path.join(self.model_path, module_name) + module = PipelineComponentLoader.load_module( + module_name=module_name, + component_model_path=component_model_path, + transformers_or_diffusers=transformers_or_diffusers, + server_args=server_args, + ) + logger.info("Loaded module %s from %s", module_name, component_model_path) + + if module_name in modules: + logger.warning("Overwriting module %s", module_name) + modules[module_name] = module + + required_modules = self.required_config_modules + # Check if all required modules were loaded + for module_name in required_modules: + if module_name not in modules or modules[module_name] is None: + raise ValueError( + f"Required module {module_name} was not loaded properly" + ) + + return modules + + +EntryClass = StepVideoPipeline diff --git a/python/sglang/multimodal_gen/runtime/architectures/basic/wan/__init__.py b/python/sglang/multimodal_gen/runtime/architectures/basic/wan/__init__.py new file mode 100644 index 000000000..af2eb7d10 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/architectures/basic/wan/__init__.py @@ -0,0 +1 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo diff --git a/python/sglang/multimodal_gen/runtime/architectures/basic/wan/wan_causal_dmd_pipeline.py b/python/sglang/multimodal_gen/runtime/architectures/basic/wan/wan_causal_dmd_pipeline.py new file mode 100644 index 000000000..6e1f59be2 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/architectures/basic/wan/wan_causal_dmd_pipeline.py @@ -0,0 +1,78 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +Wan causal DMD pipeline implementation. + +This module wires the causal DMD denoising stage into the modular pipeline. +""" + +from sglang.multimodal_gen.runtime.pipelines import ComposedPipelineBase, LoRAPipeline + +# isort: off +from sglang.multimodal_gen.runtime.pipelines.stages import ( + ConditioningStage, + DecodingStage, + CausalDMDDenoisingStage, + InputValidationStage, + LatentPreparationStage, + TextEncodingStage, +) +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +# isort: on + +logger = init_logger(__name__) + + +class WanCausalDMDPipeline(LoRAPipeline, ComposedPipelineBase): + pipeline_name = "WanPipeline" + + _required_config_modules = [ + "text_encoder", + "tokenizer", + "vae", + "transformer", + "scheduler", + ] + + def create_pipeline_stages(self, server_args: ServerArgs) -> None: + """Set up pipeline stages with proper dependency injection.""" + + self.add_stage( + stage_name="input_validation_stage", stage=InputValidationStage() + ) + + self.add_stage( + stage_name="prompt_encoding_stage", + stage=TextEncodingStage( + text_encoders=[self.get_module("text_encoder")], + tokenizers=[self.get_module("tokenizer")], + ), + ) + + self.add_stage(stage_name="conditioning_stage", stage=ConditioningStage()) + + self.add_stage( + stage_name="latent_preparation_stage", + stage=LatentPreparationStage( + scheduler=self.get_module("scheduler"), + transformer=self.get_module("transformer", None), + ), + ) + + self.add_stage( + stage_name="denoising_stage", + stage=CausalDMDDenoisingStage( + transformer=self.get_module("transformer"), + scheduler=self.get_module("scheduler"), + ), + ) + + self.add_stage( + stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae")) + ) + + +EntryClass = WanCausalDMDPipeline diff --git a/python/sglang/multimodal_gen/runtime/architectures/basic/wan/wan_dmd_pipeline.py b/python/sglang/multimodal_gen/runtime/architectures/basic/wan/wan_dmd_pipeline.py new file mode 100644 index 000000000..2b13408e3 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/architectures/basic/wan/wan_dmd_pipeline.py @@ -0,0 +1,98 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +Wan video diffusion pipeline implementation. + +This module contains an implementation of the Wan video diffusion pipeline +using the modular pipeline architecture. +""" + +from sglang.multimodal_gen.runtime.models.schedulers.scheduling_flow_match_euler_discrete import ( + FlowMatchEulerDiscreteScheduler, +) +from sglang.multimodal_gen.runtime.pipelines import ComposedPipelineBase, LoRAPipeline +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +# isort: off +from sglang.multimodal_gen.runtime.pipelines.stages import ( + ConditioningStage, + DecodingStage, + DmdDenoisingStage, + InputValidationStage, + LatentPreparationStage, + TextEncodingStage, + TimestepPreparationStage, +) + +# isort: on + +logger = init_logger(__name__) + + +class WanDMDPipeline(LoRAPipeline, ComposedPipelineBase): + """ + Wan video diffusion pipeline with LoRA support. + """ + + pipeline_name = "WanDMDPipeline" + + _required_config_modules = [ + "text_encoder", + "tokenizer", + "vae", + "transformer", + "scheduler", + ] + + def initialize_pipeline(self, server_args: ServerArgs): + + self.modules["scheduler"] = FlowMatchEulerDiscreteScheduler( + shift=server_args.pipeline_config.flow_shift + ) + + def create_pipeline_stages(self, server_args: ServerArgs) -> None: + """Set up pipeline stages with proper dependency injection.""" + + self.add_stage( + stage_name="input_validation_stage", stage=InputValidationStage() + ) + + self.add_stage( + stage_name="prompt_encoding_stage", + stage=TextEncodingStage( + text_encoders=[self.get_module("text_encoder")], + tokenizers=[self.get_module("tokenizer")], + ), + ) + + self.add_stage(stage_name="conditioning_stage", stage=ConditioningStage()) + + self.add_stage( + stage_name="timestep_preparation_stage", + stage=TimestepPreparationStage(scheduler=self.get_module("scheduler")), + ) + + self.add_stage( + stage_name="latent_preparation_stage", + stage=LatentPreparationStage( + scheduler=self.get_module("scheduler"), + transformer=self.get_module("transformer", None), + ), + ) + + self.add_stage( + stage_name="denoising_stage", + stage=DmdDenoisingStage( + transformer=self.get_module("transformer"), + scheduler=self.get_module("scheduler"), + ), + ) + + self.add_stage( + stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae")) + ) + + +EntryClass = WanDMDPipeline diff --git a/python/sglang/multimodal_gen/runtime/architectures/basic/wan/wan_i2v_dmd_pipeline.py b/python/sglang/multimodal_gen/runtime/architectures/basic/wan/wan_i2v_dmd_pipeline.py new file mode 100644 index 000000000..b0e256457 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/architectures/basic/wan/wan_i2v_dmd_pipeline.py @@ -0,0 +1,113 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +Wan video diffusion pipeline implementation. + +This module contains an implementation of the Wan video diffusion pipeline +using the modular pipeline architecture. +""" + +from sglang.multimodal_gen.runtime.pipelines.composed_pipeline_base import ( + ComposedPipelineBase, +) +from sglang.multimodal_gen.runtime.pipelines.lora_pipeline import LoRAPipeline +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +# isort: off +from sglang.multimodal_gen.runtime.pipelines.stages import ( + ImageEncodingStage, + ConditioningStage, + DecodingStage, + DmdDenoisingStage, + ImageVAEEncodingStage, + InputValidationStage, + LatentPreparationStage, + TextEncodingStage, + TimestepPreparationStage, +) + +# isort: on +from sglang.multimodal_gen.runtime.models.schedulers.scheduling_flow_match_euler_discrete import ( + FlowMatchEulerDiscreteScheduler, +) + +logger = init_logger(__name__) + + +class WanImageToVideoDmdPipeline(LoRAPipeline, ComposedPipelineBase): + pipeline_name = "WanCausalDMDPipeline" + + _required_config_modules = [ + "text_encoder", + "tokenizer", + "vae", + "transformer", + "scheduler", + "image_encoder", + "image_processor", + ] + + def initialize_pipeline(self, server_args: ServerArgs): + self.modules["scheduler"] = FlowMatchEulerDiscreteScheduler( + shift=server_args.pipeline_config.flow_shift + ) + + def create_pipeline_stages(self, server_args: ServerArgs): + """Set up pipeline stages with proper dependency injection.""" + + self.add_stage( + stage_name="input_validation_stage", stage=InputValidationStage() + ) + + self.add_stage( + stage_name="prompt_encoding_stage", + stage=TextEncodingStage( + text_encoders=[self.get_module("text_encoder")], + tokenizers=[self.get_module("tokenizer")], + ), + ) + + self.add_stage( + stage_name="image_encoding_stage", + stage=ImageEncodingStage( + image_encoder=self.get_module("image_encoder"), + image_processor=self.get_module("image_processor"), + ), + ) + + self.add_stage(stage_name="conditioning_stage", stage=ConditioningStage()) + + self.add_stage( + stage_name="timestep_preparation_stage", + stage=TimestepPreparationStage(scheduler=self.get_module("scheduler")), + ) + + self.add_stage( + stage_name="latent_preparation_stage", + stage=LatentPreparationStage( + scheduler=self.get_module("scheduler"), + transformer=self.get_module("transformer"), + ), + ) + + self.add_stage( + stage_name="image_latent_preparation_stage", + stage=ImageVAEEncodingStage(vae=self.get_module("vae")), + ) + + self.add_stage( + stage_name="denoising_stage", + stage=DmdDenoisingStage( + transformer=self.get_module("transformer"), + scheduler=self.get_module("scheduler"), + ), + ) + + self.add_stage( + stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae")) + ) + + +EntryClass = WanImageToVideoDmdPipeline diff --git a/python/sglang/multimodal_gen/runtime/architectures/basic/wan/wan_i2v_pipeline.py b/python/sglang/multimodal_gen/runtime/architectures/basic/wan/wan_i2v_pipeline.py new file mode 100644 index 000000000..9477d4009 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/architectures/basic/wan/wan_i2v_pipeline.py @@ -0,0 +1,118 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +Wan video diffusion pipeline implementation. + +This module contains an implementation of the Wan video diffusion pipeline +using the modular pipeline architecture. +""" + +from sglang.multimodal_gen.runtime.pipelines.composed_pipeline_base import ( + ComposedPipelineBase, +) +from sglang.multimodal_gen.runtime.pipelines.lora_pipeline import LoRAPipeline +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +# isort: off +from sglang.multimodal_gen.runtime.pipelines.stages import ( + ImageEncodingStage, + ConditioningStage, + DecodingStage, + DenoisingStage, + ImageVAEEncodingStage, + InputValidationStage, + LatentPreparationStage, + TextEncodingStage, + TimestepPreparationStage, +) + +# isort: on +from sglang.multimodal_gen.runtime.models.schedulers.scheduling_flow_unipc_multistep import ( + FlowUniPCMultistepScheduler, +) + +logger = init_logger(__name__) + + +class WanImageToVideoPipeline(LoRAPipeline, ComposedPipelineBase): + pipeline_name = "WanImageToVideoPipeline" + + _required_config_modules = [ + "text_encoder", + "tokenizer", + "vae", + "transformer", + "scheduler", + "image_encoder", + "image_processor", + ] + + def initialize_pipeline(self, server_args: ServerArgs): + self.modules["scheduler"] = FlowUniPCMultistepScheduler( + shift=server_args.pipeline_config.flow_shift + ) + + def create_pipeline_stages(self, server_args: ServerArgs): + """Set up pipeline stages with proper dependency injection.""" + + self.add_stage( + stage_name="input_validation_stage", stage=InputValidationStage() + ) + + self.add_stage( + stage_name="prompt_encoding_stage", + stage=TextEncodingStage( + text_encoders=[self.get_module("text_encoder")], + tokenizers=[self.get_module("tokenizer")], + ), + ) + + if ( + self.get_module("image_encoder") is not None + and self.get_module("image_processor") is not None + ): + self.add_stage( + stage_name="image_encoding_stage", + stage=ImageEncodingStage( + image_encoder=self.get_module("image_encoder"), + image_processor=self.get_module("image_processor"), + ), + ) + + self.add_stage(stage_name="conditioning_stage", stage=ConditioningStage()) + + self.add_stage( + stage_name="timestep_preparation_stage", + stage=TimestepPreparationStage(scheduler=self.get_module("scheduler")), + ) + + self.add_stage( + stage_name="latent_preparation_stage", + stage=LatentPreparationStage( + scheduler=self.get_module("scheduler"), + transformer=self.get_module("transformer"), + ), + ) + + self.add_stage( + stage_name="image_latent_preparation_stage", + stage=ImageVAEEncodingStage(vae=self.get_module("vae")), + ) + + self.add_stage( + stage_name="denoising_stage", + stage=DenoisingStage( + transformer=self.get_module("transformer"), + transformer_2=self.get_module("transformer_2"), + scheduler=self.get_module("scheduler"), + ), + ) + + self.add_stage( + stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae")) + ) + + +EntryClass = WanImageToVideoPipeline diff --git a/python/sglang/multimodal_gen/runtime/architectures/basic/wan/wan_pipeline.py b/python/sglang/multimodal_gen/runtime/architectures/basic/wan/wan_pipeline.py new file mode 100644 index 000000000..be49674d6 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/architectures/basic/wan/wan_pipeline.py @@ -0,0 +1,98 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +Wan video diffusion pipeline implementation. + +This module contains an implementation of the Wan video diffusion pipeline +using the modular pipeline architecture. +""" + +from sglang.multimodal_gen.runtime.models.schedulers.scheduling_flow_unipc_multistep import ( + FlowUniPCMultistepScheduler, +) +from sglang.multimodal_gen.runtime.pipelines import ComposedPipelineBase, LoRAPipeline +from sglang.multimodal_gen.runtime.pipelines.stages import ( + ConditioningStage, + DecodingStage, + DenoisingStage, + InputValidationStage, + LatentPreparationStage, + TextEncodingStage, + TimestepPreparationStage, +) +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class WanPipeline(LoRAPipeline, ComposedPipelineBase): + """ + Wan video diffusion pipeline with LoRA support. + """ + + pipeline_name = "WanImageToVideoPipeline" + + _required_config_modules = [ + "text_encoder", + "tokenizer", + "vae", + "transformer", + "scheduler", + ] + + def initialize_pipeline(self, server_args: ServerArgs): + # We use UniPCMScheduler from Wan2.1 official repo, not the one in diffusers. + self.modules["scheduler"] = FlowUniPCMultistepScheduler( + shift=server_args.pipeline_config.flow_shift + ) + + def create_pipeline_stages(self, server_args: ServerArgs) -> None: + """Set up pipeline stages with proper dependency injection.""" + + self.add_stage( + stage_name="input_validation_stage", stage=InputValidationStage() + ) + + self.add_stage( + stage_name="prompt_encoding_stage", + stage=TextEncodingStage( + text_encoders=[self.get_module("text_encoder")], + tokenizers=[self.get_module("tokenizer")], + ), + ) + + self.add_stage(stage_name="conditioning_stage", stage=ConditioningStage()) + + self.add_stage( + stage_name="timestep_preparation_stage", + stage=TimestepPreparationStage(scheduler=self.get_module("scheduler")), + ) + + self.add_stage( + stage_name="latent_preparation_stage", + stage=LatentPreparationStage( + scheduler=self.get_module("scheduler"), + transformer=self.get_module("transformer", None), + ), + ) + + self.add_stage( + stage_name="denoising_stage", + stage=DenoisingStage( + transformer=self.get_module("transformer"), + transformer_2=self.get_module("transformer_2", None), + scheduler=self.get_module("scheduler"), + vae=self.get_module("vae"), + pipeline=self, + ), + ) + + self.add_stage( + stage_name="decoding_stage", + stage=DecodingStage(vae=self.get_module("vae"), pipeline=self), + ) + + +EntryClass = WanPipeline diff --git a/python/sglang/multimodal_gen/runtime/architectures/preprocess/__init__.py b/python/sglang/multimodal_gen/runtime/architectures/preprocess/__init__.py new file mode 100644 index 000000000..af2eb7d10 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/architectures/preprocess/__init__.py @@ -0,0 +1 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo diff --git a/python/sglang/multimodal_gen/runtime/architectures/preprocess/preprocess_pipeline_base.py b/python/sglang/multimodal_gen/runtime/architectures/preprocess/preprocess_pipeline_base.py new file mode 100644 index 000000000..e82dd5716 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/architectures/preprocess/preprocess_pipeline_base.py @@ -0,0 +1,433 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +import os +from typing import Any + +import numpy as np +import pyarrow as pa +import pyarrow.parquet as pq +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm + +from sglang.multimodal_gen.dataset import getdataset +from sglang.multimodal_gen.dataset.dataloader.parquet_io import ( + ParquetDatasetWriter, + records_to_table, +) +from sglang.multimodal_gen.dataset.preprocessing_datasets import PreprocessBatch +from sglang.multimodal_gen.runtime.distributed import get_local_torch_device +from sglang.multimodal_gen.runtime.pipelines.composed_pipeline_base import ( + ComposedPipelineBase, +) +from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import Req +from sglang.multimodal_gen.runtime.pipelines.stages import TextEncodingStage +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class BasePreprocessPipeline(ComposedPipelineBase): + """Base class for preprocessing pipelines that handles common functionality.""" + + def create_pipeline_stages(self, server_args: ServerArgs): + """Set up pipeline stages with proper dependency injection.""" + self.add_stage( + stage_name="prompt_encoding_stage", + stage=TextEncodingStage( + text_encoders=[self.get_module("text_encoder")], + tokenizers=[self.get_module("tokenizer")], + ), + ) + + @torch.no_grad() + def forward( + self, + batch: Req, + server_args: ServerArgs, + args, + ): + if not self.post_init_called: + self.post_init() + + # Initialize class variables for data sharing + self.video_data: dict[str, Any] = {} # Store video metadata and paths + self.latent_data: dict[str, Any] = {} # Store latent tensors + self.preprocess_video_and_text(server_args, args) + + def get_extra_features( + self, valid_data: dict[str, Any], server_args: ServerArgs + ) -> dict[str, Any]: + """Get additional features specific to the pipeline type. Override in subclasses.""" + return {} + + def get_pyarrow_schema(self) -> pa.Schema: + """Return the PyArrow schema for this pipeline. Must be overridden.""" + raise NotImplementedError + + def get_schema_fields(self) -> list[str]: + """Get the schema fields for the pipeline type.""" + return [f.name for f in self.get_pyarrow_schema()] + + def create_record_for_schema( + self, preprocess_batch: PreprocessBatch, schema: pa.Schema, strict: bool = False + ) -> dict[str, Any]: + """Create a record for the Parquet dataset using a generic schema-based approach. + + Args: + preprocess_batch: The batch containing the data to extract + schema: PyArrow schema defining the expected fields + strict: If True, raises an exception when required fields are missing or unfilled + + Returns: + Dictionary record matching the schema + + Raises: + ValueError: If strict=True and required fields are missing or unfilled + """ + record = {} + unfilled_fields = [] + + for field in schema.names: + field_filled = False + + if field.endswith("_bytes"): + # Handle binary tensor data - convert numpy array or tensor to bytes + tensor_name = field.replace("_bytes", "") + tensor_data = getattr(preprocess_batch, tensor_name, None) + if tensor_data is not None: + try: + if hasattr(tensor_data, "numpy"): # torch tensor + record[field] = tensor_data.cpu().numpy().tobytes() + field_filled = True + elif hasattr(tensor_data, "tobytes"): # numpy array + record[field] = tensor_data.tobytes() + field_filled = True + else: + raise ValueError( + f"Unsupported tensor type for field {field}: {type(tensor_data)}" + ) + except Exception as e: + if strict: + raise ValueError( + f"Failed to convert tensor {tensor_name} to bytes: {e}" + ) from e + record[field] = b"" # Empty bytes for missing data + else: + record[field] = b"" # Empty bytes for missing data + + elif field.endswith("_shape"): + # Handle tensor shape info + tensor_name = field.replace("_shape", "") + tensor_data = getattr(preprocess_batch, tensor_name, None) + if tensor_data is not None and hasattr(tensor_data, "shape"): + record[field] = list(tensor_data.shape) + field_filled = True + else: + record[field] = [] + + elif field.endswith("_dtype"): + # Handle tensor dtype info + tensor_name = field.replace("_dtype", "") + tensor_data = getattr(preprocess_batch, tensor_name, None) + if tensor_data is not None and hasattr(tensor_data, "dtype"): + record[field] = str(tensor_data.dtype) + field_filled = True + else: + record[field] = "unknown" + + elif field in ["width", "height", "num_frames"]: + # Handle integer metadata fields + value = getattr(preprocess_batch, field, None) + if value is not None: + try: + record[field] = int(value) + field_filled = True + except (ValueError, TypeError) as e: + if strict: + raise ValueError( + f"Failed to convert field {field} to int: {e}" + ) from e + record[field] = 0 + else: + record[field] = 0 + + elif field in ["duration_sec", "fps"]: + # Handle float metadata fields + # Map schema field names to batch attribute names + attr_name = "duration" if field == "duration_sec" else field + value = getattr(preprocess_batch, attr_name, None) + if value is not None: + try: + record[field] = float(value) + field_filled = True + except (ValueError, TypeError) as e: + if strict: + raise ValueError( + f"Failed to convert field {field} to float: {e}" + ) from e + record[field] = 0.0 + else: + record[field] = 0.0 + + else: + # Handle string fields (id, file_name, caption, media_type, etc.) + # Map common schema field names to batch attribute names + attr_name = field + if field == "caption": + attr_name = "text" + elif field == "file_name": + attr_name = "path" + elif field == "id": + # Generate ID from path if available + path_value = getattr(preprocess_batch, "path", None) + if path_value: + import os + + record[field] = os.path.basename(path_value).split(".")[0] + field_filled = True + else: + record[field] = "" + continue + elif field == "media_type": + # Determine media type from path + path_value = getattr(preprocess_batch, "path", None) + if path_value: + record[field] = ( + "video" if path_value.endswith(".mp4") else "image" + ) + field_filled = True + else: + record[field] = "" + continue + + value = getattr(preprocess_batch, attr_name, None) + if value is not None: + record[field] = str(value) + field_filled = True + else: + record[field] = "" + + # Track unfilled fields + if not field_filled: + unfilled_fields.append(field) + + # Handle strict mode + if strict and unfilled_fields: + raise ValueError(f"Required fields were not filled: {unfilled_fields}") + + # Log unfilled fields as warning if not in strict mode + if unfilled_fields: + logger.warning( + "Some fields were not filled and got default values: %s", + unfilled_fields, + ) + + return record + + def create_record( + self, + video_name: str, + vae_latent: np.ndarray, + text_embedding: np.ndarray, + valid_data: dict[str, Any], + idx: int, + extra_features: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Create a record for the Parquet dataset.""" + record = { + "id": video_name, + "vae_latent_bytes": vae_latent.tobytes(), + "vae_latent_shape": list(vae_latent.shape), + "vae_latent_dtype": str(vae_latent.dtype), + "text_embedding_bytes": text_embedding.tobytes(), + "text_embedding_shape": list(text_embedding.shape), + "text_embedding_dtype": str(text_embedding.dtype), + "file_name": video_name, + "caption": valid_data["text"][idx] if len(valid_data["text"]) > 0 else "", + "media_type": "video", + "width": ( + valid_data["pixel_values"][idx].shape[-2] + if len(valid_data["pixel_values"]) > 0 + else 0 + ), + "height": ( + valid_data["pixel_values"][idx].shape[-1] + if len(valid_data["pixel_values"]) > 0 + else 0 + ), + "num_frames": vae_latent.shape[1] if len(vae_latent.shape) > 1 else 0, + "duration_sec": ( + float(valid_data["duration"][idx]) + if len(valid_data["duration"]) > 0 + else 0.0 + ), + "fps": float(valid_data["fps"][idx]) if len(valid_data["fps"]) > 0 else 0.0, + } + if extra_features: + record.update(extra_features) + return record + + def preprocess_video_and_text(self, server_args: ServerArgs, args): + os.makedirs(args.output_dir, exist_ok=True) + # Create directory for combined data + combined_parquet_dir = os.path.join(args.output_dir, "combined_parquet_dataset") + os.makedirs(combined_parquet_dir, exist_ok=True) + local_rank = int(os.getenv("RANK", 0)) + + # Get how many samples have already been processed + start_idx = 0 + for root, _, files in os.walk(combined_parquet_dir): + for file in files: + if file.endswith(".parquet"): + table = pq.read_table(os.path.join(root, file)) + start_idx += table.num_rows + + # Loading dataset + train_dataset = getdataset(args) + + train_dataloader = DataLoader( + train_dataset, + batch_size=args.preprocess_video_batch_size, + num_workers=args.dataloader_num_workers, + ) + + num_processed_samples = 0 + # Add progress bar for video preprocessing + pbar = tqdm( + train_dataloader, + desc="Processing videos", + unit="batch", + disable=local_rank != 0, + ) + + for batch_idx, data in enumerate(pbar): + if data is None: + continue + + with torch.inference_mode(): + # Filter out invalid samples (those with all zeros) + valid_indices = [] + for i, pixel_values in enumerate(data["pixel_values"]): + if not torch.all(pixel_values == 0): # Check if all values are zero + valid_indices.append(i) + num_processed_samples += len(valid_indices) + + if not valid_indices: + continue + + # Create new batch with only valid samples + valid_data = { + "pixel_values": torch.stack( + [data["pixel_values"][i] for i in valid_indices] + ), + "text": [data["text"][i] for i in valid_indices], + "path": [data["path"][i] for i in valid_indices], + "fps": [data["fps"][i] for i in valid_indices], + "duration": [data["duration"][i] for i in valid_indices], + } + + # VAE + with torch.autocast("cuda", dtype=torch.float32): + latents = ( + self.get_module("vae") + .encode(valid_data["pixel_values"].to(get_local_torch_device())) + .mean + ) + + # Get extra features if needed + extra_features = self.get_extra_features(valid_data, server_args) + + batch_captions = valid_data["text"] + batch = Req( + data_type="video", + prompt=batch_captions, + prompt_embeds=[], + prompt_attention_mask=[], + ) + assert hasattr(self, "prompt_encoding_stage") + result_batch = self.prompt_encoding_stage(batch, server_args) + prompt_embeds, prompt_attention_mask = ( + result_batch.prompt_embeds[0], + result_batch.prompt_attention_mask[0], + ) + assert prompt_embeds.shape[0] == prompt_attention_mask.shape[0] + + # Get sequence lengths from attention masks (number of 1s) + seq_lens = prompt_attention_mask.sum(dim=1) + + non_padded_embeds = [] + non_padded_masks = [] + + # Process each item in the batch + for i in range(prompt_embeds.size(0)): + seq_len = seq_lens[i].item() + # Slice the embeddings and masks to keep only non-padding parts + non_padded_embeds.append(prompt_embeds[i, :seq_len]) + non_padded_masks.append(prompt_attention_mask[i, :seq_len]) + + # Update the tensors with non-padded versions + prompt_embeds = non_padded_embeds + prompt_attention_mask = non_padded_masks + + # Prepare batch data for Parquet dataset + batch_data = [] + + # Add progress bar for saving outputs + save_pbar = tqdm( + enumerate(valid_data["path"]), + desc="Saving outputs", + unit="item", + leave=False, + ) + for idx, video_path in save_pbar: + # Get the corresponding latent and info using video name + latent = latents[idx].cpu() + video_name = os.path.basename(video_path).split(".")[0] + + # Convert tensors to numpy arrays + vae_latent = latent.cpu().numpy() + text_embedding = prompt_embeds[idx].cpu().numpy() + + # Get extra features for this sample if needed + sample_extra_features = {} + if extra_features: + for key, value in extra_features.items(): + if isinstance(value, torch.Tensor): + sample_extra_features[key] = value[idx].cpu().numpy() + else: + sample_extra_features[key] = value[idx] + + # Create record for Parquet dataset + record = self.create_record( + video_name=video_name, + vae_latent=vae_latent, + text_embedding=text_embedding, + valid_data=valid_data, + idx=idx, + extra_features=sample_extra_features, + ) + batch_data.append(record) + + if batch_data: + write_pbar = tqdm( + total=1, desc="Writing to Parquet dataset", unit="batch" + ) + table = records_to_table(batch_data, self.get_pyarrow_schema()) + write_pbar.update(1) + write_pbar.close() + + if not hasattr(self, "dataset_writer"): + self.dataset_writer = ParquetDatasetWriter( + out_dir=combined_parquet_dir, + samples_per_file=args.samples_per_file, + ) + self.dataset_writer.append_table(table) + logger.info("Collected batch with %s samples", len(table)) + + if num_processed_samples >= args.flush_frequency: + written = self.dataset_writer.flush() + logger.info("Flushed %s samples to parquet", written) + num_processed_samples = 0 diff --git a/python/sglang/multimodal_gen/runtime/architectures/preprocess/preprocess_pipeline_i2v.py b/python/sglang/multimodal_gen/runtime/architectures/preprocess/preprocess_pipeline_i2v.py new file mode 100644 index 000000000..2c6e8dbbc --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/architectures/preprocess/preprocess_pipeline_i2v.py @@ -0,0 +1,247 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +I2V Data Preprocessing pipeline implementation. + +This module contains an implementation of the I2V Data Preprocessing pipeline +using the modular pipeline architecture. +""" +from typing import Any + +import numpy as np +import torch +from PIL import Image + +from sglang.multimodal_gen.dataset.dataloader.schema import pyarrow_schema_i2v +from sglang.multimodal_gen.runtime.distributed import get_local_torch_device +from sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context +from sglang.multimodal_gen.runtime.pipelines.preprocess.preprocess_pipeline_base import ( + BasePreprocessPipeline, +) +from sglang.multimodal_gen.runtime.pipelines.stages import ( + ImageEncodingStage, + TextEncodingStage, +) +from sglang.multimodal_gen.runtime.server_args import ServerArgs + + +class PreprocessPipeline_I2V(BasePreprocessPipeline): + """I2V preprocessing pipeline implementation.""" + + _required_config_modules = [ + "text_encoder", + "tokenizer", + "vae", + "image_encoder", + "image_processor", + ] + + def create_pipeline_stages(self, server_args: ServerArgs): + self.add_stage( + stage_name="prompt_encoding_stage", + stage=TextEncodingStage( + text_encoders=[self.get_module("text_encoder")], + tokenizers=[self.get_module("tokenizer")], + ), + ) + + self.add_stage( + stage_name="image_encoding_stage", + stage=ImageEncodingStage( + image_encoder=self.get_module("image_encoder"), + image_processor=self.get_module("image_processor"), + ), + ) + + def get_pyarrow_schema(self): + """Return the PyArrow schema for I2V pipeline.""" + return pyarrow_schema_i2v + + def get_extra_features( + self, valid_data: dict[str, Any], server_args: ServerArgs + ) -> dict[str, Any]: + + # TODO(will): move these to cpu at some point + self.get_module("image_encoder").to(get_local_torch_device()) + self.get_module("vae").to(get_local_torch_device()) + + features = {} + """Get CLIP features from the first frame of each video.""" + first_frame = valid_data["pixel_values"][:, :, 0, :, :].permute( + 0, 2, 3, 1 + ) # (B, C, T, H, W) -> (B, H, W, C) + _, _, num_frames, height, width = valid_data["pixel_values"].shape + # latent_height = height // self.get_module( + # "vae").spatial_compression_ratio + # latent_width = width // self.get_module("vae").spatial_compression_ratio + + processed_images = [] + # Frame has values between -1 and 1 + for frame in first_frame: + frame = (frame + 1) * 127.5 + frame_pil = Image.fromarray(frame.cpu().numpy().astype(np.uint8)) + processed_img = self.get_module("image_processor")( + images=frame_pil, return_tensors="pt" + ) + processed_images.append(processed_img) + + # Get CLIP features + pixel_values = torch.cat( + [img["pixel_values"] for img in processed_images], dim=0 + ).to(get_local_torch_device()) + with torch.no_grad(): + image_inputs = {"pixel_values": pixel_values} + with set_forward_context(current_timestep=0, attn_metadata=None): + clip_features = self.get_module("image_encoder")(**image_inputs) + clip_features = clip_features.last_hidden_state + + features["clip_feature"] = clip_features + """Get VAE features from the first frame of each video""" + video_conditions = [] + for frame in first_frame: + processed_img = frame.to(device="cpu", dtype=torch.float32) + processed_img = processed_img.unsqueeze(0).permute(0, 3, 1, 2).unsqueeze(2) + # (B, H, W, C) -> (B, C, 1, H, W) + video_condition = torch.cat( + [ + processed_img, + processed_img.new_zeros( + processed_img.shape[0], + processed_img.shape[1], + num_frames - 1, + height, + width, + ), + ], + dim=2, + ) + video_condition = video_condition.to( + device=get_local_torch_device(), dtype=torch.float32 + ) + video_conditions.append(video_condition) + + video_conditions = torch.cat(video_conditions, dim=0) + + with torch.autocast(device_type="cuda", dtype=torch.float32, enabled=True): + encoder_outputs = self.get_module("vae").encode(video_conditions) + + latent_condition = encoder_outputs.mean + if ( + hasattr(self.get_module("vae"), "shift_factor") + and self.get_module("vae").shift_factor is not None + ): + if isinstance(self.get_module("vae").shift_factor, torch.Tensor): + latent_condition -= self.get_module("vae").shift_factor.to( + latent_condition.device, latent_condition.dtype + ) + else: + latent_condition -= self.get_module("vae").shift_factor + + if isinstance(self.get_module("vae").scaling_factor, torch.Tensor): + latent_condition = latent_condition * self.get_module( + "vae" + ).scaling_factor.to(latent_condition.device, latent_condition.dtype) + else: + latent_condition = latent_condition * self.get_module("vae").scaling_factor + + # mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, + # latent_width) + # mask_lat_size[:, :, list(range(1, num_frames))] = 0 + # first_frame_mask = mask_lat_size[:, :, 0:1] + # first_frame_mask = torch.repeat_interleave( + # first_frame_mask, + # dim=2, + # repeats=self.get_module("vae").temporal_compression_ratio) + # mask_lat_size = torch.concat( + # [first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) + # mask_lat_size = mask_lat_size.view( + # batch_size, -1, + # self.get_module("vae").temporal_compression_ratio, latent_height, + # latent_width) + # mask_lat_size = mask_lat_size.transpose(1, 2) + # mask_lat_size = mask_lat_size.to(latent_condition.device) + + # image_latent = torch.concat([mask_lat_size, latent_condition], dim=1) + + features["first_frame_latent"] = latent_condition + + return features + + def create_record( + self, + video_name: str, + vae_latent: np.ndarray, + text_embedding: np.ndarray, + valid_data: dict[str, Any], + idx: int, + extra_features: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """Create a record for the Parquet dataset with CLIP features.""" + record = super().create_record( + video_name=video_name, + vae_latent=vae_latent, + text_embedding=text_embedding, + valid_data=valid_data, + idx=idx, + extra_features=extra_features, + ) + + if extra_features and "clip_feature" in extra_features: + clip_feature = extra_features["clip_feature"] + record.update( + { + "clip_feature_bytes": clip_feature.tobytes(), + "clip_feature_shape": list(clip_feature.shape), + "clip_feature_dtype": str(clip_feature.dtype), + } + ) + else: + record.update( + { + "clip_feature_bytes": b"", + "clip_feature_shape": [], + "clip_feature_dtype": "", + } + ) + + if extra_features and "first_frame_latent" in extra_features: + first_frame_latent = extra_features["first_frame_latent"] + record.update( + { + "first_frame_latent_bytes": first_frame_latent.tobytes(), + "first_frame_latent_shape": list(first_frame_latent.shape), + "first_frame_latent_dtype": str(first_frame_latent.dtype), + } + ) + else: + record.update( + { + "first_frame_latent_bytes": b"", + "first_frame_latent_shape": [], + "first_frame_latent_dtype": "", + } + ) + + if extra_features and "pil_image" in extra_features: + pil_image = extra_features["pil_image"] + record.update( + { + "pil_image_bytes": pil_image.tobytes(), + "pil_image_shape": list(pil_image.shape), + "pil_image_dtype": str(pil_image.dtype), + } + ) + else: + record.update( + { + "pil_image_bytes": b"", + "pil_image_shape": [], + "pil_image_dtype": "", + } + ) + + return record + + +EntryClass = PreprocessPipeline_I2V diff --git a/python/sglang/multimodal_gen/runtime/architectures/preprocess/preprocess_pipeline_ode_trajectory.py b/python/sglang/multimodal_gen/runtime/architectures/preprocess/preprocess_pipeline_ode_trajectory.py new file mode 100644 index 000000000..950b38c36 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/architectures/preprocess/preprocess_pipeline_ode_trajectory.py @@ -0,0 +1,355 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +ODE Trajectory Data Preprocessing pipeline implementation. + +This module contains an implementation of the ODE Trajectory Data Preprocessing pipeline +using the modular pipeline architecture. + +Sec 4.3 of CausVid paper: https://arxiv.org/pdf/2412.07772 +""" + +import os +from collections.abc import Iterator +from typing import Any + +import pyarrow as pa +import torch +from torch.utils.data import DataLoader +from torchdata.stateful_dataloader import StatefulDataLoader +from tqdm import tqdm + +from sglang.multimodal_gen.configs.sample import SamplingParams +from sglang.multimodal_gen.dataset import gettextdataset +from sglang.multimodal_gen.dataset.dataloader.parquet_io import ( + ParquetDatasetWriter, + records_to_table, +) +from sglang.multimodal_gen.dataset.dataloader.record_schema import ( + ode_text_only_record_creator, +) +from sglang.multimodal_gen.dataset.dataloader.schema import ( + pyarrow_schema_ode_trajectory_text_only, +) +from sglang.multimodal_gen.runtime.models.schedulers.scheduling_self_forcing_flow_match import ( + SelfForcingFlowMatchScheduler, +) +from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import Req +from sglang.multimodal_gen.runtime.pipelines.preprocess.preprocess_pipeline_base import ( + BasePreprocessPipeline, +) +from sglang.multimodal_gen.runtime.pipelines.stages import ( + DecodingStage, + DenoisingStage, + InputValidationStage, + LatentPreparationStage, + TextEncodingStage, + TimestepPreparationStage, +) +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import save_decoded_latents_as_video, shallow_asdict + +logger = init_logger(__name__) + + +class PreprocessPipeline_ODE_Trajectory(BasePreprocessPipeline): + """ODE Trajectory preprocessing pipeline implementation.""" + + _required_config_modules = [ + "text_encoder", + "tokenizer", + "vae", + "transformer", + "scheduler", + ] + preprocess_dataloader: StatefulDataLoader + preprocess_loader_iter: Iterator[dict[str, Any]] + pbar: Any + num_processed_samples: int + + def get_pyarrow_schema(self) -> pa.Schema: + """Return the PyArrow schema for ODE Trajectory pipeline.""" + return pyarrow_schema_ode_trajectory_text_only + + def create_pipeline_stages(self, server_args: ServerArgs): + """Set up pipeline stages with proper dependency injection.""" + assert server_args.pipeline_config.flow_shift == 5 + self.modules["scheduler"] = SelfForcingFlowMatchScheduler( + shift=server_args.pipeline_config.flow_shift, + sigma_min=0.0, + extra_one_step=True, + ) + self.modules["scheduler"].set_timesteps( + num_inference_steps=48, denoising_strength=1.0 + ) + + self.add_stage( + stage_name="input_validation_stage", stage=InputValidationStage() + ) + self.add_stage( + stage_name="prompt_encoding_stage", + stage=TextEncodingStage( + text_encoders=[self.get_module("text_encoder")], + tokenizers=[self.get_module("tokenizer")], + ), + ) + self.add_stage( + stage_name="timestep_preparation_stage", + stage=TimestepPreparationStage(scheduler=self.get_module("scheduler")), + ) + self.add_stage( + stage_name="latent_preparation_stage", + stage=LatentPreparationStage( + scheduler=self.get_module("scheduler"), + transformer=self.get_module("transformer", None), + ), + ) + self.add_stage( + stage_name="denoising_stage", + stage=DenoisingStage( + transformer=self.get_module("transformer"), + scheduler=self.get_module("scheduler"), + pipeline=self, + ), + ) + self.add_stage( + stage_name="decoding_stage", stage=DecodingStage(vae=self.get_module("vae")) + ) + + def preprocess_text_and_trajectory(self, server_args: ServerArgs, args): + """Preprocess text-only data and generate trajectory information.""" + + for batch_idx, data in enumerate(self.pbar): + if data is None: + continue + + with torch.inference_mode(): + # For text-only processing, we only need text data + # Filter out samples without text + valid_indices = [] + for i, text in enumerate(data["text"]): + if text and text.strip(): # Check if text is not empty + valid_indices.append(i) + self.num_processed_samples += len(valid_indices) + + if not valid_indices: + continue + + # Create new batch with only valid samples (text-only) + valid_data = { + "text": [data["text"][i] for i in valid_indices], + "path": [data["path"][i] for i in valid_indices], + } + + # Add fps and duration if available in data + if "fps" in data: + valid_data["fps"] = [data["fps"][i] for i in valid_indices] + if "duration" in data: + valid_data["duration"] = [ + data["duration"][i] for i in valid_indices + ] + + batch_captions = valid_data["text"] + # Encode text using the standalone TextEncodingStage API + prompt_embeds_list, prompt_masks_list = ( + self.prompt_encoding_stage.encode_text( + batch_captions, + server_args, + encoder_index=[0], + return_attention_mask=True, + ) + ) + prompt_embeds = prompt_embeds_list[0] + prompt_attention_masks = prompt_masks_list[0] + assert prompt_embeds.shape[0] == prompt_attention_masks.shape[0] + + sampling_params = SamplingParams.from_pretrained(args.model_path) + + # encode negative prompt for trajectory collection + if ( + sampling_params.guidance_scale > 1 + and sampling_params.negative_prompt is not None + ): + negative_prompt_embeds_list, negative_prompt_masks_list = ( + self.prompt_encoding_stage.encode_text( + sampling_params.negative_prompt, + server_args, + encoder_index=[0], + return_attention_mask=True, + ) + ) + negative_prompt_embed = negative_prompt_embeds_list[0][0] + negative_prompt_attention_mask = negative_prompt_masks_list[0][0] + else: + negative_prompt_embed = None + negative_prompt_attention_mask = None + + trajectory_latents = [] + trajectory_timesteps = [] + trajectory_decoded = [] + + for i, (prompt_embed, prompt_attention_mask) in enumerate( + zip(prompt_embeds, prompt_attention_masks, strict=False) + ): + prompt_embed = prompt_embed.unsqueeze(0) + prompt_attention_mask = prompt_attention_mask.unsqueeze(0) + + # Collect the trajectory data (text-to-video generation) + batch = Req( + **shallow_asdict(sampling_params), + ) + batch.prompt_embeds = [prompt_embed] + batch.prompt_attention_mask = [prompt_attention_mask] + batch.negative_prompt_embeds = [negative_prompt_embed] + batch.negative_attention_mask = [negative_prompt_attention_mask] + batch.num_inference_steps = 48 + batch.return_trajectory_latents = True + # Enabling this will save the decoded trajectory videos. + # Used for debugging. + batch.return_trajectory_decoded = False + batch.height = args.max_height + batch.width = args.max_width + batch.fps = args.train_fps + batch.guidance_scale = 6.0 + batch.do_classifier_free_guidance = True + + result_batch = self.input_validation_stage(batch, server_args) + result_batch = self.timestep_preparation_stage(batch, server_args) + result_batch = self.latent_preparation_stage( + result_batch, server_args + ) + result_batch = self.denoising_stage(result_batch, server_args) + result_batch = self.decoding_stage(result_batch, server_args) + + trajectory_latents.append(result_batch.trajectory_latents.cpu()) + trajectory_timesteps.append(result_batch.trajectory_timesteps.cpu()) + trajectory_decoded.append(result_batch.trajectory_decoded) + + # Prepare extra features for text-only processing + extra_features = { + "trajectory_latents": trajectory_latents, + "trajectory_timesteps": trajectory_timesteps, + } + + if batch.return_trajectory_decoded: + for i, decoded_frames in enumerate(trajectory_decoded): + for j, decoded_frame in enumerate(decoded_frames): + save_decoded_latents_as_video( + decoded_frame, + f"decoded_videos/trajectory_decoded_{i}_{j}.mp4", + args.train_fps, + ) + + # Prepare batch data for Parquet dataset + batch_data: list[dict[str, Any]] = [] + + # Add progress bar for saving outputs + save_pbar = tqdm( + enumerate(valid_data["path"]), + desc="Saving outputs", + unit="item", + leave=False, + ) + + for idx, video_path in save_pbar: + video_name = os.path.basename(video_path).split(".")[0] + + # Convert tensors to numpy arrays + text_embedding = prompt_embeds[idx].cpu().numpy() + + # Get extra features for this sample + sample_extra_features = {} + if extra_features: + for key, value in extra_features.items(): + if isinstance(value, torch.Tensor): + sample_extra_features[key] = value[idx].cpu().numpy() + else: + assert isinstance(value, list) + if isinstance(value[idx], torch.Tensor): + sample_extra_features[key] = ( + value[idx].cpu().float().numpy() + ) + else: + sample_extra_features[key] = value[idx] + + # Create record for Parquet dataset (text-only ODE schema) + record: dict[str, Any] = ode_text_only_record_creator( + video_name=video_name, + text_embedding=text_embedding, + caption=valid_data["text"][idx], + trajectory_latents=sample_extra_features["trajectory_latents"], + trajectory_timesteps=sample_extra_features[ + "trajectory_timesteps" + ], + ) + batch_data.append(record) + + if batch_data: + write_pbar = tqdm( + total=1, desc="Writing to Parquet dataset", unit="batch" + ) + table = records_to_table(batch_data, self.get_pyarrow_schema()) + write_pbar.update(1) + write_pbar.close() + + if not hasattr(self, "dataset_writer"): + self.dataset_writer = ParquetDatasetWriter( + out_dir=self.combined_parquet_dir, + samples_per_file=args.samples_per_file, + ) + self.dataset_writer.append_table(table) + + logger.info("Collected batch with %s samples", len(table)) + + if self.num_processed_samples >= args.flush_frequency: + written = self.dataset_writer.flush() + logger.info("Flushed %s samples to parquet", written) + self.num_processed_samples = 0 + + # Final flush for any remaining samples + if hasattr(self, "dataset_writer"): + written = self.dataset_writer.flush(write_remainder=True) + if written: + logger.info("Final flush wrote %s samples", written) + + def forward(self, batch: Req, server_args: ServerArgs, args): + if not self.post_init_called: + self.post_init() + + self.local_rank = int(os.getenv("RANK", 0)) + os.makedirs(args.output_dir, exist_ok=True) + # Create directory for combined data + self.combined_parquet_dir = os.path.join( + args.output_dir, "combined_parquet_dataset" + ) + os.makedirs(self.combined_parquet_dir, exist_ok=True) + + # Loading dataset + train_dataset = gettextdataset(args) + + self.preprocess_dataloader = DataLoader( + train_dataset, + batch_size=args.preprocess_video_batch_size, + num_workers=args.dataloader_num_workers, + ) + + self.preprocess_loader_iter = iter(self.preprocess_dataloader) + + self.num_processed_samples = 0 + # Add progress bar for video preprocessing + self.pbar = tqdm( + self.preprocess_loader_iter, + desc="Processing videos", + unit="batch", + disable=self.local_rank != 0, + ) + + # Initialize class variables for data sharing + self.video_data: dict[str, Any] = {} # Store video metadata and paths + self.latent_data: dict[str, Any] = {} # Store latent tensors + self.preprocess_text_and_trajectory(server_args, args) + + +EntryClass = PreprocessPipeline_ODE_Trajectory diff --git a/python/sglang/multimodal_gen/runtime/architectures/preprocess/preprocess_pipeline_t2v.py b/python/sglang/multimodal_gen/runtime/architectures/preprocess/preprocess_pipeline_t2v.py new file mode 100644 index 000000000..d47ab9aec --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/architectures/preprocess/preprocess_pipeline_t2v.py @@ -0,0 +1,26 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +T2V Data Preprocessing pipeline implementation. + +This module contains an implementation of the T2V Data Preprocessing pipeline +using the modular pipeline architecture. +""" +from sglang.multimodal_gen.dataset.dataloader.schema import pyarrow_schema_t2v +from sglang.multimodal_gen.runtime.pipelines.preprocess.preprocess_pipeline_base import ( + BasePreprocessPipeline, +) + + +class PreprocessPipeline_T2V(BasePreprocessPipeline): + """T2V preprocessing pipeline implementation.""" + + _required_config_modules = ["text_encoder", "tokenizer", "vae"] + + def get_pyarrow_schema(self): + """Return the PyArrow schema for T2V pipeline.""" + return pyarrow_schema_t2v + + +EntryClass = PreprocessPipeline_T2V diff --git a/python/sglang/multimodal_gen/runtime/architectures/preprocess/preprocess_pipeline_text.py b/python/sglang/multimodal_gen/runtime/architectures/preprocess/preprocess_pipeline_text.py new file mode 100644 index 000000000..3906f09a5 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/architectures/preprocess/preprocess_pipeline_text.py @@ -0,0 +1,200 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +Text-only Data Preprocessing pipeline implementation. + +This module contains an implementation of the Text-only Data Preprocessing pipeline +using the modular pipeline architecture, based on the ODE Trajectory preprocessing. +""" + +import os +from collections.abc import Iterator +from typing import Any + +import torch +from torch.utils.data import DataLoader +from torchdata.stateful_dataloader import StatefulDataLoader +from tqdm import tqdm + +from sglang.multimodal_gen.dataset import gettextdataset +from sglang.multimodal_gen.dataset.dataloader.parquet_io import ( + ParquetDatasetWriter, + records_to_table, +) +from sglang.multimodal_gen.dataset.dataloader.record_schema import ( + text_only_record_creator, +) +from sglang.multimodal_gen.dataset.dataloader.schema import pyarrow_schema_text_only +from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import Req +from sglang.multimodal_gen.runtime.pipelines.preprocess.preprocess_pipeline_base import ( + BasePreprocessPipeline, +) +from sglang.multimodal_gen.runtime.pipelines.stages import TextEncodingStage +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class PreprocessPipeline_Text(BasePreprocessPipeline): + """Text-only preprocessing pipeline implementation.""" + + _required_config_modules = ["text_encoder", "tokenizer"] + preprocess_dataloader: StatefulDataLoader + preprocess_loader_iter: Iterator[dict[str, Any]] + pbar: Any + num_processed_samples: int = 0 + + def get_pyarrow_schema(self): + """Return the PyArrow schema for text-only pipeline.""" + return pyarrow_schema_text_only + + def create_pipeline_stages(self, server_args: ServerArgs): + """Set up pipeline stages with proper dependency injection.""" + self.add_stage( + stage_name="prompt_encoding_stage", + stage=TextEncodingStage( + text_encoders=[self.get_module("text_encoder")], + tokenizers=[self.get_module("tokenizer")], + ), + ) + + def preprocess_text_only(self, server_args: ServerArgs, args): + """Preprocess text-only data.""" + + for batch_idx, data in enumerate(self.pbar): + if data is None: + continue + + with torch.inference_mode(): + # For text-only processing, we only need text data + # Filter out samples without text + valid_indices = [] + for i, text in enumerate(data["text"]): + if text and text.strip(): # Check if text is not empty + valid_indices.append(i) + self.num_processed_samples += len(valid_indices) + + if not valid_indices: + continue + + # Create new batch with only valid samples (text-only) + valid_data = { + "text": [data["text"][i] for i in valid_indices], + "path": [data["path"][i] for i in valid_indices], + } + + batch_captions = valid_data["text"] + # Encode text using the standalone TextEncodingStage API + prompt_embeds_list, prompt_masks_list = ( + self.prompt_encoding_stage.encode_text( + batch_captions, + server_args, + encoder_index=[0], + return_attention_mask=True, + ) + ) + prompt_embeds = prompt_embeds_list[0] + prompt_attention_masks = prompt_masks_list[0] + assert prompt_embeds.shape[0] == prompt_attention_masks.shape[0] + + logger.info("===== prompt_embeds: %s", prompt_embeds.shape) + logger.info( + "===== prompt_attention_masks: %s", prompt_attention_masks.shape + ) + + # Prepare batch data for Parquet dataset + batch_data = [] + + # Add progress bar for saving outputs + save_pbar = tqdm( + enumerate(valid_data["path"]), + desc="Saving outputs", + unit="item", + leave=False, + ) + + for idx, text_path in save_pbar: + text_name = os.path.basename(text_path).split(".")[0] + + # Convert tensors to numpy arrays + text_embedding = prompt_embeds[idx].cpu().numpy() + + # Create record for Parquet dataset (text-only schema) + record = text_only_record_creator( + text_name=text_name, + text_embedding=text_embedding, + caption=valid_data["text"][idx], + ) + batch_data.append(record) + + if batch_data: + write_pbar = tqdm( + total=1, desc="Writing to Parquet dataset", unit="batch" + ) + table = records_to_table(batch_data, pyarrow_schema_text_only) + write_pbar.update(1) + write_pbar.close() + + if not hasattr(self, "dataset_writer"): + self.dataset_writer = ParquetDatasetWriter( + out_dir=self.combined_parquet_dir, + samples_per_file=args.samples_per_file, + ) + self.dataset_writer.append_table(table) + + logger.info("Collected batch with %s samples", len(table)) + + if self.num_processed_samples >= args.flush_frequency: + written = self.dataset_writer.flush() + logger.info("Flushed %s samples to parquet", written) + self.num_processed_samples = 0 + + # Final flush for any remaining samples + if hasattr(self, "dataset_writer"): + written = self.dataset_writer.flush(write_remainder=True) + if written: + logger.info("Final flush wrote %s samples", written) + + # Text-only record creation moved to sglang.multimodal_gen.dataset.dataloader.record_schema + + def forward(self, batch: Req, server_args: ServerArgs, args): + if not self.post_init_called: + self.post_init() + + self.local_rank = int(os.getenv("RANK", 0)) + os.makedirs(args.output_dir, exist_ok=True) + # Create directory for combined data + self.combined_parquet_dir = os.path.join( + args.output_dir, "combined_parquet_dataset" + ) + os.makedirs(self.combined_parquet_dir, exist_ok=True) + + # Loading text dataset + train_dataset = gettextdataset(args) + + self.preprocess_dataloader = DataLoader( + train_dataset, + batch_size=args.preprocess_video_batch_size, + num_workers=args.dataloader_num_workers, + ) + + self.preprocess_loader_iter = iter(self.preprocess_dataloader) + + self.num_processed_samples = 0 + # Add progress bar for text preprocessing + self.pbar = tqdm( + self.preprocess_loader_iter, + desc="Processing text", + unit="batch", + disable=self.local_rank != 0, + ) + + # Initialize class variables for data sharing + self.text_data: dict[str, Any] = {} # Store text metadata and paths + + self.preprocess_text_only(server_args, args) + + +EntryClass = PreprocessPipeline_Text diff --git a/python/sglang/multimodal_gen/runtime/architectures/preprocess/preprocess_stages.py b/python/sglang/multimodal_gen/runtime/architectures/preprocess/preprocess_stages.py new file mode 100644 index 000000000..126ab05d6 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/architectures/preprocess/preprocess_stages.py @@ -0,0 +1,134 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +import random +from collections.abc import Callable +from typing import cast + +import numpy as np +import torch +import torchvision +from einops import rearrange +from torchvision import transforms + +from sglang.multimodal_gen.configs.configs import VideoLoaderType +from sglang.multimodal_gen.dataset.transform import ( + CenterCropResizeVideo, + TemporalRandomCrop, +) +from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import ( + PreprocessBatch, + Req, +) +from sglang.multimodal_gen.runtime.pipelines.stages.base import PipelineStage +from sglang.multimodal_gen.runtime.server_args import ServerArgs, WorkloadType + + +class VideoTransformStage(PipelineStage): + """ + Crop a video in temporal dimension. + """ + + def __init__( + self, + train_fps: int, + num_frames: int, + max_height: int, + max_width: int, + do_temporal_sample: bool, + ) -> None: + self.train_fps = train_fps + self.num_frames = num_frames + if do_temporal_sample: + self.temporal_sample_fn: Callable | None = TemporalRandomCrop(num_frames) + else: + self.temporal_sample_fn = None + + self.video_transform = transforms.Compose( + [ + CenterCropResizeVideo((max_height, max_width)), + ] + ) + + def forward(self, batch: Req, server_args: ServerArgs) -> Req: + batch = cast(PreprocessBatch, batch) + assert isinstance(batch.fps, list) + assert isinstance(batch.num_frames, list) + + if batch.data_type != "video": + return batch + + if len(batch.video_loader) == 0: + raise ValueError("Video loader is not set") + + video_pixel_batch = [] + + for i in range(len(batch.video_loader)): + frame_interval = batch.fps[i] / self.train_fps + start_frame_idx = 0 + frame_indices = np.arange( + start_frame_idx, batch.num_frames[i], frame_interval + ).astype(int) + if len(frame_indices) > self.num_frames: + if self.temporal_sample_fn is not None: + begin_index, end_index = self.temporal_sample_fn(len(frame_indices)) + frame_indices = frame_indices[begin_index:end_index] + else: + frame_indices = frame_indices[: self.num_frames] + + if ( + server_args.preprocess_config.video_loader_type + == VideoLoaderType.TORCHCODEC + ): + video = batch.video_loader[i].get_frames_at(frame_indices).data + elif ( + server_args.preprocess_config.video_loader_type + == VideoLoaderType.TORCHVISION + ): + video, _, _ = torchvision.io.read_video( + batch.video_loader[i], output_format="TCHW" + ) + video = video[frame_indices] + else: + raise ValueError( + f"Invalid video loader type: {server_args.preprocess_config.video_loader_type}" + ) + video = self.video_transform(video) + video_pixel_batch.append(video) + + video_pixel_values = torch.stack(video_pixel_batch) + video_pixel_values = rearrange(video_pixel_values, "b t c h w -> b c t h w") + video_pixel_values = video_pixel_values.to(torch.uint8) + + if server_args.workload_type == WorkloadType.I2V: + batch.pil_image = video_pixel_values[:, :, 0, :, :] + + video_pixel_values = video_pixel_values.float() / 255.0 + batch.latents = video_pixel_values + batch.num_frames = [video_pixel_values.shape[2]] * len(batch.video_loader) + batch.height = [video_pixel_values.shape[3]] * len(batch.video_loader) + batch.width = [video_pixel_values.shape[4]] * len(batch.video_loader) + return cast(Req, batch) + + +class TextTransformStage(PipelineStage): + """ + Process text data according to the cfg rate. + """ + + def __init__(self, cfg_uncondition_drop_rate: float, seed: int) -> None: + self.cfg_rate = cfg_uncondition_drop_rate + self.rng = random.Random(seed) + + def forward(self, batch: Req, server_args: ServerArgs) -> Req: + batch = cast(PreprocessBatch, batch) + + prompts = [] + for prompt in batch.prompt: + if not isinstance(prompt, list): + prompt = [prompt] + prompt = self.rng.choice(prompt) + prompt = prompt if self.rng.random() > self.cfg_rate else "" + prompts.append(prompt) + + batch.prompt = prompts + return cast(Req, batch) diff --git a/python/sglang/multimodal_gen/runtime/architectures/preprocess/v1_preprocess.py b/python/sglang/multimodal_gen/runtime/architectures/preprocess/v1_preprocess.py new file mode 100644 index 000000000..8a160069a --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/architectures/preprocess/v1_preprocess.py @@ -0,0 +1,147 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +import argparse +import os +from typing import Any + +from sglang.multimodal_gen import PipelineConfig +from sglang.multimodal_gen.configs.models.vaes import WanVAEConfig +from sglang.multimodal_gen.runtime.architectures.preprocess.preprocess_pipeline_i2v import ( + PreprocessPipeline_I2V, +) +from sglang.multimodal_gen.runtime.architectures.preprocess.preprocess_pipeline_ode_trajectory import ( + PreprocessPipeline_ODE_Trajectory, +) +from sglang.multimodal_gen.runtime.architectures.preprocess.preprocess_pipeline_t2v import ( + PreprocessPipeline_T2V, +) +from sglang.multimodal_gen.runtime.architectures.preprocess.preprocess_pipeline_text import ( + PreprocessPipeline_Text, +) +from sglang.multimodal_gen.runtime.distributed import ( + get_world_size, + maybe_init_distributed_environment_and_model_parallel, +) +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import maybe_download_model +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +def main(args) -> None: + args.model_path = maybe_download_model(args.model_path) + maybe_init_distributed_environment_and_model_parallel(1, 1) + num_gpus = int(os.environ["WORLD_SIZE"]) + assert num_gpus == 1, "Only support 1 GPU" + + pipeline_config = PipelineConfig.from_pretrained(args.model_path) + + kwargs: dict[str, Any] = {} + if args.preprocess_task == "text_only": + kwargs = { + "text_encoder_cpu_offload": False, + } + else: + # Full config for video/image processing + kwargs = { + "vae_precision": "fp32", + "vae_config": WanVAEConfig(load_encoder=True, load_decoder=True), + } + pipeline_config.update_config_from_dict(kwargs) + + server_args = ServerArgs( + model_path=args.model_path, + num_gpus=get_world_size(), + dit_cpu_offload=False, + vae_cpu_offload=False, + text_encoder_cpu_offload=False, + pipeline_config=pipeline_config, + ) + if args.preprocess_task == "t2v": + PreprocessPipeline = PreprocessPipeline_T2V + elif args.preprocess_task == "i2v": + PreprocessPipeline = PreprocessPipeline_I2V + elif args.preprocess_task == "text_only": + PreprocessPipeline = PreprocessPipeline_Text + elif args.preprocess_task == "ode_trajectory": + assert args.flow_shift is not None, "flow_shift is required for ode_trajectory" + server_args.pipeline_config.flow_shift = args.flow_shift + PreprocessPipeline = PreprocessPipeline_ODE_Trajectory + else: + raise ValueError( + f"Invalid preprocess task: {args.preprocess_task}. " + f"Valid options: t2v, i2v, ode_trajectory, text_only" + ) + + logger.info( + "Preprocess task: %s using %s", + args.preprocess_task, + PreprocessPipeline.__name__, + ) + + pipeline = PreprocessPipeline(args.model_path, server_args) + pipeline.forward(batch=None, server_args=server_args, args=args) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # dataset & dataloader + parser.add_argument("--model_path", type=str, default="data/mochi") + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--data_merge_path", type=str, required=True) + parser.add_argument("--num_frames", type=int, default=163) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=1, + help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", + ) + parser.add_argument( + "--preprocess_video_batch_size", + type=int, + default=2, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument("--samples_per_file", type=int, default=64) + parser.add_argument( + "--flush_frequency", + type=int, + default=256, + help="how often to save to parquet files", + ) + parser.add_argument( + "--num_latent_t", type=int, default=28, help="Number of latent timesteps." + ) + parser.add_argument("--max_height", type=int, default=480) + parser.add_argument("--max_width", type=int, default=848) + parser.add_argument("--video_length_tolerance_range", type=int, default=2.0) + parser.add_argument("--group_frame", action="store_true") # TODO + parser.add_argument("--group_resolution", action="store_true") # TODO + parser.add_argument("--flow_shift", type=float, default=None) + parser.add_argument( + "--preprocess_task", + type=str, + default="t2v", + choices=["t2v", "i2v", "text_only", "ode_trajectory"], + help="Type of preprocessing task to run", + ) + parser.add_argument("--train_fps", type=int, default=30) + parser.add_argument("--use_image_num", type=int, default=0) + parser.add_argument("--text_max_length", type=int, default=256) + parser.add_argument("--speed_factor", type=float, default=1.0) + parser.add_argument("--drop_short_ratio", type=float, default=1.0) + parser.add_argument("--do_temporal_sample", default=False, action="store_true") + # text encoder & vae & diffusion model + parser.add_argument("--text_encoder_name", type=str, default="google/t5-v1_1-xxl") + parser.add_argument("--cache_dir", type=str, default="./cache_dir") + parser.add_argument("--training_cfg_rate", type=float, default=0.0) + parser.add_argument( + "--output_dir", + type=str, + default=None, + help="The output directory where the model predictions and checkpoints will be written.", + ) + + args = parser.parse_args() + main(args) diff --git a/python/sglang/multimodal_gen/runtime/architectures/preprocess/v1_preprocessing_new.py b/python/sglang/multimodal_gen/runtime/architectures/preprocess/v1_preprocessing_new.py new file mode 100644 index 000000000..59f03618b --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/architectures/preprocess/v1_preprocessing_new.py @@ -0,0 +1,26 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +from sglang.multimodal_gen.runtime.distributed import ( + maybe_init_distributed_environment_and_model_parallel, +) +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.runtime.workflow.workflow_base import WorkflowBase +from sglang.multimodal_gen.utils import FlexibleArgumentParser + +logger = init_logger(__name__) + + +def main(server_args: ServerArgs) -> None: + maybe_init_distributed_environment_and_model_parallel(1, 1) + preprocess_workflow_cls = WorkflowBase.get_workflow_cls(server_args) + preprocess_workflow = preprocess_workflow_cls(server_args) + preprocess_workflow.run() + + +if __name__ == "__main__": + parser = FlexibleArgumentParser() + parser = ServerArgs.add_cli_args(parser) + args = parser.parse_args() + server_args = ServerArgs.from_cli_args(args) + main(server_args) diff --git a/python/sglang/multimodal_gen/runtime/architectures/preprocess/wan/__init__.py b/python/sglang/multimodal_gen/runtime/architectures/preprocess/wan/__init__.py new file mode 100644 index 000000000..af2eb7d10 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/architectures/preprocess/wan/__init__.py @@ -0,0 +1 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo diff --git a/python/sglang/multimodal_gen/runtime/architectures/preprocess/wan/wan_preprocess_pipelines.py b/python/sglang/multimodal_gen/runtime/architectures/preprocess/wan/wan_preprocess_pipelines.py new file mode 100644 index 000000000..47ec436ff --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/architectures/preprocess/wan/wan_preprocess_pipelines.py @@ -0,0 +1,118 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +from sglang.multimodal_gen.runtime.pipelines.composed_pipeline_base import ( + ComposedPipelineBase, +) +from sglang.multimodal_gen.runtime.pipelines.preprocess.preprocess_stages import ( + TextTransformStage, + VideoTransformStage, +) +from sglang.multimodal_gen.runtime.pipelines.stages import ( + EncodingStage, + ImageEncodingStage, + TextEncodingStage, +) +from sglang.multimodal_gen.runtime.pipelines.stages.image_encoding import ( + ImageVAEEncodingStage, +) +from sglang.multimodal_gen.runtime.server_args import ServerArgs + + +class PreprocessPipelineI2V(ComposedPipelineBase): + _required_config_modules = [ + "image_encoder", + "image_processor", + "text_encoder", + "tokenizer", + "vae", + ] + + def create_pipeline_stages(self, server_args: ServerArgs): + assert server_args.preprocess_config is not None + self.add_stage( + stage_name="text_transform_stage", + stage=TextTransformStage( + cfg_uncondition_drop_rate=server_args.preprocess_config.training_cfg_rate, + seed=server_args.preprocess_config.seed, + ), + ) + self.add_stage( + stage_name="prompt_encoding_stage", + stage=TextEncodingStage( + text_encoders=[self.get_module("text_encoder")], + tokenizers=[self.get_module("tokenizer")], + ), + ) + self.add_stage( + stage_name="video_transform_stage", + stage=VideoTransformStage( + train_fps=server_args.preprocess_config.train_fps, + num_frames=server_args.preprocess_config.num_frames, + max_height=server_args.preprocess_config.max_height, + max_width=server_args.preprocess_config.max_width, + do_temporal_sample=server_args.preprocess_config.do_temporal_sample, + ), + ) + if ( + self.get_module("image_encoder") is not None + and self.get_module("image_processor") is not None + ): + self.add_stage( + stage_name="image_encoding_stage", + stage=ImageEncodingStage( + image_encoder=self.get_module("image_encoder"), + image_processor=self.get_module("image_processor"), + ), + ) + self.add_stage( + stage_name="image_vae_encoding_stage", + stage=ImageVAEEncodingStage( + vae=self.get_module("vae"), + ), + ) + self.add_stage( + stage_name="video_encoding_stage", + stage=EncodingStage( + vae=self.get_module("vae"), + ), + ) + + +class PreprocessPipelineT2V(ComposedPipelineBase): + _required_config_modules = ["text_encoder", "tokenizer", "vae"] + + def create_pipeline_stages(self, server_args: ServerArgs): + assert server_args.preprocess_config is not None + self.add_stage( + stage_name="text_transform_stage", + stage=TextTransformStage( + cfg_uncondition_drop_rate=server_args.preprocess_config.training_cfg_rate, + seed=server_args.preprocess_config.seed, + ), + ) + self.add_stage( + stage_name="prompt_encoding_stage", + stage=TextEncodingStage( + text_encoders=[self.get_module("text_encoder")], + tokenizers=[self.get_module("tokenizer")], + ), + ) + self.add_stage( + stage_name="video_transform_stage", + stage=VideoTransformStage( + train_fps=server_args.preprocess_config.train_fps, + num_frames=server_args.preprocess_config.num_frames, + max_height=server_args.preprocess_config.max_height, + max_width=server_args.preprocess_config.max_width, + do_temporal_sample=server_args.preprocess_config.do_temporal_sample, + ), + ) + self.add_stage( + stage_name="video_encoding_stage", + stage=EncodingStage( + vae=self.get_module("vae"), + ), + ) + + +EntryClass = [PreprocessPipelineI2V, PreprocessPipelineT2V] diff --git a/python/sglang/multimodal_gen/runtime/distributed/__init__.py b/python/sglang/multimodal_gen/runtime/distributed/__init__.py new file mode 100644 index 000000000..9edfd5c6f --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/distributed/__init__.py @@ -0,0 +1,55 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +from sglang.multimodal_gen.runtime.distributed.communication_op import * +from sglang.multimodal_gen.runtime.distributed.group_coordinator import ( + get_local_torch_device, +) +from sglang.multimodal_gen.runtime.distributed.parallel_state import ( + cleanup_dist_env_and_memory, + get_dp_group, + get_dp_rank, + get_dp_world_size, + get_sp_group, + get_sp_parallel_rank, + get_sp_world_size, + get_tp_group, + get_tp_rank, + get_tp_world_size, + get_world_group, + get_world_rank, + get_world_size, + init_distributed_environment, + initialize_model_parallel, + maybe_init_distributed_environment_and_model_parallel, + model_parallel_is_initialized, +) +from sglang.multimodal_gen.runtime.distributed.utils import * + +__all__ = [ + # Initialization + "init_distributed_environment", + "initialize_model_parallel", + "cleanup_dist_env_and_memory", + "model_parallel_is_initialized", + "maybe_init_distributed_environment_and_model_parallel", + # World group + "get_world_group", + "get_world_rank", + "get_world_size", + # Data parallel group + "get_dp_group", + "get_dp_rank", + "get_dp_world_size", + # Sequence parallel group + "get_sp_group", + "get_sp_parallel_rank", + "get_sp_world_size", + # Tensor parallel group + "get_tp_group", + "get_tp_rank", + "get_tp_world_size", + # Get torch device + "get_local_torch_device", +] diff --git a/python/sglang/multimodal_gen/runtime/distributed/communication_op.py b/python/sglang/multimodal_gen/runtime/distributed/communication_op.py new file mode 100644 index 000000000..61672ca45 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/distributed/communication_op.py @@ -0,0 +1,55 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/communication_op.py + +import torch +import torch.distributed + +from sglang.multimodal_gen.runtime.distributed.parallel_state import ( + get_cfg_group, + get_sp_group, + get_tp_group, +) + + +def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: + """All-reduce the input tensor across model parallel group.""" + return get_tp_group().all_reduce(input_) + + +def tensor_model_parallel_all_gather( + input_: torch.Tensor, dim: int = -1 +) -> torch.Tensor: + """All-gather the input tensor across model parallel group.""" + return get_tp_group().all_gather(input_, dim) + + +# TODO: remove model, make it sequence_parallel +def sequence_model_parallel_all_to_all_4D( + input_: torch.Tensor, scatter_dim: int = 2, gather_dim: int = 1 +) -> torch.Tensor: + """All-to-all communication of 4D tensors (e.g. QKV matrices) across sequence parallel group.""" + return get_sp_group().all_to_all_4D(input_, scatter_dim, gather_dim) + + +def sequence_model_parallel_all_gather( + input_: torch.Tensor, dim: int = -1 +) -> torch.Tensor: + """All-gather the input tensor across model parallel group.""" + return get_sp_group().all_gather(input_, dim) + + +def cfg_model_parallel_all_gather( + input_: torch.Tensor, dim: int = -1, separate_tensors: bool = False +) -> torch.Tensor: + """All-gather the input tensor across model parallel group.""" + return get_cfg_group().all_gather(input_, dim, separate_tensors) + + +def cfg_model_parallel_all_reduce( + input_: torch.Tensor, + op: torch._C._distributed_c10d.ReduceOp = torch._C._distributed_c10d.ReduceOp.SUM, +) -> torch.Tensor: + """All-reduce the input tensor across CFG parallel group.""" + return get_cfg_group().all_reduce(input_, op=op) diff --git a/python/sglang/multimodal_gen/runtime/distributed/device_communicators/__init__.py b/python/sglang/multimodal_gen/runtime/distributed/device_communicators/__init__.py new file mode 100644 index 000000000..af2eb7d10 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/distributed/device_communicators/__init__.py @@ -0,0 +1 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo diff --git a/python/sglang/multimodal_gen/runtime/distributed/device_communicators/base_device_communicator.py b/python/sglang/multimodal_gen/runtime/distributed/device_communicators/base_device_communicator.py new file mode 100644 index 000000000..01bdf1c29 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/distributed/device_communicators/base_device_communicator.py @@ -0,0 +1,297 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/device_communicators/base_device_communicator.py + +from typing import Any + +import torch +import torch.distributed as dist +from torch import Tensor +from torch.distributed import ProcessGroup, ReduceOp + + +class DistributedAutograd: + """Collection of autograd functions for distributed operations. + + This class provides custom autograd functions for distributed operations like all_reduce, + all_gather, and all_to_all. Each operation is implemented as a static inner class with + proper forward and backward implementations. + """ + + class AllReduce(torch.autograd.Function): + """Differentiable all_reduce operation. + + The gradient of all_reduce is another all_reduce operation since the operation + combines values from all ranks equally. + """ + + @staticmethod + def forward( + ctx: Any, + group: ProcessGroup, + input_: Tensor, + op: dist.ReduceOp | None = None, + ) -> Tensor: + ctx.group = group + ctx.op = op + output = input_.clone() + dist.all_reduce(output, group=group, op=op) + return output + + @staticmethod + def backward(ctx: Any, grad_output: Tensor) -> tuple[None, Tensor, None]: + grad_output = grad_output.clone() + dist.all_reduce(grad_output, group=ctx.group, op=ctx.op) + return None, grad_output, None + + class AllGather(torch.autograd.Function): + """Differentiable all_gather operation. + + The operation gathers tensors from all ranks and concatenates them along a specified dimension. + The backward pass uses reduce_scatter to efficiently distribute gradients back to source ranks. + """ + + @staticmethod + def forward( + ctx: Any, group: ProcessGroup, input_: Tensor, world_size: int, dim: int + ) -> Tensor: + ctx.group = group + ctx.world_size = world_size + ctx.dim = dim + ctx.input_shape = input_.shape + + input_size = input_.size() + output_size = (input_size[0] * world_size,) + input_size[1:] + output_tensor = torch.empty( + output_size, dtype=input_.dtype, device=input_.device + ) + + dist.all_gather_into_tensor(output_tensor, input_, group=group) + + output_tensor = output_tensor.reshape((world_size,) + input_size) + output_tensor = output_tensor.movedim(0, dim) + output_tensor = output_tensor.reshape( + input_size[:dim] + + (world_size * input_size[dim],) + + input_size[dim + 1 :] + ) + return output_tensor + + @staticmethod + def backward(ctx: Any, grad_output: Tensor) -> tuple[None, Tensor, None, None]: + # Split the gradient tensor along the gathered dimension + dim_size = grad_output.size(ctx.dim) // ctx.world_size + grad_chunks = grad_output.reshape( + grad_output.shape[: ctx.dim] + + (ctx.world_size, dim_size) + + grad_output.shape[ctx.dim + 1 :] + ) + grad_chunks = grad_chunks.movedim(ctx.dim, 0) + + # Each rank only needs its corresponding gradient + grad_input = torch.empty( + ctx.input_shape, dtype=grad_output.dtype, device=grad_output.device + ) + dist.reduce_scatter_tensor( + grad_input, grad_chunks.contiguous(), group=ctx.group + ) + + return None, grad_input, None, None + + class AllToAll4D(torch.autograd.Function): + """Differentiable all_to_all operation specialized for 4D tensors. + + This operation is particularly useful for attention operations where we need to + redistribute data across ranks for efficient parallel processing. + + The operation supports two modes: + 1. scatter_dim=2, gather_dim=1: Used for redistributing attention heads + 2. scatter_dim=1, gather_dim=2: Used for redistributing sequence dimensions + """ + + @staticmethod + def forward( + ctx: Any, + group: ProcessGroup, + input_: Tensor, + world_size: int, + scatter_dim: int, + gather_dim: int, + ) -> Tensor: + ctx.group = group + ctx.world_size = world_size + ctx.scatter_dim = scatter_dim + ctx.gather_dim = gather_dim + + if world_size == 1: + return input_ + + assert ( + input_.dim() == 4 + ), f"input must be 4D tensor, got {input_.dim()} and shape {input_.shape}" + + if scatter_dim == 2 and gather_dim == 1: + bs, shard_seqlen, hn, hd = input_.shape + seqlen = shard_seqlen * world_size + shard_hn = hn // world_size + + input_ = input_.transpose(0, 2).contiguous() # hn, shard_seqlen, bs, hd + output = torch.empty_like(input_) + + dist.all_to_all_single( + output, input_, group=group + ) # hn, shard_seqlen, bs, hd + + output = torch.cat( + output.split(shard_hn), dim=1 + ) # sharded hn, seqlen, bs, hd + + output = output.transpose( + 0, 2 + ).contiguous() # bs, seqlen, sharded_hn, hd + + return output + elif scatter_dim == 1 and gather_dim == 2: + bs, seqlen, shard_hn, hd = input_.shape + hn = shard_hn * world_size + shard_seqlen = seqlen // world_size + + input_ = input_.transpose(0, 2).contiguous() # shard_hn, seqlen, bs, hd + + input_ = ( + input_.reshape(shard_hn, world_size, shard_seqlen, bs, hd) + .transpose(0, 1) + .reshape(shard_hn * world_size, shard_seqlen, bs, hd) + .contiguous() + ) + + output = torch.empty_like(input_) + + dist.all_to_all_single(output, input_, group=group) + + output = output.transpose( + 0, 2 + ).contiguous() # bs, seqlen, sharded_hn, hd + + return output + else: + raise RuntimeError( + f"Invalid scatter_dim={scatter_dim}, gather_dim={gather_dim}. " + f"Only (scatter_dim=2, gather_dim=1) and (scatter_dim=1, gather_dim=2) are supported." + ) + + @staticmethod + def backward( + ctx: Any, grad_output: Tensor + ) -> tuple[None, Tensor, None, None, None]: + if ctx.world_size == 1: + return None, grad_output, None, None, None + + # For backward pass, we swap scatter_dim and gather_dim + output = DistributedAutograd.AllToAll4D.apply( + ctx.group, grad_output, ctx.world_size, ctx.gather_dim, ctx.scatter_dim + ) + return None, output, None, None, None + + +class DeviceCommunicatorBase: + """ + Base class for device-specific communicator with autograd support. + It can use the `cpu_group` to initialize the communicator. + If the device has PyTorch integration (PyTorch can recognize its + communication backend), the `device_group` will also be given. + """ + + def __init__( + self, + cpu_group: ProcessGroup, + device: torch.device | None = None, + device_group: ProcessGroup | None = None, + unique_name: str = "", + ): + self.device = device or torch.device("cpu") + self.cpu_group = cpu_group + self.device_group = device_group + self.unique_name = unique_name + self.rank = dist.get_rank(cpu_group) + self.world_size = dist.get_world_size(cpu_group) + self.ranks = dist.get_process_group_ranks(cpu_group) + self.global_rank = dist.get_rank() + self.global_world_size = dist.get_world_size() + self.rank_in_group = dist.get_group_rank(self.cpu_group, self.global_rank) + + def all_reduce( + self, input_: torch.Tensor, op: dist.ReduceOp | None = ReduceOp.SUM + ) -> torch.Tensor: + """Performs an all_reduce operation with gradient support.""" + return DistributedAutograd.AllReduce.apply(self.device_group, input_, op) + + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: + """Performs an all_gather operation with gradient support.""" + if dim < 0: + dim += input_.dim() + return DistributedAutograd.AllGather.apply( + self.device_group, input_, self.world_size, dim + ) + + def all_to_all_4D( + self, input_: torch.Tensor, scatter_dim: int = 2, gather_dim: int = 1 + ) -> torch.Tensor: + """Performs a 4D all-to-all operation with gradient support.""" + return DistributedAutograd.AllToAll4D.apply( + self.device_group, input_, self.world_size, scatter_dim, gather_dim + ) + + def gather( + self, input_: torch.Tensor, dst: int = 0, dim: int = -1 + ) -> torch.Tensor | None: + """ + NOTE: We assume that the input tensor is on the same device across + all the ranks. + NOTE: `dst` is the local rank of the destination rank. + """ + world_size = self.world_size + assert ( + -input_.dim() <= dim < input_.dim() + ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + + # Allocate output tensor. + if self.rank_in_group == dst: + gather_list = [torch.empty_like(input_) for _ in range(world_size)] + else: + gather_list = None + # Gather. + torch.distributed.gather( + input_, gather_list, dst=self.ranks[dst], group=self.device_group + ) + if self.rank_in_group == dst: + output_tensor = torch.cat(gather_list, dim=dim) + else: + output_tensor = None + return output_tensor + + def send(self, tensor: torch.Tensor, dst: int | None = None) -> None: + """Sends a tensor to the destination rank in a non-blocking way""" + """NOTE: `dst` is the local rank of the destination rank.""" + if dst is None: + dst = (self.rank_in_group + 1) % self.world_size + torch.distributed.send(tensor, self.ranks[dst], self.device_group) + + def recv( + self, size: torch.Size, dtype: torch.dtype, src: int | None = None + ) -> torch.Tensor: + """Receives a tensor from the source rank.""" + """NOTE: `src` is the local rank of the source rank.""" + if src is None: + src = (self.rank_in_group - 1) % self.world_size + + tensor = torch.empty(size, dtype=dtype, device=self.device) + torch.distributed.recv(tensor, self.ranks[src], self.device_group) + return tensor + + def destroy(self) -> None: + pass diff --git a/python/sglang/multimodal_gen/runtime/distributed/device_communicators/cpu_communicator.py b/python/sglang/multimodal_gen/runtime/distributed/device_communicators/cpu_communicator.py new file mode 100644 index 000000000..434cf384d --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/distributed/device_communicators/cpu_communicator.py @@ -0,0 +1,161 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from: https://github.com/vllm-project/vllm/blob/main/vllm/distributed/device_communicators/cpu_communicator.py + +import os + +import torch +from torch.distributed import ProcessGroup + +from .base_device_communicator import DeviceCommunicatorBase + + +class CpuCommunicator(DeviceCommunicatorBase): + + def __init__( + self, + cpu_group: ProcessGroup, + device: torch.device | None = None, + device_group: ProcessGroup | None = None, + unique_name: str = "", + ): + from sglang.multimodal_gen.runtime.platforms import current_platform + from sglang.multimodal_gen.runtime.platforms.interface import CpuArchEnum + + super().__init__(cpu_group, device, device_group, unique_name) + self.dist_module = torch.distributed + + if ( + (current_platform.get_cpu_architecture() == CpuArchEnum.X86) + and hasattr(torch.ops._C, "init_shm_manager") + and unique_name.startswith("tp") + ): + self.dist_module = _CPUSHMDistributed(self) + + def all_reduce( + self, + input_: torch.Tensor, + op: torch.distributed.ReduceOp | None = torch.distributed.ReduceOp.SUM, + ) -> torch.Tensor: + self.dist_module.all_reduce(input_, group=self.device_group, op=op) + return input_ + + def gather( + self, input_: torch.Tensor, dst: int = 0, dim: int = -1 + ) -> torch.Tensor | None: + """ + NOTE: We assume that the input tensor is on the same device across + all the ranks. + NOTE: `dst` is the local rank of the destination rank. + """ + world_size = self.world_size + assert ( + -input_.dim() <= dim < input_.dim() + ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + + # Allocate output tensor. + if self.rank_in_group == dst: + gather_list = [torch.empty_like(input_) for _ in range(world_size)] + else: + gather_list = None + + # Gather. + self.dist_module.gather( + input_, gather_list, dst=self.ranks[dst], group=self.device_group + ) + + if self.rank_in_group == dst: + output_tensor = torch.cat(gather_list, dim=dim) + else: + output_tensor = None + return output_tensor + + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + input_size = input_.size() + # NOTE: we have to use concat-style all-gather here, + # stack-style all-gather has compatibility issues with + # torch.compile . see https://github.com/pytorch/pytorch/issues/138795 + output_size = (input_size[0] * self.world_size,) + input_size[1:] + # Allocate output tensor. + output_tensor = torch.empty( + output_size, dtype=input_.dtype, device=input_.device + ) + # All-gather. + self.dist_module.all_gather_into_tensor( + output_tensor, input_, group=self.device_group + ) + + # Reshape + output_tensor = output_tensor.reshape((self.world_size,) + input_size) + output_tensor = output_tensor.movedim(0, dim) + output_tensor = output_tensor.reshape( + input_size[:dim] + + (self.world_size * input_size[dim],) + + input_size[dim + 1 :] + ) + return output_tensor + + +class _CPUSHMDistributed: + + def __init__(self, communicator: CpuCommunicator): + instance_identifier = os.environ["VLLM_DIST_IDENT"] + unique_name = communicator.unique_name + instance_identifier = f"{instance_identifier}-{unique_name}" + self.communicator = communicator + + group_ranks = [str(rank) for rank in self.communicator.ranks] + shm_group_identifier = f"[{'-'.join(group_ranks)}]" + self.group_name = f"{instance_identifier}-{shm_group_identifier}-cpushm" + + self.handle = self._init_cpu_shm() + + def _init_cpu_shm(self) -> int: + handle = torch.ops._C.init_shm_manager( + self.group_name, + self.communicator.world_size, + self.communicator.rank, + ) + torch.distributed.barrier(self.communicator.device_group) + torch.ops._C.join_shm_manager( + handle, + self.group_name, + ) + torch.distributed.barrier(self.communicator.device_group) + + return int(handle) + + def all_reduce( + self, input: torch.Tensor, group: ProcessGroup | None = None + ) -> None: + torch.ops._C.shm_allreduce(self.handle, input) + + def gather( + self, + input: torch.Tensor, + gather_list: list[torch.Tensor] | None, + dst: int = -1, + group: ProcessGroup | None = None, + ) -> None: + # Note: different from the torch gather, here we use local dst rank. + torch.ops._C.shm_gather( + self.handle, + input, + gather_list, + torch.distributed.get_group_rank(group, dst), + ) + + def all_gather_into_tensor( + self, + output: torch.Tensor, + input: torch.Tensor, + group: ProcessGroup | None = None, + ) -> None: + torch.ops._C.shm_all_gather(self.handle, input, output) diff --git a/python/sglang/multimodal_gen/runtime/distributed/device_communicators/cuda_communicator.py b/python/sglang/multimodal_gen/runtime/distributed/device_communicators/cuda_communicator.py new file mode 100644 index 000000000..c128c69fc --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/distributed/device_communicators/cuda_communicator.py @@ -0,0 +1,79 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/device_communicators/cuda_communicator.py + +import torch +from torch.distributed import ProcessGroup + +from sglang.multimodal_gen.runtime.distributed.device_communicators.base_device_communicator import ( + DeviceCommunicatorBase, +) + + +class CudaCommunicator(DeviceCommunicatorBase): + + def __init__( + self, + cpu_group: ProcessGroup, + device: torch.device | None = None, + device_group: ProcessGroup | None = None, + unique_name: str = "", + ): + super().__init__(cpu_group, device, device_group, unique_name) + + from sglang.multimodal_gen.runtime.distributed.device_communicators.pynccl import ( + PyNcclCommunicator, + ) + + self.pynccl_comm: PyNcclCommunicator | None = None + if self.world_size > 1: + self.pynccl_comm = PyNcclCommunicator( + group=self.cpu_group, + device=self.device, + ) + + def all_reduce(self, input_, op: torch.distributed.ReduceOp | None = None): + pynccl_comm = self.pynccl_comm + assert pynccl_comm is not None + out = pynccl_comm.all_reduce(input_, op=op) + if out is None: + # fall back to the default all-reduce using PyTorch. + # this usually happens during testing. + # when we run the model, allreduce only happens for the TP + # group, where we always have either custom allreduce or pynccl. + out = input_.clone() + torch.distributed.all_reduce(out, group=self.device_group, op=op) + return out + + def send(self, tensor: torch.Tensor, dst: int | None = None) -> None: + """Sends a tensor to the destination rank in a non-blocking way""" + """NOTE: `dst` is the local rank of the destination rank.""" + if dst is None: + dst = (self.rank_in_group + 1) % self.world_size + + pynccl_comm = self.pynccl_comm + if pynccl_comm is not None and not pynccl_comm.disabled: + pynccl_comm.send(tensor, dst) + else: + torch.distributed.send(tensor, self.ranks[dst], self.device_group) + + def recv( + self, size: torch.Size, dtype: torch.dtype, src: int | None = None + ) -> torch.Tensor: + """Receives a tensor from the source rank.""" + """NOTE: `src` is the local rank of the source rank.""" + if src is None: + src = (self.rank_in_group - 1) % self.world_size + + tensor = torch.empty(size, dtype=dtype, device=self.device) + pynccl_comm = self.pynccl_comm + if pynccl_comm is not None and not pynccl_comm.disabled: + pynccl_comm.recv(tensor, src) + else: + torch.distributed.recv(tensor, self.ranks[src], self.device_group) + return tensor + + def destroy(self) -> None: + if self.pynccl_comm is not None: + self.pynccl_comm = None diff --git a/python/sglang/multimodal_gen/runtime/distributed/device_communicators/pynccl.py b/python/sglang/multimodal_gen/runtime/distributed/device_communicators/pynccl.py new file mode 100644 index 000000000..0ab2e1adb --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/distributed/device_communicators/pynccl.py @@ -0,0 +1,258 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/device_communicators/pynccl.py + +# ===================== import region ===================== +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup, ReduceOp + +from sglang.multimodal_gen.runtime.distributed.device_communicators.pynccl_wrapper import ( + NCCLLibrary, + buffer_type, + cudaStream_t, + ncclComm_t, + ncclDataTypeEnum, + ncclRedOpTypeEnum, + ncclUniqueId, +) +from sglang.multimodal_gen.runtime.distributed.utils import StatelessProcessGroup +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import current_stream + +logger = init_logger(__name__) + + +class PyNcclCommunicator: + + def __init__( + self, + group: ProcessGroup | StatelessProcessGroup, + device: int | str | torch.device, + library_path: str | None = None, + ): + """ + Args: + group: the process group to work on. If None, it will use the + default process group. + device: the device to bind the PyNcclCommunicator to. If None, + it will be bind to f"cuda:{local_rank}". + library_path: the path to the NCCL library. If None, it will + use the default library path. + It is the caller's responsibility to make sure each communicator + is bind to a unique device. + """ + if not isinstance(group, StatelessProcessGroup): + assert dist.is_initialized() + assert ( + dist.get_backend(group) != dist.Backend.NCCL + ), "PyNcclCommunicator should be attached to a non-NCCL group." + # note: this rank is the rank in the group + self.rank = dist.get_rank(group) + self.world_size = dist.get_world_size(group) + else: + self.rank = group.rank + self.world_size = group.world_size + + self.group = group + + # if world_size == 1, no need to create communicator + if self.world_size == 1: + self.available = False + self.disabled = True + return + try: + self.nccl = NCCLLibrary(library_path) + except Exception: + # disable because of missing NCCL library + # e.g. in a non-GPU environment + self.available = False + self.disabled = True + return + + self.available = True + self.disabled = False + + logger.info("sgl-diffusion is using nccl==%s", self.nccl.ncclGetVersion()) + + if self.rank == 0: + # get the unique id from NCCL + self.unique_id = self.nccl.ncclGetUniqueId() + else: + # construct an empty unique id + self.unique_id = ncclUniqueId() + + if not isinstance(group, StatelessProcessGroup): + tensor = torch.ByteTensor(list(self.unique_id.internal)) + ranks = dist.get_process_group_ranks(group) + # arg `src` in `broadcast` is the global rank + dist.broadcast(tensor, src=ranks[0], group=group) + byte_list = tensor.tolist() + for i, byte in enumerate(byte_list): + self.unique_id.internal[i] = byte + else: + self.unique_id = group.broadcast_obj(self.unique_id, src=0) + if isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) + # now `device` is a `torch.device` object + assert isinstance(device, torch.device) + self.device = device + # nccl communicator and stream will use this device + # `torch.cuda.device` is a context manager that changes the + # current cuda device to the specified one + with torch.cuda.device(device): + self.comm: ncclComm_t = self.nccl.ncclCommInitRank( + self.world_size, self.unique_id, self.rank + ) + + stream = current_stream() + # A small all_reduce for warmup. + data = torch.zeros(1, device=device) + self.all_reduce(data) + if stream is not None: + stream.synchronize() + del data + + def all_reduce( + self, in_tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, stream=None + ) -> torch.Tensor: + if self.disabled: + return None + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert in_tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {in_tensor.device}" + ) + + out_tensor = torch.empty_like(in_tensor) + + if stream is None: + stream = current_stream() + self.nccl.ncclAllReduce( + buffer_type(in_tensor.data_ptr()), + buffer_type(out_tensor.data_ptr()), + in_tensor.numel(), + ncclDataTypeEnum.from_torch(in_tensor.dtype), + ncclRedOpTypeEnum.from_torch(op), + self.comm, + cudaStream_t(stream.cuda_stream), + ) + return out_tensor + + def all_gather( + self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, stream=None + ): + if self.disabled: + return + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert input_tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {input_tensor.device}" + ) + if stream is None: + stream = current_stream() + self.nccl.ncclAllGather( + buffer_type(input_tensor.data_ptr()), + buffer_type(output_tensor.data_ptr()), + input_tensor.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), + self.comm, + cudaStream_t(stream.cuda_stream), + ) + + def reduce_scatter( + self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + op: ReduceOp = ReduceOp.SUM, + stream=None, + ): + if self.disabled: + return + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert input_tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {input_tensor.device}" + ) + if stream is None: + stream = current_stream() + self.nccl.ncclReduceScatter( + buffer_type(input_tensor.data_ptr()), + buffer_type(output_tensor.data_ptr()), + output_tensor.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), + ncclRedOpTypeEnum.from_torch(op), + self.comm, + cudaStream_t(stream.cuda_stream), + ) + + def send(self, tensor: torch.Tensor, dst: int, stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}" + ) + if stream is None: + stream = current_stream() + self.nccl.ncclSend( + buffer_type(tensor.data_ptr()), + tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), + dst, + self.comm, + cudaStream_t(stream.cuda_stream), + ) + + def recv(self, tensor: torch.Tensor, src: int, stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}" + ) + if stream is None: + stream = current_stream() + self.nccl.ncclRecv( + buffer_type(tensor.data_ptr()), + tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), + src, + self.comm, + cudaStream_t(stream.cuda_stream), + ) + + def broadcast(self, tensor: torch.Tensor, src: int, stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}" + ) + if stream is None: + stream = current_stream() + if src == self.rank: + sendbuff = buffer_type(tensor.data_ptr()) + # NCCL requires the sender also to have a receive buffer + recvbuff = buffer_type(tensor.data_ptr()) + else: + sendbuff = buffer_type() + recvbuff = buffer_type(tensor.data_ptr()) + self.nccl.ncclBroadcast( + sendbuff, + recvbuff, + tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), + src, + self.comm, + cudaStream_t(stream.cuda_stream), + ) diff --git a/python/sglang/multimodal_gen/runtime/distributed/device_communicators/pynccl_wrapper.py b/python/sglang/multimodal_gen/runtime/distributed/device_communicators/pynccl_wrapper.py new file mode 100644 index 000000000..40da43f49 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/distributed/device_communicators/pynccl_wrapper.py @@ -0,0 +1,450 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/device_communicators/pynccl_wrapper.py + +# This file is a pure Python wrapper for the NCCL library. +# The main purpose is to use NCCL combined with CUDA graph. +# Before writing this script, we tried the following approach: +# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself +# often gets stuck when initializing the NCCL communicator. +# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce` +# contains many other potential cuda APIs, that are not allowed during +# capturing the CUDA graph. For further details, please check +# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ . +# +# Another rejected idea is to write a C/C++ binding for NCCL. It is usually +# doable, but we often encounter issues related with nccl versions, and need +# to switch between different versions of NCCL. See +# https://github.com/NVIDIA/nccl/issues/1234 for more details. +# A C/C++ binding is not flexible enough to handle this. It requires +# recompilation of the code every time we want to switch between different +# versions. This current implementation, with a **pure** Python wrapper, is +# more flexible. We can easily switch between different versions of NCCL by +# changing the environment variable `SGL_DIFFUSION_NCCL_SO_PATH`, or the `so_file` +# variable in the code. + +# TODO(will): support SGL_DIFFUSION_NCCL_SO_PATH + +import ctypes +import platform +from dataclasses import dataclass +from typing import Any + +import torch +from torch.distributed import ReduceOp + +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import find_nccl_library + +logger = init_logger(__name__) + +# === export types and functions from nccl to Python === +# for the original nccl definition, please check +# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in + +ncclResult_t = ctypes.c_int +ncclComm_t = ctypes.c_void_p + + +class ncclUniqueId(ctypes.Structure): + _fields_ = [("internal", ctypes.c_byte * 128)] + + +cudaStream_t = ctypes.c_void_p +buffer_type = ctypes.c_void_p + +ncclDataType_t = ctypes.c_int + + +class ncclDataTypeEnum: + ncclInt8 = 0 + ncclChar = 0 + ncclUint8 = 1 + ncclInt32 = 2 + ncclInt = 2 + ncclUint32 = 3 + ncclInt64 = 4 + ncclUint64 = 5 + ncclFloat16 = 6 + ncclHalf = 6 + ncclFloat32 = 7 + ncclFloat = 7 + ncclFloat64 = 8 + ncclDouble = 8 + ncclBfloat16 = 9 + ncclNumTypes = 10 + + @classmethod + def from_torch(cls, dtype: torch.dtype) -> int: + if dtype == torch.int8: + return cls.ncclInt8 + if dtype == torch.uint8: + return cls.ncclUint8 + if dtype == torch.int32: + return cls.ncclInt32 + if dtype == torch.int64: + return cls.ncclInt64 + if dtype == torch.float16: + return cls.ncclFloat16 + if dtype == torch.float32: + return cls.ncclFloat32 + if dtype == torch.float64: + return cls.ncclFloat64 + if dtype == torch.bfloat16: + return cls.ncclBfloat16 + raise ValueError(f"Unsupported dtype: {dtype}") + + +ncclRedOp_t = ctypes.c_int + + +class ncclRedOpTypeEnum: + ncclSum = 0 + ncclProd = 1 + ncclMax = 2 + ncclMin = 3 + ncclAvg = 4 + ncclNumOps = 5 + + @classmethod + def from_torch(cls, op: ReduceOp) -> int: + if op == ReduceOp.SUM: + return cls.ncclSum + if op == ReduceOp.PRODUCT: + return cls.ncclProd + if op == ReduceOp.MAX: + return cls.ncclMax + if op == ReduceOp.MIN: + return cls.ncclMin + if op == ReduceOp.AVG: + return cls.ncclAvg + raise ValueError(f"Unsupported op: {op}") + + +@dataclass +class Function: + name: str + restype: Any + argtypes: list[Any] + + +class NCCLLibrary: + exported_functions = [ + # const char* ncclGetErrorString(ncclResult_t result) + Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]), + # ncclResult_t ncclGetVersion(int *version); + Function("ncclGetVersion", ncclResult_t, [ctypes.POINTER(ctypes.c_int)]), + # ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId); + Function("ncclGetUniqueId", ncclResult_t, [ctypes.POINTER(ncclUniqueId)]), + # ncclResult_t ncclCommInitRank( + # ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank); + # note that ncclComm_t is a pointer type, so the first argument + # is a pointer to a pointer + Function( + "ncclCommInitRank", + ncclResult_t, + [ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId, ctypes.c_int], + ), + # ncclResult_t ncclAllReduce( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function( + "ncclAllReduce", + ncclResult_t, + [ + buffer_type, + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ncclRedOp_t, + ncclComm_t, + cudaStream_t, + ], + ), + # ncclResult_t ncclAllGather( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function( + "ncclAllGather", + ncclResult_t, + [ + buffer_type, + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ncclComm_t, + cudaStream_t, + ], + ), + # ncclResult_t ncclReduceScatter( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function( + "ncclReduceScatter", + ncclResult_t, + [ + buffer_type, + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ncclRedOp_t, + ncclComm_t, + cudaStream_t, + ], + ), + # ncclResult_t ncclSend( + # const void* sendbuff, size_t count, ncclDataType_t datatype, + # int dest, ncclComm_t comm, cudaStream_t stream); + Function( + "ncclSend", + ncclResult_t, + [ + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ctypes.c_int, + ncclComm_t, + cudaStream_t, + ], + ), + # ncclResult_t ncclRecv( + # void* recvbuff, size_t count, ncclDataType_t datatype, + # int src, ncclComm_t comm, cudaStream_t stream); + Function( + "ncclRecv", + ncclResult_t, + [ + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ctypes.c_int, + ncclComm_t, + cudaStream_t, + ], + ), + # ncclResult_t ncclBroadcast( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, int root, ncclComm_t comm, + # cudaStream_t stream); + Function( + "ncclBroadcast", + ncclResult_t, + [ + buffer_type, + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ctypes.c_int, + ncclComm_t, + cudaStream_t, + ], + ), + # be cautious! this is a collective call, it will block until all + # processes in the communicator have called this function. + # because Python object destruction can happen in random order, + # it is better not to call it at all. + # ncclResult_t ncclCommDestroy(ncclComm_t comm); + Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]), + ] + + # class attribute to store the mapping from the path to the library + # to avoid loading the same library multiple times + path_to_library_cache: dict[str, Any] = {} + + # class attribute to store the mapping from library path + # to the corresponding dictionary + path_to_dict_mapping: dict[str, dict[str, Any]] = {} + + def __init__(self, so_file: str | None = None): + + so_file = so_file or find_nccl_library() + + try: + if so_file not in NCCLLibrary.path_to_dict_mapping: + lib = ctypes.CDLL(so_file) + NCCLLibrary.path_to_library_cache[so_file] = lib + self.lib = NCCLLibrary.path_to_library_cache[so_file] + except Exception as e: + logger.error( + "Failed to load NCCL library from %s ." + "It is expected if you are not running on NVIDIA/AMD GPUs." + "Otherwise, the nccl library might not exist, be corrupted " + "or it does not support the current platform %s." + "If you already have the library, please set the " + "environment variable SGL_DIFFUSION_NCCL_SO_PATH" + " to point to the correct nccl library path.", + so_file, + platform.platform(), + ) + raise e + + if so_file not in NCCLLibrary.path_to_dict_mapping: + _funcs: dict[str, Any] = {} + for func in NCCLLibrary.exported_functions: + f = getattr(self.lib, func.name) + f.restype = func.restype + f.argtypes = func.argtypes + _funcs[func.name] = f + NCCLLibrary.path_to_dict_mapping[so_file] = _funcs + self._funcs = NCCLLibrary.path_to_dict_mapping[so_file] + + def ncclGetErrorString(self, result: ncclResult_t) -> str: + return str(self._funcs["ncclGetErrorString"](result).decode("utf-8")) + + def NCCL_CHECK(self, result: ncclResult_t) -> None: + if result != 0: + error_str = self.ncclGetErrorString(result) + raise RuntimeError(f"NCCL error: {error_str}") + + def ncclGetVersion(self) -> str: + version = ctypes.c_int() + self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version))) + version_str = str(version.value) + # something like 21903 --> "2.19.3" + major = version_str[0].lstrip("0") + minor = version_str[1:3].lstrip("0") + patch = version_str[3:].lstrip("0") + return f"{major}.{minor}.{patch}" + + def ncclGetUniqueId(self) -> ncclUniqueId: + unique_id = ncclUniqueId() + self.NCCL_CHECK(self._funcs["ncclGetUniqueId"](ctypes.byref(unique_id))) + return unique_id + + def ncclCommInitRank( + self, world_size: int, unique_id: ncclUniqueId, rank: int + ) -> ncclComm_t: + comm = ncclComm_t() + self.NCCL_CHECK( + self._funcs["ncclCommInitRank"]( + ctypes.byref(comm), world_size, unique_id, rank + ) + ) + return comm + + def ncclAllReduce( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + op: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: + # `datatype` actually should be `ncclDataType_t` + # and `op` should be `ncclRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK( + self._funcs["ncclAllReduce"]( + sendbuff, recvbuff, count, datatype, op, comm, stream + ) + ) + + def ncclReduceScatter( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + op: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: + # `datatype` actually should be `ncclDataType_t` + # and `op` should be `ncclRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK( + self._funcs["ncclReduceScatter"]( + sendbuff, recvbuff, count, datatype, op, comm, stream + ) + ) + + def ncclAllGather( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: + # `datatype` actually should be `ncclDataType_t` + # which is an aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK( + self._funcs["ncclAllGather"]( + sendbuff, recvbuff, count, datatype, comm, stream + ) + ) + + def ncclSend( + self, + sendbuff: buffer_type, + count: int, + datatype: int, + dest: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: + self.NCCL_CHECK( + self._funcs["ncclSend"](sendbuff, count, datatype, dest, comm, stream) + ) + + def ncclRecv( + self, + recvbuff: buffer_type, + count: int, + datatype: int, + src: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: + self.NCCL_CHECK( + self._funcs["ncclRecv"](recvbuff, count, datatype, src, comm, stream) + ) + + def ncclBroadcast( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + root: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: + self.NCCL_CHECK( + self._funcs["ncclBroadcast"]( + sendbuff, recvbuff, count, datatype, root, comm, stream + ) + ) + + def ncclCommDestroy(self, comm: ncclComm_t) -> None: + self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm)) + + +__all__ = [ + "NCCLLibrary", + "ncclDataTypeEnum", + "ncclRedOpTypeEnum", + "ncclUniqueId", + "ncclComm_t", + "cudaStream_t", + "buffer_type", +] diff --git a/python/sglang/multimodal_gen/runtime/distributed/group_coordinator.py b/python/sglang/multimodal_gen/runtime/distributed/group_coordinator.py new file mode 100644 index 000000000..dd42b8756 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/distributed/group_coordinator.py @@ -0,0 +1,1226 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# Copyright 2024 xDiT team. +# Adapted from +# https://github.com/vllm-project/vllm/blob/main/vllm/distributed/parallel_state.py +# Copyright 2023 The vLLM team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +import pickle +from collections import namedtuple +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.distributed +from torch.cuda import synchronize +from torch.distributed import Backend, ProcessGroup + +from sglang.multimodal_gen import envs +from sglang.multimodal_gen.runtime.distributed.device_communicators.base_device_communicator import ( + DeviceCommunicatorBase, +) +from sglang.multimodal_gen.runtime.distributed.device_communicators.cpu_communicator import ( + CpuCommunicator, +) +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +try: + import torch_musa # noqa: F401 + from torch_musa.core.device import synchronize +except ModuleNotFoundError: + pass + +logger = init_logger(__name__) + +TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) + + +_group_name_counter: dict[str, int] = {} + + +def get_local_torch_device() -> torch.device: + """Return the torch device for the current rank.""" + from sglang.multimodal_gen.runtime.platforms import current_platform + + return ( + torch.device(f"cuda:{envs.LOCAL_RANK}") + if current_platform.is_cuda_alike() + else torch.device("mps") + ) + + +def _get_unique_name(name: str) -> str: + """Get a unique name for the group. + Example: + _get_unique_name("tp") -> "tp:0" + _get_unique_name("tp") -> "tp:1" + """ + if name not in _group_name_counter: + _group_name_counter[name] = 0 + newname = f"{name}:{_group_name_counter[name]}" + _group_name_counter[name] += 1 + return newname + + +def _split_tensor_dict( + tensor_dict: Dict[str, Union[torch.Tensor, Any]], prefix: str = "" +) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]: + """Split the tensor dictionary into two parts: + 1. A list of (key, value) pairs. If the value is a tensor, it is replaced + by its metadata. + 2. A list of tensors. + + If the Tensor is nested under `tensor_dict["key1"]["key2"]`, the key of its + metadata will be "key1%key2". + """ + metadata_list: List[Tuple[str, Any]] = [] + tensor_list = [] + for key, value in tensor_dict.items(): + assert "%" not in key, ( + "Avoid having '%' in key " + "as it is used as a separator for nested entries." + ) + if isinstance(value, torch.Tensor): + # Note: we cannot use `value.device` here, + # because it contains not only the device type but also the device + # index (e.g. "cuda:0"). We only need the device type. + # receiving side will set the device index. + device = value.device.type + metadata_list.append( + ( + prefix + key, + TensorMetadata(device, value.dtype, value.size()), + ) + ) + tensor_list.append(value) + elif isinstance(value, dict): + if len(value) == 0: + metadata_list.append((prefix + key, value)) + inner_metadata_list, inner_tensor_list = _split_tensor_dict( + value, prefix + key + "%" + ) + metadata_list.extend(inner_metadata_list) + tensor_list.extend(inner_tensor_list) + else: + metadata_list.append((prefix + key, value)) + return metadata_list, tensor_list + + +def _update_nested_dict(nested_dict, flattened_key, value): + key_splits = flattened_key.split("%") + cur_dict = nested_dict + for k in key_splits[:-1]: + if k not in cur_dict: + cur_dict[k] = {} + cur_dict = cur_dict[k] + cur_dict[key_splits[-1]] = value + + +@dataclass +class GraphCaptureContext: + stream: torch.cuda.Stream | None + + +class GroupCoordinator: + """ + PyTorch ProcessGroup wrapper for a group of processes. + PyTorch ProcessGroup is bound to one specific communication backend, + e.g. NCCL, Gloo, MPI, etc. + GroupCoordinator takes charge of all the communication operations among + the processes in the group. It can route the communication to + a specific implementation (e.g. switch allreduce implementation + based on the tensor size and cuda graph mode). + """ + + # available attributes: + rank: int # global rank + ranks: List[int] # global ranks in the group + world_size: int # size of the group + # difference between `local_rank` and `rank_in_group`: + # if we have a group of size 4 across two nodes: + # Process | Node | Rank | Local Rank | Rank in Group + # 0 | 0 | 0 | 0 | 0 + # 1 | 0 | 1 | 1 | 1 + # 2 | 1 | 2 | 0 | 2 + # 3 | 1 | 3 | 1 | 3 + local_rank: int # local rank in the current node, used to assign devices + rank_in_group: int # rank inside the group + cpu_group: ProcessGroup # group for CPU communication + device_group: ProcessGroup # group for device communication + use_device_communicator: bool # whether to use device communicator + device_communicator: DeviceCommunicatorBase # device communicator + + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + use_device_communicator: bool = True, + use_message_queue_broadcaster: bool = False, + group_name: str | None = None, + ): + self.unique_name = _get_unique_name(group_name) + self.rank = torch.distributed.get_rank() + self.local_rank = local_rank + self.device_group = None + self.cpu_group = None + + for ranks in group_ranks: + device_group = torch.distributed.new_group( + ranks, backend=torch_distributed_backend + ) + # a group with `gloo` backend, to allow direct coordination between + # processes through the CPU. + cpu_group = torch.distributed.new_group(ranks, backend="gloo") + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + self.device_group = device_group + self.cpu_group = cpu_group + + assert self.cpu_group is not None, f"{group_ranks=}, {local_rank=}" + assert self.device_group is not None + + # TODO: fix it for other platforms + self.device = get_local_torch_device() + + from sglang.multimodal_gen.runtime.platforms import current_platform + + self.use_device_communicator = use_device_communicator + + self.device_communicator: DeviceCommunicatorBase = None # type: ignore + if use_device_communicator and self.world_size > 1: + # Platform-aware device communicator selection + if current_platform.is_cuda_alike(): + from sglang.multimodal_gen.runtime.distributed.device_communicators.cuda_communicator import ( + CudaCommunicator, + ) + + self.device_communicator = CudaCommunicator( + cpu_group=self.cpu_group, + device=self.device, + device_group=self.device_group, + unique_name=self.unique_name, + ) + else: + # For MPS and CPU, use the CPU communicator + self.device_communicator = CpuCommunicator( + cpu_group=self.cpu_group, + device=self.device, + device_group=self.device_group, + unique_name=self.unique_name, + ) + + self.mq_broadcaster = None + + # TODO(will): check if this is needed + # self.use_custom_op_call = current_platform.is_cuda_alike() + self.use_custom_op_call = False + + @property + def first_rank(self): + """Return the global rank of the first process in the group""" + return self.ranks[0] + + @property + def last_rank(self): + """Return the global rank of the last process in the group""" + return self.ranks[-1] + + @property + def is_first_rank(self): + """Return whether the caller is the first process in the group""" + return self.rank == self.first_rank + + @property + def is_last_rank(self): + """Return whether the caller is the last process in the group""" + return self.rank == self.last_rank + + @property + def next_rank(self): + """Return the global rank of the process that follows the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group + 1) % world_size] + + @property + def prev_rank(self): + """Return the global rank of the process that precedes the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group - 1) % world_size] + + @property + def group_next_rank(self): + """Return the group rank of the process that follows the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return (rank_in_group + 1) % world_size + + @property + def group_prev_rank(self): + """Return the group rank of the process that precedes the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return (rank_in_group - 1) % world_size + + @property + def skip_rank(self): + """Return the global rank of the process that skip connects with the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(world_size - rank_in_group - 1) % world_size] + + @property + def group_skip_rank(self): + """Return the group rank of the process that skip connects with the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return (world_size - rank_in_group - 1) % world_size + + @contextmanager + def graph_capture(self, graph_capture_context: GraphCaptureContext | None = None): + # Platform-aware graph capture + from sglang.multimodal_gen.runtime.platforms import current_platform + + if current_platform.is_cuda_alike(): + if graph_capture_context is None: + stream = torch.cuda.Stream() + graph_capture_context = GraphCaptureContext(stream) + else: + stream = graph_capture_context.stream + + # ensure all initialization operations complete before attempting to + # capture the graph on another stream + curr_stream = torch.cuda.current_stream() + if curr_stream != stream: + stream.wait_stream(curr_stream) + + with torch.cuda.stream(stream): + yield graph_capture_context + else: + # For non-CUDA platforms (MPS, CPU), just yield the context without stream management + if graph_capture_context is None: + # Create a dummy context for non-CUDA platforms + graph_capture_context = GraphCaptureContext(None) + yield graph_capture_context + + def all_to_all_4D( + self, input_: torch.Tensor, scatter_dim: int = 2, gather_dim: int = 1 + ) -> torch.Tensor: + if self.world_size == 1: + return input_ + return self.device_communicator.all_to_all_4D(input_, scatter_dim, gather_dim) + + def all_reduce( + self, + input_: torch.Tensor, + op=torch._C._distributed_c10d.ReduceOp.SUM, + async_op: bool = False, + ) -> torch.Tensor: + """ + NOTE: This operation will be applied in-place or out-of-place. + Always assume this function modifies its input, but use the return + value as the output. + """ + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + else: + torch.distributed.all_reduce( + input_, op=op, group=self.device_group, async_op=async_op + ) + return input_ + + def all_gather( + self, input_: torch.Tensor, dim: int = 0, separate_tensors: bool = False + ) -> Union[torch.Tensor, List[torch.Tensor]]: + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert ( + -input_.dim() <= dim < input_.dim() + ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + # Allocate output tensor. + input_size = list(input_.size()) + input_size[0] *= world_size + output_tensor = torch.empty( + input_size, dtype=input_.dtype, device=input_.device + ) + # All-gather. + torch.distributed.all_gather_into_tensor( + output_tensor, input_, group=self.device_group + ) + if dim != 0: + input_size[0] //= world_size + output_tensor = output_tensor.reshape( + [ + world_size, + ] + + input_size + ) + output_tensor = output_tensor.movedim(0, dim) + + if separate_tensors: + tensor_list = [ + output_tensor.reshape(-1) + .narrow(0, input_.numel() * i, input_.numel()) + .view_as(input_) + for i in range(world_size) + ] + return tensor_list + else: + input_size = list(input_.size()) + input_size[dim] = input_size[dim] * world_size + # Reshape + output_tensor = output_tensor.reshape(input_size) + return output_tensor + + def gather(self, input_: torch.Tensor, dst: int = 0, dim: int = -1) -> torch.Tensor: + """ + NOTE: We assume that the input tensor is on the same device across + all the ranks. + NOTE: `dst` is the local rank of the destination rank. + """ + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert ( + -input_.dim() <= dim < input_.dim() + ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + # Allocate output tensor. + if self.rank_in_group == dst: + gather_list = [torch.empty_like(input_) for _ in range(world_size)] + else: + gather_list = None + # Gather. + torch.distributed.gather( + input_, gather_list, dst=self.ranks[dst], group=self.device_group + ) + if self.rank_in_group == dst: + output_tensor = torch.cat(gather_list, dim=dim) + else: + output_tensor = None + return output_tensor + + def broadcast(self, input_: torch.Tensor, src: int = 0, async_op: bool = False): + """Broadcast the input tensor. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + # Broadcast. + torch.distributed.broadcast( + input_, + src=self.ranks[src], + group=self.device_group, + async_op=async_op, + ) + return input_ + + def broadcast_object(self, obj: Optional[Any] = None, src: int = 0): + """Broadcast the input object. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj + if self.shm_broadcaster is not None: + assert src == 0, "Shared memory broadcaster only supports src=0" + return self.shm_broadcaster.broadcast_object(obj) + if self.rank_in_group == src: + torch.distributed.broadcast_object_list( + [obj], src=self.ranks[src], group=self.cpu_group + ) + return obj + else: + recv = [None] + torch.distributed.broadcast_object_list( + recv, src=self.ranks[src], group=self.cpu_group + ) + return recv[0] + + def broadcast_object_list( + self, + obj_list: List[Any], + src: int = 0, + group: Optional[ProcessGroup] = None, + ): + """Broadcast the input object list. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj_list + # Broadcast. + torch.distributed.broadcast_object_list( + obj_list, src=self.ranks[src], group=self.device_group + ) + return obj_list + + def send_object(self, obj: Any, dst: int) -> None: + """Send the input object list to the destination rank.""" + """NOTE: `dst` is the local rank of the destination rank.""" + + assert dst < self.world_size, f"Invalid dst rank ({dst})" + + assert dst != self.rank, ( + "Invalid destination rank. Destination rank is the same " + "as the current rank." + ) + + # Serialize object to tensor and get the size as well + object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8) + + size_tensor = torch.tensor( + [object_tensor.numel()], dtype=torch.long, device="cpu" + ) + + # Send object size + + torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group) + + # Send object + torch.distributed.send(object_tensor, dst=self.ranks[dst], group=self.cpu_group) + + return None + + def recv_object(self, src: int) -> Any: + """Receive the input object list from the source rank.""" + """NOTE: `src` is the local rank of the source rank.""" + + assert src < self.world_size, f"Invalid src rank ({src})" + + assert ( + src != self.rank + ), "Invalid source rank. Source rank is the same as the current rank." + + size_tensor = torch.empty(1, dtype=torch.long, device="cpu") + + # Receive object size + rank_size = torch.distributed.recv( + size_tensor, src=self.ranks[src], group=self.cpu_group + ) + + # Tensor to receive serialized objects into. + object_tensor = torch.empty( # type: ignore[call-overload] + size_tensor.item(), # type: ignore[arg-type] + dtype=torch.uint8, + device="cpu", + ) + + rank_object = torch.distributed.recv( + object_tensor, src=self.ranks[src], group=self.cpu_group + ) + + assert ( + rank_object == rank_size + ), "Received object sender rank does not match the size sender rank." + + obj = pickle.loads(object_tensor.numpy().tobytes()) + + return obj + + def broadcast_tensor_dict( + self, + tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None, + src: int = 0, + group: Optional[ProcessGroup] = None, + metadata_group: Optional[ProcessGroup] = None, + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Broadcast the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return tensor_dict + + group = self.device_group + metadata_group = self.cpu_group + assert src < self.world_size, f"Invalid src rank ({src})" + src = self.ranks[src] + + rank = self.rank + if rank == src: + metadata_list: List[Tuple[Any, Any]] = [] + assert isinstance( + tensor_dict, dict + ), f"Expecting a dictionary, got {type(tensor_dict)}" + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `broadcast_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.broadcast_object(metadata_list, src=src) + async_handles = [] + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast( + tensor, src=src, group=metadata_group, async_op=True + ) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast( + tensor, src=src, group=group, async_op=True + ) + async_handles.append(handle) + for async_handle in async_handles: + async_handle.wait() + + else: + metadata_list = self.broadcast_object(None, src=src) + tensor_dict = {} + async_handles = [] + for key, value in metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty( + value.size, dtype=value.dtype, device=value.device + ) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + _update_nested_dict(tensor_dict, key, tensor) + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast( + tensor, src=src, group=metadata_group, async_op=True + ) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast( + tensor, src=src, group=group, async_op=True + ) + async_handles.append(handle) + _update_nested_dict(tensor_dict, key, tensor) + else: + _update_nested_dict(tensor_dict, key, value) + for async_handle in async_handles: + async_handle.wait() + return tensor_dict + + def send_tensor_dict( + self, + tensor_dict: Dict[str, Union[torch.Tensor, Any]], + dst: Optional[int] = None, + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Send the input tensor dictionary. + NOTE: `dst` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return tensor_dict + + group = self.device_group + metadata_group = self.cpu_group + + if dst is None: + dst = self.group_next_rank + assert dst < self.world_size, f"Invalid dst rank ({dst})" + + metadata_list: List[Tuple[Any, Any]] = [] + assert isinstance( + tensor_dict, dict + ), f"Expecting a dictionary, got {type(tensor_dict)}" + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `send_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.send_object(metadata_list, dst=dst) + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip sending empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.send( + tensor, dst=self.ranks[dst], group=metadata_group + ) + else: + # use group for GPU tensors + torch.distributed.send(tensor, dst=self.ranks[dst], group=group) + return None + + def recv_tensor_dict( + self, src: Optional[int] = None + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Recv the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return None + + group = self.device_group + metadata_group = self.cpu_group + + if src is None: + src = self.group_prev_rank + assert src < self.world_size, f"Invalid src rank ({src})" + + recv_metadata_list = self.recv_object(src=src) + tensor_dict: Dict[str, Any] = {} + for key, value in recv_metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, dtype=value.dtype, device=value.device) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + _update_nested_dict(tensor_dict, key, tensor) + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.recv( + tensor, src=self.ranks[src], group=metadata_group + ) + else: + # use group for GPU tensors + torch.distributed.recv(tensor, src=self.ranks[src], group=group) + _update_nested_dict(tensor_dict, key, tensor) + else: + _update_nested_dict(tensor_dict, key, value) + return tensor_dict + + def barrier(self): + """Barrier synchronization among the group. + NOTE: don't use `device_group` here! `barrier` in NCCL is + terrible because it is internally a broadcast operation with + secretly created GPU tensors. It is easy to mess up the current + device. Use the CPU group instead. + """ + torch.distributed.barrier(group=self.cpu_group) + + def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: + """Sends a tensor to the destination rank in a non-blocking way""" + """NOTE: `dst` is the rank_in_group of the destination rank.""" + if dst is None: + dst = self.group_next_rank + + torch.distributed.send( + tensor, + self.ranks[dst], + group=( + self.device_groups[self.rank_in_group % 2] + if self.world_size == 2 + else self.device_group + ), + ) + + def recv( + self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None + ) -> torch.Tensor: + """Receives a tensor from the src rank.""" + """NOTE: `src` is the rank_in_group of the source rank.""" + if src is None: + src = self.group_prev_rank + + tensor = torch.empty(size, dtype=dtype, device=self.device) + torch.distributed.recv( + tensor, + self.ranks[src], + ( + self.device_groups[(self.rank_in_group + 1) % 2] + if self.world_size == 2 + else self.device_group + ), + ) + return tensor + + def destroy(self) -> None: + if self.device_group is not None: + torch.distributed.destroy_process_group(self.device_group) + self.device_group = None + if self.cpu_group is not None: + torch.distributed.destroy_process_group(self.cpu_group) + self.cpu_group = None + if self.device_communicator is not None: + self.device_communicator.destroy() + if self.mq_broadcaster is not None: + self.mq_broadcaster = None + + +class PipelineGroupCoordinator(GroupCoordinator): + """ + available attributes: + rank: int # global rank + ranks: List[int] # global ranks in the group + world_size: int # size of the group + difference between `local_rank` and `rank_in_group`: + if we have a group of size 4 across two nodes: + Process | Node | Rank | Local Rank | Rank in Group + 0 | 0 | 0 | 0 | 0 + 1 | 0 | 1 | 1 | 1 + 2 | 1 | 2 | 0 | 2 + 3 | 1 | 3 | 1 | 3 + local_rank: int # local rank used to assign devices + rank_in_group: int # rank inside the group + cpu_group: ProcessGroup # group for CPU communication + device_group: ProcessGroup # group for device communication + """ + + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + group_name: str | None = None, + ): + super().__init__( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=torch_distributed_backend, + group_name=group_name, + ) + self.rank = torch.distributed.get_rank() + self.local_rank = local_rank + self.device_group = None + self.cpu_group = None + self.cpu_groups = [] + self.device_groups = [] + if len(group_ranks[0]) > 2 or len(group_ranks[0]) == 1: + for ranks in group_ranks: + device_group = torch.distributed.new_group( + ranks, backend=torch_distributed_backend + ) + # a group with `gloo` backend, to allow direct coordination between + # processes through the CPU. + cpu_group = torch.distributed.new_group(ranks, backend="gloo") + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + self.device_group = device_group + self.cpu_group = cpu_group + # when pipeline parallelism is 2, we need to create two groups to avoid + # communication stall. + # *_group_0_1 represents the group for communication from device 0 to + # device 1. + # *_group_1_0 represents the group for communication from device 1 to + # device 0. + elif len(group_ranks[0]) == 2: + for ranks in group_ranks: + device_group_0_1 = torch.distributed.new_group( + ranks, backend=torch_distributed_backend + ) + device_group_1_0 = torch.distributed.new_group( + ranks, backend=torch_distributed_backend + ) + # a group with `gloo` backend, to allow direct coordination between + # processes through the CPU. + cpu_group_0_1 = torch.distributed.new_group(ranks, backend="gloo") + cpu_group_1_0 = torch.distributed.new_group(ranks, backend="gloo") + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + self.device_groups = [device_group_0_1, device_group_1_0] + self.cpu_groups = [cpu_group_0_1, cpu_group_1_0] + self.device_group = device_group_0_1 + self.cpu_group = cpu_group_0_1 + + assert self.cpu_group is not None + assert self.device_group is not None + + self.device = envs.get_device(local_rank) + + self.recv_buffer_set: bool = False + self.recv_tasks_queue: List[Tuple[str, int]] = [] + self.receiving_tasks: List[Tuple[torch.distributed.Work, str, int]] = [] + self.dtype: Optional[torch.dtype] = None + self.num_pipefusion_patches: Optional[int] = None + + self.recv_shape: Dict[str, Dict[int, torch.Size]] = {} + self.send_shape: Dict[str, Dict[int, torch.Size]] = {} + self.recv_buffer: Dict[str, Dict[int, torch.Size]] = {} + + self.skip_tensor_recv_buffer_set: bool = False + self.recv_skip_tasks_queue: List[Union[int, Tuple[str, int]]] = [] + self.receiving_skip_tasks: List[Tuple[torch.distributed.Work, str, int]] = [] + self.skip_tensor_recv_buffer: Optional[ + Union[List[torch.Tensor], torch.Tensor] + ] = None + self.skip_device_group = None + for ranks in group_ranks: + skip_device_group = torch.distributed.new_group( + ranks, backend=torch_distributed_backend + ) + if self.rank in ranks: + self.skip_device_group = skip_device_group + assert self.skip_device_group is not None + + def reset_buffer(self): + self.recv_tasks_queue = [] + self.receiving_tasks = [] + self.recv_shape = {} + self.send_shape = {} + self.recv_buffer = {} + + self.recv_skip_tasks_queue = [] + self.receiving_skip_tasks = [] + self.skip_tensor_recv_buffer = {} + + def set_config(self, dtype: torch.dtype): + self.dtype = dtype + + def set_recv_buffer( + self, + num_pipefusion_patches: int, + patches_shape_list: List[List[int]], + feature_map_shape: List[int], + dtype: torch.dtype, + ): + assert isinstance(dtype, torch.dtype), "dtype must be a torch.dtype object" + assert ( + isinstance(num_pipefusion_patches, int) and num_pipefusion_patches >= 1 + ), "num_pipefusion_patches must be greater than or equal to 1" + self.dtype = dtype + self.num_pipefusion_patches = num_pipefusion_patches + self.recv_buffer = [ + torch.zeros(*shape, dtype=self.dtype, device=self.device) + for shape in patches_shape_list + ] + self.recv_buffer.append( + torch.zeros(*feature_map_shape, dtype=self.dtype, device=self.device) + ) + self.recv_buffer_set = True + + def set_extra_tensors_recv_buffer( + self, + name: str, + shape: List[int], + num_buffers: int = 1, + dtype: torch.dtype = torch.float16, + ): + self.extra_tensors_recv_buffer[name] = [ + torch.zeros(*shape, dtype=dtype, device=self.device) + for _ in range(num_buffers) + ] + + def _check_shape_and_buffer( + self, + tensor_send_to_next=None, + recv_prev=False, + name: Optional[str] = None, + segment_idx: int = 0, + ): + send_flag = False + name = name or "latent" + if tensor_send_to_next is not None: + shape_list = self.send_shape.get(name, None) + if shape_list is None: + self.send_shape[name] = {segment_idx: tensor_send_to_next.shape} + send_flag = True + elif shape_list.get(segment_idx, None) is None: + self.send_shape[name][segment_idx] = tensor_send_to_next.shape + send_flag = True + + recv_flag = False + if recv_prev: + shape_list = self.recv_shape.get(name, None) + if shape_list is None: + recv_flag = True + elif shape_list.get(segment_idx, None) is None: + recv_flag = True + + recv_prev_shape = self._communicate_shapes( + tensor_send_to_next=tensor_send_to_next if send_flag else None, + recv_prev=recv_flag, + ) + + if recv_flag: + if self.recv_shape.get(name, None) is None: + self.recv_shape[name] = {segment_idx: recv_prev_shape} + else: + self.recv_shape[name][segment_idx] = recv_prev_shape + + if self.recv_buffer.get(name, None) is None: + self.recv_buffer[name] = { + segment_idx: torch.zeros( + recv_prev_shape, device=self.device, dtype=self.dtype + ) + } + else: + if self.recv_buffer[name].get(segment_idx, None) is not None: + logger.warning( + f"Recv buffer [name: {name}, segment_idx: {segment_idx}] already exist. updating..." + ) + self.recv_buffer[name][segment_idx] = torch.zeros( + recv_prev_shape, device=self.device, dtype=self.dtype + ) + + def _communicate_shapes(self, tensor_send_to_next=None, recv_prev=False): + """Communicate tensor shapes between stages. Used to communicate + tensor shapes before the actual tensor communication happens. + + Args: + tensor_send_next: tensor to send to next rank (no tensor sent if + set to None). + recv_prev: boolean for whether tensor should be received from + previous rank. + """ + + ops = [] + if recv_prev: + recv_prev_dim_tensor = torch.empty( + (1), device=self.device, dtype=torch.int64 + ) + recv_prev_dim_op = torch.distributed.P2POp( + torch.distributed.irecv, + recv_prev_dim_tensor, + self.prev_rank, + self.device_group, + ) + ops.append(recv_prev_dim_op) + + if tensor_send_to_next is not None: + send_next_dim_tensor = torch.tensor( + tensor_send_to_next.dim(), device=self.device, dtype=torch.int64 + ) + send_next_dim_op = torch.distributed.P2POp( + torch.distributed.isend, + send_next_dim_tensor, + self.next_rank, + self.device_group, + ) + ops.append(send_next_dim_op) + + if len(ops) > 0: + reqs = torch.distributed.batch_isend_irecv(ops) + for req in reqs: + req.wait() + + # To protect against race condition when using batch_isend_irecv(). + # should take this out once the bug with batch_isend_irecv is resolved. + synchronize() + + ops = [] + recv_prev_shape_tensor = None + if recv_prev: + recv_prev_shape_tensor = torch.empty( + torch.Size(recv_prev_dim_tensor), + device=self.device, + dtype=torch.int64, + ) + recv_prev_shape_op = torch.distributed.P2POp( + torch.distributed.irecv, + recv_prev_shape_tensor, + self.prev_rank, + self.device_group, + ) + ops.append(recv_prev_shape_op) + + if tensor_send_to_next is not None: + send_next_shape_tensor = torch.tensor( + tensor_send_to_next.size(), + device=self.device, + dtype=torch.int64, + ) + send_next_shape_op = torch.distributed.P2POp( + torch.distributed.isend, + send_next_shape_tensor, + self.next_rank, + self.device_group, + ) + ops.append(send_next_shape_op) + + if len(ops) > 0: + reqs = torch.distributed.batch_isend_irecv(ops) + for req in reqs: + req.wait() + + synchronize() + + recv_prev_shape = [0, 0, 0] + if recv_prev_shape_tensor is not None: + recv_prev_shape = recv_prev_shape_tensor + return torch.Size(recv_prev_shape) + + def pipeline_send( + self, tensor: torch.Tensor, name: str = "latent", segment_idx: int = -1 + ) -> None: + tensor = tensor.contiguous() + self._check_shape_and_buffer( + tensor_send_to_next=tensor, name=name, segment_idx=segment_idx + ) + self._pipeline_isend(tensor).wait() + + def pipeline_isend( + self, tensor: torch.Tensor, name: str = "latent", segment_idx: int = -1 + ) -> None: + tensor = tensor.contiguous() + self._check_shape_and_buffer( + tensor_send_to_next=tensor, name=name, segment_idx=segment_idx + ) + self._pipeline_isend(tensor) + + def pipeline_recv(self, idx: int = -1, name: str = "latent") -> torch.Tensor: + name = name or "latent" + self._check_shape_and_buffer(recv_prev=True, name=name, segment_idx=idx) + self._pipeline_irecv(self.recv_buffer[name][idx]).wait() + return self.recv_buffer[name][idx] + + def add_pipeline_recv_task(self, idx: int = -1, name: str = "latent"): + name = name or "latent" + self.recv_tasks_queue.append((name, idx)) + + def recv_next(self): + if len(self.recv_tasks_queue) == 0: + raise ValueError("No more tasks to receive") + elif len(self.recv_tasks_queue) > 0: + name, idx = self.recv_tasks_queue.pop(0) + self._check_shape_and_buffer(recv_prev=True, name=name, segment_idx=idx) + self.receiving_tasks.append( + (self._pipeline_irecv(self.recv_buffer[name][idx]), name, idx) + ) + + def get_pipeline_recv_data( + self, idx: int = -1, name: str = "latent" + ) -> torch.Tensor: + assert ( + len(self.receiving_tasks) > 0 + ), "No tasks to receive, call add_pipeline_recv_task first" + receiving_task = self.receiving_tasks.pop(0) + receiving_task[0].wait() + assert ( + receiving_task[1] == name and receiving_task[2] == idx + ), "Received tensor does not match the requested" + return self.recv_buffer[name][idx] + + def _pipeline_irecv(self, tensor: torch.tensor): + return torch.distributed.irecv( + tensor, + src=self.prev_rank, + group=( + self.device_groups[(self.rank_in_group + 1) % 2] + if self.world_size == 2 + else self.device_group + ), + ) + + def _pipeline_isend(self, tensor: torch.tensor): + return torch.distributed.isend( + tensor, + dst=self.next_rank, + group=( + self.device_groups[self.rank_in_group % 2] + if self.world_size == 2 + else self.device_group + ), + ) + + def set_skip_tensor_recv_buffer( + self, + patches_shape_list: List[List[int]], + feature_map_shape: List[int], + ): + self.skip_tensor_recv_buffer = [ + torch.zeros(*shape, dtype=self.dtype, device=self.device) + for shape in patches_shape_list + ] + self.skip_tensor_recv_buffer.append( + torch.zeros(*feature_map_shape, dtype=self.dtype, device=self.device) + ) + self.skip_tensor_recv_buffer_set = True + + def pipeline_send_skip(self, tensor: torch.Tensor) -> None: + tensor = tensor.contiguous() + self._pipeline_isend_skip(tensor).wait() + + def pipeline_isend_skip(self, tensor: torch.Tensor) -> None: + tensor = tensor.contiguous() + self._pipeline_isend_skip(tensor) + + def pipeline_recv_skip(self, idx: int = -1) -> torch.Tensor: + self._pipeline_irecv_skip(self.skip_tensor_recv_buffer[idx]).wait() + return self.skip_tensor_recv_buffer[idx] + + def add_pipeline_recv_skip_task(self, idx: int = -1): + self.recv_skip_tasks_queue.append(idx) + + def get_pipeline_recv_skip_data(self, idx: int = -1) -> torch.Tensor: + assert ( + len(self.receiving_skip_tasks) > 0 + ), "No tasks to receive, call add_pipeline_recv_skip_task first" + receiving_skip_task = self.receiving_skip_tasks.pop(0) + receiving_skip_task[0].wait() + assert ( + receiving_skip_task[2] == idx + ), "Received tensor does not match the requested" + return self.skip_tensor_recv_buffer[idx] + + def recv_skip_next(self): + if len(self.recv_skip_tasks_queue) == 0: + raise ValueError("No more tasks to receive") + elif len(self.recv_skip_tasks_queue) > 0: + task = self.recv_skip_tasks_queue.pop(0) + idx = task + self.receiving_skip_tasks.append( + ( + self._pipeline_irecv_skip(self.skip_tensor_recv_buffer[idx]), + None, + idx, + ) + ) + + def _pipeline_irecv_skip(self, tensor: torch.tensor): + return torch.distributed.irecv( + tensor, src=self.skip_rank, group=self.skip_device_group + ) + + def _pipeline_isend_skip(self, tensor: torch.tensor): + return torch.distributed.isend( + tensor, dst=self.skip_rank, group=self.skip_device_group + ) + + +class SequenceParallelGroupCoordinator(GroupCoordinator): + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + group_name: str | None = None, + **kwargs, + ): + super().__init__( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=torch_distributed_backend, + group_name=group_name, + ) + ulysses_group = kwargs.get("ulysses_group", None) + ring_group = kwargs.get("ring_group", None) + if ulysses_group is None: + raise RuntimeError( + f"Please pass argument 'ulysses_group' when calling init func of SequenceParallelGroupCoordinator" + ) + if ring_group is None: + raise RuntimeError( + f"Please pass argument 'ring_group' when calling init func of SequenceParallelGroupCoordinator" + ) + self.ulysses_group = ulysses_group + self.ring_group = ring_group + + self.ulysses_world_size = torch.distributed.get_world_size(self.ulysses_group) + self.ulysses_rank = torch.distributed.get_rank(self.ulysses_group) + self.ring_world_size = torch.distributed.get_world_size(self.ring_group) + self.ring_rank = torch.distributed.get_rank(self.ring_group) diff --git a/python/sglang/multimodal_gen/runtime/distributed/parallel_state.py b/python/sglang/multimodal_gen/runtime/distributed/parallel_state.py new file mode 100644 index 000000000..a99195aab --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/distributed/parallel_state.py @@ -0,0 +1,1144 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/parallel_state.py +# Copyright 2023 The vLLM team. +# Adapted from +# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Adapted from +# Copyright 2024 xDiT team. +# Adapted from +# https://github.com/vllm-project/vllm/blob/main/vllm/distributed/parallel_state.py +# Copyright 2023 The vLLM team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""sgl-diffusion distributed state. + +It takes over the control of the distributed environment from PyTorch. +The typical workflow is: + +- call `init_distributed_environment` to initialize the distributed environment. +- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to + initialize the model parallel groups. + +- any code dealing with the distributed stuff + +- call `destroy_model_parallel` to destroy the model parallel groups. +- call `destroy_distributed_environment` to destroy the distributed environment. + +If you only need to use the distributed environment without model parallelism, + you can skip the model parallel initialization and destruction steps. +""" +import contextlib +import os +import weakref +from collections import namedtuple +from collections.abc import Callable +from contextlib import contextmanager +from multiprocessing import shared_memory +from typing import Any, List, Optional +from unittest.mock import patch + +import torch +import torch.distributed +from torch.distributed import ProcessGroup + +import sglang.multimodal_gen.envs as envs +from sglang.multimodal_gen.runtime.distributed.utils import StatelessProcessGroup +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +from ..utils.distributed import RankGenerator +from .group_coordinator import ( + GroupCoordinator, + PipelineGroupCoordinator, + SequenceParallelGroupCoordinator, + get_local_torch_device, +) + +logger = init_logger(__name__) + +_WORLD: Optional[GroupCoordinator] = None +_TP: Optional[GroupCoordinator] = None +_SP: Optional[SequenceParallelGroupCoordinator] = None +_PP: Optional[PipelineGroupCoordinator] = None +_CFG: Optional[GroupCoordinator] = None +_DP: Optional[GroupCoordinator] = None +_DIT: Optional[GroupCoordinator] = None +_VAE: Optional[GroupCoordinator] = None + +logger = init_logger(__name__) + +TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) + + +def _split_tensor_dict( + tensor_dict: dict[str, torch.Tensor | Any] +) -> tuple[list[tuple[str, Any]], list[torch.Tensor]]: + """Split the tensor dictionary into two parts: + 1. A list of (key, value) pairs. If the value is a tensor, it is replaced + by its metadata. + 2. A list of tensors. + """ + metadata_list: list[tuple[str, Any]] = [] + tensor_list: list[torch.Tensor] = [] + for key, value in tensor_dict.items(): + if isinstance(value, torch.Tensor): + # Note: we cannot use `value.device` here, + # because it contains not only the device type but also the device + # index (e.g. "cuda:0"). We only need the device type. + # receiving side will set the device index. + device = value.device.type + metadata_list.append( + (key, TensorMetadata(device, value.dtype, value.size())) + ) + tensor_list.append(value) + else: + metadata_list.append((key, value)) + return metadata_list, tensor_list + + +_groups: dict[str, Callable[[], Optional["GroupCoordinator"]]] = {} + + +def _register_group(group: "GroupCoordinator") -> None: + _groups[group.unique_name] = weakref.ref(group) + + +def all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor: + assert group_name in _groups, f"Group {group_name} is not found." + group = _groups[group_name]() + if group is None: + raise ValueError(f"Group {group_name} is destroyed.") + return group._all_reduce_out_place(tensor) + + +def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor: + return torch.empty_like(tensor) + + +_WORLD: GroupCoordinator | None = None +_NODE: GroupCoordinator | None = None + + +def get_world_group() -> GroupCoordinator: + assert _WORLD is not None, "world group is not initialized" + return _WORLD + + +def init_world_group( + ranks: list[int], local_rank: int, backend: str +) -> GroupCoordinator: + return GroupCoordinator( + group_ranks=[ranks], + local_rank=local_rank, + torch_distributed_backend=backend, + use_device_communicator=True, + group_name="world", + ) + + +# xDiT +def init_parallel_group_coordinator( + group_ranks: List[List[int]], + local_rank: int, + backend: str, + parallel_mode: str, + **kwargs, +) -> GroupCoordinator: + """ + Returns a Group Coordinator for the given parallel mode + """ + assert parallel_mode in [ + "data", + "pipeline", + "tensor", + "sequence", + "classifier_free_guidance", + ], f"parallel_mode {parallel_mode} is not supported" + if parallel_mode == "pipeline": + return PipelineGroupCoordinator( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + group_name="pp_group", + ) + elif parallel_mode == "sequence": + return SequenceParallelGroupCoordinator( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + group_name="sp_group", + **kwargs, + ) + else: + # fallback to GroupCoordinator + return GroupCoordinator( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + group_name="cfg_group", + ) + + +# def init_parallel_group_coordinator( +# group_ranks: list[list[int]], +# local_rank: int, +# backend: str, +# use_message_queue_broadcaster: bool = False, +# group_name: str | None = None, +# ) -> GroupCoordinator: +# return GroupCoordinator( +# group_ranks=group_ranks, +# local_rank=local_rank, +# torch_distributed_backend=backend, +# use_device_communicator=True, +# use_message_queue_broadcaster=use_message_queue_broadcaster, +# group_name=group_name, +# ) + + +_TP: GroupCoordinator | None = None + + +def get_tp_group() -> GroupCoordinator: + assert _TP is not None, "tensor model parallel group is not initialized" + return _TP + + +_ENABLE_CUSTOM_ALL_REDUCE = True + + +def set_custom_all_reduce(enable: bool): + global _ENABLE_CUSTOM_ALL_REDUCE + _ENABLE_CUSTOM_ALL_REDUCE = enable + + +def init_distributed_environment( + world_size: int = 1, + rank: int = 0, + distributed_init_method: str = "env://", + local_rank: int = 0, + backend: str = "nccl", + device_id: torch.device | None = None, +): + # Determine the appropriate backend based on the platform + from sglang.multimodal_gen.runtime.platforms import current_platform + + if backend == "nccl" and not current_platform.is_cuda_alike(): + # Use gloo backend for non-CUDA platforms (MPS, CPU) + backend = "gloo" + logger.info("Using gloo backend for %s platform", current_platform.device_name) + + logger.debug( + "world_size=%d rank=%d local_rank=%d " "distributed_init_method=%s backend=%s", + world_size, + rank, + local_rank, + distributed_init_method, + backend, + ) + if not torch.distributed.is_initialized(): + assert distributed_init_method is not None, ( + "distributed_init_method must be provided when initializing " + "distributed environment" + ) + + # For MPS, don't pass device_id as it doesn't support device indices + extra_args = {} if current_platform.is_mps() else dict(device_id=device_id) + torch.distributed.init_process_group( + backend=backend, + init_method=distributed_init_method, + world_size=world_size, + rank=rank, + **extra_args, + ) + # set the local rank + # local_rank is not available in torch ProcessGroup, + # see https://github.com/pytorch/pytorch/issues/122816 + if local_rank == -1: + # local rank not set, this usually happens in single-node + # setting, where we can use rank as local rank + if distributed_init_method == "env://": + local_rank = envs.LOCAL_RANK + else: + local_rank = rank + global _WORLD + if _WORLD is None: + ranks = list(range(torch.distributed.get_world_size())) + _WORLD = init_world_group(ranks, local_rank, backend) + else: + assert ( + _WORLD.world_size == torch.distributed.get_world_size() + ), "world group already initialized with a different world size" + + +_SP: GroupCoordinator | None = None + + +def get_sp_group() -> SequenceParallelGroupCoordinator: + assert _SP is not None, "pipeline model parallel group is not initialized" + return _SP + + +_DP: GroupCoordinator | None = None + + +def get_dp_group() -> GroupCoordinator: + assert _DP is not None, "data parallel group is not initialized" + return _DP + + +# xDiT +def initialize_model_parallel( + data_parallel_size: int = 1, + classifier_free_guidance_degree: int = 1, + sequence_parallel_degree: Optional[int] = None, + ulysses_degree: int = 1, + ring_degree: int = 1, + tensor_parallel_degree: int = 1, + pipeline_parallel_degree: int = 1, + vae_parallel_size: int = 0, + backend: Optional[str] = None, +) -> None: + """ + Initialize model parallel groups. + + Arguments: + data_parallel_size: number of data parallelism groups. + classifier_free_guidance_degree: number of GPUs used for Classifier Free Guidance (CFG) + sequence_parallel_degree: number of GPUs used for sequence parallelism. sequence_parallel_degree = ulysses_degree * ring_degree + ulysses_degree: number of GPUs used for ulysses sequence parallelism. + ring_degree: number of GPUs used for ring sequence parallelism. + tensor_parallel_degree: number of GPUs used for tensor parallelism. + pipeline_parallel_degree: number of GPUs used for pipeline parallelism. + backend: distributed backend of pytorch collective comm. + + Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we + use 2 groups to parallelize the batch dim(dp), 2 groups to parallelize + split batch caused by CFG, and 2 GPUs to parallelize sequence. + + dp_degree (2) * cfg_degree (2) * sp_degree (2) * pp_degree (2) = 16. + + The present function will create 8 data-parallel groups, + 8 CFG group, 8 pipeline-parallel group, and + 8 sequence-parallel groups: + 8 data-parallel groups: + [g0, g8], [g1, g9], [g2, g10], [g3, g11], + [g4, g12], [g5, g13], [g6, g14], [g7, g15] + 8 CFG-parallel groups: + [g0, g4], [g1, g5], [g2, g6], [g3, g7], + [g8, g12], [g9, g13], [g10, g14], [g11, g15] + 8 sequence-parallel groups: + [g0, g1], [g2, g3], [g4, g5], [g6, g7], + [g8, g9], [g10, g11], [g12, g13], [g14, g15] + 8 pipeline-parallel groups: + [g0, g2], [g4, g6], [g8, g10], [g12, g14], + [g1, g3], [g5, g7], [g9, g11], [g13, g15] + Note that for efficiency, the caller should make sure adjacent ranks + are on the same DGX box. For example if we are using 2 DGX-1 boxes + with a total of 16 GPUs, rank 0 to 7 belong to the first box and + ranks 8 to 15 belong to the second box. + """ + + if backend is None: + backend = envs.get_torch_distributed_backend() + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + world_size: int = torch.distributed.get_world_size() + backend = backend or torch.distributed.get_backend(get_world_group().device_group) + + dit_parallel_size = ( + data_parallel_size + * classifier_free_guidance_degree + * sequence_parallel_degree + * pipeline_parallel_degree + * tensor_parallel_degree + ) + + if world_size < dit_parallel_size: + raise RuntimeError( + f"world_size ({world_size}) is less than " + f"tensor_parallel_degree ({tensor_parallel_degree}) x " + f"pipeline_parallel_degree ({pipeline_parallel_degree}) x" + f"sequence_parallel_degree ({sequence_parallel_degree}) x" + f"classifier_free_guidance_degree " + f"({classifier_free_guidance_degree}) x" + f"data_parallel_degree ({data_parallel_size})" + ) + + rank_generator: RankGenerator = RankGenerator( + tensor_parallel_degree, + sequence_parallel_degree, + pipeline_parallel_degree, + classifier_free_guidance_degree, + data_parallel_size, + "tp-sp-pp-cfg-dp", + ) + global _DP + assert _DP is None, "data parallel group is already initialized" + _DP = init_parallel_group_coordinator( + group_ranks=rank_generator.get_ranks("dp"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="data", + ) + + global _CFG + assert _CFG is None, "classifier_free_guidance group is already initialized" + _CFG = init_parallel_group_coordinator( + group_ranks=rank_generator.get_ranks("cfg"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="classifier_free_guidance", + ) + global _PP + assert _PP is None, "pipeline model parallel group is already initialized" + _PP = init_parallel_group_coordinator( + group_ranks=rank_generator.get_ranks("pp"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="pipeline", + ) + + global _SP + assert _SP is None, "sequence parallel group is already initialized" + + from yunchang import set_seq_parallel_pg + from yunchang.globals import PROCESS_GROUP + + set_seq_parallel_pg( + sp_ulysses_degree=ulysses_degree, + sp_ring_degree=ring_degree, + rank=get_world_group().rank_in_group, + world_size=dit_parallel_size, + ) + + _SP = init_parallel_group_coordinator( + group_ranks=rank_generator.get_ranks("sp"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="sequence", + ulysses_group=PROCESS_GROUP.ULYSSES_PG, + ring_group=PROCESS_GROUP.RING_PG, + ) + + global _TP + assert _TP is None, "Tensor parallel group is already initialized" + _TP = init_parallel_group_coordinator( + group_ranks=rank_generator.get_ranks("tp"), + local_rank=get_world_group().local_rank, + backend=backend, + parallel_mode="tensor", + ) + + if vae_parallel_size > 0: + init_vae_group(dit_parallel_size, vae_parallel_size, backend) + init_dit_group(dit_parallel_size, backend) + + +# + + +# def initialize_model_parallel( +# tensor_model_parallel_size: int = 1, +# sequence_model_parallel_size: int = 1, +# data_parallel_size: int = 1, +# backend: str | None = None, +# ) -> None: +# """ +# Initialize model parallel groups. +# +# Arguments: +# tensor_model_parallel_size: number of GPUs used for tensor model +# parallelism (used for language encoder). +# sequence_model_parallel_size: number of GPUs used for sequence model +# parallelism (used for DiT). +# """ +# # Get world size and rank. Ensure some consistencies. +# assert ( +# _WORLD is not None +# ), "world group is not initialized, please call init_distributed_environment first" +# world_size: int = get_world_size() +# backend = backend or torch.distributed.get_backend(get_world_group().device_group) +# assert ( +# world_size >= tensor_model_parallel_size +# ), f"world_size({world_size}) must be greater than or equal to tensor_model_parallel_size({tensor_model_parallel_size})" +# num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size +# global _TP +# assert _TP is None, "tensor model parallel group is already initialized" +# group_ranks = [] +# for i in range(num_tensor_model_parallel_groups): +# ranks = list( +# range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) +# ) +# group_ranks.append(ranks) +# +# # message queue broadcaster is only used in tensor model parallel group +# _TP = init_parallel_group_coordinator( +# group_ranks, +# get_world_group().local_rank, +# backend, +# use_message_queue_broadcaster=True, +# group_name="tp", +# ) +# +# # Build the sequence model-parallel groups. +# num_sequence_model_parallel_groups: int = world_size // sequence_model_parallel_size +# global _SP +# assert _SP is None, "sequence model parallel group is already initialized" +# group_ranks = [] +# +# # Since SP is incompatible with TP and PP, we can use a simpler group creation logic +# for i in range(num_sequence_model_parallel_groups): +# # Create groups of consecutive ranks +# ranks = list( +# range( +# i * sequence_model_parallel_size, (i + 1) * sequence_model_parallel_size +# ) +# ) +# group_ranks.append(ranks) +# +# _SP = init_parallel_group_coordinator( +# group_ranks, get_world_group().local_rank, backend, group_name="sp" +# ) +# +# # Build the data parallel groups. +# num_data_parallel_groups: int = sequence_model_parallel_size +# global _DP +# assert _DP is None, "data parallel group is already initialized" +# group_ranks = [] +# +# for i in range(num_data_parallel_groups): +# ranks = list(range(i, world_size, num_data_parallel_groups)) +# group_ranks.append(ranks) +# +# _DP = init_parallel_group_coordinator( +# group_ranks, get_world_group().local_rank, backend, group_name="dp" +# ) +# + + +def get_sp_world_size() -> int: + """Return world size for the sequence model parallel group.""" + return get_sp_group().world_size + + +def get_sp_parallel_rank() -> int: + """Return my rank for the sequence model parallel group.""" + return get_sp_group().rank_in_group + + +def get_world_size() -> int: + """Return world size for the world group.""" + return get_world_group().world_size + + +def get_world_rank() -> int: + """Return my rank for the world group.""" + return get_world_group().rank + + +def get_dp_world_size() -> int: + """Return world size for the data parallel group.""" + return get_dp_group().world_size + + +def get_dp_rank() -> int: + """Return my rank for the data parallel group.""" + return get_dp_group().rank_in_group + + +def maybe_init_distributed_environment_and_model_parallel( + tp_size: int, + sp_size: int, + enable_cfg_parallel: bool, + ulysses_degree: int = 1, + ring_degree: int = 1, + dp_size: int = 1, + distributed_init_method: str = "env://", +): + from sglang.multimodal_gen.runtime.platforms import current_platform + + if _WORLD is not None and model_parallel_is_initialized(): + # make sure the tp and sp sizes are correct + assert ( + get_tp_world_size() == tp_size + ), f"You are trying to initialize model parallel groups with size {tp_size}, but they are already initialized with size {get_tp_world_size()}" + assert ( + get_sp_world_size() == sp_size + ), f"You are trying to initialize model parallel groups with size {sp_size}, but they are already initialized with size {get_sp_world_size()}" + return + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + rank = int(os.environ.get("RANK", 0)) + device = get_local_torch_device() + logger.info( + "Initializing distributed environment with world_size=%d, device=%s", + world_size, + device, + main_process_only=False, + ) + + init_distributed_environment( + world_size=world_size, + rank=rank, + local_rank=local_rank, + distributed_init_method=distributed_init_method, + device_id=device, + ) + initialize_model_parallel( + data_parallel_size=dp_size, + classifier_free_guidance_degree=2 if enable_cfg_parallel else 1, + tensor_parallel_degree=tp_size, + ulysses_degree=ulysses_degree, + ring_degree=ring_degree, + sequence_parallel_degree=sp_size, + ) + + # Only set CUDA device if we're on a CUDA platform + if current_platform.is_cuda_alike(): + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + + +def model_parallel_is_initialized() -> bool: + """Check if tensor, sequence parallel groups are initialized.""" + return _TP is not None and _SP is not None and _DP is not None and _CFG is not None + + +_TP_STATE_PATCHED = False + + +@contextmanager +def patch_tensor_parallel_group(tp_group: GroupCoordinator): + """Patch the tp group temporarily until this function ends. + + This method is for draft workers of speculative decoding to run draft model + with different tp degree from that of target model workers. + + Args: + tp_group (GroupCoordinator): the tp group coordinator + """ + global _TP_STATE_PATCHED + assert not _TP_STATE_PATCHED, "Should not call when it's already patched" + + _TP_STATE_PATCHED = True + old_tp_group = get_tp_group() + global _TP + _TP = tp_group + try: + yield + finally: + # restore the original state + _TP_STATE_PATCHED = False + _TP = old_tp_group + + +def get_tp_world_size() -> int: + """Return world size for the tensor model parallel group.""" + return get_tp_group().world_size + + +def get_tp_rank() -> int: + """Return my rank for the tensor model parallel group.""" + return get_tp_group().rank_in_group + + +def destroy_distributed_environment() -> None: + global _WORLD + if _WORLD: + _WORLD.destroy() + _WORLD = None + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + +def cleanup_dist_env_and_memory(shutdown_ray: bool = False): + destroy_model_parallel() + destroy_distributed_environment() + with contextlib.suppress(AssertionError): + torch.distributed.destroy_process_group() + if shutdown_ray: + import ray # Lazy import Ray + + ray.shutdown() + + +def is_the_same_node_as( + pg: ProcessGroup | StatelessProcessGroup, source_rank: int = 0 +) -> list[int]: + """ + This is a collective operation that returns if each rank is in the same node + as the source rank. It tests if processes are attached to the same + memory system (shared access to shared memory). + """ + if isinstance(pg, ProcessGroup): + assert ( + torch.distributed.get_backend(pg) != torch.distributed.Backend.NCCL + ), "in_the_same_node_as should be tested with a non-NCCL group." + # local rank inside the group + rank = torch.distributed.get_rank(group=pg) + world_size = torch.distributed.get_world_size(group=pg) + + # global ranks of the processes in the group + ranks = torch.distributed.get_process_group_ranks(pg) + else: + rank = pg.rank + world_size = pg.world_size + ranks = list(range(world_size)) + + # local tensor in each process to store the result + is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32) + + magic_message = b"magic_message" + shm = None + + try: + with contextlib.suppress(OSError): + if rank == source_rank: + # create a shared memory segment + shm = shared_memory.SharedMemory(create=True, size=128) + shm.buf[: len(magic_message)] = magic_message + if isinstance(pg, ProcessGroup): + torch.distributed.broadcast_object_list( + [shm.name], src=ranks[source_rank], group=pg + ) + else: + pg.broadcast_obj(shm.name, src=source_rank) + is_in_the_same_node[rank] = 1 + else: + # try to open the shared memory segment + if isinstance(pg, ProcessGroup): + recv = [None] + torch.distributed.broadcast_object_list( + recv, src=ranks[source_rank], group=pg + ) + name = recv[0] + else: + name = pg.broadcast_obj(None, src=source_rank) + # fix to https://stackoverflow.com/q/62748654/9191338 + # Python incorrectly tracks shared memory even if it is not + # created by the process. The following patch is a workaround. + with patch( + "multiprocessing.resource_tracker.register", + lambda *args, **kwargs: None, + ): + shm = shared_memory.SharedMemory(name=name) + if shm.buf[: len(magic_message)] == magic_message: + is_in_the_same_node[rank] = 1 + except Exception as e: + logger.error("Error ignored in is_in_the_same_node: %s", e) + finally: + if shm: + shm.close() + + if isinstance(pg, ProcessGroup): + torch.distributed.barrier(group=pg) + else: + pg.barrier() + + # clean up the shared memory segment + with contextlib.suppress(OSError): + if rank == source_rank and shm: + shm.unlink() + + if isinstance(pg, ProcessGroup): + torch.distributed.all_reduce(is_in_the_same_node, group=pg) + aggregated_data = is_in_the_same_node + else: + aggregated_data = torch.zeros_like(is_in_the_same_node) + for i in range(world_size): + rank_data = pg.broadcast_obj(is_in_the_same_node, src=i) + aggregated_data += rank_data + + return [x == 1 for x in aggregated_data.tolist()] + + +def initialize_tensor_parallel_group( + tensor_model_parallel_size: int = 1, + backend: str | None = None, + group_name_suffix: str = "", +) -> GroupCoordinator: + """Initialize a tensor parallel group for a specific model. + + This function creates a tensor parallel group that can be used with the + patch_tensor_parallel_group context manager. It allows different models + to use different tensor parallelism configurations. + + Arguments: + tensor_model_parallel_size: number of GPUs used for tensor model parallelism. + backend: communication backend to use. + group_name_suffix: optional suffix to make the group name unique. + + Returns: + A GroupCoordinator for tensor parallelism that can be used with + the patch_tensor_parallel_group context manager. + + Example usage: + ```python + # Initialize tensor parallel group for model1 + tp_group_model1 = initialize_tensor_parallel_group( + tensor_model_parallel_size=4, + group_name_suffix="model1" + ) + + # Use tensor parallelism for model1 + with patch_tensor_parallel_group(tp_group_model1): + # Run model1 with tensor parallelism + output1 = model1(input1) + ``` + """ + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + world_size: int = torch.distributed.get_world_size() + backend = backend or torch.distributed.get_backend(get_world_group().device_group) + + # Ensure the world size is compatible with the parallelism configuration + assert ( + world_size % tensor_model_parallel_size == 0 + ), f"World size ({world_size}) must be divisible by tensor_model_parallel_size ({tensor_model_parallel_size})" + + # Build the tensor model-parallel groups. + num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size + tp_group_ranks = [] + for i in range(num_tensor_model_parallel_groups): + ranks = list( + range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) + ) + tp_group_ranks.append(ranks) + + # Create TP group coordinator with a unique name + group_name = f"tp_{group_name_suffix}" if group_name_suffix else "tp" + tp_group = init_parallel_group_coordinator( + tp_group_ranks, + get_world_group().local_rank, + backend, + use_message_queue_broadcaster=True, + group_name=group_name, + ) + + return tp_group + + +def initialize_sequence_parallel_group( + sequence_model_parallel_size: int = 1, + backend: str | None = None, + group_name_suffix: str = "", +) -> GroupCoordinator: + """Initialize a sequence parallel group for a specific model. + + This function creates a sequence parallel group that can be used with the + patch_sequence_parallel_group context manager. It allows different models + to use different sequence parallelism configurations. + + Arguments: + sequence_model_parallel_size: number of GPUs used for sequence model parallelism. + backend: communication backend to use. + group_name_suffix: optional suffix to make the group name unique. + + Returns: + A GroupCoordinator for sequence parallelism that can be used with + the patch_sequence_parallel_group context manager. + + Example usage: + ```python + # Initialize sequence parallel group for model2 + sp_group_model2 = initialize_sequence_parallel_group( + sequence_model_parallel_size=2, + group_name_suffix="model2" + ) + + # Use sequence parallelism for model2 + with patch_sequence_parallel_group(sp_group_model2): + # Run model2 with sequence parallelism + output2 = model2(input2) + ``` + """ + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + world_size: int = torch.distributed.get_world_size() + backend = backend or torch.distributed.get_backend(get_world_group().device_group) + + # Ensure the world size is compatible with the parallelism configuration + assert ( + world_size % sequence_model_parallel_size == 0 + ), f"World size ({world_size}) must be divisible by sequence_model_parallel_size ({sequence_model_parallel_size})" + + # Build the sequence model-parallel groups. + num_sequence_model_parallel_groups: int = world_size // sequence_model_parallel_size + sp_group_ranks = [] + + for i in range(num_sequence_model_parallel_groups): + # Create groups of consecutive ranks + ranks = list( + range( + i * sequence_model_parallel_size, (i + 1) * sequence_model_parallel_size + ) + ) + sp_group_ranks.append(ranks) + + # Create SP group coordinator with a unique name + group_name = f"sp_{group_name_suffix}" if group_name_suffix else "sp" + sp_group = init_parallel_group_coordinator( + sp_group_ranks, get_world_group().local_rank, backend, group_name=group_name + ) + + return sp_group + + +# * QUERY +def get_world_group() -> GroupCoordinator: + assert _WORLD is not None, "world group is not initialized" + return _WORLD + + +# TP +def get_tp_group() -> GroupCoordinator: + assert _TP is not None, "tensor model parallel group is not initialized" + return _TP + + +def get_tensor_model_parallel_world_size(): + """Return world size for the tensor model parallel group.""" + return get_tp_group().world_size + + +def get_tensor_model_parallel_rank(): + """Return my rank for the tensor model parallel group.""" + return get_tp_group().rank_in_group + + +def get_sequence_parallel_world_size(): + """Return world size for the sequence parallel group.""" + return get_sp_group().world_size + + +def get_sequence_parallel_rank(): + """Return my rank for the sequence parallel group.""" + return get_sp_group().rank_in_group + + +def get_ulysses_parallel_world_size(): + return get_sp_group().ulysses_world_size + + +def get_ulysses_parallel_rank(): + return get_sp_group().ulysses_rank + + +def get_ring_parallel_world_size(): + return get_sp_group().ring_world_size + + +def get_ring_parallel_rank(): + return get_sp_group().ring_rank + + +# PP +def get_pp_group() -> PipelineGroupCoordinator: + assert _PP is not None, "pipeline model parallel group is not initialized" + return _PP + + +def get_pipeline_parallel_world_size(): + """Return world size for the pipeline model parallel group.""" + return get_pp_group().world_size + + +def get_pipeline_parallel_rank(): + """Return my rank for the pipeline model parallel group.""" + return get_pp_group().rank_in_group + + +def is_pipeline_first_stage(): + """Return True if in the first pipeline model parallel stage, False otherwise.""" + return get_pipeline_parallel_rank() == 0 + + +def is_pipeline_last_stage(): + """Return True if in the last pipeline model parallel stage, False otherwise.""" + return get_pipeline_parallel_rank() == (get_pipeline_parallel_world_size() - 1) + + +# CFG +def get_cfg_group() -> GroupCoordinator: + assert ( + _CFG is not None + ), "classifier_free_guidance parallel group is not initialized" + return _CFG + + +def get_classifier_free_guidance_world_size(): + """Return world size for the classifier_free_guidance parallel group.""" + return get_cfg_group().world_size + + +def get_classifier_free_guidance_rank(): + """Return my rank for the classifier_free_guidance parallel group.""" + return get_cfg_group().rank_in_group + + +# DP +def get_dp_group() -> GroupCoordinator: + assert _DP is not None, "pipeline model parallel group is not initialized" + return _DP + + +def get_data_parallel_world_size(): + """Return world size for the data parallel group.""" + return get_dp_group().world_size + + +def get_data_parallel_rank(): + """Return my rank for the data parallel group.""" + return get_dp_group().rank_in_group + + +def is_dp_last_group(): + """Return True if in the last data parallel group, False otherwise.""" + return ( + get_sequence_parallel_rank() == (get_sequence_parallel_world_size() - 1) + and get_classifier_free_guidance_rank() + == (get_classifier_free_guidance_world_size() - 1) + and get_pipeline_parallel_rank() == (get_pipeline_parallel_world_size() - 1) + ) + + +def get_dit_world_size(): + """Return world size for the DiT model (excluding VAE).""" + return ( + get_data_parallel_world_size() + * get_classifier_free_guidance_world_size() + * get_sequence_parallel_world_size() + * get_pipeline_parallel_world_size() + * get_tensor_model_parallel_world_size() + ) + + +# Add VAE getter functions +def get_vae_parallel_group() -> GroupCoordinator: + assert _VAE is not None, "VAE parallel group is not initialized" + return _VAE + + +def get_vae_parallel_world_size(): + """Return world size for the VAE parallel group.""" + return get_vae_parallel_group().world_size + + +def get_vae_parallel_rank(): + """Return my rank for the VAE parallel group.""" + return get_vae_parallel_group().rank_in_group + + +# * SET + + +def init_world_group( + ranks: List[int], local_rank: int, backend: str +) -> GroupCoordinator: + return GroupCoordinator( + group_ranks=[ranks], + local_rank=local_rank, + torch_distributed_backend=backend, + ) + + +def model_parallel_is_initialized(): + """Check if tensor and pipeline parallel groups are initialized.""" + return ( + _DP is not None + and _CFG is not None + and _SP is not None + and _PP is not None + and _TP is not None + ) + + +def init_dit_group( + dit_parallel_size: int, + backend: str, +): + global _DIT + _DIT = torch.distributed.new_group( + ranks=list(range(dit_parallel_size)), backend=backend + ) + + +def get_dit_group(): + assert _DIT is not None, "DIT group is not initialized" + return _DIT + + +def init_vae_group( + dit_parallel_size: int, + vae_parallel_size: int, + backend: str, +): + # Initialize VAE group first + global _VAE + assert _VAE is None, "VAE parallel group is already initialized" + vae_ranks = list(range(dit_parallel_size, dit_parallel_size + vae_parallel_size)) + _VAE = torch.distributed.new_group(ranks=vae_ranks, backend=backend) + + +def destroy_model_parallel() -> None: + """Set the groups to none and destroy them.""" + global _TP + if _TP: + _TP.destroy() + _TP = None + + global _SP + if _SP: + _SP.destroy() + _SP = None + + global _DP + if _DP: + _DP.destroy() + _DP = None + + +# xDit +# def destroy_model_parallel(): +# """Set the groups to none and destroy them.""" +# global _DP +# if _DP: +# _DP.destroy() +# _DP = None +# +# global _CFG +# if _CFG: +# _CFG.destroy() +# _CFG = None +# +# global _SP +# if _SP: +# _SP.destroy() +# _SP = None +# +# global _TP +# if _TP: +# _TP.destroy() +# _TP = None +# +# global _PP +# if _PP: +# _PP.destroy() +# _PP = None +# +# global _VAE +# if _VAE: +# _VAE.destroy() +# _VAE = None + + +def destroy_distributed_environment(): + global _WORLD + if _WORLD: + _WORLD.destroy() + _WORLD = None + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() diff --git a/python/sglang/multimodal_gen/runtime/distributed/utils.py b/python/sglang/multimodal_gen/runtime/distributed/utils.py new file mode 100644 index 000000000..2d84f8b52 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/distributed/utils.py @@ -0,0 +1,195 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/distributed/utils.py + +# Copyright 2023 The vLLM team. +# Adapted from +# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +import dataclasses +import pickle +import time +from collections import deque +from collections.abc import Sequence +from typing import Any + +import torch +from torch.distributed import TCPStore + +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +def ensure_divisibility(numerator, denominator) -> None: + """Ensure that numerator is divisible by the denominator.""" + assert numerator % denominator == 0, "{} is not divisible by {}".format( + numerator, denominator + ) + + +def divide(numerator: int, denominator: int) -> int: + """Ensure that numerator is divisible by the denominator and return + the division value.""" + ensure_divisibility(numerator, denominator) + return numerator // denominator + + +def split_tensor_along_last_dim( + tensor: torch.Tensor, + num_partitions: int, + contiguous_split_chunks: bool = False, +) -> Sequence[torch.Tensor]: + """Split a tensor along its last dimension. + + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + + Returns: + A list of Tensors + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = divide(tensor.size()[last_dim], num_partitions) + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # NOTE: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tuple(tensor_list) + + +@dataclasses.dataclass +class StatelessProcessGroup: + """A dataclass to hold a metadata store, and the rank, world_size of the + group. Only use it to communicate metadata between processes. + For data-plane communication, create NCCL-related objects. + """ + + rank: int + world_size: int + store: torch._C._distributed_c10d.Store + data_expiration_seconds: int = 3600 # 1 hour + + # dst rank -> counter + send_dst_counter: dict[int, int] = dataclasses.field(default_factory=dict) + # src rank -> counter + recv_src_counter: dict[int, int] = dataclasses.field(default_factory=dict) + broadcast_send_counter: int = 0 + broadcast_recv_src_counter: dict[int, int] = dataclasses.field(default_factory=dict) + + # A deque to store the data entries, with key and timestamp. + entries: deque[tuple[str, float]] = dataclasses.field(default_factory=deque) + + def __post_init__(self): + assert self.rank < self.world_size + self.send_dst_counter = {i: 0 for i in range(self.world_size)} + self.recv_src_counter = {i: 0 for i in range(self.world_size)} + self.broadcast_recv_src_counter = {i: 0 for i in range(self.world_size)} + + def send_obj(self, obj: Any, dst: int): + """Send an object to a destination rank.""" + self.expire_data() + key = f"send_to/{dst}/{self.send_dst_counter[dst]}" + self.store.set(key, pickle.dumps(obj)) + self.send_dst_counter[dst] += 1 + self.entries.append((key, time.perf_counter())) + + def expire_data(self) -> None: + """Expire data that is older than `data_expiration_seconds` seconds.""" + while self.entries: + # check the oldest entry + key, timestamp = self.entries[0] + if time.perf_counter() - timestamp > self.data_expiration_seconds: + self.store.delete_key(key) + self.entries.popleft() + else: + break + + def recv_obj(self, src: int) -> Any: + """Receive an object from a source rank.""" + obj = pickle.loads( + self.store.get(f"send_to/{self.rank}/{self.recv_src_counter[src]}") + ) + self.recv_src_counter[src] += 1 + return obj + + def broadcast_obj(self, obj: Any | None, src: int) -> Any: + """Broadcast an object from a source rank to all other ranks. + It does not clean up after all ranks have received the object. + Use it for limited times, e.g., for initialization. + """ + if self.rank == src: + self.expire_data() + key = f"broadcast_from/{src}/" f"{self.broadcast_send_counter}" + self.store.set(key, pickle.dumps(obj)) + self.broadcast_send_counter += 1 + self.entries.append((key, time.perf_counter())) + return obj + else: + key = f"broadcast_from/{src}/" f"{self.broadcast_recv_src_counter[src]}" + recv_obj = pickle.loads(self.store.get(key)) + self.broadcast_recv_src_counter[src] += 1 + return recv_obj + + def all_gather_obj(self, obj: Any) -> list[Any]: + """All gather an object from all ranks.""" + gathered_objs = [] + for i in range(self.world_size): + if i == self.rank: + gathered_objs.append(obj) + self.broadcast_obj(obj, src=self.rank) + else: + recv_obj = self.broadcast_obj(None, src=i) + gathered_objs.append(recv_obj) + return gathered_objs + + def barrier(self): + """A barrier to synchronize all ranks.""" + for i in range(self.world_size): + if i == self.rank: + self.broadcast_obj(None, src=self.rank) + else: + self.broadcast_obj(None, src=i) + + @staticmethod + def create( + host: str, + port: int, + rank: int, + world_size: int, + data_expiration_seconds: int = 3600, + ) -> "StatelessProcessGroup": + """A replacement for `torch.distributed.init_process_group` that does not + pollute the global state. + + If we have process A and process B called `torch.distributed.init_process_group` + to form a group, and then we want to form another group with process A, B, C, + D, it is not possible in PyTorch, because process A and process B have already + formed a group, and process C and process D cannot join that group. This + function is a workaround for this issue. + + `torch.distributed.init_process_group` is a global call, while this function + is a stateless call. It will return a `StatelessProcessGroup` object that can be + used for exchanging metadata. With this function, process A and process B + can call `StatelessProcessGroup.create` to form a group, and then process A, B, + C, and D can call `StatelessProcessGroup.create` to form another group. + """ # noqa + store = TCPStore( + host_name=host, + port=port, + world_size=world_size, + is_master=(rank == 0), + ) + + return StatelessProcessGroup( + rank=rank, + world_size=world_size, + store=store, + data_expiration_seconds=data_expiration_seconds, + ) diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/__init__.py b/python/sglang/multimodal_gen/runtime/entrypoints/__init__.py new file mode 100644 index 000000000..af2eb7d10 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/entrypoints/__init__.py @@ -0,0 +1 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/cli/__init__.py b/python/sglang/multimodal_gen/runtime/entrypoints/cli/__init__.py new file mode 100644 index 000000000..af2eb7d10 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/entrypoints/cli/__init__.py @@ -0,0 +1 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/cli/cli_types.py b/python/sglang/multimodal_gen/runtime/entrypoints/cli/cli_types.py new file mode 100644 index 000000000..2e5107ec0 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/entrypoints/cli/cli_types.py @@ -0,0 +1,28 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/entrypoints/cli/types.py + +import argparse + +from sglang.multimodal_gen.utils import FlexibleArgumentParser + + +class CLISubcommand: + """Base class for CLI subcommands""" + + name: str + + def cmd(self, args: argparse.Namespace) -> None: + """Execute the command with the given arguments""" + raise NotImplementedError + + def validate(self, args: argparse.Namespace) -> None: + """Validate the arguments for this command""" + pass + + def subparser_init( + self, subparsers: argparse._SubParsersAction + ) -> FlexibleArgumentParser: + """Initialize the subparser for this command""" + raise NotImplementedError diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/cli/generate.py b/python/sglang/multimodal_gen/runtime/entrypoints/cli/generate.py new file mode 100644 index 000000000..22e53bdd6 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/entrypoints/cli/generate.py @@ -0,0 +1,103 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/entrypoints/cli/serve.py + +import argparse +import dataclasses +import os +from typing import cast + +from sglang.multimodal_gen import DiffGenerator +from sglang.multimodal_gen.configs.sample.base import ( + SamplingParams, + generate_request_id, +) +from sglang.multimodal_gen.runtime.entrypoints.cli.cli_types import CLISubcommand +from sglang.multimodal_gen.runtime.entrypoints.cli.utils import ( + RaiseNotImplementedAction, +) +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import FlexibleArgumentParser + +logger = init_logger(__name__) + + +def add_multimodal_gen_generate_args(parser: argparse.ArgumentParser): + """Add the arguments for the generate command.""" + parser.add_argument( + "--config", + type=str, + default="", + required=False, + help="Read CLI options from a config JSON or YAML file. If provided, --model-path and --prompt are optional.", + ) + + parser = ServerArgs.add_cli_args(parser) + parser = SamplingParams.add_cli_args(parser) + + parser.add_argument( + "--text-encoder-configs", + action=RaiseNotImplementedAction, + help="JSON array of text encoder configurations (NOT YET IMPLEMENTED)", + ) + + return parser + + +def generate_cmd(args: argparse.Namespace): + """The entry point for the generate command.""" + # FIXME(mick): do not hard code + args.request_id = generate_request_id() + + server_args = ServerArgs.from_cli_args(args) + sampling_params = SamplingParams.from_cli_args(args) + sampling_params.request_id = generate_request_id() + generator = DiffGenerator.from_pretrained( + model_path=server_args.model_path, server_args=server_args + ) + + generator.generate(prompt=sampling_params.prompt, sampling_params=sampling_params) + + +class GenerateSubcommand(CLISubcommand): + """The `generate` subcommand for the sgl-diffusion CLI""" + + def __init__(self) -> None: + self.name = "generate" + super().__init__() + self.init_arg_names = self._get_init_arg_names() + self.generation_arg_names = self._get_generation_arg_names() + + def _get_init_arg_names(self) -> list[str]: + """Get names of arguments for DiffGenerator initialization""" + return ["num_gpus", "tp_size", "sp_size", "model_path"] + + def _get_generation_arg_names(self) -> list[str]: + """Get names of arguments for generate_video method""" + return [field.name for field in dataclasses.fields(SamplingParams)] + + def cmd(self, args: argparse.Namespace) -> None: + generate_cmd(args) + + def validate(self, args: argparse.Namespace) -> None: + """Validate the arguments for this command""" + if args.num_gpus is not None and args.num_gpus <= 0: + raise ValueError("Number of gpus must be positive") + + if args.config and not os.path.exists(args.config): + raise ValueError(f"Config file not found: {args.config}") + + def subparser_init( + self, subparsers: argparse._SubParsersAction + ) -> FlexibleArgumentParser: + generate_parser = subparsers.add_parser( + "generate", + help="Run inference on a model", + usage="sgl_diffusion generate (--model-path MODEL_PATH_OR_ID --prompt PROMPT) | --config CONFIG_FILE [OPTIONS]", + ) + + generate_parser = add_multimodal_gen_generate_args(generate_parser) + + return cast(FlexibleArgumentParser, generate_parser) diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/cli/main.py b/python/sglang/multimodal_gen/runtime/entrypoints/cli/main.py new file mode 100644 index 000000000..5158aab01 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/entrypoints/cli/main.py @@ -0,0 +1,44 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/entrypoints/cli/main.py + +from sglang.multimodal_gen.runtime.entrypoints.cli.cli_types import CLISubcommand +from sglang.multimodal_gen.runtime.entrypoints.cli.generate import GenerateSubcommand +from sglang.multimodal_gen.runtime.entrypoints.cli.serve import ServeSubcommand +from sglang.multimodal_gen.utils import FlexibleArgumentParser + + +def generate_cmd_init() -> list[CLISubcommand]: + return [GenerateSubcommand(), ServeSubcommand()] + + +def cmd_init() -> list[CLISubcommand]: + """Initialize all commands from separate modules""" + commands = [] + commands.extend(generate_cmd_init()) + return commands + + +def main() -> None: + parser = FlexibleArgumentParser(description="sgl-diffusion CLI") + parser.add_argument("-v", "--version", action="version", version="0.1.0") + + subparsers = parser.add_subparsers(required=False, dest="subparser") + + cmds = {} + for cmd in cmd_init(): + cmd.subparser_init(subparsers).set_defaults(dispatch_function=cmd.cmd) + cmds[cmd.name] = cmd + args = parser.parse_args() + if args.subparser in cmds: + cmds[args.subparser].validate(args) + + if hasattr(args, "dispatch_function"): + args.dispatch_function(args) + else: + parser.print_help() + + +if __name__ == "__main__": + main() diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/cli/serve.py b/python/sglang/multimodal_gen/runtime/entrypoints/cli/serve.py new file mode 100644 index 000000000..3df5e3fd0 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/entrypoints/cli/serve.py @@ -0,0 +1,69 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import os +from typing import cast + +from sglang.multimodal_gen.runtime.entrypoints.cli.cli_types import CLISubcommand +from sglang.multimodal_gen.runtime.launch_server import launch_server +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import FlexibleArgumentParser + +logger = init_logger(__name__) + + +def add_multimodal_gen_serve_args(parser: argparse.ArgumentParser): + """Add the arguments for the serve command.""" + parser.add_argument( + "--config", + type=str, + default="", + required=False, + help="Read CLI options from a config JSON or YAML file.", + ) + return ServerArgs.add_cli_args(parser) + + +def execute_serve_cmd(args: argparse.Namespace, unknown_args: list[str] | None = None): + """The entry point for the serve command.""" + server_args = ServerArgs.from_cli_args(args, unknown_args) + server_args.post_init_serve() + launch_server(server_args) + + +class ServeSubcommand(CLISubcommand): + """The `serve` subcommand for the sgl-diffusion CLI""" + + def __init__(self) -> None: + self.name = "serve" + super().__init__() + + def cmd( + self, args: argparse.Namespace, unknown_args: list[str] | None = None + ) -> None: + execute_serve_cmd(args, unknown_args) + + def validate(self, args: argparse.Namespace) -> None: + """Validate the arguments for this command""" + if args.config and not os.path.exists(args.config): + raise ValueError(f"Config file not found: {args.config}") + + def subparser_init( + self, subparsers: argparse._SubParsersAction + ) -> FlexibleArgumentParser: + serve_parser = subparsers.add_parser( + "serve", + help="Launch the server and start FastAPI listener.", + usage="sgl_diffusion serve --model-path MODEL_PATH_OR_ID [OPTIONS]", + ) + + serve_parser = add_multimodal_gen_serve_args(serve_parser) + + return cast(FlexibleArgumentParser, serve_parser) + + +def cmd_init() -> list[CLISubcommand]: + return [ServeSubcommand()] diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/cli/utils.py b/python/sglang/multimodal_gen/runtime/entrypoints/cli/utils.py new file mode 100644 index 000000000..a4fc75272 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/entrypoints/cli/utils.py @@ -0,0 +1,74 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import os +import subprocess +import sys + +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class RaiseNotImplementedAction(argparse.Action): + + def __call__(self, parser, namespace, values, option_string=None): + raise NotImplementedError(f"The {option_string} option is not yet implemented") + + +def launch_distributed( + num_gpus: int, args: list[str], master_port: int | None = None +) -> int: + """ + Launch a distributed job with the given arguments + + Args: + num_gpus: Number of GPUs to use + args: Arguments to pass to v1_sgl_diffusion_inference.py (defaults to sys.argv[1:]) + master_port: Port for the master process (default: random) + """ + + current_env = os.environ.copy() + python_executable = sys.executable + project_root = os.path.abspath( + os.path.join(os.path.dirname(__file__), "../../../..") + ) + main_script = os.path.join( + project_root, "sgl_diffusion/sample/v1_sgl_diffusion_inference.py" + ) + + cmd = [ + python_executable, + "-m", + "torch.distributed.run", + f"--nproc_per_node={num_gpus}", + ] + + if master_port is not None: + cmd.append(f"--master_port={master_port}") + + cmd.append(main_script) + cmd.extend(args) + + logger.info("Running inference with %d GPU(s)", num_gpus) + logger.info("Launching command: %s", " ".join(cmd)) + + current_env["PYTHONIOENCODING"] = "utf-8" + process = subprocess.Popen( + cmd, + env=current_env, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + universal_newlines=True, + bufsize=1, + encoding="utf-8", + errors="replace", + ) + + if process.stdout: + for line in iter(process.stdout.readline, ""): + print(line.strip()) + + return process.wait() diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/diffusion_generator.py b/python/sglang/multimodal_gen/runtime/entrypoints/diffusion_generator.py new file mode 100644 index 000000000..9f8d86f82 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/entrypoints/diffusion_generator.py @@ -0,0 +1,429 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +DiffGenerator module for sgl-diffusion. + +This module provides a consolidated interface for generating videos using +diffusion models. +""" + +import logging +import multiprocessing as mp +import os +import time +from copy import deepcopy +from typing import Any + +import imageio +import numpy as np +import torch +import torchvision +from einops import rearrange + +# Suppress verbose logging from imageio, which is triggered when saving images. +logging.getLogger("imageio").setLevel(logging.WARNING) +logging.getLogger("imageio_ffmpeg").setLevel(logging.WARNING) +# Suppress Pillow plugin import logs when app log level is DEBUG +logging.getLogger("PIL").setLevel(logging.WARNING) +logging.getLogger("PIL.Image").setLevel(logging.WARNING) + +from sglang.multimodal_gen.configs.sample.base import DataType, SamplingParams +from sglang.multimodal_gen.runtime.entrypoints.utils import prepare_request +from sglang.multimodal_gen.runtime.launch_server import launch_server +from sglang.multimodal_gen.runtime.managers.schedulerbase import SchedulerBase +from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import OutputBatch, Req +from sglang.multimodal_gen.runtime.server_args import PortArgs, ServerArgs +from sglang.multimodal_gen.runtime.sync_scheduler_client import sync_scheduler_client +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + +# TODO: move to somewhere appropriate +try: + # Set the start method to 'spawn' to avoid CUDA errors in forked processes. + # This must be done at the top level of the module, before any CUDA context + # or other processes are initialized. + mp.set_start_method("spawn", force=True) +except RuntimeError: + # The start method can only be set once per program execution. + pass + + +# TODO: rename +class DiffGenerator: + """ + A unified class for generating images/videos using diffusion models. + + This class provides a simple interface for image/video generation with rich + customization options, similar to popular frameworks like HF Diffusers. + """ + + def __init__( + self, + server_args: ServerArgs, + ): + """ + Initialize the generator. + + Args: + server_args: The inference arguments + """ + self.server_args = server_args + self.port_args = PortArgs.from_server_args(server_args) + + # The executor is now a client to the Scheduler service + self.local_scheduler_process: list[mp.Process] | None = None + self.owns_scheduler_client: bool = False + + @classmethod + def from_pretrained( + cls, + **kwargs, + ) -> "DiffGenerator": + """ + Create a DiffGenerator from a pretrained model. + + Args: + **kwargs: Additional arguments to customize model loading, set any ServerArgs or PipelineConfig attributes here. + + Returns: + The created DiffGenerator + + Priority level: Default pipeline config < User's pipeline config < User's kwargs + """ + # If users also provide some kwargs, it will override the ServerArgs and PipelineConfig. + + if (server_args := kwargs.get("server_args", None)) is not None: + if isinstance(server_args, ServerArgs): + pass + elif isinstance(server_args, dict): + server_args = ServerArgs.from_kwargs(**server_args) + else: + server_args = ServerArgs.from_kwargs(**kwargs) + + return cls.from_server_args(server_args) + + @classmethod + def from_server_args(cls, server_args: ServerArgs) -> "DiffGenerator": + """ + Create a DiffGenerator with the specified arguments. + + Args: + server_args: The inference arguments + + Returns: + The created DiffGenerator + """ + executor_class = SchedulerBase.get_class(server_args) + instance = cls( + server_args=server_args, + ) + is_local_mode = server_args.is_local_mode + logger.info(f"Local mode: {is_local_mode}") + if is_local_mode: + instance.local_scheduler_process = instance._start_local_server_if_needed() + else: + # In remote mode, we just need to connect and check. + sync_scheduler_client.initialize(server_args) + instance._check_remote_scheduler() + + # In both modes, this DiffGenerator instance is responsible for the client's lifecycle. + instance.owns_scheduler_client = True + return instance + + def _start_local_server_if_needed( + self, + ) -> list[mp.Process]: + """Check if a local server is running; if not, start it and return the process handles.""" + # First, we need a client to test the server. Initialize it temporarily. + sync_scheduler_client.initialize(self.server_args) + + processes = launch_server(self.server_args, launch_http_server=False) + + return processes + + def _check_remote_scheduler(self): + """Check if the remote scheduler is accessible.""" + if not sync_scheduler_client.ping(): + raise ConnectionError( + f"Could not connect to remote scheduler at " + f"{self.server_args.scheduler_endpoint()} with `local mode` as False. " + "Please ensure the server is running." + ) + logger.info( + f"Successfully connected to remote scheduler at " + f"{self.server_args.scheduler_endpoint()}." + ) + + def post_process_sample( + self, + sample: torch.Tensor, + data_type: DataType, + fps: int, + save_output: bool = True, + save_file_path: str = None, + ): + """ + Process a single sample output and save output if necessary + """ + # Process outputs + if sample.dim() == 3: + # for images, dim t is missing + sample = sample.unsqueeze(1) + sample = rearrange(sample, "c t h w -> t c h w") + frames = [] + # TODO: this can be batched + for x in sample: + x = torchvision.utils.make_grid(x, nrow=6) + x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) + frames.append((x * 255).numpy().astype(np.uint8)) + + # Save outputs if requested + if save_output: + if save_file_path: + os.makedirs(os.path.dirname(save_file_path), exist_ok=True) + if data_type == DataType.VIDEO: + imageio.mimsave( + save_file_path, + frames, + fps=fps, + format=data_type.get_default_extension(), + ) + else: + imageio.imwrite(save_file_path, frames[0]) + logger.info("Saved output to %s", save_file_path) + else: + logger.warning("No output path provided, output not saved") + + return frames + + def generate( + self, + prompt: str | list[str] | None = None, + sampling_params: SamplingParams | None = None, + **kwargs, + ) -> dict[str, Any] | list[np.ndarray] | list[dict[str, Any]] | None: + """ + Generate a image/video based on the given prompt. + + Args: + prompt: The prompt to use for generation (optional if prompt_txt is provided) + output_file_name: Name of the file to save. Default is the first 100 characters of the prompt. + save_output: Whether to save the output to disk + return_frames: Whether to return the raw frames + num_inference_steps: Number of denoising steps (overrides server_args) + guidance_scale: Classifier-free guidance scale (overrides server_args) + num_frames: Number of frames to generate (overrides server_args) + height: Height of generated file (overrides server_args) + width: Width of generated file (overrides server_args) + fps: Frames per second for saved file (overrides server_args) + seed: Random seed for generation (overrides server_args) + callback: Callback function called after each step + callback_steps: Number of steps between each callback + + Returns: + Either the output dictionary, list of frames, or list of results for batch processing + """ + # 1. prepare requests + prompts: list[str] = [] + # Handle batch processing from text file + if self.server_args.prompt_file_path is not None: + prompt_txt_path = self.server_args.prompt_file_path + if not os.path.exists(prompt_txt_path): + raise FileNotFoundError( + f"Prompt text file not found: {prompt_txt_path}" + ) + # Read prompts from file + with open(prompt_txt_path, encoding="utf-8") as f: + prompts.extend(line.strip() for line in f if line.strip()) + + if not prompts: + raise ValueError(f"No prompts found in file: {prompt_txt_path}") + + logger.info("Found %d prompts in %s", len(prompts), prompt_txt_path) + elif prompt is not None: + if isinstance(prompt, str): + prompts.append(prompt) + elif isinstance(prompt, list): + prompts.extend(prompt) + else: + raise ValueError("Either prompt or prompt_txt must be provided") + + pretrained_sampling_params = SamplingParams.from_pretrained( + self.server_args.model_path, **kwargs + ) + pretrained_sampling_params._merge_with_user_params(sampling_params) + # TODO: simplify + data_type = ( + DataType.IMAGE + if self.server_args.pipeline_config.is_image_gen + or sampling_params.num_frames == 1 + else DataType.VIDEO + ) + sampling_params.data_type = data_type + pretrained_sampling_params.set_output_file_name() + + requests: list[Req] = [] + for output_idx, p in enumerate(prompts): + current_sampling_params = deepcopy(pretrained_sampling_params) + current_sampling_params.prompt = p + requests.append( + prepare_request( + p, + server_args=self.server_args, + sampling_params=current_sampling_params, + ) + ) + + results = [] + total_start_time = time.perf_counter() + # 2. send requests to scheduler, one at a time + # TODO: send batch when supported + for request_idx, req in enumerate(requests): + logger.info( + "Processing prompt %d/%d: %s...", + request_idx + 1, + len(requests), + req.prompt[:100], + ) + try: + start_time = time.perf_counter() + output_batch = self._send_to_scheduler_and_wait_for_response([req]) + gen_time = time.perf_counter() - start_time + if output_batch.error: + raise Exception(f"{output_batch.error}") + + # FIXME: in generate mode, an internal assertion error won't raise an error + logger.info( + "Pixel data generated successfully in %.2f seconds", + gen_time, + ) + + if output_batch.output is None: + logger.error( + "Received empty output from scheduler for prompt %d", + request_idx + 1, + ) + continue + for output_idx, sample in enumerate(output_batch.output): + num_outputs = len(output_batch.output) + output_file_name = req.output_file_name + if num_outputs > 1 and output_file_name: + base, ext = os.path.splitext(output_file_name) + output_file_name = f"{base}_{output_idx}{ext}" + + save_path = ( + os.path.join(req.output_path, output_file_name) + if output_file_name + else None + ) + frames = self.post_process_sample( + sample, + fps=req.fps, + save_output=req.save_output, + save_file_path=save_path, + data_type=req.data_type, + ) + + result_item: dict[str, Any] = { + "samples": sample, + "frames": frames, + "prompts": req.prompt, + "size": (req.height, req.width, req.num_frames), + "generation_time": gen_time, + "logging_info": output_batch.logging_info, + "trajectory": output_batch.trajectory_latents, + "trajectory_timesteps": output_batch.trajectory_timesteps, + "trajectory_decoded": output_batch.trajectory_decoded, + "prompt_index": output_idx, + } + results.append(result_item) + except Exception as e: + logger.error( + "Failed to generate output for prompt %d: %s", request_idx + 1, e + ) + continue + + total_gen_time = time.perf_counter() - total_start_time + logger.info( + "Completed batch processing. Generated %d outputs in %.2f seconds.", + len(results), + total_gen_time, + ) + + if len(results) == 0: + return None + else: + if requests[0].return_frames: + results = [r["frames"] for r in results] + if len(results) == 1: + return results[0] + return results + + def _send_to_scheduler_and_wait_for_response(self, batch: list[Req]) -> OutputBatch: + """ + Sends a request to the scheduler and waits for a response. + """ + return sync_scheduler_client.forward(batch) + + def set_lora_adapter( + self, lora_nickname: str, lora_path: str | None = None + ) -> None: + # self.scheduler.set_lora_adapter(lora_nickname, lora_path) + pass # Removed as per edit hint + + def unmerge_lora_weights(self) -> None: + """ + Use unmerged weights for inference to produce outputs that align with + validation outputs generated during training. + """ + # self.scheduler.unmerge_lora_weights() + pass # Removed as per edit hint + + def merge_lora_weights(self) -> None: + # self.scheduler.merge_lora_weights() + pass # Removed as per edit hint + + def shutdown(self): + """ + Shutdown the generator. + If in local mode, it also shuts down the scheduler server. + """ + # This sends the shutdown command to the server + # self.scheduler.shutdown() + + if self.local_scheduler_process: + logger.info("Waiting for local worker processes to terminate...") + for process in self.local_scheduler_process: + process.join(timeout=10) + if process.is_alive(): + logger.warning( + f"Local worker {process.name} did not terminate gracefully, forcing." + ) + process.terminate() + self.local_scheduler_process = None + + if self.owns_scheduler_client: + sync_scheduler_client.close() + self.owns_scheduler_client = False + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.shutdown() + + def __del__(self): + if self.owns_scheduler_client: + logger.warning( + "Generator was garbage collected without being shut down. " + "Attempting to shut down the local server and client." + ) + self.shutdown() + elif self.local_scheduler_process: + logger.warning( + "Generator was garbage collected without being shut down. " + "Attempting to shut down the local server." + ) + self.shutdown() diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/http_server.py b/python/sglang/multimodal_gen/runtime/entrypoints/http_server.py new file mode 100644 index 000000000..ac880aeca --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/entrypoints/http_server.py @@ -0,0 +1,58 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +import asyncio +from contextlib import asynccontextmanager + +from fastapi import FastAPI + +from sglang.multimodal_gen.runtime.entrypoints.openai import image_api, video_api +from sglang.multimodal_gen.runtime.server_args import ServerArgs, prepare_server_args +from sglang.multimodal_gen.runtime.utils.logging_utils import configure_logger + + +@asynccontextmanager +async def lifespan(app: FastAPI): + from sglang.multimodal_gen.runtime.scheduler_client import ( + run_zeromq_broker, + scheduler_client, + ) + + # 1. Initialize the singleton client that connects to the backend Scheduler + server_args = app.state.server_args + scheduler_client.initialize(server_args) + + # 2. Start the ZMQ Broker in the background to handle offline requests + broker_task = asyncio.create_task(run_zeromq_broker(server_args)) + + yield + + # On shutdown + print("FastAPI app is shutting down...") + broker_task.cancel() + scheduler_client.close() + + +def create_app(server_args: ServerArgs): + """ + Create and configure the FastAPI application instance. + """ + app = FastAPI(lifespan=lifespan) + app.include_router(image_api.router) + app.include_router(video_api.router) + app.state.server_args = server_args + return app + + +if __name__ == "__main__": + import uvicorn + + server_args = prepare_server_args([]) + configure_logger(server_args) + app = create_app(server_args) + uvicorn.run( + app, + host=server_args.host, + port=server_args.port, + log_config=None, + reload=False, # Set to True during development for auto-reloading + ) diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/openai/image_api.py b/python/sglang/multimodal_gen/runtime/entrypoints/openai/image_api.py new file mode 100644 index 000000000..be77cd555 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/entrypoints/openai/image_api.py @@ -0,0 +1,255 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +import base64 +import os +import time +from typing import List, Optional + +from fastapi import APIRouter, File, Form, HTTPException, Path, Query, UploadFile +from fastapi.responses import FileResponse + +from sglang.multimodal_gen.configs.sample.base import ( + SamplingParams, + generate_request_id, +) +from sglang.multimodal_gen.runtime.entrypoints.openai.protocol import ( + ImageGenerationsRequest, + ImageResponse, + ImageResponseData, +) +from sglang.multimodal_gen.runtime.entrypoints.openai.stores import IMAGE_STORE +from sglang.multimodal_gen.runtime.entrypoints.openai.utils import ( + _parse_size, + _save_upload_to_path, + post_process_sample, +) +from sglang.multimodal_gen.runtime.entrypoints.utils import prepare_request +from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import Req +from sglang.multimodal_gen.runtime.scheduler_client import scheduler_client +from sglang.multimodal_gen.runtime.server_args import get_global_server_args +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +router = APIRouter(prefix="/v1/images", tags=["images"]) +logger = init_logger(__name__) + + +def _choose_ext(output_format: Optional[str], background: Optional[str]) -> str: + # Normalize and choose extension + fmt = (output_format or "").lower() + if fmt in {"png", "webp", "jpeg", "jpg"}: + return "jpg" if fmt == "jpeg" else fmt + # If transparency requested, prefer png + if (background or "auto").lower() == "transparent": + return "png" + # Default + return "jpg" + + +def _build_sampling_params_from_request( + request_id: str, + prompt: str, + n: int, + size: Optional[str], + output_format: Optional[str], + background: Optional[str], + image_path: Optional[str] = None, +) -> SamplingParams: + width, height = _parse_size(size) + ext = _choose_ext(output_format, background) + + server_args = get_global_server_args() + sampling_params = SamplingParams.from_pretrained(server_args.model_path) + + # Build user params + user_params = SamplingParams( + request_id=request_id, + prompt=prompt, + image_path=image_path, + num_frames=1, # image + width=width, + height=height, + num_outputs_per_prompt=max(1, min(int(n or 1), 10)), + save_output=True, + ) + + # Let SamplingParams auto-generate a file name, then force desired extension + sampling_params = sampling_params.from_user_sampling_params(user_params) + if not sampling_params.output_file_name: + sampling_params.output_file_name = request_id + if not sampling_params.output_file_name.endswith(f".{ext}"): + # strip any existing extension and apply desired one + base = sampling_params.output_file_name.rsplit(".", 1)[0] + sampling_params.output_file_name = f"{base}.{ext}" + + sampling_params.log(server_args) + return sampling_params + + +def _build_req_from_sampling(s: SamplingParams) -> Req: + return Req( + request_id=s.request_id, + data_type=s.data_type, + prompt=s.prompt, + image_path=s.image_path, + height=s.height, + width=s.width, + fps=1, + num_frames=s.num_frames, + seed=s.seed, + output_path=s.output_path, + output_file_name=s.output_file_name, + num_outputs_per_prompt=s.num_outputs_per_prompt, + save_output=s.save_output, + ) + + +@router.post("/generations", response_model=ImageResponse) +async def generations( + request: ImageGenerationsRequest, +): + + request_id = generate_request_id() + sampling = _build_sampling_params_from_request( + request_id=request_id, + prompt=request.prompt, + n=request.n or 1, + size=request.size, + output_format=request.output_format, + background=request.background, + ) + batch = prepare_request( + prompt=request.prompt, + server_args=get_global_server_args(), + sampling_params=sampling, + ) + # Run synchronously for images and save to disk + result = await scheduler_client.forward([batch]) + save_file_path = os.path.join(batch.output_path, batch.output_file_name) + post_process_sample( + result.output[0], + batch.data_type, + 1, + batch.save_output, + save_file_path, + ) + + await IMAGE_STORE.upsert( + request_id, + { + "id": request_id, + "created_at": int(time.time()), + "file_path": save_file_path, + }, + ) + + resp_format = (request.response_format or "b64_json").lower() + if resp_format == "b64_json": + with open(save_file_path, "rb") as f: + b64 = base64.b64encode(f.read()).decode("utf-8") + return ImageResponse( + data=[ + ImageResponseData( + b64_json=b64, + revised_prompt=request.prompt, + ) + ] + ) + else: + # Return error, not supported + raise HTTPException( + status_code=400, detail="response_format=url is not supported" + ) + + +@router.post("/edits", response_model=ImageResponse) +async def edits( + image: Optional[List[UploadFile]] = File(None), + image_array: Optional[List[UploadFile]] = File(None, alias="image[]"), + prompt: str = Form(...), + mask: Optional[UploadFile] = File(None), + model: Optional[str] = Form(None), + n: Optional[int] = Form(1), + response_format: Optional[str] = Form(None), + size: Optional[str] = Form("1024x1024"), + output_format: Optional[str] = Form(None), + background: Optional[str] = Form("auto"), + user: Optional[str] = Form(None), +): + + request_id = generate_request_id() + # Resolve images from either `image` or `image[]` (OpenAI SDK sends `image[]` when list is provided) + images = image or image_array + if not images or len(images) == 0: + raise HTTPException(status_code=422, detail="Field 'image' is required") + + # Save first input image; additional images or mask are not yet used by the pipeline + uploads_dir = os.path.join("outputs", "uploads") + os.makedirs(uploads_dir, exist_ok=True) + first_image = images[0] + input_path = os.path.join(uploads_dir, f"{request_id}_{first_image.filename}") + await _save_upload_to_path(first_image, input_path) + + sampling = _build_sampling_params_from_request( + request_id=request_id, + prompt=prompt, + n=n or 1, + size=size, + output_format=output_format, + background=background, + image_path=input_path, + ) + batch = _build_req_from_sampling(sampling) + + result = await scheduler_client.forward([batch]) + save_file_path = os.path.join(batch.output_path, batch.output_file_name) + post_process_sample( + result.output[0], + batch.data_type, + 1, + batch.save_output, + save_file_path, + ) + + await IMAGE_STORE.upsert( + request_id, + { + "id": request_id, + "created_at": int(time.time()), + "file_path": save_file_path, + }, + ) + + # Default to b64_json to align with gpt-image-1 behavior in OpenAI examples + if (response_format or "b64_json").lower() == "b64_json": + with open(save_file_path, "rb") as f: + b64 = base64.b64encode(f.read()).decode("utf-8") + return ImageResponse( + data=[ImageResponseData(b64_json=b64, revised_prompt=prompt)] + ) + else: + url = f"/v1/images/{request_id}/content" + return ImageResponse(data=[ImageResponseData(url=url, revised_prompt=prompt)]) + + +@router.get("/{image_id}/content") +async def download_image_content( + image_id: str = Path(...), variant: Optional[str] = Query(None) +): + item = await IMAGE_STORE.get(image_id) + if not item: + raise HTTPException(status_code=404, detail="Image not found") + + file_path = item.get("file_path") + if not file_path or not os.path.exists(file_path): + raise HTTPException(status_code=404, detail="Image is still being generated") + + ext = os.path.splitext(file_path)[1].lower() + media_type = "image/jpeg" + if ext == ".png": + media_type = "image/png" + elif ext == ".webp": + media_type = "image/webp" + + return FileResponse( + path=file_path, media_type=media_type, filename=os.path.basename(file_path) + ) diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/openai/protocol.py b/python/sglang/multimodal_gen/runtime/entrypoints/openai/protocol.py new file mode 100644 index 000000000..00800ab15 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/entrypoints/openai/protocol.py @@ -0,0 +1,65 @@ +import time +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Field + + +# Image API protocol models +class ImageResponseData(BaseModel): + b64_json: Optional[str] = None + url: Optional[str] = None + revised_prompt: Optional[str] = None + + +class ImageResponse(BaseModel): + created: int = Field(default_factory=lambda: int(time.time())) + data: List[ImageResponseData] + + +class ImageGenerationsRequest(BaseModel): + prompt: str + model: Optional[str] = None + n: Optional[int] = 1 + quality: Optional[str] = "auto" + response_format: Optional[str] = "url" # url | b64_json + size: Optional[str] = "1024x1024" # e.g., 1024x1024 + style: Optional[str] = "vivid" + background: Optional[str] = "auto" # transparent | opaque | auto + output_format: Optional[str] = None # png | jpeg | webp + user: Optional[str] = None + + +# Video API protocol models +class VideoResponse(BaseModel): + id: str + object: str = "video" + model: str = "sora-2" + status: str = "queued" + progress: int = 0 + created_at: int = Field(default_factory=lambda: int(time.time())) + size: str = "720x1280" + seconds: str = "4" + quality: str = "standard" + remixed_from_video_id: Optional[str] = None + completed_at: Optional[int] = None + expires_at: Optional[int] = None + error: Optional[Dict[str, Any]] = None + + +class VideoGenerationsRequest(BaseModel): + prompt: str + input_reference: Optional[str] = None + model: Optional[str] = None + seconds: Optional[int] = 4 + size: Optional[str] = "720x1280" + fps: Optional[int] = None + num_frames: Optional[int] = None + + +class VideoListResponse(BaseModel): + data: List[VideoResponse] + object: str = "list" + + +class VideoRemixRequest(BaseModel): + prompt: str diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/openai/stores.py b/python/sglang/multimodal_gen/runtime/entrypoints/openai/stores.py new file mode 100644 index 000000000..f924de819 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/entrypoints/openai/stores.py @@ -0,0 +1,46 @@ +import asyncio +from typing import Any, Dict, List, Optional + + +class AsyncDictStore: + """A small async-safe in-memory key-value store for dict items. + + This encapsulates the usual pattern of a module-level dict guarded by + an asyncio.Lock and provides simple CRUD methods that are safe to call + concurrently from FastAPI request handlers and background tasks. + """ + + def __init__(self) -> None: + self._items: Dict[str, Dict[str, Any]] = {} + self._lock = asyncio.Lock() + + async def upsert(self, key: str, value: Dict[str, Any]) -> None: + async with self._lock: + self._items[key] = value + + async def update_fields( + self, key: str, updates: Dict[str, Any] + ) -> Optional[Dict[str, Any]]: + async with self._lock: + item = self._items.get(key) + if item is None: + return None + item.update(updates) + return item + + async def get(self, key: str) -> Optional[Dict[str, Any]]: + async with self._lock: + return self._items.get(key) + + async def pop(self, key: str) -> Optional[Dict[str, Any]]: + async with self._lock: + return self._items.pop(key, None) + + async def list_values(self) -> List[Dict[str, Any]]: + async with self._lock: + return list(self._items.values()) + + +# Global stores shared by OpenAI entrypoints +VIDEO_STORE = AsyncDictStore() +IMAGE_STORE = AsyncDictStore() diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/openai/utils.py b/python/sglang/multimodal_gen/runtime/entrypoints/openai/utils.py new file mode 100644 index 000000000..42bda15e0 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/entrypoints/openai/utils.py @@ -0,0 +1,77 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +import os + +import imageio +import numpy as np +import torch +import torchvision +from einops import rearrange +from fastapi import UploadFile + +from sglang.multimodal_gen.configs.sample.base import DataType +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +def post_process_sample( + sample: torch.Tensor, + data_type: DataType, + fps: int, + save_output: bool = True, + save_file_path: str = None, +): + """ + Process sample output and save video if necessary + """ + # Process outputs + if sample.dim() == 3: + # for images, dim t is missing + sample = sample.unsqueeze(1) + videos = rearrange(sample, "c t h w -> t c h w") + frames = [] + for x in videos: + x = torchvision.utils.make_grid(x, nrow=6) + x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) + frames.append((x * 255).numpy().astype(np.uint8)) + + # Save outputs if requested + if save_output: + if save_file_path: + os.makedirs(os.path.dirname(save_file_path), exist_ok=True) + if data_type == DataType.VIDEO: + imageio.mimsave( + save_file_path, + frames, + fps=fps, + format=data_type.get_default_extension(), + ) + else: + imageio.imwrite(save_file_path, frames[0]) + logger.info(f"Saved output to {save_file_path}") + else: + logger.info(f"No output path provided, output not saved") + + return frames + + +def _parse_size(size: str) -> tuple[int, int]: + try: + parts = size.lower().replace(" ", "").split("x") + if len(parts) != 2: + raise ValueError + w, h = int(parts[0]), int(parts[1]) + return w, h + except Exception: + # Fallback to default portrait 720x1280 + return 720, 1280 + + +# Helpers +async def _save_upload_to_path(upload: UploadFile, target_path: str) -> str: + os.makedirs(os.path.dirname(target_path), exist_ok=True) + content = await upload.read() + with open(target_path, "wb") as f: + f.write(content) + return target_path diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/openai/video_api.py b/python/sglang/multimodal_gen/runtime/entrypoints/openai/video_api.py new file mode 100644 index 000000000..c6bf59235 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/entrypoints/openai/video_api.py @@ -0,0 +1,269 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +import asyncio +import json +import os +import time +from typing import Any, Dict, Optional + +from fastapi import ( + APIRouter, + File, + Form, + HTTPException, + Path, + Query, + Request, + UploadFile, +) +from fastapi.responses import FileResponse + +from sglang.multimodal_gen.configs.sample.base import ( + SamplingParams, + generate_request_id, +) +from sglang.multimodal_gen.runtime.entrypoints.openai.protocol import ( + VideoGenerationsRequest, + VideoListResponse, + VideoResponse, +) +from sglang.multimodal_gen.runtime.entrypoints.openai.stores import VIDEO_STORE +from sglang.multimodal_gen.runtime.entrypoints.openai.utils import ( + _parse_size, + _save_upload_to_path, + post_process_sample, +) +from sglang.multimodal_gen.runtime.entrypoints.utils import prepare_request +from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import Req +from sglang.multimodal_gen.runtime.server_args import get_global_server_args +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) +router = APIRouter(prefix="/v1/videos", tags=["videos"]) + + +def _build_sampling_params_from_request( + request_id: str, request: VideoGenerationsRequest +) -> SamplingParams: + width, height = _parse_size(request.size or "720x1280") + seconds = request.seconds if request.seconds is not None else 4 + # Prefer user-provided fps/num_frames from request; fallback to defaults + fps_default = 24 + fps = request.fps if request.fps is not None else fps_default + # If user provides num_frames, use it directly; otherwise derive from seconds * fps + derived_num_frames = fps * seconds + num_frames = ( + request.num_frames if request.num_frames is not None else derived_num_frames + ) + server_args = get_global_server_args() + # TODO: should we cache this sampling_params? + sampling_params = SamplingParams.from_pretrained(server_args.model_path) + user_params = SamplingParams( + request_id=request_id, + prompt=request.prompt, + num_frames=num_frames, + fps=fps, + width=width, + height=height, + image_path=request.input_reference, + save_output=True, + ) + sampling_params = sampling_params.from_user_sampling_params(user_params) + sampling_params.log(server_args) + sampling_params.set_output_file_ext() + return sampling_params + + +# extract metadata which http_server needs to know +def _video_job_from_sampling( + request_id: str, req: VideoGenerationsRequest, sampling: SamplingParams +) -> Dict[str, Any]: + size_str = f"{sampling.width}x{sampling.height}" + seconds = int(round((sampling.num_frames or 0) / float(sampling.fps or 24))) + return { + "id": request_id, + "object": "video", + "model": req.model or "sora-2", + "status": "queued", + "progress": 0, + "created_at": int(time.time()), + "size": size_str, + "seconds": str(seconds), + "quality": "standard", + "file_path": sampling.output_file_path(), + } + + +async def _dispatch_job_async(job_id: str, batch: Req) -> None: + from sglang.multimodal_gen.runtime.scheduler_client import scheduler_client + + try: + result = await scheduler_client.forward([batch]) + post_process_sample( + result.output[0], + batch.data_type, + batch.fps, + batch.save_output, + os.path.join(batch.output_path, batch.output_file_name), + ) + await VIDEO_STORE.update_fields( + job_id, + {"status": "completed", "progress": 100, "completed_at": int(time.time())}, + ) + except Exception as e: + logger.error(f"{e}") + await VIDEO_STORE.update_fields( + job_id, {"status": "failed", "error": {"message": str(e)}} + ) + + +# TODO: support image to video generation +@router.post("", response_model=VideoResponse) +async def create_video( + request: Request, + # multipart/form-data fields (optional; used only when content-type is multipart) + prompt: Optional[str] = Form(None), + input_reference: Optional[UploadFile] = File(None), + model: Optional[str] = Form(None), + seconds: Optional[int] = Form(None), + size: Optional[str] = Form(None), + fps: Optional[int] = Form(None), + num_frames: Optional[int] = Form(None), + extra_body: Optional[str] = Form(None), +): + content_type = request.headers.get("content-type", "").lower() + request_id = generate_request_id() + + if "multipart/form-data" in content_type: + if not prompt: + raise HTTPException(status_code=400, detail="prompt is required") + if input_reference is None: + raise HTTPException( + status_code=400, detail="input_reference file is required" + ) + + uploads_dir = os.path.join("outputs", "uploads") + os.makedirs(uploads_dir, exist_ok=True) + input_path = os.path.join( + uploads_dir, f"{request_id}_{input_reference.filename}" + ) + await _save_upload_to_path(input_reference, input_path) + + # Parse extra_body JSON (if provided in multipart form) to get fps/num_frames overrides + extra_from_form: Dict[str, Any] = {} + if extra_body: + try: + extra_from_form = json.loads(extra_body) + except Exception: + extra_from_form = {} + + fps_val = fps if fps is not None else extra_from_form.get("fps") + num_frames_val = ( + num_frames if num_frames is not None else extra_from_form.get("num_frames") + ) + + req = VideoGenerationsRequest( + prompt=prompt, + input_reference=input_path, + model=model, + seconds=seconds if seconds is not None else 4, + size=size or "720x1280", + fps=fps_val, + num_frames=num_frames_val, + ) + else: + try: + body = await request.json() + except Exception: + body = {} + try: + # If client uses extra_body, merge it into the top-level payload + payload: Dict[str, Any] = dict(body or {}) + extra = payload.pop("extra_body", None) + if isinstance(extra, dict): + # Shallow-merge: only keys like fps/num_frames are expected + payload.update(extra) + req = VideoGenerationsRequest(**payload) + except Exception as e: + raise HTTPException(status_code=400, detail=f"Invalid request body: {e}") + + logger.debug(f"Server received from create_video endpoint: req={req}") + + sampling_params = _build_sampling_params_from_request(request_id, req) + job = _video_job_from_sampling(request_id, req, sampling_params) + await VIDEO_STORE.upsert(request_id, job) + + # Build Req for scheduler + batch = prepare_request( + prompt=req.prompt, + server_args=get_global_server_args(), + sampling_params=sampling_params, + ) + # Enqueue the job asynchronously and return immediately + asyncio.create_task(_dispatch_job_async(request_id, batch)) + return VideoResponse(**job) + + +@router.get("", response_model=VideoListResponse) +async def list_videos( + after: Optional[str] = Query(None), + limit: Optional[int] = Query(None, ge=1, le=100), + order: Optional[str] = Query("desc"), +): + # Normalize order + order = (order or "desc").lower() + if order not in ("asc", "desc"): + order = "desc" + jobs = await VIDEO_STORE.list_values() + + reverse = order != "asc" + jobs.sort(key=lambda j: j.get("created_at", 0), reverse=reverse) + + if after is not None: + try: + idx = next(i for i, j in enumerate(jobs) if j["id"] == after) + jobs = jobs[idx + 1 :] + except StopIteration: + jobs = [] + + if limit is not None: + jobs = jobs[:limit] + items = [VideoResponse(**j) for j in jobs] + return VideoListResponse(data=items) + + +@router.get("/{video_id}", response_model=VideoResponse) +async def retrieve_video(video_id: str = Path(...)): + job = await VIDEO_STORE.get(video_id) + if not job: + raise HTTPException(status_code=404, detail="Video not found") + return VideoResponse(**job) + + +# TODO: support aborting a job. +@router.delete("/{video_id}", response_model=VideoResponse) +async def delete_video(video_id: str = Path(...)): + job = await VIDEO_STORE.pop(video_id) + if not job: + raise HTTPException(status_code=404, detail="Video not found") + # Mark as deleted in response semantics + job["status"] = "deleted" + return VideoResponse(**job) + + +@router.get("/{video_id}/content") +async def download_video_content( + video_id: str = Path(...), variant: Optional[str] = Query(None) +): + job = await VIDEO_STORE.get(video_id) + if not job: + raise HTTPException(status_code=404, detail="Video not found") + + file_path = job.get("file_path") + if not file_path or not os.path.exists(file_path): + raise HTTPException(status_code=404, detail="Generation is still in-progress") + + media_type = "video/mp4" # default variant + return FileResponse( + path=file_path, media_type=media_type, filename=os.path.basename(file_path) + ) diff --git a/python/sglang/multimodal_gen/runtime/entrypoints/utils.py b/python/sglang/multimodal_gen/runtime/entrypoints/utils.py new file mode 100644 index 000000000..123e3efec --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/entrypoints/utils.py @@ -0,0 +1,139 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +DiffGenerator module for sgl-diffusion. + +This module provides a consolidated interface for generating videos using +diffusion models. +""" + +import logging +import math + +# Suppress verbose logging from imageio, which is triggered when saving images. +logging.getLogger("imageio").setLevel(logging.WARNING) +logging.getLogger("imageio_ffmpeg").setLevel(logging.WARNING) + +from sglang.multimodal_gen.configs.sample.base import DataType, SamplingParams +from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import Req +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import shallow_asdict + +logger = init_logger(__name__) + + +def prepare_sampling_params( + prompt: str, + server_args: ServerArgs, + sampling_params: SamplingParams, +): + pipeline_config = server_args.pipeline_config + # Validate inputs + if not isinstance(prompt, str): + raise TypeError(f"`prompt` must be a string, but got {type(prompt)}") + + # Process negative prompt + if ( + sampling_params.negative_prompt is not None + and not sampling_params.negative_prompt.isspace() + ): + # avoid stripping default negative prompt: ' ' for qwen-image + sampling_params.negative_prompt = sampling_params.negative_prompt.strip() + + # Validate dimensions + if sampling_params.num_frames <= 0: + raise ValueError( + f"Height, width, and num_frames must be positive integers, got " + f"height={sampling_params.height}, width={sampling_params.width}, " + f"num_frames={sampling_params.num_frames}" + ) + + temporal_scale_factor = ( + pipeline_config.vae_config.arch_config.temporal_compression_ratio + ) + + # settle num_frames + if server_args.pipeline_config.is_image_gen: + logger.debug(f"Setting num_frames to 1 because this is a image-gen model") + sampling_params.num_frames = 1 + + num_frames = sampling_params.num_frames + num_gpus = server_args.num_gpus + use_temporal_scaling_frames = pipeline_config.vae_config.use_temporal_scaling_frames + + # Adjust number of frames based on number of GPUs + if use_temporal_scaling_frames: + orig_latent_num_frames = (num_frames - 1) // temporal_scale_factor + 1 + else: # stepvideo only + orig_latent_num_frames = sampling_params.num_frames // 17 * 3 + + if orig_latent_num_frames % server_args.num_gpus != 0: + # Adjust latent frames to be divisible by number of GPUs + if sampling_params.num_frames_round_down: + # Ensure we have at least 1 batch per GPU + new_latent_num_frames = ( + max(1, (orig_latent_num_frames // num_gpus)) * num_gpus + ) + else: + new_latent_num_frames = ( + math.ceil(orig_latent_num_frames / num_gpus) * num_gpus + ) + + if use_temporal_scaling_frames: + # Convert back to number of frames, ensuring num_frames-1 is a multiple of temporal_scale_factor + new_num_frames = (new_latent_num_frames - 1) * temporal_scale_factor + 1 + else: # stepvideo only + # Find the least common multiple of 3 and num_gpus + divisor = math.lcm(3, num_gpus) + # Round up to the nearest multiple of this LCM + new_latent_num_frames = ( + (new_latent_num_frames + divisor - 1) // divisor + ) * divisor + # Convert back to actual frames using the StepVideo formula + new_num_frames = new_latent_num_frames // 3 * 17 + + logger.info( + "Adjusting number of frames from %s to %s based on number of GPUs (%s)", + sampling_params.num_frames, + new_num_frames, + server_args.num_gpus, + ) + sampling_params.num_frames = new_num_frames + + if pipeline_config.is_image_gen: + sampling_params.data_type = DataType.IMAGE + + sampling_params.set_output_file_ext() + sampling_params.log(server_args=server_args) + return sampling_params + + +def prepare_request( + prompt: str, + server_args: ServerArgs, + sampling_params: SamplingParams, +) -> Req: + """ + Settle SamplingParams according to ServerArgs + + """ + # Create a copy of inference args to avoid modifying the original + + sampling_params = prepare_sampling_params(prompt, server_args, sampling_params) + + req = Req( + **shallow_asdict(sampling_params), + VSA_sparsity=server_args.VSA_sparsity, + ) + # req.set_width_and_height(server_args) + + # if (req.width <= 0 + # or req.height <= 0): + # raise ValueError( + # f"Height, width must be positive integers, got " + # f"height={req.height}, width={req.width}" + # ) + + return req diff --git a/python/sglang/multimodal_gen/runtime/launch_server.py b/python/sglang/multimodal_gen/runtime/launch_server.py new file mode 100644 index 000000000..36bc44c6e --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/launch_server.py @@ -0,0 +1,142 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +import multiprocessing as mp + +import uvicorn + +from sglang.multimodal_gen.runtime.entrypoints.http_server import create_app +from sglang.multimodal_gen.runtime.managers.gpu_worker import run_scheduler_process +from sglang.multimodal_gen.runtime.server_args import ServerArgs, set_global_server_args +from sglang.multimodal_gen.runtime.utils.logging_utils import ( + configure_logger, + logger, + suppress_other_loggers, +) + + +def launch_server(server_args: ServerArgs, launch_http_server: bool = True): + """ + Args: + launch_http_server: False for offline local mode + """ + configure_logger(server_args) + suppress_other_loggers() + + # Start a new server with multiple worker processes + logger.info("Starting server...") + + num_gpus = server_args.num_gpus + processes = [] + + # Pipes for master to talk to slaves + task_pipes_to_slaves_w = [] + task_pipes_to_slaves_r = [] + for _ in range(num_gpus - 1): + r, w = mp.Pipe(duplex=False) + task_pipes_to_slaves_r.append(r) + task_pipes_to_slaves_w.append(w) + + # Pipes for slaves to talk to master + result_pipes_from_slaves_w = [] + result_pipes_from_slaves_r = [] + for _ in range(num_gpus - 1): + r, w = mp.Pipe(duplex=False) + result_pipes_from_slaves_r.append(r) + result_pipes_from_slaves_w.append(w) + + # Launch all worker processes + master_port = server_args.master_port or (server_args.master_port + 100) + scheduler_pipe_readers = [] + scheduler_pipe_writers = [] + + for i in range(num_gpus): + reader, writer = mp.Pipe(duplex=False) + scheduler_pipe_writers.append(writer) + if i == 0: # Master worker + process = mp.Process( + target=run_scheduler_process, + args=( + i, # local_rank + i, # rank + master_port, + server_args, + writer, + None, # No task pipe to read from master + None, # No result pipe to write to master + task_pipes_to_slaves_w, + result_pipes_from_slaves_r, + ), + name=f"sgl-diffusionWorker-{i}", + daemon=True, + ) + else: # Slave workers + process = mp.Process( + target=run_scheduler_process, + args=( + i, # local_rank + i, # rank + master_port, + server_args, + writer, + None, # No task pipe to read from master + None, # No result pipe to write to master + task_pipes_to_slaves_r[i - 1], + result_pipes_from_slaves_w[i - 1], + ), + name=f"sgl-diffusionWorker-{i}", + daemon=True, + ) + scheduler_pipe_readers.append(reader) + process.start() + processes.append(process) + + # Wait for all workers to be ready + scheduler_infos = [] + for writer in scheduler_pipe_writers: + writer.close() + + # Close unused pipe ends in parent process + for p in task_pipes_to_slaves_w: + p.close() + for p in task_pipes_to_slaves_r: + p.close() + for p in result_pipes_from_slaves_w: + p.close() + for p in result_pipes_from_slaves_r: + p.close() + + for i, reader in enumerate(scheduler_pipe_readers): + try: + data = reader.recv() + except EOFError: + logger.error( + f"Rank {i} scheduler is dead. Please check if there are relevant logs." + ) + processes[i].join() + logger.error(f"Exit code: {processes[i].exitcode}") + raise + + if data["status"] != "ready": + raise RuntimeError( + "Initialization failed. Please see the error messages above." + ) + scheduler_infos.append(data) + reader.close() + + logger.debug("All workers are ready") + + if launch_http_server: + logger.info("Starting FastAPI server.") + + # set for endpoints to access global_server_args + set_global_server_args(server_args) + + app = create_app(server_args) + uvicorn.run( + app, + log_config=None, + log_level=server_args.log_level, + host=server_args.host, + port=server_args.port, + reload=False, + ) diff --git a/python/sglang/multimodal_gen/runtime/layers/__init__.py b/python/sglang/multimodal_gen/runtime/layers/__init__.py new file mode 100644 index 000000000..af2eb7d10 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/layers/__init__.py @@ -0,0 +1 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo diff --git a/python/sglang/multimodal_gen/runtime/layers/activation.py b/python/sglang/multimodal_gen/runtime/layers/activation.py new file mode 100644 index 000000000..4eff9ba1c --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/layers/activation.py @@ -0,0 +1,129 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/layers/activation.py +"""Custom activation functions.""" +import math +from typing import Any + +import torch +import torch.nn as nn +import torch.nn.functional as F + +# TODO (will): remove this dependency +from sglang.multimodal_gen.runtime.layers.custom_op import CustomOp + + +@CustomOp.register("silu_and_mul") +class SiluAndMul(CustomOp): + """An activation function for SwiGLU. + + The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2. + + Shapes: + x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d) + return: (num_tokens, d) or (batch_size, seq_len, d) + """ + + def __init__(self) -> None: + super().__init__() + + def forward_cuda(self, *args, **kwargs) -> Any: + return self.forward_native(*args, **kwargs) + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + d = x.shape[-1] // 2 + return F.silu(x[..., :d]) * x[..., d:] + + +@CustomOp.register("gelu_and_mul") +class GeluAndMul(CustomOp): + """An activation function for GeGLU. + + The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2. + + Shapes: + x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d) + return: (batch_size, seq_len, d) or (num_tokens, d) + """ + + def __init__(self, approximate: str = "none"): + super().__init__() + self.approximate = approximate + if approximate not in ("none", "tanh"): + raise ValueError(f"Unknown approximate mode: {approximate}") + + def forward_cuda(self, *args, **kwargs) -> Any: + return self.forward_native(*args, **kwargs) + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + d = x.shape[-1] // 2 + return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:] + + def extra_repr(self) -> str: + return f"approximate={repr(self.approximate)}" + + +@CustomOp.register("gelu_new") +class NewGELU(CustomOp): + + def __init__(self): + super().__init__() + + def forward_cuda(self, *args, **kwargs) -> Any: + return self.forward_native(*args, **kwargs) + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + c = math.sqrt(2.0 / math.pi) + return 0.5 * x * (1.0 + torch.tanh(c * (x + 0.044715 * torch.pow(x, 3.0)))) + + +@CustomOp.register("quick_gelu") +class QuickGELU(CustomOp): + # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90 + def __init__(self): + super().__init__() + + def forward_cuda(self, *args, **kwargs) -> Any: + return self.forward_native(*args, **kwargs) + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + return x * torch.sigmoid(1.702 * x) + + +_ACTIVATION_REGISTRY = { + "gelu": nn.GELU, + "gelu_new": NewGELU, + "gelu_pytorch_tanh": lambda: nn.GELU(approximate="tanh"), + "relu": nn.ReLU, + "silu": nn.SiLU, + "quick_gelu": QuickGELU, +} + + +def get_act_fn(act_fn_name: str) -> nn.Module: + """Get an activation function by name.""" + act_fn_name = act_fn_name.lower() + if act_fn_name not in _ACTIVATION_REGISTRY: + raise ValueError(f"Activation function {act_fn_name!r} is not supported.") + + return _ACTIVATION_REGISTRY[act_fn_name]() + + +_ACTIVATION_AND_MUL_REGISTRY = { + "gelu": GeluAndMul, + "silu": SiluAndMul, +} + + +def get_act_and_mul_fn(act_fn_name: str) -> nn.Module: + """Get an activation-and-mul (i.e. SiluAndMul) function by name.""" + act_fn_name = act_fn_name.lower() + if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY: + raise ValueError(f"Activation function {act_fn_name!r} is not supported.") + + return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]() diff --git a/python/sglang/multimodal_gen/runtime/layers/attention/STA_configuration.py b/python/sglang/multimodal_gen/runtime/layers/attention/STA_configuration.py new file mode 100644 index 000000000..9635a6740 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/layers/attention/STA_configuration.py @@ -0,0 +1,414 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +import json +import os +from collections import defaultdict +from typing import Any + +import numpy as np + +from sglang.multimodal_gen.utils import dict_to_3d_list + + +def configure_sta( + mode: str = "STA_searching", + layer_num: int = 40, + time_step_num: int = 50, + head_num: int = 40, + **kwargs, +) -> list[list[list[Any]]]: + """ + Configure Sliding Tile Attention (STA) parameters based on the specified mode. + + Parameters: + ---------- + mode : str + The STA mode to use. Options are: + - 'STA_searching': Generate a set of mask candidates for initial search + - 'STA_tuning': Select best mask strategy based on previously saved results + - 'STA_inference': Load and use a previously tuned mask strategy + layer_num: int, number of layers + time_step_num: int, number of timesteps + head_num: int, number of heads + + **kwargs : dict + Mode-specific parameters: + + For 'STA_searching': + - mask_candidates: list of str, optional, mask candidates to use + - mask_selected: list of int, optional, indices of selected masks + + For 'STA_tuning': + - mask_search_files_path: str, required, path to mask search results + - mask_candidates: list of str, optional, mask candidates to use + - mask_selected: list of int, optional, indices of selected masks + - skip_time_steps: int, optional, number of time steps to use full attention (default 12) + - save_dir: str, optional, directory to save mask strategy (default "mask_candidates") + + For 'STA_inference': + - load_path: str, optional, path to load mask strategy (default "mask_candidates/mask_strategy.json") + """ + valid_modes = ["STA_searching", "STA_tuning", "STA_inference", "STA_tuning_cfg"] + if mode not in valid_modes: + raise ValueError(f"Mode must be one of {valid_modes}, got {mode}") + + if mode == "STA_searching": + # Get parameters with defaults + mask_candidates: list[str] | None = kwargs.get("mask_candidates") + if mask_candidates is None: + raise ValueError("mask_candidates is required for STA_searching mode") + mask_selected: list[int] = kwargs.get( + "mask_selected", list(range(len(mask_candidates))) + ) + + # Parse selected masks + selected_masks: list[list[int]] = [] + for index in mask_selected: + mask = mask_candidates[index] + masks_list = [int(x) for x in mask.split(",")] + selected_masks.append(masks_list) + + # Create 3D mask structure with fixed dimensions (t=50, l=60) + masks_3d: list[list[list[list[int]]]] = [] + for i in range(time_step_num): # Fixed t dimension = 50 + row = [] + for j in range(layer_num): # Fixed l dimension = 60 + row.append(selected_masks) # Add all masks at each position + masks_3d.append(row) + + return masks_3d + + elif mode == "STA_tuning": + # Get required parameters + mask_search_files_path: str | None = kwargs.get("mask_search_files_path") + if not mask_search_files_path: + raise ValueError("mask_search_files_path is required for STA_tuning mode") + + # Get optional parameters with defaults + mask_candidates_tuning: list[str] | None = kwargs.get("mask_candidates") + if mask_candidates_tuning is None: + raise ValueError("mask_candidates is required for STA_tuning mode") + mask_selected_tuning: list[int] = kwargs.get( + "mask_selected", list(range(len(mask_candidates_tuning))) + ) + skip_time_steps_tuning: int | None = kwargs.get("skip_time_steps") + save_dir_tuning: str | None = kwargs.get("save_dir", "mask_candidates") + + # Parse selected masks + selected_masks_tuning: list[list[int]] = [] + for index in mask_selected_tuning: + mask = mask_candidates_tuning[index] + masks_list = [int(x) for x in mask.split(",")] + selected_masks_tuning.append(masks_list) + + # Read JSON results + results = read_specific_json_files(mask_search_files_path) + averaged_results = average_head_losses(results, selected_masks_tuning) + + # Add full attention mask for specific cases + full_attention_mask_tuning: list[int] | None = kwargs.get("full_attention_mask") + if full_attention_mask_tuning is not None: + selected_masks_tuning.append(full_attention_mask_tuning) + + # Select best mask strategy + timesteps_tuning: int = kwargs.get("timesteps", time_step_num) + if skip_time_steps_tuning is None: + skip_time_steps_tuning = 12 + mask_strategy, sparsity, strategy_counts = select_best_mask_strategy( + averaged_results, + selected_masks_tuning, + skip_time_steps_tuning, + timesteps_tuning, + head_num, + ) + + # Save mask strategy + if save_dir_tuning is not None: + os.makedirs(save_dir_tuning, exist_ok=True) + file_path = os.path.join( + save_dir_tuning, f"mask_strategy_s{skip_time_steps_tuning}.json" + ) + with open(file_path, "w") as f: + json.dump(mask_strategy, f, indent=4) + print(f"Successfully saved mask_strategy to {file_path}") + + # Print sparsity and strategy counts for information + print(f"Overall sparsity: {sparsity:.4f}") + print("\nStrategy usage counts:") + total_heads = time_step_num * layer_num * head_num # Fixed dimensions + for strategy, count in strategy_counts.items(): + print(f"Strategy {strategy}: {count} heads ({count/total_heads*100:.2f}%)") + + # Convert dictionary to 3D list with fixed dimensions + mask_strategy_3d = dict_to_3d_list( + mask_strategy, t_max=time_step_num, l_max=layer_num, h_max=head_num + ) + + return mask_strategy_3d + elif mode == "STA_tuning_cfg": + # Get required parameters for both positive and negative paths + mask_search_files_path_pos: str | None = kwargs.get( + "mask_search_files_path_pos" + ) + mask_search_files_path_neg: str | None = kwargs.get( + "mask_search_files_path_neg" + ) + save_dir_cfg: str | None = kwargs.get("save_dir") + + if ( + not mask_search_files_path_pos + or not mask_search_files_path_neg + or not save_dir_cfg + ): + raise ValueError( + "mask_search_files_path_pos, mask_search_files_path_neg, and save_dir are required for STA_tuning_cfg mode" + ) + + # Get optional parameters with defaults + mask_candidates_cfg: list[str] | None = kwargs.get("mask_candidates") + if mask_candidates_cfg is None: + raise ValueError("mask_candidates is required for STA_tuning_cfg mode") + mask_selected_cfg: list[int] = kwargs.get( + "mask_selected", list(range(len(mask_candidates_cfg))) + ) + skip_time_steps_cfg: int | None = kwargs.get("skip_time_steps") + + # Parse selected masks + selected_masks_cfg: list[list[int]] = [] + for index in mask_selected_cfg: + mask = mask_candidates_cfg[index] + masks_list = [int(x) for x in mask.split(",")] + selected_masks_cfg.append(masks_list) + + # Read JSON results for both positive and negative paths + pos_results = read_specific_json_files(mask_search_files_path_pos) + neg_results = read_specific_json_files(mask_search_files_path_neg) + # Combine positive and negative results into one list + combined_results = pos_results + neg_results + + # Average the combined results + averaged_results = average_head_losses(combined_results, selected_masks_cfg) + + # Add full attention mask for specific cases + full_attention_mask_cfg: list[int] | None = kwargs.get("full_attention_mask") + if full_attention_mask_cfg is not None: + selected_masks_cfg.append(full_attention_mask_cfg) + + timesteps_cfg: int = kwargs.get("timesteps", time_step_num) + if skip_time_steps_cfg is None: + skip_time_steps_cfg = 12 + # Select best mask strategy using combined results + mask_strategy, sparsity, strategy_counts = select_best_mask_strategy( + averaged_results, + selected_masks_cfg, + skip_time_steps_cfg, + timesteps_cfg, + head_num, + ) + + # Save mask strategy + os.makedirs(save_dir_cfg, exist_ok=True) + file_path = os.path.join( + save_dir_cfg, f"mask_strategy_s{skip_time_steps_cfg}.json" + ) + with open(file_path, "w") as f: + json.dump(mask_strategy, f, indent=4) + print(f"Successfully saved mask_strategy to {file_path}") + + # Print sparsity and strategy counts for information + print(f"Overall sparsity: {sparsity:.4f}") + print("\nStrategy usage counts:") + total_heads = time_step_num * layer_num * head_num # Fixed dimensions + for strategy, count in strategy_counts.items(): + print(f"Strategy {strategy}: {count} heads ({count/total_heads*100:.2f}%)") + + # Convert dictionary to 3D list with fixed dimensions + mask_strategy_3d = dict_to_3d_list( + mask_strategy, t_max=time_step_num, l_max=layer_num, h_max=head_num + ) + + return mask_strategy_3d + + else: # STA_inference + # Get parameters with defaults + load_path: str | None = kwargs.get( + "load_path", "mask_candidates/mask_strategy.json" + ) + if load_path is None: + raise ValueError("load_path is required for STA_inference mode") + + # Load previously saved mask strategy + with open(load_path) as f: + mask_strategy = json.load(f) + + # Convert dictionary to 3D list with fixed dimensions + mask_strategy_3d = dict_to_3d_list( + mask_strategy, t_max=time_step_num, l_max=layer_num, h_max=head_num + ) + + return mask_strategy_3d + + +# Helper functions + + +def read_specific_json_files(folder_path: str) -> list[dict[str, Any]]: + """Read and parse JSON files containing mask search results.""" + json_contents: list[dict[str, Any]] = [] + + # List files only in the current directory (no walk) + files = os.listdir(folder_path) + # Filter files + matching_files = [f for f in files if "mask" in f and f.endswith(".json")] + print(f"Found {len(matching_files)} matching files: {matching_files}") + + for file_name in matching_files: + file_path = os.path.join(folder_path, file_name) + with open(file_path) as file: + data = json.load(file) + json_contents.append(data) + + return json_contents + + +def average_head_losses( + results: list[dict[str, Any]], selected_masks: list[list[int]] +) -> dict[str, dict[str, np.ndarray]]: + """Average losses across all prompts for each mask strategy.""" + # Initialize a dictionary to store the averaged results + averaged_losses: dict[str, dict[str, np.ndarray]] = {} + loss_type = "L2_loss" + # Get all loss types (e.g., 'L2_loss') + averaged_losses[loss_type] = {} + + for mask in selected_masks: + mask_str = str(mask) + data_shape = np.array(results[0][loss_type][mask_str]).shape + accumulated_data = np.zeros(data_shape) + + # Sum across all prompts + for prompt_result in results: + accumulated_data += np.array(prompt_result[loss_type][mask_str]) + + # Average by dividing by number of prompts + averaged_data = accumulated_data / len(results) + averaged_losses[loss_type][mask_str] = averaged_data + + return averaged_losses + + +def select_best_mask_strategy( + averaged_results: dict[str, dict[str, np.ndarray]], + selected_masks: list[list[int]], + skip_time_steps: int = 12, + timesteps: int = 50, + head_num: int = 40, +) -> tuple[dict[str, list[int]], float, dict[str, int]]: + """Select the best mask strategy for each head based on loss minimization.""" + best_mask_strategy: dict[str, list[int]] = {} + loss_type = "L2_loss" + # Get the shape of time steps and layers + layers = len(averaged_results[loss_type][str(selected_masks[0])][0]) + + # Counter for sparsity calculation + total_tokens = 0 # total number of masked tokens + total_length = 0 # total sequence length + + strategy_counts: dict[str, int] = {str(strategy): 0 for strategy in selected_masks} + full_attn_strategy = selected_masks[-1] # Last strategy is full attention + print(f"Strategy {full_attn_strategy}, skip first {skip_time_steps} steps ") + + for t in range(timesteps): + for layer_idx in range(layers): + for h in range(head_num): + if t < skip_time_steps: # First steps use full attention + strategy = full_attn_strategy + else: + # Get losses for this head across all strategies + head_losses = [] + for strategy in selected_masks[:-1]: # Exclude full attention + head_losses.append( + averaged_results[loss_type][str(strategy)][t][layer_idx][h] + ) + + # Find which strategy gives minimum loss + best_strategy_idx = np.argmin(head_losses) + strategy = selected_masks[best_strategy_idx] + + best_mask_strategy[f"{t}_{layer_idx}_{h}"] = strategy + + # Calculate sparsity + nums = strategy # strategy is already a list of numbers + total_tokens += ( + nums[0] * nums[1] * nums[2] + ) # masked tokens for chosen strategy + total_length += ( + full_attn_strategy[0] + * full_attn_strategy[1] + * full_attn_strategy[2] + ) + + # Count strategy usage + strategy_counts[str(strategy)] += 1 + + overall_sparsity = 1 - total_tokens / total_length + + return best_mask_strategy, overall_sparsity, strategy_counts + + +def save_mask_search_results( + mask_search_final_result: list[dict[str, list[float]]], + prompt: str, + mask_strategies: list[str], + output_dir: str = "output/mask_search_result/", +) -> str | None: + if not mask_search_final_result: + print("No mask search results to save") + return None + + # Create result dictionary with defaultdict for nested lists + mask_search_dict: dict[str, dict[str, list[list[float]]]] = { + "L2_loss": defaultdict(list), + "L1_loss": defaultdict(list), + } + + mask_selected = list(range(len(mask_strategies))) + selected_masks: list[list[int]] = [] + for index in mask_selected: + mask = mask_strategies[index] + masks_list = [int(x) for x in mask.split(",")] + selected_masks.append(masks_list) + + # Process each mask strategy + for i, mask_strategy in enumerate(selected_masks): + mask_strategy_str = str(mask_strategy) + # Process L2 loss + step_results: list[list[float]] = [] + for step_data in mask_search_final_result: + if isinstance(step_data, dict) and "L2_loss" in step_data: + layer_losses = [float(loss) for loss in step_data["L2_loss"]] + step_results.append(layer_losses) + mask_search_dict["L2_loss"][mask_strategy_str] = step_results + + step_results = [] + for step_data in mask_search_final_result: + if isinstance(step_data, dict) and "L1_loss" in step_data: + layer_losses = [float(loss) for loss in step_data["L1_loss"]] + step_results.append(layer_losses) + mask_search_dict["L1_loss"][mask_strategy_str] = step_results + + # Create the output directory if it doesn't exist + os.makedirs(output_dir, exist_ok=True) + + # Create a filename based on the first 20 characters of the prompt + filename = prompt[:50].replace(" ", "_") + filepath = os.path.join(output_dir, f"mask_search_{filename}.json") + + # Save the results to a JSON file + with open(filepath, "w") as f: + json.dump(mask_search_dict, f, indent=4) + + print(f"Successfully saved mask research results to {filepath}") + + return filepath diff --git a/python/sglang/multimodal_gen/runtime/layers/attention/__init__.py b/python/sglang/multimodal_gen/runtime/layers/attention/__init__.py new file mode 100644 index 000000000..1b40782be --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/layers/attention/__init__.py @@ -0,0 +1,28 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import ( + AttentionBackend, + AttentionMetadata, + AttentionMetadataBuilder, +) +from sglang.multimodal_gen.runtime.layers.attention.layer import ( + LocalAttention, + UlyssesAttention, + UlyssesAttention_VSA, + USPAttention, +) +from sglang.multimodal_gen.runtime.layers.attention.selector import get_attn_backend + +__all__ = [ + "USPAttention", + "LocalAttention", + "UlyssesAttention", + "UlyssesAttention_VSA", + "AttentionBackend", + "AttentionMetadata", + "AttentionMetadataBuilder", + # "AttentionState", + "get_attn_backend", +] diff --git a/python/sglang/multimodal_gen/runtime/layers/attention/backends/__init__.py b/python/sglang/multimodal_gen/runtime/layers/attention/backends/__init__.py new file mode 100644 index 000000000..af2eb7d10 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/layers/attention/backends/__init__.py @@ -0,0 +1 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo diff --git a/python/sglang/multimodal_gen/runtime/layers/attention/backends/aiter.py b/python/sglang/multimodal_gen/runtime/layers/attention/backends/aiter.py new file mode 100644 index 000000000..b96aad6a4 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/layers/attention/backends/aiter.py @@ -0,0 +1,101 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +import aiter +import torch + +from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, + AttentionMetadataBuilder, +) + + +class AITerBackend(AttentionBackend): + """ + Backend for AITemplate attention implementation. + """ + + @staticmethod + def get_name() -> str: + return "AITER" + + @staticmethod + def get_impl_cls() -> type["AITerImpl"]: + return AITerImpl + + @staticmethod + def get_metadata_cls() -> type["AttentionMetadata"]: + # AITer backend does not require special metadata. + return AttentionMetadata + + @staticmethod + def get_builder_cls() -> type["AttentionMetadataBuilder"]: + raise NotImplementedError("AITer backend does not have a metadata builder.") + + +class AITerImpl(AttentionImpl): + """ + Implementation of attention using AITemplate. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + softmax_scale: float, + causal: bool = False, + num_kv_heads: int | None = None, + prefix: str = "", + dropout_p: float = 0.0, + **extra_impl_args, + ) -> None: + super().__init__( + num_heads=num_heads, + head_size=head_size, + softmax_scale=softmax_scale, + causal=causal, + num_kv_heads=num_kv_heads, + prefix=prefix, + **extra_impl_args, + ) + if num_kv_heads is not None and num_kv_heads != num_heads: + raise NotImplementedError( + "AITer backend does not support Grouped Query Attention yet." + ) + self.causal = causal + self.dropout_p = dropout_p + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AttentionMetadata | None = None, + ) -> torch.Tensor: + """ + Performs attention using aiter.flash_attn_func. + + Args: + query: Query tensor of shape [batch_size, num_heads, seq_len, head_dim] + key: Key tensor of shape [batch_size, num_heads, seq_len, head_dim] + value: Value tensor of shape [batch_size, num_heads, seq_len, head_dim] + attn_metadata: Metadata for the attention operation (unused). + + Returns: + Output tensor of shape [batch_size, num_heads, seq_len, head_dim] + """ + # aiter.flash_attn_func expects tensors in [B, H, S, D] layout, + # which is what ring_attn provides. + output, _ = aiter.flash_attn_func( + query, + key, + value, + dropout_p=self.dropout_p, + causal=self.causal, + return_attn_probs=False, + return_lse=True, + ) + return output diff --git a/python/sglang/multimodal_gen/runtime/layers/attention/backends/attention_backend.py b/python/sglang/multimodal_gen/runtime/layers/attention/backends/attention_backend.py new file mode 100644 index 000000000..3463ef05c --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/layers/attention/backends/attention_backend.py @@ -0,0 +1,180 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/attention/backends/abstract.py + +from abc import ABC, abstractmethod +from dataclasses import dataclass, fields +from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar + +if TYPE_CHECKING: + pass + +import torch + + +class AttentionBackend(ABC): + """Abstract class for attention backends.""" + + # For some attention backends, we allocate an output tensor before + # calling the custom op. When piecewise cudagraph is enabled, this + # makes sure the output tensor is allocated inside the cudagraph. + accept_output_buffer: bool = False + + @staticmethod + @abstractmethod + def get_name() -> str: + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_impl_cls() -> type["AttentionImpl"]: + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_metadata_cls() -> type["AttentionMetadata"]: + raise NotImplementedError + + # @staticmethod + # @abstractmethod + # def get_state_cls() -> Type["AttentionState"]: + # raise NotImplementedError + + # @classmethod + # def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata": + # return cls.get_metadata_cls()(*args, **kwargs) + + @staticmethod + @abstractmethod + def get_builder_cls() -> type["AttentionMetadataBuilder"]: + return None + + +@dataclass +class AttentionMetadata: + """Attention metadata for prefill and decode batched together.""" + + # Current step of diffusion process + current_timestep: int + + def asdict_zerocopy(self, skip_fields: set[str] | None = None) -> dict[str, Any]: + """Similar to dataclasses.asdict, but avoids deepcopying.""" + if skip_fields is None: + skip_fields = set() + # Note that if we add dataclasses as fields, they will need + # similar handling. + return { + field.name: getattr(self, field.name) + for field in fields(self) + if field.name not in skip_fields + } + + +T = TypeVar("T", bound=AttentionMetadata) + + +class AttentionMetadataBuilder(ABC, Generic[T]): + """Abstract class for attention metadata builders.""" + + @abstractmethod + def __init__(self) -> None: + """Create the builder, remember some configuration and parameters.""" + raise NotImplementedError + + @abstractmethod + def prepare(self) -> None: + """Prepare for one batch.""" + raise NotImplementedError + + @abstractmethod + def build( + self, + **kwargs: dict[str, Any], + ) -> AttentionMetadata: + """Build attention metadata with on-device tensors.""" + raise NotImplementedError + + +class AttentionLayer(Protocol): + + _k_scale: torch.Tensor + _v_scale: torch.Tensor + _k_scale_float: float + _v_scale_float: float + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: ... + + +class AttentionImpl(ABC, Generic[T]): + + @abstractmethod + def __init__( + self, + num_heads: int, + head_size: int, + softmax_scale: float, + causal: bool = False, + num_kv_heads: int | None = None, + prefix: str = "", + **extra_impl_args, + ) -> None: + raise NotImplementedError + + def preprocess_qkv(self, qkv: torch.Tensor, attn_metadata: T) -> torch.Tensor: + """Preprocess QKV tensor before performing attention operation. + + Default implementation returns the tensor unchanged. + Subclasses can override this to implement custom preprocessing + like reshaping, tiling, scaling, or other transformations. + + Called AFTER all_to_all for distributed attention + + Args: + qkv: The query-key-value tensor + attn_metadata: Metadata for the attention operation + + Returns: + Processed QKV tensor + """ + return qkv + + def postprocess_output( + self, + output: torch.Tensor, + attn_metadata: T, + ) -> torch.Tensor: + """Postprocess the output tensor after the attention operation. + + Default implementation returns the tensor unchanged. + Subclasses can override this to implement custom postprocessing + like untiling, scaling, or other transformations. + + Called BEFORE all_to_all for distributed attention + + Args: + output: The output tensor from the attention operation + attn_metadata: Metadata for the attention operation + + Returns: + Postprocessed output tensor + """ + + return output + + @abstractmethod + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: T, + ) -> torch.Tensor: + raise NotImplementedError diff --git a/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn.py b/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn.py new file mode 100644 index 000000000..ee6cd38b8 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn.py @@ -0,0 +1,132 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass +from typing import Any + +import torch + +from sglang.multimodal_gen.runtime.managers.forward_context import get_forward_context +from sglang.srt.layers.attention.flashattention_backend import FlashAttentionMetadata + +try: + from sgl_kernel.flash_attn import flash_attn_varlen_func + + # flash_attn 3 no longer have a different API, see following commit: + # https://github.com/Dao-AILab/flash-attention/commit/ed209409acedbb2379f870bbd03abce31a7a51b7 + flash_attn_func = flash_attn_varlen_func +except ImportError as e: + raise e + +from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, + AttentionMetadataBuilder, +) +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +@dataclass +class FlashAttentionMetadata: + # Sequence lengths for the forward batch + # Maximum sequence length for query + max_seqlen_q: int = 1 + # Maximum sequence length for key + max_seqlen_k: int = 0 + # Cumulative sequence lengths for query + cu_seqlens_q: torch.Tensor = None + # Cumulative sequence lengths for key + cu_seqlens_k: torch.Tensor = None + + +class FlashAttentionMetadataBuilder(AttentionMetadataBuilder): + + def __init__(self): + pass + + def prepare(self): + pass + + def build( # type: ignore + self, + raw_latent_shape=list, + **kwargs: dict[str, Any], + ) -> FlashAttentionMetadata: + # TODO: put empty values here to be set at first-run, since the q_len calculation can be complicated + return FlashAttentionMetadata(max_seqlen_q=None, max_seqlen_k=None) + + +class FlashAttentionBackend(AttentionBackend): + accept_output_buffer: bool = True + + @staticmethod + def get_supported_head_sizes() -> list[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] + + @staticmethod + def get_name() -> str: + return "FLASH_ATTN" + + @staticmethod + def get_impl_cls() -> type["FlashAttentionImpl"]: + return FlashAttentionImpl + + @staticmethod + def get_metadata_cls() -> type["AttentionMetadata"]: + raise NotImplementedError + + @staticmethod + def get_builder_cls() -> type["AttentionMetadataBuilder"]: + return FlashAttentionMetadataBuilder + + +class FlashAttentionImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + causal: bool, + softmax_scale: float, + num_kv_heads: int | None = None, + prefix: str = "", + **extra_impl_args, + ) -> None: + self.causal = causal + self.softmax_scale = softmax_scale + self.attention_metadata = FlashAttentionMetadata() + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AttentionMetadata = None, + *, + return_softmax_lse: bool = False, + ): + attn_metadata: FlashAttentionMetadata = get_forward_context().attn_metadata + if attn_metadata is not None and attn_metadata.max_seqlen_q is None: + attn_metadata.max_seqlen_q = query.shape[1] + attn_metadata.max_seqlen_k = key.shape[1] + max_seqlen_q = attn_metadata.max_seqlen_q + max_seqlen_k = attn_metadata.max_seqlen_k + else: + max_seqlen_q = query.shape[1] + max_seqlen_k = key.shape[1] + output = flash_attn_func( + q=query, # type: ignore[no-untyped-call] + k=key, + v=value, + cu_seqlens_q=None, + cu_seqlens_k=None, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + softmax_scale=self.softmax_scale, + causal=self.causal, + return_softmax_lse=return_softmax_lse, + ) + return output diff --git a/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn_2.py b/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn_2.py new file mode 100644 index 000000000..05f30c085 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/layers/attention/backends/flash_attn_2.py @@ -0,0 +1,78 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, + AttentionMetadataBuilder, +) +from sglang.multimodal_gen.runtime.layers.attention.backends.flash_attn import ( + flash_attn_func, +) +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class FlashAttention2Backend(AttentionBackend): + accept_output_buffer: bool = True + + @staticmethod + def get_supported_head_sizes() -> list[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] + + @staticmethod + def get_name() -> str: + return "FA3" + + @staticmethod + def get_impl_cls() -> type["FlashAttention2Impl"]: + return FlashAttention2Impl + + @staticmethod + def get_metadata_cls() -> type["AttentionMetadata"]: + raise NotImplementedError + + @staticmethod + def get_builder_cls() -> type["AttentionMetadataBuilder"]: + raise NotImplementedError + + +class FlashAttention2Impl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + causal: bool, + softmax_scale: float, + num_kv_heads: int | None = None, + prefix: str = "", + **extra_impl_args, + ) -> None: + self.causal = causal + self.softmax_scale = softmax_scale + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AttentionMetadata, + ): + output = flash_attn_func( + q=query, # type: ignore[no-untyped-call] + k=key, + v=value, + cu_seqlens_q=None, + cu_seqlens_k=None, + max_seqlen_q=None, + max_seqlen_k=None, + softmax_scale=self.softmax_scale, + causal=self.causal, + ) + return output diff --git a/python/sglang/multimodal_gen/runtime/layers/attention/backends/sage_attn.py b/python/sglang/multimodal_gen/runtime/layers/attention/backends/sage_attn.py new file mode 100644 index 000000000..3563ddd18 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/layers/attention/backends/sage_attn.py @@ -0,0 +1,70 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +import torch +from sageattention import sageattn + +from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import ( # FlashAttentionMetadata, + AttentionBackend, + AttentionImpl, + AttentionMetadata, +) +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class SageAttentionBackend(AttentionBackend): + + accept_output_buffer: bool = True + + @staticmethod + def get_supported_head_sizes() -> list[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] + + @staticmethod + def get_name() -> str: + return "SAGE_ATTN" + + @staticmethod + def get_impl_cls() -> type["SageAttentionImpl"]: + return SageAttentionImpl + + # @staticmethod + # def get_metadata_cls() -> Type["AttentionMetadata"]: + # return FlashAttentionMetadata + + +class SageAttentionImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + causal: bool, + softmax_scale: float, + num_kv_heads: int | None = None, + prefix: str = "", + **extra_impl_args, + ) -> None: + self.causal = causal + self.softmax_scale = softmax_scale + self.dropout = extra_impl_args.get("dropout_p", 0.0) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + output = sageattn( + query, + key, + value, + # since input is (batch_size, seq_len, head_num, head_dim) + tensor_layout="NHD", + is_causal=self.causal, + ) + return output diff --git a/python/sglang/multimodal_gen/runtime/layers/attention/backends/sage_attn3.py b/python/sglang/multimodal_gen/runtime/layers/attention/backends/sage_attn3.py new file mode 100644 index 000000000..fd5b6f2b6 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/layers/attention/backends/sage_attn3.py @@ -0,0 +1,78 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, + AttentionMetadataBuilder, +) +from sglang.multimodal_gen.runtime.layers.attention.backends.sageattn.api import ( + sageattn_blackwell, +) +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class SageAttention3Backend(AttentionBackend): + + accept_output_buffer: bool = True + + @staticmethod + def get_supported_head_sizes() -> list[int]: + return [64, 128, 256] + + @staticmethod + def get_name() -> str: + return "SAGE_ATTN_THREE" + + @staticmethod + def get_impl_cls() -> type["SageAttention3Impl"]: + return SageAttention3Impl + + @staticmethod + def get_metadata_cls() -> type["AttentionMetadata"]: + raise NotImplementedError + + @staticmethod + def get_builder_cls() -> type["AttentionMetadataBuilder"]: + raise NotImplementedError + + # @staticmethod + # def get_metadata_cls() -> Type["AttentionMetadata"]: + # return FlashAttentionMetadata + + +class SageAttention3Impl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + causal: bool, + softmax_scale: float, + num_kv_heads: int | None = None, + prefix: str = "", + **extra_impl_args, + ) -> None: + self.causal = causal + self.softmax_scale = softmax_scale + self.dropout = extra_impl_args.get("dropout_p", 0.0) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + output = sageattn_blackwell(query, key, value, is_causal=self.causal) + output = output.transpose(1, 2) + return output diff --git a/python/sglang/multimodal_gen/runtime/layers/attention/backends/sdpa.py b/python/sglang/multimodal_gen/runtime/layers/attention/backends/sdpa.py new file mode 100644 index 000000000..bfa3b430d --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/layers/attention/backends/sdpa.py @@ -0,0 +1,77 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import ( # FlashAttentionMetadata, + AttentionBackend, + AttentionImpl, + AttentionMetadata, +) +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class SDPABackend(AttentionBackend): + + accept_output_buffer: bool = True + + @staticmethod + def get_supported_head_sizes() -> list[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] + + @staticmethod + def get_name() -> str: + return "SDPA" + + @staticmethod + def get_impl_cls() -> type["SDPAImpl"]: + return SDPAImpl + + # @staticmethod + # def get_metadata_cls() -> Type["AttentionMetadata"]: + # return FlashAttentionMetadata + + +class SDPAImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + causal: bool, + softmax_scale: float, + num_kv_heads: int | None = None, + prefix: str = "", + **extra_impl_args, + ) -> None: + self.causal = causal + self.softmax_scale = softmax_scale + self.dropout = extra_impl_args.get("dropout_p", 0.0) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + # transpose to bs, heads, seq_len, head_dim + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + attn_kwargs = { + "attn_mask": None, + "dropout_p": self.dropout, + "is_causal": self.causal, + "scale": self.softmax_scale, + } + if query.shape[1] != key.shape[1]: + attn_kwargs["enable_gqa"] = True + output = torch.nn.functional.scaled_dot_product_attention( + query, key, value, **attn_kwargs + ) + output = output.transpose(1, 2) + return output diff --git a/python/sglang/multimodal_gen/runtime/layers/attention/backends/sliding_tile_attn.py b/python/sglang/multimodal_gen/runtime/layers/attention/backends/sliding_tile_attn.py new file mode 100644 index 000000000..f7917c520 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/layers/attention/backends/sliding_tile_attn.py @@ -0,0 +1,313 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +import json +from dataclasses import dataclass +from typing import Any + +import torch +from einops import rearrange + +import sglang.multimodal_gen.envs as envs +from sglang.multimodal_gen.runtime.distributed import get_sp_group +from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, + AttentionMetadataBuilder, +) +from sglang.multimodal_gen.runtime.managers.forward_context import ( + ForwardContext, + get_forward_context, +) +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import dict_to_3d_list + +try: + from st_attn import sliding_tile_attention + + st_attn_backend_available = True +except Exception: + st_attn_backend_available = False + +logger = init_logger(__name__) + + +class RangeDict(dict): + + def __getitem__(self, item: int) -> str: + for key in self.keys(): + if isinstance(key, tuple): + low, high = key + if low <= item <= high: + return str(super().__getitem__(key)) + elif key == item: + return str(super().__getitem__(key)) + raise KeyError(f"seq_len {item} not supported for STA") + + +class SlidingTileAttentionBackend(AttentionBackend): + accept_output_buffer: bool = True + + @staticmethod + def get_supported_head_sizes() -> list[int]: + # TODO(will-refactor): check this + return [32, 64, 96, 128, 160, 192, 224, 256] + + @staticmethod + def get_name() -> str: + return "SLIDING_TILE_ATTN" + + @staticmethod + def get_impl_cls() -> type["SlidingTileAttentionImpl"]: + return SlidingTileAttentionImpl + + @staticmethod + def get_metadata_cls() -> type["SlidingTileAttentionMetadata"]: + return SlidingTileAttentionMetadata + + @staticmethod + def get_builder_cls() -> type["SlidingTileAttentionMetadataBuilder"]: + return SlidingTileAttentionMetadataBuilder + + +@dataclass +class SlidingTileAttentionMetadata(AttentionMetadata): + current_timestep: int + STA_param: list[ + list[Any] + ] # each timestep with one metadata, shape [num_layers, num_heads] + + +class SlidingTileAttentionMetadataBuilder(AttentionMetadataBuilder): + + def __init__(self): + pass + + def prepare(self): + pass + + def build( # type: ignore + self, + STA_param: list[list[Any]], + current_timestep: int, + **kwargs: dict[str, Any], + ) -> SlidingTileAttentionMetadata: + param = STA_param + if param is None: + return SlidingTileAttentionMetadata( + current_timestep=current_timestep, STA_param=[] + ) + return SlidingTileAttentionMetadata( + current_timestep=current_timestep, STA_param=param[current_timestep] + ) + + +class SlidingTileAttentionImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + causal: bool, + softmax_scale: float, + num_kv_heads: int | None = None, + prefix: str = "", + **extra_impl_args, + ) -> None: + if not st_attn_backend_available: + raise ValueError("st attn not supported") + # TODO(will-refactor): for now this is the mask strategy, but maybe we should + # have a more general config for STA? + config_file = envs.SGL_DIFFUSION_ATTENTION_CONFIG + if config_file is None: + raise ValueError("SGL_DIFFUSION_ATTENTION_CONFIG is not set") + + # TODO(kevin): get mask strategy for different STA modes + with open(config_file) as f: + mask_strategy = json.load(f) + self.mask_strategy = dict_to_3d_list(mask_strategy) + + self.prefix = prefix + sp_group = get_sp_group() + self.sp_size = sp_group.world_size + # STA config + self.STA_base_tile_size = [6, 8, 8] + self.dit_seq_shape_mapping = RangeDict( + { + (115200, 115456): "30x48x80", + 82944: "36x48x48", + 69120: "18x48x80", + } + ) + self.full_window_mapping = { + "30x48x80": [5, 6, 10], + "36x48x48": [6, 6, 6], + "18x48x80": [3, 6, 10], + } + + def tile(self, x: torch.Tensor) -> torch.Tensor: + return rearrange( + x, + "b (n_t ts_t n_h ts_h n_w ts_w) h d -> b (n_t n_h n_w ts_t ts_h ts_w) h d", + n_t=self.full_window_size[0], + n_h=self.full_window_size[1], + n_w=self.full_window_size[2], + ts_t=self.STA_base_tile_size[0], + ts_h=self.STA_base_tile_size[1], + ts_w=self.STA_base_tile_size[2], + ) + + def untile(self, x: torch.Tensor) -> torch.Tensor: + x = rearrange( + x, + "b (n_t n_h n_w ts_t ts_h ts_w) h d -> b (n_t ts_t n_h ts_h n_w ts_w) h d", + n_t=self.full_window_size[0], + n_h=self.full_window_size[1], + n_w=self.full_window_size[2], + ts_t=self.STA_base_tile_size[0], + ts_h=self.STA_base_tile_size[1], + ts_w=self.STA_base_tile_size[2], + ) + return x + + def preprocess_qkv( + self, + qkv: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + img_sequence_length = qkv.shape[1] + self.dit_seq_shape_str = self.dit_seq_shape_mapping[img_sequence_length] + self.full_window_size = self.full_window_mapping[self.dit_seq_shape_str] + self.dit_seq_shape_int = list(map(int, self.dit_seq_shape_str.split("x"))) + self.img_seq_length = ( + self.dit_seq_shape_int[0] + * self.dit_seq_shape_int[1] + * self.dit_seq_shape_int[2] + ) + return self.tile(qkv) + + def postprocess_output( + self, + output: torch.Tensor, + attn_metadata: SlidingTileAttentionMetadata, + ) -> torch.Tensor: + return self.untile(output) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attn_metadata: SlidingTileAttentionMetadata, + ) -> torch.Tensor: + if self.mask_strategy is None: + raise ValueError("mask_strategy cannot be None for SlidingTileAttention") + if self.mask_strategy[0] is None: + raise ValueError("mask_strategy[0] cannot be None for SlidingTileAttention") + + timestep = attn_metadata.current_timestep + forward_context: ForwardContext = get_forward_context() + forward_batch = forward_context.forward_batch + if forward_batch is None: + raise ValueError("forward_batch cannot be None") + # pattern:'.double_blocks.0.attn.impl' or '.single_blocks.0.attn.impl' + layer_idx = int(self.prefix.split(".")[-3]) + if attn_metadata.STA_param is None or len(attn_metadata.STA_param) <= layer_idx: + raise ValueError("Invalid STA_param") + STA_param = attn_metadata.STA_param[layer_idx] + + text_length = q.shape[1] - self.img_seq_length + has_text = text_length > 0 + + query = q.transpose(1, 2).contiguous() + key = k.transpose(1, 2).contiguous() + value = v.transpose(1, 2).contiguous() + + head_num = query.size(1) + sp_group = get_sp_group() + current_rank = sp_group.rank_in_group + start_head = current_rank * head_num + + # searching or tuning mode + if len(STA_param) < head_num * sp_group.world_size: + sparse_attn_hidden_states_all = [] + full_mask_window = STA_param[-1] + for window_size in STA_param[:-1]: + sparse_hidden_states = sliding_tile_attention( + query, + key, + value, + [window_size] * head_num, + text_length, + has_text, + self.dit_seq_shape_str, + ).transpose(1, 2) + sparse_attn_hidden_states_all.append(sparse_hidden_states) + + hidden_states = sliding_tile_attention( + query, + key, + value, + [full_mask_window] * head_num, + text_length, + has_text, + self.dit_seq_shape_str, + ).transpose(1, 2) + + attn_L2_loss = [] + attn_L1_loss = [] + # average loss across all heads + for sparse_attn_hidden_states in sparse_attn_hidden_states_all: + # L2 loss + attn_L2_loss_ = ( + torch.mean( + (sparse_attn_hidden_states.float() - hidden_states.float()) + ** 2, + dim=[0, 1, 3], + ) + .cpu() + .numpy() + ) + attn_L2_loss_ = [round(float(x), 6) for x in attn_L2_loss_] + attn_L2_loss.append(attn_L2_loss_) + # L1 loss + attn_L1_loss_ = ( + torch.mean( + torch.abs( + sparse_attn_hidden_states.float() - hidden_states.float() + ), + dim=[0, 1, 3], + ) + .cpu() + .numpy() + ) + attn_L1_loss_ = [round(float(x), 6) for x in attn_L1_loss_] + attn_L1_loss.append(attn_L1_loss_) + + layer_loss_save = {"L2_loss": attn_L2_loss, "L1_loss": attn_L1_loss} + + if forward_batch.is_cfg_negative: + if forward_batch.mask_search_final_result_neg is not None: + forward_batch.mask_search_final_result_neg[timestep].append( + layer_loss_save + ) + else: + if forward_batch.mask_search_final_result_pos is not None: + forward_batch.mask_search_final_result_pos[timestep].append( + layer_loss_save + ) + else: + windows = [STA_param[head_idx + start_head] for head_idx in range(head_num)] + + hidden_states = sliding_tile_attention( + query, + key, + value, + windows, + text_length, + has_text, + self.dit_seq_shape_str, + ).transpose(1, 2) + + return hidden_states diff --git a/python/sglang/multimodal_gen/runtime/layers/attention/backends/video_sparse_attn.py b/python/sglang/multimodal_gen/runtime/layers/attention/backends/video_sparse_attn.py new file mode 100644 index 000000000..6fe342922 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/layers/attention/backends/video_sparse_attn.py @@ -0,0 +1,331 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +import functools +import math +from dataclasses import dataclass + +import torch + +try: + from vsa import video_sparse_attn +except ImportError: + video_sparse_attn = None + +from typing import Any + +from sglang.multimodal_gen.runtime.distributed import get_sp_group +from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, + AttentionMetadataBuilder, +) +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) +VSA_TILE_SIZE = (4, 4, 4) + + +@functools.lru_cache(maxsize=10) +def get_tile_partition_indices( + dit_seq_shape: tuple[int, int, int], + tile_size: tuple[int, int, int], + device: torch.device, +) -> torch.LongTensor: + T, H, W = dit_seq_shape + ts, hs, ws = tile_size + indices = torch.arange(T * H * W, device=device, dtype=torch.long).reshape(T, H, W) + ls = [] + for t in range(math.ceil(T / ts)): + for h in range(math.ceil(H / hs)): + for w in range(math.ceil(W / ws)): + ls.append( + indices[ + t * ts : min(t * ts + ts, T), + h * hs : min(h * hs + hs, H), + w * ws : min(w * ws + ws, W), + ].flatten() + ) + index = torch.cat(ls, dim=0) + return index + + +@functools.lru_cache(maxsize=10) +def get_reverse_tile_partition_indices( + dit_seq_shape: tuple[int, int, int], + tile_size: tuple[int, int, int], + device: torch.device, +) -> torch.LongTensor: + return torch.argsort(get_tile_partition_indices(dit_seq_shape, tile_size, device)) + + +@functools.lru_cache(maxsize=10) +def construct_variable_block_sizes( + dit_seq_shape: tuple[int, int, int], + num_tiles: tuple[int, int, int], + device: torch.device, +) -> torch.LongTensor: + """ + Compute the number of valid (non‑padded) tokens inside every + (ts_t × ts_h × ts_w) tile after padding ‑‑ flattened in the order + (t‑tile, h‑tile, w‑tile) that `rearrange` uses. + + Returns + ------- + torch.LongTensor # shape: [∏ full_window_size] + """ + # unpack + t, h, w = dit_seq_shape + ts_t, ts_h, ts_w = VSA_TILE_SIZE + n_t, n_h, n_w = num_tiles + + def _sizes(dim_len: int, tile: int, n_tiles: int) -> torch.LongTensor: + """Vector with the size of each tile along one dimension.""" + sizes = torch.full((n_tiles,), tile, dtype=torch.int, device=device) + # size of last (possibly partial) tile + remainder = dim_len - (n_tiles - 1) * tile + sizes[-1] = remainder if remainder > 0 else tile + return sizes + + t_sizes = _sizes(t, ts_t, n_t) # [n_t] + h_sizes = _sizes(h, ts_h, n_h) # [n_h] + w_sizes = _sizes(w, ts_w, n_w) # [n_w] + + # broadcast‑multiply to get voxels per tile, then flatten + block_sizes = ( + t_sizes[:, None, None] # [n_t, 1, 1] + * h_sizes[None, :, None] # [1, n_h, 1] + * w_sizes[None, None, :] # [1, 1, n_w] + ).reshape( + -1 + ) # [n_t * n_h * n_w] + + return block_sizes + + +@functools.lru_cache(maxsize=10) +def get_non_pad_index( + variable_block_sizes: torch.LongTensor, + max_block_size: int, +): + n_win = variable_block_sizes.shape[0] + device = variable_block_sizes.device + starts_pad = torch.arange(n_win, device=device) * max_block_size + index_pad = ( + starts_pad[:, None] + torch.arange(max_block_size, device=device)[None, :] + ) + index_mask = ( + torch.arange(max_block_size, device=device)[None, :] + < variable_block_sizes[:, None] + ) + return index_pad[index_mask] + + +class VideoSparseAttentionBackend(AttentionBackend): + + accept_output_buffer: bool = True + + @staticmethod + def get_supported_head_sizes() -> list[int]: + return [64, 128] + + @staticmethod + def get_name() -> str: + return "VIDEO_SPARSE_ATTN" + + @staticmethod + def get_impl_cls() -> type["VideoSparseAttentionImpl"]: + return VideoSparseAttentionImpl + + @staticmethod + def get_metadata_cls() -> type["VideoSparseAttentionMetadata"]: + return VideoSparseAttentionMetadata + + @staticmethod + def get_builder_cls() -> type["VideoSparseAttentionMetadataBuilder"]: + return VideoSparseAttentionMetadataBuilder + + +@dataclass +class VideoSparseAttentionMetadata(AttentionMetadata): + current_timestep: int + dit_seq_shape: list[int] + VSA_sparsity: float + num_tiles: list[int] + total_seq_length: int + tile_partition_indices: torch.LongTensor + reverse_tile_partition_indices: torch.LongTensor + variable_block_sizes: torch.LongTensor + non_pad_index: torch.LongTensor + + # adaption for FastWan2.1-T2V-1.3B-Diffusers + # Sequence lengths for the forward batch + # Maximum sequence length for query + max_seqlen_q: int = 1 + # Maximum sequence length for key + max_seqlen_k: int = 0 + + +class VideoSparseAttentionMetadataBuilder(AttentionMetadataBuilder): + + def __init__(self): + pass + + def prepare(self): + pass + + def build( # type: ignore + self, + current_timestep: int, + raw_latent_shape: tuple[int, int, int], + patch_size: tuple[int, int, int], + VSA_sparsity: float, + device: torch.device, + **kwargs: dict[str, Any], + ) -> VideoSparseAttentionMetadata: + patch_size = patch_size + dit_seq_shape = ( + raw_latent_shape[0] // patch_size[0], + raw_latent_shape[1] // patch_size[1], + raw_latent_shape[2] // patch_size[2], + ) + + num_tiles = ( + math.ceil(dit_seq_shape[0] / VSA_TILE_SIZE[0]), + math.ceil(dit_seq_shape[1] / VSA_TILE_SIZE[1]), + math.ceil(dit_seq_shape[2] / VSA_TILE_SIZE[2]), + ) + total_seq_length = math.prod(dit_seq_shape) + + tile_partition_indices = get_tile_partition_indices( + dit_seq_shape, VSA_TILE_SIZE, device + ) + reverse_tile_partition_indices = get_reverse_tile_partition_indices( + dit_seq_shape, VSA_TILE_SIZE, device + ) + variable_block_sizes = construct_variable_block_sizes( + dit_seq_shape, num_tiles, device + ) + non_pad_index = get_non_pad_index( + variable_block_sizes, math.prod(VSA_TILE_SIZE) + ) + + return VideoSparseAttentionMetadata( + current_timestep=current_timestep, + dit_seq_shape=dit_seq_shape, # type: ignore + VSA_sparsity=VSA_sparsity, # type: ignore + num_tiles=num_tiles, # type: ignore + total_seq_length=total_seq_length, # type: ignore + tile_partition_indices=tile_partition_indices, # type: ignore + reverse_tile_partition_indices=reverse_tile_partition_indices, + variable_block_sizes=variable_block_sizes, + non_pad_index=non_pad_index, + ) + + +class VideoSparseAttentionImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + causal: bool, + softmax_scale: float, + num_kv_heads: int | None = None, + prefix: str = "", + **extra_impl_args, + ) -> None: + self.prefix = prefix + sp_group = get_sp_group() + self.sp_size = sp_group.world_size + + def tile( + self, + x: torch.Tensor, + num_tiles: list[int], + tile_partition_indices: torch.LongTensor, + non_pad_index: torch.LongTensor, + ) -> torch.Tensor: + t_padded_size = num_tiles[0] * VSA_TILE_SIZE[0] + h_padded_size = num_tiles[1] * VSA_TILE_SIZE[1] + w_padded_size = num_tiles[2] * VSA_TILE_SIZE[2] + + x_padded = torch.zeros( + ( + x.shape[0], + t_padded_size * h_padded_size * w_padded_size, + x.shape[-2], + x.shape[-1], + ), + device=x.device, + dtype=x.dtype, + ) + x_padded[:, non_pad_index] = x[:, tile_partition_indices] + return x_padded + + def untile( + self, + x: torch.Tensor, + reverse_tile_partition_indices: torch.LongTensor, + non_pad_index: torch.LongTensor, + ) -> torch.Tensor: + x = x[:, non_pad_index][:, reverse_tile_partition_indices] + return x + + def preprocess_qkv( + self, + qkv: torch.Tensor, + attn_metadata: VideoSparseAttentionMetadata, + ) -> torch.Tensor: + return self.tile( + qkv, + attn_metadata.num_tiles, + attn_metadata.tile_partition_indices, + attn_metadata.non_pad_index, + ) + + def postprocess_output( + self, + output: torch.Tensor, + attn_metadata: VideoSparseAttentionMetadata, + ) -> torch.Tensor: + return self.untile( + output, + attn_metadata.reverse_tile_partition_indices, + attn_metadata.non_pad_index, + ) + + def forward( # type: ignore[override] + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + gate_compress: torch.Tensor, + attn_metadata: VideoSparseAttentionMetadata, + ) -> torch.Tensor: + query = query.transpose(1, 2).contiguous() + key = key.transpose(1, 2).contiguous() + value = value.transpose(1, 2).contiguous() + gate_compress = gate_compress.transpose(1, 2).contiguous() + + VSA_sparsity = attn_metadata.VSA_sparsity + + cur_topk = math.ceil( + (1 - VSA_sparsity) + * (attn_metadata.total_seq_length / math.prod(VSA_TILE_SIZE)) + ) + + if video_sparse_attn is None: + raise NotImplementedError("video_sparse_attn is not installed") + hidden_states = video_sparse_attn( + query, + key, + value, + variable_block_sizes=attn_metadata.variable_block_sizes, + topk=cur_topk, + block_size=VSA_TILE_SIZE, + compress_attn_weight=gate_compress, + ).transpose(1, 2) + + return hidden_states diff --git a/python/sglang/multimodal_gen/runtime/layers/attention/backends/vmoba.py b/python/sglang/multimodal_gen/runtime/layers/attention/backends/vmoba.py new file mode 100644 index 000000000..5709601d2 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/layers/attention/backends/vmoba.py @@ -0,0 +1,258 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +import re +from dataclasses import dataclass + +import torch +from einops import rearrange +from kernel.attn.vmoba_attn.vmoba import ( + moba_attn_varlen, + process_moba_input, + process_moba_output, +) + +from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, + AttentionMetadataBuilder, +) +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class VMOBAAttentionBackend(AttentionBackend): + + accept_output_buffer: bool = True + + @staticmethod + def get_name() -> str: + return "VMOBA_ATTN" + + @staticmethod + def get_impl_cls() -> type["VMOBAAttentionImpl"]: + return VMOBAAttentionImpl + + @staticmethod + def get_metadata_cls() -> type["VideoMobaAttentionMetadata"]: + return VideoMobaAttentionMetadata + + @staticmethod + def get_builder_cls() -> type["VideoMobaAttentionMetadataBuilder"]: + return VideoMobaAttentionMetadataBuilder + + +@dataclass +class VideoMobaAttentionMetadata(AttentionMetadata): + current_timestep: int + + temporal_chunk_size: int + temporal_topk: int + spatial_chunk_size: tuple[int, int] + spatial_topk: int + st_chunk_size: tuple[int, int, int] + st_topk: int + + moba_select_mode: str + moba_threshold: float + moba_threshold_type: str + patch_resolution: list[int] + + first_full_step: int = 12 + first_full_layer: int = 0 + # temporal_layer -> spatial_layer -> st_layer + temporal_layer: int = 1 + spatial_layer: int = 1 + st_layer: int = 1 + + +def pad_input(hidden_states, indices, batch, seqlen): + """ + Arguments: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. + batch: int, batch size for the padded sequence. + seqlen: int, maximum sequence length for the padded sequence. + Return: + hidden_states: (batch, seqlen, ...) + """ + dim = hidden_states.shape[1:] + output = torch.zeros( + (batch * seqlen), *dim, device=hidden_states.device, dtype=hidden_states.dtype + ) + output[indices] = hidden_states + return rearrange(output, "(b s) ... -> b s ...", b=batch) + + +class VideoMobaAttentionMetadataBuilder(AttentionMetadataBuilder): + + def __init__(self): + pass + + def prepare(self): + pass + + def build( # type: ignore + self, + current_timestep: int, + raw_latent_shape: tuple[int, int, int], + patch_size: tuple[int, int, int], + temporal_chunk_size: int, + temporal_topk: int, + spatial_chunk_size: tuple[int, int], + spatial_topk: int, + st_chunk_size: tuple[int, int, int], + st_topk: int, + moba_select_mode: str = "threshold", + moba_threshold: float = 0.25, + moba_threshold_type: str = "query_head", + device: torch.device = None, + first_full_layer: int = 0, + first_full_step: int = 12, + temporal_layer: int = 1, + spatial_layer: int = 1, + st_layer: int = 1, + **kwargs, + ) -> VideoMobaAttentionMetadata: + if device is None: + device = torch.device("cpu") + assert ( + raw_latent_shape[0] % patch_size[0] == 0 + and raw_latent_shape[1] % patch_size[1] == 0 + and raw_latent_shape[2] % patch_size[2] == 0 + ), f"spatial patch_resolution {raw_latent_shape} should be divisible by patch_size {patch_size}" + patch_resolution = [ + t // pt for t, pt in zip(raw_latent_shape, patch_size, strict=False) + ] + + return VideoMobaAttentionMetadata( + current_timestep=current_timestep, + temporal_chunk_size=temporal_chunk_size, + temporal_topk=temporal_topk, + spatial_chunk_size=spatial_chunk_size, + spatial_topk=spatial_topk, + st_chunk_size=st_chunk_size, + st_topk=st_topk, + moba_select_mode=moba_select_mode, + moba_threshold=moba_threshold, + moba_threshold_type=moba_threshold_type, + patch_resolution=patch_resolution, + first_full_layer=first_full_layer, + first_full_step=first_full_step, + temporal_layer=temporal_layer, + spatial_layer=spatial_layer, + st_layer=st_layer, + ) + + +class VMOBAAttentionImpl(AttentionImpl): + + def __init__( + self, + num_heads, + head_size, + softmax_scale, + causal=False, + num_kv_heads=None, + prefix="", + **extra_impl_args, + ) -> None: + self.prefix = prefix + self.layer_idx = self._get_layer_idx(prefix) + + self.pad_input = pad_input + + def _get_layer_idx(self, prefix: str) -> int | None: + match = re.search(r"blocks\.(\d+)", prefix) + if not match: + raise ValueError(f"Invalid prefix: {prefix}") + return int(match.group(1)) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + """ + query: [B, L, H, D] + key: [B, L, H, D] + value: [B, L, H, D] + attn_metadata: AttentionMetadata + """ + batch_size, sequence_length, num_heads, head_dim = query.shape + + # select chunk type according to layer idx: + loop_layer_num = ( + attn_metadata.temporal_layer + + attn_metadata.spatial_layer + + attn_metadata.st_layer + ) + moba_layer = self.layer_idx - attn_metadata.first_full_layer + if moba_layer % loop_layer_num < attn_metadata.temporal_layer: + moba_chunk_size = attn_metadata.temporal_chunk_size + moba_topk = attn_metadata.temporal_topk + elif ( + moba_layer % loop_layer_num + < attn_metadata.temporal_layer + attn_metadata.spatial_layer + ): + moba_chunk_size = attn_metadata.spatial_chunk_size + moba_topk = attn_metadata.spatial_topk + elif ( + moba_layer % loop_layer_num + < attn_metadata.temporal_layer + + attn_metadata.spatial_layer + + attn_metadata.st_layer + ): + moba_chunk_size = attn_metadata.st_chunk_size + moba_topk = attn_metadata.st_topk + + query, chunk_size = process_moba_input( + query, attn_metadata.patch_resolution, moba_chunk_size + ) + key, chunk_size = process_moba_input( + key, attn_metadata.patch_resolution, moba_chunk_size + ) + value, chunk_size = process_moba_input( + value, attn_metadata.patch_resolution, moba_chunk_size + ) + max_seqlen = query.shape[1] + indices_q = torch.arange( + 0, query.shape[0] * query.shape[1], device=query.device + ) + cu_seqlens = torch.arange( + 0, + query.shape[0] * query.shape[1] + 1, + query.shape[1], + dtype=torch.int32, + device=query.device, + ) + query = rearrange(query, "b s ... -> (b s) ...") + key = rearrange(key, "b s ... -> (b s) ...") + value = rearrange(value, "b s ... -> (b s) ...") + + # current_timestep=attn_metadata.current_timestep + hidden_states = moba_attn_varlen( + query, + key, + value, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + moba_chunk_size=chunk_size, + moba_topk=moba_topk, + select_mode=attn_metadata.moba_select_mode, + simsum_threshold=attn_metadata.moba_threshold, + threshold_type=attn_metadata.moba_threshold_type, + ) + hidden_states = self.pad_input( + hidden_states, indices_q, batch_size, sequence_length + ) + hidden_states = process_moba_output( + hidden_states, attn_metadata.patch_resolution, moba_chunk_size + ) + + return hidden_states diff --git a/python/sglang/multimodal_gen/runtime/layers/attention/layer.py b/python/sglang/multimodal_gen/runtime/layers/attention/layer.py new file mode 100644 index 000000000..482ea4efc --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/layers/attention/layer.py @@ -0,0 +1,399 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from typing import Type + +import torch +import torch.nn as nn + +from sglang.multimodal_gen.runtime.distributed.communication_op import ( + sequence_model_parallel_all_gather, + sequence_model_parallel_all_to_all_4D, +) +from sglang.multimodal_gen.runtime.distributed.parallel_state import ( + get_ring_parallel_world_size, + get_sequence_parallel_world_size, + get_sp_parallel_rank, + get_sp_world_size, + get_ulysses_parallel_world_size, +) +from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import ( + AttentionImpl, +) +from sglang.multimodal_gen.runtime.layers.attention.selector import ( + backend_name_to_enum, + get_attn_backend, +) +from sglang.multimodal_gen.runtime.layers.usp import ( + _usp_input_all_to_all, + _usp_output_all_to_all, + ring_attn, +) +from sglang.multimodal_gen.runtime.managers.forward_context import ( + ForwardContext, + get_forward_context, +) +from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum +from sglang.multimodal_gen.utils import get_compute_dtype + + +class UlyssesAttention(nn.Module): + """Ulysses-style SequenceParallelism attention layer.""" + + def __init__( + self, + num_heads: int, + head_size: int, + num_kv_heads: int | None = None, + softmax_scale: float | None = None, + causal: bool = False, + supported_attention_backends: set[AttentionBackendEnum] | None = None, + prefix: str = "", + **extra_impl_args, + ) -> None: + super().__init__() + if softmax_scale is None: + self.softmax_scale = head_size**-0.5 + else: + self.softmax_scale = softmax_scale + + if num_kv_heads is None: + num_kv_heads = num_heads + + dtype = get_compute_dtype() + attn_backend = get_attn_backend( + head_size, dtype, supported_attention_backends=supported_attention_backends + ) + impl_cls = attn_backend.get_impl_cls() + + self.attn_impl = impl_cls( + num_heads=num_heads, + head_size=head_size, + causal=causal, + softmax_scale=self.softmax_scale, + num_kv_heads=num_kv_heads, + prefix=f"{prefix}.impl", + **extra_impl_args, + ) + self.num_heads = num_heads + self.head_size = head_size + self.num_kv_heads = num_kv_heads + self.backend = backend_name_to_enum(attn_backend.get_name()) + self.dtype = dtype + + @torch.compiler.disable + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + replicated_q: torch.Tensor | None = None, + replicated_k: torch.Tensor | None = None, + replicated_v: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """Forward pass for distributed attention. + + Args: + q (torch.Tensor): Query tensor [batch_size, seq_len, num_heads, head_dim] + k (torch.Tensor): Key tensor [batch_size, seq_len, num_heads, head_dim] + v (torch.Tensor): Value tensor [batch_size, seq_len, num_heads, head_dim] + replicated_q (Optional[torch.Tensor]): Replicated query tensor, typically for text tokens + replicated_k (Optional[torch.Tensor]): Replicated key tensor + replicated_v (Optional[torch.Tensor]): Replicated value tensor + + Returns: + Tuple[torch.Tensor, Optional[torch.Tensor]]: A tuple containing: + - o (torch.Tensor): Output tensor after attention for the main sequence + - replicated_o (Optional[torch.Tensor]): Output tensor for replicated tokens, if provided + """ + # Check input shapes + assert q.dim() == 4 and k.dim() == 4 and v.dim() == 4, "Expected 4D tensors" + batch_size, seq_len, num_heads, head_dim = q.shape + local_rank = get_sp_parallel_rank() + world_size = get_sp_world_size() + + forward_context: ForwardContext = get_forward_context() + ctx_attn_metadata = forward_context.attn_metadata + + # Stack QKV + qkv = torch.cat([q, k, v], dim=0) # [3, seq_len, num_heads, head_dim] + + # Redistribute heads across sequence dimension + qkv = sequence_model_parallel_all_to_all_4D(qkv, scatter_dim=2, gather_dim=1) + # Apply backend-specific preprocess_qkv + qkv = self.attn_impl.preprocess_qkv(qkv, ctx_attn_metadata) + + # Concatenate with replicated QKV if provided + if replicated_q is not None: + assert replicated_k is not None and replicated_v is not None + replicated_qkv = torch.cat( + [replicated_q, replicated_k, replicated_v], dim=0 + ) # [3, seq_len, num_heads, head_dim] + heads_per_rank = num_heads // world_size + replicated_qkv = replicated_qkv[ + :, :, local_rank * heads_per_rank : (local_rank + 1) * heads_per_rank + ] + qkv = torch.cat([qkv, replicated_qkv], dim=1) + + q, k, v = qkv.chunk(3, dim=0) + + output = self.attn_impl.forward(q, k, v, ctx_attn_metadata) + + # Redistribute back if using sequence parallelism + replicated_output = None + if replicated_q is not None: + replicated_output = output[:, seq_len * world_size :] + output = output[:, : seq_len * world_size] + # TODO: make this asynchronous + replicated_output = sequence_model_parallel_all_gather( + replicated_output.contiguous(), dim=2 + ) + # Apply backend-specific postprocess_output + output = self.attn_impl.postprocess_output(output, ctx_attn_metadata) + + output = sequence_model_parallel_all_to_all_4D( + output, scatter_dim=1, gather_dim=2 + ) + return output, replicated_output + + +class UlyssesAttention_VSA(UlyssesAttention): + """Distributed attention layer with VSA support.""" + + @torch.compiler.disable + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + replicated_q: torch.Tensor | None = None, + replicated_k: torch.Tensor | None = None, + replicated_v: torch.Tensor | None = None, + gate_compress: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """Forward pass for distributed attention. + + Args: + q (torch.Tensor): Query tensor [batch_size, seq_len, num_heads, head_dim] + k (torch.Tensor): Key tensor [batch_size, seq_len, num_heads, head_dim] + v (torch.Tensor): Value tensor [batch_size, seq_len, num_heads, head_dim] + gate_compress (torch.Tensor): Gate compress tensor [batch_size, seq_len, num_heads, head_dim] + replicated_q (Optional[torch.Tensor]): Replicated query tensor, typically for text tokens + replicated_k (Optional[torch.Tensor]): Replicated key tensor + replicated_v (Optional[torch.Tensor]): Replicated value tensor + + Returns: + Tuple[torch.Tensor, Optional[torch.Tensor]]: A tuple containing: + - o (torch.Tensor): Output tensor after attention for the main sequence + - replicated_o (Optional[torch.Tensor]): Output tensor for replicated tokens, if provided + """ + # Check text tokens are not supported for VSA now + assert ( + replicated_q is None and replicated_k is None and replicated_v is None + ), "Replicated QKV is not supported for VSA now" + # Check input shapes + assert q.dim() == 4 and k.dim() == 4 and v.dim() == 4, "Expected 4D tensors" + + forward_context: ForwardContext = get_forward_context() + ctx_attn_metadata = forward_context.attn_metadata + + # Stack QKV + qkvg = torch.cat( + [q, k, v, gate_compress], dim=0 + ) # [3, seq_len, num_heads, head_dim] + + # Redistribute heads across sequence dimension + qkvg = sequence_model_parallel_all_to_all_4D(qkvg, scatter_dim=2, gather_dim=1) + + qkvg = self.attn_impl.preprocess_qkv(qkvg, ctx_attn_metadata) + + q, k, v, gate_compress = qkvg.chunk(4, dim=0) + output = self.attn_impl.forward( + q, k, v, gate_compress=gate_compress, attn_metadata=ctx_attn_metadata + ) # type: ignore[call-arg] + + # Redistribute back if using sequence parallelism + replicated_output = None + + # Apply backend-specific postprocess_output + output = self.attn_impl.postprocess_output(output, ctx_attn_metadata) + + output = sequence_model_parallel_all_to_all_4D( + output, scatter_dim=1, gather_dim=2 + ) + return output, replicated_output + + +class LocalAttention(nn.Module): + """Attention layer.""" + + def __init__( + self, + num_heads: int, + head_size: int, + num_kv_heads: int | None = None, + softmax_scale: float | None = None, + causal: bool = False, + supported_attention_backends: set[AttentionBackendEnum] | None = None, + **extra_impl_args, + ) -> None: + super().__init__() + if softmax_scale is None: + self.softmax_scale = head_size**-0.5 + else: + self.softmax_scale = softmax_scale + if num_kv_heads is None: + num_kv_heads = num_heads + + dtype = get_compute_dtype() + attn_backend = get_attn_backend( + head_size, dtype, supported_attention_backends=supported_attention_backends + ) + impl_cls = attn_backend.get_impl_cls() + self.attn_impl = impl_cls( + num_heads=num_heads, + head_size=head_size, + softmax_scale=self.softmax_scale, + num_kv_heads=num_kv_heads, + causal=causal, + **extra_impl_args, + ) + self.num_heads = num_heads + self.head_size = head_size + self.num_kv_heads = num_kv_heads + self.backend = backend_name_to_enum(attn_backend.get_name()) + self.dtype = dtype + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + ) -> torch.Tensor: + """ + Apply local attention between query, key and value tensors. + + Args: + q (torch.Tensor): Query tensor of shape [batch_size, seq_len, num_heads, head_dim] + k (torch.Tensor): Key tensor of shape [batch_size, seq_len, num_heads, head_dim] + v (torch.Tensor): Value tensor of shape [batch_size, seq_len, num_heads, head_dim] + + Returns: + torch.Tensor: Output tensor after local attention + """ + # Check input shapes + assert q.dim() == 4 and k.dim() == 4 and v.dim() == 4, "Expected 4D tensors" + + forward_context: ForwardContext = get_forward_context() + ctx_attn_metadata = forward_context.attn_metadata + + output = self.attn_impl.forward(q, k, v, attn_metadata=ctx_attn_metadata) + return output + + +class USPAttention(nn.Module): + """ + Ulysses Sequence Parallelism with Ring Attention. + + This class implements the USP algorithm, which is a combination of + Ulysses-style all-to-all communication for sequence-head dimension sharding + and Ring Attention for fine-grained sequence parallelism within subgroups. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + num_kv_heads: int | None = None, + softmax_scale: float | None = None, + causal: bool = False, + supported_attention_backends: set[AttentionBackendEnum] | None = None, + prefix: str = "", + dropout_p: float = 0.0, + **extra_impl_args, + ) -> None: + super().__init__() + if softmax_scale is None: + self.softmax_scale = head_size**-0.5 + else: + self.softmax_scale = softmax_scale + + if num_kv_heads is None: + num_kv_heads = num_heads + + dtype = get_compute_dtype() + attn_backend = get_attn_backend( + head_size, dtype, supported_attention_backends=supported_attention_backends + ) + impl_cls: Type["AttentionImpl"] = attn_backend.get_impl_cls() + self.attn_impl = impl_cls( + num_heads=num_heads, + head_size=head_size, + causal=causal, + softmax_scale=self.softmax_scale, + num_kv_heads=num_kv_heads, + prefix=f"{prefix}.impl", + **extra_impl_args, + ) + self.num_heads = num_heads + self.head_size = head_size + self.num_kv_heads = num_kv_heads + self.backend = backend_name_to_enum(attn_backend.get_name()) + self.dtype = dtype + self.causal = causal + self.dropout_p = dropout_p + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + replicated_q: torch.Tensor | None = None, + replicated_k: torch.Tensor | None = None, + replicated_v: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Forward pass for USPAttention. + + q, k, v: [B, S_local, H, D] + + Note: Replicated tensors are not supported in this implementation. + """ + assert ( + replicated_q is None and replicated_k is None and replicated_v is None + ), "USPAttention does not support replicated_qkv." + forward_context: ForwardContext = get_forward_context() + ctx_attn_metadata = forward_context.attn_metadata + + if get_sequence_parallel_world_size() == 1: + # No sequence parallelism, just run local attention. + out = self.attn_impl.forward(q, k, v, ctx_attn_metadata) + return out, None + + # Ulysses-style All-to-All for sequence/head sharding + if get_ulysses_parallel_world_size() > 1: + # -> [B, S, H_local, D] + q = _usp_input_all_to_all(q, head_dim=2) + k = _usp_input_all_to_all(k, head_dim=2) + v = _usp_input_all_to_all(v, head_dim=2) + + # Ring Attention within subgroups or local attention + if get_ring_parallel_world_size() > 1: + out = ring_attn( + q, + k, + v, + attn_impl=self.attn_impl, + is_causal=self.causal, + dropout_p=self.dropout_p, + ) + else: + # -> [B, S, H_local, D] + out = self.attn_impl.forward(q, k, v, ctx_attn_metadata) + + # Ulysses-style All-to-All to restore original sharding + if get_ulysses_parallel_world_size() > 1: + # -> [B, S_local, H, D] + out = _usp_output_all_to_all(out, head_dim=2) + + return out, None diff --git a/python/sglang/multimodal_gen/runtime/layers/attention/selector.py b/python/sglang/multimodal_gen/runtime/layers/attention/selector.py new file mode 100644 index 000000000..7d9d18463 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/layers/attention/selector.py @@ -0,0 +1,197 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/attention/selector.py + +import os +from collections.abc import Generator +from contextlib import contextmanager +from functools import cache +from typing import cast + +import torch + +from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import ( + AttentionBackend, +) +from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum +from sglang.multimodal_gen.runtime.server_args import get_global_server_args +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname + +logger = init_logger(__name__) + + +def backend_name_to_enum(backend_name: str) -> AttentionBackendEnum | None: + """ + Convert a string backend name to a _Backend enum value. + + Returns: + * _Backend: enum value if backend_name is a valid in-tree type + * None: otherwise it's an invalid in-tree type or an out-of-tree platform is + loaded. + """ + assert backend_name is not None + return ( + AttentionBackendEnum[backend_name] + if backend_name in AttentionBackendEnum.__members__ + else None + ) + + +def get_env_variable_attn_backend() -> AttentionBackendEnum | None: + """ + Get the backend override specified by the sgl-diffusion attention + backend environment variable, if one is specified. + + Returns: + + * _Backend enum value if an override is specified + * None otherwise + """ + backend_name = os.environ.get(STR_BACKEND_ENV_VAR) + return None if backend_name is None else backend_name_to_enum(backend_name) + + +# Global state allows a particular choice of backend +# to be forced, overriding the logic which auto-selects +# a backend based on system & workload configuration +# (default behavior if this variable is None) +# +# THIS SELECTION TAKES PRECEDENCE OVER THE +# FASTVIDEO ATTENTION BACKEND ENVIRONMENT VARIABLE +forced_attn_backend: AttentionBackendEnum | None = None + + +def global_force_attn_backend(attn_backend: AttentionBackendEnum | None) -> None: + """ + Force all attention operations to use a specified backend. + + Passing `None` for the argument re-enables automatic + backend selection., + + Arguments: + + * attn_backend: backend selection (None to revert to auto) + """ + global forced_attn_backend + forced_attn_backend = attn_backend + + +def get_global_forced_attn_backend() -> AttentionBackendEnum | None: + """ + Get the currently-forced choice of attention backend, + or None if auto-selection is currently enabled. + """ + return forced_attn_backend + + +def get_attn_backend( + head_size: int, + dtype: torch.dtype, + supported_attention_backends: set[AttentionBackendEnum] | None = None, +) -> type[AttentionBackend]: + if supported_attention_backends is not None: + # Sort the backend names to ensure consistent cache key + be_tuple = tuple( + sorted(list(supported_attention_backends), key=lambda b: b.name) + ) + else: + be_tuple = None + return _cached_get_attn_backend(head_size, dtype, be_tuple) + + +@cache +def _cached_get_attn_backend( + head_size: int, + dtype: torch.dtype, + supported_attention_backends: tuple[AttentionBackendEnum] | None = None, +) -> type[AttentionBackend]: + # Check whether a particular choice of backend was + # previously forced. + # + # THIS SELECTION OVERRIDES THE SGL_DIFFUSION_ATTENTION_BACKEND + # ENVIRONMENT VARIABLE. + from sglang.multimodal_gen.runtime.platforms import current_platform + + supported_attention_backends = set(supported_attention_backends) + if not supported_attention_backends: + raise ValueError("supported_attention_backends is empty") + selected_backend = None + backend_by_global_setting: AttentionBackendEnum | None = ( + get_global_forced_attn_backend() + ) + if backend_by_global_setting is not None: + selected_backend = backend_by_global_setting + else: + # Check the server arguments for a backend override + server_args = get_global_server_args() + if server_args.attention_backend is not None: + try: + selected_backend = AttentionBackendEnum[ + server_args.attention_backend.upper() + ] + + except KeyError: + raise ValueError( + f"Invalid attention backend '{server_args.attention_backend}' specified via command line. " + f"Available options are: {[e.name.lower() for e in AttentionBackendEnum]}" + ) + + # get device-specific attn_backend + if selected_backend is None: + logger.debug(f"Attention backend not specified") + elif ( + not supported_attention_backends + or selected_backend not in supported_attention_backends + ): + supported_attention_backends_str = [ + supported_attention_backend.__str__() + for supported_attention_backend in supported_attention_backends + ] + logger.debug( + f"Selected attention backend: '{selected_backend}' not in supported attention backends: {supported_attention_backends_str}" + ) + selected_backend = None + + attention_cls = current_platform.get_attn_backend_cls_str( + selected_backend, head_size, dtype + ) + if not attention_cls: + raise ValueError( + f"Invalid attention backend for {current_platform.device_name}" + ) + return cast(type[AttentionBackend], resolve_obj_by_qualname(attention_cls)) + + +@contextmanager +def global_force_attn_backend_context_manager( + attn_backend: AttentionBackendEnum, +) -> Generator[None, None, None]: + """ + Globally force a sgl-diffusion attention backend override within a + context manager, reverting the global attention backend + override to its prior state upon exiting the context + manager. + + Arguments: + + * attn_backend: attention backend to force + + Returns: + + * Generator + """ + + # Save the current state of the global backend override (if any) + original_value = get_global_forced_attn_backend() + + # Globally force the new backend override + global_force_attn_backend(attn_backend) + + # Yield control back to the enclosed code block + try: + yield + finally: + # Revert the original global backend override, if any + global_force_attn_backend(original_value) diff --git a/python/sglang/multimodal_gen/runtime/layers/custom_op.py b/python/sglang/multimodal_gen/runtime/layers/custom_op.py new file mode 100644 index 000000000..abc2f1238 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/layers/custom_op.py @@ -0,0 +1,110 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/custom_op.py + +from collections.abc import Callable +from typing import Any + +import torch.nn as nn + +from sglang.multimodal_gen.runtime.utils.common import ( + is_cpu, + is_cuda, + is_hip, + is_npu, + is_xpu, +) +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + +_is_cuda = is_cuda() +_is_hip = is_hip() +_is_cpu = is_cpu() +_is_npu = is_npu() +_is_xpu = is_xpu() + + +class CustomOp(nn.Module): + """ + Base class for custom ops. + Dispatches the forward method to the appropriate backend. + """ + + def __init__(self) -> None: + super().__init__() + self._forward_method = self.dispatch_forward() + + def forward(self, *args, **kwargs) -> Any: + return self._forward_method(*args, **kwargs) + + def forward_native(self, *args, **kwargs) -> Any: + """PyTorch-native implementation of the forward method. + This method is optional. If implemented, it can be used with compilers + such as torch.compile or PyTorch XLA. Also, it can be used for testing + purposes. + """ + raise NotImplementedError + + def forward_cuda(self, *args, **kwargs) -> Any: + raise NotImplementedError + + def forward_cpu(self, *args, **kwargs) -> Any: + # By default, we assume that CPU ops are compatible with CUDA ops. + return self.forward_cuda(*args, **kwargs) + + def forward_tpu(self, *args, **kwargs) -> Any: + # By default, we assume that TPU ops are compatible with the + # PyTorch-native implementation. + # NOTE(woosuk): This is a placeholder for future extensions. + return self.forward_native(*args, **kwargs) + + def forward_oot(self, *args, **kwargs) -> Any: + # By default, we assume that OOT ops are compatible with the + # PyTorch-native implementation. + return self.forward_native(*args, **kwargs) + + def dispatch_forward(self) -> Callable: + if _is_cuda: + return self.forward_cuda + elif _is_hip: + return self.forward_hip + elif _is_npu: + return self.forward_npu + elif _is_xpu: + return self.forward_xpu + else: + return self.forward_native + + @classmethod + def enabled(cls) -> bool: + # since we are not using Inductor, we always return True + return True + + @staticmethod + def default_on() -> bool: + """ + On by default if level < CompilationLevel.PIECEWISE + Specifying 'all' or 'none' in custom_op takes precedence. + """ + raise NotImplementedError + + # Dictionary of all custom ops (classes, indexed by registered name). + # To check if an op with a name is enabled, call .enabled() on the class. + # Examples: + # - MyOp.enabled() + # - op_registry["my_op"].enabled() + op_registry: dict[str, type["CustomOp"]] = {} + + # Decorator to register custom ops. + @classmethod + def register(cls, name: str) -> Callable: + + def decorator(op_cls): + assert name not in cls.op_registry, f"Duplicate op name: {name}" + op_cls.name = name + cls.op_registry[name] = op_cls + return op_cls + + return decorator diff --git a/python/sglang/multimodal_gen/runtime/layers/layernorm.py b/python/sglang/multimodal_gen/runtime/layers/layernorm.py new file mode 100644 index 000000000..166ab24d5 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/layers/layernorm.py @@ -0,0 +1,429 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/layers/layernorm.py +"""Custom normalization layers.""" +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from sglang.multimodal_gen.runtime.layers.custom_op import CustomOp +from sglang.multimodal_gen.runtime.layers.triton_ops import ( + fuse_scale_shift_kernel, + norm_infer, + rms_norm_fn, +) +from sglang.multimodal_gen.runtime.utils.common import ( + get_bool_env_var, + is_cpu, + is_cuda, + is_hip, + is_npu, + is_xpu, +) + +_is_cuda = is_cuda() +_is_hip = is_hip() +_is_npu = is_npu() +_is_cpu = is_cpu() +_is_xpu = is_xpu() + +from sgl_kernel import fused_add_rmsnorm, rmsnorm + + +# Copied and adapted from sglang +@CustomOp.register("rms_norm") +class RMSNorm(CustomOp): + """Root mean square normalization. + + Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight. + Refer to https://arxiv.org/abs/1910.07467 + """ + + def __init__( + self, + hidden_size: int, + eps: float = 1e-6, + dtype: torch.dtype = torch.float32, + var_hidden_size: Optional[int] = None, + ) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + self.hidden_size = hidden_size + self.variance_size_override = ( + None if var_hidden_size == hidden_size else var_hidden_size + ) + if get_bool_env_var("SGLANG_ENABLE_DETERMINISTIC_INFERENCE"): + self._forward_method = self.forward_native + + def forward_triton(self, x: torch.Tensor, residual: Optional[torch.Tensor] = None): + return rms_norm_fn( + x, self.weight, bias=None, residual=residual, eps=self.variance_epsilon + ) + + def forward_cuda( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + shape = x.shape + x = x.view(-1, shape[-1]) + if residual is not None: + residual_shape = residual.shape + residual = residual.view(-1, shape[-1]) + + if x.dtype == torch.float: + # fp32 + out = self.forward_triton(x, residual) + elif self.variance_size_override is not None: + return self.forward_native(x, residual) + elif residual is not None: + fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon) + return x.view(shape), residual.view(residual_shape) + else: + out = rmsnorm(x, self.weight.data, self.variance_epsilon) + out = out.view(shape) + return out + + def forward_native( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if not x.is_contiguous(): + x = x.contiguous() + orig_dtype = x.dtype + x = x.to(torch.float32) + if residual is not None: + x = x + residual.to(torch.float32) + residual = x.to(orig_dtype) + + hidden_size = x.shape[-1] + if hidden_size != self.hidden_size: + raise ValueError( + "Expected hidden_size to be " + f"{self.hidden_size}, but found: {hidden_size}" + ) + + if self.variance_size_override is None: + x_var = x + else: + if hidden_size < self.variance_size_override: + raise ValueError( + "Expected hidden_size to be at least " + f"{self.variance_size_override}, but found: {hidden_size}" + ) + + x_var = x[..., : self.variance_size_override] + + variance = x_var.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + self.variance_epsilon) + x = (x * self.weight).to(orig_dtype) + if residual is None: + return x + else: + return x, residual + + def forward_cpu( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + return self.forward_native(x, residual) + + def extra_repr(self) -> str: + s = f"hidden_size={self.weight.data.size(0)}" + s += f", eps={self.variance_epsilon}" + return s + + +# Copied and adapted from sglang +@CustomOp.register("layer_norm") +class LayerNorm(CustomOp): + def __init__( + self, + hidden_size: int, + eps=1e-5, + bias: bool = True, + elementwise_affine=True, + device=None, + dtype=None, + ) -> None: + super().__init__() + self.eps = eps + factory_kwargs = {"device": device, "dtype": dtype} + self.hidden_size = hidden_size + if elementwise_affine: + self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.bias = ( + torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + if bias + else None + ) + else: + self.register_parameter("weight", None) + self.register_parameter("bias", None) + # Lazy cache for ones vector (not a registered buffer to avoid FSDP/meta issues) + self._weight_fallback_cache = None + + def _get_weight_fallback(self, x: torch.Tensor) -> torch.Tensor: + wf = getattr(self, "_weight_fallback_cache", None) + if ( + wf is None + or wf.device != x.device + or wf.dtype != x.dtype + or wf.numel() != self.hidden_size + ): + wf = torch.ones(self.hidden_size, device=x.device, dtype=x.dtype) + self._weight_fallback_cache = wf + return wf + + def forward_triton(self, x: torch.Tensor): + # Fast inference kernel without residual/dropout branches + return norm_infer( + x.view(-1, self.hidden_size), + self.weight, + self.bias, + eps=self.eps, + is_rms_norm=False, + ).view(x.shape) + + def forward_cuda( + self, + x: torch.Tensor, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + shape = x.shape + x = x.view(-1, self.hidden_size) + return self.forward_triton(x).view(shape) + + @torch.compile(backend="inductor") + def forward_native( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + input_dtype = x.dtype + mean = x.mean(-1, keepdim=True) + variance = (x - mean).pow(2).mean(-1, keepdim=True) + x = (x - mean) * torch.rsqrt(variance + self.eps) + if self.weight is not None: + x = self.weight * x + # if no affine, this is a no-op + if self.bias is not None: + x = x + self.bias + return x.to(input_dtype) + + def forward_cpu( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + return self.forward_native(x, residual) + + def extra_repr(self) -> str: + s = f"hidden_size={self.weight.data.size(0)}" + s += f", eps={self.variance_epsilon}" + return s + + +class ScaleResidual(nn.Module): + """ + Applies gated residual connection. + """ + + def __init__(self, prefix: str = ""): + super().__init__() + + def forward( + self, residual: torch.Tensor, x: torch.Tensor, gate: torch.Tensor + ) -> torch.Tensor: + """Apply gated residual connection.""" + # x.shape: [batch_size, seq_len, inner_dim] + if gate.dim() == 4: + # gate.shape: [batch_size, num_frames, 1, inner_dim] + num_frames = gate.shape[1] + frame_seqlen = x.shape[1] // num_frames + return residual + ( + x.unflatten(dim=1, sizes=(num_frames, frame_seqlen)) * gate + ).flatten(1, 2) + else: + # gate.shape: [batch_size, 1, inner_dim] + return residual + x * gate + + +# adapted from Diffusers: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/normalization.py +# NOTE(will): Needed to match behavior of diffusers and wan2.1 even while using +# FSDP's MixedPrecisionPolicy +class FP32LayerNorm(nn.LayerNorm): + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + origin_dtype = inputs.dtype + return F.layer_norm( + inputs.float(), + self.normalized_shape, + self.weight.float() if self.weight is not None else None, + self.bias.float() if self.bias is not None else None, + self.eps, + ).to(origin_dtype) + + +class ScaleResidualLayerNormScaleShift(nn.Module): + """ + Fused operation that combines: + 1. Gated residual connection + 2. LayerNorm + 3. Scale and shift operations + + This reduces memory bandwidth by combining memory-bound operations. + """ + + def __init__( + self, + hidden_size: int, + norm_type: str = "rms", + eps: float = 1e-6, + elementwise_affine: bool = False, + dtype: torch.dtype = torch.float32, + compute_dtype: torch.dtype | None = None, + prefix: str = "", + ): + super().__init__() + if norm_type == "rms": + self.norm = RMSNorm( + hidden_size, has_weight=elementwise_affine, eps=eps, dtype=dtype + ) + elif norm_type == "layer": + if compute_dtype == torch.float32: + self.norm = FP32LayerNorm( + hidden_size, elementwise_affine=elementwise_affine, eps=eps + ) + else: + self.norm = LayerNorm( + hidden_size, + elementwise_affine=elementwise_affine, + eps=eps, + dtype=dtype, + ) + else: + raise NotImplementedError(f"Norm type {norm_type} not implemented") + + def forward( + self, + residual: torch.Tensor, + x: torch.Tensor, + gate: torch.Tensor | int, + shift: torch.Tensor, + scale: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Apply gated residual connection, followed by layernorm and + scale/shift in a single fused operation. + + Returns: + Tuple containing: + - normalized and modulated output of shape: [batch_size, seq_len, inner_dim] + - residual value (value after residual connection + but before normalization) + """ + # x.shape: [batch_size, seq_len, inner_dim] + # Apply residual connection with gating + if isinstance(gate, int): + # used by cross-attention, should be 1 + assert gate == 1 + residual_output = residual + x + elif isinstance(gate, torch.Tensor): + if gate.dim() == 4: + # gate.shape: [batch_size, num_frames, 1, inner_dim] + num_frames = gate.shape[1] + frame_seqlen = x.shape[1] // num_frames + residual_output = residual + ( + x.unflatten(dim=1, sizes=(num_frames, frame_seqlen)) * gate + ).flatten(1, 2) + else: + # used by bidirectional self attention + # gate.shape: [batch_size, 1, inner_dim] + residual_output = residual + x * gate + else: + raise ValueError(f"Gate type {type(gate)} not supported") + # residual_output.shape: [batch_size, seq_len, inner_dim] + + # Apply normalization + normalized = self.norm(residual_output) + + # modulated = fused_scale_shift( + # normalized, + # scale, + # shift, + # ) + modulated = fuse_scale_shift_kernel( + normalized, + scale, + shift, + ) + return modulated, residual_output + + +class LayerNormScaleShift(nn.Module): + """ + Fused operation that combines LayerNorm with scale and shift operations. + This reduces memory bandwidth by combining memory-bound operations. + """ + + def __init__( + self, + hidden_size: int, + norm_type: str = "rms", + eps: float = 1e-6, + elementwise_affine: bool = False, + dtype: torch.dtype = torch.float32, + compute_dtype: torch.dtype | None = None, + prefix: str = "", + ): + super().__init__() + self.compute_dtype = compute_dtype + if norm_type == "rms": + self.norm = RMSNorm(hidden_size, has_weight=elementwise_affine, eps=eps) + elif norm_type == "layer": + if self.compute_dtype == torch.float32: + self.norm = FP32LayerNorm( + hidden_size, elementwise_affine=elementwise_affine, eps=eps + ) + else: + self.norm = nn.LayerNorm( + hidden_size, + elementwise_affine=elementwise_affine, + eps=eps, + dtype=dtype, + ) + else: + raise NotImplementedError(f"Norm type {norm_type} not implemented") + + def forward( + self, x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor + ) -> torch.Tensor: + """Apply ln followed by scale and shift in a single fused operation.""" + # x.shape: [batch_size, seq_len, inner_dim] + normalized = self.norm(x) + if self.compute_dtype == torch.float32: + normalized = normalized.float() + + if scale.dim() == 4: + # scale.shape: [batch_size, num_frames, 1, inner_dim] + num_frames = scale.shape[1] + frame_seqlen = normalized.shape[1] // num_frames + output = ( + normalized.unflatten(dim=1, sizes=(num_frames, frame_seqlen)) + * (1.0 + scale) + + shift + ).flatten(1, 2) + else: + # scale.shape: [batch_size, 1, inner_dim] + # shift.shape: [batch_size, 1, inner_dim] + output = normalized * (1.0 + scale) + shift + + if self.compute_dtype == torch.float32: + output = output.to(x.dtype) + + return output diff --git a/python/sglang/multimodal_gen/runtime/layers/linear.py b/python/sglang/multimodal_gen/runtime/layers/linear.py new file mode 100644 index 000000000..65c71372a --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/layers/linear.py @@ -0,0 +1,1057 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/layers/linear.py + +from abc import abstractmethod + +import torch +import torch.nn.functional as F +from torch.nn.parameter import Parameter + +from sglang.multimodal_gen.runtime.distributed import ( + divide, + get_tp_rank, + get_tp_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, +) +from sglang.multimodal_gen.runtime.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) + +# yapf: disable +from sglang.multimodal_gen.runtime.models.parameter import ( + BasevLLMParameter, + BlockQuantScaleParameter, + PackedColumnParameter, + PackedvLLMParameter, + PerTensorScaleParameter, + RowvLLMParameter, +) + +# yapf: enable +from sglang.multimodal_gen.runtime.models.utils import set_weight_attrs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + +WEIGHT_LOADER_V2_SUPPORTED = [ + "CompressedTensorsLinearMethod", + "AWQMarlinLinearMethod", + "AWQLinearMethod", + "GPTQMarlinLinearMethod", + "Fp8LinearMethod", + "MarlinLinearMethod", + "QQQLinearMethod", + "GPTQMarlin24LinearMethod", + "TPUInt8LinearMethod", + "GPTQLinearMethod", + "FBGEMMFp8LinearMethod", + "ModelOptFp8LinearMethod", + "IPEXAWQLinearMethod", + "IPEXGPTQLinearMethod", + "HQQMarlinMethod", + "QuarkLinearMethod", +] + + +def adjust_scalar_to_fused_array( + param: torch.Tensor, loaded_weight: torch.Tensor, shard_id: str | int +) -> tuple[torch.Tensor, torch.Tensor]: + """For fused modules (QKV and MLP) we have an array of length + N that holds 1 scale for each "logical" matrix. So the param + is an array of length N. The loaded_weight corresponds to + one of the shards on disk. Here, we slice the param based on + the shard_id for loading. + """ + qkv_idxs = {"q": 0, "k": 1, "v": 2} + + if isinstance(shard_id, str): + shard_id = qkv_idxs[shard_id] + elif not isinstance(shard_id, int): + raise ValueError(f"Unknown Shard Id {shard_id}") + + # AutoFP8 scales do not have a shape + # compressed-tensors scales do have a shape + if len(loaded_weight.shape) != 0: + assert loaded_weight.shape[0] == 1 + loaded_weight = loaded_weight[0] + + return param[shard_id], loaded_weight + + +class LinearMethodBase(QuantizeMethodBase): + """Base class for different (maybe quantized) linear methods.""" + + @abstractmethod + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + """Create weights for a linear layer. + The weights will be set as attributes of the layer. + + Args: + layer: The layer that is using the LinearMethodBase factory. + input_size_per_partition: Size of the weight input dim on rank X. + output_partition_sizes: Sizes of the output dim of each logical + weight on rank X. E.g., output_partition_sizes for QKVLinear + is a list contains the width of Wq, Wk, Wv on rank X. + input_size: Size of the input dim of the weight across all ranks. + output_size: Size of the output dim of the weight across all ranks. + params_dtype: Datatype of the parameters. + """ + raise NotImplementedError + + @abstractmethod + def apply( + self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None + ) -> torch.Tensor: + """Apply the weights in layer to the input tensor. + Expects create_weights to have been called before on the layer.""" + raise NotImplementedError + + +class UnquantizedLinearMethod(LinearMethodBase): + """Linear method without quantization.""" + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + weight = Parameter( + torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, extra_weight_attrs) + + def apply( + self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None + ) -> torch.Tensor: + output = ( + F.linear(x, layer.weight, bias) + if torch.cuda.is_available() or bias is None + else F.linear(x, layer.weight, bias.to(x.dtype)) + ) # NOTE: this line assumes that we are using amp when using cuda and is needed to account for the fact that amp isn't supported in mps + return output + + +class LinearBase(torch.nn.Module): + """Base linear layer. + + Args: + input_size: input dimension of the linear layer. + output_size: output dimension of the linear layer. + bias: If true, add bias. + skip_bias_add: If true, skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + """ + + def __init__( + self, + input_size: int, + output_size: int, + skip_bias_add: bool = False, + params_dtype: torch.dtype | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): + super().__init__() + + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.skip_bias_add = skip_bias_add + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + self.quant_config = quant_config + self.prefix = prefix + if quant_config is None: + self.quant_method: QuantizeMethodBase | None = UnquantizedLinearMethod() + else: + self.quant_method = quant_config.get_quant_method(self, prefix=prefix) + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, Parameter | None]: + raise NotImplementedError + + +class ReplicatedLinear(LinearBase): + """Replicated linear layer. + + Args: + input_size: input dimension of the linear layer. + output_size: output dimension of the linear layer. + bias: If true, add bias. + skip_bias_add: If true, skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + prefix: The name of the layer in the state dict, including all parents + (e.g. model.layers.0.qkv_proj) + """ + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: torch.dtype | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): + super().__init__( + input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix=prefix, + ) + + # All the linear layer supports quant method. + assert self.quant_method is not None + self.quant_method.create_weights( + self, + self.input_size, + [self.output_size], + self.input_size, + self.output_size, + self.params_dtype, + weight_loader=self.weight_loader, + ) + + if bias: + self.bias = Parameter( + torch.empty( + self.output_size, + dtype=self.params_dtype, + ) + ) + set_weight_attrs( + self.bias, + { + "output_dim": 0, + "weight_loader": self.weight_loader, + }, + ) + else: + self.register_parameter("bias", None) + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor) -> None: + # If the weight on disk does not have a shape, give it one + # (such scales for AutoFp8). + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + assert param.size() == loaded_weight.size(), ( + f"Tried to load weights of size {loaded_weight.size()}" + f"to a parameter of size {param.size()}" + ) + param.data.copy_(loaded_weight) + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, Parameter | None]: + bias = self.bias if not self.skip_bias_add else None + assert self.quant_method is not None + output = self.quant_method.apply(self, x, bias) + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + def extra_repr(self) -> str: + s = f"in_features={self.input_size}" + s += f", output_features={self.output_size}" + s += f", bias={self.bias is not None}" + return s + + +class ColumnParallelLinear(LinearBase): + """Linear layer with column parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its second dimension as A = [A_1, ..., A_p]. + + Args: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + bias: If true, add bias. + gather_output: If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is Y_i = XA_i + skip_bias_add: This was added to enable performance optimizations where + bias can be fused with other element-wise operations. we + skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + output_sizes: list of output sizes packed into one output, like for QKV + the list would be size 3. + prefix: The name of the layer in the state dict, including all parents + (e.g. model.layers.0.qkv_proj) + """ + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: torch.dtype | None = None, + quant_config: QuantizationConfig | None = None, + output_sizes: list[int] | None = None, + prefix: str = "", + ): + # Divide the weight matrix along the last dimension. + self.tp_size = get_tp_world_size() + self.input_size_per_partition = input_size + self.output_size_per_partition = divide(output_size, self.tp_size) + self.output_partition_sizes = [self.output_size_per_partition] + # If QKV or MergedColumn, use output size of each partition. + if hasattr(self, "output_sizes"): + self.output_partition_sizes = [ + divide(output_size, self.tp_size) for output_size in self.output_sizes + ] + + super().__init__( + input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix + ) + + self.gather_output = gather_output + + if output_sizes is None: + output_sizes = [output_size] + + assert self.quant_method is not None + self.quant_method.create_weights( + layer=self, + input_size_per_partition=self.input_size_per_partition, + output_partition_sizes=self.output_partition_sizes, + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=( + self.weight_loader_v2 + if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED + else self.weight_loader + ), + ) + if bias: + self.bias = Parameter( + torch.empty( + self.output_size_per_partition, + dtype=params_dtype, + ) + ) + set_weight_attrs( + self.bias, + { + "output_dim": 0, + "weight_loader": self.weight_loader, + }, + ) + else: + self.register_parameter("bias", None) + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor) -> None: + tp_rank = get_tp_rank() + output_dim = getattr(param, "output_dim", None) + + is_sharded_weight = getattr(param, "is_sharded_weight", False) + is_sharded_weight = is_sharded_weight + + param_data = param.data + if output_dim is not None and not is_sharded_weight: + shard_size = param_data.shape[output_dim] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor) -> None: + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: + assert loaded_weight.numel() == 1 + loaded_weight = loaded_weight.reshape(1) + param.load_column_parallel_weight(loaded_weight=loaded_weight) + + def forward(self, input_: torch.Tensor) -> tuple[torch.Tensor, Parameter | None]: + bias = self.bias if not self.skip_bias_add else None + + # Matrix multiply. + assert self.quant_method is not None + output_parallel = self.quant_method.apply(self, input_, bias) + if self.gather_output: + # All-gather across the partitions. + output = tensor_model_parallel_all_gather(output_parallel) + else: + output = output_parallel + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + def extra_repr(self) -> str: + s = f"in_features={self.input_size}" + s += f", output_features={self.output_size_per_partition}" + s += f", bias={self.bias is not None}" + s += f", tp_size={get_tp_world_size()}" + s += f", gather_output={self.gather_output}" + return s + + +class MergedColumnParallelLinear(ColumnParallelLinear): + """Packed linear layers with column parallelism. + + Similar to ColumnParallelLinear, but the weight matrix is concatenated + along the output dimension. When the weight matrix is loaded, the + different partitions are sharded separately. + + Args: + input_size: input dimension of the linear layer. + output_sizes: list of output dimensions of the linear layer. + bias: If true, add bias. + gather_output: If true, call all-gather on output and make the output + available to all GPUs, otherwise, every GPU will have + its own output. + skip_bias_add: This was added to enable performance optimizations where + bias can be fused with other element-wise operations. we + skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + prefix: The name of the layer in the state dict, including all parents + (e.g. model.layers.0.qkv_proj) + """ + + def __init__( + self, + input_size: int, + output_sizes: list[int], + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: torch.dtype | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): + self.output_sizes = output_sizes + tp_size = get_tp_world_size() + assert all(output_size % tp_size == 0 for output_size in output_sizes) + super().__init__( + input_size=input_size, + output_size=sum(output_sizes), + bias=bias, + gather_output=gather_output, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config, + prefix=prefix, + ) + + def weight_loader( + self, + param: Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: int | None = None, + ) -> None: + + param_data = param.data + output_dim = getattr(param, "output_dim", None) + # Special case for AQLM codebooks. + is_metadata = getattr(param, "is_metadata", False) + # Special case for per-tensor scale to load scalar into fused array. + needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False) + + if loaded_shard_id is None: + # Loaded weight is already fused on disk (mlp). + # (e.g., Phi-3's gate_up_proj). + if output_dim is None: + if needs_scalar_to_array: + param_data, loaded_weight = adjust_scalar_to_fused_array( + param_data, loaded_weight, 0 + ) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + return + current_shard_offset = 0 + shard_offsets: list[tuple[int, int, int]] = [] + for i, output_size in enumerate(self.output_sizes): + shard_offsets.append((i, current_shard_offset, output_size)) + current_shard_offset += output_size + for shard_id, shard_offset, shard_size in shard_offsets: + loaded_weight_shard = loaded_weight.narrow( + output_dim, shard_offset, shard_size + ) + self.weight_loader(param, loaded_weight_shard, shard_id) + return + + assert loaded_shard_id < len(self.output_sizes) + tp_rank = get_tp_rank() + tp_size = get_tp_world_size() + if output_dim is not None: + shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size + shard_size = self.output_sizes[loaded_shard_id] // tp_size + + is_sharded_weight = getattr(param, "is_sharded_weight", False) + # bitsandbytes loads the weights of the specific portion + # no need to narrow + is_sharded_weight = is_sharded_weight + + param_data = param_data.narrow(output_dim, shard_offset, shard_size) + start_idx = tp_rank * shard_size + if not is_sharded_weight: + loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + # Special case for AQLM codebooks. + elif is_metadata: + # metadata indicates fixed size concatenated along dim 0 + shard_size = loaded_weight.shape[0] + shard_offset = loaded_shard_id * shard_size + param_data = param_data.narrow(0, shard_offset, shard_size) + + # Special case for per-tensor scales in fused case. + elif needs_scalar_to_array: + param_data, loaded_weight = adjust_scalar_to_fused_array( + param_data, loaded_weight, loaded_shard_id + ) + + else: + ignore_warning = getattr(param, "ignore_warning", False) + if not ignore_warning: + logger.warning( + "Loading a weight without `output_dim` attribute in " + "MergedColumnParallelLinear, assume the weight is " + "the same for all partitions." + ) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + def _load_fused_module_from_checkpoint( + self, param: BasevLLMParameter, loaded_weight: torch.Tensor + ) -> None: + """ + Handle special case for models where MLP layers are already + fused on disk. In this case, we have no shard id. This function + determmines the shard id by splitting these layers and then calls + the weight loader using the shard id. + + An example of a model with these fused layers: + https://huggingface.co/microsoft/Phi-3-mini-4k-instruct + """ + + current_shard_offset = 0 + shard_offsets: list[tuple[int, int, int]] = [] + for i, output_size in enumerate(self.output_sizes): + shard_offsets.append((i, current_shard_offset, output_size)) + current_shard_offset += output_size + + for shard_id, shard_offset, shard_size in shard_offsets: + # Special case for Quantization. + # If quantized, we need to adjust the offset and size to account + # for the packing. + if ( + isinstance(param, PackedColumnParameter | PackedvLLMParameter) + and param.packed_dim == param.output_dim + ): + shard_size, shard_offset = param.adjust_shard_indexes_for_packing( + shard_size=shard_size, shard_offset=shard_offset + ) + + loaded_weight_shard = loaded_weight.narrow( + param.output_dim, shard_offset, shard_size + ) + self.weight_loader_v2(param, loaded_weight_shard, shard_id) + + def weight_loader_v2( + self, + param: BasevLLMParameter, + loaded_weight: torch.Tensor, + loaded_shard_id: int | None = None, + ) -> None: + if loaded_shard_id is None: + if isinstance(param, PerTensorScaleParameter): + param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0) + return + elif type(param) in (RowvLLMParameter, BasevLLMParameter): + param.load_merged_column_weight(loaded_weight=loaded_weight) + return + # TODO: @dsikka - move to parameter.py + self._load_fused_module_from_checkpoint(param, loaded_weight) + return + + assert loaded_shard_id < len(self.output_sizes) + + tp_size = get_tp_world_size() + + if isinstance(param, BlockQuantScaleParameter): + raise NotImplementedError("FP8 is not implemented yet") + # FIXME(will): add fp8 support + # from vllm.model_executor.layers.quantization.fp8 import ( + # Fp8LinearMethod, Fp8MoEMethod) + # assert self.quant_method is not None + # assert isinstance(self.quant_method, + # (Fp8LinearMethod, Fp8MoEMethod)) + # weight_block_size = self.quant_method.quant_config.weight_block_size + # assert weight_block_size is not None + # block_n, _ = weight_block_size[0], weight_block_size[1] + # shard_offset = ( + # (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // + # block_n) // tp_size + # shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) // + # block_n // tp_size) + else: + shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size + shard_size = self.output_sizes[loaded_shard_id] // tp_size + + param.load_merged_column_weight( + loaded_weight=loaded_weight, + shard_id=loaded_shard_id, + shard_offset=shard_offset, + shard_size=shard_size, + ) + + +class QKVParallelLinear(ColumnParallelLinear): + """Linear layers for the attention's QKV transformation. + + Linear layers for the linear transformation of the query, key, and value + vectors in the attention layer. The weight matrix is concatenated along + the output dimension. The layer is parallelized along the head dimension. + When the number of key/value heads is smaller than the number of query + heads (e.g., multi-query/grouped-query attention), the key/value head may + be replicated while the query heads are partitioned. + + Args: + hidden_size: input hidden state size of the transformer. + head_size: size of each attention head. + total_num_heads: total number of attention query heads. + total_num_kv_heads: total number of attention key/value heads. If + None, assume total_num_kv_heads = total_num_heads. + bias: If true, add bias. + skip_bias_add: This was added to enable performance optimizations where + bias can be fused with other element-wise operations. we + skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + prefix: The name of the layer in the state dict, including all parents + (e.g. model.layers.0.qkv_proj) + """ + + def __init__( + self, + hidden_size: int, + head_size: int, + total_num_heads: int, + total_num_kv_heads: int | None = None, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: torch.dtype | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): + self.hidden_size = hidden_size + self.head_size = head_size + self.total_num_heads = total_num_heads + if total_num_kv_heads is None: + total_num_kv_heads = total_num_heads + self.total_num_kv_heads = total_num_kv_heads + # Divide the weight matrix along the last dimension. + tp_size = get_tp_world_size() + self.num_heads = divide(self.total_num_heads, tp_size) + if tp_size >= self.total_num_kv_heads: + self.num_kv_heads = 1 + self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads) + else: + self.num_kv_heads = divide(self.total_num_kv_heads, tp_size) + self.num_kv_head_replicas = 1 + input_size = self.hidden_size + output_size = ( + (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size + ) + self.output_sizes = [ + self.num_heads * self.head_size * tp_size, # q_proj + self.num_kv_heads * self.head_size * tp_size, # k_proj + self.num_kv_heads * self.head_size * tp_size, # v_proj + ] + + super().__init__( + input_size=input_size, + output_size=output_size, + bias=bias, + gather_output=False, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config, + prefix=prefix, + ) + + def _get_shard_offset_mapping(self, loaded_shard_id: str) -> int | None: + shard_offset_mapping = { + "q": 0, + "k": self.num_heads * self.head_size, + "v": (self.num_heads + self.num_kv_heads) * self.head_size, + "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size, + } + return shard_offset_mapping.get(loaded_shard_id) + + def _get_shard_size_mapping(self, loaded_shard_id: str) -> int | None: + shard_size_mapping = { + "q": self.num_heads * self.head_size, + "k": self.num_kv_heads * self.head_size, + "v": self.num_kv_heads * self.head_size, + } + return shard_size_mapping.get(loaded_shard_id) + + def _load_fused_module_from_checkpoint( + self, param: BasevLLMParameter, loaded_weight: torch.Tensor + ): + """ + Handle special case for models where QKV layers are already + fused on disk. In this case, we have no shard id. This function + determmines the shard id by splitting these layers and then calls + the weight loader using the shard id. + + An example of a model with these fused layers: + https://huggingface.co/microsoft/Phi-3-mini-4k-instruct + """ + shard_offsets = [ + # (shard_id, shard_offset, shard_size) + ("q", 0, self.total_num_heads * self.head_size), + ( + "k", + self.total_num_heads * self.head_size, + self.total_num_kv_heads * self.head_size, + ), + ( + "v", + (self.total_num_heads + self.total_num_kv_heads) * self.head_size, + self.total_num_kv_heads * self.head_size, + ), + ] + + for shard_id, shard_offset, shard_size in shard_offsets: + # Special case for Quantization. + # If quantized, we need to adjust the offset and size to account + # for the packing. + if ( + isinstance(param, PackedColumnParameter | PackedvLLMParameter) + and param.packed_dim == param.output_dim + ): + shard_size, shard_offset = param.adjust_shard_indexes_for_packing( + shard_size=shard_size, shard_offset=shard_offset + ) + + loaded_weight_shard = loaded_weight.narrow( + param.output_dim, shard_offset, shard_size + ) + self.weight_loader_v2(param, loaded_weight_shard, shard_id) + + def weight_loader_v2( + self, + param: BasevLLMParameter, + loaded_weight: torch.Tensor, + loaded_shard_id: str | None = None, + ): + if loaded_shard_id is None: # special case for certain models + if isinstance(param, PerTensorScaleParameter): + param.load_qkv_weight(loaded_weight=loaded_weight, shard_id=0) + return + elif type(param) in (RowvLLMParameter, BasevLLMParameter): + param.load_qkv_weight(loaded_weight=loaded_weight) + return + # TODO: @dsikka - move to parameter.py + self._load_fused_module_from_checkpoint(param, loaded_weight) + return + + assert loaded_shard_id in ["q", "k", "v"] + + shard_offset = self._get_shard_offset_mapping(loaded_shard_id) + shard_size = self._get_shard_size_mapping(loaded_shard_id) + + param.load_qkv_weight( + loaded_weight=loaded_weight, + num_heads=self.num_kv_head_replicas, + shard_id=loaded_shard_id, + shard_offset=shard_offset, + shard_size=shard_size, + ) + + def weight_loader( + self, + param: Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: str | None = None, + ): + + param_data = param.data + output_dim = getattr(param, "output_dim", None) + # Special case for AQLM codebooks. + is_metadata = getattr(param, "is_metadata", False) + + # Special case for per-tensor scales in fused case. + needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False) + + if loaded_shard_id is None: + # Loaded weight is already fused on disk (qkv). + # (e.g., Phi-3's qkv_proj). + if output_dim is None: + if needs_scalar_to_array: + param_data, loaded_weight = adjust_scalar_to_fused_array( + param_data, loaded_weight, 0 + ) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + return + shard_offsets = [ + # (shard_id, shard_offset, shard_size) + ("q", 0, self.total_num_heads * self.head_size), + ( + "k", + self.total_num_heads * self.head_size, + self.total_num_kv_heads * self.head_size, + ), + ( + "v", + (self.total_num_heads + self.total_num_kv_heads) * self.head_size, + self.total_num_kv_heads * self.head_size, + ), + ] + + for shard_id, shard_offset, shard_size in shard_offsets: + + loaded_weight_shard = loaded_weight.narrow( + output_dim, shard_offset, shard_size + ) + self.weight_loader(param, loaded_weight_shard, shard_id) + return + + tp_rank = get_tp_rank() + assert loaded_shard_id in ["q", "k", "v"] + + # If output dim is defined, use the default loading process. + if output_dim is not None: + if loaded_shard_id == "q": + shard_offset = 0 + shard_size = self.num_heads * self.head_size + elif loaded_shard_id == "k": + shard_offset = self.num_heads * self.head_size + shard_size = self.num_kv_heads * self.head_size + elif loaded_shard_id == "v": + shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size + shard_size = self.num_kv_heads * self.head_size + + is_sharded_weight = getattr(param, "is_sharded_weight", False) + # bitsandbytes loads the weights of the specific portion + # no need to narrow + is_sharded_weight = is_sharded_weight + + shard_idx = 0 + param_data = param_data.narrow(output_dim, shard_offset, shard_size) + if loaded_shard_id == "q": + shard_idx = tp_rank + else: + shard_idx = tp_rank // self.num_kv_head_replicas + start_idx = shard_idx * shard_size + + if not is_sharded_weight: + loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + + # Special case for for AQLM codebooks. + elif is_metadata: + # metadata indicates fixed size concatenated along dim 0 + shard_size = loaded_weight.shape[0] + shard_index = ["q", "k", "v"].index(loaded_shard_id) + param_data = param_data.narrow(0, shard_index * shard_size, shard_size) + # Special case for per-tensor scales in fused case. + elif needs_scalar_to_array: + param_data, loaded_weight = adjust_scalar_to_fused_array( + param_data, loaded_weight, loaded_shard_id + ) + else: + ignore_warning = getattr(param, "ignore_warning", False) + if not ignore_warning: + logger.warning( + "Loading a weight without `output_dim` attribute in " + "QKVParallelLinear, assume the weight is the same " + "for all partitions." + ) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + +class RowParallelLinear(LinearBase): + """Linear layer with row parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its first dimension and X along its second dimension as: + - - + | A_1 | + | . | + A = | . | X = [X_1, ..., X_p] + | . | + | A_p | + - - + Arguments: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + bias: If true, add bias. Note that bias is not parallelized. + input_is_parallel: If true, we assume that the input is already + split across the GPUs and we do not split + again. + skip_bias_add: This was added to enable performance optimization where + bias can be fused with other element-wise operations. + We skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + """ + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + input_is_parallel: bool = True, + skip_bias_add: bool = False, + params_dtype: torch.dtype | None = None, + reduce_results: bool = True, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): + # Divide the weight matrix along the first dimension. + self.tp_rank = get_tp_rank() + self.tp_size = get_tp_world_size() + self.input_size_per_partition = divide(input_size, self.tp_size) + self.output_size_per_partition = output_size + self.output_partition_sizes = [output_size] + + super().__init__( + input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix + ) + + self.input_is_parallel = input_is_parallel + self.reduce_results = reduce_results + + assert self.quant_method is not None + self.quant_method.create_weights( + layer=self, + input_size_per_partition=self.input_size_per_partition, + output_partition_sizes=self.output_partition_sizes, + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=( + self.weight_loader_v2 + if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED + else self.weight_loader + ), + ) + if not reduce_results and (bias and not skip_bias_add): + raise ValueError( + "When not reduce the results, adding bias to the " + "results can lead to incorrect results" + ) + + if bias: + self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype)) + set_weight_attrs( + self.bias, + { + "output_dim": 0, + "weight_loader": self.weight_loader, + }, + ) + else: + self.register_parameter("bias", None) + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + tp_rank = get_tp_rank() + input_dim = getattr(param, "input_dim", None) + is_sharded_weight = getattr(param, "is_sharded_weight", False) + # bitsandbytes loads the weights of the specific portion + # no need to narrow + is_sharded_weight = is_sharded_weight + + param_data = param.data + if input_dim is not None and not is_sharded_weight: + shard_size = param_data.shape[input_dim] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size) + + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor): + + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: + assert loaded_weight.numel() == 1 + loaded_weight = loaded_weight.reshape(1) + + param.load_row_parallel_weight(loaded_weight=loaded_weight) + + def forward(self, input_) -> tuple[torch.Tensor, Parameter | None]: + if self.input_is_parallel: + input_parallel = input_ + else: + tp_rank = get_tp_rank() + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.tp_size + ) + input_parallel = splitted_input[tp_rank].contiguous() + + # Matrix multiply. + assert self.quant_method is not None + # Only fuse bias add into GEMM for rank 0 (this ensures that + # bias will not get added more than once in TP>1 case) + bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias + output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_) + if self.reduce_results and self.tp_size > 1: + output = tensor_model_parallel_all_reduce(output_parallel) + else: + output = output_parallel + + output_bias = self.bias if self.skip_bias_add else None + + return output, output_bias + + def extra_repr(self) -> str: + s = f"input_features={self.input_size_per_partition}" + s += f", output_features={self.output_size}" + s += f", bias={self.bias is not None}" + s += f", tp_size={self.tp_size}" + s += f", reduce_results={self.reduce_results}" + return s diff --git a/python/sglang/multimodal_gen/runtime/layers/lora/linear.py b/python/sglang/multimodal_gen/runtime/layers/lora/linear.py new file mode 100644 index 000000000..e21e4dd6e --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/layers/lora/linear.py @@ -0,0 +1,426 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Code adapted from SGLang https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/lora/layers.py + +import math + +import torch +from torch import nn +from torch.distributed._composable.fsdp import ( + CPUOffloadPolicy, + OffloadPolicy, + fully_shard, +) +from torch.distributed.tensor import DTensor + +from sglang.multimodal_gen.runtime.distributed import ( + get_local_torch_device, + get_tp_rank, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, +) +from sglang.multimodal_gen.runtime.layers.linear import ( + ColumnParallelLinear, + LinearBase, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from sglang.multimodal_gen.runtime.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, +) +from sglang.multimodal_gen.utils import get_mixed_precision_state + +torch._dynamo.config.recompile_limit = 16 + + +class BaseLayerWithLoRA(nn.Module): + + def __init__( + self, + base_layer: nn.Module, + lora_rank: int | None = None, + lora_alpha: int | None = None, + training_mode: bool = False, + ): + super().__init__() + self.base_layer: nn.Module = base_layer + + self.merged: bool = False + self.cpu_weight = base_layer.weight.to("cpu") + # indicates adapter weights don't contain this layer + # (which shouldn't normally happen, but we want to separate it from the case of erroneous merging) + self.disable_lora: bool = False + self.lora_rank = lora_rank + self.lora_alpha = lora_alpha + self.training_mode = training_mode + self.lora_path: str | None = None + + if training_mode: + assert ( + self.lora_rank is not None + ), "LoRA rank must be set for training mode" + if self.lora_rank is None or self.lora_alpha is None: + self.lora_alpha = lora_rank + self.base_layer.requires_grad_(False) + in_dim = self.base_layer.weight.shape[1] + out_dim = self.base_layer.weight.shape[0] + self.lora_A = nn.Parameter( + torch.zeros( + self.lora_rank, + in_dim, + device=self.base_layer.weight.device, + dtype=self.base_layer.weight.dtype, + ) + ) + self.lora_B = nn.Parameter( + torch.zeros( + out_dim, + self.lora_rank, + device=self.base_layer.weight.device, + dtype=self.base_layer.weight.dtype, + ) + ) + torch.nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + torch.nn.init.zeros_(self.lora_B) + else: + self.lora_A = None + self.lora_B = None + + @torch.compile() + def forward(self, x: torch.Tensor) -> torch.Tensor: + lora_A = self.lora_A + lora_B = self.lora_B + if isinstance(self.lora_B, DTensor): + lora_B = self.lora_B.to_local() + lora_A = self.lora_A.to_local() + + if not self.merged and not self.disable_lora: + lora_A_sliced = self.slice_lora_a_weights(lora_A.to(x, non_blocking=True)) + lora_B_sliced = self.slice_lora_b_weights(lora_B.to(x, non_blocking=True)) + delta = x @ lora_A_sliced.T @ lora_B_sliced.T + if self.lora_alpha != self.lora_rank: + delta = delta * ( + self.lora_alpha / self.lora_rank # type: ignore + ) # type: ignore + out, output_bias = self.base_layer(x) + return out + delta, output_bias + else: + out, output_bias = self.base_layer(x) + return out.to(x), output_bias + + def slice_lora_a_weights(self, A: torch.Tensor) -> torch.Tensor: + return A + + def slice_lora_b_weights(self, B: torch.Tensor) -> torch.Tensor: + return B + + def set_lora_weights( + self, + A: torch.Tensor, + B: torch.Tensor, + training_mode: bool = False, + lora_path: str | None = None, + ) -> None: + self.lora_A = torch.nn.Parameter( + A + ) # share storage with weights in the pipeline + self.lora_B = torch.nn.Parameter(B) + self.disable_lora = False + if not training_mode: + self.merge_lora_weights() + self.lora_path = lora_path + + @torch.no_grad() + def merge_lora_weights(self) -> None: + if self.disable_lora: + return + + if self.merged: + self.unmerge_lora_weights() + assert ( + self.lora_A is not None and self.lora_B is not None + ), "LoRA weights not set. Please set them first." + if isinstance(self.base_layer.weight, DTensor): + mesh = self.base_layer.weight.data.device_mesh + unsharded_base_layer = ReplicatedLinear( + input_size=self.base_layer.input_size, + output_size=self.base_layer.output_size, + bias=getattr(self.base_layer, "bias", None) is not None, + skip_bias_add=self.base_layer.skip_bias_add, + params_dtype=self.base_layer.params_dtype, + quant_config=self.base_layer.quant_config, + prefix=self.base_layer.prefix, + ) + # Using offload param is on CPU, so current_device is for "CPU -> GPU -> merge -> CPU" + current_device = self.base_layer.weight.data.device + data = self.base_layer.weight.data.to( + get_local_torch_device() + ).full_tensor() + data += self.slice_lora_b_weights(self.lora_B).to( + data + ) @ self.slice_lora_a_weights(self.lora_A).to(data) + unsharded_base_layer.weight = nn.Parameter(data.to(current_device)) + if isinstance(getattr(self.base_layer, "bias", None), DTensor): + unsharded_base_layer.bias = nn.Parameter( + self.base_layer.bias.to(get_local_torch_device(), non_blocking=True) + .full_tensor() + .to(current_device) + ) + + offload_policy = ( + CPUOffloadPolicy() if "cpu" in str(current_device) else OffloadPolicy() + ) + mp_policy = get_mixed_precision_state().mp_policy + + self.base_layer = fully_shard( + unsharded_base_layer, + mesh=mesh, + mp_policy=mp_policy, + offload_policy=offload_policy, + ) + else: + current_device = self.base_layer.weight.data.device + data = self.base_layer.weight.data.to(get_local_torch_device()) + data += self.slice_lora_b_weights( + self.lora_B.to(data) + ) @ self.slice_lora_a_weights(self.lora_A.to(data)) + self.base_layer.weight.data = data.to(current_device, non_blocking=True) + + self.merged = True + + @torch.no_grad() + # @torch.compile(dynamic=True) + def unmerge_lora_weights(self) -> None: + if self.disable_lora: + return + + if not self.merged: + raise ValueError( + "LoRA weights not merged. Please merge them first before unmerging." + ) + + # avoid precision loss + if isinstance(self.base_layer.weight, DTensor): + device = self.base_layer.weight.data.device + self.base_layer.weight = nn.Parameter( + self.cpu_weight.to(device, non_blocking=True) + ) + else: + self.base_layer.weight.data = self.cpu_weight.data.to( + self.base_layer.weight, non_blocking=True + ) + + self.merged = False + + +class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): + """ + Vocab parallel embedding layer with support for LoRA (Low-Rank Adaptation). + + Note: The current version does not yet implement the LoRA functionality. + This class behaves exactly the same as the base VocabParallelEmbedding. + Future versions will integrate LoRA functionality to support efficient parameter fine-tuning. + """ + + def __init__( + self, + base_layer: VocabParallelEmbedding, + ) -> None: + super().__init__(base_layer) + + def forward(self, input_: torch.Tensor) -> torch.Tensor: + raise NotImplementedError( + "We don't support VocabParallelEmbeddingWithLoRA yet." + ) + + +class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): + + def __init__( + self, + base_layer: ColumnParallelLinear, + lora_rank: int | None = None, + lora_alpha: int | None = None, + training_mode: bool = False, + ) -> None: + super().__init__(base_layer, lora_rank, lora_alpha, training_mode) + + def forward(self, input_: torch.Tensor) -> torch.Tensor: + # duplicate the logic in ColumnParallelLinear + bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None + output_parallel = self.base_layer.quant_method.apply( + self.base_layer, input_, bias + ) + if self.base_layer.gather_output: + output = tensor_model_parallel_all_gather(output_parallel) + else: + output = output_parallel + output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None + return output, output_bias + + def slice_lora_a_weights(self, A: torch.Tensor) -> torch.Tensor: + return A + + def slice_lora_b_weights(self, B: torch.Tensor) -> torch.Tensor: + tp_rank = get_tp_rank() + shard_size = self.base_layer.output_partition_sizes[0] + start_idx = tp_rank * shard_size + end_idx = (tp_rank + 1) * shard_size + B = B[start_idx:end_idx, :] + return B + + +class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): + + def __init__( + self, + base_layer: MergedColumnParallelLinear, + lora_rank: int | None = None, + lora_alpha: int | None = None, + training_mode: bool = False, + ) -> None: + super().__init__(base_layer, lora_rank, lora_alpha, training_mode) + + def slice_lora_a_weights(self, A: torch.Tensor) -> torch.Tensor: + return A.to(self.base_layer.weight) + + def slice_lora_b_weights(self, B: torch.Tensor) -> torch.Tensor: + tp_rank = get_tp_rank() + # Since the outputs for both gate and up are identical, we use a random one. + shard_size = self.base_layer.output_partition_sizes[0] + start_idx = tp_rank * shard_size + end_idx = (tp_rank + 1) * shard_size + return B[:, start_idx:end_idx, :] + + +class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): + + def __init__( + self, + base_layer: QKVParallelLinear, + lora_rank: int | None = None, + lora_alpha: int | None = None, + training_mode: bool = False, + ) -> None: + super().__init__(base_layer, lora_rank, lora_alpha, training_mode) + + def slice_lora_a_weights(self, A: torch.Tensor) -> torch.Tensor: + return A + + def slice_lora_b_weights( + self, B: list[torch.Tensor] + ) -> tuple[torch.Tensor, torch.Tensor]: + tp_rank = get_tp_rank() + B_q, B_kv = B + base_layer = self.base_layer + q_proj_shard_size = base_layer.q_proj_shard_size + kv_proj_shard_size = base_layer.kv_proj_shard_size + num_kv_head_replicas = base_layer.num_kv_head_replicas + + q_start_idx = q_proj_shard_size * tp_rank + q_end_idx = q_start_idx + q_proj_shard_size + + kv_shard_id = tp_rank // num_kv_head_replicas + kv_start_idx = kv_proj_shard_size * kv_shard_id + kv_end_idx = kv_start_idx + kv_proj_shard_size + + return B_q[q_start_idx:q_end_idx, :], B_kv[:, kv_start_idx:kv_end_idx, :] + + +class RowParallelLinearWithLoRA(BaseLayerWithLoRA): + + def __init__( + self, + base_layer: RowParallelLinear, + lora_rank: int | None = None, + lora_alpha: int | None = None, + training_mode: bool = False, + ) -> None: + super().__init__(base_layer, lora_rank, lora_alpha, training_mode) + + def forward(self, input_: torch.Tensor): + # duplicate the logic in RowParallelLinear + if self.base_layer.input_is_parallel: + input_parallel = input_ + else: + tp_rank = get_tp_rank() + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.base_layer.tp_size + ) + input_parallel = splitted_input[tp_rank].contiguous() + output_parallel = self.base_layer.quant_method.apply( + self.base_layer, input_parallel + ) + + if self.set_lora: + output_parallel = self.apply_lora(output_parallel, input_parallel) + + if self.base_layer.reduce_results and self.base_layer.tp_size > 1: + output_ = tensor_model_parallel_all_reduce(output_parallel) + else: + output_ = output_parallel + + if not self.base_layer.skip_bias_add: + output = ( + output_ + self.base_layer.bias + if self.base_layer.bias is not None + else output_ + ) + output_bias = None + else: + output = output_ + output_bias = self.base_layer.bias + return output, output_bias + + def slice_lora_a_weights(self, A: torch.Tensor) -> torch.Tensor: + tp_rank = get_tp_rank() + shard_size = self.base_layer.input_size_per_partition + start_idx = tp_rank * shard_size + end_idx = (tp_rank + 1) * shard_size + A = A[:, start_idx:end_idx].contiguous() + return A + + def slice_lora_b_weights(self, B: torch.Tensor) -> torch.Tensor: + return B + + +def get_lora_layer( + layer: nn.Module, + lora_rank: int | None = None, + lora_alpha: int | None = None, + training_mode: bool = False, +) -> BaseLayerWithLoRA | None: + supported_layer_types: dict[type[LinearBase], type[BaseLayerWithLoRA]] = { + # the order matters + # VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA, + QKVParallelLinear: QKVParallelLinearWithLoRA, + MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA, + ColumnParallelLinear: ColumnParallelLinearWithLoRA, + RowParallelLinear: RowParallelLinearWithLoRA, + ReplicatedLinear: BaseLayerWithLoRA, + } + for src_layer_type, lora_layer_type in supported_layer_types.items(): + if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck + ret = lora_layer_type( + layer, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + training_mode=training_mode, + ) + return ret + return None + + +# source: https://github.com/vllm-project/vllm/blob/93b38bea5dd03e1b140ca997dfaadef86f8f1855/vllm/lora/utils.py#L9 +def replace_submodule( + model: nn.Module, module_name: str, new_module: nn.Module +) -> nn.Module: + """Replace a submodule in a model with a new module.""" + parent = model.get_submodule(".".join(module_name.split(".")[:-1])) + target_name = module_name.split(".")[-1] + setattr(parent, target_name, new_module) + return new_module diff --git a/python/sglang/multimodal_gen/runtime/layers/mlp.py b/python/sglang/multimodal_gen/runtime/layers/mlp.py new file mode 100644 index 000000000..17918e2aa --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/layers/mlp.py @@ -0,0 +1,46 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.nn as nn + +from sglang.multimodal_gen.runtime.layers.activation import get_act_fn +from sglang.multimodal_gen.runtime.layers.linear import ReplicatedLinear + + +class MLP(nn.Module): + """ + MLP for DiT blocks, NO gated linear units + """ + + def __init__( + self, + input_dim: int, + mlp_hidden_dim: int, + output_dim: int | None = None, + bias: bool = True, + act_type: str = "gelu_pytorch_tanh", + dtype: torch.dtype | None = None, + prefix: str = "", + ): + super().__init__() + self.fc_in = ReplicatedLinear( + input_dim, + mlp_hidden_dim, # For activation func like SiLU that need 2x width + bias=bias, + params_dtype=dtype, + ) + + self.act = get_act_fn(act_type) + if output_dim is None: + output_dim = input_dim + self.fc_out = ReplicatedLinear( + mlp_hidden_dim, output_dim, bias=bias, params_dtype=dtype + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, _ = self.fc_in(x) + x = self.act(x) + x, _ = self.fc_out(x) + return x diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/__init__.py b/python/sglang/multimodal_gen/runtime/layers/quantization/__init__.py new file mode 100644 index 000000000..0d6c79797 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/__init__.py @@ -0,0 +1,71 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +from typing import Literal, get_args + +from sglang.multimodal_gen.runtime.layers.quantization.base_config import ( + QuantizationConfig, +) + +QuantizationMethods = Literal[None] + +QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods)) + +# The customized quantization methods which will be added to this dict. +_CUSTOMIZED_METHOD_TO_QUANT_CONFIG = {} + + +def register_quantization_config(quantization: str): + """Register a customized vllm quantization config. + + When a quantization method is not supported by vllm, you can register a customized + quantization config to support it. + + Args: + quantization (str): The quantization method name. + + Examples: + >>> from sglang.multimodal_gen.runtime.layers.quantization import register_quantization_config + >>> from sglang.multimodal_gen.runtime.layers.quantization import get_quantization_config + >>> from sglang.multimodal_gen.runtime.layers.quantization.base_config import QuantizationConfig + >>> + >>> @register_quantization_config("my_quant") + ... class MyQuantConfig(QuantizationConfig): + ... pass + >>> + >>> get_quantization_config("my_quant") + + """ # noqa: E501 + + def _wrapper(quant_config_cls): + if quantization in QUANTIZATION_METHODS: + raise ValueError( + f"The quantization method `{quantization}` is already exists." + ) + if not issubclass(quant_config_cls, QuantizationConfig): + raise ValueError( + "The quantization config must be a subclass of " "`QuantizationConfig`." + ) + _CUSTOMIZED_METHOD_TO_QUANT_CONFIG[quantization] = quant_config_cls + QUANTIZATION_METHODS.append(quantization) + return quant_config_cls + + return _wrapper + + +def get_quantization_config(quantization: str) -> type[QuantizationConfig]: + if quantization not in QUANTIZATION_METHODS: + raise ValueError(f"Invalid quantization method: {quantization}") + + method_to_config: dict[str, type[QuantizationConfig]] = {} + # Update the `method_to_config` with customized quantization methods. + method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG) + + return method_to_config[quantization] + + +all = [ + "QuantizationMethods", + "QuantizationConfig", + "get_quantization_config", + "QUANTIZATION_METHODS", +] diff --git a/python/sglang/multimodal_gen/runtime/layers/quantization/base_config.py b/python/sglang/multimodal_gen/runtime/layers/quantization/base_config.py new file mode 100644 index 000000000..ffb275a8b --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/layers/quantization/base_config.py @@ -0,0 +1,152 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/layers/quantization/base_config.py + +import inspect +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any + +import torch +from torch import nn + +if TYPE_CHECKING: + from sglang.multimodal_gen.runtime.layers.quantization import QuantizationMethods +else: + QuantizationMethods = str + + +class QuantizeMethodBase(ABC): + """Base class for different quantized methods.""" + + @abstractmethod + def create_weights( + self, layer: torch.nn.Module, *weight_args, **extra_weight_attrs + ): + """Create weights for a layer. + + The weights will be set as attributes of the layer.""" + raise NotImplementedError + + @abstractmethod + def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor: + """Apply the weights in layer to the input tensor. + + Expects create_weights to have been called before on the layer.""" + raise NotImplementedError + + # Not required functions + def embedding(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor: + """Gather embeddings in the layer based on indices in the input tensor. + + Expects create_weights to have been called before on the layer.""" + raise NotImplementedError + + def process_weights_after_loading(self, layer: nn.Module) -> None: + """Process the weight after loading. + + This can be used for example, to transpose weights for computation. + """ + return + + +def method_has_implemented_embedding(method_class: type[QuantizeMethodBase]) -> bool: + """ + Not all quant methods have embedding implemented, so we need to check that + it exists for our given method. We check this by making sure the function + has been changed from the base implementation. + """ + base_embedding = inspect.getattr_static(QuantizeMethodBase, "embedding", None) + class_embedding = inspect.getattr_static(method_class, "embedding", None) + + return class_embedding is not None and class_embedding is not base_embedding + + +class QuantizationConfig(ABC): + """Base class for quantization configs.""" + + def __init__(self): + super().__init__() + # mapping is updated by models as they initialize + self.packed_modules_mapping: dict[str, list[str]] = dict() + + @abstractmethod + def get_name(self) -> QuantizationMethods: + """Name of the quantization method.""" + raise NotImplementedError + + @abstractmethod + def get_supported_act_dtypes(self) -> list[torch.dtype]: + """List of supported activation dtypes.""" + raise NotImplementedError + + @classmethod + @abstractmethod + def get_min_capability(cls) -> int: + """Minimum GPU capability to support the quantization method. + + E.g., 70 for Volta, 75 for Turing, 80 for Ampere. + This requirement is due to the custom CUDA kernels used by the + quantization method. + """ + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_config_filenames() -> list[str]: + """List of filenames to search for in the model directory.""" + raise NotImplementedError + + @classmethod + @abstractmethod + def from_config(cls, config: dict[str, Any]) -> "QuantizationConfig": + """Create a config class from the model's quantization config.""" + raise NotImplementedError + + @classmethod + def override_quantization_method( + cls, hf_quant_cfg, user_quant + ) -> QuantizationMethods | None: + """ + Detects if this quantization method can support a given checkpoint + format by overriding the user specified quantization method -- + this method should only be overwritten by subclasses in exceptional + circumstances + """ + return None + + @staticmethod + def get_from_keys(config: dict[str, Any], keys: list[str]) -> Any: + """Get a value from the model's quantization config.""" + for key in keys: + if key in config: + return config[key] + raise ValueError( + f"Cannot find any of {keys} in the model's " "quantization config." + ) + + @staticmethod + def get_from_keys_or(config: dict[str, Any], keys: list[str], default: Any) -> Any: + """Get a optional value from the model's quantization config.""" + try: + return QuantizationConfig.get_from_keys(config, keys) + except ValueError: + return default + + @abstractmethod + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> QuantizeMethodBase | None: + """Get the quantize method to use for the quantized layer. + + Args: + layer: The layer for the quant method. + prefix: The full name of the layer in the state dict + Returns: + The quantize method. None if the given layer doesn't support quant + method. + """ + raise NotImplementedError + + def get_cache_scale(self, name: str) -> str | None: + return None diff --git a/python/sglang/multimodal_gen/runtime/layers/rotary_embedding.py b/python/sglang/multimodal_gen/runtime/layers/rotary_embedding.py new file mode 100644 index 000000000..698e3cd9a --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/layers/rotary_embedding.py @@ -0,0 +1,886 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/layers/rotary_embedding.py + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Rotary Positional Embeddings.""" +import functools +from collections import OrderedDict +from typing import Any + +import torch + +from sglang.multimodal_gen.runtime.distributed.parallel_state import get_sp_group +from sglang.multimodal_gen.runtime.layers.custom_op import CustomOp +from sglang.multimodal_gen.runtime.layers.triton_ops import apply_rotary_embedding +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +def _rotate_neox(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def _rotate_gptj(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., ::2] + x2 = x[..., 1::2] + x = torch.stack((-x2, x1), dim=-1) + return x.flatten(-2) + + +def _apply_rotary_emb( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_neox_style: bool, + interleaved: bool = False, +) -> torch.Tensor: + """ + Args: + x: [num_tokens, num_heads, head_size] or [num_tokens, head_size] + cos: [num_tokens, head_size // 2] + sin: [num_tokens, head_size // 2] + is_neox_style: Whether to use the Neox-style or GPT-J-style rotary + positional embeddings. + """ + # cos = cos.unsqueeze(-2).to(x.dtype) + # sin = sin.unsqueeze(-2).to(x.dtype) + if is_neox_style: + cos = cos.unsqueeze(-2) + sin = sin.unsqueeze(-2) + if is_neox_style: + x1, x2 = torch.chunk(x, 2, dim=-1) + else: + x1 = x[..., ::2] + x2 = x[..., 1::2] + o1 = (x1.float() * cos - x2.float() * sin).type_as(x) + o2 = (x2.float() * cos + x1.float() * sin).type_as(x) + return torch.cat((o1, o2), dim=-1) + else: + return apply_rotary_embedding(x, cos, sin, interleaved) + + +@CustomOp.register("rotary_embedding") +class RotaryEmbedding(CustomOp): + """Original rotary positional embedding.""" + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int | float, + is_neox_style: bool, + dtype: torch.dtype, + ) -> None: + super().__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.is_neox_style = is_neox_style + self.dtype = dtype + + cache = self._compute_cos_sin_cache() + cache = cache.to(dtype) + self.cos_sin_cache: torch.Tensor + self.register_buffer("cos_sin_cache", cache, persistent=False) + + def _compute_inv_freq(self, base: int | float) -> torch.Tensor: + """Compute the inverse frequency.""" + # NOTE(woosuk): To exactly match the HF implementation, we need to + # use CPU to compute the cache and then move it to GPU. However, we + # create the cache on GPU for faster initialization. This may cause + # a slight numerical difference between the HF implementation and ours. + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim + ) + ) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + """Compute the cos and sin cache.""" + inv_freq = self._compute_inv_freq(self.base) + t = torch.arange(self.max_position_embeddings, dtype=torch.float) + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward_cuda(self, *args, **kwargs) -> Any: + return self.forward_native(*args, **kwargs) + + def forward_native( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """A PyTorch-native implementation of forward().""" + if offsets is not None: + positions = positions + offsets + positions = positions.flatten() + num_tokens = positions.shape[0] + cos_sin = self.cos_sin_cache.index_select(0, positions) + cos, sin = cos_sin.chunk(2, dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., : self.rotary_dim] + query_pass = query[..., self.rotary_dim :] + query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., : self.rotary_dim] + key_pass = key[..., self.rotary_dim :] + key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + + def extra_repr(self) -> str: + s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" + s += f", max_position_embeddings={self.max_position_embeddings}" + s += f", base={self.base}, is_neox_style={self.is_neox_style}" + return s + + +class OneDRotaryEmbedding(torch.nn.Module): + """1D rotary positional embedding with caching.""" + + def __init__( + self, + dim: int, + theta: float = 10000.0, + theta_rescale_factor: float = 1.0, + interpolation_factor: float = 1.0, + dtype: torch.dtype = torch.float32, + use_real: bool = False, + repeat_interleave_real: bool = False, + ): + super().__init__() + assert dim % 2 == 0 + self.dim = dim + self.theta = theta + self.theta_rescale_factor = theta_rescale_factor + self.interpolation_factor = interpolation_factor + # dtype of freqs + self.dtype = dtype + self.use_real = use_real + self.repeat_interleave_real = repeat_interleave_real + + def build_freqs(self, device): + freqs = 1.0 / ( + self.theta + ** ( + torch.arange(0, self.dim, 2, dtype=self.dtype, device=device)[ + : (self.dim // 2) + ] + / self.dim + ).to(device=device) + ) + return freqs + + def build_freqs_outer(self, pos: torch.Tensor, device): + theta = self.theta + # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning + # has some connection to NTK literature + if self.theta_rescale_factor != 1.0: + theta *= self.theta_rescale_factor ** (self.dim / (self.dim - 2)) + + freqs = self.build_freqs(device) + + freqs = torch.outer(pos * self.interpolation_factor, freqs) + freqs_cos = freqs.cos() + freqs_sin = freqs.sin() + + if self.use_real and self.repeat_interleave_real: + freqs_cos = freqs_cos.repeat_interleave(2, dim=1) + freqs_sin = freqs_sin.repeat_interleave(2, dim=1) + + return freqs_cos.float(), freqs_sin.float() + + @functools.lru_cache(maxsize=16) + def forward_from_grid( + self, seq_len: int, start_pos: int, device_str: str + ) -> tuple[torch.Tensor, torch.Tensor]: + device = torch.device(device_str) + pos = torch.arange( + start_pos, start_pos + seq_len, dtype=self.dtype, device=device + ) + + freqs_cos, freqs_sin = self.build_freqs_outer(pos, device) + return freqs_cos, freqs_sin + + def forward(self, pos: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Calculates 1D rotary embeddings for the given positions. + + This method converts the input tensor to a hashable representation + and calls a cached helper method to perform the computation. + """ + pos_tuple = tuple(pos.tolist()) + device_str = str(pos.device) + return self._forward_cached(pos_tuple, device_str) + + @functools.lru_cache(maxsize=16) + def _forward_cached( + self, pos_tuple: tuple, device_str: str + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + The core implementation that computes 1D rotary embeddings. + This method is wrapped by an LRU cache. + """ + device = torch.device(device_str) + pos = torch.as_tensor(pos_tuple, dtype=self.dtype, device=device) + freqs_cos, freqs_sin = self.build_freqs_outer(pos, device) + return freqs_cos, freqs_sin + + +class NDRotaryEmbedding(torch.nn.Module): + """N-dimensional rotary positional embedding.""" + + def __init__( + self, + rope_dim_list: list[int], + rope_theta: float, + theta_rescale_factor: float | list[float] = 1.0, + interpolation_factor: float | list[float] = 1.0, + use_real: bool = False, + repeat_interleave_real: bool = False, + dtype: torch.dtype = torch.float32, + ): + super().__init__() + self.rope_dim_list = rope_dim_list + self.ndim = len(rope_dim_list) + self.rope_theta = rope_theta + # dtype of freqs + # does not control the output dtype + self.dtype = dtype + + if isinstance(theta_rescale_factor, (int, float)): + self.theta_rescale_factor = [theta_rescale_factor] * self.ndim + elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1: + self.theta_rescale_factor = [theta_rescale_factor[0]] * self.ndim + else: + self.theta_rescale_factor = theta_rescale_factor + assert ( + len(self.theta_rescale_factor) == self.ndim + ), "len(theta_rescale_factor) should equal to len(rope_dim_list)" + + if isinstance(interpolation_factor, (int, float)): + self.interpolation_factor = [interpolation_factor] * self.ndim + elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1: + self.interpolation_factor = [interpolation_factor[0]] * self.ndim + else: + self.interpolation_factor = interpolation_factor + assert ( + len(self.interpolation_factor) == self.ndim + ), "len(interpolation_factor) should equal to len(rope_dim_list)" + + self.rope_generators: list[OneDRotaryEmbedding] = torch.nn.ModuleList() + _config_to_gen_idx: dict[tuple, int] = {} + self.dim_idx_to_gen_idx: list[int] = [] + + for i in range(self.ndim): + dim = self.rope_dim_list[i] + rescale = self.theta_rescale_factor[i] + interp = self.interpolation_factor[i] + + config_key = (dim, rescale, interp, use_real, repeat_interleave_real) + if config_key not in _config_to_gen_idx: + generator = OneDRotaryEmbedding( + dim=dim, + theta=self.rope_theta, + theta_rescale_factor=rescale, + interpolation_factor=interp, + dtype=self.dtype, + use_real=use_real, + repeat_interleave_real=repeat_interleave_real, + ) + _config_to_gen_idx[config_key] = len(self.rope_generators) + self.rope_generators.append(generator) + + gen_idx = _config_to_gen_idx[config_key] + self.dim_idx_to_gen_idx.append(gen_idx) + + def forward(self, positions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Calculates n-d rotary embeddings for given absolute positions. + + Args: + positions (torch.Tensor): A tensor of shape `[num_tokens, ndim]` + containing the integer coordinates for each token. + + Returns: + A tuple of (cos, sin) tensors. + """ + # Caching wrapper: convert tensor to a hashable tuple of tuples. + pos_tuple = tuple(map(tuple, positions.tolist())) + device_str = str(positions.device) + return self._forward_cached(pos_tuple, device_str) + + @functools.lru_cache(maxsize=16) + def _forward_cached( + self, pos_tuple: tuple[tuple[int, ...], ...], device_str: str + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + The core implementation that computes embeddings from a position tensor. + This method is wrapped by an LRU cache. + """ + device = torch.device(device_str) + positions = torch.tensor(pos_tuple, dtype=torch.long, device=device) + return self.forward_uncached(pos=positions) + + def forward_uncached(self, pos: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + The core implementation that computes embeddings from a position tensor. + This method is wrapped by an LRU cache. + """ + device = pos.device + + # Pre-allocate the final tensors for efficiency. + num_tokens = pos.shape[0] + first_generator = self.rope_generators[0] + if first_generator.use_real and first_generator.repeat_interleave_real: + head_dim = sum(self.rope_dim_list) + else: + head_dim = sum(self.rope_dim_list) // 2 + + cos = torch.empty((num_tokens, head_dim), device=device, dtype=self.dtype) + sin = torch.empty((num_tokens, head_dim), device=device, dtype=self.dtype) + + col_offset = 0 + for i in range(self.ndim): + # Extract position coordinates for the current dimension for all tokens. + pos_i = pos[:, i].to(self.dtype) + + # Get the appropriate 1D generator. + gen_idx = self.dim_idx_to_gen_idx[i] + generator = self.rope_generators[gen_idx] + + # Calculate 1D embeddings. + cos_1d, sin_1d = generator(pos_i) + + slice_width = cos_1d.shape[1] + cos[:, col_offset : col_offset + slice_width] = cos_1d + sin[:, col_offset : col_offset + slice_width] = sin_1d + col_offset += slice_width + + return cos.float(), sin.float() + + def forward_from_grid( + self, + grid_size: tuple[int, ...], + shard_dim: int = 0, + start_frame: int = 0, + device: torch.device | str | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Caching wrapper: use grid parameters directly as the key. + # grid_tuple = _to_tuple(grid_size, dim=self.ndim) + device_str = str(device) if device is not None else "cpu" + return self._forward_cached_from_grid( + grid_size, shard_dim, start_frame, device_str + ) + + @functools.lru_cache(maxsize=16) + def _forward_cached_from_grid( + self, + grid_size: tuple[int, ...], + shard_dim: int, + start_frame: int, + device_str: str, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Computes embeddings for a structured grid, using a highly efficient + implementation that avoids materializing the full position tensor. + This method is wrapped by an LRU cache. + """ + device = torch.device(device_str) + sp_group = get_sp_group() + sp_rank = sp_group.rank_in_group + sp_world_size = sp_group.world_size + + sizes = _to_tuple(grid_size, dim=self.ndim) + starts = (0,) * self.ndim + + # Apply sequence parallel sharding to the sizes and compute shard offset + shard_sizes = list(sizes) + shard_offsets = [0] * self.ndim + if sp_world_size > 1: + assert sizes[shard_dim] % sp_world_size == 0, ( + f"Dimension {shard_dim} with size {sizes[shard_dim]} is not divisible " + f"by sequence parallel world size {sp_world_size}" + ) + shard_size = sizes[shard_dim] // sp_world_size + shard_offsets[shard_dim] = sp_rank * shard_size + shard_sizes[shard_dim] = shard_size + + # Pre-allocate outputs on the requested device to avoid CPU ops and extra cats + num_tokens = 1 + for s in shard_sizes: + num_tokens *= int(s) + head_dim_half = sum(self.rope_dim_list) // 2 + cos = torch.empty((num_tokens, head_dim_half), device=device, dtype=self.dtype) + sin = torch.empty((num_tokens, head_dim_half), device=device, dtype=self.dtype) + + # Compute per-axis 1D embeddings once and expand via repeats to [N, d_i/2] + col_offset = 0 + for i in range(self.ndim): + dim_i = self.rope_dim_list[i] + dim_i_half = dim_i // 2 + size_i = int(shard_sizes[i]) + + # Starting position for this axis, with optional frame offset for time axis (i==0) + base_offset = starts[i] + if i == 0 and start_frame > 0: + base_offset += start_frame + if sp_world_size > 1 and i == shard_dim: + base_offset += shard_offsets[i] + + gen_idx = self.dim_idx_to_gen_idx[i] + generator = self.rope_generators[gen_idx] + cos_1d, sin_1d = generator.forward_from_grid( + size_i, base_offset, device_str + ) + + # Expand to [num_tokens, dim_i/2] matching flatten order (last dims vary fastest) + repeats_per_entry = 1 + for j in range(i + 1, self.ndim): + repeats_per_entry *= int(shard_sizes[j]) + tile_count = 1 + for j in range(0, i): + tile_count *= int(shard_sizes[j]) + + cos_expanded = cos_1d.repeat_interleave(repeats_per_entry, dim=0) + sin_expanded = sin_1d.repeat_interleave(repeats_per_entry, dim=0) + if tile_count > 1: + cos_expanded = cos_expanded.repeat(tile_count, 1) + sin_expanded = sin_expanded.repeat(tile_count, 1) + + cos[:, col_offset : col_offset + dim_i_half] = cos_expanded + sin[:, col_offset : col_offset + dim_i_half] = sin_expanded + col_offset += dim_i_half + + return cos.float(), sin.float() + + +def _to_tuple(x: int | tuple[int, ...], dim: int = 2) -> tuple[int, ...]: + if isinstance(x, int): + return (x,) * dim + elif len(x) == dim: + return x + else: + raise ValueError(f"Expected length {dim} or int, but got {x}") + + +def get_meshgrid_nd( + start: int | tuple[int, ...], + *args: int | tuple[int, ...], + dim: int = 2, + device: torch.device | str | None = None, + dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """ + Get n-D meshgrid with start, stop and num. + + Args: + start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, + step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num + should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in + n-tuples. + *args: See above. + dim (int): Dimension of the meshgrid. Defaults to 2. + + Returns: + grid (np.ndarray): [dim, ...] + """ + if len(args) == 0: + # start is grid_size + num = _to_tuple(start, dim=dim) + start = (0,) * dim + stop = num + elif len(args) == 1: + # start is start, args[0] is stop, step is 1 + start = _to_tuple(start, dim=dim) + stop = _to_tuple(args[0], dim=dim) + num = tuple(stop[i] - start[i] for i in range(dim)) + elif len(args) == 2: + # start is start, args[0] is stop, args[1] is num + start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0 + stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32 + num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124 + else: + raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}") + + # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False) + axis_grid = [] + for i in range(dim): + a, b, n = start[i], stop[i], num[i] + g = torch.linspace(a, b, n + 1, dtype=dtype, device=device)[:n] + axis_grid.append(g) + grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D] + grid = torch.stack(grid, dim=0) # [dim, W, H, D] + + return grid + + +def get_1d_rotary_pos_embed( + dim: int, + pos: torch.FloatTensor | int, + theta: float = 10000.0, + theta_rescale_factor: float = 1.0, + interpolation_factor: float = 1.0, + dtype: torch.dtype = torch.float32, + device: torch.device | str | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Precompute the frequency tensor for complex exponential (cis) with given dimensions. + (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.) + + This function calculates a frequency tensor with complex exponential using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + + Args: + dim (int): Dimension of the frequency tensor. + pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar + theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0. + interpolation_factor (float, optional): Factor to scale positions. Defaults to 1.0. + + Returns: + freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D] + """ + if isinstance(pos, int): + pos = torch.arange(pos, dtype=dtype, device=device) + elif ( + isinstance(pos, torch.Tensor) + and device is not None + and pos.device != torch.device(device) + ): + # Ensure positions are on the requested device to avoid implicit CPU ops. + pos = pos.to(device) + + # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning + # has some connection to NTK literature + if theta_rescale_factor != 1.0: + theta *= theta_rescale_factor ** (dim / (dim - 2)) + + freqs = 1.0 / ( + theta + ** (torch.arange(0, dim, 2, device=device)[: (dim // 2)].to(dtype) / dim).to( + device=device + ) + ) # [D/2] + freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2] + freqs_cos = freqs.cos() # [S, D/2] + freqs_sin = freqs.sin() # [S, D/2] + return freqs_cos, freqs_sin + + +def get_nd_rotary_pos_embed( + rope_dim_list, + start, + *args, + theta=10000.0, + theta_rescale_factor: float | list[float] = 1.0, + interpolation_factor: float | list[float] = 1.0, + shard_dim: int = 0, + sp_rank: int = 0, + sp_world_size: int = 1, + dtype: torch.dtype = torch.float32, + start_frame: int = 0, + device: torch.device | str | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure. + Supports sequence parallelism by allowing sharding of a specific dimension. + + Args: + rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n. + sum(rope_dim_list) should equal to head_dim of attention layer. + start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start, + args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. + *args: See above. + theta (float): Scaling factor for frequency computation. Defaults to 10000.0. + theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0. + interpolation_factor (float): Factor to scale positions. Defaults to 1.0. + shard_dim (int): Which dimension to shard for sequence parallelism. Defaults to 0. + sp_rank (int): Rank in the sequence parallel group. Defaults to 0. + sp_world_size (int): World size of the sequence parallel group. Defaults to 1. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: (cos, sin) tensors of shape [HW, D/2] + """ + # Determine per-axis sizes for the (possibly sharded) grid without materializing it + ndim = len(rope_dim_list) + if len(args) == 0: + # start is grid_size + sizes = _to_tuple(start, dim=ndim) + starts = (0,) * ndim + elif len(args) == 1: + # start is start, args[0] is stop, step is 1 + starts = _to_tuple(start, dim=ndim) + stops = _to_tuple(args[0], dim=ndim) + sizes = tuple(stops[i] - starts[i] for i in range(ndim)) + elif len(args) == 2: + # start is start, args[0] is stop, args[1] is num + starts = _to_tuple(start, dim=ndim) + _ = _to_tuple(args[0], dim=ndim) # stop, unused here + sizes = _to_tuple(args[1], dim=ndim) + else: + raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}") + + assert ( + shard_dim < ndim + ), f"shard_dim {shard_dim} must be less than number of dimensions {ndim}" + + # Apply sequence parallel sharding to the sizes and compute shard offset + shard_sizes = list(sizes) + shard_offsets = [0] * ndim + if sp_world_size > 1: + assert sizes[shard_dim] % sp_world_size == 0, ( + f"Dimension {shard_dim} with size {sizes[shard_dim]} is not divisible " + f"by sequence parallel world size {sp_world_size}" + ) + shard_size = sizes[shard_dim] // sp_world_size + shard_offsets[shard_dim] = sp_rank * shard_size + shard_sizes[shard_dim] = shard_size + + # Handle theta scaling/interpolation factor per-axis + if isinstance(theta_rescale_factor, int | float): + theta_rescale_factor = [theta_rescale_factor] * ndim + elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1: + theta_rescale_factor = [theta_rescale_factor[0]] * ndim + assert ( + len(theta_rescale_factor) == ndim + ), "len(theta_rescale_factor) should equal to len(rope_dim_list)" + + if isinstance(interpolation_factor, int | float): + interpolation_factor = [interpolation_factor] * ndim + elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1: + interpolation_factor = [interpolation_factor[0]] * ndim + assert ( + len(interpolation_factor) == ndim + ), "len(interpolation_factor) should equal to len(rope_dim_list)" + + # Pre-allocate outputs on the requested device to avoid CPU ops and extra cats + num_tokens = 1 + for s in shard_sizes: + num_tokens *= int(s) + head_dim_half = sum(rope_dim_list) // 2 + cos = torch.empty((num_tokens, head_dim_half), device=device, dtype=dtype) + sin = torch.empty((num_tokens, head_dim_half), device=device, dtype=dtype) + # Compute per-axis 1D embeddings once and expand via repeats to [N, d_i/2] + col_offset = 0 + for i in range(ndim): + dim_i = int(rope_dim_list[i]) + dim_i_half = dim_i // 2 + size_i = int(shard_sizes[i]) + + # Starting position for this axis, with optional frame offset for time axis (i==0) + base_offset = starts[i] + if i == 0 and start_frame > 0: + base_offset += start_frame + if sp_world_size > 1 and i == shard_dim: + base_offset += shard_offsets[i] + + pos_i = torch.arange(size_i, device=device, dtype=dtype) + base_offset + + cos_1d, sin_1d = get_1d_rotary_pos_embed( + dim_i, + pos_i, + theta=theta, + theta_rescale_factor=theta_rescale_factor[i], + interpolation_factor=interpolation_factor[i], + dtype=dtype, + device=device, + ) # [size_i, dim_i/2] + + # Expand to [num_tokens, dim_i/2] matching flatten order (last dims vary fastest) + repeats_per_entry = 1 + for j in range(i + 1, ndim): + repeats_per_entry *= int(shard_sizes[j]) + tile_count = 1 + for j in range(0, i): + tile_count *= int(shard_sizes[j]) + + cos_expanded = cos_1d.repeat_interleave(repeats_per_entry, dim=0) + sin_expanded = sin_1d.repeat_interleave(repeats_per_entry, dim=0) + if tile_count > 1: + cos_expanded = cos_expanded.repeat(tile_count, 1) + sin_expanded = sin_expanded.repeat(tile_count, 1) + + cos[:, col_offset : col_offset + dim_i_half] = cos_expanded + sin[:, col_offset : col_offset + dim_i_half] = sin_expanded + col_offset += dim_i_half + + return cos, sin + + +def get_rotary_pos_embed( + rope_sizes, + hidden_size, + heads_num, + rope_dim_list, + rope_theta, + theta_rescale_factor=1.0, + interpolation_factor=1.0, + shard_dim: int = 0, + dtype: torch.dtype = torch.float32, + start_frame: int = 0, + device: torch.device | str | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Generate rotary positional embeddings for the given sizes. + + Args: + rope_sizes: Tuple of dimensions (t, h, w) + hidden_size: Hidden dimension size + heads_num: Number of attention heads + rope_dim_list: List of dimensions for each axis, or None + rope_theta: Base for frequency calculations + theta_rescale_factor: Rescale factor for theta. Defaults to 1.0 + interpolation_factor: Factor to scale positions. Defaults to 1.0 + shard_dim: Which dimension to shard for sequence parallelism. Defaults to 0. + + Returns: + Tuple of (cos, sin) tensors for rotary embeddings + """ + + target_ndim = 3 + head_dim = hidden_size // heads_num + + if rope_dim_list is None: + rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] + + assert ( + sum(rope_dim_list) == head_dim + ), "sum(rope_dim_list) should equal to head_dim of attention layer" + + # Get SP info - now handled within NDRotaryEmbedding + # sp_group = get_sp_group() + # sp_rank = sp_group.rank_in_group + # sp_world_size = sp_group.world_size + + # Simple LRU cache keyed by parameters + global _ND_ROPE_CACHE + key = ( + tuple(rope_dim_list), + float(rope_theta), + ( + tuple(theta_rescale_factor) + if isinstance(theta_rescale_factor, list) + else float(theta_rescale_factor) + ), + ( + tuple(interpolation_factor) + if isinstance(interpolation_factor, list) + else float(interpolation_factor) + ), + dtype, + ) + + cache_hit = key in _ND_ROPE_CACHE + if cache_hit: + rope_emb = _ND_ROPE_CACHE.pop(key) + _ND_ROPE_CACHE[key] = rope_emb # move to end (most-recent) + else: + rope_emb = NDRotaryEmbedding( + rope_dim_list=rope_dim_list, + rope_theta=rope_theta, + theta_rescale_factor=theta_rescale_factor, + interpolation_factor=interpolation_factor, + dtype=dtype, + ) + _ND_ROPE_CACHE[key] = rope_emb + if len(_ND_ROPE_CACHE) > 16: + # pop least-recently-used + _ND_ROPE_CACHE.pop(next(iter(_ND_ROPE_CACHE))) + + freqs_cos, freqs_sin = rope_emb.forward_from_grid( + grid_size=_to_tuple(rope_sizes, dim=3), + shard_dim=shard_dim, + start_frame=start_frame, + device=device, + ) + return freqs_cos, freqs_sin + + +_ROPE_DICT: dict[tuple, RotaryEmbedding] = {} +_ND_ROPE_CACHE: "OrderedDict[tuple, NDRotaryEmbedding]" = OrderedDict() +_ROPE_3D_CACHE: "OrderedDict[tuple, tuple[torch.Tensor, torch.Tensor]]" = OrderedDict() + + +def get_rope( + head_size: int, + rotary_dim: int, + max_position: int, + base: int | float, + is_neox_style: bool = True, + rope_scaling: dict[str, Any] | None = None, + dtype: torch.dtype | None = None, + partial_rotary_factor: float = 1.0, +) -> RotaryEmbedding: + if dtype is None: + dtype = torch.get_default_dtype() + if rope_scaling is not None: + # Transforms every value that is a list into a tuple for caching calls + rope_scaling_tuple = { + k: tuple(v) if isinstance(v, list) else v for k, v in rope_scaling.items() + } + rope_scaling_args = tuple(rope_scaling_tuple.items()) + else: + rope_scaling_args = None + if partial_rotary_factor < 1.0: + rotary_dim = int(rotary_dim * partial_rotary_factor) + key = ( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + rope_scaling_args, + dtype, + ) + if key in _ROPE_DICT: + return _ROPE_DICT[key] + + if rope_scaling is None: + rotary_emb = RotaryEmbedding( + head_size, rotary_dim, max_position, base, is_neox_style, dtype + ) + else: + raise ValueError(f"Unknown RoPE scaling {rope_scaling}") + _ROPE_DICT[key] = rotary_emb + return rotary_emb diff --git a/python/sglang/multimodal_gen/runtime/layers/triton_ops.py b/python/sglang/multimodal_gen/runtime/layers/triton_ops.py new file mode 100644 index 000000000..8fa74994a --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/layers/triton_ops.py @@ -0,0 +1,948 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# TODO: for temporary usage, expecting a refactor +from typing import Optional + +import torch +import triton # type: ignore +import triton.language as tl # type: ignore +from torch import Tensor + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_N": 64}, num_warps=2), + triton.Config({"BLOCK_N": 128}, num_warps=4), + triton.Config({"BLOCK_N": 256}, num_warps=4), + triton.Config({"BLOCK_N": 512}, num_warps=4), + triton.Config({"BLOCK_N": 1024}, num_warps=8), + ], + key=["inner_dim"], +) +@triton.jit +def _fused_scale_shift_4d_kernel( + output_ptr, + normalized_ptr, + scale_ptr, + shift_ptr, + rows, + inner_dim, + seq_len, + num_frames, + frame_seqlen, + BLOCK_N: tl.constexpr, +): + pid_row = tl.program_id(0) + pid_col = tl.program_id(1) + + col_offsets = pid_col * BLOCK_N + tl.arange(0, BLOCK_N) + mask = col_offsets < inner_dim + + # Pointers for normalized and output + row_base = pid_row * inner_dim + norm_ptrs = normalized_ptr + row_base + col_offsets + out_ptrs = output_ptr + row_base + col_offsets + + # Pointers for scale and shift for 4D + b_idx = pid_row // seq_len + t_idx = pid_row % seq_len + frame_idx_in_batch = t_idx // frame_seqlen + + scale_row_idx = b_idx * num_frames + frame_idx_in_batch + scale_ptrs = scale_ptr + scale_row_idx * inner_dim + col_offsets + shift_ptrs = shift_ptr + scale_row_idx * inner_dim + col_offsets + + normalized = tl.load(norm_ptrs, mask=mask, other=0.0) + scale = tl.load(scale_ptrs, mask=mask, other=0.0) + shift = tl.load(shift_ptrs, mask=mask, other=0.0) + + one = tl.full([BLOCK_N], 1.0, dtype=scale.dtype) + output = normalized * (one + scale) + shift + + tl.store(out_ptrs, output, mask=mask) + + +@triton.jit +def fuse_scale_shift_kernel_blc_opt( + x_ptr, + shift_ptr, + scale_ptr, + y_ptr, + B, + L, + C, + stride_x_b, + stride_x_l, + stride_x_c, + stride_s_b, + stride_s_l, + stride_s_c, + stride_sc_b, + stride_sc_l, + stride_sc_c, + SCALE_IS_SCALAR: tl.constexpr, + SHIFT_IS_SCALAR: tl.constexpr, + BLOCK_L: tl.constexpr, + BLOCK_C: tl.constexpr, +): + pid_l = tl.program_id(0) + pid_c = tl.program_id(1) + pid_b = tl.program_id(2) + + l_offsets = pid_l * BLOCK_L + tl.arange(0, BLOCK_L) + c_offsets = pid_c * BLOCK_C + tl.arange(0, BLOCK_C) + + mask_l = l_offsets < L + mask_c = c_offsets < C + mask = mask_l[:, None] & mask_c[None, :] + + x_off = ( + pid_b * stride_x_b + + l_offsets[:, None] * stride_x_l + + c_offsets[None, :] * stride_x_c + ) + x = tl.load(x_ptr + x_off, mask=mask, other=0) + + if SHIFT_IS_SCALAR: + shift_val = tl.load(shift_ptr) + shift = tl.full((BLOCK_L, BLOCK_C), shift_val, dtype=shift_val.dtype) + else: + s_off = ( + pid_b * stride_s_b + + l_offsets[:, None] * stride_s_l + + c_offsets[None, :] * stride_s_c + ) + shift = tl.load(shift_ptr + s_off, mask=mask, other=0) + + if SCALE_IS_SCALAR: + scale_val = tl.load(scale_ptr) + scale = tl.full((BLOCK_L, BLOCK_C), scale_val, dtype=scale_val.dtype) + else: + sc_off = ( + pid_b * stride_sc_b + + l_offsets[:, None] * stride_sc_l + + c_offsets[None, :] * stride_sc_c + ) + scale = tl.load(scale_ptr + sc_off, mask=mask, other=0) + + y = x * (1 + scale) + shift + tl.store(y_ptr + x_off, y, mask=mask) + + +def fuse_scale_shift_kernel( + x: torch.Tensor, + scale: torch.Tensor, + shift: torch.Tensor, + block_l: int = 128, + block_c: int = 128, +): + assert x.is_cuda and scale.is_cuda + assert x.is_contiguous() + + B, L, C = x.shape + output = torch.empty_like(x) + + if scale.dim() == 4: + # scale/shift: [B, F, 1, C] + rows = B * L + x_2d = x.view(rows, C) + output_2d = output.view(rows, C) + grid = lambda META: (rows, triton.cdiv(C, META["BLOCK_N"])) + num_frames = scale.shape[1] + assert ( + L % num_frames == 0 + ), "seq_len must be divisible by num_frames for 4D scale/shift" + frame_seqlen = L // num_frames + + # Compact [B, F, C] without the singleton dim into [B*F, C] + scale_reshaped = scale.squeeze(2).reshape(-1, C).contiguous() + shift_reshaped = shift.squeeze(2).reshape(-1, C).contiguous() + + _fused_scale_shift_4d_kernel[grid]( + output_2d, + x_2d, + scale_reshaped, + shift_reshaped, + rows, + C, + L, + num_frames, + frame_seqlen, + ) + else: + # 2D: [B, C] or [1, C] -> treat as [B, 1, C] and broadcast over L + # 3D: [B, L, C] (or broadcastable variants like [B, 1, C], [1, L, C], [1, 1, C]) + # Also support scalar (0D or 1-element) + if scale.dim() == 0 or (scale.dim() == 1 and scale.numel() == 1): + scale_blc = scale.reshape(1) + elif scale.dim() == 2: + scale_blc = scale[:, None, :] + elif scale.dim() == 3: + scale_blc = scale + else: + raise ValueError("scale must be 0D/1D(1)/2D/3D or 4D") + + if shift.dim() == 0 or (shift.dim() == 1 and shift.numel() == 1): + shift_blc = shift.reshape(1) + elif shift.dim() == 2: + shift_blc = shift[:, None, :] + elif shift.dim() == 3: + shift_blc = shift + else: + # broadcast later via expand if possible + shift_blc = shift + + need_scale_scalar = scale_blc.dim() == 1 and scale_blc.numel() == 1 + need_shift_scalar = shift_blc.dim() == 1 and shift_blc.numel() == 1 + + if not need_scale_scalar: + scale_exp = scale_blc.expand(B, L, C) + s_sb, s_sl, s_sc = scale_exp.stride() + else: + s_sb = s_sl = s_sc = 0 + + if not need_shift_scalar: + shift_exp = shift_blc.expand(B, L, C) + sh_sb, sh_sl, sh_sc = shift_exp.stride() + else: + sh_sb = sh_sl = sh_sc = 0 + + # If both scalars and both zero, copy fast-path + if need_scale_scalar and need_shift_scalar: + if (scale_blc.abs().max() == 0) and (shift_blc.abs().max() == 0): + output.copy_(x) + return output + + grid = (triton.cdiv(L, block_l), triton.cdiv(C, block_c), B) + fuse_scale_shift_kernel_blc_opt[grid]( + x, + shift_blc if need_shift_scalar else shift_exp, + scale_blc if need_scale_scalar else scale_exp, + output, + B, + L, + C, + x.stride(0), + x.stride(1), + x.stride(2), + sh_sb, + sh_sl, + sh_sc, + s_sb, + s_sl, + s_sc, + SCALE_IS_SCALAR=need_scale_scalar, + SHIFT_IS_SCALAR=need_shift_scalar, + BLOCK_L=block_l, + BLOCK_C=block_c, + num_warps=4, + num_stages=2, + ) + return output + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_HS_HALF": 32}, num_warps=2), + triton.Config({"BLOCK_HS_HALF": 64}, num_warps=4), + triton.Config({"BLOCK_HS_HALF": 128}, num_warps=4), + triton.Config({"BLOCK_HS_HALF": 256}, num_warps=8), + ], + key=["head_size", "interleaved"], +) +@triton.jit +def _rotary_embedding_kernel( + output_ptr, + x_ptr, + cos_ptr, + sin_ptr, + num_heads, + head_size, + num_tokens, + stride_x_row, + stride_cos_row, + stride_sin_row, + interleaved: tl.constexpr, + BLOCK_HS_HALF: tl.constexpr, +): + row_idx = tl.program_id(0) + token_idx = (row_idx // num_heads) % num_tokens + + x_row_ptr = x_ptr + row_idx * stride_x_row + cos_row_ptr = cos_ptr + token_idx * stride_cos_row + sin_row_ptr = sin_ptr + token_idx * stride_sin_row + output_row_ptr = output_ptr + row_idx * stride_x_row + + # half size for x1 and x2 + head_size_half = head_size // 2 + + for block_start in range(0, head_size_half, BLOCK_HS_HALF): + offsets_half = block_start + tl.arange(0, BLOCK_HS_HALF) + mask = offsets_half < head_size_half + + cos_vals = tl.load(cos_row_ptr + offsets_half, mask=mask, other=0.0) + sin_vals = tl.load(sin_row_ptr + offsets_half, mask=mask, other=0.0) + + offsets_x1 = 2 * offsets_half + offsets_x2 = 2 * offsets_half + 1 + + x1_vals = tl.load(x_row_ptr + offsets_x1, mask=mask, other=0.0) + x2_vals = tl.load(x_row_ptr + offsets_x2, mask=mask, other=0.0) + + x1_fp32 = x1_vals.to(tl.float32) + x2_fp32 = x2_vals.to(tl.float32) + cos_fp32 = cos_vals.to(tl.float32) + sin_fp32 = sin_vals.to(tl.float32) + o1_vals = tl.fma(-x2_fp32, sin_fp32, x1_fp32 * cos_fp32) + o2_vals = tl.fma(x1_fp32, sin_fp32, x2_fp32 * cos_fp32) + + tl.store(output_row_ptr + offsets_x1, o1_vals.to(x1_vals.dtype), mask=mask) + tl.store(output_row_ptr + offsets_x2, o2_vals.to(x2_vals.dtype), mask=mask) + + +def apply_rotary_embedding( + x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False +) -> torch.Tensor: + output = torch.empty_like(x) + + if x.dim() > 3: + bsz, num_tokens, num_heads, head_size = x.shape + else: + num_tokens, num_heads, head_size = x.shape + bsz = 1 + + assert head_size % 2 == 0, "head_size must be divisible by 2" + + x_reshaped = x.view(-1, head_size) + output_reshaped = output.view(-1, head_size) + + # num_tokens per head, 1 token per block + grid = (bsz * num_tokens * num_heads,) + + if interleaved and cos.shape[-1] == head_size: + cos = cos[..., ::2].contiguous() + sin = sin[..., ::2].contiguous() + else: + cos = cos.contiguous() + sin = sin.contiguous() + + _rotary_embedding_kernel[grid]( + output_reshaped, + x_reshaped, + cos, + sin, + num_heads, + head_size, + num_tokens, + x_reshaped.stride(0), + cos.stride(0), + sin.stride(0), + interleaved, + ) + + return output + + +# RMSNorm-fp32 +def maybe_contiguous_lastdim(x): + return x.contiguous() if x is not None and x.stride(-1) != 1 else x + + +def maybe_contiguous(x): + return x.contiguous() if x is not None else None + + +def triton_autotune_configs(): + # Return configs with a valid warp count for the current device + configs = [] + # Maximum threads per block is architecture-dependent in theory, but in reality all are 1024 + max_threads_per_block = 1024 + # Default to warp size 32 if not defined by device + warp_size = getattr( + torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32 + ) + # Autotune for warp counts which are powers of 2 and do not exceed thread per block limit + return [ + triton.Config({}, num_warps=warp_count) + for warp_count in [1, 2, 4, 8, 16, 32] + if warp_count * warp_size <= max_threads_per_block + ] + # return [triton.Config({}, num_warps=8)] + + +# Copied from flash-attn +@triton.autotune( + configs=triton_autotune_configs(), + key=[ + "N", + "HAS_RESIDUAL", + "STORE_RESIDUAL_OUT", + "IS_RMS_NORM", + "HAS_BIAS", + "HAS_WEIGHT", + "HAS_X1", + "HAS_W1", + "HAS_B1", + ], +) +# torch compile doesn't like triton.heuristics, so we set these manually when calling the kernel +# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) +# @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None}) +# @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None}) +# @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None}) +@triton.jit +def _layer_norm_fwd_1pass_kernel( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + RESIDUAL, # pointer to the residual + X1, + W1, + B1, + Y1, + RESIDUAL_OUT, # pointer to the residual + ROWSCALE, + SEEDS, # Dropout seeds for each row + DROPOUT_MASK, + DROPOUT_MASK1, + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_res_row, + stride_res_out_row, + stride_x1_row, + stride_y1_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + dropout_p, # Dropout probability + zero_centered_weight, # If true, add 1.0 to the weight + IS_RMS_NORM: tl.constexpr, + BLOCK_N: tl.constexpr, + HAS_RESIDUAL: tl.constexpr, + STORE_RESIDUAL_OUT: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_DROPOUT: tl.constexpr, + STORE_DROPOUT_MASK: tl.constexpr, + HAS_ROWSCALE: tl.constexpr, + HAS_X1: tl.constexpr, + HAS_W1: tl.constexpr, + HAS_B1: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + X += row * stride_x_row + Y += row * stride_y_row + if HAS_RESIDUAL: + RESIDUAL += row * stride_res_row + if STORE_RESIDUAL_OUT: + RESIDUAL_OUT += row * stride_res_out_row + if HAS_X1: + X1 += row * stride_x1_row + if HAS_W1: + Y1 += row * stride_y1_row + # Compute mean and variance + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + if HAS_ROWSCALE: + rowscale = tl.load(ROWSCALE + row).to(tl.float32) + x *= rowscale + if HAS_DROPOUT: + # Compute dropout mask + # 7 rounds is good enough, and reduces register pressure + keep_mask = ( + tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p + ) + x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0) + if STORE_DROPOUT_MASK: + tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N) + if HAS_X1: + x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32) + if HAS_ROWSCALE: + rowscale = tl.load(ROWSCALE + M + row).to(tl.float32) + x1 *= rowscale + if HAS_DROPOUT: + # Compute dropout mask + # 7 rounds is good enough, and reduces register pressure + keep_mask = ( + tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) + > dropout_p + ) + x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0) + if STORE_DROPOUT_MASK: + tl.store(DROPOUT_MASK1 + row * N + cols, keep_mask, mask=cols < N) + x += x1 + if HAS_RESIDUAL: + residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32) + x += residual + if STORE_RESIDUAL_OUT: + tl.store(RESIDUAL_OUT + cols, x, mask=cols < N) + if not IS_RMS_NORM: + mean = tl.sum(x, axis=0) / N + tl.store(Mean + row, mean) + xbar = tl.where(cols < N, x - mean, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + else: + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + if HAS_WEIGHT: + w = tl.load(W + cols, mask=mask).to(tl.float32) + if zero_centered_weight: + w += 1.0 + if HAS_BIAS: + b = tl.load(B + cols, mask=mask).to(tl.float32) + x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + if HAS_WEIGHT: + y = x_hat * w + b if HAS_BIAS else x_hat * w + else: + y = x_hat + b if HAS_BIAS else x_hat + # Write output + tl.store(Y + cols, y, mask=mask) + if HAS_W1: + w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) + if zero_centered_weight: + w1 += 1.0 + if HAS_B1: + b1 = tl.load(B1 + cols, mask=mask).to(tl.float32) + y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1 + tl.store(Y1 + cols, y1, mask=mask) + + +def _layer_norm_fwd( + x: Tensor, + weight: Tensor, + bias: Tensor, + eps: float, + residual: Optional[Tensor] = None, + x1: Optional[Tensor] = None, + weight1: Optional[Tensor] = None, + bias1: Optional[Tensor] = None, + dropout_p: float = 0.0, + rowscale: Optional[Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + residual_dtype: Optional[torch.dtype] = None, + zero_centered_weight: bool = False, + is_rms_norm: bool = False, + return_dropout_mask: bool = False, + out: Optional[Tensor] = None, + residual_out: Optional[Tensor] = None, +) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor): + # Need to wrap to handle the case where residual_out is a alias of x, which makes torch.library + # and torch.compile unhappy. Also allocate memory for out and residual_out if they are None + # so that _layer_norm_fwd_impl doesn't have to return them. + if out is None: + out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) + if residual is not None: + residual_dtype = residual.dtype + if residual_out is None and ( + residual is not None + or (residual_dtype is not None and residual_dtype != x.dtype) + or dropout_p > 0.0 + or rowscale is not None + or x1 is not None + ): + residual_out = torch.empty_like( + x, dtype=residual_dtype if residual_dtype is not None else x.dtype + ) + else: + residual_out = None + y1, mean, rstd, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd_impl( + x, + weight, + bias, + eps, + out, + residual=residual, + x1=x1, + weight1=weight1, + bias1=bias1, + dropout_p=dropout_p, + rowscale=rowscale, + zero_centered_weight=zero_centered_weight, + is_rms_norm=is_rms_norm, + return_dropout_mask=return_dropout_mask, + residual_out=residual_out, + ) + # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0 + if residual_out is None: + residual_out = x + return out, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 + + +# [2025-04-28] torch.library.triton_op ignores the schema argument, but here we need the schema +# since we're returning a tuple of tensors +def _layer_norm_fwd_impl( + x: Tensor, + weight: Optional[Tensor], + bias: Tensor, + eps: float, + out: Tensor, + residual: Optional[Tensor] = None, + x1: Optional[Tensor] = None, + weight1: Optional[Tensor] = None, + bias1: Optional[Tensor] = None, + dropout_p: float = 0.0, + rowscale: Optional[Tensor] = None, + zero_centered_weight: bool = False, + is_rms_norm: bool = False, + return_dropout_mask: bool = False, + residual_out: Optional[Tensor] = None, +) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor): + M, N = x.shape + assert x.stride(-1) == 1 + if residual is not None: + assert residual.stride(-1) == 1 + assert residual.shape == (M, N) + if weight is not None: + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + if x1 is not None: + assert x1.shape == x.shape + assert rowscale is None + assert x1.stride(-1) == 1 + if weight1 is not None: + assert weight1.shape == (N,) + assert weight1.stride(-1) == 1 + if bias1 is not None: + assert bias1.shape == (N,) + assert bias1.stride(-1) == 1 + if rowscale is not None: + assert rowscale.is_contiguous() + assert rowscale.shape == (M,) + assert out.shape == x.shape + assert out.stride(-1) == 1 + if residual_out is not None: + assert residual_out.shape == x.shape + assert residual_out.stride(-1) == 1 + if weight1 is not None: + y1 = torch.empty_like(out) + assert y1.stride(-1) == 1 + else: + y1 = None + mean = ( + torch.empty((M,), dtype=torch.float32, device=x.device) + if not is_rms_norm + else None + ) + rstd = torch.empty((M,), dtype=torch.float32, device=x.device) + if dropout_p > 0.0: + seeds = torch.randint( + 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64 + ) + else: + seeds = None + if return_dropout_mask and dropout_p > 0.0: + dropout_mask = torch.empty(M, N, device=x.device, dtype=torch.bool) + if x1 is not None: + dropout_mask1 = torch.empty(M, N, device=x.device, dtype=torch.bool) + else: + dropout_mask1 = None + else: + dropout_mask, dropout_mask1 = None, None + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + with torch.cuda.device(x.device.index): + torch.library.wrap_triton(_layer_norm_fwd_1pass_kernel)[(M,)]( + x, + out, + weight if weight is not None else x, # unused when HAS_WEIGHT == False + bias, + residual, + x1, + weight1, + bias1, + y1, + residual_out, + rowscale, + seeds, + dropout_mask, + dropout_mask1, + mean, + rstd, + x.stride(0), + out.stride(0), + residual.stride(0) if residual is not None else 0, + residual_out.stride(0) if residual_out is not None else 0, + x1.stride(0) if x1 is not None else 0, + y1.stride(0) if y1 is not None else 0, + M, + N, + eps, + dropout_p, + # Passing bool make torch inductor very unhappy since it then tries to compare to int_max + int(zero_centered_weight), + is_rms_norm, + BLOCK_N, + residual is not None, + residual_out is not None, + weight is not None, + bias is not None, + dropout_p > 0.0, + dropout_mask is not None, + rowscale is not None, + HAS_X1=x1 is not None, + HAS_W1=weight1 is not None, + HAS_B1=bias1 is not None, + ) + return y1, mean, rstd, seeds, dropout_mask, dropout_mask1 + + +class LayerNormFn: + + @staticmethod + def forward( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + residual_in_fp32=False, + zero_centered_weight=False, + is_rms_norm=False, + return_dropout_mask=False, + out_dtype=None, + out=None, + residual_out=None, + ): + x_shape_og = x.shape + # reshape input data into 2D tensor + x = maybe_contiguous_lastdim(x.reshape(-1, x.shape[-1])) + if residual is not None: + assert residual.shape == x_shape_og + residual = maybe_contiguous_lastdim( + residual.reshape(-1, residual.shape[-1]) + ) + if x1 is not None: + assert x1.shape == x_shape_og + assert rowscale is None, "rowscale is not supported with parallel LayerNorm" + x1 = maybe_contiguous_lastdim(x1.reshape(-1, x1.shape[-1])) + # weight can be None when elementwise_affine=False for LayerNorm + if weight is not None: + weight = weight.contiguous() + bias = maybe_contiguous(bias) + weight1 = maybe_contiguous(weight1) + bias1 = maybe_contiguous(bias1) + if rowscale is not None: + rowscale = rowscale.reshape(-1).contiguous() + residual_dtype = ( + residual.dtype + if residual is not None + else (torch.float32 if residual_in_fp32 else None) + ) + if out is not None: + out = out.reshape(-1, out.shape[-1]) + if residual_out is not None: + residual_out = residual_out.reshape(-1, residual_out.shape[-1]) + y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = ( + _layer_norm_fwd( + x, + weight, + bias, + eps, + residual, + x1, + weight1, + bias1, + dropout_p=dropout_p, + rowscale=rowscale, + out_dtype=out_dtype, + residual_dtype=residual_dtype, + zero_centered_weight=zero_centered_weight, + is_rms_norm=is_rms_norm, + return_dropout_mask=return_dropout_mask, + out=out, + residual_out=residual_out, + ) + ) + y = y.reshape(x_shape_og) + return y + + +def layer_norm_fn( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + residual_in_fp32=False, + zero_centered_weight=False, + is_rms_norm=False, + return_dropout_mask=False, + out_dtype=None, + out=None, + residual_out=None, +): + return LayerNormFn.forward( + x, + weight, + bias, + residual, + x1, + weight1, + bias1, + eps, + dropout_p, + rowscale, + prenorm, + residual_in_fp32, + zero_centered_weight, + is_rms_norm, + return_dropout_mask, + out_dtype, + out, + residual_out, + ) + + +@triton.jit +def _norm_infer_kernel( + X, + Y, + W, + B, + stride_x_row, + stride_y_row, + M, + N, + eps, + IS_RMS_NORM: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + HAS_BIAS: tl.constexpr, + BLOCK_N: tl.constexpr, +): + row = tl.program_id(0) + X += row * stride_x_row + Y += row * stride_y_row + if HAS_WEIGHT: + W += 0 + if HAS_BIAS: + B += 0 + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + if not IS_RMS_NORM: + mean = tl.sum(x, axis=0) / N + xbar = tl.where(cols < N, x - mean, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + else: + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + if HAS_WEIGHT: + w = tl.load(W + cols, mask=cols < N, other=1.0).to(tl.float32) + y = x_hat * w + else: + y = x_hat + if HAS_BIAS: + b = tl.load(B + cols, mask=cols < N, other=0.0).to(tl.float32) + y += b + tl.store(Y + cols, y, mask=cols < N) + + +def norm_infer( + x: Tensor, + weight: Optional[Tensor], + bias: Optional[Tensor], + eps: float, + is_rms_norm: bool = False, + out: Optional[Tensor] = None, +): + M, N = x.shape + assert x.stride(-1) == 1 + if weight is not None: + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.shape == (N,) + assert bias.stride(-1) == 1 + if out is None: + out = torch.empty_like(x) + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + num_warps = min(max(BLOCK_N // 256, 1), 8) + _norm_infer_kernel[(M,)]( + x, + out, + weight if weight is not None else x, # dummy when HAS_WEIGHT=False + bias if bias is not None else x, # dummy when HAS_BIAS=False + x.stride(0), + out.stride(0), + M, + N, + eps, + IS_RMS_NORM=is_rms_norm, + HAS_WEIGHT=weight is not None, + HAS_BIAS=bias is not None, + BLOCK_N=BLOCK_N, + num_warps=num_warps, + ) + return out + + +def rms_norm_fn( + x, + weight, + bias, + residual=None, + x1=None, + weight1=None, + bias1=None, + eps=1e-6, + dropout_p=0.0, + rowscale=None, + prenorm=False, + residual_in_fp32=False, + zero_centered_weight=False, + return_dropout_mask=False, + out_dtype=None, + out=None, + residual_out=None, +): + return LayerNormFn.forward( + x, + weight, + bias, + residual, + x1, + weight1, + bias1, + eps, + dropout_p, + rowscale, + prenorm, + residual_in_fp32, + zero_centered_weight, + True, + return_dropout_mask, + out_dtype, + out, + residual_out, + ) diff --git a/python/sglang/multimodal_gen/runtime/layers/usp.py b/python/sglang/multimodal_gen/runtime/layers/usp.py new file mode 100644 index 000000000..4f3804c91 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/layers/usp.py @@ -0,0 +1,255 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +import logging +from typing import TYPE_CHECKING + +import torch +import torch.distributed._functional_collectives as ft_c +from packaging.version import parse +from torch.distributed.tensor.experimental._attention import _cp_options + +from sglang.multimodal_gen.runtime.distributed.parallel_state import ( + get_sp_group, + get_ulysses_parallel_world_size, +) + +_cp_options.enable_load_balance = False + +if TYPE_CHECKING: + from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import ( + AttentionImpl, + ) + +logger = logging.getLogger(__name__) + + +def _maybe_wait(tensor: torch.Tensor) -> torch.Tensor: + """ + When tracing the code, the result tensor is not an AsyncCollectiveTensor, + so we cannot call ``wait()``. + """ + if isinstance(tensor, ft_c.AsyncCollectiveTensor): + return tensor.wait() + return tensor + + +def _usp_all_to_all_single(x: torch.Tensor) -> torch.Tensor: + ulysses_pg = get_sp_group().ulysses_group + assert ulysses_pg is not None, "Ulysses process group is not initialized." + x_shape = x.shape + x = x.flatten() + x = ft_c.all_to_all_single( + x, output_split_sizes=None, input_split_sizes=None, group=ulysses_pg + ) + x = _maybe_wait(x) + x = x.reshape(x_shape) + return x + + +def _usp_input_all_to_all(x: torch.Tensor, head_dim: int = 1) -> torch.Tensor: + """ + Perform Ulysses-style input all-to-all over the head dimension. + + Default layout expects heads at dim=1 and sequence at dim=2: + [b, h, s_local, d] -> [b, h // world_size, s_global, d] + + If heads are at dim=2 (input is [b, s_local, h, d]), set head_dim=2, and the + function returns [b, s_global, h // world_size, d], preserving the original + head/sequence dim ordering. + + Args: + x: A 4D tensor with layout [b, *, *, d] where '*' are sequence and heads + head_dim: Which dimension index corresponds to heads (1 or 2) + + Returns: + Tensor with the same dim order as input, with heads sharded and sequence gathered. + """ + world_size = get_ulysses_parallel_world_size() + if world_size <= 1: + return x + + assert x.ndim == 4, f"x must have 4 dimensions, got {x.ndim}" + assert head_dim in (1, 2), f"head_dim must be 1 or 2, got {head_dim}" + seq_dim = 1 if head_dim == 2 else 2 + + # Bring to canonical [b, h, s, d] + if head_dim == 1 and seq_dim == 2: + x_c = x + else: + x_c = x.permute(0, head_dim, seq_dim, 3).contiguous() + + b, h, s, d = x_c.shape + assert ( + h % world_size == 0 + ), f"h ({h}) must be divisible by world_size ({world_size})" + + # [b, h, s, d] -> [h, b, s, d] + x_c = x_c.permute(1, 0, 2, 3).contiguous() + # all-to-all along h + x_c = _usp_all_to_all_single(x_c) + # -> [b, h // world, s * world, d] + x_c = ( + x_c.reshape(world_size, h // world_size, b, -1, d) + .permute(2, 1, 0, 3, 4) + .reshape(b, h // world_size, -1, d) + ) + + if head_dim == 1 and seq_dim == 2: + return x_c + + # Map back to original ordering, preserving head/seq positions + new_order = [0, None, None, 3] + new_order[head_dim] = 1 + new_order[seq_dim] = 2 + return x_c.permute(tuple(new_order)).contiguous() + + +def _usp_output_all_to_all(x: torch.Tensor, head_dim: int = 1) -> torch.Tensor: + """ + Perform Ulysses-style output all-to-all over the head dimension (inverse of input). + + Default layout expects heads at dim=1 and sequence at dim=2: + [b, h // world_size, s_global, d] -> [b, h, s_local, d] + + If heads are at dim=2 (input is [b, s_global, h // world_size, d]), set head_dim=2, + and the function returns [b, s_local, h, d], preserving the original head/sequence + dim ordering. + + Args: + x: A 4D tensor with layout [b, *, *, d] where '*' are sequence and heads + head_dim: Which dimension index corresponds to heads (1 or 2) + + Returns: + Tensor with the same dim order as input, with heads gathered and sequence sharded. + """ + world_size = get_ulysses_parallel_world_size() + if world_size <= 1: + return x + + assert x.ndim == 4, f"x must have 4 dimensions, got {x.ndim}" + assert head_dim in (1, 2), f"head_dim must be 1 or 2, got {head_dim}" + seq_dim = 1 if head_dim == 2 else 2 + + # Bring to canonical [b, h, s, d] + if head_dim == 1 and seq_dim == 2: + x_c = x + else: + x_c = x.permute(0, head_dim, seq_dim, 3).contiguous() + + b, h, s, d = x_c.shape + assert ( + s % world_size == 0 + ), f"s ({s}) must be divisible by world_size ({world_size})" + + # [b, h, s, d] -> [s, b, h, d] + x_c = x_c.permute(2, 0, 1, 3).contiguous() + x_c = _usp_all_to_all_single(x_c) + # -> [b, h * world, s // world, d] + x_c = ( + x_c.reshape(world_size, s // world_size, b, -1, d) + .permute(2, 0, 3, 1, 4) + .reshape(b, -1, s // world_size, d) + ) + + if head_dim == 1 and seq_dim == 2: + return x_c + + # Map back to original ordering, preserving head/seq positions + new_order = [0, None, None, 3] + new_order[head_dim] = 1 + new_order[seq_dim] = 2 + return x_c.permute(tuple(new_order)).contiguous() + + +def ring_attn( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_impl: "AttentionImpl", + is_causal: bool = False, + dropout_p: float = 0.0, +): + """ + Ring Attention implementation. + + This function implements Ring Attention, a strategy for distributed attention + computation that reduces peak memory usage. It accepts a generic attention + implementation (`attn_impl`) which is called by the underlying PyTorch + distributed attention primitive. + + Args: + query, key, value: The input tensors for attention. + attn_impl: An instance of an attention implementation backend + (e.g., FlashAttentionImpl) whose `forward` method will be + used as the computational kernel. + is_causal: Whether to apply causal masking. + dropout_p: Dropout probability. + """ + # torch.distributed.tensor.experimental._attention is not a public API, + from torch.distributed.tensor.experimental._attention import ( + _templated_ring_attention, + ) + + ring_pg = get_sp_group().ring_group + assert ring_pg is not None, "Ring process group is not initialized." + + # Ring attention primitives expect tensors in [B, H, S, D] layout. + # We permute the inputs here. + query = torch.permute(query, [0, 2, 1, 3]).contiguous() + key = torch.permute(key, [0, 2, 1, 3]).contiguous() + value = torch.permute(value, [0, 2, 1, 3]).contiguous() + + # Create an adapter function that matches the signature expected by + # _templated_ring_attention. The `attn_impl` already has dropout and + # causal settings configured during its initialization. + + # Note: Please be aware that Attention Backend and Ring Attention may require different QKV tensor shapes. + # For example, FlashAttention expects the format to be BSHD. + def attn_callable_adapter(q, k, v, *args, **kwargs): + # We ignore the dropout_p and is_causal passed by _templated_ring_attention + # and rely on the pre-configured attn_impl. + # The `attn_metadata` is not available here, so we pass None. + # This is a limitation we must accept when using this experimental API. + q = torch.permute(q, [0, 2, 1, 3]) + k = torch.permute(k, [0, 2, 1, 3]) + v = torch.permute(v, [0, 2, 1, 3]) + # logger.warning(f"Warning: return_s·oftmax_lse is only supported for FlashAttentionImpl") + output, softmax_lse, *rest = attn_impl.forward( + q, + k, + v, + attn_metadata=None, + return_softmax_lse=True, + ) + output = torch.permute(output, [0, 2, 1, 3]) + return output, softmax_lse, *rest + + # Starting from torch 2.6.0, _templated_ring_attention expects an integer + # segment_id for the attention function. + use_segment_id = parse(torch.__version__).release >= parse("2.6.0").release + + attn_kwargs = dict( + mesh=ring_pg, + op=attn_callable_adapter, + dropout_p=dropout_p, + is_causal=is_causal, + query=query, + key=key, + value=value, + ) + + if use_segment_id: + # For torch >= 2.6, segment_id is required. The value '1' is a placeholder + # as we are not using complex segmentation features. + out, *_ = _templated_ring_attention( + seq_dim=1, # segment_id + **attn_kwargs, + ) + else: + out, *_ = _templated_ring_attention( + **attn_kwargs, + ) + + # Permute the output back to [B, S, H, D] layout. + output = torch.permute(out, [0, 2, 1, 3]) + return output diff --git a/python/sglang/multimodal_gen/runtime/layers/utils.py b/python/sglang/multimodal_gen/runtime/layers/utils.py new file mode 100644 index 000000000..615ebc385 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/layers/utils.py @@ -0,0 +1,24 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/layers/utils.py +"""Utility methods for model layers.""" + +import torch + + +def get_token_bin_counts_and_mask( + tokens: torch.Tensor, + vocab_size: int, + num_seqs: int, +) -> tuple[torch.Tensor, torch.Tensor]: + # Compute the bin counts for the tokens. + # vocab_size + 1 for padding. + bin_counts = torch.zeros( + (num_seqs, vocab_size + 1), dtype=torch.long, device=tokens.device + ) + bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens)) + bin_counts = bin_counts[:, :vocab_size] + mask = bin_counts > 0 + + return bin_counts, mask diff --git a/python/sglang/multimodal_gen/runtime/layers/visual_embedding.py b/python/sglang/multimodal_gen/runtime/layers/visual_embedding.py new file mode 100644 index 000000000..715a1b874 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/layers/visual_embedding.py @@ -0,0 +1,186 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +import math + +import torch +import torch.nn as nn + +from sglang.multimodal_gen.runtime.layers.activation import get_act_fn +from sglang.multimodal_gen.runtime.layers.linear import ReplicatedLinear +from sglang.multimodal_gen.runtime.layers.mlp import MLP + + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding + + Image to Patch Embedding using Conv2d + + A convolution based approach to patchifying a 2D image w/ embedding projection. + + Based on the impl in https://github.com/google-research/vision_transformer + + Hacked together by / Copyright 2020 Ross Wightman + + Remove the _assert function in forward function to be compatible with multi-resolution images. + """ + + def __init__( + self, + patch_size=16, + in_chans=3, + embed_dim=768, + norm_layer=None, + flatten=True, + bias=True, + dtype=None, + prefix: str = "", + ): + super().__init__() + # Convert patch_size to 2-tuple + if isinstance(patch_size, list | tuple): + if len(patch_size) == 1: + patch_size = (patch_size[0], patch_size[0]) + else: + patch_size = (patch_size, patch_size) + + self.patch_size = patch_size + self.flatten = flatten + + self.proj = nn.Conv3d( + in_chans, + embed_dim, + kernel_size=patch_size, + stride=patch_size, + bias=bias, + dtype=dtype, + ) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__( + self, + hidden_size, + act_layer="silu", + frequency_embedding_size=256, + max_period=10000, + dtype=None, + freq_dtype=torch.float32, + prefix: str = "", + ): + super().__init__() + self.frequency_embedding_size = frequency_embedding_size + self.max_period = max_period + + self.mlp = MLP( + frequency_embedding_size, + hidden_size, + hidden_size, + act_type=act_layer, + dtype=dtype, + ) + self.freq_dtype = freq_dtype + + def forward( + self, t: torch.Tensor, timestep_seq_len: int | None = None + ) -> torch.Tensor: + t_freq = timestep_embedding( + t, self.frequency_embedding_size, self.max_period, dtype=self.freq_dtype + ).to(self.mlp.fc_in.weight.dtype) + if timestep_seq_len is not None: + t_freq = t_freq.unflatten(0, (1, timestep_seq_len)) + # t_freq = t_freq.to(self.mlp.fc_in.weight.dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +def timestep_embedding( + t: torch.Tensor, + dim: int, + max_period: int = 10000, + dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """ + Create sinusoidal timestep embeddings. + + Args: + t: Tensor of shape [B] with timesteps + dim: Embedding dimension + max_period: Controls the minimum frequency of the embeddings + + Returns: + Tensor of shape [B, dim] with embeddings + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=dtype, device=t.device) + / half + ) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +class ModulateProjection(nn.Module): + """Modulation layer for DiT blocks.""" + + def __init__( + self, + hidden_size: int, + factor: int = 2, + act_layer: str = "silu", + dtype: torch.dtype | None = None, + prefix: str = "", + ): + super().__init__() + self.factor = factor + self.hidden_size = hidden_size + self.linear = ReplicatedLinear( + hidden_size, hidden_size * factor, bias=True, params_dtype=dtype + ) + self.act = get_act_fn(act_layer) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.act(x) + x, _ = self.linear(x) + return x + + +def unpatchify(x, t, h, w, patch_size, channels) -> torch.Tensor: + """ + Convert patched representation back to image space. + + Args: + x: Tensor of shape [B, T*H*W, C*P_t*P_h*P_w] + t, h, w: Temporal and spatial dimensions + + Returns: + Unpatchified tensor of shape [B, C, T*P_t, H*P_h, W*P_w] + """ + assert x.ndim == 3, f"x.ndim: {x.ndim}" + assert len(patch_size) == 3, f"patch_size: {patch_size}" + assert t * h * w == x.shape[1], f"t * h * w: {t * h * w}, x.shape[1]: {x.shape[1]}" + c = channels + pt, ph, pw = patch_size + + x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw)) + x = torch.einsum("nthwcopq->nctohpwq", x) + imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw)) + + return imgs diff --git a/python/sglang/multimodal_gen/runtime/layers/vocab_parallel_embedding.py b/python/sglang/multimodal_gen/runtime/layers/vocab_parallel_embedding.py new file mode 100644 index 000000000..fbddaab40 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/layers/vocab_parallel_embedding.py @@ -0,0 +1,480 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +from collections.abc import Sequence +from dataclasses import dataclass + +import torch +import torch.nn.functional as F +from torch.nn.parameter import Parameter, UninitializedParameter + +from sglang.multimodal_gen.runtime.distributed import ( + divide, + get_tp_rank, + get_tp_world_size, + tensor_model_parallel_all_reduce, +) +from sglang.multimodal_gen.runtime.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, + method_has_implemented_embedding, +) +from sglang.multimodal_gen.runtime.models.parameter import BasevLLMParameter +from sglang.multimodal_gen.runtime.models.utils import set_weight_attrs +from sglang.multimodal_gen.runtime.platforms import current_platform + +DEFAULT_VOCAB_PADDING_SIZE = 64 + + +class UnquantizedEmbeddingMethod(QuantizeMethodBase): + """Unquantized method for embeddings.""" + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + """Create weights for embedding layer.""" + + weight = Parameter( + torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, extra_weight_attrs) + + def apply( + self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None + ) -> torch.Tensor: + return F.linear(x, layer.weight, bias) + + def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor: + return F.embedding(input_, layer.weight) + + +def pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int: + """Pad the vocab size to the given value.""" + return ((vocab_size + pad_to - 1) // pad_to) * pad_to + + +def vocab_range_from_per_partition_vocab_size( + per_partition_vocab_size: int, rank: int, offset: int = 0 +) -> Sequence[int]: + index_f = rank * per_partition_vocab_size + index_l = index_f + per_partition_vocab_size + return index_f + offset, index_l + offset + + +def vocab_range_from_global_vocab_size( + global_vocab_size: int, rank: int, world_size: int, offset: int = 0 +) -> Sequence[int]: + per_partition_vocab_size = divide(global_vocab_size, world_size) + return vocab_range_from_per_partition_vocab_size( + per_partition_vocab_size, rank, offset=offset + ) + + +@dataclass +class VocabParallelEmbeddingShardIndices: + """Indices for a shard of a vocab parallel embedding.""" + + padded_org_vocab_start_index: int + padded_org_vocab_end_index: int + padded_added_vocab_start_index: int + padded_added_vocab_end_index: int + + org_vocab_start_index: int + org_vocab_end_index: int + added_vocab_start_index: int + added_vocab_end_index: int + + @property + def num_org_elements(self) -> int: + return self.org_vocab_end_index - self.org_vocab_start_index + + @property + def num_added_elements(self) -> int: + return self.added_vocab_end_index - self.added_vocab_start_index + + @property + def num_org_elements_padded(self) -> int: + return self.padded_org_vocab_end_index - self.padded_org_vocab_start_index + + @property + def num_added_elements_padded(self) -> int: + return self.padded_added_vocab_end_index - self.padded_added_vocab_start_index + + @property + def num_org_vocab_padding(self) -> int: + return self.num_org_elements_padded - self.num_org_elements + + @property + def num_added_vocab_padding(self) -> int: + return self.num_added_elements_padded - self.num_added_elements + + @property + def num_elements_padded(self) -> int: + return self.num_org_elements_padded + self.num_added_elements_padded + + def __post_init__(self): + # sanity checks + assert self.padded_org_vocab_start_index <= self.padded_org_vocab_end_index + assert self.padded_added_vocab_start_index <= self.padded_added_vocab_end_index + + assert self.org_vocab_start_index <= self.org_vocab_end_index + assert self.added_vocab_start_index <= self.added_vocab_end_index + + assert self.org_vocab_start_index <= self.padded_org_vocab_start_index + assert self.added_vocab_start_index <= self.padded_added_vocab_start_index + assert self.org_vocab_end_index <= self.padded_org_vocab_end_index + assert self.added_vocab_end_index <= self.padded_added_vocab_end_index + + assert self.num_org_elements <= self.num_org_elements_padded + assert self.num_added_elements <= self.num_added_elements_padded + + +@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) +def get_masked_input_and_mask( + input_: torch.Tensor, + org_vocab_start_index: int, + org_vocab_end_index: int, + num_org_vocab_padding: int, + added_vocab_start_index: int, + added_vocab_end_index: int, +) -> tuple[torch.Tensor, torch.Tensor]: + # torch.compile will fuse all of the pointwise ops below + # into a single kernel, making it very fast + org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ < org_vocab_end_index) + added_vocab_mask = (input_ >= added_vocab_start_index) & ( + input_ < added_vocab_end_index + ) + added_offset = ( + added_vocab_start_index + - (org_vocab_end_index - org_vocab_start_index) + - num_org_vocab_padding + ) + valid_offset = (org_vocab_start_index * org_vocab_mask) + ( + added_offset * added_vocab_mask + ) + vocab_mask = org_vocab_mask | added_vocab_mask + input_ = vocab_mask * (input_ - valid_offset) + return input_, ~vocab_mask + + +class VocabParallelEmbedding(torch.nn.Module): + """Embedding parallelized in the vocabulary dimension. + + Adapted from torch.nn.Embedding, note that we pad the vocabulary size to + make sure it is divisible by the number of model parallel GPUs. + + In order to support various loading methods, we ensure that LoRA-added + embeddings are always at the end of TP-sharded tensors. In other words, + we shard base embeddings and LoRA embeddings separately (both padded), + and place them in the same tensor. + In this example, we will have the original vocab size = 1010, + added vocab size = 16 and padding to 64. Therefore, the total + vocab size with padding will be 1088 (because we first pad 1010 to + 1024, add 16, and then pad to 1088). + Therefore, the tensor format looks like the following: + TP1, rank 0 (no sharding): + |< --------BASE-------- >|< -BASE PADDING-- >|< -----LORA------ >|< -LORA PADDING-- >| + corresponding token_id: | 0 | 1 | ... | 1009 | -1 | ... | -1 | 1010 | ... | 1015 | -1 | ... | -1 | + index: | 0 | 1 | ... | 1009 | 1010 | ... | 1023 | 1024 | ... | 1039 | 1040 | ... | 1087 | + + TP2, rank 0: + |< --------------------BASE--------------------- >|< -----LORA------ >|< -LORA PADDING- >| + corresponding token_id: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 1000 | ... | 1015 | -1 | ... | -1 | + index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 520 | ... | 543 | + TP2, rank 1: + |< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >| + corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 | + index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 | + + Args: + num_embeddings: vocabulary size. + embedding_dim: size of hidden state. + params_dtype: type of the parameters. + org_num_embeddings: original vocabulary size (without LoRA). + padding_size: padding size for the vocabulary. + quant_config: quant config for the layer + prefix: full name of the layer in the state dict + """ # noqa: E501 + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + params_dtype: torch.dtype | None = None, + org_num_embeddings: int | None = None, + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): + super().__init__() + + # Keep the input dimensions. + tp_rank = get_tp_rank() + self.tp_size = get_tp_world_size() + self.num_embeddings = num_embeddings + self.padding_size = padding_size + self.org_vocab_size = org_num_embeddings or num_embeddings + num_added_embeddings = num_embeddings - self.org_vocab_size + self.org_vocab_size_padded = pad_vocab_size( + self.org_vocab_size, self.padding_size + ) + self.num_embeddings_padded = pad_vocab_size( + self.org_vocab_size_padded + num_added_embeddings, self.padding_size + ) + assert self.org_vocab_size_padded <= self.num_embeddings_padded + + self.shard_indices = self._get_indices( + self.num_embeddings_padded, + self.org_vocab_size_padded, + self.num_embeddings, + self.org_vocab_size, + tp_rank, + self.tp_size, + ) + self.embedding_dim = embedding_dim + + quant_method = None + if quant_config is not None: + quant_method = quant_config.get_quant_method(self, prefix=prefix) + if quant_method is None: + quant_method = UnquantizedEmbeddingMethod() + + # If we are making an embedding layer, then our quantization linear + # method must implement the embedding operation. If we are another + # layer type like ParallelLMHead, this is not important. + is_embedding_layer = type(self.__class__) is VocabParallelEmbedding + quant_method_implements_embedding = method_has_implemented_embedding( + type(quant_method) + ) + if is_embedding_layer and not quant_method_implements_embedding: + raise NotImplementedError( + f"The class {type(quant_method).__name__} must implement " + "the 'embedding' method, see UnquantizedEmbeddingMethod." + ) + + self.quant_method: QuantizeMethodBase = quant_method + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + # Divide the weight matrix along the vocaburaly dimension. + self.num_added_embeddings = self.num_embeddings - self.org_vocab_size + self.num_embeddings_per_partition = divide( + self.num_embeddings_padded, self.tp_size + ) + assert ( + self.shard_indices.num_elements_padded == self.num_embeddings_per_partition + ) + self.num_org_embeddings_per_partition = ( + self.shard_indices.org_vocab_end_index + - self.shard_indices.org_vocab_start_index + ) + self.num_added_embeddings_per_partition = ( + self.shard_indices.added_vocab_end_index + - self.shard_indices.added_vocab_start_index + ) + + self.quant_method.create_weights( + self, + self.embedding_dim, + [self.num_embeddings_per_partition], + self.embedding_dim, + self.num_embeddings_padded, + params_dtype=params_dtype, + weight_loader=self.weight_loader, + ) + + @classmethod + def _get_indices( + cls, + vocab_size_padded: int, + org_vocab_size_padded: int, + vocab_size: int, + org_vocab_size: int, + tp_rank: int, + tp_size: int, + ) -> VocabParallelEmbeddingShardIndices: + """Get start and end indices for vocab parallel embedding, following the + layout outlined in the class docstring, based on the given tp_rank and + tp_size.""" + num_added_embeddings_padded = vocab_size_padded - org_vocab_size_padded + padded_org_vocab_start_index, padded_org_vocab_end_index = ( + vocab_range_from_global_vocab_size(org_vocab_size_padded, tp_rank, tp_size) + ) + padded_added_vocab_start_index, padded_added_vocab_end_index = ( + vocab_range_from_global_vocab_size( + num_added_embeddings_padded, tp_rank, tp_size, offset=org_vocab_size + ) + ) + # remove padding + org_vocab_start_index = min(padded_org_vocab_start_index, org_vocab_size) + org_vocab_end_index = min(padded_org_vocab_end_index, org_vocab_size) + added_vocab_start_index = min(padded_added_vocab_start_index, vocab_size) + added_vocab_end_index = min(padded_added_vocab_end_index, vocab_size) + return VocabParallelEmbeddingShardIndices( + padded_org_vocab_start_index, + padded_org_vocab_end_index, + padded_added_vocab_start_index, + padded_added_vocab_end_index, + org_vocab_start_index, + org_vocab_end_index, + added_vocab_start_index, + added_vocab_end_index, + ) + + def get_sharded_to_full_mapping(self) -> list[int] | None: + """Get a mapping that can be used to reindex the gathered + logits for sampling. + + During sampling, we gather logits from all ranks. The relationship + of index->token_id will follow the same format as outlined in the class + docstring. However, after the gather, we want to reindex the final + logits tensor to map index->token_id one-to-one (the index is always + equal the token_id it corresponds to). The indices returned by this + method allow us to do that. + """ + if self.tp_size < 2: + return None + + base_embeddings: list[int] = [] + added_embeddings: list[int] = [] + padding: list[int] = [] + for tp_rank in range(self.tp_size): + shard_indices = self._get_indices( + self.num_embeddings_padded, + self.org_vocab_size_padded, + self.num_embeddings, + self.org_vocab_size, + tp_rank, + self.tp_size, + ) + range_start = self.num_embeddings_per_partition * tp_rank + range_end = self.num_embeddings_per_partition * (tp_rank + 1) + base_embeddings.extend( + range(range_start, range_start + shard_indices.num_org_elements) + ) + padding.extend( + range( + range_start + shard_indices.num_org_elements, + range_start + shard_indices.num_org_elements_padded, + ) + ) + added_embeddings.extend( + range( + range_start + shard_indices.num_org_elements_padded, + range_start + + shard_indices.num_org_elements_padded + + shard_indices.num_added_elements, + ) + ) + padding.extend( + range( + range_start + + shard_indices.num_org_elements_padded + + shard_indices.num_added_elements, + range_start + + shard_indices.num_org_elements_padded + + shard_indices.num_added_elements_padded, + ) + ) + assert ( + range_start + + shard_indices.num_org_elements_padded + + shard_indices.num_added_elements_padded + == range_end + ) + ret = base_embeddings + added_embeddings + padding + assert len(ret) == self.num_embeddings_padded + return ret + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + output_dim = getattr(param, "output_dim", None) + packed_dim = getattr(param, "packed_dim", None) + + # If the parameter is a gguf weight, then load it directly. + if getattr(param, "is_gguf_weight_type", None): + param.data.copy_(loaded_weight) + param.weight_type = loaded_weight.item() + return + elif isinstance(param, UninitializedParameter): + shape = list(loaded_weight.shape) + if output_dim is not None: + shape[output_dim] = self.num_embeddings_per_partition + param.materialize(tuple(shape), dtype=loaded_weight.dtype) + + # If parameter does not have output dim, then it should + # be copied onto all gpus (e.g. g_idx for act_order gptq). + if output_dim is None: + assert param.data.shape == loaded_weight.shape + param.data.copy_(loaded_weight) + return + + # Shard indexes for loading the weight + start_idx = self.shard_indices.org_vocab_start_index + shard_size = self.shard_indices.org_vocab_end_index - start_idx + + # If param packed on the same dim we are sharding on, then + # need to adjust offsets of loaded weight by pack_factor. + if packed_dim is not None and packed_dim == output_dim: + packed_factor = ( + param.packed_factor + if isinstance(param, BasevLLMParameter) + else param.pack_factor + ) + assert loaded_weight.shape[output_dim] == ( + self.org_vocab_size // param.packed_factor + ) + start_idx = start_idx // packed_factor + shard_size = shard_size // packed_factor + else: + assert loaded_weight.shape[output_dim] == self.org_vocab_size + + # Copy the data. Select chunk corresponding to current shard. + loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + + param[: loaded_weight.shape[0]].data.copy_(loaded_weight) + param[loaded_weight.shape[0] :].data.fill_(0) + + def forward(self, input_): + if self.tp_size > 1: + # Build the mask. + masked_input, input_mask = get_masked_input_and_mask( + input_, + self.shard_indices.org_vocab_start_index, + self.shard_indices.org_vocab_end_index, + self.shard_indices.num_org_vocab_padding, + self.shard_indices.added_vocab_start_index, + self.shard_indices.added_vocab_end_index, + ) + else: + masked_input = input_ + # Get the embeddings. + output_parallel = self.quant_method.embedding(self, masked_input.long()) + # Mask the output embedding. + if self.tp_size > 1: + output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) + # Reduce across all the model parallel GPUs. + output = tensor_model_parallel_all_reduce(output_parallel) + return output + + def extra_repr(self) -> str: + s = f"num_embeddings={self.num_embeddings_per_partition}" + s += f", embedding_dim={self.embedding_dim}" + s += f", org_vocab_size={self.org_vocab_size}" + s += f", num_embeddings_padded={self.num_embeddings_padded}" + s += f", tp_size={self.tp_size}" + return s diff --git a/python/sglang/multimodal_gen/runtime/loader/__init__.py b/python/sglang/multimodal_gen/runtime/loader/__init__.py new file mode 100644 index 000000000..af2eb7d10 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/loader/__init__.py @@ -0,0 +1 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo diff --git a/python/sglang/multimodal_gen/runtime/loader/component_loader.py b/python/sglang/multimodal_gen/runtime/loader/component_loader.py new file mode 100644 index 000000000..bdd3c4822 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/loader/component_loader.py @@ -0,0 +1,670 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +import dataclasses +import glob +import json +import os +import time +from abc import ABC, abstractmethod +from collections.abc import Generator, Iterable +from copy import deepcopy +from typing import cast + +import torch +import torch.distributed as dist +import torch.nn as nn +from safetensors.torch import load_file as safetensors_load_file +from torch.distributed import init_device_mesh +from transformers import AutoImageProcessor, AutoProcessor, AutoTokenizer +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME + +from sglang.multimodal_gen.configs.models import EncoderConfig +from sglang.multimodal_gen.runtime.distributed import get_local_torch_device +from sglang.multimodal_gen.runtime.loader.fsdp_load import ( + maybe_load_fsdp_model, + shard_model, +) +from sglang.multimodal_gen.runtime.loader.utils import set_default_torch_dtype +from sglang.multimodal_gen.runtime.loader.weight_utils import ( + filter_duplicate_safetensors_files, + filter_files_not_needed_for_inference, + pt_weights_iterator, + safetensors_weights_iterator, +) +from sglang.multimodal_gen.runtime.models.registry import ModelRegistry +from sglang.multimodal_gen.runtime.platforms import current_platform +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import ( + get_config, + get_diffusers_config, +) +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import PRECISION_TO_TYPE + +logger = init_logger(__name__) + + +class ComponentLoader(ABC): + """Base class for loading a specific type of model component.""" + + def __init__(self, device=None) -> None: + self.device = device + + @abstractmethod + def load(self, model_path: str, server_args: ServerArgs, module_name: str): + """ + Load the component based on the model path, architecture, and inference args. + + Args: + model_path: Path to the component model + server_args: ServerArgs + + Returns: + The loaded component + """ + raise NotImplementedError + + @classmethod + def for_module_type( + cls, module_type: str, transformers_or_diffusers: str + ) -> "ComponentLoader": + """ + Factory method to create a component loader for a specific module type. + + Args: + module_type: Type of module (e.g., "vae", "text_encoder", "transformer", "scheduler") + transformers_or_diffusers: Whether the module is from transformers or diffusers + + Returns: + A component loader for the specified module type + """ + # Map of module types to their loader classes and expected library + module_loaders = { + "scheduler": (SchedulerLoader, "diffusers"), + "transformer": (TransformerLoader, "diffusers"), + "transformer_2": (TransformerLoader, "diffusers"), + "vae": (VAELoader, "diffusers"), + "text_encoder": (TextEncoderLoader, "transformers"), + "text_encoder_2": (TextEncoderLoader, "transformers"), + "tokenizer": (TokenizerLoader, "transformers"), + "tokenizer_2": (TokenizerLoader, "transformers"), + "image_processor": (ImageProcessorLoader, "transformers"), + "image_encoder": (ImageEncoderLoader, "transformers"), + "processor": (AutoProcessorLoader, "transformers"), + } + + if module_type in module_loaders: + loader_cls, expected_library = module_loaders[module_type] + # Assert that the library matches what's expected for this module type + assert ( + transformers_or_diffusers == expected_library + ), f"{module_type} must be loaded from {expected_library}, got {transformers_or_diffusers}" + return loader_cls() + + # For unknown module types, use a generic loader + logger.warning( + "No specific loader found for module type: %s. Using generic loader.", + module_type, + ) + return GenericComponentLoader(transformers_or_diffusers) + + +class TextEncoderLoader(ComponentLoader): + """Loader for text encoders.""" + + @dataclasses.dataclass + class Source: + """A source for weights.""" + + model_or_path: str + """The model ID or path.""" + + prefix: str = "" + """A prefix to prepend to all weights.""" + + fall_back_to_pt: bool = True + """Whether .pt weights can be used.""" + + allow_patterns_overrides: list[str] | None = None + """If defined, weights will load exclusively using these patterns.""" + + counter_before_loading_weights: float = 0.0 + counter_after_loading_weights: float = 0.0 + + def _prepare_weights( + self, + model_name_or_path: str, + fall_back_to_pt: bool, + allow_patterns_overrides: list[str] | None, + ) -> tuple[str, list[str], bool]: + """Prepare weights for the model. + + If the model is not local, it will be downloaded.""" + # model_name_or_path = (self._maybe_download_from_modelscope( + # model_name_or_path, revision) or model_name_or_path) + + is_local = os.path.isdir(model_name_or_path) + assert is_local, "Model path must be a local directory" + + use_safetensors = False + index_file = SAFE_WEIGHTS_INDEX_NAME + allow_patterns = ["*.safetensors", "*.bin"] + + if fall_back_to_pt: + allow_patterns += ["*.pt"] + + if allow_patterns_overrides is not None: + allow_patterns = allow_patterns_overrides + + hf_folder = model_name_or_path + + hf_weights_files: list[str] = [] + for pattern in allow_patterns: + hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) + if len(hf_weights_files) > 0: + if pattern == "*.safetensors": + use_safetensors = True + break + + if use_safetensors: + hf_weights_files = filter_duplicate_safetensors_files( + hf_weights_files, hf_folder, index_file + ) + else: + hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files) + + if len(hf_weights_files) == 0: + raise RuntimeError( + f"Cannot find any model weights with `{model_name_or_path}`" + ) + + return hf_folder, hf_weights_files, use_safetensors + + def _get_weights_iterator( + self, source: "Source", to_cpu: bool + ) -> Generator[tuple[str, torch.Tensor], None, None]: + """Get an iterator for the model weights based on the load format.""" + hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( + source.model_or_path, + source.fall_back_to_pt, + source.allow_patterns_overrides, + ) + if use_safetensors: + weights_iterator = safetensors_weights_iterator( + hf_weights_files, to_cpu=to_cpu + ) + else: + weights_iterator = pt_weights_iterator(hf_weights_files, to_cpu=to_cpu) + + if self.counter_before_loading_weights == 0.0: + self.counter_before_loading_weights = time.perf_counter() + # Apply the prefix. + return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator) + + def _get_all_weights( + self, + model: nn.Module, + model_path: str, + to_cpu: bool, + ) -> Generator[tuple[str, torch.Tensor], None, None]: + primary_weights = TextEncoderLoader.Source( + model_path, + prefix="", + fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True), + allow_patterns_overrides=getattr(model, "allow_patterns_overrides", None), + ) + yield from self._get_weights_iterator(primary_weights, to_cpu) + + secondary_weights = cast( + Iterable[TextEncoderLoader.Source], + getattr(model, "secondary_weights", ()), + ) + for source in secondary_weights: + yield from self._get_weights_iterator(source, to_cpu) + + def load(self, model_path: str, server_args: ServerArgs, module_name: str): + """Load the text encoders based on the model path, and inference args.""" + # model_config: PretrainedConfig = get_hf_config( + # model=model_path, + # trust_remote_code=server_args.trust_remote_code, + # revision=server_args.revision, + # model_override_args=None, + # ) + diffusers_pretrained_config = get_config(model_path, trust_remote_code=True) + model_config = get_diffusers_config(model=model_path) + model_config.pop("_name_or_path", None) + model_config.pop("transformers_version", None) + model_config.pop("model_type", None) + model_config.pop("tokenizer_class", None) + model_config.pop("torch_dtype", None) + logger.info("HF model config: %s", model_config) + + def is_not_first_encoder(module_name): + return "2" in module_name + + # TODO(mick): had to throw an exception for different text-encoder arch + if not is_not_first_encoder(module_name): + encoder_config = server_args.pipeline_config.text_encoder_configs[0] + encoder_config.update_model_arch(model_config) + for key, value in diffusers_pretrained_config.__dict__.items(): + setattr(encoder_config.arch_config, key, value) + encoder_dtype = server_args.pipeline_config.text_encoder_precisions[0] + else: + assert len(server_args.pipeline_config.text_encoder_configs) == 2 + encoder_config = server_args.pipeline_config.text_encoder_configs[1] + encoder_config.update_model_arch(model_config) + encoder_dtype = server_args.pipeline_config.text_encoder_precisions[1] + target_device = get_local_torch_device() + # TODO(will): add support for other dtypes + return self.load_model( + model_path, + encoder_config, + target_device, + server_args, + encoder_dtype, + ) + + def load_model( + self, + model_path: str, + model_config: EncoderConfig, + target_device: torch.device, + server_args: ServerArgs, + dtype: str = "fp16", + ): + use_cpu_offload = ( + server_args.text_encoder_cpu_offload + and len(getattr(model_config, "_fsdp_shard_conditions", [])) > 0 + ) + + if server_args.text_encoder_cpu_offload: + target_device = ( + torch.device("mps") + if current_platform.is_mps() + else torch.device("cpu") + ) + + with set_default_torch_dtype(PRECISION_TO_TYPE[dtype]): + with target_device: + architectures = getattr(model_config, "architectures", []) + model_cls, _ = ModelRegistry.resolve_model_cls(architectures) + model = model_cls(model_config) + + weights_to_load = {name for name, _ in model.named_parameters()} + loaded_weights = model.load_weights( + self._get_all_weights(model, model_path, to_cpu=use_cpu_offload) + ) + self.counter_after_loading_weights = time.perf_counter() + logger.info( + "Loading weights took %.2f seconds", + self.counter_after_loading_weights + - self.counter_before_loading_weights, + ) + + # Explicitly move model to target device after loading weights + model = model.to(target_device) + + if use_cpu_offload: + # Disable FSDP for MPS as it's not compatible + if current_platform.is_mps(): + logger.info( + "Disabling FSDP sharding for MPS platform as it's not compatible" + ) + else: + mesh = init_device_mesh( + "cuda", + mesh_shape=(1, dist.get_world_size()), + mesh_dim_names=("offload", "replicate"), + ) + shard_model( + model, + cpu_offload=True, + reshard_after_forward=True, + mesh=mesh["offload"], + fsdp_shard_conditions=model._fsdp_shard_conditions, + pin_cpu_memory=server_args.pin_cpu_memory, + ) + # We only enable strict check for non-quantized models + # that have loaded weights tracking currently. + # if loaded_weights is not None: + weights_not_loaded = weights_to_load - loaded_weights + if weights_not_loaded: + raise ValueError( + "Following weights were not initialized from " + f"checkpoint: {weights_not_loaded}" + ) + + return model.eval() + + +class ImageEncoderLoader(TextEncoderLoader): + + def load(self, model_path: str, server_args: ServerArgs, *args): + """Load the text encoders based on the model path, and inference args.""" + # model_config: PretrainedConfig = get_hf_config( + # model=model_path, + # trust_remote_code=server_args.trust_remote_code, + # revision=server_args.revision, + # model_override_args=None, + # ) + with open(os.path.join(model_path, "config.json")) as f: + model_config = json.load(f) + model_config.pop("_name_or_path", None) + model_config.pop("transformers_version", None) + model_config.pop("torch_dtype", None) + model_config.pop("model_type", None) + logger.info("HF model config: %s", model_config) + + encoder_config = server_args.pipeline_config.image_encoder_config + encoder_config.update_model_arch(model_config) + + if server_args.image_encoder_cpu_offload: + target_device = ( + torch.device("mps") + if current_platform.is_mps() + else torch.device("cpu") + ) + else: + target_device = get_local_torch_device() + # TODO(will): add support for other dtypes + return self.load_model( + model_path, + encoder_config, + target_device, + server_args, + server_args.pipeline_config.image_encoder_precision, + ) + + +class ImageProcessorLoader(ComponentLoader): + """Loader for image processor.""" + + def load(self, model_path: str, server_args: ServerArgs, *args): + """Load the image processor based on the model path, and inference args.""" + logger.info("Loading image processor from %s", model_path) + + image_processor = AutoImageProcessor.from_pretrained(model_path, use_fast=True) + logger.info("Loaded image processor: %s", image_processor.__class__.__name__) + return image_processor + + +class AutoProcessorLoader(ComponentLoader): + """Loader for auto processor.""" + + def load(self, model_path: str, server_args: ServerArgs, *args): + """Load the image processor based on the model path, and inference args.""" + logger.info("Loading auto processor from %s", model_path) + + processor = AutoProcessor.from_pretrained( + model_path, + ) + logger.info("Loaded auto processor: %s", processor.__class__.__name__) + return processor + + +class TokenizerLoader(ComponentLoader): + """Loader for tokenizers.""" + + def load(self, model_path: str, server_args: ServerArgs, *args): + """Load the tokenizer based on the model path, and inference args.""" + logger.info("Loading tokenizer from %s", model_path) + + tokenizer = AutoTokenizer.from_pretrained( + model_path, # "/tokenizer" + # in v0, this was same string as encoder_name "ClipTextModel" + # TODO(will): pass these tokenizer kwargs from inference args? Maybe + # other method of config? + padding_size="right", + ) + logger.info("Loaded tokenizer: %s", tokenizer.__class__.__name__) + return tokenizer + + +class VAELoader(ComponentLoader): + """Loader for VAE.""" + + def load(self, model_path: str, server_args: ServerArgs, *args): + """Load the VAE based on the model path, and inference args.""" + config = get_diffusers_config(model=model_path) + class_name = config.pop("_class_name") + assert ( + class_name is not None + ), "Model config does not contain a _class_name attribute. Only diffusers format is supported." + + server_args.model_paths["vae"] = model_path + + # TODO: abstract these logics + logger.info("HF model config: %s", config) + vae_config = server_args.pipeline_config.vae_config + vae_config.update_model_arch(config) + + # NOTE: some post init logics are only available after updated with config + vae_config.post_init() + + if server_args.vae_cpu_offload: + target_device = ( + torch.device("mps") + if current_platform.is_mps() + else torch.device("cpu") + ) + else: + target_device = get_local_torch_device() + + with set_default_torch_dtype( + PRECISION_TO_TYPE[server_args.pipeline_config.vae_precision] + ): + vae_cls, _ = ModelRegistry.resolve_model_cls(class_name) + vae = vae_cls(vae_config).to(target_device) + + # Find all safetensors files + safetensors_list = glob.glob(os.path.join(str(model_path), "*.safetensors")) + # TODO(PY) + assert ( + len(safetensors_list) == 1 + ), f"Found {len(safetensors_list)} safetensors files in {model_path}" + loaded = safetensors_load_file(safetensors_list[0]) + vae.load_state_dict( + loaded, strict=False + ) # We might only load encoder or decoder + + return vae.eval() + + +class TransformerLoader(ComponentLoader): + """Loader for transformer.""" + + def load(self, model_path: str, server_args: ServerArgs, *args): + """Load the transformer based on the model path, and inference args.""" + config = get_diffusers_config(model=model_path) + hf_config = deepcopy(config) + cls_name = config.pop("_class_name") + if cls_name is None: + raise ValueError( + "Model config does not contain a _class_name attribute. " + "Only diffusers format is supported." + ) + + logger.info("transformer cls_name: %s", cls_name) + if server_args.override_transformer_cls_name is not None: + cls_name = server_args.override_transformer_cls_name + logger.info("Overriding transformer cls_name to %s", cls_name) + + server_args.model_paths["transformer"] = model_path + + # Config from Diffusers supersedes sgl_diffusion's model config + dit_config = server_args.pipeline_config.dit_config + dit_config.update_model_arch(config) + + model_cls, _ = ModelRegistry.resolve_model_cls(cls_name) + + # Find all safetensors files + safetensors_list = glob.glob(os.path.join(str(model_path), "*.safetensors")) + if not safetensors_list: + raise ValueError(f"No safetensors files found in {model_path}") + + # Check if we should use custom initialization weights + custom_weights_path = getattr( + server_args, "init_weights_from_safetensors", None + ) + use_custom_weights = False + + if use_custom_weights: + logger.info( + "Using custom initialization weights from: %s", custom_weights_path + ) + assert ( + custom_weights_path is not None + ), "Custom initialization weights must be provided" + if os.path.isdir(custom_weights_path): + safetensors_list = glob.glob( + os.path.join(str(custom_weights_path), "*.safetensors") + ) + else: + assert custom_weights_path.endswith( + ".safetensors" + ), "Custom initialization weights must be a safetensors file" + safetensors_list = [custom_weights_path] + + logger.info( + "Loading model from %s safetensors files: %s", + len(safetensors_list), + safetensors_list, + ) + + default_dtype = PRECISION_TO_TYPE[server_args.pipeline_config.dit_precision] + + # Load the model using FSDP loader + logger.info("Loading %s, default_dtype: %s", cls_name, default_dtype) + assert server_args.hsdp_shard_dim is not None + model = maybe_load_fsdp_model( + model_cls=model_cls, + init_params={"config": dit_config, "hf_config": hf_config}, + weight_dir_list=safetensors_list, + device=get_local_torch_device(), + hsdp_replicate_dim=server_args.hsdp_replicate_dim, + hsdp_shard_dim=server_args.hsdp_shard_dim, + cpu_offload=server_args.dit_cpu_offload, + pin_cpu_memory=server_args.pin_cpu_memory, + fsdp_inference=server_args.use_fsdp_inference, + # TODO(will): make these configurable + default_dtype=default_dtype, + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + output_dtype=None, + ) + + total_params = sum(p.numel() for p in model.parameters()) + logger.info("Loaded model with %.2fB parameters", total_params / 1e9) + + assert ( + next(model.parameters()).dtype == default_dtype + ), "Model dtype does not match default dtype" + + model = model.eval() + return model + + +class SchedulerLoader(ComponentLoader): + """Loader for scheduler.""" + + def load(self, model_path: str, server_args: ServerArgs, *args): + """Load the scheduler based on the model path, and inference args.""" + config = get_diffusers_config(model=model_path) + + class_name = config.pop("_class_name") + assert ( + class_name is not None + ), "Model config does not contain a _class_name attribute. Only diffusers format is supported." + + scheduler_cls, _ = ModelRegistry.resolve_model_cls(class_name) + + scheduler = scheduler_cls(**config) + if server_args.pipeline_config.flow_shift is not None: + scheduler.set_shift(server_args.pipeline_config.flow_shift) + if server_args.pipeline_config.timesteps_scale is not None: + scheduler.set_timesteps_scale(server_args.pipeline_config.timesteps_scale) + return scheduler + + +class GenericComponentLoader(ComponentLoader): + """Generic loader for components that don't have a specific loader.""" + + def __init__(self, library="transformers") -> None: + super().__init__() + self.library = library + + def load(self, model_path: str, server_args: ServerArgs, *args): + """Load a generic component based on the model path, and inference args.""" + logger.warning( + "Using generic loader for %s with library %s", model_path, self.library + ) + + if self.library == "transformers": + from transformers import AutoModel + + model = AutoModel.from_pretrained( + model_path, + trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, + ) + logger.info( + "Loaded generic transformers model: %s", model.__class__.__name__ + ) + return model + elif self.library == "diffusers": + logger.warning( + "Generic loading for diffusers components is not fully implemented" + ) + + model_config = get_diffusers_config(model=model_path) + logger.info("Diffusers Model config: %s", model_config) + # This is a placeholder - in a real implementation, you'd need to handle this properly + return None + else: + raise ValueError(f"Unsupported library: {self.library}") + + +class PipelineComponentLoader: + """ + Utility class for loading pipeline components. + This replaces the chain of if-else statements in load_pipeline_module. + """ + + @staticmethod + def load_module( + module_name: str, + component_model_path: str, + transformers_or_diffusers: str, + server_args: ServerArgs, + ): + """ + Load a pipeline module. + + Args: + module_name: Name of the module (e.g., "vae", "text_encoder", "transformer", "scheduler") + component_model_path: Path to the component model + transformers_or_diffusers: Whether the module is from transformers or diffusers + + Returns: + The loaded module + """ + logger.info( + "Loading %s using %s from %s", + module_name, + transformers_or_diffusers, + component_model_path, + ) + + # Get the appropriate loader for this module type + loader = ComponentLoader.for_module_type(module_name, transformers_or_diffusers) + + try: + # Load the module + return loader.load(component_model_path, server_args, module_name) + except Exception as e: + logger.error( + f"Error while loading component: {module_name}, {component_model_path=}" + ) + raise e diff --git a/python/sglang/multimodal_gen/runtime/loader/fsdp_load.py b/python/sglang/multimodal_gen/runtime/loader/fsdp_load.py new file mode 100644 index 000000000..d11da7dc6 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/loader/fsdp_load.py @@ -0,0 +1,314 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from torchtune +# Copyright 2024 The TorchTune Authors. +# Copyright 2025 The sgl-diffusion Authors. + +import contextlib +from collections.abc import Callable, Generator +from itertools import chain +from typing import Any + +import torch +from torch import nn +from torch.distributed import DeviceMesh, init_device_mesh +from torch.distributed._tensor import distribute_tensor +from torch.distributed.fsdp import ( + CPUOffloadPolicy, + FSDPModule, + MixedPrecisionPolicy, + fully_shard, +) +from torch.nn.modules.module import _IncompatibleKeys + +from sglang.multimodal_gen.runtime.loader.utils import ( + get_param_names_mapping, + hf_to_custom_state_dict, +) +from sglang.multimodal_gen.runtime.loader.weight_utils import ( + safetensors_weights_iterator, +) +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import set_mixed_precision_policy + +logger = init_logger(__name__) + + +# TODO(PY): move this to utils elsewhere +@contextlib.contextmanager +def set_default_dtype(dtype: torch.dtype) -> Generator[None, None, None]: + """ + Context manager to set torch's default dtype. + + Args: + dtype (torch.dtype): The desired default dtype inside the context manager. + + Returns: + ContextManager: context manager for setting default dtype. + + Example: + >>> with set_default_dtype(torch.bfloat16): + >>> x = torch.tensor([1, 2, 3]) + >>> x.dtype + torch.bfloat16 + + + """ + old_dtype = torch.get_default_dtype() + torch.set_default_dtype(dtype) + try: + yield + finally: + torch.set_default_dtype(old_dtype) + + +# TODO(PY): add compile option +def maybe_load_fsdp_model( + model_cls: type[nn.Module], + init_params: dict[str, Any], + weight_dir_list: list[str], + device: torch.device, + hsdp_replicate_dim: int, + hsdp_shard_dim: int, + default_dtype: torch.dtype, + param_dtype: torch.dtype, + reduce_dtype: torch.dtype, + cpu_offload: bool = False, + fsdp_inference: bool = False, + output_dtype: torch.dtype | None = None, + pin_cpu_memory: bool = True, +) -> torch.nn.Module: + """ + Load the model with FSDP if is training, else load the model without FSDP. + """ + # NOTE(will): cast_forward_inputs=True shouldn't be needed as we are + # manually casting the inputs to the model + mp_policy = MixedPrecisionPolicy( + param_dtype, reduce_dtype, output_dtype, cast_forward_inputs=False + ) + + set_mixed_precision_policy( + param_dtype=param_dtype, + reduce_dtype=reduce_dtype, + output_dtype=output_dtype, + mp_policy=mp_policy, + ) + + with set_default_dtype(default_dtype), torch.device("meta"): + model = model_cls(**init_params) + + # Check if we should use FSDP + use_fsdp = fsdp_inference + + # Disable FSDP for MPS as it's not compatible + from sglang.multimodal_gen.runtime.platforms import current_platform + + if current_platform.is_mps(): + use_fsdp = False + logger.info("Disabling FSDP for MPS platform as it's not compatible") + + if use_fsdp: + world_size = hsdp_replicate_dim * hsdp_shard_dim + if not fsdp_inference: + hsdp_replicate_dim = world_size + hsdp_shard_dim = 1 + + device_mesh = init_device_mesh( + "cuda", + # (Replicate(), Shard(dim=0)) + mesh_shape=(hsdp_replicate_dim, hsdp_shard_dim), + mesh_dim_names=("replicate", "shard"), + ) + shard_model( + model, + cpu_offload=cpu_offload, + reshard_after_forward=True, + mp_policy=mp_policy, + mesh=device_mesh, + fsdp_shard_conditions=model._fsdp_shard_conditions, + pin_cpu_memory=pin_cpu_memory, + ) + + weight_iterator = safetensors_weights_iterator(weight_dir_list) + param_names_mapping_fn = get_param_names_mapping(model.param_names_mapping) + load_model_from_full_model_state_dict( + model, + weight_iterator, + device, + default_dtype, + strict=True, + cpu_offload=cpu_offload, + param_names_mapping=param_names_mapping_fn, + ) + for n, p in chain(model.named_parameters(), model.named_buffers()): + if p.is_meta: + raise RuntimeError(f"Unexpected param or buffer {n} on meta device.") + # Avoid unintended computation graph accumulation during inference + if isinstance(p, torch.nn.Parameter): + p.requires_grad = False + return model + + +def shard_model( + model, + *, + cpu_offload: bool, + reshard_after_forward: bool = True, + mp_policy: MixedPrecisionPolicy | None = MixedPrecisionPolicy(), # noqa + mesh: DeviceMesh | None = None, + fsdp_shard_conditions: list[Callable[[str, nn.Module], bool]] = [], # noqa + pin_cpu_memory: bool = True, +) -> None: + """ + Utility to shard a model with FSDP using the PyTorch Distributed fully_shard API. + + This method will over the model's named modules from the bottom-up and apply shard modules + based on whether they meet any of the criteria from shard_conditions. + + Args: + model (TransformerDecoder): Model to shard with FSDP. + cpu_offload (bool): If set to True, FSDP will offload parameters, gradients, and optimizer + states to CPU. + reshard_after_forward (bool): Whether to reshard parameters and buffers after + the forward pass. Setting this to True corresponds to the FULL_SHARD sharding strategy + from FSDP1, while setting it to False corresponds to the SHARD_GRAD_OP sharding strategy. + mesh (Optional[DeviceMesh]): Device mesh to use for FSDP sharding under multiple parallelism. + Default to None. + fsdp_shard_conditions (List[Callable[[str, nn.Module], bool]]): A list of functions to determine + which modules to shard with FSDP. + pin_cpu_memory (bool): If set to True, FSDP will pin the CPU memory of the offloaded parameters. + + Raises: + ValueError: If no layer modules were sharded, indicating that no shard_condition was triggered. + """ + if fsdp_shard_conditions is None or len(fsdp_shard_conditions) == 0: + logger.warning( + "The FSDP shard condition list is empty or None. No modules will be sharded in %s", + type(model).__name__, + ) + return + + fsdp_kwargs = { + "reshard_after_forward": reshard_after_forward, + "mesh": mesh, + "mp_policy": mp_policy, + } + if cpu_offload: + fsdp_kwargs["offload_policy"] = CPUOffloadPolicy(pin_memory=pin_cpu_memory) + + # iterating in reverse to start with + # lowest-level modules first + num_layers_sharded = 0 + # TODO(will): don't reshard after forward for the last layer to save on the + # all-gather that will immediately happen Shard the model with FSDP, + for n, m in reversed(list(model.named_modules())): + if any([shard_condition(n, m) for shard_condition in fsdp_shard_conditions]): + fully_shard(m, **fsdp_kwargs) + num_layers_sharded += 1 + + if num_layers_sharded == 0: + raise ValueError( + "No layer modules were sharded. Please check if shard conditions are working as expected." + ) + + # Finally shard the entire model to account for any stragglers + fully_shard(model, **fsdp_kwargs) + + +# TODO(PY): device mesh for cfg parallel +def load_model_from_full_model_state_dict( + model: FSDPModule | torch.nn.Module, + full_sd_iterator: Generator[tuple[str, torch.Tensor], None, None], + device: torch.device, + param_dtype: torch.dtype, + strict: bool = False, + cpu_offload: bool = False, + param_names_mapping: Callable[[str], tuple[str, Any, Any]] | None = None, +) -> _IncompatibleKeys: + """ + Converting full state dict into a sharded state dict + and loading it into FSDP model (if training) or normal huggingface model + Args: + model (Union[FSDPModule, torch.nn.Module]): Model to generate fully qualified names for cpu_state_dict + full_sd_iterator (Generator): an iterator yielding (param_name, tensor) pairs + device (torch.device): device used to move full state dict tensors + param_dtype (torch.dtype): dtype used to move full state dict tensors + strict (bool): flag to check if to load the model in strict mode + cpu_offload (bool): flag to check if FSDP offload is enabled + param_names_mapping (Optional[Callable[[str], str]]): a function that maps full param name to sharded param name + Returns: + ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: + * **missing_keys** is a list of str containing the missing keys + * **unexpected_keys** is a list of str containing the unexpected keys + + Raises: + NotImplementedError: If got FSDP with more than 1D. + """ + meta_sd = model.state_dict() + sharded_sd = {} + custom_param_sd, reverse_param_names_mapping = hf_to_custom_state_dict( + full_sd_iterator, param_names_mapping + ) # type: ignore + for target_param_name, full_tensor in custom_param_sd.items(): + meta_sharded_param = meta_sd.get(target_param_name) + if meta_sharded_param is None: + raise ValueError( + f"Parameter {target_param_name} not found in custom model state dict. The hf to custom mapping may be incorrect." + ) + if not hasattr(meta_sharded_param, "device_mesh"): + full_tensor = full_tensor.to(device=device, dtype=param_dtype) + # In cases where parts of the model aren't sharded, some parameters will be plain tensors + sharded_tensor = full_tensor + else: + full_tensor = full_tensor.to(device=device, dtype=param_dtype) + sharded_tensor = distribute_tensor( + full_tensor, + meta_sharded_param.device_mesh, + meta_sharded_param.placements, + ) + if cpu_offload: + sharded_tensor = sharded_tensor.cpu() + sharded_sd[target_param_name] = nn.Parameter(sharded_tensor) + + model.reverse_param_names_mapping = reverse_param_names_mapping + unused_keys = set(meta_sd.keys()) - set(sharded_sd.keys()) + if unused_keys: + logger.warning("Found unloaded parameters in meta state dict: %s", unused_keys) + + # List of allowed parameter name patterns + ALLOWED_NEW_PARAM_PATTERNS = ["gate_compress"] # Can be extended as needed + for new_param_name in unused_keys: + if not any(pattern in new_param_name for pattern in ALLOWED_NEW_PARAM_PATTERNS): + logger.error( + "Unsupported new parameter: %s. Allowed patterns: %s", + new_param_name, + ALLOWED_NEW_PARAM_PATTERNS, + ) + raise ValueError( + f"New parameter '{new_param_name}' is not supported. " + f"Currently only parameters containing {ALLOWED_NEW_PARAM_PATTERNS} are allowed." + ) + meta_sharded_param = meta_sd.get(new_param_name) + if not hasattr(meta_sharded_param, "device_mesh"): + # Initialize with zeros + sharded_tensor = torch.zeros_like( + meta_sharded_param, device=device, dtype=param_dtype + ) + else: + # Initialize with zeros and distribute + full_tensor = torch.zeros_like( + meta_sharded_param, device=device, dtype=param_dtype + ) + sharded_tensor = distribute_tensor( + full_tensor, + meta_sharded_param.device_mesh, + meta_sharded_param.placements, + ) + if cpu_offload: + sharded_tensor = sharded_tensor.cpu() + sharded_sd[new_param_name] = nn.Parameter(sharded_tensor) + + # choose `assign=True` since we cannot call `copy_` on meta tensor + return model.load_state_dict(sharded_sd, strict=strict, assign=True) diff --git a/python/sglang/multimodal_gen/runtime/loader/utils.py b/python/sglang/multimodal_gen/runtime/loader/utils.py new file mode 100644 index 000000000..fe3c2de69 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/loader/utils.py @@ -0,0 +1,103 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +"""Utilities for selecting and loading models.""" +import contextlib +import re +from collections import defaultdict +from collections.abc import Callable, Iterator +from typing import Any + +import torch + +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +@contextlib.contextmanager +def set_default_torch_dtype(dtype: torch.dtype): + """Sets the default torch dtype to the given dtype.""" + old_dtype = torch.get_default_dtype() + torch.set_default_dtype(dtype) + yield + torch.set_default_dtype(old_dtype) + + +def get_param_names_mapping( + mapping_dict: dict[str, str] +) -> Callable[[str], tuple[str, Any, Any]]: + """ + Creates a mapping function that transforms parameter names using regex patterns. + + Args: + mapping_dict (Dict[str, str]): Dictionary mapping regex patterns to replacement patterns + param_name (str): The parameter name to be transformed + + Returns: + Callable[[str], str]: A function that maps parameter names from source to target format + """ + + def mapping_fn(name: str) -> tuple[str, Any, Any]: + # Try to match and transform the name using the regex patterns in mapping_dict + for pattern, replacement in mapping_dict.items(): + match = re.match(pattern, name) + if match: + merge_index = None + total_splitted_params = None + if isinstance(replacement, tuple): + merge_index = replacement[1] + total_splitted_params = replacement[2] + replacement = replacement[0] + name = re.sub(pattern, replacement, name) + return name, merge_index, total_splitted_params + + # If no pattern matches, return the original name + return name, None, None + + return mapping_fn + + +def hf_to_custom_state_dict( + hf_param_sd: dict[str, torch.Tensor] | Iterator[tuple[str, torch.Tensor]], + param_names_mapping: Callable[[str], tuple[str, Any, Any]], +) -> tuple[dict[str, torch.Tensor], dict[str, tuple[str, Any, Any]]]: + """ + Converts a Hugging Face parameter state dictionary to a custom parameter state dictionary. + + Args: + hf_param_sd (Dict[str, torch.Tensor]): The Hugging Face parameter state dictionary + param_names_mapping (Callable[[str], tuple[str, Any, Any]]): A function that maps parameter names from source to target format + + Returns: + custom_param_sd (Dict[str, torch.Tensor]): The custom formatted parameter state dict + reverse_param_names_mapping (Dict[str, Tuple[str, Any, Any]]): Maps back from custom to hf + """ + custom_param_sd = {} + to_merge_params = defaultdict(dict) # type: ignore + reverse_param_names_mapping = {} + if isinstance(hf_param_sd, dict): + hf_param_sd = hf_param_sd.items() # type: ignore + for source_param_name, full_tensor in hf_param_sd: # type: ignore + target_param_name, merge_index, num_params_to_merge = param_names_mapping( + source_param_name + ) + reverse_param_names_mapping[target_param_name] = ( + source_param_name, + merge_index, + num_params_to_merge, + ) + if merge_index is not None: + to_merge_params[target_param_name][merge_index] = full_tensor + if len(to_merge_params[target_param_name]) == num_params_to_merge: + # cat at output dim according to the merge_index order + sorted_tensors = [ + to_merge_params[target_param_name][i] + for i in range(num_params_to_merge) + ] + full_tensor = torch.cat(sorted_tensors, dim=0) + del to_merge_params[target_param_name] + else: + continue + custom_param_sd[target_param_name] = full_tensor + return custom_param_sd, reverse_param_names_mapping diff --git a/python/sglang/multimodal_gen/runtime/loader/weight_utils.py b/python/sglang/multimodal_gen/runtime/loader/weight_utils.py new file mode 100644 index 000000000..7796defd8 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/loader/weight_utils.py @@ -0,0 +1,238 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/model_loader/weight_utils.py +"""Utilities for downloading and initializing model weights.""" +import hashlib +import json +import os +import tempfile +from collections.abc import Generator +from pathlib import Path + +import filelock +import huggingface_hub.constants +import torch +from safetensors.torch import safe_open +from tqdm.auto import tqdm + +from sglang.multimodal_gen.runtime.distributed import get_local_torch_device +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + +# use system-level temp directory for file locks, so that multiple users +# can share the same lock without error. +# lock files in the temp directory will be automatically deleted when the +# system reboots, so users will not complain about annoying lock files +temp_dir = tempfile.gettempdir() + + +def enable_hf_transfer() -> None: + """automatically activates hf_transfer""" + if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ: + try: + # enable hf hub transfer if available + import hf_transfer # type: ignore # noqa + + huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True + except ImportError: + pass + + +enable_hf_transfer() + + +class DisabledTqdm(tqdm): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs, disable=True) + + +def get_lock(model_name_or_path: str | Path, cache_dir: str | None = None): + lock_dir = cache_dir or temp_dir + model_name_or_path = str(model_name_or_path) + os.makedirs(os.path.dirname(lock_dir), exist_ok=True) + model_name = model_name_or_path.replace("/", "-") + hash_name = hashlib.sha256(model_name.encode()).hexdigest() + # add hash to avoid conflict with old users' lock files + lock_file_name = hash_name + model_name + ".lock" + # mode 0o666 is required for the filelock to be shared across users + lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name), mode=0o666) + return lock + + +# For models like Mistral-7B-v0.3, there are both sharded +# safetensors files and a consolidated safetensors file. +# Passing both of these to the weight loader functionality breaks. +# So, we use the index_file to +# look up which safetensors files should be used. +def filter_duplicate_safetensors_files( + hf_weights_files: list[str], hf_folder: str, index_file: str +) -> list[str]: + # model.safetensors.index.json is a mapping from keys in the + # torch state_dict to safetensors file holding that weight. + index_file_name = os.path.join(hf_folder, index_file) + if not os.path.isfile(index_file_name): + return hf_weights_files + + # Iterate through the weight_map (weight_name: safetensors files) + # to identify weights that we should use. + with open(index_file_name) as f: + weight_map = json.load(f)["weight_map"] + weight_files_in_index = set() + for weight_name in weight_map: + weight_files_in_index.add(os.path.join(hf_folder, weight_map[weight_name])) + # Filter out any fields that are not found in the index file. + hf_weights_files = [f for f in hf_weights_files if f in weight_files_in_index] + return hf_weights_files + + +def filter_files_not_needed_for_inference(hf_weights_files: list[str]) -> list[str]: + """ + Exclude files that are not needed for inference. + + See https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233 + """ + blacklist = [ + "training_args.bin", + "optimizer.bin", + "optimizer.pt", + "scheduler.pt", + "scaler.pt", + ] + hf_weights_files = [ + f for f in hf_weights_files if not any(f.endswith(x) for x in blacklist) + ] + return hf_weights_files + + +# explicitly use pure text format, with a newline at the end +# this makes it impossible to see the animation in the progress bar +# but will avoid messing up with ray or multiprocessing, which wraps +# each line of output with some prefix. +_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501 + + +def safetensors_weights_iterator( + hf_weights_files: list[str], + to_cpu: bool = True, +) -> Generator[tuple[str, torch.Tensor], None, None]: + """Iterate over the weights in the model safetensor files.""" + enable_tqdm = ( + not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 + ) + device = "cpu" if to_cpu else str(get_local_torch_device()) + for st_file in tqdm( + hf_weights_files, + desc="Loading safetensors checkpoint shards", + disable=not enable_tqdm, + bar_format=_BAR_FORMAT, + ): + with safe_open(st_file, framework="pt", device=device) as f: + for name in f.keys(): # noqa: SIM118 + param = f.get_tensor(name) + yield name, param + + +def pt_weights_iterator( + hf_weights_files: list[str], + to_cpu: bool = True, +) -> Generator[tuple[str, torch.Tensor], None, None]: + """Iterate over the weights in the model bin/pt files.""" + device = "cpu" if to_cpu else str(get_local_torch_device()) + enable_tqdm = ( + not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 + ) + for bin_file in tqdm( + hf_weights_files, + desc="Loading pt checkpoint shards", + disable=not enable_tqdm, + bar_format=_BAR_FORMAT, + ): + state = torch.load(bin_file, map_location=device, weights_only=True) + yield from state.items() + del state + + +def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: + """Default weight loader.""" + try: + if param.numel() == 1 and loaded_weight.numel() == 1: + # Sometimes scalar values aren't considered tensors with shapes + # so if both param and loaded_weight are a scalar, + # "broadcast" instead of copy + param.data.fill_(loaded_weight.item()) + else: + assert param.size() == loaded_weight.size(), ( + f"Attempted to load weight ({loaded_weight.size()}) " + f"into parameter ({param.size()})" + ) + + param.data.copy_(loaded_weight) + except Exception: + # NOTE: This exception is added for the purpose of setting breakpoint to + # debug weight loading issues. + raise + + +def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None: + """Remap the name of FP8 k/v_scale parameters. + + This function handles the remapping of FP8 k/v_scale parameter names. + It detects if the given name ends with a suffix and attempts to remap + it to the expected name format in the model. If the remapped name is not + found in the params_dict, a warning is printed and None is returned. + + Args: + name (str): The original loaded checkpoint parameter name. + params_dict (dict): Dictionary containing the model's named parameters. + + Returns: + str: The remapped parameter name if successful, or the original name + if no remapping is needed. + None: If the remapped name is not found in params_dict. + """ + if name.endswith(".kv_scale"): + logger.warning_once( + "DEPRECATED. Found kv_scale in the checkpoint. " + "This format is deprecated in favor of separate k_scale and " + "v_scale tensors and will be removed in a future release. " + "Functionally, we will remap kv_scale to k_scale and duplicate " + "k_scale to v_scale" + ) + # NOTE: we remap the deprecated kv_scale to k_scale + remapped_name = name.replace(".kv_scale", ".attn.k_scale") + if remapped_name not in params_dict: + logger.warning_once( + f"Found kv_scale in the checkpoint (e.g. {name}), " + "but not found the expected name in the model " + f"(e.g. {remapped_name}). kv_scale is " + "not loaded." + ) + return None + return remapped_name + + possible_scale_names = [".k_scale", ".v_scale"] + modelopt_scale_names = [".self_attn.k_proj.k_scale", ".self_attn.v_proj.v_scale"] + for scale_name in possible_scale_names: + if name.endswith(scale_name): + if any(mo_scale_name in name for mo_scale_name in modelopt_scale_names): + remapped_name = name.replace( + f".self_attn.{scale_name[1]}_proj{scale_name}", + f".self_attn.attn{scale_name}", + ) + else: + remapped_name = name.replace(scale_name, f".attn{scale_name}") + if remapped_name not in params_dict: + logger.warning_once( + f"Found {scale_name} in the checkpoint (e.g. {name}), " + "but not found the expected name in the model " + f"(e.g. {remapped_name}). {scale_name} is " + "not loaded." + ) + return None + return remapped_name + + # If there were no matches, return the untouched param name + return name diff --git a/python/sglang/multimodal_gen/runtime/managers/forward_context.py b/python/sglang/multimodal_gen/runtime/managers/forward_context.py new file mode 100644 index 000000000..d9d107e69 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/managers/forward_context.py @@ -0,0 +1,120 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/forward_context.py +import time +from collections import defaultdict +from contextlib import contextmanager +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional, Type + +import torch + +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +if TYPE_CHECKING: + from sglang.multimodal_gen.runtime.layers.attention import AttentionMetadata + from sglang.multimodal_gen.runtime.pipelines import Req + +logger = init_logger(__name__) + +# TODO(will): check if this is needed +# track_batchsize: bool = envs.SGL_DIFFUSION_LOG_BATCHSIZE_INTERVAL >= 0 +track_batchsize: bool = False +last_logging_time: float = 0 +forward_start_time: float = 0 +# batchsize_logging_interval: float = envs.SGL_DIFFUSION_LOG_BATCHSIZE_INTERVAL +batchsize_logging_interval: float = 1000 +batchsize_forward_time: defaultdict = defaultdict(list) + + +@dataclass +class ForwardContext: + current_timestep: int + # TODO(will): check this arg + # copy from vllm_config.compilation_config.static_forward_context + # attn_layers: Dict[str, Any] + # TODO: extend to support per-layer dynamic forward context + attn_metadata: "AttentionMetadata" # set dynamically for each forward pass + forward_batch: Optional["Req"] = None + attention_backend_cls: Optional[Type] = None + + def set_attn_backend_cls(self, attention_backend_cls: Type): + if self.attention_backend_cls: + if self.attention_backend_cls != attention_backend_cls: + raise RuntimeError( + f"Different types of attention backend in a same context detected, previous: {self.attention_backend_cls}, new: {attention_backend_cls}" + ) + else: + self.attention_backend_cls = attention_backend_cls + + +_forward_context: Optional["ForwardContext"] = None + + +def get_forward_context() -> "ForwardContext": + """Get the current forward context.""" + assert _forward_context is not None, ( + "Forward context is not set. " + "Please use `set_forward_context` to set the forward context." + ) + return _forward_context + + +# TODO(will): finalize the interface +@contextmanager +def set_forward_context( + current_timestep, attn_metadata, forward_batch: Optional["Req"] = None +): + """A context manager that stores the current forward context, + can be attention metadata, etc. + Here we can inject common logic for every model forward pass. + """ + global forward_start_time + need_to_track_batchsize = track_batchsize and attn_metadata is not None + if need_to_track_batchsize: + forward_start_time = time.perf_counter() + global _forward_context + prev_context = _forward_context + _forward_context = ForwardContext( + current_timestep=current_timestep, + attn_metadata=attn_metadata, + forward_batch=forward_batch, + ) + + try: + yield + finally: + global last_logging_time, batchsize_logging_interval + if need_to_track_batchsize: + if hasattr(attn_metadata, "num_prefill_tokens"): + # for v0 attention backends + batchsize = ( + attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens + ) + else: + # for v1 attention backends + batchsize = attn_metadata.num_input_tokens + now = time.perf_counter() + # time measurement is in milliseconds + batchsize_forward_time[batchsize].append((now - forward_start_time) * 1000) + if now - last_logging_time > batchsize_logging_interval: + last_logging_time = now + forward_stats = [] + for bs, times in batchsize_forward_time.items(): + if len(times) <= 1: + # can be cudagraph / profiling run + continue + medium = torch.quantile(torch.tensor(times), q=0.5).item() + medium = round(medium, 2) + forward_stats.append((bs, len(times), medium)) + forward_stats.sort(key=lambda x: x[1], reverse=True) + if forward_stats: + logger.info( + ( + "Batchsize forward time stats " + "(batchsize, count, median_time(ms)): %s" + ), + forward_stats, + ) + _forward_context = prev_context diff --git a/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py b/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py new file mode 100644 index 000000000..c6606fa8a --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/managers/gpu_worker.py @@ -0,0 +1,171 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +import multiprocessing as mp +import os +from typing import List + +import torch +from setproctitle import setproctitle + +from sglang.multimodal_gen.runtime.distributed import ( + get_sp_group, + maybe_init_distributed_environment_and_model_parallel, +) +from sglang.multimodal_gen.runtime.distributed.parallel_state import ( + get_cfg_group, + get_tp_group, +) +from sglang.multimodal_gen.runtime.pipelines import build_pipeline +from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import OutputBatch, Req +from sglang.multimodal_gen.runtime.server_args import PortArgs, ServerArgs +from sglang.multimodal_gen.runtime.utils.common import set_cuda_arch +from sglang.multimodal_gen.runtime.utils.logging_utils import ( + configure_logger, + init_logger, + suppress_other_loggers, +) + +logger = init_logger(__name__) + +# ANSI color codes +CYAN = "\033[1;36m" +RESET = "\033[0;0m" + + +class GPUWorker: + """ + A worker that executes the model on a single GPU. + """ + + def __init__( + self, + local_rank: int, + rank: int, + master_port: int, + server_args: ServerArgs, + ): + self.local_rank = local_rank + self.rank = rank + self.master_port = master_port + # FIXME: should we use tcp as distribute init method? + self.server_args = server_args + self.pipeline = None + + self.init_device_and_model() + self.sp_group = get_sp_group() + self.sp_cpu_group = self.sp_group.cpu_group + self.tp_group = get_tp_group() + self.tp_cpu_group = self.tp_group.cpu_group + + self.cfg_group = get_cfg_group() + self.cfg_cpu_group = self.cfg_group.cpu_group + + def init_device_and_model(self) -> None: + """Initialize the device and load the model.""" + setproctitle(f"sgl_diffusion::scheduler:{self.local_rank}") + torch.cuda.set_device(self.local_rank) + # Set environment variables for distributed initialization + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(self.master_port) + os.environ["LOCAL_RANK"] = str(self.local_rank) + os.environ["RANK"] = str(self.rank) + os.environ["WORLD_SIZE"] = str(self.server_args.num_gpus) + # Initialize the distributed environment + maybe_init_distributed_environment_and_model_parallel( + tp_size=self.server_args.tp_size, + enable_cfg_parallel=self.server_args.enable_cfg_parallel, + ulysses_degree=self.server_args.ulysses_degree, + ring_degree=self.server_args.ring_degree, + sp_size=self.server_args.sp_degree, + dp_size=self.server_args.dp_size, + ) + + self.pipeline = build_pipeline(self.server_args) + + logger.info( + f"Worker {self.rank}: Initialized device, model, and distributed environment." + ) + + def execute_forward(self, batch: List[Req], server_args: ServerArgs) -> OutputBatch: + """ + Execute a forward pass. + """ + assert self.pipeline is not None + # TODO: dealing with first req for now + req = batch[0] + output_batch = self.pipeline.forward(req, server_args) + if req.perf_logger: + req.perf_logger.log_total_duration("total_inference_time") + return output_batch + + def set_lora_adapter( + self, lora_nickname: str, lora_path: str | None = None + ) -> None: + """ + Set the LoRA adapter for the pipeline. + """ + assert self.pipeline is not None + self.pipeline.set_lora_adapter(lora_nickname, lora_path) + + def merge_lora_weights(self) -> None: + """ + Merge LoRA weights. + """ + assert self.pipeline is not None + self.pipeline.merge_lora_weights() + + def unmerge_lora_weights(self) -> None: + """ + Unmerge LoRA weights. + """ + assert self.pipeline is not None + self.pipeline.unmerge_lora_weights() + + +def run_scheduler_process( + local_rank: int, + rank: int, + master_port: int, + server_args: ServerArgs, + pipe_writer: mp.connection.Connection, + # For all workers: pipe to receive tasks from rank 0 + task_pipe_r: mp.connection.Connection, + # For slave workers: pipe to send results back to rank 0 + result_pipe_w: mp.connection.Connection | None, + # For rank 0 worker only: pipes to send tasks to slaves + task_pipes_to_slaves: list[mp.connection.Connection] | None = None, + # For rank 0 worker only: pipes to receive results from slaves + result_pipes_from_slaves: list[mp.connection.Connection] | None = None, +) -> None: + """ + The entry point for the worker process. + Rank 0 acts as the master, handling ZMQ requests and coordinating slaves. + Ranks > 0 act as slaves, waiting for tasks from the master. + """ + configure_logger(server_args) + suppress_other_loggers() + set_cuda_arch() + + port_args = PortArgs.from_server_args(server_args) + + # start the scheduler event loop + assert task_pipes_to_slaves is not None + assert result_pipes_from_slaves is not None + from sglang.multimodal_gen.runtime.managers.scheduler import Scheduler + + scheduler = Scheduler( + server_args, + gpu_id=rank, + port_args=port_args, + task_pipes_to_slaves=task_pipes_to_slaves, + result_pipes_from_slaves=result_pipes_from_slaves, + ) + logger.info(f"Worker {rank}: Scheduler loop started.") + pipe_writer.send( + { + "status": "ready", + } + ) + scheduler.event_loop() + logger.info(f"Worker {rank}: Shutdown complete.") diff --git a/python/sglang/multimodal_gen/runtime/managers/scheduler.py b/python/sglang/multimodal_gen/runtime/managers/scheduler.py new file mode 100644 index 000000000..d2e07e9b1 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/managers/scheduler.py @@ -0,0 +1,177 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import zmq + +from sglang.multimodal_gen.runtime.managers.gpu_worker import GPUWorker +from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import OutputBatch +from sglang.multimodal_gen.runtime.server_args import ( + PortArgs, + ServerArgs, + set_global_server_args, +) +from sglang.multimodal_gen.runtime.utils.common import get_zmq_socket +from sglang.multimodal_gen.runtime.utils.distributed import broadcast_pyobj +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class Scheduler: + """ + Runs the main event loop for the rank 0 worker. + It listens for external requests via ZMQ and coordinates with other workers. + This class does NOT manage worker processes. + """ + + def __init__( + self, + server_args: ServerArgs, + gpu_id: int, + port_args: PortArgs, + task_pipes_to_slaves: list = None, + result_pipes_from_slaves: list = None, + ): + self.server_args = server_args + self.port_args = port_args + + set_global_server_args(server_args=server_args) + + # Inter-process Communication + self.context = zmq.Context(io_threads=2) + endpoint = server_args.scheduler_endpoint() + logger.info(f"Scheduler listening at endpoint: {endpoint}") + if gpu_id == 0: + self.receiver = get_zmq_socket(self.context, zmq.REP, endpoint, True) + else: + self.receiver = None + + worker = GPUWorker( + local_rank=gpu_id, + master_port=port_args.master_port, + rank=gpu_id, + server_args=server_args, + ) + self.worker = worker + self.task_pipes_to_slaves = task_pipes_to_slaves + self.result_pipes_from_slaves = result_pipes_from_slaves + self.gpu_id = gpu_id + self._running = True + + def return_result(self, output_batch: OutputBatch): + """ + replies to client, only on rank 0 + """ + if self.receiver is not None: + self.receiver.send_pyobj(output_batch) + + def recv_reqs(self): + """ + For non-main schedulers, reqs are broadcasted from main using broadcast_pyobj + """ + if self.receiver is not None: + recv_reqs = self.receiver.recv_pyobj() + assert isinstance(recv_reqs, list) + else: + recv_reqs = None + + # TODO: fix this condition + if self.server_args.sp_degree != 1: + recv_reqs = broadcast_pyobj( + recv_reqs, + self.worker.sp_group.rank, + self.worker.sp_cpu_group, + src=self.worker.sp_group.ranks[0], + ) + + if self.server_args.enable_cfg_parallel: + recv_reqs = broadcast_pyobj( + recv_reqs, + self.worker.cfg_group.rank, + self.worker.cfg_cpu_group, + src=self.worker.cfg_group.ranks[0], + ) + + if self.server_args.tp_size > 1: + recv_reqs = broadcast_pyobj( + recv_reqs, + self.worker.tp_group.rank, + self.worker.tp_cpu_group, + src=self.worker.tp_group.ranks[0], + ) + + assert recv_reqs is not None + + return recv_reqs + + # TODO: queueing, cancellation + def event_loop(self) -> None: + """ + The main event loop that listens for ZMQ requests. + Handles abortion + """ + + logger.info( + f"Rank 0 scheduler listening on tcp://*:{self.server_args.scheduler_port}" + ) + + while self._running: + reqs = None + # 1: receive requests + try: + reqs = self.recv_reqs() + except Exception as e: + logger.error( + f"Error receiving requests in scheduler event loop: {e}", + exc_info=True, + ) + continue + + # 2: execute, make sure a reply is always sent + try: + output_batch = self.worker.execute_forward(reqs, self.server_args) + except Exception as e: + logger.error( + f"Error executing forward in scheduler event loop: {e}", + exc_info=True, + ) + output_batch = OutputBatch(error=str(e)) + + try: + self.return_result(output_batch) + except zmq.ZMQError as e: + # Reply failed; log and keep loop alive to accept future requests + logger.error(f"ZMQ error sending reply: {e}") + continue + + logger.info("Scheduler event loop terminated.") + if self.receiver is not None: + self.receiver.close() + self.context.term() + + def _broadcast_task(self, payload: dict[str, Any]) -> None: + """Broadcast a task to all slave worker processes.""" + method = payload["method"] + kwargs = {k: v for k, v in payload.items() if k != "method"} + task = {"method": method, "kwargs": kwargs} + for pipe in self.task_pipes_to_slaves: + pipe.send(task) + + def _execute_on_rank0(self, payload: dict[str, Any]) -> dict[str, Any]: + """Execute task locally on the rank 0 worker.""" + method = payload["method"] + kwargs = {k: v for k, v in payload.items() if k != "method"} + handler = getattr(self.worker, method, None) + if handler: + result = handler(**kwargs) + return {"status": "ok", "result": result} + return {"status": "error", "error": f"Unknown method: {method}"} + + def _collect_slave_results(self) -> list[dict[str, Any]]: + """Collect results from all slave worker processes.""" + results = [] + for pipe in self.result_pipes_from_slaves: + results.append(pipe.recv()) + return results diff --git a/python/sglang/multimodal_gen/runtime/managers/schedulerbase.py b/python/sglang/multimodal_gen/runtime/managers/schedulerbase.py new file mode 100644 index 000000000..4bf392250 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/managers/schedulerbase.py @@ -0,0 +1,103 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from abc import ABC +from typing import TypeVar + +import zmq + +from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import OutputBatch, Req +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.utils import init_logger + +logger = init_logger(__name__) + +_R = TypeVar("_R") + + +class SchedulerBase(ABC): + """ + Abstract base class for all schedulers. + """ + + def __init__(self, server_args: "ServerArgs"): + """ + Initialize the scheduler. + + Args: + server_args: The inference arguments + """ + self.server_args = server_args + self.context = zmq.Context() + self.socket = self.context.socket(zmq.REQ) + self.socket.connect(self.server_args.scheduler_endpoint()) + + @classmethod + def get_class(cls, server_args: "ServerArgs") -> type["SchedulerBase"]: + """ + Get the scheduler class based on the server arguments. + """ + if server_args.distributed_executor_backend == "mp": + from sglang.multimodal_gen.runtime.managers.scheduler import Scheduler + + # For now, always return the new Scheduler + return Scheduler + else: + raise ValueError( + f"Unsupported distributed executor backend: {server_args.distributed_executor_backend}" + ) + + # @abstractmethod + def start(self) -> None: + """ + Start the scheduler service. + """ + raise NotImplementedError + + def execute_forward(self, batch: Req, server_args: "ServerArgs") -> OutputBatch: + """ + Execute a forward pass. This method now sends a request over ZMQ. + """ + payload = {"method": "execute_forward", "batch": batch} + self.socket.send_pyobj(payload) + output_batch = self.socket.recv_pyobj() + return output_batch + + def set_lora_adapter( + self, lora_nickname: str, lora_path: str | None = None + ) -> None: + """ + Set the LoRA adapter. + """ + payload = { + "method": "set_lora_adapter", + "lora_nickname": lora_nickname, + "lora_path": lora_path, + } + self.socket.send_pyobj(payload) + self.socket.recv_pyobj() # Wait for confirmation + + # @abstractmethod + def unmerge_lora_weights(self) -> None: + """ + Unmerge the LoRA weights for the workers. + """ + raise NotImplementedError + + # @abstractmethod + def merge_lora_weights(self) -> None: + """ + Merge the LoRA weights for the workers. + """ + raise NotImplementedError + + def shutdown(self) -> None: + """ + Shutdown the scheduler. + """ + logger.info("Shutting down scheduler client.") + payload = {"method": "shutdown"} + self.socket.send_pyobj(payload) + self.socket.recv_pyobj() # Wait for shutdown confirmation + self.socket.close() + self.context.term() diff --git a/python/sglang/multimodal_gen/runtime/models/__init__.py b/python/sglang/multimodal_gen/runtime/models/__init__.py new file mode 100644 index 000000000..af2eb7d10 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/models/__init__.py @@ -0,0 +1 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo diff --git a/python/sglang/multimodal_gen/runtime/models/dits/base.py b/python/sglang/multimodal_gen/runtime/models/dits/base.py new file mode 100644 index 000000000..886a6a331 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/models/dits/base.py @@ -0,0 +1,134 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from abc import ABC, abstractmethod +from typing import Any + +import torch +from torch import nn + +from sglang.multimodal_gen.configs.models import DiTConfig +from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum + + +# TODO +class BaseDiT(nn.Module, ABC): + _fsdp_shard_conditions: list = [] + _compile_conditions: list = [] + param_names_mapping: dict + reverse_param_names_mapping: dict + hidden_size: int + num_attention_heads: int + num_channels_latents: int + # always supports torch_sdpa + _supported_attention_backends: set[AttentionBackendEnum] = ( + DiTConfig()._supported_attention_backends + ) + + def __init_subclass__(cls) -> None: + required_class_attrs = [ + "_fsdp_shard_conditions", + "param_names_mapping", + "_compile_conditions", + ] + super().__init_subclass__() + for attr in required_class_attrs: + if not hasattr(cls, attr): + raise AttributeError( + f"Subclasses of BaseDiT must define '{attr}' class variable" + ) + + def __init__(self, config: DiTConfig, hf_config: dict[str, Any], **kwargs) -> None: + super().__init__() + self.config = config + self.hf_config = hf_config + if not self.supported_attention_backends: + raise ValueError( + f"Subclass {self.__class__.__name__} must define _supported_attention_backends" + ) + + @abstractmethod + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | list[torch.Tensor], + timestep: torch.LongTensor, + encoder_hidden_states_image: torch.Tensor | list[torch.Tensor] | None = None, + guidance=None, + **kwargs, + ) -> torch.Tensor: + pass + + def __post_init__(self) -> None: + required_attrs = ["hidden_size", "num_attention_heads", "num_channels_latents"] + for attr in required_attrs: + if not hasattr(self, attr): + raise AttributeError( + f"Subclasses of BaseDiT must define '{attr}' instance variable" + ) + + @property + def supported_attention_backends(self) -> set[AttentionBackendEnum]: + return self._supported_attention_backends + + @property + def device(self) -> torch.device: + """Get the device of the model.""" + return next(self.parameters()).device + + +class CachableDiT(BaseDiT): + """ + An intermediate base class that adds TeaCache optimization functionality to DiT models. + TeaCache accelerates inference by selectively skipping redundant computation when consecutive + diffusion steps are similar enough. + """ + + # These are required class attributes that should be overridden by concrete implementations + _fsdp_shard_conditions = [] + param_names_mapping = {} + reverse_param_names_mapping = {} + lora_param_names_mapping: dict = {} + # Ensure these instance attributes are properly defined in subclasses + hidden_size: int + num_attention_heads: int + num_channels_latents: int + # always supports torch_sdpa + _supported_attention_backends: set[AttentionBackendEnum] = ( + DiTConfig()._supported_attention_backends + ) + + def __init__(self, config: DiTConfig, **kwargs) -> None: + super().__init__(config, **kwargs) + + self.cnt = 0 + self.teacache_thresh = 0 + self.coefficients: list[float] = [] + + # NOTE(will): Only wan2.1 needs these, so we are hardcoding it here + if self.config.prefix == "wan": + self.use_ret_steps = self.config.cache_config.use_ret_steps + self.is_even = False + self.previous_residual_even: torch.Tensor | None = None + self.previous_residual_odd: torch.Tensor | None = None + self.accumulated_rel_l1_distance_even = 0 + self.accumulated_rel_l1_distance_odd = 0 + self.should_calc_even = True + self.should_calc_odd = True + else: + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = None + self.previous_resiual = None + self.previous_e0_even: torch.Tensor | None = None + self.previous_e0_odd: torch.Tensor | None = None + + def maybe_cache_states( + self, hidden_states: torch.Tensor, original_hidden_states: torch.Tensor + ) -> None: + pass + + def should_skip_forward_for_cached_states(self, **kwargs: dict[str, Any]) -> bool: + return False + + def retrieve_cached_states(self, hidden_states: torch.Tensor) -> torch.Tensor: + raise NotImplementedError("maybe_retrieve_cached_states is not implemented") diff --git a/python/sglang/multimodal_gen/runtime/models/dits/causal_wanvideo.py b/python/sglang/multimodal_gen/runtime/models/dits/causal_wanvideo.py new file mode 100644 index 000000000..9cce5ac1f --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/models/dits/causal_wanvideo.py @@ -0,0 +1,851 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +import math +from typing import Any + +import torch +import torch.nn as nn +from torch.nn.attention.flex_attention import ( + BlockMask, + create_block_mask, + flex_attention, +) + +# wan 1.3B model has a weird channel / head configurations and require max-autotune to work with flexattention +# see https://github.com/pytorch/pytorch/issues/133254 +# change to default for other models +flex_attention = torch.compile( + flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs" +) +import torch.distributed as dist + +from sglang.multimodal_gen.configs.models.dits import WanVideoConfig +from sglang.multimodal_gen.runtime.distributed.parallel_state import get_sp_world_size +from sglang.multimodal_gen.runtime.layers.attention import LocalAttention +from sglang.multimodal_gen.runtime.layers.layernorm import ( + FP32LayerNorm, + LayerNormScaleShift, + RMSNorm, + ScaleResidual, + ScaleResidualLayerNormScaleShift, +) +from sglang.multimodal_gen.runtime.layers.linear import ReplicatedLinear +from sglang.multimodal_gen.runtime.layers.mlp import MLP +from sglang.multimodal_gen.runtime.layers.rotary_embedding import ( + _apply_rotary_emb, + get_rotary_pos_embed, +) +from sglang.multimodal_gen.runtime.layers.visual_embedding import PatchEmbed +from sglang.multimodal_gen.runtime.models.dits.base import BaseDiT +from sglang.multimodal_gen.runtime.models.dits.wanvideo import ( + WanT2VCrossAttention, + WanTimeTextImageEmbedding, +) +from sglang.multimodal_gen.runtime.platforms import ( + AttentionBackendEnum, + current_platform, +) +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class CausalWanSelfAttention(nn.Module): + + def __init__( + self, + dim: int, + num_heads: int, + local_attn_size: int = -1, + sink_size: int = 0, + qk_norm=True, + eps=1e-6, + parallel_attention=False, + ) -> None: + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.local_attn_size = local_attn_size + self.sink_size = sink_size + self.qk_norm = qk_norm + self.eps = eps + self.parallel_attention = parallel_attention + self.max_attention_size = ( + 32760 if local_attn_size == -1 else local_attn_size * 1560 + ) + + # Scaled dot product attention + self.attn = LocalAttention( + num_heads=num_heads, + head_size=self.head_dim, + dropout_rate=0, + softmax_scale=None, + causal=False, + supported_attention_backends=( + AttentionBackendEnum.FA3, + AttentionBackendEnum.TORCH_SDPA, + ), + ) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + freqs_cis: tuple[torch.Tensor, torch.Tensor], + block_mask: BlockMask, + kv_cache: dict | None = None, + current_start: int = 0, + cache_start: int | None = None, + ): + r""" + Args: + x(Tensor): Shape [B, L, num_heads, C / num_heads] + seq_lens(Tensor): Shape [B] + grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + """ + if cache_start is None: + cache_start = current_start + + cos, sin = freqs_cis + roped_query = _apply_rotary_emb(q, cos, sin, is_neox_style=False).type_as(v) + roped_key = _apply_rotary_emb(k, cos, sin, is_neox_style=False).type_as(v) + + if kv_cache is None: + # Padding for flex attention + padded_length = math.ceil(q.shape[1] / 128) * 128 - q.shape[1] + padded_roped_query = torch.cat( + [ + roped_query, + torch.zeros( + [q.shape[0], padded_length, q.shape[2], q.shape[3]], + device=q.device, + dtype=v.dtype, + ), + ], + dim=1, + ) + + padded_roped_key = torch.cat( + [ + roped_key, + torch.zeros( + [k.shape[0], padded_length, k.shape[2], k.shape[3]], + device=k.device, + dtype=v.dtype, + ), + ], + dim=1, + ) + + padded_v = torch.cat( + [ + v, + torch.zeros( + [v.shape[0], padded_length, v.shape[2], v.shape[3]], + device=v.device, + dtype=v.dtype, + ), + ], + dim=1, + ) + + x = flex_attention( + query=padded_roped_query.transpose(2, 1), + key=padded_roped_key.transpose(2, 1), + value=padded_v.transpose(2, 1), + block_mask=block_mask, + )[:, :, :-padded_length].transpose(2, 1) + else: + frame_seqlen = q.shape[1] + current_end = current_start + roped_query.shape[1] + sink_tokens = self.sink_size * frame_seqlen + # If we are using local attention and the current KV cache size is larger than the local attention size, we need to truncate the KV cache + kv_cache_size = kv_cache["k"].shape[1] + num_new_tokens = roped_query.shape[1] + if ( + self.local_attn_size != -1 + and (current_end > kv_cache["global_end_index"].item()) + and ( + num_new_tokens + kv_cache["local_end_index"].item() > kv_cache_size + ) + ): + # Calculate the number of new tokens added in this step + # Shift existing cache content left to discard oldest tokens + # Clone the source slice to avoid overlapping memory error + num_evicted_tokens = ( + num_new_tokens + kv_cache["local_end_index"].item() - kv_cache_size + ) + num_rolled_tokens = ( + kv_cache["local_end_index"].item() + - num_evicted_tokens + - sink_tokens + ) + kv_cache["k"][ + :, sink_tokens : sink_tokens + num_rolled_tokens + ] = kv_cache["k"][ + :, + sink_tokens + + num_evicted_tokens : sink_tokens + + num_evicted_tokens + + num_rolled_tokens, + ].clone() + kv_cache["v"][ + :, sink_tokens : sink_tokens + num_rolled_tokens + ] = kv_cache["v"][ + :, + sink_tokens + + num_evicted_tokens : sink_tokens + + num_evicted_tokens + + num_rolled_tokens, + ].clone() + # Insert the new keys/values at the end + local_end_index = ( + kv_cache["local_end_index"].item() + + current_end + - kv_cache["global_end_index"].item() + - num_evicted_tokens + ) + local_start_index = local_end_index - num_new_tokens + kv_cache["k"][:, local_start_index:local_end_index] = roped_key + kv_cache["v"][:, local_start_index:local_end_index] = v + else: + # Assign new keys/values directly up to current_end + local_end_index = ( + kv_cache["local_end_index"].item() + + current_end + - kv_cache["global_end_index"].item() + ) + local_start_index = local_end_index - num_new_tokens + kv_cache["k"] = kv_cache["k"].detach() + kv_cache["v"] = kv_cache["v"].detach() + # logger.info("kv_cache['k'] is in comp graph: %s", kv_cache["k"].requires_grad or kv_cache["k"].grad_fn is not None) + kv_cache["k"][:, local_start_index:local_end_index] = roped_key + kv_cache["v"][:, local_start_index:local_end_index] = v + x = self.attn( + roped_query, + kv_cache["k"][ + :, + max(0, local_end_index - self.max_attention_size) : local_end_index, + ], + kv_cache["v"][ + :, + max(0, local_end_index - self.max_attention_size) : local_end_index, + ], + ) + kv_cache["global_end_index"].fill_(current_end) + kv_cache["local_end_index"].fill_(local_end_index) + + return x + + +class CausalWanTransformerBlock(nn.Module): + + def __init__( + self, + dim: int, + ffn_dim: int, + num_heads: int, + local_attn_size: int = -1, + sink_size: int = 0, + qk_norm: str = "rms_norm_across_heads", + cross_attn_norm: bool = False, + eps: float = 1e-6, + added_kv_proj_dim: int | None = None, + supported_attention_backends: set[AttentionBackendEnum] | None = None, + prefix: str = "", + ): + super().__init__() + + # 1. Self-attention + self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) + self.to_q = ReplicatedLinear(dim, dim, bias=True) + self.to_k = ReplicatedLinear(dim, dim, bias=True) + self.to_v = ReplicatedLinear(dim, dim, bias=True) + + self.to_out = ReplicatedLinear(dim, dim, bias=True) + self.attn1 = CausalWanSelfAttention( + dim, + num_heads, + local_attn_size=local_attn_size, + sink_size=sink_size, + qk_norm=qk_norm, + eps=eps, + ) + self.hidden_dim = dim + self.num_attention_heads = num_heads + self.local_attn_size = local_attn_size + dim_head = dim // num_heads + if qk_norm == "rms_norm": + self.norm_q = RMSNorm(dim_head, eps=eps) + self.norm_k = RMSNorm(dim_head, eps=eps) + elif qk_norm == "rms_norm_across_heads": + # LTX applies qk norm across all heads + self.norm_q = RMSNorm(dim, eps=eps) + self.norm_k = RMSNorm(dim, eps=eps) + else: + print("QK Norm type not supported") + raise Exception + assert cross_attn_norm is True + self.self_attn_residual_norm = ScaleResidualLayerNormScaleShift( + dim, + norm_type="layer", + eps=eps, + elementwise_affine=True, + dtype=torch.float32, + compute_dtype=torch.float32, + ) + + # 2. Cross-attention + # Only T2V for now + self.attn2 = WanT2VCrossAttention(dim, num_heads, qk_norm=qk_norm, eps=eps) + self.cross_attn_residual_norm = ScaleResidualLayerNormScaleShift( + dim, + norm_type="layer", + eps=eps, + elementwise_affine=False, + dtype=torch.float32, + compute_dtype=torch.float32, + ) + + # 3. Feed-forward + self.ffn = MLP(dim, ffn_dim, act_type="gelu_pytorch_tanh") + self.mlp_residual = ScaleResidual() + + self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + freqs_cis: tuple[torch.Tensor, torch.Tensor], + block_mask: BlockMask, + kv_cache: dict | None = None, + crossattn_cache: dict | None = None, + current_start: int = 0, + cache_start: int | None = None, + ) -> torch.Tensor: + # hidden_states.shape: [batch_size, seq_length, inner_dim] + # temb.shape: [batch_size, num_frames, 6, inner_dim] + if hidden_states.dim() == 4: + hidden_states = hidden_states.squeeze(1) + num_frames = temb.shape[1] + frame_seqlen = hidden_states.shape[1] // num_frames + bs, seq_length, _ = hidden_states.shape + orig_dtype = hidden_states.dtype + # assert orig_dtype != torch.float32 + e = self.scale_shift_table + temb.float() + # e.shape: [batch_size, num_frames, 6, inner_dim] + assert e.shape == (bs, num_frames, 6, self.hidden_dim) + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = e.chunk( + 6, dim=2 + ) + # *_msa.shape: [batch_size, num_frames, 1, inner_dim] + assert shift_msa.dtype == torch.float32 + + # 1. Self-attention + norm_hidden_states = ( + ( + self.norm1(hidden_states.float()).unflatten( + dim=1, sizes=(num_frames, frame_seqlen) + ) + * (1 + scale_msa) + + shift_msa + ) + .flatten(1, 2) + .to(orig_dtype) + ) + query, _ = self.to_q(norm_hidden_states) + key, _ = self.to_k(norm_hidden_states) + value, _ = self.to_v(norm_hidden_states) + + if self.norm_q is not None: + query = self.norm_q(query) + if self.norm_k is not None: + key = self.norm_k(key) + + query = query.squeeze(1).unflatten(2, (self.num_attention_heads, -1)) + key = key.squeeze(1).unflatten(2, (self.num_attention_heads, -1)) + value = value.squeeze(1).unflatten(2, (self.num_attention_heads, -1)) + + attn_output = self.attn1( + query, + key, + value, + freqs_cis, + block_mask, + kv_cache, + current_start, + cache_start, + ) + attn_output = attn_output.flatten(2) + attn_output, _ = self.to_out(attn_output) + attn_output = attn_output.squeeze(1) + + null_shift = null_scale = torch.zeroes( + (1,), device=hidden_states.device, dtype=hidden_states.dtype + ) + norm_hidden_states, hidden_states = self.self_attn_residual_norm( + hidden_states, attn_output, gate_msa, null_shift, null_scale + ) + norm_hidden_states, hidden_states = norm_hidden_states.to( + orig_dtype + ), hidden_states.to(orig_dtype) + + # 2. Cross-attention + attn_output = self.attn2( + norm_hidden_states, + context=encoder_hidden_states, + context_lens=None, + crossattn_cache=crossattn_cache, + ) + norm_hidden_states, hidden_states = self.cross_attn_residual_norm( + hidden_states, attn_output, 1, c_shift_msa, c_scale_msa + ) + norm_hidden_states, hidden_states = norm_hidden_states.to( + orig_dtype + ), hidden_states.to(orig_dtype) + + # 3. Feed-forward + ff_output = self.ffn(norm_hidden_states) + hidden_states = self.mlp_residual(hidden_states, ff_output, c_gate_msa) + hidden_states = hidden_states.to(orig_dtype) + + return hidden_states + + +class CausalWanTransformer3DModel(BaseDiT): + _fsdp_shard_conditions = WanVideoConfig()._fsdp_shard_conditions + _compile_conditions = WanVideoConfig()._compile_conditions + _supported_attention_backends = WanVideoConfig()._supported_attention_backends + param_names_mapping = WanVideoConfig().param_names_mapping + reverse_param_names_mapping = WanVideoConfig().reverse_param_names_mapping + lora_param_names_mapping = WanVideoConfig().lora_param_names_mapping + + def __init__(self, config: WanVideoConfig, hf_config: dict[str, Any]) -> None: + super().__init__(config=config, hf_config=hf_config) + + inner_dim = config.num_attention_heads * config.attention_head_dim + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.attention_head_dim = config.attention_head_dim + self.in_channels = config.in_channels + self.out_channels = config.out_channels + self.num_channels_latents = config.num_channels_latents + self.patch_size = config.patch_size + self.text_len = config.text_len + self.local_attn_size = config.local_attn_size + + # 1. Patch & position embedding + self.patch_embedding = PatchEmbed( + in_chans=config.in_channels, + embed_dim=inner_dim, + patch_size=config.patch_size, + flatten=False, + ) + + # 2. Condition embeddings + self.condition_embedder = WanTimeTextImageEmbedding( + dim=inner_dim, + time_freq_dim=config.freq_dim, + text_embed_dim=config.text_dim, + image_embed_dim=config.image_dim, + ) + + # 3. Transformer blocks + self.blocks = nn.ModuleList( + [ + CausalWanTransformerBlock( + inner_dim, + config.ffn_dim, + config.num_attention_heads, + config.local_attn_size, + config.sink_size, + config.qk_norm, + config.cross_attn_norm, + config.eps, + config.added_kv_proj_dim, + self._supported_attention_backends, + prefix=f"{config.prefix}.blocks.{i}", + ) + for i in range(config.num_layers) + ] + ) + + # 4. Output norm & projection + self.norm_out = LayerNormScaleShift( + inner_dim, + norm_type="layer", + eps=config.eps, + elementwise_affine=False, + dtype=torch.float32, + compute_dtype=torch.float32, + ) + self.proj_out = nn.Linear( + inner_dim, config.out_channels * math.prod(config.patch_size) + ) + self.scale_shift_table = nn.Parameter( + torch.randn(1, 2, inner_dim) / inner_dim**0.5 + ) + + self.gradient_checkpointing = False + + # Causal-specific + self.block_mask = None + self.num_frame_per_block = config.arch_config.num_frames_per_block + assert self.num_frame_per_block <= 3 + self.independent_first_frame = False + + self.__post_init__() + + @staticmethod + def _prepare_blockwise_causal_attn_mask( + device: torch.device | str, + num_frames: int = 21, + frame_seqlen: int = 1560, + num_frame_per_block=1, + local_attn_size=-1, + ) -> BlockMask: + """ + we will divide the token sequence into the following format + [1 latent frame] [1 latent frame] ... [1 latent frame] + We use flexattention to construct the attention mask + """ + total_length = num_frames * frame_seqlen + + # we do right padding to get to a multiple of 128 + padded_length = math.ceil(total_length / 128) * 128 - total_length + + ends = torch.zeros( + total_length + padded_length, device=device, dtype=torch.long + ) + + # Block-wise causal mask will attend to all elements that are before the end of the current chunk + frame_indices = torch.arange( + start=0, + end=total_length, + step=frame_seqlen * num_frame_per_block, + device=device, + ) + + for tmp in frame_indices: + ends[tmp : tmp + frame_seqlen * num_frame_per_block] = ( + tmp + frame_seqlen * num_frame_per_block + ) + + def attention_mask(b, h, q_idx, kv_idx): + if local_attn_size == -1: + return (kv_idx < ends[q_idx]) | (q_idx == kv_idx) + else: + return ( + (kv_idx < ends[q_idx]) + & (kv_idx >= (ends[q_idx] - local_attn_size * frame_seqlen)) + ) | (q_idx == kv_idx) + # return ((kv_idx < total_length) & (q_idx < total_length)) | (q_idx == kv_idx) # bidirectional mask + + block_mask = create_block_mask( + attention_mask, + B=None, + H=None, + Q_LEN=total_length + padded_length, + KV_LEN=total_length + padded_length, + _compile=False, + device=device, + ) + + if not dist.is_initialized() or dist.get_rank() == 0: + print( + f" cache a block wise causal mask with block size of {num_frame_per_block} frames" + ) + print(block_mask) + + # import imageio + # import numpy as np + # from torch.nn.attention.flex_attention import create_mask + + # mask = create_mask(attention_mask, B=None, H=None, Q_LEN=total_length + + # padded_length, KV_LEN=total_length + padded_length, device=device) + # import cv2 + # mask = cv2.resize(mask[0, 0].cpu().float().numpy(), (1024, 1024)) + # imageio.imwrite("mask_%d.jpg" % (0), np.uint8(255. * mask)) + + return block_mask + + def _forward_inference( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | list[torch.Tensor], + timestep: torch.LongTensor, + encoder_hidden_states_image: torch.Tensor | list[torch.Tensor] | None = None, + kv_cache: dict = None, + crossattn_cache: dict = None, + current_start: int = 0, + cache_start: int = 0, + start_frame: int = 0, + **kwargs, + ) -> torch.Tensor: + r""" + Run the diffusion model with kv caching. + See Algorithm 2 of CausVid paper https://arxiv.org/abs/2412.07772 for details. + This function will be run for num_frame times. + Process the latent frames one by one (1560 tokens each) + """ + + orig_dtype = hidden_states.dtype + if not isinstance(encoder_hidden_states, torch.Tensor): + encoder_hidden_states = encoder_hidden_states[0] + if ( + isinstance(encoder_hidden_states_image, list) + and len(encoder_hidden_states_image) > 0 + ): + encoder_hidden_states_image = encoder_hidden_states_image[0] + else: + encoder_hidden_states_image = None + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + + # Get rotary embeddings + d = self.hidden_size // self.num_attention_heads + rope_dim_list = [d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)] + freqs_cos, freqs_sin = get_rotary_pos_embed( + ( + post_patch_num_frames * get_sp_world_size(), + post_patch_height, + post_patch_width, + ), + self.hidden_size, + self.num_attention_heads, + rope_dim_list, + dtype=torch.float32 if current_platform.is_mps() else torch.float64, + rope_theta=10000, + start_frame=start_frame, # Assume that start_frame is 0 when kv_cache is None + ) + freqs_cos = freqs_cos.to(hidden_states.device) + freqs_sin = freqs_sin.to(hidden_states.device) + freqs_cis = ( + (freqs_cos.float(), freqs_sin.float()) if freqs_cos is not None else None + ) + + hidden_states = self.patch_embedding(hidden_states) + hidden_states = hidden_states.flatten(2).transpose(1, 2) + + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = ( + self.condition_embedder( + timestep.flatten(), encoder_hidden_states, encoder_hidden_states_image + ) + ) + timestep_proj = timestep_proj.unflatten(1, (6, self.hidden_size)).unflatten( + dim=0, sizes=timestep.shape + ) + + if encoder_hidden_states_image is not None: + encoder_hidden_states = torch.concat( + [encoder_hidden_states_image, encoder_hidden_states], dim=1 + ) + + encoder_hidden_states = ( + encoder_hidden_states.to(orig_dtype) + if current_platform.is_mps() + else encoder_hidden_states + ) # cast to orig_dtype for MPS + + assert encoder_hidden_states.dtype == orig_dtype + + # 4. Transformer blocks + for block_index, block in enumerate(self.blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + causal_kwargs = { + "kv_cache": kv_cache[block_index], + "current_start": current_start, + "cache_start": cache_start, + "block_mask": self.block_mask, + } + hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + timestep_proj, + freqs_cis, + **causal_kwargs, + ) + else: + causal_kwargs = { + "kv_cache": kv_cache[block_index], + "crossattn_cache": crossattn_cache[block_index], + "current_start": current_start, + "cache_start": cache_start, + "block_mask": self.block_mask, + } + hidden_states = block( + hidden_states, + encoder_hidden_states, + timestep_proj, + freqs_cis, + **causal_kwargs, + ) + + # 5. Output norm, projection & unpatchify + temb = temb.unflatten(dim=0, sizes=timestep.shape).unsqueeze(2) + shift, scale = (self.scale_shift_table.unsqueeze(1) + temb).chunk(2, dim=2) + hidden_states = self.norm_out(hidden_states, shift, scale) + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.reshape( + batch_size, + post_patch_num_frames, + post_patch_height, + post_patch_width, + p_t, + p_h, + p_w, + -1, + ) + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + return output + + def _forward_train( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | list[torch.Tensor], + timestep: torch.LongTensor, + encoder_hidden_states_image: torch.Tensor | list[torch.Tensor] | None = None, + start_frame: int = 0, + **kwargs, + ) -> torch.Tensor: + + orig_dtype = hidden_states.dtype + if not isinstance(encoder_hidden_states, torch.Tensor): + encoder_hidden_states = encoder_hidden_states[0] + if ( + isinstance(encoder_hidden_states_image, list) + and len(encoder_hidden_states_image) > 0 + ): + encoder_hidden_states_image = encoder_hidden_states_image[0] + else: + encoder_hidden_states_image = None + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + + # Get rotary embeddings + d = self.hidden_size // self.num_attention_heads + rope_dim_list = [d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)] + freqs_cos, freqs_sin = get_rotary_pos_embed( + ( + post_patch_num_frames * get_sp_world_size(), + post_patch_height, + post_patch_width, + ), + self.hidden_size, + self.num_attention_heads, + rope_dim_list, + dtype=torch.float32 if current_platform.is_mps() else torch.float64, + rope_theta=10000, + start_frame=start_frame, + ) + freqs_cos = freqs_cos.to(hidden_states.device) + freqs_sin = freqs_sin.to(hidden_states.device) + freqs_cis = ( + (freqs_cos.float(), freqs_sin.float()) if freqs_cos is not None else None + ) + + # Construct blockwise causal attn mask + if self.block_mask is None: + self.block_mask = self._prepare_blockwise_causal_attn_mask( + device=hidden_states.device, + num_frames=num_frames, + frame_seqlen=post_patch_height * post_patch_width, + num_frame_per_block=self.num_frame_per_block, + local_attn_size=self.local_attn_size, + ) + + hidden_states = self.patch_embedding(hidden_states) + hidden_states = hidden_states.flatten(2).transpose(1, 2) + + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = ( + self.condition_embedder( + timestep.flatten(), encoder_hidden_states, encoder_hidden_states_image + ) + ) + timestep_proj = timestep_proj.unflatten(1, (6, self.hidden_size)).unflatten( + dim=0, sizes=timestep.shape + ) + + if encoder_hidden_states_image is not None: + encoder_hidden_states = torch.concat( + [encoder_hidden_states_image, encoder_hidden_states], dim=1 + ) + + encoder_hidden_states = ( + encoder_hidden_states.to(orig_dtype) + if current_platform.is_mps() + else encoder_hidden_states + ) # cast to orig_dtype for MPS + + assert encoder_hidden_states.dtype == orig_dtype + + # 4. Transformer blocks + if torch.is_grad_enabled() and self.gradient_checkpointing: + for block in self.blocks: + hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + timestep_proj, + freqs_cis, + block_mask=self.block_mask, + ) + else: + for block in self.blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states, + timestep_proj, + freqs_cis, + block_mask=self.block_mask, + ) + + # 5. Output norm, projection & unpatchify + temb = temb.unflatten(dim=0, sizes=timestep.shape).unsqueeze(2) + shift, scale = (self.scale_shift_table.unsqueeze(1) + temb).chunk(2, dim=2) + hidden_states = self.norm_out(hidden_states, shift, scale) + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.reshape( + batch_size, + post_patch_num_frames, + post_patch_height, + post_patch_width, + p_t, + p_h, + p_w, + -1, + ) + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + return output + + def forward(self, *args, **kwargs): + if kwargs.get("kv_cache") is not None: + return self._forward_inference(*args, **kwargs) + else: + return self._forward_train(*args, **kwargs) + + +EntryClass = CausalWanTransformer3DModel diff --git a/python/sglang/multimodal_gen/runtime/models/dits/flux.py b/python/sglang/multimodal_gen/runtime/models/dits/flux.py new file mode 100644 index 000000000..5bc7dad76 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/models/dits/flux.py @@ -0,0 +1,559 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from diffusers.models.attention import AttentionModuleMixin, FeedForward +from diffusers.models.embeddings import ( + CombinedTimestepGuidanceTextProjEmbeddings, + CombinedTimestepTextProjEmbeddings, +) +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.normalization import ( + AdaLayerNormContinuous, + AdaLayerNormZero, + AdaLayerNormZeroSingle, +) +from torch.nn import LayerNorm as LayerNorm + +from sglang.multimodal_gen.configs.models.dits.flux import FluxConfig +from sglang.multimodal_gen.runtime.layers.attention import LocalAttention + +# from sglang.multimodal_gen.runtime.layers.layernorm import LayerNorm as LayerNorm +from sglang.multimodal_gen.runtime.layers.layernorm import RMSNorm +from sglang.multimodal_gen.runtime.layers.linear import ReplicatedLinear +from sglang.multimodal_gen.runtime.layers.mlp import MLP +from sglang.multimodal_gen.runtime.layers.rotary_embedding import ( + NDRotaryEmbedding, + _apply_rotary_emb, +) +from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT +from sglang.multimodal_gen.runtime.platforms import ( + AttentionBackendEnum, + current_platform, +) +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) # pylint: disable=invalid-name + + +def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): + query, _ = attn.to_q(hidden_states) + key, _ = attn.to_k(hidden_states) + value, _ = attn.to_v(hidden_states) + + encoder_query = encoder_key = encoder_value = None + if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: + encoder_query, _ = attn.add_q_proj(encoder_hidden_states) + encoder_key, _ = attn.add_k_proj(encoder_hidden_states) + encoder_value, _ = attn.add_v_proj(encoder_hidden_states) + + return query, key, value, encoder_query, encoder_key, encoder_value + + +def _get_fused_projections( + attn: "FluxAttention", hidden_states, encoder_hidden_states=None +): + query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) + + encoder_query = encoder_key = encoder_value = None + if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"): + encoder_query, encoder_key, encoder_value = attn.to_added_qkv( + encoder_hidden_states + ).chunk(3, dim=-1) + + return query, key, value, encoder_query, encoder_key, encoder_value + + +def _get_qkv_projections( + attn: "FluxAttention", hidden_states, encoder_hidden_states=None +): + if attn.fused_projections: + return _get_fused_projections(attn, hidden_states, encoder_hidden_states) + return _get_projections(attn, hidden_states, encoder_hidden_states) + + +class FluxAttention(torch.nn.Module, AttentionModuleMixin): + + def __init__( + self, + query_dim: int, + num_heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int = None, + context_pre_only: Optional[bool] = None, + pre_only: bool = False, + ): + super().__init__() + + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * num_heads + self.query_dim = query_dim + self.use_bias = bias + self.dropout = dropout + self.out_dim = out_dim if out_dim is not None else query_dim + self.context_pre_only = context_pre_only + self.pre_only = pre_only + self.heads = out_dim // dim_head if out_dim is not None else num_heads + self.added_kv_proj_dim = added_kv_proj_dim + self.added_proj_bias = added_proj_bias + + self.norm_q = RMSNorm(dim_head, eps=eps) + + self.norm_k = RMSNorm(dim_head, eps=eps) + self.to_q = ReplicatedLinear(query_dim, self.inner_dim, bias=bias) + self.to_k = ReplicatedLinear(query_dim, self.inner_dim, bias=bias) + self.to_v = ReplicatedLinear(query_dim, self.inner_dim, bias=bias) + + if not self.pre_only: + self.to_out = torch.nn.ModuleList([]) + self.to_out.append( + ReplicatedLinear(self.inner_dim, self.out_dim, bias=out_bias) + ) + if dropout != 0.0: + self.to_out.append(torch.nn.Dropout(dropout)) + + if added_kv_proj_dim is not None: + self.norm_added_q = RMSNorm(dim_head, eps=eps) + self.norm_added_k = RMSNorm(dim_head, eps=eps) + self.add_q_proj = ReplicatedLinear( + added_kv_proj_dim, self.inner_dim, bias=added_proj_bias + ) + self.add_k_proj = ReplicatedLinear( + added_kv_proj_dim, self.inner_dim, bias=added_proj_bias + ) + self.add_v_proj = ReplicatedLinear( + added_kv_proj_dim, self.inner_dim, bias=added_proj_bias + ) + self.to_add_out = ReplicatedLinear(self.inner_dim, query_dim, bias=out_bias) + + # Scaled dot product attention + self.attn = LocalAttention( + num_heads=num_heads, + head_size=self.head_dim, + dropout_rate=0, + softmax_scale=None, + causal=False, + supported_attention_backends=( + AttentionBackendEnum.FA3, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.SAGE_ATTN, + ), + ) + + def forward( + self, + x: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + freqs_cis=None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + query, key, value, encoder_query, encoder_key, encoder_value = ( + _get_qkv_projections(self, x, encoder_hidden_states) + ) + + query = query.unflatten(-1, (self.heads, -1)) + key = key.unflatten(-1, (self.heads, -1)) + value = value.unflatten(-1, (self.heads, -1)) + query = self.norm_q(query) + key = self.norm_k(key) + + if self.added_kv_proj_dim is not None: + encoder_query = encoder_query.unflatten(-1, (self.heads, -1)) + encoder_key = encoder_key.unflatten(-1, (self.heads, -1)) + encoder_value = encoder_value.unflatten(-1, (self.heads, -1)) + + encoder_query = self.norm_added_q(encoder_query) + encoder_key = self.norm_added_k(encoder_key) + + bsz, seq_len, _, _ = query.shape + query = torch.cat([encoder_query, query], dim=1) + key = torch.cat([encoder_key, key], dim=1) + value = torch.cat([encoder_value, value], dim=1) + + if freqs_cis is not None: + cos, sin = freqs_cis + query = _apply_rotary_emb( + query, cos, sin, is_neox_style=False, interleaved=False + ) + key = _apply_rotary_emb( + key, cos, sin, is_neox_style=False, interleaved=False + ) + + x = self.attn(query, key, value) + x = x.flatten(2, 3) + x = x.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, x = x.split_with_sizes( + [ + encoder_hidden_states.shape[1], + x.shape[1] - encoder_hidden_states.shape[1], + ], + dim=1, + ) + x, _ = self.to_out[0](x) + if len(self.to_out) == 2: + x = self.to_out[1](x) + encoder_hidden_states, _ = self.to_add_out(encoder_hidden_states) + + return x, encoder_hidden_states + else: + return x + + +class FluxSingleTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 4.0, + ): + super().__init__() + self.mlp_hidden_dim = int(dim * mlp_ratio) + + self.norm = AdaLayerNormZeroSingle(dim) + self.proj_mlp = ReplicatedLinear(dim, self.mlp_hidden_dim) + self.act_mlp = nn.GELU(approximate="tanh") + self.proj_out = ReplicatedLinear(dim + self.mlp_hidden_dim, dim) + + self.attn = FluxAttention( + query_dim=dim, + dim_head=attention_head_dim, + num_heads=num_attention_heads, + out_dim=dim, + bias=True, + eps=1e-6, + pre_only=True, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + text_seq_len = encoder_hidden_states.shape[1] + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + residual = hidden_states + norm_hidden_states, gate = self.norm(hidden_states, emb=temb) + proj_hidden_states, _ = self.proj_mlp(norm_hidden_states) + mlp_hidden_states = self.act_mlp(proj_hidden_states) + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + x=norm_hidden_states, + freqs_cis=freqs_cis, + **joint_attention_kwargs, + ) + + hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) + gate = gate.unsqueeze(1) + proj_out, _ = self.proj_out(hidden_states) + hidden_states = gate * proj_out + hidden_states = residual + hidden_states + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + encoder_hidden_states, hidden_states = ( + hidden_states[:, :text_seq_len], + hidden_states[:, text_seq_len:], + ) + return encoder_hidden_states, hidden_states + + +class FluxTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + qk_norm: str = "rms_norm", + eps: float = 1e-6, + ): + super().__init__() + + self.norm1 = AdaLayerNormZero(dim) + self.norm1_context = AdaLayerNormZero(dim) + + self.attn = FluxAttention( + query_dim=dim, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + num_heads=num_attention_heads, + out_dim=dim, + context_pre_only=False, + bias=True, + eps=eps, + ) + + self.norm2 = LayerNorm(dim, eps=1e-6, elementwise_affine=False) + self.ff = MLP( + input_dim=dim, mlp_hidden_dim=dim * 4, output_dim=dim, act_type="gelu" + ) + self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + self.norm2_context = LayerNorm(dim, eps=1e-6, elementwise_affine=False) + self.ff_context = MLP( + input_dim=dim, mlp_hidden_dim=dim * 4, output_dim=dim, act_type="gelu" + ) + + self.ff_context = FeedForward( + dim=dim, dim_out=dim, activation_fn="gelu-approximate" + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, emb=temb + ) + + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = ( + self.norm1_context(encoder_hidden_states, emb=temb) + ) + + joint_attention_kwargs = joint_attention_kwargs or {} + # Attention. + attention_outputs = self.attn( + x=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + freqs_cis=freqs_cis, + **joint_attention_kwargs, + ) + + if len(attention_outputs) == 2: + attn_output, context_attn_output = attention_outputs + elif len(attention_outputs) == 3: + attn_output, context_attn_output, ip_attn_output = attention_outputs + + # Process attention outputs for the `hidden_states`. + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = hidden_states + attn_output + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = ( + norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ) + + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = hidden_states + ff_output + + if len(attention_outputs) == 3: + hidden_states = hidden_states + ip_attn_output + # Process attention outputs for the `encoder_hidden_states`. + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = ( + norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + + c_shift_mlp[:, None] + ) + + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = ( + encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + ) + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +class FluxPosEmbed(nn.Module): + # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 + def __init__(self, theta: int, axes_dim: List[int]): + super().__init__() + self.rope = NDRotaryEmbedding( + rope_dim_list=axes_dim, + rope_theta=theta, + use_real=False, + repeat_interleave_real=False, + dtype=torch.float32 if current_platform.is_mps() else torch.float64, + ) + + def forward(self, ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + pos = ids.float() + # freqs_cos, freqs_sin = self.rope.forward(positions=pos) + freqs_cos, freqs_sin = self.rope.forward_uncached(pos=pos) + return freqs_cos.contiguous().float(), freqs_sin.contiguous().float() + + +class FluxTransformer2DModel(CachableDiT): + """ + The Transformer model introduced in Flux. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + """ + + def __init__(self, config: FluxConfig, hf_config: dict[str, Any]) -> None: + super().__init__(config=config, hf_config=hf_config) + self.config = config.arch_config + + self.out_channels = ( + getattr(self.config, "out_channels", None) or self.config.in_channels + ) + self.inner_dim = ( + self.config.num_attention_heads * self.config.attention_head_dim + ) + + self.rotary_emb = FluxPosEmbed(theta=10000, axes_dim=self.config.axes_dims_rope) + + text_time_guidance_cls = ( + CombinedTimestepGuidanceTextProjEmbeddings + if self.config.guidance_embeds + else CombinedTimestepTextProjEmbeddings + ) + self.time_text_embed = text_time_guidance_cls( + embedding_dim=self.inner_dim, + pooled_projection_dim=self.config.pooled_projection_dim, + ) + + self.context_embedder = ReplicatedLinear( + self.config.joint_attention_dim, self.inner_dim + ) + self.x_embedder = ReplicatedLinear(self.config.in_channels, self.inner_dim) + self.transformer_blocks = nn.ModuleList( + [ + FluxTransformerBlock( + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + attention_head_dim=self.config.attention_head_dim, + ) + for _ in range(self.config.num_layers) + ] + ) + + self.single_transformer_blocks = nn.ModuleList( + [ + FluxSingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=self.config.num_attention_heads, + attention_head_dim=self.config.attention_head_dim, + ) + for _ in range(self.config.num_single_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous( + self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6 + ) + self.proj_out = ReplicatedLinear( + self.inner_dim, + self.config.patch_size * self.config.patch_size * self.out_channels, + bias=True, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + pooled_projections: torch.Tensor = None, + timestep: torch.LongTensor = None, + guidance: torch.Tensor = None, + freqs_cis: torch.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + """ + The [`FluxTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): + Input `hidden_states`. + encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected + from the embeddings of input conditions. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + guidance (`torch.Tensor`): + Guidance embeddings. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + + """ + if ( + joint_attention_kwargs is not None + and joint_attention_kwargs.get("scale", None) is not None + ): + logger.warning( + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + ) + hidden_states, _ = self.x_embedder(hidden_states) + + temb = ( + self.time_text_embed(timestep, pooled_projections) + if guidance is None + else self.time_text_embed(timestep, guidance, pooled_projections) + ) + + encoder_hidden_states, _ = self.context_embedder(encoder_hidden_states) + + if ( + joint_attention_kwargs is not None + and "ip_adapter_image_embeds" in joint_attention_kwargs + ): + ip_adapter_image_embeds = joint_attention_kwargs.pop( + "ip_adapter_image_embeds" + ) + ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds) + joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states}) + + for index_block, block in enumerate(self.transformer_blocks): + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + freqs_cis=freqs_cis, + joint_attention_kwargs=joint_attention_kwargs, + ) + + for index_block, block in enumerate(self.single_transformer_blocks): + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + freqs_cis=freqs_cis, + joint_attention_kwargs=joint_attention_kwargs, + ) + + hidden_states = self.norm_out(hidden_states, temb) + + output, _ = self.proj_out(hidden_states) + + return output + + +EntryClass = FluxTransformer2DModel diff --git a/python/sglang/multimodal_gen/runtime/models/dits/hunyuanvideo.py b/python/sglang/multimodal_gen/runtime/models/dits/hunyuanvideo.py new file mode 100644 index 000000000..f6394e942 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/models/dits/hunyuanvideo.py @@ -0,0 +1,961 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +import numpy as np +import torch +import torch.nn as nn + +from sglang.multimodal_gen.configs.models.dits import HunyuanVideoConfig +from sglang.multimodal_gen.configs.sample.teacache import TeaCacheParams +from sglang.multimodal_gen.runtime.distributed.parallel_state import get_sp_world_size +from sglang.multimodal_gen.runtime.layers.attention import ( + LocalAttention, + UlyssesAttention, +) +from sglang.multimodal_gen.runtime.layers.layernorm import ( + LayerNormScaleShift, + RMSNorm, + ScaleResidual, + ScaleResidualLayerNormScaleShift, +) +from sglang.multimodal_gen.runtime.layers.linear import ReplicatedLinear +from sglang.multimodal_gen.runtime.layers.mlp import MLP +from sglang.multimodal_gen.runtime.layers.rotary_embedding import ( + _apply_rotary_emb, + get_rotary_pos_embed, +) +from sglang.multimodal_gen.runtime.layers.visual_embedding import ( + ModulateProjection, + PatchEmbed, + TimestepEmbedder, + unpatchify, +) +from sglang.multimodal_gen.runtime.managers.forward_context import get_forward_context +from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT +from sglang.multimodal_gen.runtime.models.utils import modulate +from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum + + +class MMDoubleStreamBlock(nn.Module): + """ + A multimodal DiT block with separate modulation for text and image/video, + using distributed attention and linear layers. + """ + + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + mlp_ratio: float, + dtype: torch.dtype | None = None, + supported_attention_backends: set[AttentionBackendEnum] | None = None, + prefix: str = "", + ): + super().__init__() + + self.deterministic = False + self.num_attention_heads = num_attention_heads + head_dim = hidden_size // num_attention_heads + mlp_hidden_dim = int(hidden_size * mlp_ratio) + + # Image modulation components + self.img_mod = ModulateProjection( + hidden_size, + factor=6, + act_layer="silu", + dtype=dtype, + prefix=f"{prefix}.img_mod", + ) + + # Fused operations for image stream + self.img_attn_norm = LayerNormScaleShift( + hidden_size, norm_type="layer", elementwise_affine=False, dtype=dtype + ) + self.img_attn_residual_mlp_norm = ScaleResidualLayerNormScaleShift( + hidden_size, norm_type="layer", elementwise_affine=False, dtype=dtype + ) + self.img_mlp_residual = ScaleResidual() + + # Image attention components + self.img_attn_qkv = ReplicatedLinear( + hidden_size, + hidden_size * 3, + bias=True, + params_dtype=dtype, + prefix=f"{prefix}.img_attn_qkv", + ) + + self.img_attn_q_norm = RMSNorm(head_dim, eps=1e-6, dtype=dtype) + self.img_attn_k_norm = RMSNorm(head_dim, eps=1e-6, dtype=dtype) + + self.img_attn_proj = ReplicatedLinear( + hidden_size, + hidden_size, + bias=True, + params_dtype=dtype, + prefix=f"{prefix}.img_attn_proj", + ) + + self.img_mlp = MLP( + hidden_size, + mlp_hidden_dim, + bias=True, + dtype=dtype, + prefix=f"{prefix}.img_mlp", + ) + + # Text modulation components + self.txt_mod = ModulateProjection( + hidden_size, + factor=6, + act_layer="silu", + dtype=dtype, + prefix=f"{prefix}.txt_mod", + ) + + # Fused operations for text stream + self.txt_attn_norm = LayerNormScaleShift( + hidden_size, norm_type="layer", elementwise_affine=False, dtype=dtype + ) + self.txt_attn_residual_mlp_norm = ScaleResidualLayerNormScaleShift( + hidden_size, norm_type="layer", elementwise_affine=False, dtype=dtype + ) + self.txt_mlp_residual = ScaleResidual() + + # Text attention components + self.txt_attn_qkv = ReplicatedLinear( + hidden_size, hidden_size * 3, bias=True, params_dtype=dtype + ) + + # QK norm layers for text + self.txt_attn_q_norm = RMSNorm(head_dim, eps=1e-6, dtype=dtype) + self.txt_attn_k_norm = RMSNorm(head_dim, eps=1e-6, dtype=dtype) + + self.txt_attn_proj = ReplicatedLinear( + hidden_size, hidden_size, bias=True, params_dtype=dtype + ) + + self.txt_mlp = MLP(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype) + + # Use UlyssesAttention to replace Distributed attention + self.attn = UlyssesAttention( + num_heads=num_attention_heads, + head_size=head_dim, + causal=False, + supported_attention_backends=supported_attention_backends, + prefix=f"{prefix}.attn", + ) + + def forward( + self, + img: torch.Tensor, + txt: torch.Tensor, + vec: torch.Tensor, + freqs_cis: tuple, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Process modulation vectors + img_mod_outputs = self.img_mod(vec) + ( + img_attn_shift, + img_attn_scale, + img_attn_gate, + img_mlp_shift, + img_mlp_scale, + img_mlp_gate, + ) = torch.chunk(img_mod_outputs, 6, dim=-1) + + txt_mod_outputs = self.txt_mod(vec) + ( + txt_attn_shift, + txt_attn_scale, + txt_attn_gate, + txt_mlp_shift, + txt_mlp_scale, + txt_mlp_gate, + ) = torch.chunk(txt_mod_outputs, 6, dim=-1) + + # Prepare image for attention using fused operation + img_attn_input = self.img_attn_norm(img, img_attn_shift, img_attn_scale) + # Get QKV for image + img_qkv, _ = self.img_attn_qkv(img_attn_input) + batch_size, image_seq_len = img_qkv.shape[0], img_qkv.shape[1] + + # Split QKV + img_qkv = img_qkv.view( + batch_size, image_seq_len, 3, self.num_attention_heads, -1 + ) + img_q, img_k, img_v = img_qkv[:, :, 0], img_qkv[:, :, 1], img_qkv[:, :, 2] + + # Apply QK-Norm if needed + + img_q = self.img_attn_q_norm(img_q.contiguous()).to(img_v) + img_k = self.img_attn_k_norm(img_k.contiguous()).to(img_v) + # Apply rotary embeddings + cos, sin = freqs_cis + img_q, img_k = _apply_rotary_emb( + img_q, cos, sin, is_neox_style=False + ), _apply_rotary_emb(img_k, cos, sin, is_neox_style=False) + # Prepare text for attention using fused operation + txt_attn_input = self.txt_attn_norm(txt, txt_attn_shift, txt_attn_scale) + + # Get QKV for text + txt_qkv, _ = self.txt_attn_qkv(txt_attn_input) + batch_size, text_seq_len = txt_qkv.shape[0], txt_qkv.shape[1] + + # Split QKV + txt_qkv = txt_qkv.view( + batch_size, text_seq_len, 3, self.num_attention_heads, -1 + ) + txt_q, txt_k, txt_v = txt_qkv[:, :, 0], txt_qkv[:, :, 1], txt_qkv[:, :, 2] + + # Apply QK-Norm if needed + txt_q = self.txt_attn_q_norm(txt_q.contiguous()).to(txt_q.dtype) + txt_k = self.txt_attn_k_norm(txt_k.contiguous()).to(txt_k.dtype) + + # Run distributed attention + img_attn, txt_attn = self.attn(img_q, img_k, img_v, txt_q, txt_k, txt_v) + img_attn_out, _ = self.img_attn_proj( + img_attn.view(batch_size, image_seq_len, -1) + ) + # Use fused operation for residual connection, normalization, and modulation + img_mlp_input, img_residual = self.img_attn_residual_mlp_norm( + img, img_attn_out, img_attn_gate, img_mlp_shift, img_mlp_scale + ) + + # Process image MLP + img_mlp_out = self.img_mlp(img_mlp_input) + img = self.img_mlp_residual(img_residual, img_mlp_out, img_mlp_gate) + + # Process text attention output + txt_attn_out, _ = self.txt_attn_proj( + txt_attn.reshape(batch_size, text_seq_len, -1) + ) + + # Use fused operation for residual connection, normalization, and modulation + txt_mlp_input, txt_residual = self.txt_attn_residual_mlp_norm( + txt, txt_attn_out, txt_attn_gate, txt_mlp_shift, txt_mlp_scale + ) + + # Process text MLP + txt_mlp_out = self.txt_mlp(txt_mlp_input) + txt = self.txt_mlp_residual(txt_residual, txt_mlp_out, txt_mlp_gate) + + return img, txt + + +class MMSingleStreamBlock(nn.Module): + """ + A DiT block with parallel linear layers using distributed attention + and tensor parallelism. + """ + + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + mlp_ratio: float = 4.0, + dtype: torch.dtype | None = None, + supported_attention_backends: set[AttentionBackendEnum] | None = None, + prefix: str = "", + ): + super().__init__() + + self.deterministic = False + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + head_dim = hidden_size // num_attention_heads + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp_hidden_dim = mlp_hidden_dim + + # Combined QKV and MLP input projection + self.linear1 = ReplicatedLinear( + hidden_size, + hidden_size * 3 + mlp_hidden_dim, + bias=True, + params_dtype=dtype, + prefix=f"{prefix}.linear1", + ) + + # Combined projection and MLP output + self.linear2 = ReplicatedLinear( + hidden_size + mlp_hidden_dim, + hidden_size, + bias=True, + params_dtype=dtype, + prefix=f"{prefix}.linear2", + ) + + # QK norm layers + self.q_norm = RMSNorm(head_dim, eps=1e-6, dtype=dtype) + self.k_norm = RMSNorm(head_dim, eps=1e-6, dtype=dtype) + + # Fused operations with better naming + self.input_norm_scale_shift = LayerNormScaleShift( + hidden_size, + norm_type="layer", + eps=1e-6, + elementwise_affine=False, + dtype=dtype, + ) + self.output_residual = ScaleResidual() + + # Activation function + self.mlp_act = nn.GELU(approximate="tanh") + + # Modulation + self.modulation = ModulateProjection( + hidden_size, + factor=3, + act_layer="silu", + dtype=dtype, + prefix=f"{prefix}.modulation", + ) + + # Use UlyssesAttention to replace Distributed attention + self.attn = UlyssesAttention( + num_heads=num_attention_heads, + head_size=head_dim, + causal=False, + supported_attention_backends=supported_attention_backends, + prefix=f"{prefix}.attn", + ) + + def forward( + self, + x: torch.Tensor, + vec: torch.Tensor, + txt_len: int, + freqs_cis: tuple[torch.Tensor, torch.Tensor], + ) -> torch.Tensor: + # Process modulation + mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1) + + # Apply pre-norm and modulation using fused operation + x_mod = self.input_norm_scale_shift(x, mod_shift, mod_scale) + + # Get combined projections + linear1_out, _ = self.linear1(x_mod) + + # Split into QKV and MLP parts + qkv, mlp = torch.split( + linear1_out, [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1 + ) + + # Process QKV + batch_size, seq_len = qkv.shape[0], qkv.shape[1] + qkv = qkv.view(batch_size, seq_len, 3, self.num_attention_heads, -1) + q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2] + + # Apply QK-Norm + q = self.q_norm(q.contiguous()).to(v.dtype) + k = self.k_norm(k.contiguous()).to(v.dtype) + + # Split into image and text parts + img_q, txt_q = q[:, :-txt_len], q[:, -txt_len:] + img_k, txt_k = k[:, :-txt_len], k[:, -txt_len:] + img_v, txt_v = v[:, :-txt_len], v[:, -txt_len:] + # Apply rotary embeddings to image parts + cos, sin = freqs_cis + img_q, img_k = _apply_rotary_emb( + img_q, cos, sin, is_neox_style=False + ), _apply_rotary_emb(img_k, cos, sin, is_neox_style=False) + + # Run distributed attention + img_attn_output, txt_attn_output = self.attn( + img_q, img_k, img_v, txt_q, txt_k, txt_v + ) + attn_output = torch.cat((img_attn_output, txt_attn_output), dim=1).view( + batch_size, seq_len, -1 + ) + # Process MLP activation + mlp_output = self.mlp_act(mlp) + + # Combine attention and MLP outputs + combined = torch.cat((attn_output, mlp_output), dim=-1) + + # Final projection + output, _ = self.linear2(combined) + + # Apply residual connection with gating using fused operation + return self.output_residual(x, output, mod_gate) + + +class HunyuanVideoTransformer3DModel(CachableDiT): + """ + HunyuanVideo Transformer backbone adapted for distributed training. + + This implementation uses distributed attention and linear layers for efficient + parallel processing across multiple GPUs. + + Based on the architecture from: + - Flux.1: https://github.com/black-forest-labs/flux + - MMDiT: http://arxiv.org/abs/2403.03206 + """ + + # PY: we make the input args the same as HF config + + # shard single stream, double stream blocks, and refiner_blocks + _fsdp_shard_conditions = HunyuanVideoConfig()._fsdp_shard_conditions + _compile_conditions = HunyuanVideoConfig()._compile_conditions + _supported_attention_backends = HunyuanVideoConfig()._supported_attention_backends + param_names_mapping = HunyuanVideoConfig().param_names_mapping + reverse_param_names_mapping = HunyuanVideoConfig().reverse_param_names_mapping + lora_param_names_mapping = HunyuanVideoConfig().lora_param_names_mapping + + def __init__(self, config: HunyuanVideoConfig, hf_config: dict[str, Any]): + super().__init__(config=config, hf_config=hf_config) + + self.patch_size = [config.patch_size_t, config.patch_size, config.patch_size] + self.in_channels = config.in_channels + self.num_channels_latents = config.num_channels_latents + self.out_channels = ( + config.in_channels if config.out_channels is None else config.out_channels + ) + self.unpatchify_channels = self.out_channels + self.guidance_embeds = config.guidance_embeds + self.rope_dim_list = list(config.rope_axes_dim) + self.rope_theta = config.rope_theta + self.text_states_dim = config.text_embed_dim + self.text_states_dim_2 = config.pooled_projection_dim + # TODO(will): hack? + self.dtype = config.dtype + + pe_dim = config.hidden_size // config.num_attention_heads + if sum(config.rope_axes_dim) != pe_dim: + raise ValueError( + f"Got {config.rope_axes_dim} but expected positional dim {pe_dim}" + ) + + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.num_channels_latents = config.num_channels_latents + + # Image projection + self.img_in = PatchEmbed( + self.patch_size, + self.in_channels, + self.hidden_size, + dtype=config.dtype, + prefix=f"{config.prefix}.img_in", + ) + + self.txt_in = SingleTokenRefiner( + self.text_states_dim, + config.hidden_size, + config.num_attention_heads, + depth=config.num_refiner_layers, + dtype=config.dtype, + prefix=f"{config.prefix}.txt_in", + ) + + # Time modulation + self.time_in = TimestepEmbedder( + self.hidden_size, + act_layer="silu", + dtype=config.dtype, + prefix=f"{config.prefix}.time_in", + ) + + # Text modulation + self.vector_in = MLP( + self.text_states_dim_2, + self.hidden_size, + self.hidden_size, + act_type="silu", + dtype=config.dtype, + prefix=f"{config.prefix}.vector_in", + ) + + # Guidance modulation + self.guidance_in = ( + TimestepEmbedder( + self.hidden_size, + act_layer="silu", + dtype=config.dtype, + prefix=f"{config.prefix}.guidance_in", + ) + if self.guidance_embeds + else None + ) + + # Double blocks + self.double_blocks = nn.ModuleList( + [ + MMDoubleStreamBlock( + config.hidden_size, + config.num_attention_heads, + mlp_ratio=config.mlp_ratio, + dtype=config.dtype, + supported_attention_backends=self._supported_attention_backends, + prefix=f"{config.prefix}.double_blocks.{i}", + ) + for i in range(config.num_layers) + ] + ) + + # Single blocks + self.single_blocks = nn.ModuleList( + [ + MMSingleStreamBlock( + config.hidden_size, + config.num_attention_heads, + mlp_ratio=config.mlp_ratio, + dtype=config.dtype, + supported_attention_backends=self._supported_attention_backends, + prefix=f"{config.prefix}.single_blocks.{i+config.num_layers}", + ) + for i in range(config.num_single_layers) + ] + ) + + self.final_layer = FinalLayer( + config.hidden_size, + self.patch_size, + self.out_channels, + dtype=config.dtype, + prefix=f"{config.prefix}.final_layer", + ) + + self.__post_init__() + + # TODO: change the input the FORWARD_BATCH Dict + # TODO: change output to a dict + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | list[torch.Tensor], + timestep: torch.LongTensor, + encoder_hidden_states_image: torch.Tensor | list[torch.Tensor] | None = None, + guidance=None, + **kwargs, + ): + """ + Forward pass of the HunyuanDiT model. + + Args: + hidden_states: Input image/video latents [B, C, T, H, W] + encoder_hidden_states: Text embeddings [B, L, D] + timestep: Diffusion timestep + guidance: Guidance scale for CFG + + Returns: + Tuple of (output) + """ + forward_context = get_forward_context() + forward_batch = forward_context.forward_batch + enable_teacache = forward_batch is not None and forward_batch.enable_teacache + + if guidance is None: + guidance = torch.tensor( + [6016.0], device=hidden_states.device, dtype=hidden_states.dtype + ) + + img = x = hidden_states + t = timestep + + # Split text embeddings - first token is global, rest are per-token + if isinstance(encoder_hidden_states, torch.Tensor): + txt = encoder_hidden_states[:, 1:] + text_states_2 = encoder_hidden_states[:, 0, : self.text_states_dim_2] + else: + txt = encoder_hidden_states[0] + text_states_2 = encoder_hidden_states[1] + + # Get spatial dimensions + _, _, ot, oh, ow = x.shape # codespell:ignore + tt, th, tw = ( + ot // self.patch_size[0], # codespell:ignore + oh // self.patch_size[1], + ow // self.patch_size[2], + ) + + # Get rotary embeddings + freqs_cos, freqs_sin = get_rotary_pos_embed( + (tt * get_sp_world_size(), th, tw), + self.hidden_size, + self.num_attention_heads, + self.rope_dim_list, + self.rope_theta, + ) + freqs_cos = freqs_cos.to(x.device) + freqs_sin = freqs_sin.to(x.device) + # Prepare modulation vectors + vec = self.time_in(t) + + # Add text modulation + vec = vec + self.vector_in(text_states_2) + + # Add guidance modulation if needed + if self.guidance_in and guidance is not None: + vec = vec + self.guidance_in(guidance) + # Embed image and text + img = self.img_in(img) + txt = self.txt_in(txt, t) + txt_seq_len = txt.shape[1] + img_seq_len = img.shape[1] + + freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None + + should_skip_forward = self.should_skip_forward_for_cached_states( + img=img, vec=vec + ) + + if should_skip_forward: + img = self.retrieve_cached_states(img) + else: + if enable_teacache: + original_img = img.clone() + + # Process through double stream blocks + for index, block in enumerate(self.double_blocks): + double_block_args = [img, txt, vec, freqs_cis] + img, txt = block(*double_block_args) + # Merge txt and img to pass through single stream blocks + x = torch.cat((img, txt), 1) + + # Process through single stream blocks + if len(self.single_blocks) > 0: + for index, block in enumerate(self.single_blocks): + single_block_args = [ + x, + vec, + txt_seq_len, + freqs_cis, + ] + x = block(*single_block_args) + + # Extract image features + img = x[:, :img_seq_len, ...] + + if enable_teacache: + self.maybe_cache_states(img, original_img) + + # Final layer processing + img = self.final_layer(img, vec) + # Unpatchify to get original shape + img = unpatchify(img, tt, th, tw, self.patch_size, self.out_channels) + + return img + + def maybe_cache_states( + self, hidden_states: torch.Tensor, original_hidden_states: torch.Tensor + ) -> None: + self.previous_residual = hidden_states - original_hidden_states + + def should_skip_forward_for_cached_states(self, **kwargs) -> bool: + + forward_context = get_forward_context() + forward_batch = forward_context.forward_batch + if forward_batch is None: + return False + current_timestep = forward_context.current_timestep + enable_teacache = forward_batch.enable_teacache + + if not enable_teacache: + return False + raise NotImplementedError("teacache is not supported yet for HunyuanVideo") + + teacache_params = forward_batch.teacache_params + assert teacache_params is not None, "teacache_params is not initialized" + assert isinstance( + teacache_params, TeaCacheParams + ), "teacache_params is not a TeaCacheParams" + num_inference_steps = forward_batch.num_inference_steps + teache_thresh = teacache_params.teacache_thresh + + coefficients = teacache_params.coefficients + + if current_timestep == 0: + self.cnt = 0 + + inp = kwargs["img"].clone() + vec_ = kwargs["vec"].clone() + # convert to DTensor + vec_ = torch.distributed.tensor.DTensor.from_local( + vec_, + torch.distributed.DeviceMesh( + "cuda", list(range(get_sp_world_size())), mesh_dim_names=("dp",) + ), + [torch.distributed.tensor.Replicate()], + ) + + inp = torch.distributed.tensor.DTensor.from_local( + inp, + torch.distributed.DeviceMesh( + "cuda", list(range(get_sp_world_size())), mesh_dim_names=("dp",) + ), + [torch.distributed.tensor.Replicate()], + ) + + # txt_ = kwargs["txt"].clone() + + # inp = img.clone() + # vec_ = vec.clone() + # txt_ = txt.clone() + ( + img_mod1_shift, + img_mod1_scale, + img_mod1_gate, + img_mod2_shift, + img_mod2_scale, + img_mod2_gate, + ) = ( + self.double_blocks[0].img_mod(vec_).chunk(6, dim=-1) + ) + normed_inp = self.double_blocks[0].img_attn_norm.norm(inp) + modulated_inp = modulate(normed_inp, shift=img_mod1_shift, scale=img_mod1_scale) + if self.cnt == 0 or self.cnt == num_inference_steps - 1: + should_calc = True + self.accumulated_rel_l1_distance = 0 + else: + coefficients = [ + 7.33226126e02, + -4.01131952e02, + 6.75869174e01, + -3.14987800e00, + 9.61237896e-02, + ] + rescale_func = np.poly1d(coefficients) + assert ( + self.previous_modulated_input is not None + ), "previous_modulated_input is not initialized" + self.accumulated_rel_l1_distance += rescale_func( + ( + (modulated_inp - self.previous_modulated_input).abs().mean() + / self.previous_modulated_input.abs().mean() + ) + .cpu() + .item() + ) + if self.accumulated_rel_l1_distance < teache_thresh: + should_calc = False + else: + should_calc = True + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = modulated_inp + self.cnt += 1 + + return not should_calc + + def retrieve_cached_states(self, hidden_states: torch.Tensor) -> torch.Tensor: + return hidden_states + self.previous_residual + + +class SingleTokenRefiner(nn.Module): + """ + A token refiner that processes text embeddings with attention to improve + their representation for cross-attention with image features. + """ + + def __init__( + self, + in_channels, + hidden_size, + num_attention_heads, + depth=2, + qkv_bias=True, + dtype=None, + prefix: str = "", + ) -> None: + super().__init__() + + # Input projection + self.input_embedder = ReplicatedLinear( + in_channels, + hidden_size, + bias=True, + params_dtype=dtype, + prefix=f"{prefix}.input_embedder", + ) + + # Timestep embedding + self.t_embedder = TimestepEmbedder( + hidden_size, act_layer="silu", dtype=dtype, prefix=f"{prefix}.t_embedder" + ) + + # Context embedding + self.c_embedder = MLP( + in_channels, + hidden_size, + hidden_size, + act_type="silu", + dtype=dtype, + prefix=f"{prefix}.c_embedder", + ) + + # Refiner blocks + self.refiner_blocks = nn.ModuleList( + [ + IndividualTokenRefinerBlock( + hidden_size, + num_attention_heads, + qkv_bias=qkv_bias, + dtype=dtype, + prefix=f"{prefix}.refiner_blocks.{i}", + ) + for i in range(depth) + ] + ) + + def forward(self, x, t): + # Get timestep embeddings + timestep_aware_representations = self.t_embedder(t) + + # Get context-aware representations + + context_aware_representations = torch.mean(x, dim=1) + + context_aware_representations = self.c_embedder(context_aware_representations) + c = timestep_aware_representations + context_aware_representations + # Project input + x, _ = self.input_embedder(x) + # Process through refiner blocks + for block in self.refiner_blocks: + x = block(x, c) + return x + + +class IndividualTokenRefinerBlock(nn.Module): + """ + A transformer block for refining individual tokens with self-attention. + """ + + def __init__( + self, + hidden_size, + num_attention_heads, + mlp_ratio=4.0, + qkv_bias=True, + dtype=None, + prefix: str = "", + ) -> None: + super().__init__() + self.num_attention_heads = num_attention_heads + mlp_hidden_dim = int(hidden_size * mlp_ratio) + + # Normalization and attention + self.norm1 = nn.LayerNorm( + hidden_size, eps=1e-6, elementwise_affine=True, dtype=dtype + ) + + self.self_attn_qkv = ReplicatedLinear( + hidden_size, + hidden_size * 3, + bias=qkv_bias, + params_dtype=dtype, + prefix=f"{prefix}.self_attn_qkv", + ) + + self.self_attn_proj = ReplicatedLinear( + hidden_size, + hidden_size, + bias=qkv_bias, + params_dtype=dtype, + prefix=f"{prefix}.self_attn_proj", + ) + + # MLP + self.norm2 = nn.LayerNorm( + hidden_size, eps=1e-6, elementwise_affine=True, dtype=dtype + ) + self.mlp = MLP( + hidden_size, + mlp_hidden_dim, + bias=True, + act_type="silu", + dtype=dtype, + prefix=f"{prefix}.mlp", + ) + + # Modulation + self.adaLN_modulation = ModulateProjection( + hidden_size, + factor=2, + act_layer="silu", + dtype=dtype, + prefix=f"{prefix}.adaLN_modulation", + ) + + # Scaled dot product attention + self.attn = LocalAttention( + num_heads=num_attention_heads, + head_size=hidden_size // num_attention_heads, + # TODO: remove hardcode; remove STA + supported_attention_backends=( + AttentionBackendEnum.FA3, + AttentionBackendEnum.TORCH_SDPA, + ), + ) + + def forward(self, x, c): + # Get modulation parameters + gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=-1) + # Self-attention + norm_x = self.norm1(x) + qkv, _ = self.self_attn_qkv(norm_x) + + batch_size, seq_len = qkv.shape[0], qkv.shape[1] + qkv = qkv.view(batch_size, seq_len, 3, self.num_attention_heads, -1) + q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2] + + # Run scaled dot product attention + attn_output = self.attn(q, k, v) # [B, L, H, D] + attn_output = attn_output.reshape(batch_size, seq_len, -1) # [B, L, H*D] + + # Project and apply residual connection with gating + attn_out, _ = self.self_attn_proj(attn_output) + x = x + attn_out * gate_msa.unsqueeze(1) + + # MLP + mlp_out = self.mlp(self.norm2(x)) + x = x + mlp_out * gate_mlp.unsqueeze(1) + + return x + + +class FinalLayer(nn.Module): + """ + The final layer of DiT that projects features to pixel space. + """ + + def __init__( + self, hidden_size, patch_size, out_channels, dtype=None, prefix: str = "" + ) -> None: + super().__init__() + + # Normalization + self.norm_final = nn.LayerNorm( + hidden_size, eps=1e-6, elementwise_affine=False, dtype=dtype + ) + + output_dim = patch_size[0] * patch_size[1] * patch_size[2] * out_channels + + self.linear = ReplicatedLinear( + hidden_size, + output_dim, + bias=True, + params_dtype=dtype, + prefix=f"{prefix}.linear", + ) + + # Modulation + self.adaLN_modulation = ModulateProjection( + hidden_size, + factor=2, + act_layer="silu", + dtype=dtype, + prefix=f"{prefix}.adaLN_modulation", + ) + + def forward(self, x, c): + # What the heck HF? Why you change the scale and shift order here??? + scale, shift = self.adaLN_modulation(c).chunk(2, dim=-1) + x = self.norm_final(x) * (1.0 + scale.unsqueeze(1)) + shift.unsqueeze(1) + x, _ = self.linear(x) + return x + + +EntryClass = HunyuanVideoTransformer3DModel diff --git a/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py b/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py new file mode 100644 index 000000000..97a676891 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/models/dits/qwen_image.py @@ -0,0 +1,651 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +import functools +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +from diffusers.models.attention import FeedForward +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.normalization import AdaLayerNormContinuous + +from sglang.multimodal_gen.configs.models.dits.qwenimage import QwenImageDitConfig +from sglang.multimodal_gen.runtime.layers.attention import LocalAttention +from sglang.multimodal_gen.runtime.layers.layernorm import LayerNorm, RMSNorm +from sglang.multimodal_gen.runtime.layers.linear import ReplicatedLinear +from sglang.multimodal_gen.runtime.layers.triton_ops import ( + apply_rotary_embedding, + fuse_scale_shift_kernel, +) +from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT +from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) # pylint: disable=invalid-name + + +class QwenTimestepProjEmbeddings(nn.Module): + def __init__(self, embedding_dim): + super().__init__() + + self.time_proj = Timesteps( + num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000 + ) + self.timestep_embedder = TimestepEmbedding( + in_channels=256, time_embed_dim=embedding_dim + ) + + def forward(self, timestep, hidden_states): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder( + timesteps_proj.to(dtype=hidden_states.dtype) + ) # (N, D) + + conditioning = timesteps_emb + + return conditioning + + +class QwenEmbedRope(nn.Module): + def __init__(self, theta: int, axes_dim: List[int], scale_rope=False): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + pos_index = torch.arange(4096) + neg_index = torch.arange(4096).flip(0) * -1 - 1 + self.pos_freqs = torch.cat( + [ + self.rope_params(pos_index, self.axes_dim[0], self.theta), + self.rope_params(pos_index, self.axes_dim[1], self.theta), + self.rope_params(pos_index, self.axes_dim[2], self.theta), + ], + dim=1, + ) + self.neg_freqs = torch.cat( + [ + self.rope_params(neg_index, self.axes_dim[0], self.theta), + self.rope_params(neg_index, self.axes_dim[1], self.theta), + self.rope_params(neg_index, self.axes_dim[2], self.theta), + ], + dim=1, + ) + + # self.rope = NDRotaryEmbedding( + # rope_dim_list=axes_dim, + # rope_theta=theta, + # use_real=False, + # repeat_interleave_real=False, + # dtype=torch.float32 if current_platform.is_mps() else torch.float64, + # ) + + # DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART + self.scale_rope = scale_rope + + def rope_params(self, index, dim, theta=10000): + """ + Args: + index: [0, 1, 2, 3] 1D Tensor representing the position index of the token + """ + device = index.device + assert dim % 2 == 0 + freqs = torch.outer( + index, + ( + 1.0 + / torch.pow( + theta, + torch.arange(0, dim, 2, device=device).to(torch.float32).div(dim), + ) + ).to(device=device), + ) + freqs = torch.polar(torch.ones_like(freqs), freqs) + return freqs + + def forward( + self, + video_fhw: Union[Tuple[int, int, int], List[Tuple[int, int, int]]], + txt_seq_lens: List[int], + device: torch.device, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + video_fhw (`Tuple[int, int, int]` or `List[Tuple[int, int, int]]`): + A list of 3 integers [frame, height, width] representing the shape of the video. + txt_seq_lens (`List[int]`): + A list of integers of length batch_size representing the length of each text prompt. + device: (`torch.device`): + The device on which to perform the RoPE computation. + """ + # When models are initialized under a "meta" device context (e.g. init_empty_weights), + # tensors created during __init__ become meta tensors. Calling .to(...) on a meta tensor + # raises "Cannot copy out of meta tensor". Rebuild the frequencies on the target device + # in that case; otherwise move them if just on a different device. + if getattr(self.pos_freqs, "device", torch.device("meta")).type == "meta": + pos_index = torch.arange(4096, device=device) + neg_index = torch.arange(4096, device=device).flip(0) * -1 - 1 + self.pos_freqs = torch.cat( + [ + self.rope_params(pos_index, self.axes_dim[0], self.theta), + self.rope_params(pos_index, self.axes_dim[1], self.theta), + self.rope_params(pos_index, self.axes_dim[2], self.theta), + ], + dim=1, + ).to(device=device) + self.neg_freqs = torch.cat( + [ + self.rope_params(neg_index, self.axes_dim[0], self.theta), + self.rope_params(neg_index, self.axes_dim[1], self.theta), + self.rope_params(neg_index, self.axes_dim[2], self.theta), + ], + dim=1, + ).to(device=device) + elif self.pos_freqs.device != device: + self.pos_freqs = self.pos_freqs.to(device) + self.neg_freqs = self.neg_freqs.to(device) + + if isinstance(video_fhw, list): + video_fhw = video_fhw[0] + if not isinstance(video_fhw, list): + video_fhw = [video_fhw] + + vid_freqs = [] + max_vid_index = 0 + for idx, fhw in enumerate(video_fhw): + frame, height, width = fhw + # RoPE frequencies are cached via a lru_cache decorator on _compute_video_freqs + video_freq = self._compute_video_freqs(frame, height, width, idx) + video_freq = video_freq.to(device) + vid_freqs.append(video_freq) + + if self.scale_rope: + max_vid_index = max(height // 2, width // 2, max_vid_index) + else: + max_vid_index = max(height, width, max_vid_index) + + max_len = max(txt_seq_lens) + txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] + vid_freqs = torch.cat(vid_freqs, dim=0).to(device=device) + return vid_freqs, txt_freqs + + @functools.lru_cache(maxsize=128) + def _compute_video_freqs( + self, frame: int, height: int, width: int, idx: int = 0 + ) -> torch.Tensor: + seq_lens = frame * height * width + freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) + freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) + + freqs_frame = ( + freqs_pos[0][idx : idx + frame] + .view(frame, 1, 1, -1) + .expand(frame, height, width, -1) + ) + if self.scale_rope: + freqs_height = torch.cat( + [freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], + dim=0, + ) + freqs_height = freqs_height.view(1, height, 1, -1).expand( + frame, height, width, -1 + ) + freqs_width = torch.cat( + [freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], + dim=0, + ) + freqs_width = freqs_width.view(1, 1, width, -1).expand( + frame, height, width, -1 + ) + else: + freqs_height = ( + freqs_pos[1][:height] + .view(1, height, 1, -1) + .expand(frame, height, width, -1) + ) + freqs_width = ( + freqs_pos[2][:width] + .view(1, 1, width, -1) + .expand(frame, height, width, -1) + ) + + freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape( + seq_lens, -1 + ) + return freqs.clone().contiguous() + + +class QwenImageCrossAttention(nn.Module): + + def __init__( + self, + dim: int, # query_dim + num_heads: int, + head_dim: int, + window_size=(-1, -1), + added_kv_proj_dim: int = None, + out_bias: bool = True, + qk_norm=True, # rmsnorm + eps=1e-6, + pre_only=False, + context_pre_only: bool = False, + parallel_attention=False, + out_dim: int = None, + ) -> None: + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.window_size = window_size + self.qk_norm = qk_norm + self.eps = eps + self.parallel_attention = parallel_attention + + # layers + self.to_q = ReplicatedLinear(dim, dim) + self.to_k = ReplicatedLinear(dim, dim) + self.to_v = ReplicatedLinear(dim, dim) + if self.qk_norm: + self.norm_q = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity() + self.norm_k = RMSNorm(head_dim, eps=eps) if qk_norm else nn.Identity() + self.inner_dim = out_dim if out_dim is not None else head_dim * num_heads + self.inner_kv_dim = self.inner_dim + if added_kv_proj_dim is not None: + self.add_k_proj = ReplicatedLinear( + added_kv_proj_dim, self.inner_kv_dim, bias=True + ) + self.add_v_proj = ReplicatedLinear( + added_kv_proj_dim, self.inner_kv_dim, bias=True + ) + if context_pre_only is not None: + self.add_q_proj = ReplicatedLinear( + added_kv_proj_dim, self.inner_dim, bias=True + ) + + if context_pre_only is not None and not context_pre_only: + self.to_add_out = ReplicatedLinear(self.inner_dim, self.dim, bias=out_bias) + else: + self.to_add_out = None + + if not pre_only: + self.to_out = nn.ModuleList([]) + self.to_out.append( + ReplicatedLinear(self.inner_dim, self.dim, bias=out_bias) + ) + else: + self.to_out = None + + self.norm_added_q = RMSNorm(head_dim, eps=eps) + self.norm_added_k = RMSNorm(head_dim, eps=eps) + + # Scaled dot product attention + self.attn = LocalAttention( + num_heads=num_heads, + head_size=self.head_dim, + dropout_rate=0, + softmax_scale=None, + causal=False, + supported_attention_backends={ + AttentionBackendEnum.FA3, + AttentionBackendEnum.TORCH_SDPA, + }, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + image_rotary_emb: tuple[torch.Tensor, torch.Tensor], + **cross_attention_kwargs, + ): + seq_txt = encoder_hidden_states.shape[1] + + # Compute QKV for image stream (sample projections) + img_query, _ = self.to_q(hidden_states) + img_key, _ = self.to_k(hidden_states) + img_value, _ = self.to_v(hidden_states) + + # Compute QKV for text stream (context projections) + txt_query, _ = self.add_q_proj(encoder_hidden_states) + txt_key, _ = self.add_k_proj(encoder_hidden_states) + txt_value, _ = self.add_v_proj(encoder_hidden_states) + + # Reshape for multi-head attention + img_query = img_query.unflatten(-1, (self.num_heads, -1)) + img_key = img_key.unflatten(-1, (self.num_heads, -1)) + img_value = img_value.unflatten(-1, (self.num_heads, -1)) + + txt_query = txt_query.unflatten(-1, (self.num_heads, -1)) + txt_key = txt_key.unflatten(-1, (self.num_heads, -1)) + txt_value = txt_value.unflatten(-1, (self.num_heads, -1)) + + # Apply QK normalization + if self.norm_q is not None: + img_query = self.norm_q(img_query) + if self.norm_k is not None: + img_key = self.norm_k(img_key) + if self.norm_added_q is not None: + txt_query = self.norm_added_q(txt_query) + if self.norm_added_k is not None: + txt_key = self.norm_added_k(txt_key) + + # Apply RoPE + if image_rotary_emb is not None: + (img_cos, img_sin), (txt_cos, txt_sin) = image_rotary_emb + img_query = apply_rotary_embedding( + img_query, img_cos, img_sin, interleaved=True + ) + img_key = apply_rotary_embedding( + img_key, img_cos, img_sin, interleaved=True + ) + txt_query = apply_rotary_embedding( + txt_query, txt_cos, txt_sin, interleaved=True + ) + txt_key = apply_rotary_embedding( + txt_key, txt_cos, txt_sin, interleaved=True + ) + + # Concatenate for joint attention + # Order: [text, image] + joint_query = torch.cat([txt_query, img_query], dim=1) + joint_key = torch.cat([txt_key, img_key], dim=1) + joint_value = torch.cat([txt_value, img_value], dim=1) + + # Compute joint attention + joint_hidden_states = self.attn( + joint_query, + joint_key, + joint_value, + ) + + # Reshape back + joint_hidden_states = joint_hidden_states.flatten(2, 3) + joint_hidden_states = joint_hidden_states.to(joint_query.dtype) + + # Split attention outputs back + txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part + img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part + + # Apply output projections + img_attn_output, _ = self.to_out[0](img_attn_output) + if len(self.to_out) > 1: + (img_attn_output,) = self.to_out[1](img_attn_output) # dropout + + txt_attn_output, _ = self.to_add_out(txt_attn_output) + + return img_attn_output, txt_attn_output + + +class QwenImageTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + qk_norm: str = "rms_norm", + eps: float = 1e-6, + ): + super().__init__() + + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + + # Image processing modules + self.img_mod = nn.Sequential( + nn.SiLU(), + nn.Linear( + dim, 6 * dim, bias=True + ), # For scale, shift, gate for norm1 and norm2 + ) + self.img_norm1 = LayerNorm(dim, elementwise_affine=False, eps=eps) + + self.attn = QwenImageCrossAttention( + dim=dim, + num_heads=num_attention_heads, + added_kv_proj_dim=dim, + context_pre_only=False, + head_dim=attention_head_dim, + ) + self.img_norm2 = LayerNorm(dim, eps=eps, elementwise_affine=False) + self.img_mlp = FeedForward( + dim=dim, dim_out=dim, activation_fn="gelu-approximate" + ) + + # Text processing modules + self.txt_mod = nn.Sequential( + nn.SiLU(), + nn.Linear( + dim, 6 * dim, bias=True + ), # For scale, shift, gate for norm1 and norm2 + ) + self.txt_norm1 = LayerNorm(dim, elementwise_affine=False, eps=eps) + # Text doesn't need separate attention - it's handled by img_attn joint computation + self.txt_norm2 = LayerNorm(dim, elementwise_affine=False, eps=eps) + self.txt_mlp = FeedForward( + dim=dim, dim_out=dim, activation_fn="gelu-approximate" + ) + + def _modulate(self, x, mod_params): + """Apply modulation to input tensor""" + shift, scale, gate = mod_params.chunk(3, dim=-1) + return fuse_scale_shift_kernel(x, scale, shift), gate.unsqueeze(1) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_mask: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Get modulation parameters for both streams + img_mod_params = self.img_mod(temb) # [B, 6*dim] + txt_mod_params = self.txt_mod(temb) # [B, 6*dim] + + # Split modulation parameters for norm1 and norm2 + img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) # Each [B, 3*dim] + txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim] + + # Process image stream - norm1 + modulation + + img_normed = self.img_norm1(hidden_states) + + img_modulated, img_gate1 = self._modulate(img_normed, img_mod1) + + # Process text stream - norm1 + modulation + txt_normed = self.txt_norm1(encoder_hidden_states) + txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1) + + # Use QwenAttnProcessor2_0 for joint attention computation + # This directly implements the DoubleStreamLayerMegatron logic: + # 1. Computes QKV for both streams + # 2. Applies QK normalization and RoPE + # 3. Concatenates and runs joint attention + # 4. Splits results back to separate streams + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=img_modulated, # Image stream (will be processed as "sample") + encoder_hidden_states=txt_modulated, # Text stream (will be processed as "context") + encoder_hidden_states_mask=encoder_hidden_states_mask, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + # QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided + img_attn_output, txt_attn_output = attn_output + + # Apply attention gates and add residual (like in Megatron) + hidden_states = hidden_states + img_gate1 * img_attn_output + + encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output + + # Process image stream - norm2 + MLP + img_normed2 = self.img_norm2(hidden_states) + img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2) + img_mlp_output = self.img_mlp(img_modulated2) + hidden_states = hidden_states + img_gate2 * img_mlp_output + + # Process text stream - norm2 + MLP + txt_normed2 = self.txt_norm2(encoder_hidden_states) + txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2) + txt_mlp_output = self.txt_mlp(txt_modulated2) + encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output + + # Clip to prevent overflow for fp16 + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +class QwenImageTransformer2DModel(CachableDiT): + """ + The Transformer model introduced in Qwen. + + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["QwenImageTransformerBlock"] + _skip_layerwise_casting_patterns = ["pos_embed", "norm"] + _repeated_blocks = ["QwenImageTransformerBlock"] + + def __init__( + self, + config: QwenImageDitConfig, + hf_config: dict[str, Any], + ): + super().__init__(config=config, hf_config=hf_config) + patch_size = config.arch_config.patch_size + in_channels = config.arch_config.in_channels + out_channels = config.arch_config.out_channels + num_layers = config.arch_config.num_layers + attention_head_dim = config.arch_config.attention_head_dim + num_attention_heads = config.arch_config.num_attention_heads + joint_attention_dim = config.arch_config.joint_attention_dim + axes_dims_rope = config.arch_config.axes_dims_rope + self.out_channels = out_channels or in_channels + self.inner_dim = num_attention_heads * attention_head_dim + + self.rotary_emb = QwenEmbedRope( + theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True + ) + + self.time_text_embed = QwenTimestepProjEmbeddings(embedding_dim=self.inner_dim) + + self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6) + + self.img_in = nn.Linear(in_channels, self.inner_dim) + self.txt_in = nn.Linear(joint_attention_dim, self.inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + QwenImageTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + ) + for _ in range(num_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous( + self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6 + ) + self.proj_out = nn.Linear( + self.inner_dim, patch_size * patch_size * self.out_channels, bias=True + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + encoder_hidden_states_mask: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_shapes: Optional[List[Tuple[int, int, int]]] = None, + txt_seq_lens: Optional[List[int]] = None, + freqs_cis: tuple[torch.Tensor, torch.Tensor] = None, + guidance: torch.Tensor = None, # TODO: this should probably be removed + attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_block_samples=None, + return_dict: bool = True, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + """ + The [`QwenTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): + Input `hidden_states`. + encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`): + Mask of the input conditions. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + if ( + attention_kwargs is not None + and attention_kwargs.get("scale", None) is not None + ): + logger.warning( + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + ) + + if isinstance(encoder_hidden_states, list): + encoder_hidden_states = encoder_hidden_states[0] + + hidden_states = self.img_in(hidden_states) + + timestep = (timestep / 1000).to(hidden_states.dtype) + encoder_hidden_states = self.txt_norm(encoder_hidden_states) + encoder_hidden_states = self.txt_in(encoder_hidden_states) + + temb = self.time_text_embed(timestep, hidden_states) + + image_rotary_emb = freqs_cis + for index_block, block in enumerate(self.transformer_blocks): + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=attention_kwargs, + ) + + # controlnet residual + if controlnet_block_samples is not None: + interval_control = len(self.transformer_blocks) / len( + controlnet_block_samples + ) + interval_control = int(np.ceil(interval_control)) + hidden_states = ( + hidden_states + + controlnet_block_samples[index_block // interval_control] + ) + + # Use only the image part (hidden_states) from the dual-stream blocks + hidden_states = self.norm_out(hidden_states, temb) + + output = self.proj_out(hidden_states) + return output + + +EntryClass = QwenImageTransformer2DModel diff --git a/python/sglang/multimodal_gen/runtime/models/dits/stepvideo.py b/python/sglang/multimodal_gen/runtime/models/dits/stepvideo.py new file mode 100644 index 000000000..50b15b61b --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/models/dits/stepvideo.py @@ -0,0 +1,729 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# Copyright 2025 StepFun Inc. All Rights Reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# ============================================================================== +from typing import Any + +import torch +from einops import rearrange, repeat +from torch import nn + +from sglang.multimodal_gen.configs.models.dits import StepVideoConfig +from sglang.multimodal_gen.runtime.distributed.parallel_state import get_sp_world_size +from sglang.multimodal_gen.runtime.layers.attention import LocalAttention, USPAttention +from sglang.multimodal_gen.runtime.layers.layernorm import LayerNormScaleShift +from sglang.multimodal_gen.runtime.layers.linear import ReplicatedLinear +from sglang.multimodal_gen.runtime.layers.mlp import MLP +from sglang.multimodal_gen.runtime.layers.rotary_embedding import ( + _apply_rotary_emb, + get_rotary_pos_embed, +) +from sglang.multimodal_gen.runtime.layers.visual_embedding import TimestepEmbedder +from sglang.multimodal_gen.runtime.models.dits.base import BaseDiT +from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum + + +class PatchEmbed2D(nn.Module): + """2D Image to Patch Embedding + + Image to Patch Embedding using Conv2d + + A convolution based approach to patchifying a 2D image w/ embedding projection. + + Based on the impl in https://github.com/google-research/vision_transformer + + Hacked together by / Copyright 2020 Ross Wightman + + Remove the _assert function in forward function to be compatible with multi-resolution images. + """ + + def __init__( + self, + patch_size=16, + in_chans=3, + embed_dim=768, + norm_layer=None, + flatten=True, + bias=True, + dtype=None, + prefix: str = "", + ): + super().__init__() + # Convert patch_size to 2-tuple + if isinstance(patch_size, list | tuple): + if len(patch_size) == 1: + patch_size = (patch_size[0], patch_size[0]) + else: + patch_size = (patch_size, patch_size) + + self.patch_size = patch_size + self.flatten = flatten + + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=patch_size, + stride=patch_size, + bias=bias, + dtype=dtype, + ) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x + + +class StepVideoRMSNorm(nn.Module): + + def __init__( + self, + dim: int, + elementwise_affine=True, + eps: float = 1e-6, + device=None, + dtype=None, + ): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs)) + + def _norm(self, x) -> torch.Tensor: + """ + Apply the RMSNorm normalization to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + + """ + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + + """ + output = self._norm(x.float()).type_as(x) + if hasattr(self, "weight"): + output = output * self.weight + return output + + +class SelfAttention(nn.Module): + + def __init__( + self, + hidden_dim, + head_dim, + rope_split: tuple[int, int, int] = (64, 32, 32), + bias: bool = False, + with_rope: bool = True, + with_qk_norm: bool = True, + attn_type: str = "torch", + supported_attention_backends=( + AttentionBackendEnum.FA3, + AttentionBackendEnum.TORCH_SDPA, + ), + ): + super().__init__() + self.head_dim = head_dim + self.hidden_dim = hidden_dim + self.rope_split = list(rope_split) + self.n_heads = hidden_dim // head_dim + + self.wqkv = ReplicatedLinear(hidden_dim, hidden_dim * 3, bias=bias) + self.wo = ReplicatedLinear(hidden_dim, hidden_dim, bias=bias) + + self.with_rope = with_rope + self.with_qk_norm = with_qk_norm + if self.with_qk_norm: + self.q_norm = StepVideoRMSNorm(head_dim, elementwise_affine=True) + self.k_norm = StepVideoRMSNorm(head_dim, elementwise_affine=True) + + # self.core_attention = self.attn_processor(attn_type=attn_type) + self.parallel = attn_type == "parallel" + self.attn = USPAttention( + num_heads=self.n_heads, + head_size=head_dim, + causal=False, + supported_attention_backends=supported_attention_backends, + ) + + def _apply_rope(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): + """ + x: [B, S, H, D] + cos: [S, D/2] where D = head_dim = sum(self.rope_split) + sin: [S, D/2] + returns x with rotary applied exactly as v0 did + """ + B, S, H, D = x.shape + # 1) split cos/sin per chunk + half_splits = [c // 2 for c in self.rope_split] # [32,16,16] for [64,32,32] + cos_splits = cos.split(half_splits, dim=1) + sin_splits = sin.split(half_splits, dim=1) + + outs = [] + idx = 0 + for chunk_size, cos_i, sin_i in zip( + self.rope_split, cos_splits, sin_splits, strict=True + ): + # slice the corresponding channels + x_chunk = x[..., idx : idx + chunk_size] # [B,S,H,chunk_size] + idx += chunk_size + + # flatten to [S, B*H, chunk_size] + x_flat = rearrange(x_chunk, "b s h d -> s (b h) d") + + # apply rotary on *that* chunk + out_flat = _apply_rotary_emb(x_flat, cos_i, sin_i, is_neox_style=True) + + # restore [B,S,H,chunk_size] + out = rearrange(out_flat, "s (b h) d -> b s h d", b=B, h=H) + outs.append(out) + + # concatenate back to [B,S,H,D] + return torch.cat(outs, dim=-1) + + def forward( + self, + x, + cu_seqlens=None, + max_seqlen=None, + rope_positions=None, + cos_sin=None, + attn_mask=None, + mask_strategy=None, + ): + + B, S, _ = x.shape + xqkv, _ = self.wqkv(x) + xqkv = xqkv.view(*x.shape[:-1], self.n_heads, 3 * self.head_dim) + q, k, v = torch.split(xqkv, [self.head_dim] * 3, dim=-1) # [B,S,H,D] + + if self.with_qk_norm: + q = self.q_norm(q) + k = self.k_norm(k) + + if self.with_rope: + if rope_positions is not None: + F, Ht, W = rope_positions + assert F * Ht * W == S, "rope_positions mismatches sequence length" + + cos, sin = cos_sin + cos = cos.to(x.device, dtype=x.dtype) + sin = sin.to(x.device, dtype=x.dtype) + + q = self._apply_rope(q, cos, sin) + k = self._apply_rope(k, cos, sin) + + output, _ = self.attn(q, k, v) # [B,heads,S,D] + + output = rearrange(output, "b s h d -> b s (h d)") + output, _ = self.wo(output) + + return output + + +class CrossAttention(nn.Module): + + def __init__( + self, + hidden_dim, + head_dim, + bias=False, + with_qk_norm=True, + supported_attention_backends=( + AttentionBackendEnum.FA3, + AttentionBackendEnum.TORCH_SDPA, + ), + ) -> None: + super().__init__() + self.head_dim = head_dim + self.n_heads = hidden_dim // head_dim + + self.wq = ReplicatedLinear(hidden_dim, hidden_dim, bias=bias) + self.wkv = ReplicatedLinear(hidden_dim, hidden_dim * 2, bias=bias) + self.wo = ReplicatedLinear(hidden_dim, hidden_dim, bias=bias) + + self.with_qk_norm = with_qk_norm + if self.with_qk_norm: + self.q_norm = StepVideoRMSNorm(head_dim, elementwise_affine=True) + self.k_norm = StepVideoRMSNorm(head_dim, elementwise_affine=True) + + self.attn = LocalAttention( + num_heads=self.n_heads, + head_size=head_dim, + causal=False, + supported_attention_backends=supported_attention_backends, + ) + + def forward( + self, x: torch.Tensor, encoder_hidden_states: torch.Tensor, attn_mask=None + ) -> torch.Tensor: + + xq, _ = self.wq(x) + xq = xq.view(*xq.shape[:-1], self.n_heads, self.head_dim) + + xkv, _ = self.wkv(encoder_hidden_states) + xkv = xkv.view(*xkv.shape[:-1], self.n_heads, 2 * self.head_dim) + + xk, xv = torch.split(xkv, [self.head_dim] * 2, dim=-1) ## seq_len, n, dim + + if self.with_qk_norm: + xq = self.q_norm(xq) + xk = self.k_norm(xk) + + output = self.attn(xq, xk, xv) + + output = rearrange(output, "b s h d -> b s (h d)") + output, _ = self.wo(output) + + return output + + +class AdaLayerNormSingle(nn.Module): + r""" + Norm layer adaptive layer norm single (adaLN-single). + + As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + use_additional_conditions (`bool`): To use additional conditions for normalization or not. + """ + + def __init__(self, embedding_dim: int, time_step_rescale=1000): + super().__init__() + + self.emb = TimestepEmbedder(embedding_dim) + + self.silu = nn.SiLU() + self.linear = ReplicatedLinear(embedding_dim, 6 * embedding_dim, bias=True) + + self.time_step_rescale = time_step_rescale ## timestep usually in [0, 1], we rescale it to [0,1000] for stability + + def forward( + self, + timestep: torch.Tensor, + added_cond_kwargs: dict[str, torch.Tensor] | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + embedded_timestep = self.emb(timestep * self.time_step_rescale) + + out, _ = self.linear(self.silu(embedded_timestep)) + + return out, embedded_timestep + + +class StepVideoTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + positional_embeddings (`str`, *optional*, defaults to `None`): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + """ + + def __init__( + self, + dim: int, + attention_head_dim: int, + norm_eps: float = 1e-5, + ff_inner_dim: int | None = None, + ff_bias: bool = False, + attention_type: str = "torch", + ): + super().__init__() + self.dim = dim + self.norm1 = LayerNormScaleShift( + dim, norm_type="layer", elementwise_affine=True, eps=norm_eps + ) + self.attn1 = SelfAttention( + dim, + attention_head_dim, + bias=False, + with_rope=True, + with_qk_norm=True, + ) + + self.norm2 = LayerNormScaleShift( + dim, norm_type="layer", elementwise_affine=True, eps=norm_eps + ) + self.attn2 = CrossAttention( + dim, attention_head_dim, bias=False, with_qk_norm=True + ) + + self.ff = MLP( + input_dim=dim, + mlp_hidden_dim=dim * 4 if ff_inner_dim is None else ff_inner_dim, + act_type="gelu_pytorch_tanh", + bias=ff_bias, + ) + + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + + @torch.no_grad() + def forward( + self, + q: torch.Tensor, + kv: torch.Tensor, + t_expand: torch.LongTensor, + attn_mask=None, + rope_positions: list | None = None, + cos_sin=None, + mask_strategy=None, + ) -> torch.Tensor: + + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + torch.clone(chunk) + for chunk in ( + self.scale_shift_table[None] + t_expand.reshape(-1, 6, self.dim) + ).chunk(6, dim=1) + ) + + scale_shift_q = self.norm1( + q, scale=scale_msa.squeeze(1), shift=shift_msa.squeeze(1) + ) + + attn_q = self.attn1( + scale_shift_q, + rope_positions=rope_positions, + cos_sin=cos_sin, + mask_strategy=mask_strategy, + ) + + q = attn_q * gate_msa + q + + attn_q = self.attn2(q, kv, attn_mask) + + q = attn_q + q + + scale_shift_q = self.norm2( + q, scale=scale_mlp.squeeze(1), shift=shift_mlp.squeeze(1) + ) + + ff_output = self.ff(scale_shift_q) + + q = ff_output * gate_mlp + q + + return q + + +class StepVideoModel(BaseDiT): + # (Optional) Keep the same attribute for compatibility with splitting, etc. + _fsdp_shard_conditions = [ + lambda n, m: "transformer_blocks" in n and n.split(".")[-1].isdigit(), + # lambda n, m: "pos_embed" in n # If needed for the patch embedding. + ] + param_names_mapping = StepVideoConfig().param_names_mapping + reverse_param_names_mapping = StepVideoConfig().reverse_param_names_mapping + lora_param_names_mapping = StepVideoConfig().lora_param_names_mapping + _supported_attention_backends = StepVideoConfig()._supported_attention_backends + + def __init__(self, config: StepVideoConfig, hf_config: dict[str, Any]) -> None: + super().__init__(config=config, hf_config=hf_config) + self.num_attention_heads = config.num_attention_heads + self.attention_head_dim = config.attention_head_dim + self.in_channels = config.in_channels + self.out_channels = config.out_channels + self.num_layers = config.num_layers + self.dropout = config.dropout + self.patch_size = config.patch_size + self.norm_type = config.norm_type + self.norm_elementwise_affine = config.norm_elementwise_affine + self.norm_eps = config.norm_eps + self.use_additional_conditions = config.use_additional_conditions + self.caption_channels = config.caption_channels + self.attention_type = config.attention_type + self.num_channels_latents = config.num_channels_latents + # Compute inner dimension. + self.hidden_size = config.hidden_size + + # Image/video patch embedding. + self.pos_embed = PatchEmbed2D( + patch_size=self.patch_size, + in_chans=self.in_channels, + embed_dim=self.hidden_size, + ) + + self._rope_cache: dict[tuple, tuple[torch.Tensor, torch.Tensor]] = {} + # Transformer blocks. + self.transformer_blocks = nn.ModuleList( + [ + StepVideoTransformerBlock( + dim=self.hidden_size, + attention_head_dim=self.attention_head_dim, + attention_type=self.attention_type, + ) + for _ in range(self.num_layers) + ] + ) + + # Output blocks. + self.norm_out = LayerNormScaleShift( + self.hidden_size, + norm_type="layer", + eps=self.norm_eps, + elementwise_affine=self.norm_elementwise_affine, + ) + self.scale_shift_table = nn.Parameter( + torch.randn(2, self.hidden_size) / (self.hidden_size**0.5) + ) + self.proj_out = ReplicatedLinear( + self.hidden_size, self.patch_size * self.patch_size * self.out_channels + ) + # Time modulation via adaptive layer norm. + self.adaln_single = AdaLayerNormSingle(self.hidden_size) + + # Set up caption conditioning. + if isinstance(self.caption_channels, int): + caption_channel = self.caption_channels + else: + caption_channel, clip_channel = self.caption_channels + self.clip_projection = ReplicatedLinear(clip_channel, self.hidden_size) + self.caption_norm = nn.LayerNorm( + caption_channel, + eps=self.norm_eps, + elementwise_affine=self.norm_elementwise_affine, + ) + self.caption_projection = MLP( + input_dim=caption_channel, + mlp_hidden_dim=self.hidden_size, + act_type="gelu_pytorch_tanh", + ) + + # Flag to indicate if using parallel attention. + self.parallel = self.attention_type == "parallel" + + self.__post_init__() + + def patchfy(self, hidden_states) -> torch.Tensor: + hidden_states = rearrange(hidden_states, "b f c h w -> (b f) c h w") + hidden_states = self.pos_embed(hidden_states) + return hidden_states + + def prepare_attn_mask( + self, encoder_attention_mask, encoder_hidden_states, q_seqlen + ) -> tuple[torch.Tensor, torch.Tensor]: + kv_seqlens = encoder_attention_mask.sum(dim=1).int() + mask = torch.zeros( + [len(kv_seqlens), q_seqlen, max(kv_seqlens)], + dtype=torch.bool, + device=encoder_attention_mask.device, + ) + encoder_hidden_states = encoder_hidden_states[:, : max(kv_seqlens)] + for i, kv_len in enumerate(kv_seqlens): + mask[i, :, :kv_len] = 1 + return encoder_hidden_states, mask + + def block_forward( + self, + hidden_states, + encoder_hidden_states=None, + t_expand=None, + rope_positions=None, + cos_sin=None, + attn_mask=None, + parallel=True, + mask_strategy=None, + ) -> torch.Tensor: + + for i, block in enumerate(self.transformer_blocks): + hidden_states = block( + hidden_states, + encoder_hidden_states, + t_expand=t_expand, + attn_mask=attn_mask, + rope_positions=rope_positions, + cos_sin=cos_sin, + mask_strategy=mask_strategy[i], + ) + + return hidden_states + + def _get_rope( + self, + rope_positions: tuple[int, int, int], + dtype: torch.dtype, + device: torch.device, + ): + F, Ht, W = rope_positions + key = (F, Ht, W, dtype) + if key not in self._rope_cache: + cos, sin = get_rotary_pos_embed( + rope_sizes=(F * get_sp_world_size(), Ht, W), + hidden_size=self.hidden_size, + heads_num=self.hidden_size // self.attention_head_dim, + rope_dim_list=(64, 32, 32), # same split you used + rope_theta=1.0e4, + dtype=torch.float32, # build once in fp32 + ) + # move & cast once + self._rope_cache[key] = ( + cos.to(device, dtype=dtype), + sin.to(device, dtype=dtype), + ) + return self._rope_cache[key] + + @torch.inference_mode() + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None = None, + t_expand: torch.LongTensor | None = None, + encoder_hidden_states_2: torch.Tensor | None = None, + added_cond_kwargs: dict[str, torch.Tensor] | None = None, + encoder_attention_mask: torch.Tensor | None = None, + fps: torch.Tensor | None = None, + return_dict: bool = True, + mask_strategy=None, + guidance=None, + ): + assert hidden_states.ndim == 5 + "hidden_states's shape should be (bsz, f, ch, h ,w)" + frame = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> b f c h w", f=frame) + if mask_strategy is None: + mask_strategy = [None, None] + bsz, frame, _, height, width = hidden_states.shape + height, width = height // self.patch_size, width // self.patch_size + + hidden_states = self.patchfy(hidden_states) + len_frame = hidden_states.shape[1] + + t_expand, embedded_timestep = self.adaln_single(t_expand) + encoder_hidden_states = self.caption_projection( + self.caption_norm(encoder_hidden_states) + ) + + if encoder_hidden_states_2 is not None and hasattr(self, "clip_projection"): + clip_embedding, _ = self.clip_projection(encoder_hidden_states_2) + encoder_hidden_states = torch.cat( + [clip_embedding, encoder_hidden_states], dim=1 + ) + + hidden_states = rearrange( + hidden_states, "(b f) l d-> b (f l) d", b=bsz, f=frame, l=len_frame + ).contiguous() + encoder_hidden_states, attn_mask = self.prepare_attn_mask( + encoder_attention_mask, encoder_hidden_states, q_seqlen=frame * len_frame + ) + + cos_sin = self._get_rope( + (frame, height, width), hidden_states.dtype, hidden_states.device + ) + + hidden_states = self.block_forward( + hidden_states, + encoder_hidden_states, + t_expand=t_expand, + rope_positions=[frame, height, width], + cos_sin=cos_sin, + attn_mask=attn_mask, + parallel=self.parallel, + mask_strategy=mask_strategy, + ) + + hidden_states = rearrange( + hidden_states, "b (f l) d -> (b f) l d", b=bsz, f=frame, l=len_frame + ) + + embedded_timestep = repeat( + embedded_timestep, "b d -> (b f) d", f=frame + ).contiguous() + + shift, scale = ( + self.scale_shift_table[None] + embedded_timestep[:, None] + ).chunk(2, dim=1) + hidden_states = self.norm_out( + hidden_states, shift=shift.squeeze(1), scale=scale.squeeze(1) + ) + # Modulation + hidden_states, _ = self.proj_out(hidden_states) + + # unpatchify + hidden_states = hidden_states.reshape( + shape=( + -1, + height, + width, + self.patch_size, + self.patch_size, + self.out_channels, + ) + ) + + hidden_states = rearrange(hidden_states, "n h w p q c -> n c h p w q") + output = hidden_states.reshape( + shape=( + -1, + self.out_channels, + height * self.patch_size, + width * self.patch_size, + ) + ) + + output = rearrange(output, "(b f) c h w -> b c f h w", f=frame) + return output + + +EntryClass = StepVideoModel diff --git a/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py b/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py new file mode 100644 index 000000000..8a746fa08 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/models/dits/wanvideo.py @@ -0,0 +1,945 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +import math +from typing import Any + +import numpy as np +import torch +import torch.nn as nn + +from sglang.multimodal_gen.configs.models.dits import WanVideoConfig +from sglang.multimodal_gen.configs.sample.wan import WanTeaCacheParams +from sglang.multimodal_gen.runtime.distributed.parallel_state import get_sp_world_size +from sglang.multimodal_gen.runtime.layers.attention import ( + LocalAttention, + UlyssesAttention_VSA, + USPAttention, +) +from sglang.multimodal_gen.runtime.layers.layernorm import ( + FP32LayerNorm, + LayerNormScaleShift, + RMSNorm, + ScaleResidual, + ScaleResidualLayerNormScaleShift, +) +from sglang.multimodal_gen.runtime.layers.linear import ReplicatedLinear +from sglang.multimodal_gen.runtime.layers.mlp import MLP +from sglang.multimodal_gen.runtime.layers.rotary_embedding import ( + NDRotaryEmbedding, + _apply_rotary_emb, +) +from sglang.multimodal_gen.runtime.layers.visual_embedding import ( + ModulateProjection, + PatchEmbed, + TimestepEmbedder, +) +from sglang.multimodal_gen.runtime.managers.forward_context import get_forward_context +from sglang.multimodal_gen.runtime.models.dits.base import CachableDiT +from sglang.multimodal_gen.runtime.platforms import ( + AttentionBackendEnum, + current_platform, +) +from sglang.multimodal_gen.runtime.server_args import get_global_server_args +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class WanImageEmbedding(torch.nn.Module): + + def __init__(self, in_features: int, out_features: int): + super().__init__() + + self.norm1 = FP32LayerNorm(in_features) + self.ff = MLP(in_features, in_features, out_features, act_type="gelu") + self.norm2 = FP32LayerNorm(out_features) + + def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor: + dtype = encoder_hidden_states_image.dtype + hidden_states = self.norm1(encoder_hidden_states_image) + hidden_states = self.ff(hidden_states) + hidden_states = self.norm2(hidden_states).to(dtype) + return hidden_states + + +class WanTimeTextImageEmbedding(nn.Module): + + def __init__( + self, + dim: int, + time_freq_dim: int, + text_embed_dim: int, + image_embed_dim: int | None = None, + ): + super().__init__() + + self.time_embedder = TimestepEmbedder( + dim, frequency_embedding_size=time_freq_dim, act_layer="silu" + ) + self.time_modulation = ModulateProjection(dim, factor=6, act_layer="silu") + self.text_embedder = MLP( + text_embed_dim, dim, dim, bias=True, act_type="gelu_pytorch_tanh" + ) + + self.image_embedder = None + if image_embed_dim is not None: + self.image_embedder = WanImageEmbedding(image_embed_dim, dim) + + def forward( + self, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_image: torch.Tensor | None = None, + timestep_seq_len: int | None = None, + ): + temb = self.time_embedder(timestep, timestep_seq_len) + timestep_proj = self.time_modulation(temb) + + encoder_hidden_states = self.text_embedder(encoder_hidden_states) + if encoder_hidden_states_image is not None: + assert self.image_embedder is not None + encoder_hidden_states_image = self.image_embedder( + encoder_hidden_states_image + ) + + return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image + + +class WanSelfAttention(nn.Module): + + def __init__( + self, + dim: int, + num_heads: int, + window_size=(-1, -1), + qk_norm=True, + eps=1e-6, + parallel_attention=False, + supported_attention_backends: set[AttentionBackendEnum] | None = None, + ) -> None: + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.window_size = window_size + self.qk_norm = qk_norm + self.eps = eps + self.parallel_attention = parallel_attention + + # layers + self.to_q = ReplicatedLinear(dim, dim) + self.to_k = ReplicatedLinear(dim, dim) + self.to_v = ReplicatedLinear(dim, dim) + self.to_out = ReplicatedLinear(dim, dim) + self.norm_q = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + self.norm_k = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + + # Scaled dot product attention + self.attn = LocalAttention( + num_heads=num_heads, + head_size=self.head_dim, + dropout_rate=0, + softmax_scale=None, + causal=False, + supported_attention_backends=supported_attention_backends, + ) + + def forward(self, x: torch.Tensor, context: torch.Tensor, context_lens: int): + r""" + Args: + x(Tensor): Shape [B, L, num_heads, C / num_heads] + seq_lens(Tensor): Shape [B] + grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + """ + pass + + +class WanT2VCrossAttention(WanSelfAttention): + + def forward(self, x, context, context_lens, crossattn_cache=None): + r""" + Args: + x(Tensor): Shape [B, L1, C] + context(Tensor): Shape [B, L2, C] + context_lens(Tensor): Shape [B] + """ + b, n, d = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.norm_q(self.to_q(x)[0]).view(b, -1, n, d) + + if crossattn_cache is not None: + if not crossattn_cache["is_init"]: + crossattn_cache["is_init"] = True + k = self.norm_k(self.to_k(context)[0]).view(b, -1, n, d) + v = self.to_v(context)[0].view(b, -1, n, d) + crossattn_cache["k"] = k + crossattn_cache["v"] = v + else: + k = crossattn_cache["k"] + v = crossattn_cache["v"] + else: + k = self.norm_k(self.to_k(context)[0]).view(b, -1, n, d) + v = self.to_v(context)[0].view(b, -1, n, d) + + # compute attention + x = self.attn(q, k, v) + + # output + x = x.flatten(2) + x, _ = self.to_out(x) + return x + + +class WanI2VCrossAttention(WanSelfAttention): + + def __init__( + self, + dim: int, + num_heads: int, + window_size=(-1, -1), + qk_norm=True, + eps=1e-6, + supported_attention_backends: set[AttentionBackendEnum] | None = None, + ) -> None: + # VSA should not be in supported_attention_backends + super().__init__( + dim, + num_heads, + window_size, + qk_norm, + eps, + supported_attention_backends=supported_attention_backends, + ) + + self.add_k_proj = ReplicatedLinear(dim, dim) + self.add_v_proj = ReplicatedLinear(dim, dim) + self.norm_added_k = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + self.norm_added_q = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + + def forward(self, x, context, context_lens): + r""" + Args: + x(Tensor): Shape [B, L1, C] + context(Tensor): Shape [B, L2, C] + context_lens(Tensor): Shape [B] + """ + context_img = context[:, :257] + context = context[:, 257:] + b, n, d = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.norm_q(self.to_q(x)[0]).view(b, -1, n, d) + k = self.norm_k(self.to_k(context)[0]).view(b, -1, n, d) + v = self.to_v(context)[0].view(b, -1, n, d) + k_img = self.norm_added_k(self.add_k_proj(context_img)[0]).view(b, -1, n, d) + v_img = self.add_v_proj(context_img)[0].view(b, -1, n, d) + img_x = self.attn(q, k_img, v_img) + # compute attention + x = self.attn(q, k, v) + + # output + x = x.flatten(2) + img_x = img_x.flatten(2) + x = x + img_x + x, _ = self.to_out(x) + return x + + +class WanTransformerBlock(nn.Module): + + def __init__( + self, + dim: int, + ffn_dim: int, + num_heads: int, + qk_norm: str = "rms_norm_across_heads", + cross_attn_norm: bool = False, + eps: float = 1e-6, + added_kv_proj_dim: int | None = None, + supported_attention_backends: set[AttentionBackendEnum] | None = None, + prefix: str = "", + ): + super().__init__() + + # 1. Self-attention + self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) + self.to_q = ReplicatedLinear(dim, dim, bias=True) + self.to_k = ReplicatedLinear(dim, dim, bias=True) + self.to_v = ReplicatedLinear(dim, dim, bias=True) + + self.to_out = ReplicatedLinear(dim, dim, bias=True) + self.attn1 = USPAttention( + num_heads=num_heads, + head_size=dim // num_heads, + causal=False, + supported_attention_backends=supported_attention_backends, + prefix=f"{prefix}.attn1", + ) + + self.hidden_dim = dim + self.num_attention_heads = num_heads + dim_head = dim // num_heads + if qk_norm == "rms_norm": + self.norm_q = RMSNorm(dim_head, eps=eps) + self.norm_k = RMSNorm(dim_head, eps=eps) + elif qk_norm == "rms_norm_across_heads": + # LTX applies qk norm across all heads + self.norm_q = RMSNorm(dim, eps=eps) + self.norm_k = RMSNorm(dim, eps=eps) + else: + logger.error("QK Norm type not supported") + raise Exception + assert cross_attn_norm is True + self.self_attn_residual_norm = ScaleResidualLayerNormScaleShift( + dim, + norm_type="layer", + eps=eps, + elementwise_affine=True, + dtype=torch.float32, + compute_dtype=torch.float32, + ) + + # 2. Cross-attention + if added_kv_proj_dim is not None: + # I2V + self.attn2 = WanI2VCrossAttention( + dim, + num_heads, + qk_norm=qk_norm, + eps=eps, + supported_attention_backends=supported_attention_backends, + ) + else: + # T2V + self.attn2 = WanT2VCrossAttention( + dim, + num_heads, + qk_norm=qk_norm, + eps=eps, + supported_attention_backends=supported_attention_backends, + ) + self.cross_attn_residual_norm = ScaleResidualLayerNormScaleShift( + dim, + norm_type="layer", + eps=eps, + elementwise_affine=False, + dtype=torch.float32, + compute_dtype=torch.float32, + ) + + # 3. Feed-forward + self.ffn = MLP(dim, ffn_dim, act_type="gelu_pytorch_tanh") + self.mlp_residual = ScaleResidual() + + self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + freqs_cis: tuple[torch.Tensor, torch.Tensor], + ) -> torch.Tensor: + if hidden_states.dim() == 4: + hidden_states = hidden_states.squeeze(1) + bs, seq_length, _ = hidden_states.shape + orig_dtype = hidden_states.dtype + + if temb.dim() == 4: + # temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v) + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + self.scale_shift_table.unsqueeze(0) + temb.float() + ).chunk(6, dim=2) + # batch_size, seq_len, 1, inner_dim + shift_msa = shift_msa.squeeze(2) + scale_msa = scale_msa.squeeze(2) + gate_msa = gate_msa.squeeze(2) + c_shift_msa = c_shift_msa.squeeze(2) + c_scale_msa = c_scale_msa.squeeze(2) + c_gate_msa = c_gate_msa.squeeze(2) + else: + # temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B) + e = self.scale_shift_table + temb.float() + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + e.chunk(6, dim=1) + ) + assert shift_msa.dtype == torch.float32 + + # 1. Self-attention + norm_hidden_states = ( + self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa + ).to(orig_dtype) + query, _ = self.to_q(norm_hidden_states) + key, _ = self.to_k(norm_hidden_states) + value, _ = self.to_v(norm_hidden_states) + + if self.norm_q is not None: + query = self.norm_q(query) + if self.norm_k is not None: + key = self.norm_k(key) + + query = query.squeeze(1).unflatten(2, (self.num_attention_heads, -1)) + key = key.squeeze(1).unflatten(2, (self.num_attention_heads, -1)) + value = value.squeeze(1).unflatten(2, (self.num_attention_heads, -1)) + + # Apply rotary embeddings + cos, sin = freqs_cis + query, key = _apply_rotary_emb( + query, cos, sin, is_neox_style=False + ), _apply_rotary_emb(key, cos, sin, is_neox_style=False) + attn_output, _ = self.attn1(query, key, value) + attn_output = attn_output.flatten(2) + attn_output, _ = self.to_out(attn_output) + attn_output = attn_output.squeeze(1) + + null_shift = null_scale = torch.zeros( + (1,), device=hidden_states.device, dtype=hidden_states.dtype + ) + norm_hidden_states, hidden_states = self.self_attn_residual_norm( + hidden_states, attn_output, gate_msa, null_shift, null_scale + ) + norm_hidden_states, hidden_states = norm_hidden_states.to( + orig_dtype + ), hidden_states.to(orig_dtype) + + # 2. Cross-attention + attn_output = self.attn2( + norm_hidden_states, context=encoder_hidden_states, context_lens=None + ) + norm_hidden_states, hidden_states = self.cross_attn_residual_norm( + hidden_states, attn_output, 1, c_shift_msa, c_scale_msa + ) + norm_hidden_states, hidden_states = norm_hidden_states.to( + orig_dtype + ), hidden_states.to(orig_dtype) + + # 3. Feed-forward + ff_output = self.ffn(norm_hidden_states) + hidden_states = self.mlp_residual(hidden_states, ff_output, c_gate_msa) + hidden_states = hidden_states.to(orig_dtype) + + return hidden_states + + +class WanTransformerBlock_VSA(nn.Module): + + def __init__( + self, + dim: int, + ffn_dim: int, + num_heads: int, + qk_norm: str = "rms_norm_across_heads", + cross_attn_norm: bool = False, + eps: float = 1e-6, + added_kv_proj_dim: int | None = None, + supported_attention_backends: set[AttentionBackendEnum] | None = None, + prefix: str = "", + ): + super().__init__() + + # 1. Self-attention + self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) + self.to_q = ReplicatedLinear(dim, dim, bias=True) + self.to_k = ReplicatedLinear(dim, dim, bias=True) + self.to_v = ReplicatedLinear(dim, dim, bias=True) + self.to_gate_compress = ReplicatedLinear(dim, dim, bias=True) + + self.to_out = ReplicatedLinear(dim, dim, bias=True) + self.attn1 = UlyssesAttention_VSA( + num_heads=num_heads, + head_size=dim // num_heads, + causal=False, + supported_attention_backends=supported_attention_backends, + prefix=f"{prefix}.attn1", + ) + self.hidden_dim = dim + self.num_attention_heads = num_heads + dim_head = dim // num_heads + if qk_norm == "rms_norm": + self.norm_q = RMSNorm(dim_head, eps=eps) + self.norm_k = RMSNorm(dim_head, eps=eps) + elif qk_norm == "rms_norm_across_heads": + # LTX applies qk norm across all heads + self.norm_q = RMSNorm(dim, eps=eps) + self.norm_k = RMSNorm(dim, eps=eps) + else: + logger.error("QK Norm type not supported") + raise Exception + assert cross_attn_norm is True + self.self_attn_residual_norm = ScaleResidualLayerNormScaleShift( + dim, + norm_type="layer", + eps=eps, + elementwise_affine=True, + dtype=torch.float32, + compute_dtype=torch.float32, + ) + + if AttentionBackendEnum.VIDEO_SPARSE_ATTN in supported_attention_backends: + supported_attention_backends.remove(AttentionBackendEnum.VIDEO_SPARSE_ATTN) + # 2. Cross-attention + if added_kv_proj_dim is not None: + # I2V + self.attn2 = WanI2VCrossAttention( + dim, + num_heads, + qk_norm=qk_norm, + eps=eps, + supported_attention_backends=supported_attention_backends, + ) + else: + # T2V + self.attn2 = WanT2VCrossAttention( + dim, + num_heads, + qk_norm=qk_norm, + eps=eps, + supported_attention_backends=supported_attention_backends, + ) + self.cross_attn_residual_norm = ScaleResidualLayerNormScaleShift( + dim, + norm_type="layer", + eps=eps, + elementwise_affine=False, + dtype=torch.float32, + compute_dtype=torch.float32, + ) + + # 3. Feed-forward + self.ffn = MLP(dim, ffn_dim, act_type="gelu_pytorch_tanh") + self.mlp_residual = ScaleResidual() + + self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + freqs_cis: tuple[torch.Tensor, torch.Tensor], + ) -> torch.Tensor: + if hidden_states.dim() == 4: + hidden_states = hidden_states.squeeze(1) + bs, seq_length, _ = hidden_states.shape + orig_dtype = hidden_states.dtype + # assert orig_dtype != torch.float32 + e = self.scale_shift_table + temb.float() + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = e.chunk( + 6, dim=1 + ) + assert shift_msa.dtype == torch.float32 + + # 1. Self-attention + norm_hidden_states = ( + self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa + ).to(orig_dtype) + query, _ = self.to_q(norm_hidden_states) + key, _ = self.to_k(norm_hidden_states) + value, _ = self.to_v(norm_hidden_states) + gate_compress, _ = self.to_gate_compress(norm_hidden_states) + + if self.norm_q is not None: + query = self.norm_q(query) + if self.norm_k is not None: + key = self.norm_k(key) + + query = query.squeeze(1).unflatten(2, (self.num_attention_heads, -1)) + key = key.squeeze(1).unflatten(2, (self.num_attention_heads, -1)) + value = value.squeeze(1).unflatten(2, (self.num_attention_heads, -1)) + gate_compress = gate_compress.squeeze(1).unflatten( + 2, (self.num_attention_heads, -1) + ) + + # Apply rotary embeddings + cos, sin = freqs_cis + query, key = _apply_rotary_emb( + query, cos, sin, is_neox_style=False + ), _apply_rotary_emb(key, cos, sin, is_neox_style=False) + + attn_output, _ = self.attn1(query, key, value, gate_compress=gate_compress) + attn_output = attn_output.flatten(2) + attn_output, _ = self.to_out(attn_output) + attn_output = attn_output.squeeze(1) + + null_shift = null_scale = torch.zeros((1,), device=hidden_states.device) + norm_hidden_states, hidden_states = self.self_attn_residual_norm( + hidden_states, attn_output, gate_msa, null_shift, null_scale + ) + norm_hidden_states, hidden_states = norm_hidden_states.to( + orig_dtype + ), hidden_states.to(orig_dtype) + + # 2. Cross-attention + attn_output = self.attn2( + norm_hidden_states, context=encoder_hidden_states, context_lens=None + ) + norm_hidden_states, hidden_states = self.cross_attn_residual_norm( + hidden_states, attn_output, 1, c_shift_msa, c_scale_msa + ) + norm_hidden_states, hidden_states = norm_hidden_states.to( + orig_dtype + ), hidden_states.to(orig_dtype) + + # 3. Feed-forward + ff_output = self.ffn(norm_hidden_states) + hidden_states = self.mlp_residual(hidden_states, ff_output, c_gate_msa) + hidden_states = hidden_states.to(orig_dtype) + + return hidden_states + + +class WanTransformer3DModel(CachableDiT): + _fsdp_shard_conditions = WanVideoConfig()._fsdp_shard_conditions + _compile_conditions = WanVideoConfig()._compile_conditions + _supported_attention_backends = WanVideoConfig()._supported_attention_backends + param_names_mapping = WanVideoConfig().param_names_mapping + reverse_param_names_mapping = WanVideoConfig().reverse_param_names_mapping + lora_param_names_mapping = WanVideoConfig().lora_param_names_mapping + + def __init__(self, config: WanVideoConfig, hf_config: dict[str, Any]) -> None: + super().__init__(config=config, hf_config=hf_config) + + inner_dim = config.num_attention_heads * config.attention_head_dim + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.in_channels = config.in_channels + self.out_channels = config.out_channels + self.num_channels_latents = config.num_channels_latents + self.patch_size = config.patch_size + self.text_len = config.text_len + + # 1. Patch & position embedding + self.patch_embedding = PatchEmbed( + in_chans=config.in_channels, + embed_dim=inner_dim, + patch_size=config.patch_size, + flatten=False, + ) + + # 2. Condition embeddings + self.condition_embedder = WanTimeTextImageEmbedding( + dim=inner_dim, + time_freq_dim=config.freq_dim, + text_embed_dim=config.text_dim, + image_embed_dim=config.image_dim, + ) + + # 3. Transformer blocks + attn_backend = get_global_server_args().attention_backend + transformer_block = ( + WanTransformerBlock_VSA + if (attn_backend and attn_backend.lower() == "video_sparse_attn") + else WanTransformerBlock + ) + self.blocks = nn.ModuleList( + [ + transformer_block( + inner_dim, + config.ffn_dim, + config.num_attention_heads, + config.qk_norm, + config.cross_attn_norm, + config.eps, + config.added_kv_proj_dim, + self._supported_attention_backends + | {AttentionBackendEnum.VIDEO_SPARSE_ATTN}, + prefix=f"{config.prefix}.blocks.{i}", + ) + for i in range(config.num_layers) + ] + ) + + # 4. Output norm & projection + self.norm_out = LayerNormScaleShift( + inner_dim, + norm_type="layer", + eps=config.eps, + elementwise_affine=False, + dtype=torch.float32, + compute_dtype=torch.float32, + ) + self.proj_out = nn.Linear( + inner_dim, config.out_channels * math.prod(config.patch_size) + ) + self.scale_shift_table = nn.Parameter( + torch.randn(1, 2, inner_dim) / inner_dim**0.5 + ) + + # For type checking + self.previous_e0_even = None + self.previous_e0_odd = None + self.previous_residual_even = None + self.previous_residual_odd = None + self.is_even = True + self.should_calc_even = True + self.should_calc_odd = True + self.accumulated_rel_l1_distance_even = 0 + self.accumulated_rel_l1_distance_odd = 0 + self.cnt = 0 + self.__post_init__() + + # misc + self.sp_size = get_sp_world_size() + + # Get rotary embeddings + d = self.hidden_size // self.num_attention_heads + self.rope_dim_list = [d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)] + + self.rope = NDRotaryEmbedding( + rope_dim_list=self.rope_dim_list, + rope_theta=10000, + dtype=torch.float32 if current_platform.is_mps() else torch.float64, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | list[torch.Tensor], + timestep: torch.LongTensor, + encoder_hidden_states_image: torch.Tensor | list[torch.Tensor] | None = None, + guidance=None, + **kwargs, + ) -> torch.Tensor: + forward_batch = get_forward_context().forward_batch + enable_teacache = forward_batch is not None and forward_batch.enable_teacache + + orig_dtype = hidden_states.dtype + if not isinstance(encoder_hidden_states, torch.Tensor): + encoder_hidden_states = encoder_hidden_states[0] + if ( + isinstance(encoder_hidden_states_image, list) + and len(encoder_hidden_states_image) > 0 + ): + encoder_hidden_states_image = encoder_hidden_states_image[0] + else: + encoder_hidden_states_image = None + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + + freqs_cos, freqs_sin = self.rope.forward_from_grid( + ( + post_patch_num_frames * self.sp_size, + post_patch_height, + post_patch_width, + ), + shard_dim=0, + start_frame=0, + device=hidden_states.device, + ) + assert freqs_cos.dtype == torch.float32 + assert freqs_cos.device == hidden_states.device + freqs_cis = ( + (freqs_cos.float(), freqs_sin.float()) if freqs_cos is not None else None + ) + + hidden_states = self.patch_embedding(hidden_states) + hidden_states = hidden_states.flatten(2).transpose(1, 2) + + # timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v) + if timestep.dim() == 2: + ts_seq_len = timestep.shape[1] + timestep = timestep.flatten() # batch_size * seq_len + else: + ts_seq_len = None + + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = ( + self.condition_embedder( + timestep, + encoder_hidden_states, + encoder_hidden_states_image, + timestep_seq_len=ts_seq_len, + ) + ) + if ts_seq_len is not None: + # batch_size, seq_len, 6, inner_dim + timestep_proj = timestep_proj.unflatten(2, (6, -1)) + else: + # batch_size, 6, inner_dim + timestep_proj = timestep_proj.unflatten(1, (6, -1)) + + if encoder_hidden_states_image is not None: + encoder_hidden_states = torch.concat( + [encoder_hidden_states_image, encoder_hidden_states], dim=1 + ) + + encoder_hidden_states = ( + encoder_hidden_states.to(orig_dtype) + if current_platform.is_mps() + else encoder_hidden_states + ) # cast to orig_dtype for MPS + + assert encoder_hidden_states.dtype == orig_dtype + + # 4. Transformer blocks + # if caching is enabled, we might be able to skip the forward pass + should_skip_forward = self.should_skip_forward_for_cached_states( + timestep_proj=timestep_proj, temb=temb + ) + + if should_skip_forward: + hidden_states = self.retrieve_cached_states(hidden_states) + else: + # if teacache is enabled, we need to cache the original hidden states + if enable_teacache: + original_hidden_states = hidden_states.clone() + + for block in self.blocks: + hidden_states = block( + hidden_states, encoder_hidden_states, timestep_proj, freqs_cis + ) + # if teacache is enabled, we need to cache the original hidden states + if enable_teacache: + self.maybe_cache_states(hidden_states, original_hidden_states) + # 5. Output norm, projection & unpatchify + if temb.dim() == 3: + # batch_size, seq_len, inner_dim (wan 2.2 ti2v) + shift, scale = ( + self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2) + ).chunk(2, dim=2) + shift = shift.squeeze(2) + scale = scale.squeeze(2) + else: + # batch_size, inner_dim + shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) + + hidden_states = self.norm_out(hidden_states, shift, scale) + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.reshape( + batch_size, + post_patch_num_frames, + post_patch_height, + post_patch_width, + p_t, + p_h, + p_w, + -1, + ) + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + return output + + def maybe_cache_states( + self, hidden_states: torch.Tensor, original_hidden_states: torch.Tensor + ) -> None: + if self.is_even: + self.previous_residual_even = ( + hidden_states.squeeze(0) - original_hidden_states + ) + else: + self.previous_residual_odd = ( + hidden_states.squeeze(0) - original_hidden_states + ) + + def should_skip_forward_for_cached_states(self, **kwargs) -> bool: + + forward_context = get_forward_context() + forward_batch = forward_context.forward_batch + if forward_batch is None or not forward_batch.enable_teacache: + return False + teacache_params = forward_batch.teacache_params + assert teacache_params is not None, "teacache_params is not initialized" + assert isinstance( + teacache_params, WanTeaCacheParams + ), "teacache_params is not a WanTeaCacheParams" + current_timestep = forward_context.current_timestep + num_inference_steps = forward_batch.num_inference_steps + + # initialize the coefficients, cutoff_steps, and ret_steps + coefficients = teacache_params.coefficients + use_ret_steps = teacache_params.use_ret_steps + cutoff_steps = teacache_params.get_cutoff_steps(num_inference_steps) + ret_steps = teacache_params.ret_steps + teacache_thresh = teacache_params.teacache_thresh + + if current_timestep == 0: + self.cnt = 0 + + timestep_proj = kwargs["timestep_proj"] + temb = kwargs["temb"] + modulated_inp = timestep_proj if use_ret_steps else temb + + if self.cnt % 2 == 0: # even -> condition + self.is_even = True + if self.cnt < ret_steps or self.cnt >= cutoff_steps: + self.should_calc_even = True + self.accumulated_rel_l1_distance_even = 0 + else: + assert ( + self.previous_e0_even is not None + ), "previous_e0_even is not initialized" + assert ( + self.accumulated_rel_l1_distance_even is not None + ), "accumulated_rel_l1_distance_even is not initialized" + rescale_func = np.poly1d(coefficients) + self.accumulated_rel_l1_distance_even += rescale_func( + ( + (modulated_inp - self.previous_e0_even).abs().mean() + / self.previous_e0_even.abs().mean() + ) + .cpu() + .item() + ) + if self.accumulated_rel_l1_distance_even < teacache_thresh: + self.should_calc_even = False + else: + self.should_calc_even = True + self.accumulated_rel_l1_distance_even = 0 + self.previous_e0_even = modulated_inp.clone() + + else: # odd -> unconditon + self.is_even = False + if self.cnt < ret_steps or self.cnt >= cutoff_steps: + self.should_calc_odd = True + self.accumulated_rel_l1_distance_odd = 0 + else: + assert ( + self.previous_e0_odd is not None + ), "previous_e0_odd is not initialized" + assert ( + self.accumulated_rel_l1_distance_odd is not None + ), "accumulated_rel_l1_distance_odd is not initialized" + rescale_func = np.poly1d(coefficients) + self.accumulated_rel_l1_distance_odd += rescale_func( + ( + (modulated_inp - self.previous_e0_odd).abs().mean() + / self.previous_e0_odd.abs().mean() + ) + .cpu() + .item() + ) + if self.accumulated_rel_l1_distance_odd < teacache_thresh: + self.should_calc_odd = False + else: + self.should_calc_odd = True + self.accumulated_rel_l1_distance_odd = 0 + self.previous_e0_odd = modulated_inp.clone() + self.cnt += 1 + should_skip_forward = False + if self.is_even: + if not self.should_calc_even: + should_skip_forward = True + else: + if not self.should_calc_odd: + should_skip_forward = True + + return should_skip_forward + + def retrieve_cached_states(self, hidden_states: torch.Tensor) -> torch.Tensor: + if self.is_even: + return hidden_states + self.previous_residual_even + else: + return hidden_states + self.previous_residual_odd + + +EntryClass = WanTransformer3DModel diff --git a/python/sglang/multimodal_gen/runtime/models/encoders/base.py b/python/sglang/multimodal_gen/runtime/models/encoders/base.py new file mode 100644 index 000000000..a36c616cc --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/models/encoders/base.py @@ -0,0 +1,71 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from abc import ABC, abstractmethod +from dataclasses import field + +import torch +from torch import nn + +from sglang.multimodal_gen.configs.models.encoders import ( + BaseEncoderOutput, + ImageEncoderConfig, + TextEncoderConfig, +) +from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum + + +class TextEncoder(nn.Module, ABC): + _fsdp_shard_conditions: list = field(default_factory=lambda: []) + _stacked_params_mapping: list[tuple[str, str, str]] = field(default_factory=list) + _supported_attention_backends: set[AttentionBackendEnum] = ( + TextEncoderConfig()._supported_attention_backends + ) + + def __init__(self, config: TextEncoderConfig) -> None: + super().__init__() + self.config = config + self._fsdp_shard_conditions = config._fsdp_shard_conditions + self._stacked_params_mapping = config.arch_config.stacked_params_mapping + if not self.supported_attention_backends: + raise ValueError( + f"Subclass {self.__class__.__name__} must define _supported_attention_backends" + ) + + @abstractmethod + def forward( + self, + input_ids: torch.Tensor | None, + position_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + inputs_embeds: torch.Tensor | None = None, + output_hidden_states: bool | None = None, + **kwargs, + ) -> BaseEncoderOutput: + pass + + @property + def supported_attention_backends(self) -> set[AttentionBackendEnum]: + return self._supported_attention_backends + + +class ImageEncoder(nn.Module, ABC): + _supported_attention_backends: set[AttentionBackendEnum] = ( + ImageEncoderConfig()._supported_attention_backends + ) + + def __init__(self, config: ImageEncoderConfig) -> None: + super().__init__() + self.config = config + if not self.supported_attention_backends: + raise ValueError( + f"Subclass {self.__class__.__name__} must define _supported_attention_backends" + ) + + @abstractmethod + def forward(self, pixel_values: torch.Tensor, **kwargs) -> BaseEncoderOutput: + pass + + @property + def supported_attention_backends(self) -> set[AttentionBackendEnum]: + return self._supported_attention_backends diff --git a/python/sglang/multimodal_gen/runtime/models/encoders/bert.py b/python/sglang/multimodal_gen/runtime/models/encoders/bert.py new file mode 100644 index 000000000..5a423e51b --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/models/encoders/bert.py @@ -0,0 +1,46 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# type: ignore +import os + +import torch +import torch.nn as nn +from transformers import BertModel, BertTokenizer + + +class HunyuanClip(nn.Module): + """ + Hunyuan clip code copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py + hunyuan's clip used BertModel and BertTokenizer, so we copy it. + """ + + def __init__(self, model_dir, max_length=77): + super().__init__() + + self.max_length = max_length + self.tokenizer = BertTokenizer.from_pretrained( + os.path.join(model_dir, "tokenizer") + ) + self.text_encoder = BertModel.from_pretrained( + os.path.join(model_dir, "clip_text_encoder") + ) + + @torch.no_grad + def forward(self, prompts, with_mask=True): + self.device = next(self.text_encoder.parameters()).device + text_inputs = self.tokenizer( + prompts, + padding="max_length", + max_length=self.max_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + prompt_embeds = self.text_encoder( + text_inputs.input_ids.to(self.device), + attention_mask=( + text_inputs.attention_mask.to(self.device) if with_mask else None + ), + ) + return prompt_embeds.last_hidden_state, prompt_embeds.pooler_output diff --git a/python/sglang/multimodal_gen/runtime/models/encoders/clip.py b/python/sglang/multimodal_gen/runtime/models/encoders/clip.py new file mode 100644 index 000000000..ec80e387f --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/models/encoders/clip.py @@ -0,0 +1,700 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/models/clip.py +# Adapted from transformers: https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py +"""Minimal implementation of CLIPVisionModel intended to be only used +within a vision language model.""" +from collections.abc import Iterable +from typing import Optional + +import torch +import torch.nn as nn + +from sglang.multimodal_gen.configs.models.encoders import ( + BaseEncoderOutput, + CLIPTextConfig, + CLIPVisionConfig, +) +from sglang.multimodal_gen.runtime.distributed import divide, get_tp_world_size +from sglang.multimodal_gen.runtime.layers.activation import get_act_fn +from sglang.multimodal_gen.runtime.layers.attention import LocalAttention +from sglang.multimodal_gen.runtime.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from sglang.multimodal_gen.runtime.layers.quantization import QuantizationConfig + +# TODO: support quantization +# from vllm.model_executor.layers.quantization import QuantizationConfig +from sglang.multimodal_gen.runtime.loader.weight_utils import default_weight_loader +from sglang.multimodal_gen.runtime.models.encoders.base import ImageEncoder, TextEncoder +from sglang.multimodal_gen.runtime.models.encoders.vision import ( + resolve_visual_encoder_outputs, +) +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa +class CLIPVisionEmbeddings(nn.Module): + + def __init__(self, config: CLIPVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + assert self.image_size % self.patch_size == 0 + + self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.register_buffer( + "position_ids", + torch.arange(self.num_positions).expand((1, -1)), + persistent=False, + ) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + batch_size = pixel_values.shape[0] + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding( + pixel_values.to(dtype=target_dtype) + ) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + embeddings = embeddings + self.position_embedding(self.position_ids) + + return embeddings + + +class CLIPTextEmbeddings(nn.Module): + + def __init__(self, config: CLIPTextConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) + self.position_embedding = nn.Embedding( + config.max_position_embeddings, embed_dim + ) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", + torch.arange(config.max_position_embeddings).expand((1, -1)), + persistent=False, + ) + + def forward( + self, + input_ids: torch.LongTensor | None = None, + position_ids: torch.LongTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + ) -> torch.Tensor: + if input_ids is not None: + seq_length = input_ids.shape[-1] + elif inputs_embeds is not None: + seq_length = inputs_embeds.shape[-2] + else: + raise ValueError("Either input_ids or inputs_embeds must be provided.") + + max_position_embedding = self.position_embedding.weight.shape[0] + + if seq_length > max_position_embedding: + raise ValueError( + f"Sequence length must be less than max_position_embeddings (got `sequence length`: " + f"{seq_length} and max_position_embeddings: {max_position_embedding}" + ) + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings + + +class CLIPAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: CLIPVisionConfig | CLIPTextConfig, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + "embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.qkv_proj = QKVParallelLinear( + hidden_size=self.embed_dim, + head_size=self.head_dim, + total_num_heads=self.num_heads, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.out_proj = RowParallelLinear( + input_size=self.embed_dim, + output_size=self.embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + + self.tp_size = get_tp_world_size() + self.num_heads_per_partition = divide(self.num_heads, self.tp_size) + + self.attn = LocalAttention( + self.num_heads_per_partition, + self.head_dim, + self.num_heads_per_partition, + softmax_scale=self.scale, + causal=False, + supported_attention_backends=config._supported_attention_backends, + ) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return ( + tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) + + def forward( + self, + hidden_states: torch.Tensor, + ): + """Input shape: Batch x Time x Channel""" + + qkv_states, _ = self.qkv_proj(hidden_states) + query_states, key_states, value_states = qkv_states.chunk(3, dim=-1) + # use flash_attn_func + query_states = query_states.reshape( + query_states.shape[0], + query_states.shape[1], + self.num_heads_per_partition, + self.head_dim, + ) + key_states = key_states.reshape( + key_states.shape[0], + key_states.shape[1], + self.num_heads_per_partition, + self.head_dim, + ) + value_states = value_states.reshape( + value_states.shape[0], + value_states.shape[1], + self.num_heads_per_partition, + self.head_dim, + ) + attn_output = self.attn(query_states, key_states, value_states) + + attn_output = attn_output.reshape( + attn_output.shape[0], + attn_output.shape[1], + self.num_heads_per_partition * self.head_dim, + ) + attn_output, _ = self.out_proj(attn_output) + + return attn_output, None + + +class CLIPMLP(nn.Module): + + def __init__( + self, + config: CLIPVisionConfig | CLIPTextConfig, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.activation_fn = get_act_fn(config.hidden_act) + self.fc1 = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + ) + self.fc2 = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + + return hidden_states + + +class CLIPEncoderLayer(nn.Module): + + def __init__( + self, + config: CLIPTextConfig | CLIPVisionConfig, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.self_attn = CLIPAttention( + config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = CLIPMLP(config, quant_config=quant_config, prefix=f"{prefix}.mlp") + self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, _ = self.self_attn(hidden_states=hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class CLIPEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self + attention layers. Each layer is a [`CLIPEncoderLayer`]. + + Args: + config: CLIPConfig + """ + + def __init__( + self, + config: CLIPVisionConfig | CLIPTextConfig, + quant_config: QuantizationConfig | None = None, + num_hidden_layers_override: int | None = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + + if num_hidden_layers_override is None: + num_hidden_layers = config.num_hidden_layers + else: + num_hidden_layers = num_hidden_layers_override + self.layers = nn.ModuleList( + [ + CLIPEncoderLayer( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}", + ) + for layer_idx in range(num_hidden_layers) + ] + ) + + def forward( + self, inputs_embeds: torch.Tensor, return_all_hidden_states: bool + ) -> torch.Tensor | list[torch.Tensor]: + hidden_states_pool = [inputs_embeds] + hidden_states = inputs_embeds + + for idx, encoder_layer in enumerate(self.layers): + hidden_states = encoder_layer(hidden_states) + if return_all_hidden_states: + hidden_states_pool.append(hidden_states) + # If we have multiple feature sample layers, we return all hidden + # states in order and grab the ones we need by index. + if return_all_hidden_states: + return hidden_states_pool + return [hidden_states] + + +class CLIPTextTransformer(nn.Module): + + def __init__( + self, + config: CLIPTextConfig, + quant_config: QuantizationConfig | None = None, + num_hidden_layers_override: int | None = None, + prefix: str = "", + ): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = CLIPTextEmbeddings(config) + + self.encoder = CLIPEncoder( + config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers_override, + prefix=prefix, + ) + + self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + # For `pooled_output` computation + self.eos_token_id = config.eos_token_id + + def forward( + self, + input_ids: torch.Tensor | None, + position_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + inputs_embeds: torch.Tensor | None = None, + output_hidden_states: bool | None = None, + ) -> BaseEncoderOutput: + r""" + Returns: + + """ + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + + if input_ids is None: + raise ValueError("You have to specify input_ids") + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) + + # CLIP's text model uses causal mask, prepare it here. + # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 + # causal_attention_mask = _create_4d_causal_attention_mask( + # input_shape, hidden_states.dtype, device=hidden_states.device + # ) + + # # expand attention_mask + # if attention_mask is not None and not self._use_flash_attention_2: + # raise NotImplementedError("attention_mask is not supported for CLIPTextTransformer") + # # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + # attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + # attention_mask=attention_mask, + # causal_attention_mask=causal_attention_mask, + # output_attentions=output_attentions, + return_all_hidden_states=output_hidden_states, + # return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[-1] + last_hidden_state = self.final_layer_norm(last_hidden_state) + + if self.eos_token_id == 2: + # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here. + # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added + # ------------------------------------------------------------ + # text_embeds.shape = [batch_size, sequence_length, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 + pooled_output = last_hidden_state[ + torch.arange( + last_hidden_state.shape[0], device=last_hidden_state.device + ), + input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax( + dim=-1 + ), + ] + else: + # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible) + pooled_output = last_hidden_state[ + torch.arange( + last_hidden_state.shape[0], device=last_hidden_state.device + ), + # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`) + # Note: we assume each sequence (along batch dim.) contains an `eos_token_id` (e.g. prepared by the tokenizer) + ( + input_ids.to(dtype=torch.int, device=last_hidden_state.device) + == self.eos_token_id + ) + .int() + .argmax(dim=-1), + ] + + return BaseEncoderOutput( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs, + # attentions=encoder_outputs.attentions, + ) + + +class CLIPTextModel(TextEncoder): + + def __init__( + self, + config: CLIPTextConfig, + ) -> None: + super().__init__(config) + self.text_model = CLIPTextTransformer( + config=config, quant_config=config.quant_config, prefix=config.prefix + ) + + def forward( + self, + input_ids: torch.Tensor | None, + position_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + inputs_embeds: torch.Tensor | None = None, + output_hidden_states: bool | None = None, + **kwargs, + ) -> BaseEncoderOutput: + + outputs: BaseEncoderOutput = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_hidden_states=output_hidden_states, + ) + return outputs + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + + # Define mapping for stacked parameters + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + # Handle q_proj, k_proj, v_proj -> qkv_proj mapping + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name in name: + # Replace the weight name with the parameter name + model_param_name = name.replace(weight_name, param_name) + + if model_param_name in params_dict: + param = params_dict[model_param_name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + loaded_params.add(model_param_name) + break + else: + # Use default weight loader for all other parameters + if name in params_dict: + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params + + +class CLIPVisionTransformer(nn.Module): + + def __init__( + self, + config: CLIPVisionConfig, + quant_config: QuantizationConfig | None = None, + num_hidden_layers_override: int | None = None, + require_post_norm: bool | None = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + embed_dim = config.hidden_size + + self.embeddings = CLIPVisionEmbeddings(config) + + # NOTE: This typo of "layrnorm" is not fixed on purpose to match + # the original transformers code and name of the model weights. + self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + self.encoder = CLIPEncoder( + config=config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers_override, + prefix=f"{prefix}.encoder", + ) + + num_hidden_layers = config.num_hidden_layers + if len(self.encoder.layers) > config.num_hidden_layers: + raise ValueError( + f"The original encoder only has {num_hidden_layers} " + f"layers, but you requested {len(self.encoder.layers)} layers." + ) + + # If possible, skip post_layernorm to conserve memory + if require_post_norm is None: + require_post_norm = len(self.encoder.layers) == num_hidden_layers + + if require_post_norm: + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + else: + self.post_layernorm = None + + def forward( + self, + pixel_values: torch.Tensor, + output_hidden_states: Optional[bool] = None, + feature_sample_layers: list[int] | None = None, + ) -> BaseEncoderOutput: + + hidden_states = self.embeddings(pixel_values) + hidden_states = self.pre_layrnorm(hidden_states) + + return_all_hidden_states = output_hidden_states or ( + feature_sample_layers is not None + ) + + # Produces either the last layer output or all of the hidden states, + # depending on if we have feature_sample_layers or not + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + return_all_hidden_states=return_all_hidden_states, + ) + + if not return_all_hidden_states: + encoder_outputs = encoder_outputs[0] + + # Handle post-norm (if applicable) and stacks feature layers if needed + encoder_outputs = resolve_visual_encoder_outputs( + encoder_outputs, + feature_sample_layers, + self.post_layernorm, + self.config.num_hidden_layers, + ) + + if return_all_hidden_states: + return BaseEncoderOutput(hidden_states=encoder_outputs) + + return BaseEncoderOutput(last_hidden_state=encoder_outputs) + + +class CLIPVisionModel(ImageEncoder): + config_class = CLIPVisionConfig + main_input_name = "pixel_values" + packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} + + def __init__(self, config: CLIPVisionConfig) -> None: + super().__init__(config) + self.vision_model = CLIPVisionTransformer( + config=config, + quant_config=config.quant_config, + num_hidden_layers_override=config.num_hidden_layers_override, + require_post_norm=config.require_post_norm, + prefix=f"{config.prefix}.vision_model", + ) + + def forward( + self, + pixel_values: torch.Tensor, + feature_sample_layers: list[int] | None = None, + output_hidden_states: Optional[bool] = None, + **kwargs, + ) -> BaseEncoderOutput: + base_encoder_output = self.vision_model( + pixel_values, + output_hidden_states=output_hidden_states, + feature_sample_layers=feature_sample_layers, + ) + + return base_encoder_output + + @property + def device(self): + return next(self.parameters()).device + + # (TODO) Add prefix argument for filtering out weights to be loaded + # ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986 + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + layer_count = len(self.vision_model.encoder.layers) + + for name, loaded_weight in weights: + if name.startswith("visual_projection"): + continue + # post_layernorm is not needed in CLIPVisionModel + if ( + name.startswith("vision_model.post_layernorm") + and self.vision_model.post_layernorm is None + ): + continue + + # omit layers when num_hidden_layers_override is set + if name.startswith("vision_model.encoder.layers"): + layer_idx = int(name.split(".")[3]) + if layer_idx >= layer_count: + continue + + for ( + param_name, + weight_name, + shard_id, + ) in self.config.arch_config.stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class BertModel(CLIPTextModel): + pass + + +EntryClass = [CLIPTextModel, CLIPVisionModel] diff --git a/python/sglang/multimodal_gen/runtime/models/encoders/llama.py b/python/sglang/multimodal_gen/runtime/models/encoders/llama.py new file mode 100644 index 000000000..ea208f124 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/models/encoders/llama.py @@ -0,0 +1,459 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/models/llama.py + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only LLaMA model compatible with HuggingFace weights.""" +from collections.abc import Iterable +from typing import Any + +import torch +from torch import nn + +# from ..utils import (extract_layer_index) +from sglang.multimodal_gen.configs.models.encoders import BaseEncoderOutput, LlamaConfig +from sglang.multimodal_gen.runtime.distributed import get_tp_world_size +from sglang.multimodal_gen.runtime.layers.activation import SiluAndMul + +# from vllm.model_executor.layers.quantization import QuantizationConfig +from sglang.multimodal_gen.runtime.layers.attention import LocalAttention +from sglang.multimodal_gen.runtime.layers.layernorm import RMSNorm +from sglang.multimodal_gen.runtime.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from sglang.multimodal_gen.runtime.layers.quantization import QuantizationConfig +from sglang.multimodal_gen.runtime.layers.rotary_embedding import get_rope +from sglang.multimodal_gen.runtime.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, +) +from sglang.multimodal_gen.runtime.loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from sglang.multimodal_gen.runtime.models.encoders.base import TextEncoder + + +class LlamaMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: QuantizationConfig | None = None, + bias: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + # output_size=intermediate_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now." + ) + self.act_fn = SiluAndMul() + + def forward(self, x): + x, _ = self.gate_up_proj(x) + x = self.act_fn(x) + x, _ = self.down_proj(x) + return x + + +class LlamaAttention(nn.Module): + + def __init__( + self, + config: LlamaConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: dict[str, Any] | None = None, + max_position_embeddings: int = 8192, + quant_config: QuantizationConfig | None = None, + bias: bool = False, + bias_o_proj: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + # layer_idx = extract_layer_index(prefix) + self.hidden_size = hidden_size + tp_size = get_tp_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + # MistralConfig has an optional head_dim introduced by Mistral-Nemo + self.head_dim = getattr( + config, "head_dim", self.hidden_size // self.total_num_heads + ) + # Phi models introduced a partial_rotary_factor parameter in the config + partial_rotary_factor = getattr(config, "partial_rotary_factor", 1) + self.rotary_dim = int(partial_rotary_factor * self.head_dim) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias_o_proj, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + is_neox_style = True + is_gguf = ( + quant_config + and hasattr(quant_config, "get_name") + and quant_config.get_name() == "gguf" + ) + if is_gguf and config.model_type == "llama": + is_neox_style = False + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.rotary_dim, + max_position=max_position_embeddings, + base=int(rope_theta), + rope_scaling=rope_scaling, + is_neox_style=is_neox_style, + ) + + self.attn = LocalAttention( + self.num_heads, + self.head_dim, + self.num_kv_heads, + softmax_scale=self.scaling, + causal=True, + supported_attention_backends=config._supported_attention_backends, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + # attn_output = self.attn(q, k, v) + # use flash_attn_func + # TODO (Attn abstraction and backend) + # reshape q, k, v to (batch_size, seq_len, num_heads, head_dim) + batch_size = q.shape[0] + seq_len = q.shape[1] + q = q.reshape(batch_size, seq_len, self.num_heads, self.head_dim) + k = k.reshape(batch_size, seq_len, self.num_kv_heads, self.head_dim) + v = v.reshape(batch_size, seq_len, self.num_kv_heads, self.head_dim) + # import pdb; pdb.set_trace() + # attn_output = flash_attn_varlen_func(q, k, v, softmax_scale=self.scaling, causal=True) + attn_output = self.attn(q, k, v) + attn_output = attn_output.reshape( + batch_size, seq_len, self.num_heads * self.head_dim + ) + + output, _ = self.o_proj(attn_output) + return output + + +class LlamaDecoderLayer(nn.Module): + + def __init__( + self, + config: LlamaConfig, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None + ): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings + ) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + # Support abacusai/Smaug-72B-v0.1 with attention_bias + # Support internlm/internlm-7b with bias + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False + ) + bias_o_proj = attention_bias + # support internlm/internlm3-8b with qkv_bias + if hasattr(config, "qkv_bias"): + attention_bias = config.qkv_bias + + self.self_attn = LlamaAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr( + config, "num_key_value_heads", config.num_attention_heads + ), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + bias_o_proj=bias_o_proj, + prefix=f"{prefix}.self_attn", + ) + self.mlp = LlamaMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class LlamaModel(TextEncoder): + + def __init__( + self, + config: LlamaConfig, + ): + super().__init__(config) + + self.config = config + self.quant_config = self.config.quant_config + if config.lora_config is not None: + max_loras = 1 + lora_vocab_size = 1 + if hasattr(config.lora_config, "max_loras"): + max_loras = config.lora_config.max_loras + if hasattr(config.lora_config, "lora_extra_vocab_size"): + lora_vocab_size = config.lora_config.lora_extra_vocab_size + lora_vocab = lora_vocab_size * max_loras + else: + lora_vocab = 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=config.quant_config, + ) + + self.layers = nn.ModuleList( + [ + LlamaDecoderLayer( + config=config, + quant_config=config.quant_config, + prefix=f"{config.prefix}.layers.{i}", + ) + for i in range(config.num_hidden_layers) + ] + ) + + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor | None, + position_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + inputs_embeds: torch.Tensor | None = None, + output_hidden_states: bool | None = None, + **kwargs, + ) -> BaseEncoderOutput: + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + + if position_ids is None: + position_ids = torch.arange( + 0, hidden_states.shape[1], device=hidden_states.device + ).unsqueeze(0) + + all_hidden_states: tuple[Any, ...] | None = () if output_hidden_states else None + for layer in self.layers: + if all_hidden_states is not None: + # TODO + all_hidden_states += ( + (hidden_states,) + if residual is None + else (hidden_states + residual,) + ) + hidden_states, residual = layer(position_ids, hidden_states, residual) + + hidden_states, _ = self.norm(hidden_states, residual) + + # add hidden states from the last decoder layer + if all_hidden_states is not None: + all_hidden_states += (hidden_states,) + + # TODO(will): maybe unify the output format with other models and use + # our own class + output = BaseEncoderOutput( + last_hidden_state=hidden_states, + # past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + # attentions=all_self_attns, + ) + + return output + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + # if (self.quant_config is not None and + # (scale_name := self.quant_config.get_cache_scale(name))): + # # Loading kv cache quantization scales + # param = params_dict[scale_name] + # weight_loader = getattr(param, "weight_loader", + # default_weight_loader) + # loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + # loaded_weight[0]) + # weight_loader(param, loaded_weight) + # loaded_params.add(scale_name) + # continue + if "scale" in name: + # Remapping the name of FP8 kv-scale. + kv_scale_name: str | None = maybe_remap_kv_scale_name(name, params_dict) + if kv_scale_name is None: + continue + else: + name = kv_scale_name + for ( + param_name, + weight_name, + shard_id, + ) in self.config.arch_config.stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +EntryClass = LlamaModel diff --git a/python/sglang/multimodal_gen/runtime/models/encoders/qwen2_5vl.py b/python/sglang/multimodal_gen/runtime/models/encoders/qwen2_5vl.py new file mode 100644 index 000000000..08184cccb --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/models/encoders/qwen2_5vl.py @@ -0,0 +1,1181 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +from types import SimpleNamespace + +from transformers import ( + Cache, + DynamicCache, + PretrainedConfig, + Qwen2_5_VLTextConfig, + Qwen2RMSNorm, +) +from transformers.masking_utils import ( + create_causal_mask, + create_sliding_window_causal_mask, +) +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.utils import TransformersKwargs, is_torchdynamo_compiling + +from sglang.multimodal_gen.configs.models.encoders.qwen_image import Qwen2_5VLConfig +from sglang.multimodal_gen.runtime.layers.attention import LocalAttention +from sglang.multimodal_gen.runtime.layers.linear import ( + MergedColumnParallelLinear, + RowParallelLinear, +) +from sglang.multimodal_gen.runtime.layers.quantization import QuantizationConfig +from sglang.multimodal_gen.runtime.loader.weight_utils import default_weight_loader +from sglang.multimodal_gen.runtime.models.encoders.base import TextEncoder +from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum +from sglang.multimodal_gen.runtime.utils.common import add_prefix + +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen2-VL model compatible with HuggingFace weights.""" +import logging +from typing import Callable, Iterable, Optional, Tuple, Union, Unpack + +import torch +import torch.nn as nn +from transformers.activations import ACT2FN +from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VisionTransformerPretrainedModel, + Qwen2_5_VLAttention, + Qwen2_5_VLCausalLMOutputWithPast, + Qwen2_5_VLModelOutputWithPast, + Qwen2_5_VLRotaryEmbedding, + Qwen2MLP, + apply_multimodal_rotary_pos_emb, + eager_attention_forward, +) + +logger = logging.getLogger(__name__) + + +class Qwen2_5_VLAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: Qwen2_5_VLTextConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warn( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.is_causal = True + self.attention_dropout = config.attention_dropout + self.rope_scaling = config.rope_scaling + self.scaling = self.head_dim**-0.5 + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.head_dim, bias=True + ) + self.k_proj = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True + ) + self.v_proj = nn.Linear( + self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True + ) + self.o_proj = nn.Linear( + self.num_heads * self.head_dim, self.hidden_size, bias=False + ) + self.sliding_window = ( + config.sliding_window + if config.layer_types[layer_idx] == "sliding_attention" + else None + ) + + self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config) + self.attn = LocalAttention( + num_heads=self.num_heads, + head_size=self.head_dim, + num_kv_heads=self.num_key_value_heads, + softmax_scale=self.scaling, + causal=True, + supported_attention_backends=( + AttentionBackendEnum.FA3, + AttentionBackendEnum.TORCH_SDPA, + ), + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + tuple[torch.Tensor, torch.Tensor] + ] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_values is not None: + cache_kwargs = { + "sin": sin, + "cos": cos, + "cache_position": cache_position, + } # Specific to RoPE models + key_states, value_states = past_key_values.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + attention_interface: Callable = eager_attention_forward + # if self.config._attn_implementation != "eager": + # attention_interface = ALL_ATTENTION_FUNCTIONS["sdpa"] + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + attn_output = self.attn(query_states, key_states, value_states) + # + # attn_output, attn_weights = attention_interface( + # self, + # query_states, + # key_states, + # value_states, + # attention_mask, + # dropout=0.0 if not self.training else self.attention_dropout, + # scaling=self.scaling, + # sliding_window=self.sliding_window, + # position_ids=position_ids, # pass positions for FA2 + # **kwargs, + # ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output + + +class Qwen2_5_VLDecoderLayer(nn.Module): + def __init__(self, config: Qwen2_5_VLTextConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + if ( + config.use_sliding_window + and config._attn_implementation != "flash_attention_2" + ): + logger.warning( + f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " + "unexpected results may be encountered." + ) + self.self_attn = Qwen2_5_VLAttention(config, layer_idx) + + self.mlp = Qwen2MLP(config) + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.attention_type = config.layer_types[layer_idx] + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + tuple[torch.Tensor, torch.Tensor] + ] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[ + torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class Qwen2_5_VLMLP(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: int = None, + bias: bool = True, + hidden_act="silu", + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=in_features, + output_sizes=[hidden_features] * 2, # [gate_proj, up_proj] + bias=bias, + quant_config=quant_config, + prefix=add_prefix("gate_up_proj", prefix), + ) + self.down_proj = RowParallelLinear( + hidden_features, + in_features, + bias=bias, + quant_config=quant_config, + prefix=add_prefix("down_proj", prefix), + ) + self.act = ACT2FN[hidden_act] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + gate, up = gate_up.chunk(2, dim=-1) + x = self.act(gate) * up + x_down, _ = self.down_proj(x) + return x_down + + +class Qwen2_5_VLTextModel(nn.Module): + def __init__(self, config: PretrainedConfig): + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = nn.ModuleList( + [ + Qwen2_5_VLDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self._attn_implementation = config._attn_implementation + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config) + self.has_sliding_layers = "sliding_attention" in self.config.layer_types + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + # self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[tuple, BaseModelOutputWithPast]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # torch.jit.trace() doesn't support cache objects in the output + if use_cache and past_key_values is None and not torch.jit.is_tracing(): + past_key_values = DynamicCache(config=self.config) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + # the hard coded `3` is for temporal, height and width. + if position_ids is None: + position_ids = cache_position.view(1, 1, -1).expand( + 3, inputs_embeds.shape[0], -1 + ) + elif position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + # NOTE: we need to pass text position ids for packing. Qwen2-VL uses 3D positions + # where each dim indicates visual spatial positions for temporal/height/width grids. + # There are two scenarios when FA2-like packed masking might be activated. + # 1. User specifically passed packed `position_ids` and no attention mask. + # In this case we expect the user to create correct position ids for all 3 grids + # and prepend text-only position ids to it. The final tensor will be [4, bs, seq-len] + # 2. User runs forward with no attention mask and no position ids. In this case, position ids + # are prepared by the model (`get_rope_index`) as `[4, bs, seq-len]` tensor. Text-only positions are + # prepended by us when creating positions so that the mask is constructed correctly. NOTE: failing to pass + # text-only positions will cause incorrect mask construction, do not change `prepare_input_for_generation` + if position_ids.ndim == 3 and position_ids.shape[0] == 4: + text_position_ids = position_ids[0] + position_ids = position_ids[1:] + else: + text_position_ids = position_ids[0] + + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + # Prepare mask arguments + mask_kwargs = { + "config": self.config, + "input_embeds": inputs_embeds, + "attention_mask": attention_mask, + "cache_position": cache_position, + "past_key_values": past_key_values, + "position_ids": text_position_ids, + } + # Create the masks + causal_mask_mapping = { + "full_attention": create_causal_mask(**mask_kwargs), + } + # The sliding window alternating layers are not always activated depending on the config + if self.has_sliding_layers: + causal_mask_mapping["sliding_attention"] = ( + create_sliding_window_causal_mask(**mask_kwargs) + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_ids=text_position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + past_key_values, + all_hidden_states, + all_self_attns, + ] + if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class Qwen2_5_VLModel(nn.Module): + base_model_prefix = "" + _checkpoint_conversion_mapping = {"^model": "language_model"} + # Reference: fix gemma3 grad acc #37208 + accepts_loss_kwargs = False + _no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"] + + def __init__(self, config): + super().__init__() + self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config( + config.vision_config + ) + self.language_model = Qwen2_5_VLTextModel(config.text_config) + self.visual.to(torch.get_default_dtype()) + self.rope_deltas = None # cache rope_deltas here + self.config = config + # Initialize weights and apply final processing + # self.post_init() + + def get_input_embeddings(self): + return self.language_model.embed_tokens + + def set_input_embeddings(self, value): + self.language_model.embed_tokens = value + + def set_decoder(self, decoder): + self.language_model = decoder + + def get_decoder(self): + return self.language_model + + def get_rope_index( + self, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Calculate the 3D rope index based on image and video's temporal, height and width in LLM. + + Explanation: + Each embedding sequence contains vision embedding and text embedding or just contains text embedding. + + For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs. + Examples: + input_ids: [T T T T T], here T is for text. + temporal position_ids: [0, 1, 2, 3, 4] + height position_ids: [0, 1, 2, 3, 4] + width position_ids: [0, 1, 2, 3, 4] + + For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part + and 1D rotary position embedding for text part. + Examples: + Temporal (Time): 3 patches, representing different segments of the video in time. + Height: 2 patches, dividing each frame vertically. + Width: 2 patches, dividing each frame horizontally. + We also have some important parameters: + fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second. + tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity. + temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames. + interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs. + input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. + vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100] + vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] + vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + text temporal position_ids: [101, 102, 103, 104, 105] + text height position_ids: [101, 102, 103, 104, 105] + text width position_ids: [101, 102, 103, 104, 105] + Here we calculate the text start position_ids as the max vision position_ids plus 1. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*): + The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + Returns: + position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) + mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) + """ + spatial_merge_size = self.config.vision_config.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + mrope_position_deltas = [] + if input_ids is not None and ( + image_grid_thw is not None or video_grid_thw is not None + ): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere( + input_ids == vision_start_token_id + ).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + second_per_grid_t = 0 + image_index += 1 + remain_images -= 1 + ed = ed_image + + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + if second_per_grid_ts is not None: + second_per_grid_t = second_per_grid_ts[video_index] + else: + second_per_grid_t = 1.0 + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + range_tensor = torch.arange(llm_grid_t).view(-1, 1) + expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) + + ## normalize type, send to device. + second_per_grid_t = torch.as_tensor( + second_per_grid_t, + dtype=range_tensor.dtype, + device=range_tensor.device, + ) + + time_tensor = ( + expanded_range + * second_per_grid_t + * self.config.vision_config.tokens_per_second + ) + + time_tensor_long = time_tensor.long() + t_index = time_tensor_long.flatten() + + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx + ) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to( + position_ids.device + ) + mrope_position_deltas.append( + llm_positions.max() + 1 - len(total_input_ids[i]) + ) + mrope_position_deltas = torch.tensor( + mrope_position_deltas, device=input_ids.device + ).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = ( + position_ids.unsqueeze(0) + .expand(3, -1, -1) + .to(attention_mask.device) + ) + max_position_ids = position_ids.max(0, keepdim=False)[0].max( + -1, keepdim=True + )[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + def get_video_features( + self, + pixel_values_videos: torch.FloatTensor, + video_grid_thw: Optional[torch.LongTensor] = None, + ): + """ + Encodes videos into continuous embeddings that can be forwarded to the language model. + + Args: + pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input videos. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + """ + pixel_values_videos = pixel_values_videos.type(self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + split_sizes = ( + video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2 + ).tolist() + video_embeds = torch.split(video_embeds, split_sizes) + return video_embeds + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + image_grid_thw: Optional[torch.LongTensor] = None, + ): + """ + Encodes images into continuous embeddings that can be forwarded to the language model. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): + The tensors corresponding to the input images. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + """ + pixel_values = pixel_values.type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + split_sizes = ( + image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2 + ).tolist() + image_embeds = torch.split(image_embeds, split_sizes) + return image_embeds + + def get_placeholder_mask( + self, + input_ids: torch.LongTensor, + inputs_embeds: torch.FloatTensor, + image_features: torch.FloatTensor = None, + video_features: torch.FloatTensor = None, + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor( + self.config.image_token_id, + dtype=torch.long, + device=inputs_embeds.device, + ) + ) + special_image_mask = special_image_mask.all(-1) + special_video_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor( + self.config.video_token_id, + dtype=torch.long, + device=inputs_embeds.device, + ) + ) + special_video_mask = special_video_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + special_video_mask = input_ids == self.config.video_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = ( + special_image_mask.unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + if ( + image_features is not None + and inputs_embeds[special_image_mask].numel() != image_features.numel() + ): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}" + ) + + n_video_tokens = special_video_mask.sum() + special_video_mask = ( + special_video_mask.unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + if ( + video_features is not None + and inputs_embeds[special_video_mask].numel() != video_features.numel() + ): + raise ValueError( + f"Videos features and video tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}" + ) + + return special_image_mask, special_video_mask + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, Qwen2_5_VLModelOutputWithPast]: + r""" + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*): + The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. + """ + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None: + image_embeds = self.get_image_features(pixel_values, image_grid_thw) + image_embeds = torch.cat(image_embeds, dim=0).to( + inputs_embeds.device, inputs_embeds.dtype + ) + image_mask, _ = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) + video_embeds = torch.cat(video_embeds, dim=0).to( + inputs_embeds.device, inputs_embeds.dtype + ) + _, video_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, video_features=video_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + if position_ids is None: + # Calculate RoPE index once per generation in the pre-fill stage only. + # When compiling, we can't check tensor values thus we check only input length + # It is safe to assume that `length!=1` means we're in pre-fill because compiled + # models currently cannot do asssisted decoding + prefill_compiled_stage = is_torchdynamo_compiling() and ( + (input_ids is not None and input_ids.shape[1] != 1) + or (inputs_embeds is not None and inputs_embeds.shape[1] != 1) + ) + prefill_noncompiled_stage = not is_torchdynamo_compiling() and ( + (cache_position is not None and cache_position[0] == 0) + or (past_key_values is None or past_key_values.get_seq_length() == 0) + ) + if ( + prefill_compiled_stage or prefill_noncompiled_stage + ) or self.rope_deltas is None: + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + attention_mask=attention_mask, + ) + self.rope_deltas = rope_deltas + else: + batch_size, seq_length, _ = inputs_embeds.shape + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1) + if cache_position is not None: + delta = (cache_position[0] + self.rope_deltas).to( + inputs_embeds.device + ) + else: + delta = torch.zeros( + (batch_size, seq_length), device=inputs_embeds.device + ) + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=1) + position_ids += delta.to(position_ids.device) + + outputs = self.language_model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **kwargs, + ) + + output = Qwen2_5_VLModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + ) + return output if return_dict else output.to_tuple() + + +class DotDict(dict): + def __init__(self, mapping): + super().__init__() + for key, value in mapping.items(): + if isinstance(value, dict): + value = DotDict(value) # 递归转换 + elif isinstance(value, list): + # 如果是 list,且元素是 dict 也递归转换 + value = [ + DotDict(item) if isinstance(item, dict) else item for item in value + ] + self[key] = value + + def __getattr__(self, item): + try: + return self[item] + except KeyError: + raise AttributeError(f"No attribute '{item}'") + + def __setattr__(self, key, value): + self[key] = value + + def __delattr__(self, key): + del self[key] + + +def dict_to_namespace(d): + for k, v in d.items(): + if isinstance(v, dict): + d[k] = dict_to_namespace(v) + elif isinstance(v, list): + d[k] = [dict_to_namespace(i) if isinstance(i, dict) else i for i in v] + return SimpleNamespace(**d) + + +class Qwen2_5_VLForConditionalGeneration(TextEncoder): + # BitandBytes specific attributes + default_bitsandbytes_target_modules = [ + ".gate_up_proj.", + ".down_proj.", + ".q_proj.", + ".k_proj.", + ".v_proj.", + ".o_proj.", + ] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + + def __init__( + self, + config: Qwen2_5VLConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(config) + config = config.arch_config + self.model = Qwen2_5_VLModel(config) + self.lm_head = nn.Linear( + config.text_config.hidden_size, config.text_config.vocab_size, bias=False + ) + + self.config = config + + def get_input_embeddings(self): + return self.model.embed_tokens + + @torch.no_grad() + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ): + """Run forward pass for Qwen2_5-VL. + + Args: + input_ids: Flattened (concatenated) input_ids corresponding to a + batch. + positions: Flattened (concatenated) position ids corresponding to a + batch. + **NOTE**: If mrope is enabled (default setting for Qwen2-VL + opensource models), the shape will be `(3, seq_len)`, + otherwise it will be `(seq_len,). + (Use input_metadata.mrope_positions to replace it) + """ + output_attentions = False + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = ( + slice(-logits_to_keep, None) + if isinstance(logits_to_keep, int) + else logits_to_keep + ) + logits = self.lm_head(hidden_states[:, slice_indices, :]) + return Qwen2_5_VLCausalLMOutputWithPast( + loss=None, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=outputs.rope_deltas, + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + loaded_params: set[str] = set() + + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + name = name.replace("model.", "model.language_model.") + if "visual." in name: + name = name.replace("visual.", "model.visual.") + try: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + except KeyError: + print(params_dict.keys()) + raise + + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = loaded_weight.to(param.dtype) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + def get_embed_and_head(self): + return self.model.embed_tokens.weight, self.lm_head.weight + + +EntryClass = Qwen2_5_VLForConditionalGeneration diff --git a/python/sglang/multimodal_gen/runtime/models/encoders/stepllm.py b/python/sglang/multimodal_gen/runtime/models/encoders/stepllm.py new file mode 100644 index 000000000..18f10046c --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/models/encoders/stepllm.py @@ -0,0 +1,614 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# type: ignore +# Copyright 2025 StepFun Inc. All Rights Reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# ============================================================================== +import os +from functools import wraps + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from transformers.modeling_utils import PretrainedConfig, PreTrainedModel + +from sglang.multimodal_gen.runtime.models.dits.stepvideo import StepVideoRMSNorm + + +class EmptyInitOnDevice(torch.overrides.TorchFunctionMode): + + def __init__(self, device=None): + self.device = device + + def __torch_function__(self, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + if getattr(func, "__module__", None) == "torch.nn.init": + if "tensor" in kwargs: + return kwargs["tensor"] + else: + return args[0] + if ( + self.device is not None + and func in torch.utils._device._device_constructors() + and kwargs.get("device") is None + ): + kwargs["device"] = self.device + return func(*args, **kwargs) + + +def with_empty_init(func): + + @wraps(func) + def wrapper(*args, **kwargs): + with EmptyInitOnDevice("cpu"): + return func(*args, **kwargs) + + return wrapper + + +class LLaMaEmbedding(nn.Module): + """Language model embeddings. + + Arguments: + hidden_size: hidden size + vocab_size: vocabulary size + max_sequence_length: maximum size of sequence. This + is used for positional embedding + embedding_dropout_prob: dropout probability for embeddings + init_method: weight initialization method + num_tokentypes: size of the token-type embeddings. 0 value + will ignore this embedding + """ + + def __init__( + self, + cfg, + ): + super().__init__() + self.hidden_size = cfg.hidden_size + self.params_dtype = cfg.params_dtype + self.fp32_residual_connection = cfg.fp32_residual_connection + self.embedding_weights_in_fp32 = cfg.embedding_weights_in_fp32 + self.word_embeddings = torch.nn.Embedding( + cfg.padded_vocab_size, + self.hidden_size, + ) + self.embedding_dropout = torch.nn.Dropout(cfg.hidden_dropout) + + def forward(self, input_ids): + # Embeddings. + if self.embedding_weights_in_fp32: + self.word_embeddings = self.word_embeddings.to(torch.float32) + embeddings = self.word_embeddings(input_ids) + if self.embedding_weights_in_fp32: + embeddings = embeddings.to(self.params_dtype) + self.word_embeddings = self.word_embeddings.to(self.params_dtype) + + # Data format change to avoid explicit transposes : [b s h] --> [s b h]. + embeddings = embeddings.transpose(0, 1).contiguous() + + # If the input flag for fp32 residual connection is set, convert for float. + if self.fp32_residual_connection: + embeddings = embeddings.float() + + # Dropout. + embeddings = self.embedding_dropout(embeddings) + + return embeddings + + +class StepChatTokenizer: + """Step Chat Tokenizer""" + + def __init__( + self, + model_file, + name="StepChatTokenizer", + bot_token="<|BOT|>", # Begin of Turn + eot_token="<|EOT|>", # End of Turn + call_start_token="<|CALL_START|>", # Call Start + call_end_token="<|CALL_END|>", # Call End + think_start_token="<|THINK_START|>", # Think Start + think_end_token="<|THINK_END|>", # Think End + mask_start_token="<|MASK_1e69f|>", # Mask start + mask_end_token="<|UNMASK_1e69f|>", # Mask end + ): + import sentencepiece + + self._tokenizer = sentencepiece.SentencePieceProcessor(model_file=model_file) + + self._vocab = {} + self._inv_vocab = {} + + self._special_tokens = {} + self._inv_special_tokens = {} + + self._t5_tokens = [] + + for idx in range(self._tokenizer.get_piece_size()): + text = self._tokenizer.id_to_piece(idx) + self._inv_vocab[idx] = text + self._vocab[text] = idx + + if self._tokenizer.is_control(idx) or self._tokenizer.is_unknown(idx): + self._special_tokens[text] = idx + self._inv_special_tokens[idx] = text + + self._unk_id = self._tokenizer.unk_id() + self._bos_id = self._tokenizer.bos_id() + self._eos_id = self._tokenizer.eos_id() + + for token in [ + bot_token, + eot_token, + call_start_token, + call_end_token, + think_start_token, + think_end_token, + ]: + assert token in self._vocab, f"Token '{token}' not found in tokenizer" + assert ( + token in self._special_tokens + ), f"Token '{token}' is not a special token" + + for token in [mask_start_token, mask_end_token]: + assert token in self._vocab, f"Token '{token}' not found in tokenizer" + + self._bot_id = self._tokenizer.piece_to_id(bot_token) + self._eot_id = self._tokenizer.piece_to_id(eot_token) + self._call_start_id = self._tokenizer.piece_to_id(call_start_token) + self._call_end_id = self._tokenizer.piece_to_id(call_end_token) + self._think_start_id = self._tokenizer.piece_to_id(think_start_token) + self._think_end_id = self._tokenizer.piece_to_id(think_end_token) + self._mask_start_id = self._tokenizer.piece_to_id(mask_start_token) + self._mask_end_id = self._tokenizer.piece_to_id(mask_end_token) + + self._underline_id = self._tokenizer.piece_to_id("\u2581") + + @property + def vocab(self): + return self._vocab + + @property + def inv_vocab(self): + return self._inv_vocab + + @property + def vocab_size(self): + return self._tokenizer.vocab_size() + + def tokenize(self, text: str) -> list[int]: + return self._tokenizer.encode_as_ids(text) + + def detokenize(self, token_ids: list[int]) -> str: + return self._tokenizer.decode_ids(token_ids) + + +class Tokens: + + def __init__( + self, input_ids, cu_input_ids, attention_mask, cu_seqlens, max_seq_len + ) -> None: + self.input_ids = input_ids + self.attention_mask = attention_mask + self.cu_input_ids = cu_input_ids + self.cu_seqlens = cu_seqlens + self.max_seq_len = max_seq_len + + def to(self, device): + self.input_ids = self.input_ids.to(device) + self.attention_mask = self.attention_mask.to(device) + self.cu_input_ids = self.cu_input_ids.to(device) + self.cu_seqlens = self.cu_seqlens.to(device) + return self + + +class Wrapped_StepChatTokenizer(StepChatTokenizer): + + def __call__( + self, + text, + max_length=320, + padding="max_length", + truncation=True, + return_tensors="pt", + ): + # [bos, ..., eos, pad, pad, ..., pad] + self.BOS = 1 + self.EOS = 2 + self.PAD = 2 + out_tokens = [] + attn_mask = [] + if len(text) == 0: + part_tokens = [self.BOS] + [self.EOS] + valid_size = len(part_tokens) + if len(part_tokens) < max_length: + part_tokens += [self.PAD] * (max_length - valid_size) + out_tokens.append(part_tokens) + attn_mask.append([1] * valid_size + [0] * (max_length - valid_size)) + else: + for part in text: + part_tokens = self.tokenize(part) + part_tokens = part_tokens[ + : (max_length - 2) + ] # leave 2 space for bos and eos + part_tokens = [self.BOS] + part_tokens + [self.EOS] + valid_size = len(part_tokens) + if len(part_tokens) < max_length: + part_tokens += [self.PAD] * (max_length - valid_size) + out_tokens.append(part_tokens) + attn_mask.append([1] * valid_size + [0] * (max_length - valid_size)) + + out_tokens = torch.tensor(out_tokens, dtype=torch.long) + attn_mask = torch.tensor(attn_mask, dtype=torch.long) + + # padding y based on tp size + padded_len = 0 + padded_flag = False + if padded_len > 0: + padded_flag = True + if padded_flag: + pad_tokens = torch.tensor( + [[self.PAD] * max_length], device=out_tokens.device + ) + pad_attn_mask = torch.tensor( + [[1] * padded_len + [0] * (max_length - padded_len)], + device=attn_mask.device, + ) + out_tokens = torch.cat([out_tokens, pad_tokens], dim=0) + attn_mask = torch.cat([attn_mask, pad_attn_mask], dim=0) + + # cu_seqlens + cu_out_tokens = out_tokens.masked_select(attn_mask != 0).unsqueeze(0) + seqlen = attn_mask.sum(dim=1).tolist() + cu_seqlens = torch.cumsum(torch.tensor([0] + seqlen), 0).to( + device=out_tokens.device, dtype=torch.int32 + ) + max_seq_len = max(seqlen) + return Tokens(out_tokens, cu_out_tokens, attn_mask, cu_seqlens, max_seq_len) + + +def flash_attn_func( + q, + k, + v, + dropout_p=0.0, + softmax_scale=None, + causal=True, + return_attn_probs=False, + tp_group_rank=0, + tp_group_size=1, +): + softmax_scale = q.size(-1) ** (-0.5) if softmax_scale is None else softmax_scale + return torch.ops.Optimus.fwd( + q, + k, + v, + None, + dropout_p, + softmax_scale, + causal, + return_attn_probs, + None, + tp_group_rank, + tp_group_size, + )[0] + + +class FlashSelfAttention(torch.nn.Module): + + def __init__( + self, + attention_dropout=0.0, + ): + super().__init__() + self.dropout_p = attention_dropout + + def forward(self, q, k, v, cu_seqlens=None, max_seq_len=None): + if cu_seqlens is None: + output = flash_attn_func(q, k, v, dropout_p=self.dropout_p) + else: + raise ValueError("cu_seqlens is not supported!") + + return output + + +def safediv(n, d): + q, r = divmod(n, d) + assert r == 0 + return q + + +class MultiQueryAttention(nn.Module): + + def __init__(self, cfg, layer_id=None): + super().__init__() + + self.head_dim = cfg.hidden_size // cfg.num_attention_heads + self.max_seq_len = cfg.seq_length + self.use_flash_attention = cfg.use_flash_attn + assert self.use_flash_attention, "FlashAttention is required!" + + self.n_groups = cfg.num_attention_groups + self.tp_size = 1 + self.n_local_heads = cfg.num_attention_heads + self.n_local_groups = self.n_groups + + self.wqkv = nn.Linear( + cfg.hidden_size, + cfg.hidden_size + self.head_dim * 2 * self.n_groups, + bias=False, + ) + self.wo = nn.Linear( + cfg.hidden_size, + cfg.hidden_size, + bias=False, + ) + + # assert self.use_flash_attention, 'non-Flash attention not supported yet.' + self.core_attention = FlashSelfAttention( + attention_dropout=cfg.attention_dropout + ) + # self.core_attention = LocalAttention( + # num_heads = self.n_local_heads, + # head_size = self.head_dim, + # # num_kv_heads = self.n_local_groups, + # casual = True, + # supported_attention_backends = [_Backend.FLASH_ATTN, _Backend.TORCH_SDPA], # RIVER TODO + # ) + self.layer_id = layer_id + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor | None, + cu_seqlens: torch.Tensor | None, + max_seq_len: torch.Tensor | None, + ): + seqlen, bsz, dim = x.shape + xqkv = self.wqkv(x) + + xq, xkv = torch.split( + xqkv, + (dim // self.tp_size, self.head_dim * 2 * self.n_groups // self.tp_size), + dim=-1, + ) + + # gather on 1st dimension + xq = xq.view(seqlen, bsz, self.n_local_heads, self.head_dim) + xkv = xkv.view(seqlen, bsz, self.n_local_groups, 2 * self.head_dim) + xk, xv = xkv.chunk(2, -1) + + # rotary embedding + flash attn + xq = rearrange(xq, "s b h d -> b s h d") + xk = rearrange(xk, "s b h d -> b s h d") + xv = rearrange(xv, "s b h d -> b s h d") + + # q_per_kv = self.n_local_heads // self.n_local_groups + # if q_per_kv > 1: + # b, s, h, d = xk.size() + # if h == 1: + # xk = xk.expand(b, s, q_per_kv, d) + # xv = xv.expand(b, s, q_per_kv, d) + # else: + # ''' To cover the cases where h > 1, we have + # the following implementation, which is equivalent to: + # xk = xk.repeat_interleave(q_per_kv, dim=-2) + # xv = xv.repeat_interleave(q_per_kv, dim=-2) + # but can avoid calling aten::item() that involves cpu. + # ''' + # idx = torch.arange(q_per_kv * h, device=xk.device).reshape(q_per_kv, -1).permute(1, 0).flatten() + # xk = torch.index_select(xk.repeat(1, 1, q_per_kv, 1), 2, idx).contiguous() + # xv = torch.index_select(xv.repeat(1, 1, q_per_kv, 1), 2, idx).contiguous() + if self.use_flash_attention: + output = self.core_attention(xq, xk, xv) + # reduce-scatter only support first dimension now + output = rearrange(output, "b s h d -> s b (h d)").contiguous() + else: + xq, xk, xv = [ + rearrange(x, "b s ... -> s b ...").contiguous() for x in (xq, xk, xv) + ] + output = self.core_attention(xq, xk, xv) # , mask) + output = self.wo(output) + return output + + +class FeedForward(nn.Module): + + def __init__( + self, + cfg, + dim: int, + hidden_dim: int, + layer_id: int, + multiple_of: int = 256, + ): + super().__init__() + + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + def swiglu(x): + x = torch.chunk(x, 2, dim=-1) + return F.silu(x[0]) * x[1] + + self.swiglu = swiglu + + self.w1 = nn.Linear( + dim, + 2 * hidden_dim, + bias=False, + ) + self.w2 = nn.Linear( + hidden_dim, + dim, + bias=False, + ) + + def forward(self, x): + x = self.swiglu(self.w1(x)) + output = self.w2(x) + return output + + +class TransformerBlock(nn.Module): + + def __init__(self, cfg, layer_id: int): + super().__init__() + + self.n_heads = cfg.num_attention_heads + self.dim = cfg.hidden_size + self.head_dim = cfg.hidden_size // cfg.num_attention_heads + self.attention = MultiQueryAttention( + cfg, + layer_id=layer_id, + ) + + self.feed_forward = FeedForward( + cfg, + dim=cfg.hidden_size, + hidden_dim=cfg.ffn_hidden_size, + layer_id=layer_id, + ) + self.layer_id = layer_id + self.attention_norm = StepVideoRMSNorm( + cfg.hidden_size, + eps=cfg.layernorm_epsilon, + ) + self.ffn_norm = StepVideoRMSNorm( + cfg.hidden_size, + eps=cfg.layernorm_epsilon, + ) + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor | None, + cu_seqlens: torch.Tensor | None, + max_seq_len: torch.Tensor | None, + ): + residual = self.attention.forward( + self.attention_norm(x), mask, cu_seqlens, max_seq_len + ) + h = x + residual + ffn_res = self.feed_forward.forward(self.ffn_norm(h)) + out = h + ffn_res + return out + + +class Transformer(nn.Module): + + def __init__( + self, + config, + max_seq_size=8192, + ): + super().__init__() + self.num_layers = config.num_layers + self.layers = self._build_layers(config) + + def _build_layers(self, config): + layers = torch.nn.ModuleList() + for layer_id in range(self.num_layers): + layers.append( + TransformerBlock( + config, + layer_id=layer_id + 1, + ) + ) + return layers + + def forward( + self, + hidden_states, + attention_mask, + cu_seqlens=None, + max_seq_len=None, + ): + + if max_seq_len is not None and not isinstance(max_seq_len, torch.Tensor): + max_seq_len = torch.tensor(max_seq_len, dtype=torch.int32, device="cpu") + + for lid, layer in enumerate(self.layers): + hidden_states = layer( + hidden_states, + attention_mask, + cu_seqlens, + max_seq_len, + ) + return hidden_states + + +class Step1Model(PreTrainedModel): + config_class = PretrainedConfig + + @with_empty_init + def __init__( + self, + config, + ): + super().__init__(config) + self.tok_embeddings = LLaMaEmbedding(config) + self.transformer = Transformer(config) + + def forward( + self, + input_ids=None, + attention_mask=None, + ): + + hidden_states = self.tok_embeddings(input_ids) + + hidden_states = self.transformer( + hidden_states, + attention_mask, + ) + return hidden_states + + +class STEP1TextEncoder(torch.nn.Module): + + def __init__(self, model_dir, max_length=320): + super().__init__() + self.max_length = max_length + self.text_tokenizer = Wrapped_StepChatTokenizer( + os.path.join(model_dir, "step1_chat_tokenizer.model") + ) + text_encoder = Step1Model.from_pretrained(model_dir) + self.text_encoder = text_encoder.eval().to(torch.bfloat16) + + @torch.no_grad + def forward(self, prompts, with_mask=True, max_length=None): + self.device = next(self.text_encoder.parameters()).device + + with torch.no_grad(), torch.amp.autocast("cuda", dtype=torch.bfloat16): + if type(prompts) is str: + prompts = [prompts] + txt_tokens = self.text_tokenizer( + prompts, + max_length=max_length or self.max_length, + padding="max_length", + truncation=True, + return_tensors="pt", + ) + y = self.text_encoder( + txt_tokens.input_ids.to(self.device), + attention_mask=( + txt_tokens.attention_mask.to(self.device) if with_mask else None + ), + ) + y_mask = txt_tokens.attention_mask + return y.transpose(0, 1), y_mask + + +EntryClass = STEP1TextEncoder diff --git a/python/sglang/multimodal_gen/runtime/models/encoders/t5.py b/python/sglang/multimodal_gen/runtime/models/encoders/t5.py new file mode 100644 index 000000000..048308ad1 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/models/encoders/t5.py @@ -0,0 +1,716 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from transformers: https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/t5/modeling_t5.py + +# Derived from T5 implementation posted on HuggingFace; license below: +# +# coding=utf-8 +# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch T5 & UMT5 model.""" + +import math +from collections.abc import Iterable +from dataclasses import dataclass + +import torch +import torch.nn.functional as F +from torch import nn + +from sglang.multimodal_gen.configs.models.encoders import BaseEncoderOutput, T5Config +from sglang.multimodal_gen.runtime.distributed import get_tp_rank, get_tp_world_size +from sglang.multimodal_gen.runtime.layers.activation import get_act_fn +from sglang.multimodal_gen.runtime.layers.layernorm import RMSNorm +from sglang.multimodal_gen.runtime.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from sglang.multimodal_gen.runtime.layers.quantization import QuantizationConfig +from sglang.multimodal_gen.runtime.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, +) +from sglang.multimodal_gen.runtime.loader.weight_utils import default_weight_loader +from sglang.multimodal_gen.runtime.models.encoders.base import TextEncoder +from sglang.multimodal_gen.runtime.platforms import current_platform + + +class AttentionType: + """ + Attention type. + Use string to be compatible with `torch.compile`. + """ + + # Decoder attention between previous layer Q/K/V + DECODER = "decoder" + # Encoder attention between previous layer Q/K/V for encoder-decoder + ENCODER = "encoder" + # Encoder attention between previous layer Q/K/V + ENCODER_ONLY = "encoder_only" + # Attention between dec. Q and enc. K/V for encoder-decoder + ENCODER_DECODER = "encoder_decoder" + + +_seen_keys = set() # 用集合记录已经出现过的 key + + +@dataclass +class AttentionMetadata: + attn_bias: torch.Tensor + + +class T5DenseActDense(nn.Module): + + def __init__( + self, config: T5Config, quant_config: QuantizationConfig | None = None + ): + super().__init__() + self.wi = MergedColumnParallelLinear(config.d_model, [config.d_ff], bias=False) + self.wo = RowParallelLinear( + config.d_ff, config.d_model, bias=False, quant_config=quant_config + ) + self.act = get_act_fn(config.dense_act_fn) + + def forward(self, hidden_states) -> torch.Tensor: + hidden_states, _ = self.wi(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.wo(hidden_states) + return hidden_states + + +class T5DenseGatedActDense(nn.Module): + + def __init__( + self, config: T5Config, quant_config: QuantizationConfig | None = None + ): + super().__init__() + self.wi_0 = MergedColumnParallelLinear( + config.d_model, [config.d_ff], bias=False, quant_config=quant_config + ) + self.wi_1 = MergedColumnParallelLinear( + config.d_model, [config.d_ff], bias=False, quant_config=quant_config + ) + # Should not run in fp16 unless mixed-precision is used, + # see https://github.com/huggingface/transformers/issues/20287. + self.wo = RowParallelLinear( + config.d_ff, config.d_model, bias=False, quant_config=quant_config + ) + self.act = get_act_fn(config.dense_act_fn) + + def forward(self, hidden_states) -> torch.Tensor: + hidden_gelu = self.act(self.wi_0(hidden_states)[0]) + hidden_linear, _ = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states, _ = self.wo(hidden_states) + return hidden_states + + +class T5LayerFF(nn.Module): + + def __init__( + self, config: T5Config, quant_config: QuantizationConfig | None = None + ): + super().__init__() + if config.is_gated_act: + self.DenseReluDense = T5DenseGatedActDense( + config, quant_config=quant_config + ) + else: + self.DenseReluDense = T5DenseActDense(config, quant_config=quant_config) + + self.layer_norm = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + + def forward(self, hidden_states) -> torch.Tensor: + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + forwarded_states + return hidden_states + + +# T5 has attn_bias and does not use softmax scaling +class T5MultiHeadAttention(nn.Module): + + def __init__(self) -> None: + super().__init__() + + def forward(self, q, k, v, attn_bias=None): + b, _, n, c = q.shape + attn = torch.einsum("binc,bjnc->bnij", q, k) + if attn_bias is not None: + attn += attn_bias + + attn = F.softmax(attn.float(), dim=-1).type_as(attn) + x = torch.einsum("bnij,bjnc->binc", attn, v) + x = x.reshape(b, -1, n * c) + return x + + +class T5Attention(nn.Module): + + def __init__( + self, + config: T5Config, + attn_type: str, + has_relative_attention_bias=False, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): + super().__init__() + self.attn_type = attn_type + # Cross-attention has no relative pos encoding anyway + self.is_decoder = attn_type == AttentionType.DECODER + self.has_relative_attention_bias = has_relative_attention_bias + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.relative_attention_max_distance = config.relative_attention_max_distance + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + self.total_num_heads = self.total_num_kv_heads = config.num_heads + + # Partition heads across multiple tensor parallel GPUs. + tp_world_size = get_tp_world_size() + assert config.num_heads % tp_world_size == 0 + self.n_heads = config.num_heads // tp_world_size + + self.inner_dim = self.n_heads * self.key_value_proj_dim + # No GQA in t5. + # self.n_kv_heads = self.n_heads + + self.qkv_proj = QKVParallelLinear( + self.d_model, + self.d_model // self.total_num_heads, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.attn = T5MultiHeadAttention() + + if self.has_relative_attention_bias: + self.relative_attention_bias = VocabParallelEmbedding( + self.relative_attention_num_buckets, + self.total_num_heads, + org_num_embeddings=self.relative_attention_num_buckets, + padding_size=self.relative_attention_num_buckets, + quant_config=quant_config, + ) + self.o = RowParallelLinear( + self.d_model, + self.d_model, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + @staticmethod + def _relative_position_bucket( + relative_position, bidirectional=True, num_buckets=32, max_distance=128 + ) -> torch.Tensor: + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + Translate relative position to a bucket number for relative attention. + The relative position is defined as memory_position - query_position, + i.e. the distance in tokens from the attending position to the + attended-to position. If bidirectional=False, then positive relative + positions are invalid. We use smaller buckets for small absolute + relative_position and larger buckets for larger absolute + relative_positions. All relative positions >=max_distance map to the + same bucket. All relative positions <=-max_distance map to the same + bucket. This should allow for more graceful generalization to longer + sequences than the model has been trained on + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + Returns: + a Tensor with the same shape as relative_position, containing int32 + values in the range [0, num_buckets) + """ # noqa: E501 + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min( + relative_position, torch.zeros_like(relative_position) + ) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins + # in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, + torch.full_like(relative_position_if_large, num_buckets - 1), + ) + + relative_buckets += torch.where( + is_small, relative_position, relative_position_if_large + ) + return relative_buckets + + def compute_bias(self, query_length, key_length, device=None) -> torch.Tensor: + """Compute binned relative position bias""" + if device is None: + device = self.relative_attention_bias.weight.device + context_position = torch.arange(query_length, dtype=torch.long, device=device)[ + :, None + ] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[ + None, : + ] + # max_seq_len, nh + relative_position = memory_position - context_position + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias( + relative_position_bucket + ) # shape (query_length, key_length, num_heads) + x = values.permute([2, 0, 1]).unsqueeze( + 0 + ) # shape (1, num_heads, query_length, key_length) + return x + + def forward( + self, + hidden_states: torch.Tensor, # (num_tokens, d_model) + attention_mask: torch.Tensor, + attn_metadata: AttentionMetadata | None = None, + ) -> torch.Tensor: + bs, seq_len, _ = hidden_states.shape + num_seqs = bs + n, c = self.n_heads, self.d_model // self.total_num_heads + qkv, _ = self.qkv_proj(hidden_states) + # Projection of 'own' hidden state (self-attention). No GQA here. + q, k, v = qkv.split(self.inner_dim, dim=-1) + q = q.reshape(bs, seq_len, n, c) + k = k.reshape(bs, seq_len, n, c) + v = v.reshape(bs, seq_len, n, c) + + assert attn_metadata is not None + attn_bias = attn_metadata.attn_bias + # Not compatible with CP here (as all encoder-decoder models), + # as it assumes homogeneous batch (prefills or decodes). + if self.has_relative_attention_bias: + # Self-attention. Compute T5 relative positional encoding. + # The bias term is computed on longest sequence in batch. Biases + # for shorter sequences are slices of the longest. + assert self.attn_type == AttentionType.ENCODER + attn_bias = self.compute_bias(seq_len, seq_len).repeat(num_seqs, 1, 1, 1) + attn_metadata.attn_bias = attn_bias + else: + # Encoder/Decoder Self-Attention Layer, attn bias already cached. + assert attn_bias is not None + + if attention_mask is not None: + attention_mask = ( + attention_mask.view(bs, 1, 1, -1) + if attention_mask.ndim == 2 + else attention_mask.unsqueeze(1) + ) + mask_val = -1e4 if current_platform.is_mps() else torch.finfo(q.dtype).min + attn_bias.masked_fill_(attention_mask == 0, mask_val) + + if get_tp_world_size() > 1: + rank = get_tp_rank() + attn_bias = attn_bias[ + :, rank * self.n_heads : (rank + 1) * self.n_heads, :, : + ] + + attn_output = self.attn(q, k, v, attn_bias) + output, _ = self.o(attn_output) + return output + + +class T5LayerSelfAttention(nn.Module): + + def __init__( + self, + config, + has_relative_attention_bias=False, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): + super().__init__() + self.SelfAttention = T5Attention( + config, + AttentionType.DECODER if "decoder" in prefix else AttentionType.ENCODER, + has_relative_attention_bias=has_relative_attention_bias, + quant_config=quant_config, + prefix=f"{prefix}.SelfAttention", + ) + self.layer_norm = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + attn_metadata: AttentionMetadata | None = None, + ) -> torch.Tensor: + normed_hidden_states = self.layer_norm(hidden_states) + + attention_output = self.SelfAttention( + hidden_states=normed_hidden_states, + attention_mask=attention_mask, + attn_metadata=attn_metadata, + ) + + hidden_states = hidden_states + attention_output + + return hidden_states + + +class T5LayerCrossAttention(nn.Module): + + def __init__( + self, config, quant_config: QuantizationConfig | None = None, prefix: str = "" + ): + super().__init__() + self.EncDecAttention = T5Attention( + config, + AttentionType.ENCODER_DECODER, + has_relative_attention_bias=False, + quant_config=quant_config, + prefix=f"{prefix}.EncDecAttention", + ) + self.layer_norm = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata | None = None, + ) -> torch.Tensor: + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + hidden_states=normed_hidden_states, + attn_metadata=attn_metadata, + ) + hidden_states = hidden_states + attention_output + return hidden_states + + +class T5Block(nn.Module): + + def __init__( + self, + config: T5Config, + is_decoder: bool, + has_relative_attention_bias=False, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): + super().__init__() + self.is_decoder = is_decoder + self.layer = nn.ModuleList() + self.layer.append( + T5LayerSelfAttention( + config, + has_relative_attention_bias=has_relative_attention_bias, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + ) + + if self.is_decoder: + self.layer.append( + T5LayerCrossAttention( + config, quant_config=quant_config, prefix=f"{prefix}.cross_attn" + ) + ) + + self.layer.append(T5LayerFF(config, quant_config=quant_config)) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + attn_metadata: AttentionMetadata | None = None, + ) -> torch.Tensor: + + if attention_mask is None: + attention_mask = torch.ones( + hidden_states.shape[:2], device=hidden_states.device + ) + + hidden_states = self.layer[0]( + hidden_states=hidden_states, + attention_mask=attention_mask, + attn_metadata=attn_metadata, + ) + + if self.is_decoder: + hidden_states = self.layer[1]( + hidden_states=hidden_states, attn_metadata=attn_metadata + ) + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states) + + return hidden_states + + +class T5Stack(nn.Module): + + def __init__( + self, + config: T5Config, + is_decoder: bool, + n_layers: int, + embed_tokens=None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + is_umt5: bool = False, + ): + super().__init__() + self.embed_tokens = embed_tokens + self.is_umt5 = is_umt5 + if is_umt5: + self.block = nn.ModuleList( + [ + T5Block( + config, + is_decoder=is_decoder, + has_relative_attention_bias=True, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{i}", + ) + for i in range(n_layers) + ] + ) + else: + # Only the first block has relative positional encoding. + self.block = nn.ModuleList( + [ + T5Block( + config, + is_decoder=is_decoder, + has_relative_attention_bias=i == 0, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{i}", + ) + for i in range(n_layers) + ] + ) + self.final_layer_norm = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + + for idx, block in enumerate(self.block): + hidden_states = block( + hidden_states=hidden_states, + attention_mask=attention_mask, + attn_metadata=attn_metadata, + ) + + hidden_states = self.final_layer_norm(hidden_states) + return hidden_states + + +class T5EncoderModel(TextEncoder): + + def __init__(self, config: T5Config, prefix: str = ""): + super().__init__(config) + + quant_config = None + + self.shared = VocabParallelEmbedding( + config.vocab_size, config.d_model, org_num_embeddings=config.vocab_size + ) + + self.encoder = T5Stack( + config, + False, + config.num_layers, + self.shared, + quant_config=quant_config, + prefix=f"{prefix}.encoder", + is_umt5=False, + ) + + def get_input_embeddings(self): + return self.shared + + def forward( + self, + input_ids: torch.Tensor | None, + position_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + inputs_embeds: torch.Tensor | None = None, + output_hidden_states: bool | None = None, + **kwargs, + ) -> BaseEncoderOutput: + attn_metadata = AttentionMetadata(None) + hidden_states = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + attn_metadata=attn_metadata, + ) + + return BaseEncoderOutput(last_hidden_state=hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q", "q"), + (".qkv_proj", ".k", "k"), + (".qkv_proj", ".v", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + loaded = False + if "decoder" in name or "lm_head" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + loaded = True + break + if not loaded: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class UMT5EncoderModel(TextEncoder): + + def __init__(self, config: T5Config, prefix: str = ""): + super().__init__(config) + + quant_config = None + + self.shared = VocabParallelEmbedding( + config.vocab_size, config.d_model, org_num_embeddings=config.vocab_size + ) + + self.encoder = T5Stack( + config, + False, + config.num_layers, + self.shared, + quant_config=quant_config, + prefix=f"{prefix}.encoder", + is_umt5=True, + ) + + def get_input_embeddings(self): + return self.shared + + def forward( + self, + input_ids: torch.Tensor | None, + position_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + inputs_embeds: torch.Tensor | None = None, + output_hidden_states: bool | None = None, + **kwargs, + ) -> BaseEncoderOutput: + attn_metadata = AttentionMetadata(None) + hidden_states = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + attn_metadata=attn_metadata, + ) + + return BaseEncoderOutput( + last_hidden_state=hidden_states, + attention_mask=attention_mask, + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + loaded = False + if "decoder" in name or "lm_head" in name: + continue + for ( + param_name, + weight_name, + shard_id, + ) in self.config.arch_config.stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + loaded = True + break + if not loaded: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +EntryClass = [T5EncoderModel, UMT5EncoderModel] diff --git a/python/sglang/multimodal_gen/runtime/models/encoders/vision.py b/python/sglang/multimodal_gen/runtime/models/encoders/vision.py new file mode 100644 index 000000000..3150abf1c --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/models/encoders/vision.py @@ -0,0 +1,96 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/models/vision.py + +from abc import ABC, abstractmethod +from typing import Generic, TypeVar + +import torch +from transformers import PretrainedConfig + +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + +_C = TypeVar("_C", bound=PretrainedConfig) + + +class VisionEncoderInfo(ABC, Generic[_C]): + + def __init__(self, vision_config: _C) -> None: + super().__init__() + + self.vision_config = vision_config + + @abstractmethod + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + raise NotImplementedError + + @abstractmethod + def get_max_image_tokens(self) -> int: + raise NotImplementedError + + @abstractmethod + def get_image_size(self) -> int: + raise NotImplementedError + + @abstractmethod + def get_patch_size(self) -> int: + raise NotImplementedError + + @abstractmethod + def get_patch_grid_length(self) -> int: + raise NotImplementedError + + +def resolve_visual_encoder_outputs( + encoder_outputs: torch.Tensor | list[torch.Tensor], + feature_sample_layers: list[int] | None, + post_layer_norm: torch.nn.LayerNorm | None, + max_possible_layers: int, +) -> torch.Tensor: + """Given the outputs a visual encoder module that may correspond to the + output of the last layer, or a list of hidden states to be stacked, + handle post normalization and resolve it into a single output tensor. + + Args: + encoder_outputs: Output of encoder's last layer or all hidden states. + feature_sample_layers: Optional layer indices to grab from the encoder + outputs; if provided, encoder outputs must be a list. + post_layer_norm: Post norm to apply to the output of the encoder. + max_possible_layers: Total layers in the fully loaded visual encoder. + + """ + if feature_sample_layers is None: + if post_layer_norm is not None: + return post_layer_norm(encoder_outputs) + return encoder_outputs + + # Get the hidden states corresponding to the layer indices. + # Negative values are relative to the full visual encoder, + # so offset them depending on how many layers were loaded. + # NOTE: this assumes that encoder_outputs is a list containing + # the inputs to the visual encoder, followed by the hidden states + # of each layer. + num_loaded_layers = len(encoder_outputs) - 1 + offset = max_possible_layers - num_loaded_layers + hs_pool = [ + ( + encoder_outputs[layer_idx] + if layer_idx >= 0 + else encoder_outputs[layer_idx + offset] + ) + for layer_idx in feature_sample_layers + ] + + # Apply post-norm on the final hidden state if we are using it + uses_last_layer = feature_sample_layers[-1] in (len(hs_pool) - 1, -1) + if post_layer_norm is not None and uses_last_layer: + hs_pool[-1] = post_layer_norm(encoder_outputs) + return torch.cat(hs_pool, dim=-1) diff --git a/python/sglang/multimodal_gen/runtime/models/parameter.py b/python/sglang/multimodal_gen/runtime/models/parameter.py new file mode 100644 index 000000000..ba9b42c66 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/models/parameter.py @@ -0,0 +1,423 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/parameter.py + +from collections.abc import Callable +from fractions import Fraction +from typing import Any + +import torch +from torch.nn import Parameter + +from sglang.multimodal_gen.runtime.distributed import get_tp_rank +from sglang.multimodal_gen.runtime.models.utils import _make_synced_weight_loader +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class BasevLLMParameter(Parameter): + """ + Base parameter for vLLM linear layers. Extends the torch.nn.parameter + by taking in a linear weight loader. Will copy the loaded weight + into the parameter when the provided weight loader is called. + """ + + def __new__(cls, data: torch.Tensor, **kwargs): + + return super().__new__(cls, data=data, requires_grad=False) + + def __init__(self, data: torch.Tensor, weight_loader: Callable): + """ + Initialize the BasevLLMParameter + + :param data: torch tensor with the parameter data + :param weight_loader: weight loader callable + + :returns: a torch.nn.parameter + """ + + # During weight loading, we often do something like: + # narrowed_tensor = param.data.narrow(0, offset, len) + # narrowed_tensor.copy_(real_weight) + # expecting narrowed_tensor and param.data to share the same storage. + # However, on TPUs, narrowed_tensor will lazily propagate to the base + # tensor, which is param.data, leading to the redundant memory usage. + # This sometimes causes OOM errors during model loading. To avoid this, + # we sync the param tensor after its weight loader is called. + from sglang.multimodal_gen.runtime.platforms import current_platform + + if current_platform.is_tpu(): + weight_loader = _make_synced_weight_loader(weight_loader) + + self._weight_loader = weight_loader + + @property + def weight_loader(self): + return self._weight_loader + + def _is_1d_and_scalar(self, loaded_weight: torch.Tensor): + cond1 = self.data.ndim == 1 and self.data.numel() == 1 + cond2 = loaded_weight.ndim == 0 and loaded_weight.numel() == 1 + return cond1 and cond2 + + def _assert_and_load(self, loaded_weight: torch.Tensor) -> None: + assert self.data.shape == loaded_weight.shape or self._is_1d_and_scalar( + loaded_weight + ) + self.data.copy_(loaded_weight) + + def load_column_parallel_weight(self, loaded_weight: torch.Tensor) -> None: + self._assert_and_load(loaded_weight) + + def load_row_parallel_weight(self, loaded_weight: torch.Tensor) -> None: + self._assert_and_load(loaded_weight) + + def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs) -> None: + self._assert_and_load(loaded_weight) + + def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs) -> None: + self._assert_and_load(loaded_weight) + + +class _ColumnvLLMParameter(BasevLLMParameter): + """ + Private class defining weight loading functionality + (load_merged_column_weight, load_qkv_weight) + for parameters being loaded into linear layers with column + parallelism. This includes QKV and MLP layers which are + not already fused on disk. Requires an output dimension + to be defined. Called within the weight loader of + each of the column parallel linear layers. + """ + + def __init__(self, output_dim: int, **kwargs): + self._output_dim = output_dim + super().__init__(**kwargs) + + @property + def output_dim(self): + return self._output_dim + + def load_column_parallel_weight(self, loaded_weight: torch.Tensor) -> None: + tp_rank = get_tp_rank() + shard_size = self.data.shape[self.output_dim] + loaded_weight = loaded_weight.narrow( + self.output_dim, tp_rank * shard_size, shard_size + ) + assert self.data.shape == loaded_weight.shape + self.data.copy_(loaded_weight) + + def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs) -> None: + + shard_offset = kwargs.get("shard_offset") + shard_size = kwargs.get("shard_size") + if shard_offset is None or shard_size is None: + raise ValueError("shard_offset and shard_size must be provided") + if ( + isinstance(self, PackedColumnParameter | PackedvLLMParameter) + and self.packed_dim == self.output_dim + ): + shard_size, shard_offset = self.adjust_shard_indexes_for_packing( + shard_offset=shard_offset, shard_size=shard_size + ) + + param_data = self.data + + tp_rank = get_tp_rank() + param_data = param_data.narrow(self.output_dim, shard_offset, shard_size) + loaded_weight = loaded_weight.narrow( + self.output_dim, tp_rank * shard_size, shard_size + ) + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs) -> None: + + shard_offset = kwargs.get("shard_offset") + shard_size = kwargs.get("shard_size") + shard_id = kwargs.get("shard_id") + num_heads = kwargs.get("num_heads") + + assert shard_offset is not None + assert shard_size is not None + assert shard_id is not None + assert num_heads is not None + + if ( + isinstance(self, PackedColumnParameter | PackedvLLMParameter) + and self.output_dim == self.packed_dim + ): + shard_size, shard_offset = self.adjust_shard_indexes_for_packing( + shard_offset=shard_offset, shard_size=shard_size + ) + + param_data = self.data + tp_rank = get_tp_rank() + shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads + param_data = param_data.narrow(self.output_dim, shard_offset, shard_size) + loaded_weight = loaded_weight.narrow( + self.output_dim, shard_id * shard_size, shard_size + ) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + +class RowvLLMParameter(BasevLLMParameter): + """ + Parameter class defining weight_loading functionality + (load_row_parallel_weight) for parameters being loaded + into linear layers with row parallel functionality. + Requires an input_dim to be defined. + """ + + def __init__(self, input_dim: int, **kwargs): + self._input_dim = input_dim + super().__init__(**kwargs) + + @property + def input_dim(self): + return self._input_dim + + def load_row_parallel_weight(self, loaded_weight: torch.Tensor) -> None: + tp_rank = get_tp_rank() + shard_size = self.data.shape[self.input_dim] + loaded_weight = loaded_weight.narrow( + self.input_dim, tp_rank * shard_size, shard_size + ) + + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + assert self.data.shape == loaded_weight.shape + self.data.copy_(loaded_weight) + + +class ModelWeightParameter(_ColumnvLLMParameter, RowvLLMParameter): + """ + Parameter class for linear layer weights. Uses both column and + row parallelism. + """ + + pass + + +class GroupQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter): + """ + Parameter class for weight scales loaded for weights with + grouped quantization. Uses both column and row parallelism. + """ + + pass + + +class ChannelQuantScaleParameter(_ColumnvLLMParameter): + """ + Parameter class for weight scales loaded for weights with + channel-wise quantization. Equivalent to _ColumnvLLMParameter. + """ + + pass + + +class PerTensorScaleParameter(BasevLLMParameter): + """ + Parameter class for scales where the number of scales is + equivalent to the number of logical matrices in fused linear + layers (e.g. for QKV, there are 3 scales loaded from disk). + This is relevant to weights with per-tensor quantization. + Adds functionality to map the scalers to a shard during + weight loading. + + Note: additional parameter manipulation may be handled + for each quantization config specifically, within + process_weights_after_loading + """ + + def __init__(self, **kwargs): + self.qkv_idxs = {"q": 0, "k": 1, "v": 2} + super().__init__(**kwargs) + + def _shard_id_as_int(self, shard_id: str | int) -> int: + if isinstance(shard_id, int): + return shard_id + + # if not int, assume shard_id for qkv + # map to int and return + assert isinstance(shard_id, str) + assert shard_id in self.qkv_idxs + return self.qkv_idxs[shard_id] + + # For row parallel layers, no sharding needed + # load weight into parameter as is + def load_row_parallel_weight(self, *args, **kwargs) -> None: + super().load_row_parallel_weight(*args, **kwargs) + + def load_merged_column_weight(self, *args, **kwargs) -> None: + self._load_into_shard_id(*args, **kwargs) + + def load_qkv_weight(self, *args, **kwargs) -> None: + self._load_into_shard_id(*args, **kwargs) + + def load_column_parallel_weight(self, *args, **kwargs) -> None: + super().load_row_parallel_weight(*args, **kwargs) + + def _load_into_shard_id( + self, loaded_weight: torch.Tensor, shard_id: str | int, **kwargs + ): + """ + Slice the parameter data based on the shard id for + loading. + """ + + param_data = self.data + shard_id = self._shard_id_as_int(shard_id) + + # AutoFP8 scales do not have a shape + # compressed-tensors scales do have a shape + if len(loaded_weight.shape) != 0: + assert loaded_weight.shape[0] == 1 + loaded_weight = loaded_weight[0] + + param_data = param_data[shard_id] + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + +class PackedColumnParameter(_ColumnvLLMParameter): + """ + Parameter for model parameters which are packed on disk + and support column parallelism only. See PackedvLLMParameter + for more details on the packed properties. + """ + + def __init__(self, packed_factor: int | Fraction, packed_dim: int, **kwargs): + self._packed_factor = packed_factor + self._packed_dim = packed_dim + super().__init__(**kwargs) + + @property + def packed_dim(self): + return self._packed_dim + + @property + def packed_factor(self): + return self._packed_factor + + def adjust_shard_indexes_for_packing( + self, shard_size, shard_offset + ) -> tuple[Any, Any]: + return _adjust_shard_indexes_for_packing( + shard_size=shard_size, + shard_offset=shard_offset, + packed_factor=self.packed_factor, + ) + + +class PackedvLLMParameter(ModelWeightParameter): + """ + Parameter for model weights which are packed on disk. + Example: GPTQ Marlin weights are int4 or int8, packed into int32. + Extends the ModelWeightParameter to take in the + packed factor, the packed dimension, and optionally, marlin + tile size for marlin kernels. Adjusts the shard_size and + shard_offset for fused linear layers model weight loading + by accounting for packing and optionally, marlin tile size. + """ + + def __init__(self, packed_factor: int | Fraction, packed_dim: int, **kwargs): + self._packed_factor = packed_factor + self._packed_dim = packed_dim + super().__init__(**kwargs) + + @property + def packed_dim(self): + return self._packed_dim + + @property + def packed_factor(self): + return self._packed_factor + + def adjust_shard_indexes_for_packing(self, shard_size, shard_offset): + return _adjust_shard_indexes_for_packing( + shard_size=shard_size, + shard_offset=shard_offset, + packed_factor=self.packed_factor, + ) + + +class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter): + """ + Parameter class for weight scales loaded for weights with + block-wise quantization. Uses both column and row parallelism. + """ + + pass + + +def permute_param_layout_( + param: BasevLLMParameter, input_dim: int, output_dim: int, **kwargs +) -> BasevLLMParameter: + """ + Permute a parameter's layout to the specified input and output dimensions, + useful for forcing the parameter into a known layout, for example, if I need + a packed (quantized) weight matrix to be in the layout + {input_dim = 0, output_dim = 1, packed_dim = 0} + then I can call: + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + to ensure x is in the correct layout (permuting it to the correct layout if + required, asserting if it cannot get it to the correct layout) + """ + + curr_input_dim = getattr(param, "input_dim", None) + curr_output_dim = getattr(param, "output_dim", None) + + if curr_input_dim is None or curr_output_dim is None: + assert param.data.dim() == 2, ( + "permute_param_layout_ only supports 2D parameters when either " + "input_dim or output_dim is not set" + ) + + # if one of the dimensions is not set, set it to the opposite of the other + # we can only do this since we asserted the parameter is 2D above + if curr_input_dim is None: + assert curr_output_dim is not None, "either input or output dim must be set" + curr_input_dim = (curr_output_dim + 1) % 2 + if curr_output_dim is None: + assert curr_input_dim is not None, "either input or output dim must be set" + curr_output_dim = (curr_input_dim + 1) % 2 + + # create permutation from the current layout to the layout with + # self.input_dim at input_dim and self.output_dim at output_dim preserving + # other dimensions + perm = [ + i for i in range(param.data.dim()) if i not in [curr_input_dim, curr_output_dim] + ] + perm.insert(input_dim, curr_input_dim) + perm.insert(output_dim, curr_output_dim) + + if "packed_dim" in kwargs: + assert ( + hasattr(param, "packed_dim") + and param.packed_dim == perm[kwargs["packed_dim"]] + ), "permute_param_layout_ currently doesn't support repacking" + + param.data = param.data.permute(*perm) + if hasattr(param, "_input_dim"): + param._input_dim = input_dim + if hasattr(param, "_output_dim"): + param._output_dim = output_dim + if "packed_dim" in kwargs and hasattr(param, "_packed_dim"): + param._packed_dim = kwargs["packed_dim"] + + return param + + +def _adjust_shard_indexes_for_packing( + shard_size, shard_offset, packed_factor +) -> tuple[Any, Any]: + shard_size = shard_size // packed_factor + shard_offset = shard_offset // packed_factor + return shard_size, shard_offset diff --git a/python/sglang/multimodal_gen/runtime/models/registry.py b/python/sglang/multimodal_gen/runtime/models/registry.py new file mode 100644 index 000000000..a3cb0934e --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/models/registry.py @@ -0,0 +1,366 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/models/registry.py + +import ast +import importlib +import os +import pickle +import subprocess +import sys +import tempfile +from abc import ABC, abstractmethod +from collections.abc import Callable, Set +from dataclasses import dataclass, field +from functools import lru_cache +from typing import NoReturn, TypeVar, cast + +import cloudpickle +from torch import nn + +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + +MODELS_PATH = os.path.dirname(__file__) +COMPONENT_DIRS = [ + d + for d in os.listdir(MODELS_PATH) + if os.path.isdir(os.path.join(MODELS_PATH, d)) + and not d.startswith("__") + and not d.startswith(".") +] + +_IMAGE_ENCODER_MODELS: dict[str, tuple] = { + # "HunyuanVideoTransformer3DModel": ("image_encoder", "hunyuanvideo", "HunyuanVideoImageEncoder"), + "CLIPVisionModelWithProjection": ("encoders", "clip", "CLIPVisionModel"), +} + + +@lru_cache(maxsize=None) +def _discover_and_register_models() -> dict[str, tuple[str, str, str]]: + discovered_models = _IMAGE_ENCODER_MODELS + for component in COMPONENT_DIRS: + component_path = os.path.join(MODELS_PATH, component) + for filename in os.listdir(component_path): + if not filename.endswith(".py"): + continue + + mod_relname = filename[:-3] + filepath = os.path.join(component_path, filename) + try: + with open(filepath, "r", encoding="utf-8") as f: + source = f.read() + tree = ast.parse(source, filename=filename) + + entry_class_node = None + first_class_def = None + + for node in ast.walk(tree): + if isinstance(node, ast.Assign): + for target in node.targets: + if ( + isinstance(target, ast.Name) + and target.id == "EntryClass" + ): + entry_class_node = node + break + if first_class_def is None and isinstance(node, ast.ClassDef): + first_class_def = node + if entry_class_node and first_class_def: + model_cls_name_list = [] + value_node = entry_class_node.value + + # EntryClass = ClassName + if isinstance(value_node, ast.Name): + model_cls_name_list.append(value_node.id) + # EntryClass = ["...", ClassName, ...] + elif isinstance(value_node, (ast.List, ast.Tuple)): + for elt in value_node.elts: + if isinstance(elt, ast.Constant): + model_cls_name_list.append(elt.value) + elif isinstance(elt, ast.Name): + model_cls_name_list.append(elt.id) + + if model_cls_name_list: + for model_cls_str in model_cls_name_list: + if model_cls_str in discovered_models: + logger.warning( + f"Duplicate architecture found: {model_cls_str}. It will be overwritten." + ) + model_arch = model_cls_str + discovered_models[model_arch] = ( + component, + mod_relname, + model_cls_str, + ) + + except Exception as e: + logger.warning(f"Could not parse {filepath} to find models: {e}") + + return discovered_models + + +_SGL_DIFFUSION_MODELS = _discover_and_register_models() + +_SUBPROCESS_COMMAND = [ + sys.executable, + "-m", + "sglang.multimodal_gen.runtime.models.dits.registry", +] + +_T = TypeVar("_T") + + +@dataclass(frozen=True) +class _ModelInfo: + architecture: str + + @staticmethod + def from_model_cls(model: type[nn.Module]) -> "_ModelInfo": + return _ModelInfo( + architecture=model.__name__, + ) + + +class _BaseRegisteredModel(ABC): + + @abstractmethod + def inspect_model_cls(self) -> _ModelInfo: + raise NotImplementedError + + @abstractmethod + def load_model_cls(self) -> type[nn.Module]: + raise NotImplementedError + + +@dataclass(frozen=True) +class _RegisteredModel(_BaseRegisteredModel): + """ + Represents a model that has already been imported in the main process. + """ + + interfaces: _ModelInfo + model_cls: type[nn.Module] + + @staticmethod + def from_model_cls(model_cls: type[nn.Module]): + return _RegisteredModel( + interfaces=_ModelInfo.from_model_cls(model_cls), + model_cls=model_cls, + ) + + def inspect_model_cls(self) -> _ModelInfo: + return self.interfaces + + def load_model_cls(self) -> type[nn.Module]: + return self.model_cls + + +def _run_in_subprocess(fn: Callable[[], _T]) -> _T: + # NOTE: We use a temporary directory instead of a temporary file to avoid + # issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file + with tempfile.TemporaryDirectory() as tempdir: + output_filepath = os.path.join(tempdir, "registry_output.tmp") + + # `cloudpickle` allows pickling lambda functions directly + input_bytes = cloudpickle.dumps((fn, output_filepath)) + + # cannot use `sys.executable __file__` here because the script + # contains relative imports + returned = subprocess.run( + _SUBPROCESS_COMMAND, input=input_bytes, capture_output=True + ) + + # check if the subprocess is successful + try: + returned.check_returncode() + except Exception as e: + # wrap raised exception to provide more information + raise RuntimeError( + f"Error raised in subprocess:\n" f"{returned.stderr.decode()}" + ) from e + + with open(output_filepath, "rb") as f: + return cast(_T, pickle.load(f)) + + +@dataclass(frozen=True) +class _LazyRegisteredModel(_BaseRegisteredModel): + """ + Represents a model that has not been imported in the main process. + """ + + module_name: str + component_name: str + class_name: str + + # Performed in another process to avoid initializing CUDA + def inspect_model_cls(self) -> _ModelInfo: + return _run_in_subprocess( + lambda: _ModelInfo.from_model_cls(self.load_model_cls()) + ) + + def load_model_cls(self) -> type[nn.Module]: + mod = importlib.import_module(self.module_name) + return cast(type[nn.Module], getattr(mod, self.class_name)) + + +@lru_cache(maxsize=128) +def _try_load_model_cls( + model_arch: str, + model: _BaseRegisteredModel, +) -> type[nn.Module] | None: + from sglang.multimodal_gen.runtime.platforms import current_platform + + current_platform.verify_model_arch(model_arch) + try: + return model.load_model_cls() + except Exception: + logger.exception("Ignore import error when loading '%s'", model_arch) + return None + + +@lru_cache(maxsize=128) +def _try_inspect_model_cls( + model_arch: str, + model: _BaseRegisteredModel, +) -> _ModelInfo | None: + try: + return model.inspect_model_cls() + except Exception: + logger.exception("Error in inspecting model architecture '%s'", model_arch) + return None + + +@dataclass +class _ModelRegistry: + # Keyed by model_arch + models: dict[str, _BaseRegisteredModel] = field(default_factory=dict) + + def get_supported_archs(self) -> Set[str]: + return self.models.keys() + + def register_model( + self, + model_arch: str, + model_cls: type[nn.Module] | str, + ) -> None: + """ + Register an external model to be used in vLLM. + + :code:`model_cls` can be either: + + - A :class:`torch.nn.Module` class directly referencing the model. + - A string in the format :code:`:` which can be used to + lazily import the model. This is useful to avoid initializing CUDA + when importing the model and thus the related error + :code:`RuntimeError: Cannot re-initialize CUDA in forked subprocess`. + """ + if model_arch in self.models: + logger.warning( + "Model architecture %s is already registered, and will be " + "overwritten by the new model class %s.", + model_arch, + model_cls, + ) + + if isinstance(model_cls, str): + split_str = model_cls.split(":") + if len(split_str) != 2: + msg = "Expected a string in the format `:`" + raise ValueError(msg) + + model = _LazyRegisteredModel(*split_str) + else: + model = _RegisteredModel.from_model_cls(model_cls) + + self.models[model_arch] = model + + def _raise_for_unsupported(self, architectures: list[str]) -> NoReturn: + all_supported_archs = self.get_supported_archs() + + if any(arch in all_supported_archs for arch in architectures): + raise ValueError( + f"Model architectures {architectures} failed " + "to be inspected. Please check the logs for more details." + ) + + raise ValueError( + f"Model architectures {architectures} are not supported for now. " + f"Supported architectures: {all_supported_archs}" + ) + + def _try_load_model_cls(self, model_arch: str) -> type[nn.Module] | None: + if model_arch not in self.models: + return None + + return _try_load_model_cls(model_arch, self.models[model_arch]) + + def _try_inspect_model_cls(self, model_arch: str) -> _ModelInfo | None: + if model_arch not in self.models: + return None + + return _try_inspect_model_cls(model_arch, self.models[model_arch]) + + def _normalize_archs( + self, + architectures: str | list[str], + ) -> list[str]: + if isinstance(architectures, str): + architectures = [architectures] + if not architectures: + logger.warning("No model architectures are specified") + + normalized_arch = [] + for model in architectures: + if model not in self.models: + raise Exception( + f"Unsupported model architecture: {model}. Registered architectures: {architectures}" + ) + model = "TransformersModel" + normalized_arch.append(model) + return normalized_arch + + def inspect_model_cls( + self, + architectures: str | list[str], + ) -> tuple[_ModelInfo, str]: + architectures = self._normalize_archs(architectures) + + for arch in architectures: + model_info = self._try_inspect_model_cls(arch) + if model_info is not None: + return (model_info, arch) + + return self._raise_for_unsupported(architectures) + + def resolve_model_cls( + self, + architectures: str | list[str], + ) -> tuple[type[nn.Module], str]: + architectures = self._normalize_archs(architectures) + + for arch in architectures: + model_cls = self._try_load_model_cls(arch) + if model_cls is not None: + return (model_cls, arch) + + return self._raise_for_unsupported(architectures) + + +ModelRegistry = _ModelRegistry( + { + model_arch: _LazyRegisteredModel( + module_name=f"sglang.multimodal_gen.runtime.models.{component_name}.{mod_relname}", + component_name=component_name, + class_name=cls_name, + ) + for model_arch, ( + component_name, + mod_relname, + cls_name, + ) in _SGL_DIFFUSION_MODELS.items() + } +) diff --git a/python/sglang/multimodal_gen/runtime/models/schedulers/base.py b/python/sglang/multimodal_gen/runtime/models/schedulers/base.py new file mode 100644 index 000000000..eb4e3bdda --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/models/schedulers/base.py @@ -0,0 +1,37 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod + +import torch + + +class BaseScheduler(ABC): + timesteps: torch.Tensor + order: int + num_train_timesteps: int + + def __init__(self, *args, **kwargs) -> None: + # Check if subclass has defined all required properties + required_attributes = ["timesteps", "order", "num_train_timesteps"] + + for attr in required_attributes: + if not hasattr(self, attr): + raise AttributeError( + f"Subclasses of BaseScheduler must define '{attr}' property" + ) + + @abstractmethod + def set_shift(self, shift: float) -> None: + pass + + @abstractmethod + def set_timesteps(self, *args, **kwargs) -> None: + pass + + @abstractmethod + def scale_model_input( + self, sample: torch.Tensor, timestep: int | None = None + ) -> torch.Tensor: + pass diff --git a/python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_flow_match_euler_discrete.py b/python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_flow_match_euler_discrete.py new file mode 100644 index 000000000..d184802b8 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_flow_match_euler_discrete.py @@ -0,0 +1,698 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Modified from diffusers==0.29.2 +# +# ============================================================================== +import math +from dataclasses import dataclass +from typing import Any + +import numpy as np +import scipy.stats +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import SchedulerMixin +from diffusers.utils import BaseOutput + +from sglang.multimodal_gen.runtime.models.schedulers.base import BaseScheduler +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +@dataclass +class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.FloatTensor + + +class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin, BaseScheduler): + """ + Euler scheduler. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + shift (`float`, defaults to 1.0): + The shift value for the timestep schedule. + use_dynamic_shifting (`bool`, defaults to False): + Whether to apply timestep shifting on-the-fly based on the image resolution. + base_shift (`float`, defaults to 0.5): + Value to stabilize image generation. Increasing `base_shift` reduces variation and image is more consistent + with desired output. + max_shift (`float`, defaults to 1.15): + Value change allowed to latent vectors. Increasing `max_shift` encourages more variation and image may be + more exaggerated or stylized. + base_image_seq_len (`int`, defaults to 256): + The base image sequence length. + max_image_seq_len (`int`, defaults to 4096): + The maximum image sequence length. + invert_sigmas (`bool`, defaults to False): + Whether to invert the sigmas. + shift_terminal (`float`, defaults to None): + The end value of the shifted timestep schedule. + use_karras_sigmas (`bool`, defaults to False): + Whether to use Karras sigmas for step sizes in the noise schedule during sampling. + use_exponential_sigmas (`bool`, defaults to False): + Whether to use exponential sigmas for step sizes in the noise schedule during sampling. + use_beta_sigmas (`bool`, defaults to False): + Whether to use beta sigmas for step sizes in the noise schedule during sampling. + time_shift_type (`str`, defaults to "exponential"): + The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear". + stochastic_sampling (`bool`, defaults to False): + Whether to use stochastic sampling. + """ + + _compatibles: list[Any] = [] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + use_dynamic_shifting: bool = False, + base_shift: float | None = 0.5, + max_shift: float | None = 1.15, + base_image_seq_len: int | None = 256, + max_image_seq_len: int | None = 4096, + invert_sigmas: bool = False, + shift_terminal: float | None = None, + use_karras_sigmas: bool | None = False, + use_exponential_sigmas: bool | None = False, + use_beta_sigmas: bool | None = False, + time_shift_type: str = "exponential", + stochastic_sampling: bool = False, + ): + if ( + sum( + [ + self.config.use_beta_sigmas, + self.config.use_exponential_sigmas, + self.config.use_karras_sigmas, + ] + ) + > 1 + ): + raise ValueError( + "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." + ) + if time_shift_type not in {"exponential", "linear"}: + raise ValueError( + "`time_shift_type` must either be 'exponential' or 'linear'." + ) + + timesteps = np.linspace( + 1, num_train_timesteps, num_train_timesteps, dtype=np.float32 + )[::-1].copy() + timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) + + sigmas = timesteps / num_train_timesteps + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + + self.timesteps = sigmas * num_train_timesteps + self.num_train_timesteps = num_train_timesteps + + self._step_index: int | None = None + self._begin_index: int | None = None + + self._shift = shift + + self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + BaseScheduler.__init__(self) + + @property + def shift(self) -> float: + """ + The value used for shifting. + """ + return self._shift + + @property + def step_index(self) -> int | None: + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self) -> int | None: + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0) -> None: + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def set_shift(self, shift: float) -> None: + self._shift = shift + + def scale_noise( + self, + sample: torch.FloatTensor, + timestep: float | torch.FloatTensor, + noise: torch.FloatTensor | None = None, + ) -> torch.FloatTensor: + """ + Forward process in flow-matching + + Args: + sample (`torch.FloatTensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.FloatTensor`: + A scaled input sample. + """ + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype) + + if sample.device.type == "mps" and torch.is_floating_point(timestep): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32) + assert isinstance(timestep, torch.Tensor) + timestep = timestep.to(sample.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(sample.device) + assert isinstance(timestep, torch.Tensor) + timestep = timestep.to(sample.device) + + # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [ + self.index_for_timestep(t, schedule_timesteps) for t in timestep + ] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timestep.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timestep.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(sample.shape): + sigma = sigma.unsqueeze(-1) + + sample = sigma * noise + (1.0 - sigma) * sample + + return sample + + def _sigma_to_t(self, sigma: float) -> float: + return sigma * self.config.num_train_timesteps + + def time_shift( + self, mu: float, sigma: float, t: torch.Tensor | np.ndarray + ) -> torch.Tensor | np.ndarray: + if self.config.time_shift_type == "exponential": + return self._time_shift_exponential(mu, sigma, t) + elif self.config.time_shift_type == "linear": + return self._time_shift_linear(mu, sigma, t) + else: + raise ValueError(f"Unknown time_shift_type: {self.config.time_shift_type}") + + def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor: + r""" + Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config + value. + + Reference: + https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51 + + Args: + t (`torch.Tensor`): + A tensor of timesteps to be stretched and shifted. + + Returns: + `torch.Tensor`: + A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`. + """ + one_minus_z = 1 - t + scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal) + stretched_t = 1 - (one_minus_z / scale_factor) + return stretched_t + + def set_timesteps( + self, + num_inference_steps: int | None = None, + device: str | torch.device = None, + sigmas: list[float] | None = None, + mu: float | None = None, + timesteps: list[float] | None = None, + ) -> None: + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`, *optional*): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + sigmas (`List[float]`, *optional*): + Custom values for sigmas to be used for each diffusion step. If `None`, the sigmas are computed + automatically. + mu (`float`, *optional*): + Determines the amount of shifting applied to sigmas when performing resolution-dependent timestep + shifting. + timesteps (`List[float]`, *optional*): + Custom values for timesteps to be used for each diffusion step. If `None`, the timesteps are computed + automatically. + """ + + if self.config.use_dynamic_shifting and mu is None: + raise ValueError( + "`mu` must be passed when `use_dynamic_shifting` is set to be `True`" + ) + + if ( + sigmas is not None + and timesteps is not None + and len(sigmas) != len(timesteps) + ): + raise ValueError("`sigmas` and `timesteps` should have the same length") + + if num_inference_steps is not None: + if (sigmas is not None and len(sigmas) != num_inference_steps) or ( + timesteps is not None and len(timesteps) != num_inference_steps + ): + raise ValueError( + "`sigmas` and `timesteps` should have the same length as num_inference_steps, if `num_inference_steps` is provided" + ) + else: + if sigmas is not None: + num_inference_steps = len(sigmas) + elif timesteps is not None: + num_inference_steps = len(timesteps) + else: + raise ValueError( + "Either num_inference_steps, sigmas, or timesteps must be provided" + ) + + self.num_inference_steps = num_inference_steps + + # 1. Prepare default sigmas + is_timesteps_provided = timesteps is not None + + timesteps_array: np.ndarray | None = None + if is_timesteps_provided: + assert timesteps is not None + timesteps_array = np.array(timesteps).astype(np.float32) + + sigmas_array: np.ndarray + if sigmas is None: + if timesteps_array is None: + timesteps_array = np.linspace( + self._sigma_to_t(self.sigma_max), + self._sigma_to_t(self.sigma_min), + num_inference_steps, + ) + sigmas_array = timesteps_array / self.config.num_train_timesteps + else: + sigmas_array = np.array(sigmas).astype(np.float32) + num_inference_steps = len(sigmas_array) + + # 2. Perform timestep shifting. Either no shifting is applied, or resolution-dependent shifting of + # "exponential" or "linear" type is applied + if self.config.use_dynamic_shifting: + assert mu is not None, "mu cannot be None when use_dynamic_shifting is True" + sigmas_array = self.time_shift(mu, 1.0, sigmas_array) + else: + sigmas_array = ( + self.shift * sigmas_array / (1 + (self.shift - 1) * sigmas_array) + ) + + # 3. If required, stretch the sigmas schedule to terminate at the configured `shift_terminal` value + if self.config.shift_terminal: + sigmas_tensor = torch.from_numpy(sigmas_array).to(dtype=torch.float32) + sigmas_tensor = self.stretch_shift_to_terminal(sigmas_tensor) + sigmas_array = sigmas_tensor.numpy() + + # 4. If required, convert sigmas to one of karras, exponential, or beta sigma schedules + if self.config.use_karras_sigmas: + sigmas_tensor = torch.from_numpy(sigmas_array).to(dtype=torch.float32) + sigmas_tensor = self._convert_to_karras( + in_sigmas=sigmas_tensor, num_inference_steps=num_inference_steps + ) + sigmas_array = sigmas_tensor.numpy() + elif self.config.use_exponential_sigmas: + sigmas_tensor = torch.from_numpy(sigmas_array).to(dtype=torch.float32) + sigmas_tensor = self._convert_to_exponential( + in_sigmas=sigmas_tensor, num_inference_steps=num_inference_steps + ) + sigmas_array = sigmas_tensor.numpy() + elif self.config.use_beta_sigmas: + sigmas_tensor = torch.from_numpy(sigmas_array).to(dtype=torch.float32) + sigmas_tensor = self._convert_to_beta( + in_sigmas=sigmas_tensor, num_inference_steps=num_inference_steps + ) + sigmas_array = sigmas_tensor.numpy() + + # 5. Convert sigmas and timesteps to tensors and move to specified device + sigmas_tensor = torch.from_numpy(sigmas_array).to( + dtype=torch.float32, device=device + ) + if not is_timesteps_provided: + timesteps_tensor = sigmas_tensor * self.config.num_train_timesteps + else: + assert timesteps_array is not None + timesteps_tensor = torch.from_numpy(timesteps_array).to( + dtype=torch.float32, device=device + ) + + # 6. Append the terminal sigma value. + # If a model requires inverted sigma schedule for denoising but timesteps without inversion, the + # `invert_sigmas` flag can be set to `True`. This case is only required in Mochi + if self.config.invert_sigmas: + sigmas_tensor = 1.0 - sigmas_tensor + timesteps_tensor = sigmas_tensor * self.config.num_train_timesteps + sigmas_tensor = torch.cat( + [sigmas_tensor, torch.ones(1, device=sigmas_tensor.device)] + ) + else: + sigmas_tensor = torch.cat( + [sigmas_tensor, torch.zeros(1, device=sigmas_tensor.device)] + ) + + self.timesteps = timesteps_tensor + self.sigmas = sigmas_tensor + self._step_index = None + self._begin_index = None + + def index_for_timestep( + self, + timestep: float | torch.FloatTensor, + schedule_timesteps: torch.Tensor | None = None, + ) -> int: + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def _init_step_index(self, timestep: float | torch.FloatTensor) -> None: + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.FloatTensor, + timestep: int | torch.Tensor, + sample: torch.FloatTensor, + s_churn: float = 0.0, + s_tmin: float = 0.0, + s_tmax: float = float("inf"), + s_noise: float = 1.0, + generator: torch.Generator | None = None, + per_token_timesteps: torch.Tensor | None = None, + return_dict: bool = True, + ) -> FlowMatchEulerDiscreteSchedulerOutput | tuple[torch.FloatTensor, ...]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`int` or `torch.Tensor`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + s_churn (`float`): + s_tmin (`float`): + s_tmax (`float`): + s_noise (`float`, defaults to 1.0): + Scaling factor for noise added to the sample. + generator (`torch.Generator`, *optional*): + A random number generator. + per_token_timesteps (`torch.Tensor`, *optional*): + The timesteps for each token in the sample. + return_dict (`bool`): + Whether or not to return a + [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or tuple. + + Returns: + [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, + [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] is returned, + otherwise a tuple is returned where the first element is the sample tensor. + """ + + if isinstance(timestep, int | torch.IntTensor | torch.LongTensor): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `FlowMatchEulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + + if per_token_timesteps is not None: + per_token_sigmas = per_token_timesteps / self.config.num_train_timesteps + + sigmas = self.sigmas[:, None, None] + lower_mask = sigmas < per_token_sigmas[None] - 1e-6 + lower_sigmas = lower_mask * sigmas + lower_sigmas, _ = lower_sigmas.max(dim=0) + + current_sigma = per_token_sigmas[..., None] + next_sigma = lower_sigmas[..., None] + dt = current_sigma - next_sigma + else: + assert self.step_index is not None, "step_index should not be None" + sigma_idx = self.step_index + sigma = self.sigmas[sigma_idx] + sigma_next = self.sigmas[sigma_idx + 1] + + current_sigma = sigma + next_sigma = sigma_next + dt = sigma_next - sigma + + if self.config.stochastic_sampling: + x0 = sample - current_sigma * model_output + noise = torch.randn_like(sample) + prev_sample = (1.0 - next_sigma) * x0 + next_sigma * noise + else: + prev_sample = sample + dt * model_output + + # upon completion increase step index by one + assert self._step_index is not None, "_step_index should not be None" + self._step_index += 1 + if per_token_timesteps is None: + # Cast sample back to model compatible dtype + prev_sample = prev_sample.to(model_output.dtype) + + if isinstance(prev_sample, torch.Tensor | float) and not return_dict: + return (prev_sample,) + + return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample) + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras( + self, in_sigmas: torch.Tensor, num_inference_steps: int + ) -> torch.Tensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential + def _convert_to_exponential( + self, in_sigmas: torch.Tensor, num_inference_steps: int + ) -> torch.Tensor: + """Constructs an exponential noise schedule.""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = np.exp( + np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps) + ) + return sigmas + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta + def _convert_to_beta( + self, + in_sigmas: torch.Tensor, + num_inference_steps: int, + alpha: float = 0.6, + beta: float = 0.6, + ) -> torch.Tensor: + """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = np.array( + [ + sigma_min + (ppf * (sigma_max - sigma_min)) + for ppf in [ + scipy.stats.beta.ppf(timestep, alpha, beta) + for timestep in 1 - np.linspace(0, 1, num_inference_steps) + ] + ] + ) + return sigmas + + def _time_shift_exponential( + self, mu: float, sigma: float, t: torch.Tensor | np.ndarray + ) -> torch.Tensor | np.ndarray: + if isinstance(t, np.ndarray): + return np.exp(mu) / (np.exp(mu) + (1 / t - 1) ** sigma) + else: + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + def _time_shift_linear( + self, mu: float, sigma: float, t: torch.Tensor | np.ndarray + ) -> torch.Tensor | np.ndarray: + return mu / (mu + (1 / t - 1) ** sigma) + + def add_noise( + self, + clean_latent: torch.Tensor, + noise: torch.Tensor, + timestep: torch.IntTensor, + ) -> torch.Tensor: + """ + Args: + clean_latent: the clean latent with shape [B, C, H, W], + where B is batch_size or batch_size * num_frames + noise: the noise with shape [B, C, H, W] + timestep: the timestep with shape [1] or [bs * num_frames] or [bs, num_frames] + + Returns: + the corrupted latent with shape [B, C, H, W] + """ + # If timestep is [bs, num_frames] + if timestep.ndim == 2: + timestep = timestep.flatten(0, 1) + assert timestep.numel() == clean_latent.shape[0] + elif timestep.ndim == 1: + # If timestep is [1] + if timestep.shape[0] == 1: + timestep = timestep.expand(clean_latent.shape[0]) + else: + assert timestep.numel() == clean_latent.shape[0] + else: + raise ValueError(f"[add_noise] Invalid timestep shape: {timestep.shape}") + # timestep shape should be [B] + self.sigmas = self.sigmas.to(noise.device) + self.timesteps = self.timesteps.to(noise.device) + timestep_id = torch.argmin( + (self.timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1 + ) + sigma = self.sigmas[timestep_id].reshape(-1, 1, 1, 1) + sample = (1 - sigma) * clean_latent + sigma * noise + return sample.type_as(noise) + + def scale_model_input( + self, sample: torch.Tensor, timestep: int | None = None + ) -> torch.Tensor: + return sample + + def __len__(self) -> int: + return 0 + + +EntryClass = FlowMatchEulerDiscreteScheduler diff --git a/python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_flow_unipc_multistep.py b/python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_flow_unipc_multistep.py new file mode 100644 index 000000000..1e6b84e04 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_flow_unipc_multistep.py @@ -0,0 +1,853 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py +# Convert unipc for flow matching +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. + +import math +from typing import Any + +import numpy as np +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import ( + KarrasDiffusionSchedulers, + SchedulerMixin, + SchedulerOutput, +) +from diffusers.utils import deprecate + +from sglang.multimodal_gen.runtime.models.schedulers.base import BaseScheduler + + +class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin, BaseScheduler): + """ + `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + solver_order (`int`, default `2`): + The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1` + due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for + unconditional sampling. + prediction_type (`str`, defaults to "flow_prediction"): + Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts + the flow of the diffusion process. + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`. + predict_x0 (`bool`, defaults to `True`): + Whether to use the updating algorithm on the predicted x0. + solver_type (`str`, default `bh2`): + Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2` + otherwise. + lower_order_final (`bool`, default `True`): + Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can + stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + disable_corrector (`list`, default `[]`): + Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)` + and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is + usually disabled during the first few steps. + solver_p (`SchedulerMixin`, default `None`): + Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, + the sigmas are determined according to a sequence of noise levels {σi}. + use_exponential_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + final_sigmas_type (`str`, defaults to `"zero"`): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + solver_order: int = 2, + prediction_type: str = "flow_prediction", + shift: float | None = 1.0, + use_dynamic_shifting=False, + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + predict_x0: bool = True, + solver_type: str = "bh2", + lower_order_final: bool = True, + disable_corrector: tuple = (), + solver_p: SchedulerMixin = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + final_sigmas_type: str | None = "zero", # "zero", "sigma_min" + **kwargs, + ): + + if solver_type not in ["bh1", "bh2"]: + if solver_type in ["midpoint", "heun", "logrho"]: + self.register_to_config(solver_type="bh2") + else: + raise NotImplementedError( + f"{solver_type} is not implemented for {self.__class__}" + ) + + self.predict_x0 = predict_x0 + # setable values + self.num_inference_steps: int | None = None + alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps)[ + ::-1 + ].copy() + sigmas = 1.0 - alphas + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32) + + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + assert shift is not None + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) # pyright: ignore + + self.sigmas = sigmas + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + self.timesteps = sigmas * num_train_timesteps + self.num_train_timesteps = num_train_timesteps + + self.model_outputs = [None] * solver_order + self.timestep_list: list[Any | None] = [None] * solver_order + self.lower_order_nums = 0 + self.disable_corrector = list(disable_corrector) + self.solver_p = solver_p + self.last_sample = None + self._step_index: int | None = None + self._begin_index: int | None = None + + BaseScheduler.__init__(self) + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + def set_shift(self, shift: float) -> None: + self.config.shift = shift + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps + def set_timesteps( + self, + num_inference_steps: int | None = None, + device: str | torch.device = None, + sigmas: list[float] | None = None, + mu: float | None | None = None, + shift: float | None | None = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + Args: + num_inference_steps (`int`): + Total number of the spacing of the time steps. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + + if self.config.use_dynamic_shifting and mu is None: + raise ValueError( + " you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`" + ) + + if sigmas is None: + assert num_inference_steps is not None + sigmas = np.linspace( + self.sigma_max, self.sigma_min, num_inference_steps + 1 + ).copy()[ + :-1 + ] # pyright: ignore + + if self.config.use_dynamic_shifting: + assert mu is not None + sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore + else: + if shift is None: + shift = self.config.shift + assert isinstance(sigmas, np.ndarray) + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) # pyright: ignore + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + + timesteps = sigmas * self.config.num_train_timesteps + sigmas = np.concatenate([sigmas, [sigma_last]]).astype( + np.float32 + ) # pyright: ignore + + self.sigmas = torch.from_numpy(sigmas).to(device=device) + self.timesteps = torch.from_numpy(timesteps).to( + device=device, dtype=torch.int64 + ) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + self.last_sample = None + if self.solver_p: + self.solver_p.set_timesteps(self.num_inference_steps, device=device) + + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + self._begin_index = None + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = ( + sample.float() + ) # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = ( + torch.clamp(sample, -s, s) / s + ) # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def _sigma_to_alpha_sigma_t(self, sigma) -> tuple[Any, Any]: + return 1 - sigma, sigma + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + def convert_model_output( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + r""" + Convert the model output to the corresponding type the UniPC algorithm needs. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The converted model output. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyword argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + + if self.predict_x0: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + else: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + epsilon = sample - (1 - sigma_t) * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." + ) + + if self.config.thresholding: + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + x0_pred = self._threshold_sample(x0_pred) + epsilon = model_output + x0_pred + + return epsilon + + def multistep_uni_p_bh_update( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + order: int | None = None, # pyright: ignore + **kwargs, + ) -> torch.Tensor: + """ + One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model at the current timestep. + prev_timestep (`int`): + The previous discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + order (`int`): + The order of UniP at this timestep (corresponds to the *p* in UniPC-p). + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError(" missing `sample` as a required keyword argument") + if order is None: + if len(args) > 2: + order = args[2] + else: + raise ValueError(" missing `order` as a required keyword argument") + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + model_output_list = self.model_outputs + + s0 = self.timestep_list[-1] + m0 = model_output_list[-1] + x = sample + + if self.solver_p: + x_t = self.solver_p.step(model_output, s0, x).prev_sample + return x_t + + sigma_t, sigma_s0 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + ) # pyright: ignore + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + device = sample.device + + rks = [] + D1s: list[Any] | None = [] + sigmas = self.sigmas.to(device=device) + for i in range(1, order): + si = self.step_index - i # pyright: ignore + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + assert mi is not None + D1s.append((mi - m0) / rk) # pyright: ignore + + if len(rks) > 0: + rks = torch.stack(rks) + one = torch.ones(1, device=device, dtype=rks.dtype) + rks = torch.cat([rks, one]) + else: + rks = torch.ones(1, device=device, dtype=h.dtype) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.stack(b) + + if D1s is not None and len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) # (B, K) + # for order 2, we use a simplified version + if order == 2: + rhos_p = 0.5 * torch.ones(1, dtype=x.dtype, device=device) + else: + assert isinstance(R, torch.Tensor) + rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype) + else: + D1s = None + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum( + "k,bkc...->bc...", rhos_p, D1s + ) # pyright: ignore + else: + pred_res = 0 + x_t = x_t_ - alpha_t * B_h * pred_res + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum( + "k,bkc...->bc...", rhos_p, D1s + ) # pyright: ignore + else: + pred_res = 0 + x_t = x_t_ - sigma_t * B_h * pred_res + + x_t = x_t.to(x.dtype) + return x_t + + def multistep_uni_c_bh_update( + self, + this_model_output: torch.Tensor, + *args, + last_sample: torch.Tensor = None, + this_sample: torch.Tensor = None, + order: int | None = None, # pyright: ignore + **kwargs, + ) -> torch.Tensor: + """ + One step for the UniC (B(h) version). + + Args: + this_model_output (`torch.Tensor`): + The model outputs at `x_t`. + this_timestep (`int`): + The current timestep `t`. + last_sample (`torch.Tensor`): + The generated sample before the last predictor `x_{t-1}`. + this_sample (`torch.Tensor`): + The generated sample after the last predictor `x_{t}`. + order (`int`): + The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`. + + Returns: + `torch.Tensor`: + The corrected sample tensor at the current timestep. + """ + this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None) + if last_sample is None: + if len(args) > 1: + last_sample = args[1] + else: + raise ValueError(" missing`last_sample` as a required keyword argument") + if this_sample is None: + if len(args) > 2: + this_sample = args[2] + else: + raise ValueError(" missing`this_sample` as a required keyword argument") + if order is None: + if len(args) > 3: + order = args[3] + else: + raise ValueError(" missing`order` as a required keyword argument") + if this_timestep is not None: + deprecate( + "this_timestep", + "1.0.0", + "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + model_output_list = self.model_outputs + + m0 = model_output_list[-1] + x = last_sample + x_t = this_sample + model_t = this_model_output + + sigma_t, sigma_s0 = ( + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + ) # pyright: ignore + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + device = this_sample.device + + # Build rks and D1s fully on device to avoid any host-device sync + # Fast paths for small orders (common cases: 1 or 2) + if order == 1: + rks = torch.ones(1, device=device, dtype=h.dtype) + D1s = None + elif order == 2: + # order == 2 -> only one historical point is used + si = self.step_index - 2 # i = 1 + mi = model_output_list[-2] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h # 0-dim tensor on device + # rks = [rk, 1.0] but keep it on device without list->tensor sync + rks = torch.stack((rk, torch.ones_like(rk))) + assert mi is not None + # D1s shape: (B, K=1, C, ...) to match later einsum over K + D1s = ((mi - m0) / rk).unsqueeze(1) # pyright: ignore + else: + rks_list = [] + D1s_list = [] + for i in range(1, order): + si = self.step_index - (i + 1) + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks_list.append(rk) + assert mi is not None + D1s_list.append((mi - m0) / rk) # pyright: ignore + + # Append 1.0 as a device tensor to rks + rks = torch.stack(rks_list + [torch.ones_like(rks_list[0])]) + D1s = torch.stack(D1s_list, dim=1) if len(D1s_list) > 0 else None + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + # Avoid torch.tensor(list_of_gpu_scalars) which syncs to host + b = torch.stack(b) + + # D1s is already prepared above for order==2; remains None for order==1 + + # for order 1, we use a simplified version + if order == 1: + rhos_c = 0.5 * torch.ones(1, dtype=x.dtype, device=device) + elif order == 2: + # Manually solve the 2x2 linear system to avoid device synchronization from torch.linalg.solve + # R = [[1, 1], [rk, 1]], where rk = rks[0] + rk = rks[0] + det = 1 - rk + # Using Cramer's rule to solve for rhos_c = [x0, x1] + # x0 = (b0 - b1) / det + # x1 = (b1 - rk * b0) / det + rhos_c_0 = (b[0] - b[1]) / det + rhos_c_1 = (b[1] - rk * b[0]) / det + rhos_c = torch.stack([rhos_c_0, rhos_c_1]) + else: + rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype) + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t) + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t) + x_t = x_t.to(x.dtype) + return x_t + + def index_for_timestep(self, timestep, schedule_timesteps=None) -> int: + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + step_index: int = indices[pos].item() + + return step_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index + def _init_step_index(self, timestep) -> None: + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.Tensor, + timestep: int | torch.Tensor, + sample: torch.Tensor, + return_dict: bool = True, + generator=None, + ) -> SchedulerOutput | tuple: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep UniPC. + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to call 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + use_corrector = ( + self.step_index > 0 + and self.step_index - 1 not in self.disable_corrector + and self.last_sample is not None # pyright: ignore + ) + + sample = sample.to(model_output.device) + model_output_convert = self.convert_model_output(model_output, sample=sample) + + if use_corrector: + sample = self.multistep_uni_c_bh_update( + this_model_output=model_output_convert, + last_sample=self.last_sample, + this_sample=sample, + order=self.this_order, + ) + + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.timestep_list[i] = self.timestep_list[i + 1] + + self.model_outputs[-1] = model_output_convert + self.timestep_list[-1] = timestep # pyright: ignore + + if self.config.lower_order_final: + this_order = min( + self.config.solver_order, len(self.timesteps) - self.step_index + ) # pyright: ignore + else: + this_order = self.config.solver_order + + self.this_order: int = min( + this_order, self.lower_order_nums + 1 + ) # warmup for multistep + assert self.this_order > 0 + + self.last_sample = sample + prev_sample = self.multistep_uni_p_bh_update( + model_output=model_output, # pass the original non-converted model output, in case solver-p is used + sample=sample, + order=self.this_order, + ) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + # upon completion increase step index by one + assert self._step_index is not None + self._step_index += 1 # pyright: ignore + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to( + device=original_samples.device, dtype=original_samples.dtype + ) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to( + original_samples.device, dtype=torch.float32 + ) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [ + self.index_for_timestep(t, schedule_timesteps) for t in timesteps + ] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps + + +EntryClass = FlowUniPCMultistepScheduler diff --git a/python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_self_forcing_flow_match.py b/python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_self_forcing_flow_match.py new file mode 100644 index 000000000..08fc4d8bb --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_self_forcing_flow_match.py @@ -0,0 +1,172 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import SchedulerMixin +from diffusers.utils import BaseOutput + +from sglang.multimodal_gen.runtime.models.schedulers.base import BaseScheduler +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class SelfForcingFlowMatchSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.FloatTensor + + +class SelfForcingFlowMatchScheduler(BaseScheduler, ConfigMixin, SchedulerMixin): + + config_name = "scheduler_config.json" + order = 1 + + @register_to_config + def __init__( + self, + num_inference_steps=100, + num_train_timesteps=1000, + shift=3.0, + sigma_max=1.0, + sigma_min=0.003 / 1.002, + inverse_timesteps=False, + extra_one_step=False, + reverse_sigmas=False, + training=False, + ): + self.num_train_timesteps = num_train_timesteps + self.shift = shift + self.sigma_max = sigma_max + self.sigma_min = sigma_min + self.inverse_timesteps = inverse_timesteps + self.extra_one_step = extra_one_step + self.reverse_sigmas = reverse_sigmas + self.set_timesteps(num_inference_steps, training=training) + + def set_timesteps( + self, + num_inference_steps=100, + denoising_strength=1.0, + training=False, + return_dict=False, + **kwargs, + ): + sigma_start = ( + self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength + ) + if self.extra_one_step: + self.sigmas = torch.linspace( + sigma_start, self.sigma_min, num_inference_steps + 1 + )[:-1] + else: + self.sigmas = torch.linspace( + sigma_start, self.sigma_min, num_inference_steps + ) + if self.inverse_timesteps: + self.sigmas = torch.flip(self.sigmas, dims=[0]) + self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas) + if self.reverse_sigmas: + self.sigmas = 1 - self.sigmas + self.timesteps = self.sigmas * self.num_train_timesteps + if training: + x = self.timesteps + y = torch.exp( + -2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2 + ) + y_shifted = y - y.min() + bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum()) + self.linear_timesteps_weights = bsmntw_weighing + + def step( + self, + model_output: torch.FloatTensor, + timestep: torch.FloatTensor, + sample: torch.FloatTensor, + to_final=False, + return_dict=False, + **kwargs, + ): + if timestep.ndim == 2: + timestep = timestep.flatten(0, 1) + elif timestep.ndim == 0: + # handles the case where timestep is a scalar, this occurs when we + # use this scheduler for ODE trajectory + timestep = timestep.unsqueeze(0) + + self.sigmas = self.sigmas.to(model_output.device) + self.timesteps = self.timesteps.to(model_output.device) + timestep = timestep.to(model_output.device) + + timestep_id = torch.argmin( + (self.timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1 + ) + sigma = self.sigmas[timestep_id].reshape(-1, 1, 1, 1) + if to_final or (timestep_id + 1 >= len(self.timesteps)).any(): + sigma_ = 1 if (self.inverse_timesteps or self.reverse_sigmas) else 0 + else: + sigma_ = self.sigmas[timestep_id + 1].reshape(-1, 1, 1, 1) + prev_sample = sample + model_output * (sigma_ - sigma) + if isinstance(prev_sample, torch.Tensor | float) and not return_dict: + return (prev_sample,) + return SelfForcingFlowMatchSchedulerOutput(prev_sample=prev_sample) + + def add_noise(self, original_samples, noise, timestep): + """ + Diffusion forward corruption process. + Input: + - clean_latent: the clean latent with shape [B*T, C, H, W] + - noise: the noise with shape [B*T, C, H, W] + - timestep: the timestep with shape [B*T] + Output: the corrupted latent with shape [B*T, C, H, W] + """ + if timestep.ndim == 2: + timestep = timestep.flatten(0, 1) + self.sigmas = self.sigmas.to(noise.device) + self.timesteps = self.timesteps.to(noise.device) + timestep_id = torch.argmin( + (self.timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1 + ) + sigma = self.sigmas[timestep_id].reshape(-1, 1, 1, 1) + sample = (1 - sigma) * original_samples + sigma * noise + return sample.type_as(noise) + + def training_target(self, sample, noise, timestep): + target = noise - sample + return target + + def training_weight(self, timestep): + """ + Input: + - timestep: the timestep with shape [B*T] + Output: the corresponding weighting [B*T] + """ + if timestep.ndim == 2: + timestep = timestep.flatten(0, 1) + self.linear_timesteps_weights = self.linear_timesteps_weights.to( + timestep.device + ) + timestep_id = torch.argmin( + (self.timesteps.unsqueeze(1) - timestep.unsqueeze(0)).abs(), dim=0 + ) + weights = self.linear_timesteps_weights[timestep_id] + return weights + + def scale_model_input( + self, sample: torch.Tensor, timestep: int | None = None + ) -> torch.Tensor: + return sample + + def set_shift(self, shift: float) -> None: + self.shift = shift + + +EntryClass = SelfForcingFlowMatchScheduler diff --git a/python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_unipc_multistep.py b/python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_unipc_multistep.py new file mode 100644 index 000000000..df5e9b834 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/models/schedulers/scheduling_unipc_multistep.py @@ -0,0 +1,1207 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# Copyright 2025 TSAIL Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: check https://huggingface.co/papers/2302.04867 and https://github.com/wl-zhao/UniPC for more info +# The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +# ============================================================================== +# +# Modified from diffusers==0.35.0.dev0 +# +# ============================================================================== + +import math + +import numpy as np +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import ( + KarrasDiffusionSchedulers, + SchedulerMixin, + SchedulerOutput, +) +from diffusers.utils import deprecate, is_scipy_available + +from sglang.multimodal_gen.runtime.models.schedulers.base import BaseScheduler + +if is_scipy_available(): + import scipy.stats + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1) + + + Args: + betas (`torch.Tensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.Tensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + +class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin, BaseScheduler): + """ + `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + solver_order (`int`, default `2`): + The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1` + due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for + unconditional sampling. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`. + predict_x0 (`bool`, defaults to `True`): + Whether to use the updating algorithm on the predicted x0. + solver_type (`str`, default `bh2`): + Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2` + otherwise. + lower_order_final (`bool`, default `True`): + Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can + stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + disable_corrector (`list`, default `[]`): + Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)` + and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is + usually disabled during the first few steps. + solver_p (`SchedulerMixin`, default `None`): + Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, + the sigmas are determined according to a sequence of noise levels {σi}. + use_exponential_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. + use_beta_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta + Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information. + use_flow_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use flow sigmas for step sizes in the noise schedule during the sampling process. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + final_sigmas_type (`str`, defaults to `"zero"`): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: np.ndarray | list[float] | None = None, + solver_order: int = 2, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + predict_x0: bool = True, + solver_type: str = "bh2", + lower_order_final: bool = True, + disable_corrector: list[int] = [], + solver_p: SchedulerMixin = None, + use_karras_sigmas: bool | None = False, + use_exponential_sigmas: bool | None = False, + use_beta_sigmas: bool | None = False, + use_flow_sigmas: bool | None = False, + flow_shift: float | None = 1.0, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + final_sigmas_type: str | None = "zero", # "zero", "sigma_min" + rescale_betas_zero_snr: bool = False, + use_dynamic_shifting: bool = False, + time_shift_type: str = "exponential", + ): + if self.config.use_beta_sigmas and not is_scipy_available(): + raise ImportError( + "Make sure to install scipy if you want to use beta sigmas." + ) + if ( + sum( + [ + self.config.use_beta_sigmas, + self.config.use_exponential_sigmas, + self.config.use_karras_sigmas, + ] + ) + > 1 + ): + raise ValueError( + "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." + ) + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace( + beta_start, beta_end, num_train_timesteps, dtype=torch.float32 + ) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace( + beta_start**0.5, + beta_end**0.5, + num_train_timesteps, + dtype=torch.float32, + ) + ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError( + f"{beta_schedule} is not implemented for {self.__class__}" + ) + + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + if rescale_betas_zero_snr: + # Close to 0 without being 0 so first sigma is not inf + # FP16 smallest positive subnormal works well here + self.alphas_cumprod[-1] = 2**-24 + + # Currently we only support VP-type noise schedule + self.alpha_t = torch.sqrt(self.alphas_cumprod) + self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) + self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) + self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + if solver_type not in ["bh1", "bh2"]: + if solver_type in ["midpoint", "heun", "logrho"]: + self.register_to_config(solver_type="bh2") + else: + raise NotImplementedError( + f"{solver_type} is not implemented for {self.__class__}" + ) + + self.predict_x0 = predict_x0 + # setable values + self.num_inference_steps = None + timesteps = np.linspace( + 0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32 + )[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.num_train_timesteps = num_train_timesteps + self.model_outputs = [None] * solver_order + self.timestep_list = [None] * solver_order + self.lower_order_nums = 0 + self.disable_corrector = disable_corrector + self.solver_p = solver_p + self.last_sample = None + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + BaseScheduler.__init__(self) + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + def set_shift(self, shift: float) -> None: + self.config.flow_shift = shift + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def set_timesteps( + self, + num_inference_steps: int, + device: str | torch.device = None, + mu: float | None = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891 + if mu is not None: + assert ( + self.config.use_dynamic_shifting + and self.config.time_shift_type == "exponential" + ) + self.config.flow_shift = np.exp(mu) + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace( + 0, self.config.num_train_timesteps - 1, num_inference_steps + 1 + ) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // (num_inference_steps + 1) + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = ( + (np.arange(0, num_inference_steps + 1) * step_ratio) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = ( + np.arange(self.config.num_train_timesteps, 0, -step_ratio) + .round() + .copy() + .astype(np.int64) + ) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + if self.config.use_karras_sigmas: + log_sigmas = np.log(sigmas) + sigmas = np.flip(sigmas).copy() + sigmas = self._convert_to_karras( + in_sigmas=sigmas, num_inference_steps=num_inference_steps + ) + timesteps = np.array( + [self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas] + ).round() + if self.config.final_sigmas_type == "sigma_min": + sigma_last = sigmas[-1] + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) + elif self.config.use_exponential_sigmas: + log_sigmas = np.log(sigmas) + sigmas = np.flip(sigmas).copy() + sigmas = self._convert_to_exponential( + in_sigmas=sigmas, num_inference_steps=num_inference_steps + ) + timesteps = np.array( + [self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas] + ) + if self.config.final_sigmas_type == "sigma_min": + sigma_last = sigmas[-1] + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) + elif self.config.use_beta_sigmas: + log_sigmas = np.log(sigmas) + sigmas = np.flip(sigmas).copy() + sigmas = self._convert_to_beta( + in_sigmas=sigmas, num_inference_steps=num_inference_steps + ) + timesteps = np.array( + [self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas] + ) + if self.config.final_sigmas_type == "sigma_min": + sigma_last = sigmas[-1] + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) + elif self.config.use_flow_sigmas: + alphas = np.linspace( + 1, 1 / self.config.num_train_timesteps, num_inference_steps + 1 + ) + sigmas = 1.0 - alphas + sigmas = np.flip( + self.config.flow_shift + * sigmas + / (1 + (self.config.flow_shift - 1) * sigmas) + )[:-1].copy() + timesteps = (sigmas * self.config.num_train_timesteps).copy() + if self.config.final_sigmas_type == "sigma_min": + sigma_last = sigmas[-1] + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) + else: + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ( + (1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0] + ) ** 0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) + + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps).to( + device=device, dtype=torch.int64 + ) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + self.last_sample = None + if self.solver_p: + self.solver_p.set_timesteps(self.num_inference_steps, device=device) + + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://huggingface.co/papers/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = ( + sample.float() + ) # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = ( + torch.clamp(sample, -s, s) / s + ) # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(np.maximum(sigma, 1e-10)) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = ( + np.cumsum((dists >= 0), axis=0) + .argmax(axis=0) + .clip(max=log_sigmas.shape[0] - 2) + ) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t + def _sigma_to_alpha_sigma_t(self, sigma): + if self.config.use_flow_sigmas: + alpha_t = 1 - sigma + sigma_t = sigma + else: + alpha_t = 1 / ((sigma**2 + 1) ** 0.5) + sigma_t = sigma * alpha_t + + return alpha_t, sigma_t + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras( + self, in_sigmas: torch.Tensor, num_inference_steps + ) -> torch.Tensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential + def _convert_to_exponential( + self, in_sigmas: torch.Tensor, num_inference_steps: int + ) -> torch.Tensor: + """Constructs an exponential noise schedule.""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = np.exp( + np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps) + ) + return sigmas + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta + def _convert_to_beta( + self, + in_sigmas: torch.Tensor, + num_inference_steps: int, + alpha: float = 0.6, + beta: float = 0.6, + ) -> torch.Tensor: + """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = np.array( + [ + sigma_min + (ppf * (sigma_max - sigma_min)) + for ppf in [ + scipy.stats.beta.ppf(timestep, alpha, beta) + for timestep in 1 - np.linspace(0, 1, num_inference_steps) + ] + ] + ) + return sigmas + + def convert_model_output( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + r""" + Convert the model output to the corresponding type the UniPC algorithm needs. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The converted model output. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyword argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + + if self.predict_x0: + if self.config.prediction_type == "epsilon": + x0_pred = (sample - sigma_t * model_output) / alpha_t + elif self.config.prediction_type == "sample": + x0_pred = model_output + elif self.config.prediction_type == "v_prediction": + x0_pred = alpha_t * sample - sigma_t * model_output + elif self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, " + "`v_prediction`, or `flow_prediction` for the UniPCMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + else: + if self.config.prediction_type == "epsilon": + return model_output + elif self.config.prediction_type == "sample": + epsilon = (sample - alpha_t * model_output) / sigma_t + return epsilon + elif self.config.prediction_type == "v_prediction": + epsilon = alpha_t * model_output + sigma_t * sample + return epsilon + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the UniPCMultistepScheduler." + ) + + def multistep_uni_p_bh_update( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + order: int = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model at the current timestep. + prev_timestep (`int`): + The previous discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + order (`int`): + The order of UniP at this timestep (corresponds to the *p* in UniPC-p). + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyword argument") + if order is None: + if len(args) > 2: + order = args[2] + else: + raise ValueError("missing `order` as a required keyword argument") + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + model_output_list = self.model_outputs + + s0 = self.timestep_list[-1] + m0 = model_output_list[-1] + x = sample + + if self.solver_p: + x_t = self.solver_p.step(model_output, s0, x).prev_sample + return x_t + + sigma_t, sigma_s0 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + ) + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + device = sample.device + + rks = [] + D1s = [] + for i in range(1, order): + si = self.step_index - i + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) # (B, K) + # for order 2, we use a simplified version + if order == 2: + rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype) + else: + D1s = None + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) + else: + pred_res = 0 + x_t = x_t_ - alpha_t * B_h * pred_res + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s) + else: + pred_res = 0 + x_t = x_t_ - sigma_t * B_h * pred_res + + x_t = x_t.to(x.dtype) + return x_t + + def multistep_uni_c_bh_update( + self, + this_model_output: torch.Tensor, + *args, + last_sample: torch.Tensor = None, + this_sample: torch.Tensor = None, + order: int = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the UniC (B(h) version). + + Args: + this_model_output (`torch.Tensor`): + The model outputs at `x_t`. + this_timestep (`int`): + The current timestep `t`. + last_sample (`torch.Tensor`): + The generated sample before the last predictor `x_{t-1}`. + this_sample (`torch.Tensor`): + The generated sample after the last predictor `x_{t}`. + order (`int`): + The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`. + + Returns: + `torch.Tensor`: + The corrected sample tensor at the current timestep. + """ + this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None) + if last_sample is None: + if len(args) > 1: + last_sample = args[1] + else: + raise ValueError("missing `last_sample` as a required keyword argument") + if this_sample is None: + if len(args) > 2: + this_sample = args[2] + else: + raise ValueError("missing `this_sample` as a required keyword argument") + if order is None: + if len(args) > 3: + order = args[3] + else: + raise ValueError("missing `order` as a required keyword argument") + if this_timestep is not None: + deprecate( + "this_timestep", + "1.0.0", + "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + model_output_list = self.model_outputs + + m0 = model_output_list[-1] + x = last_sample + x_t = this_sample + model_t = this_model_output + + sigma_t, sigma_s0 = ( + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + ) + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + device = this_sample.device + + rks = [] + D1s = [] + for i in range(1, order): + si = self.step_index - (i + 1) + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) + else: + D1s = None + + # for order 1, we use a simplified version + if order == 1: + rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype) + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t) + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t) + x_t = x_t.to(x.dtype) + return x_t + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + index_candidates = (schedule_timesteps == timestep).nonzero() + + if len(index_candidates) == 0: + step_index = len(self.timesteps) - 1 + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + + return step_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index + def _init_step_index(self, timestep): + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.Tensor, + timestep: int | torch.Tensor, + sample: torch.Tensor, + return_dict: bool = True, + ) -> SchedulerOutput | tuple: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep UniPC. + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to call 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + use_corrector = ( + self.step_index > 0 + and self.step_index - 1 not in self.disable_corrector + and self.last_sample is not None + ) + + model_output_convert = self.convert_model_output(model_output, sample=sample) + if use_corrector: + sample = self.multistep_uni_c_bh_update( + this_model_output=model_output_convert, + last_sample=self.last_sample, + this_sample=sample, + order=self.this_order, + ) + + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.timestep_list[i] = self.timestep_list[i + 1] + + self.model_outputs[-1] = model_output_convert + self.timestep_list[-1] = timestep + + if self.config.lower_order_final: + this_order = min( + self.config.solver_order, len(self.timesteps) - self.step_index + ) + else: + this_order = self.config.solver_order + + self.this_order = min( + this_order, self.lower_order_nums + 1 + ) # warmup for multistep + assert self.this_order > 0 + + self.last_sample = sample + prev_sample = self.multistep_uni_p_bh_update( + model_output=model_output, # pass the original non-converted model output, in case solver-p is used + sample=sample, + order=self.this_order, + ) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: torch.Tensor, *args, **kwargs) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to( + device=original_samples.device, dtype=original_samples.dtype + ) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to( + original_samples.device, dtype=torch.float32 + ) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [ + self.index_for_timestep(t, schedule_timesteps) for t in timesteps + ] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps + + +EntryClass = UniPCMultistepScheduler diff --git a/python/sglang/multimodal_gen/runtime/models/utils.py b/python/sglang/multimodal_gen/runtime/models/utils.py new file mode 100644 index 000000000..6761593ed --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/models/utils.py @@ -0,0 +1,194 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/utils.py +"""Utils for model executor.""" +from typing import Any + +import torch + + +# TODO(PY): move it elsewhere +def auto_attributes(init_func): + """ + Decorator that automatically adds all initialization arguments as object attributes. + + Example: + @auto_attributes + def __init__(self, a=1, b=2): + pass + + # This will automatically set: + # - self.a = 1 and self.b = 2 + # - self.config.a = 1 and self.config.b = 2 + """ + + def wrapper(self, *args, **kwargs): + # Get the function signature + import inspect + + signature = inspect.signature(init_func) + parameters = signature.parameters + + # Get parameter names (excluding 'self') + param_names = list(parameters.keys())[1:] + + # Bind arguments to parameters + bound_args = signature.bind(self, *args, **kwargs) + bound_args.apply_defaults() + + # Create config object if it doesn't exist + if not hasattr(self, "config"): + self.config = type("Config", (), {})() + + # Set attributes on self and self.config + for name in param_names: + if name in bound_args.arguments: + value = bound_args.arguments[name] + setattr(self, name, value) + setattr(self.config, name, value) + + # Call the original __init__ function + return init_func(self, *args, **kwargs) + + return wrapper + + +def set_weight_attrs( + weight: torch.Tensor, + weight_attrs: dict[str, Any] | None, +): + """Set attributes on a weight tensor. + + This method is used to set attributes on a weight tensor. This method + will not overwrite existing attributes. + + Args: + weight: The weight tensor. + weight_attrs: A dictionary of attributes to set on the weight tensor. + """ + if weight_attrs is None: + return + for key, value in weight_attrs.items(): + assert not hasattr(weight, key), f"Overwriting existing tensor attribute: {key}" + + # NOTE(woosuk): During weight loading, we often do something like: + # narrowed_tensor = param.data.narrow(0, offset, len) + # narrowed_tensor.copy_(real_weight) + # expecting narrowed_tensor and param.data to share the same storage. + # However, on TPUs, narrowed_tensor will lazily propagate to the base + # tensor, which is param.data, leading to the redundant memory usage. + # This sometimes causes OOM errors during model loading. To avoid this, + # we sync the param tensor after its weight loader is called. + # TODO(woosuk): Remove this hack once we have a better solution. + from sglang.multimodal_gen.runtime.platforms import current_platform + + if current_platform.is_tpu() and key == "weight_loader": + value = _make_synced_weight_loader(value) + setattr(weight, key, value) + + +def _make_synced_weight_loader(original_weight_loader) -> Any: + + def _synced_weight_loader(param, *args, **kwargs): + original_weight_loader(param, *args, **kwargs) + torch._sync(param) + + return _synced_weight_loader + + +def extract_layer_index(layer_name: str) -> int: + """ + Extract the layer index from the module name. + Examples: + - "encoder.layers.0" -> 0 + - "encoder.layers.1.self_attn" -> 1 + - "2.self_attn" -> 2 + - "model.encoder.layers.0.sub.1" -> ValueError + """ + subnames = layer_name.split(".") + int_vals: list[int] = [] + for subname in subnames: + try: + int_vals.append(int(subname)) + except ValueError: + continue + assert len(int_vals) == 1, ( + f"layer name {layer_name} should" " only contain one integer" + ) + return int_vals[0] + + +def modulate( + x: torch.Tensor, + shift: torch.Tensor | None = None, + scale: torch.Tensor | None = None, +) -> torch.Tensor: + """modulate by shift and scale + + Args: + x (torch.Tensor): input tensor. + shift (torch.Tensor, optional): shift tensor. Defaults to None. + scale (torch.Tensor, optional): scale tensor. Defaults to None. + + Returns: + torch.Tensor: the output tensor after modulate. + """ + if scale is None and shift is None: + return x + elif shift is None: + return x * (1 + scale.unsqueeze(1)) # type: ignore[union-attr] + elif scale is None: + return x + shift.unsqueeze(1) # type: ignore[union-attr] + else: + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze( + 1 + ) # type: ignore[union-attr] + + +def pred_noise_to_pred_video( + pred_noise: torch.Tensor, + noise_input_latent: torch.Tensor, + timestep: torch.Tensor, + scheduler: Any, +) -> torch.Tensor: + """ + Convert predicted noise to clean latent. + + Args: + pred_noise: the predicted noise with shape [B, C, H, W] + where B is batch_size or batch_size * num_frames + noise_input_latent: the noisy latent with shape [B, C, H, W], + timestep: the timestep with shape [1] or [bs * num_frames] or [bs, num_frames] + scheduler: the scheduler + + Returns: + the predicted video with shape [B, C, H, W] + """ + # If timestep is [bs, num_frames] + if timestep.ndim == 2: + timestep = timestep.flatten(0, 1) + assert timestep.numel() == noise_input_latent.shape[0] + elif timestep.ndim == 1: + # If timestep is [1] + if timestep.shape[0] == 1: + timestep = timestep.expand(noise_input_latent.shape[0]) + else: + assert timestep.numel() == noise_input_latent.shape[0] + else: + raise ValueError( + f"[pred_noise_to_pred_video] Invalid timestep shape: {timestep.shape}" + ) + # timestep shape should be [B] + dtype = pred_noise.dtype + device = pred_noise.device + pred_noise = pred_noise.double().to(device) + noise_input_latent = noise_input_latent.double().to(device) + sigmas = scheduler.sigmas.double().to(device) + timesteps = scheduler.timesteps.double().to(device) + timestep_id = torch.argmin( + (timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1 + ) + sigma_t = sigmas[timestep_id].reshape(-1, 1, 1, 1) + pred_video = noise_input_latent - sigma_t * pred_noise + return pred_video.to(dtype) diff --git a/python/sglang/multimodal_gen/runtime/models/vaes/autoencoder.py b/python/sglang/multimodal_gen/runtime/models/vaes/autoencoder.py new file mode 100644 index 000000000..91fa447e0 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/models/vaes/autoencoder.py @@ -0,0 +1,585 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +from typing import Dict, Optional, Tuple, Union + +import torch +from diffusers.models.attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + Attention, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, + FusedAttnProcessor2_0, +) +from diffusers.models.autoencoders.vae import ( + Decoder, + DecoderOutput, + DiagonalGaussianDistribution, + Encoder, +) +from diffusers.models.modeling_outputs import AutoencoderKLOutput +from torch import nn + +from sglang.multimodal_gen.configs.models.vaes.flux import FluxVAEConfig + + +class AutoencoderKL(nn.Module): + r""" + A VAE model with KL loss for encoding images into latents and decoding latent representations into images. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): + Tuple of block output channels. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `32`): Sample input size. + scaling_factor (`float`, *optional*, defaults to 0.18215): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) paper. + force_upcast (`bool`, *optional*, default to `True`): + If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE + can be fine-tuned / trained to a lower range without losing too much precision in which case `force_upcast` + can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix + mid_block_add_attention (`bool`, *optional*, default to `True`): + If enabled, the mid_block of the Encoder and Decoder will have attention blocks. If set to false, the + mid_block will only have resnet blocks + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"] + + def __init__( + self, + config: FluxVAEConfig, + ): + super().__init__() + self.config = config + arch_config = config.arch_config + + in_channels = arch_config.in_channels + out_channels = arch_config.out_channels + down_block_types = arch_config.down_block_types + up_block_types = arch_config.up_block_types + block_out_channels = arch_config.block_out_channels + layers_per_block = arch_config.layers_per_block + act_fn = arch_config.act_fn + latent_channels = arch_config.latent_channels + norm_num_groups = arch_config.norm_num_groups + sample_size = arch_config.sample_size + use_quant_conv = arch_config.use_quant_conv + use_post_quant_conv = arch_config.use_post_quant_conv + mid_block_add_attention = arch_config.mid_block_add_attention + + # pass init params to Encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=True, + mid_block_add_attention=mid_block_add_attention, + ) + + # pass init params to Decoder + self.decoder = Decoder( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + mid_block_add_attention=mid_block_add_attention, + ) + + self.quant_conv = ( + nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) + if use_quant_conv + else None + ) + self.post_quant_conv = ( + nn.Conv2d(latent_channels, latent_channels, 1) + if use_post_quant_conv + else None + ) + + self.use_slicing = False + self.use_tiling = False + + # only relevant if vae tiling is enabled + self.tile_sample_min_size = sample_size + sample_size = ( + self.config.sample_size[0] + if isinstance(self.config.sample_size, (list, tuple)) + else self.config.sample_size + ) + self.tile_latent_min_size = int( + sample_size / (2 ** (len(self.config.block_out_channels) - 1)) + ) + self.tile_overlap_factor = 0.25 + + def enable_tiling(self, use_tiling: bool = True): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.use_tiling = use_tiling + + def disable_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.enable_tiling(False) + + def enable_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors( + name: str, + module: torch.nn.Module, + processors: Dict[str, AttentionProcessor], + ): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor( + self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]] + ): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all( + proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS + for proc in self.attn_processors.values() + ): + processor = AttnAddedKVProcessor() + elif all( + proc.__class__ in CROSS_ATTENTION_PROCESSORS + for proc in self.attn_processors.values() + ): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, height, width = x.shape + + if self.use_tiling and ( + width > self.tile_sample_min_size or height > self.tile_sample_min_size + ): + return self._tiled_encode(x) + + enc = self.encoder(x) + if self.quant_conv is not None: + enc = self.quant_conv(enc) + + return enc + + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + """ + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded images. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode( + self, z: torch.Tensor, return_dict: bool = True + ) -> Union[DecoderOutput, torch.Tensor]: + if self.use_tiling and ( + z.shape[-1] > self.tile_latent_min_size + or z.shape[-2] > self.tile_latent_min_size + ): + return self.tiled_decode(z, return_dict=return_dict) + + if self.post_quant_conv is not None: + z = self.post_quant_conv(z) + + dec = self.decoder(z) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def decode(self, z: torch.FloatTensor) -> Union[DecoderOutput, torch.FloatTensor]: + """ + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + + """ + + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + return decoded + + def blend_v( + self, a: torch.Tensor, b: torch.Tensor, blend_extent: int + ) -> torch.Tensor: + blend_extent = min(a.shape[2], b.shape[2], blend_extent) + for y in range(blend_extent): + b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[ + :, :, y, : + ] * (y / blend_extent) + return b + + def blend_h( + self, a: torch.Tensor, b: torch.Tensor, blend_extent: int + ) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for x in range(blend_extent): + b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[ + :, :, :, x + ] * (x / blend_extent) + return b + + def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + output, but they should be much less noticeable. + + Args: + x (`torch.Tensor`): Input batch of images. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[2], overlap_size): + row = [] + for j in range(0, x.shape[3], overlap_size): + tile = x[ + :, + :, + i : i + self.tile_sample_min_size, + j : j + self.tile_sample_min_size, + ] + tile = self.encoder(tile) + if self.config.use_quant_conv: + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + enc = torch.cat(result_rows, dim=2) + return enc + + def tiled_encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> AutoencoderKLOutput: + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + output, but they should be much less noticeable. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`: + If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain + `tuple` is returned. + """ + deprecation_message = ( + "The tiled_encode implementation supporting the `return_dict` parameter is deprecated. In the future, the " + "implementation of this method will be replaced with that of `_tiled_encode` and you will no longer be able " + "to pass `return_dict`. You will also have to create a `DiagonalGaussianDistribution()` from the returned value." + ) + # deprecate("tiled_encode", "1.0.0", deprecation_message, standard_warn=False) + + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[2], overlap_size): + row = [] + for j in range(0, x.shape[3], overlap_size): + tile = x[ + :, + :, + i : i + self.tile_sample_min_size, + j : j + self.tile_sample_min_size, + ] + tile = self.encoder(tile) + if self.config.use_quant_conv: + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + moments = torch.cat(result_rows, dim=2) + posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def tiled_decode( + self, z: torch.Tensor, return_dict: bool = True + ) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) + row_limit = self.tile_sample_min_size - blend_extent + + # Split z into overlapping 64x64 tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, z.shape[2], overlap_size): + row = [] + for j in range(0, z.shape[3], overlap_size): + tile = z[ + :, + :, + i : i + self.tile_latent_min_size, + j : j + self.tile_latent_min_size, + ] + if self.config.use_post_quant_conv: + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile) + row.append(decoded) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + dec = torch.cat(result_rows, dim=2) + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.Tensor]: + r""" + Args: + sample (`torch.Tensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z).sample + + return dec + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + > [!WARNING] > This API is 🧪 experimental. + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError( + "`fuse_qkv_projections()` is not supported for models having added KV projections." + ) + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + self.set_attn_processor(FusedAttnProcessor2_0()) + + +EntryClass = AutoencoderKL diff --git a/python/sglang/multimodal_gen/runtime/models/vaes/autoencoder_kl_qwenimage.py b/python/sglang/multimodal_gen/runtime/models/vaes/autoencoder_kl_qwenimage.py new file mode 100644 index 000000000..26d682f48 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/models/vaes/autoencoder_kl_qwenimage.py @@ -0,0 +1,1183 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.models.activations import get_activation +from diffusers.models.autoencoders.vae import ( + DecoderOutput, + DiagonalGaussianDistribution, +) +from diffusers.models.modeling_outputs import AutoencoderKLOutput + +from sglang.multimodal_gen.configs.models.vaes.qwenimage import QwenImageVAEConfig +from sglang.multimodal_gen.runtime.distributed import get_local_torch_device +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) # pylint: disable=invalid-name + +CACHE_T = 2 + + +class QwenImageCausalConv3d(nn.Conv3d): + r""" + A custom 3D causal convolution layer with feature caching support. + + This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature + caching for efficient inference. + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0 + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + ) -> None: + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + + # Set up causal padding + self._padding = ( + self.padding[2], + self.padding[2], + self.padding[1], + self.padding[1], + 2 * self.padding[0], + 0, + ) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = F.pad(x, padding) + return super().forward(x) + + +class QwenImageRMS_norm(nn.Module): + r""" + A custom RMS normalization layer. + + Args: + dim (int): The number of dimensions to normalize over. + channel_first (bool, optional): Whether the input tensor has channels as the first dimension. + Default is True. + images (bool, optional): Whether the input represents image data. Default is True. + bias (bool, optional): Whether to include a learnable bias term. Default is False. + """ + + def __init__( + self, + dim: int, + channel_first: bool = True, + images: bool = True, + bias: bool = False, + ) -> None: + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 + + def forward(self, x): + return ( + F.normalize(x, dim=(1 if self.channel_first else -1)) + * self.scale + * self.gamma + + self.bias + ) + + +class QwenImageUpsample(nn.Upsample): + r""" + Perform upsampling while ensuring the output tensor has the same data type as the input. + + Returns: + torch.Tensor: Upsampled tensor with the same data type as the input. + """ + + def forward(self, x): + return super().forward(x.float()).type_as(x) + + +class QwenImageResample(nn.Module): + r""" + A custom resampling module for 2D and 3D data. + + Args: + dim (int): The number of input/output channels. + mode (str): The resampling mode. Must be one of: + - 'none': No resampling (identity operation). + - 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution. + - 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution. + - 'downsample2d': 2D downsampling with zero-padding and convolution. + - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution. + """ + + def __init__(self, dim: int, mode: str) -> None: + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == "upsample2d": + self.resample = nn.Sequential( + QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim // 2, 3, padding=1), + ) + elif mode == "upsample3d": + self.resample = nn.Sequential( + QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim // 2, 3, padding=1), + ) + self.time_conv = QwenImageCausalConv3d( + dim, dim * 2, (3, 1, 1), padding=(1, 0, 0) + ) + + elif mode == "downsample2d": + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)) + ) + elif mode == "downsample3d": + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)) + ) + self.time_conv = QwenImageCausalConv3d( + dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0) + ) + + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == "upsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = "Rep" + feat_idx[0] += 1 + else: + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if ( + cache_x.shape[2] < 2 + and feat_cache[idx] is not None + and feat_cache[idx] != "Rep" + ): + # cache last frame of last two chunk + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :] + .unsqueeze(2) + .to(cache_x.device), + cache_x, + ], + dim=2, + ) + if ( + cache_x.shape[2] < 2 + and feat_cache[idx] is not None + and feat_cache[idx] == "Rep" + ): + cache_x = torch.cat( + [torch.zeros_like(cache_x).to(cache_x.device), cache_x], + dim=2, + ) + if feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.resample(x) + x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4) + + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv( + torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2) + ) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + +class QwenImageResidualBlock(nn.Module): + r""" + A custom residual block module. + + Args: + in_dim (int): Number of input channels. + out_dim (int): Number of output channels. + dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0. + non_linearity (str, optional): Type of non-linearity to use. Default is "silu". + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + dropout: float = 0.0, + non_linearity: str = "silu", + ) -> None: + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.nonlinearity = get_activation(non_linearity) + + # layers + self.norm1 = QwenImageRMS_norm(in_dim, images=False) + self.conv1 = QwenImageCausalConv3d(in_dim, out_dim, 3, padding=1) + self.norm2 = QwenImageRMS_norm(out_dim, images=False) + self.dropout = nn.Dropout(dropout) + self.conv2 = QwenImageCausalConv3d(out_dim, out_dim, 3, padding=1) + self.conv_shortcut = ( + QwenImageCausalConv3d(in_dim, out_dim, 1) + if in_dim != out_dim + else nn.Identity() + ) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + # Apply shortcut connection + h = self.conv_shortcut(x) + + # First normalization and activation + x = self.norm1(x) + x = self.nonlinearity(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), + cache_x, + ], + dim=2, + ) + + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + # Second normalization and activation + x = self.norm2(x) + x = self.nonlinearity(x) + + # Dropout + x = self.dropout(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), + cache_x, + ], + dim=2, + ) + + x = self.conv2(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv2(x) + + # Add residual connection + return x + h + + +class QwenImageAttentionBlock(nn.Module): + r""" + Causal self-attention with a single head. + + Args: + dim (int): The number of channels in the input tensor. + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim + + # layers + self.norm = QwenImageRMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + def forward(self, x): + identity = x + batch_size, channels, time, height, width = x.size() + + x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width) + x = self.norm(x) + + # compute query, key, value + qkv = self.to_qkv(x) + qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1) + qkv = qkv.permute(0, 1, 3, 2).contiguous() + q, k, v = qkv.chunk(3, dim=-1) + + # apply attention + x = F.scaled_dot_product_attention(q, k, v) + + x = ( + x.squeeze(1) + .permute(0, 2, 1) + .reshape(batch_size * time, channels, height, width) + ) + + # output projection + x = self.proj(x) + + # Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w] + x = x.view(batch_size, time, channels, height, width) + x = x.permute(0, 2, 1, 3, 4) + + return x + identity + + +class QwenImageMidBlock(nn.Module): + """ + Middle block for QwenImageVAE encoder and decoder. + + Args: + dim (int): Number of input/output channels. + dropout (float): Dropout rate. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__( + self, + dim: int, + dropout: float = 0.0, + non_linearity: str = "silu", + num_layers: int = 1, + ): + super().__init__() + self.dim = dim + + # Create the components + resnets = [QwenImageResidualBlock(dim, dim, dropout, non_linearity)] + attentions = [] + for _ in range(num_layers): + attentions.append(QwenImageAttentionBlock(dim)) + resnets.append(QwenImageResidualBlock(dim, dim, dropout, non_linearity)) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + # First residual block + x = self.resnets[0](x, feat_cache, feat_idx) + + # Process through attention and residual blocks + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + x = attn(x) + + x = resnet(x, feat_cache, feat_idx) + + return x + + +class QwenImageEncoder3d(nn.Module): + r""" + A 3D encoder module. + + Args: + dim (int): The base number of channels in the first layer. + z_dim (int): The dimensionality of the latent space. + dim_mult (list of int): Multipliers for the number of channels in each block. + num_res_blocks (int): Number of residual blocks in each block. + attn_scales (list of float): Scales at which to apply attention mechanisms. + temperal_downsample (list of bool): Whether to downsample temporally in each block. + dropout (float): Dropout rate for the dropout layers. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + non_linearity: str = "silu", + ): + super().__init__() + # dim = config.arch_config.dim + # z_dim = config.arch_config.z_dim + # dim_mult = config.arch_config.dim_mult + # num_res_blocks = config.arch_config.num_res_blocks + # attn_scales = config.arch_config.attn_scales + # temperal_downsample = config.arch_config.temperal_downsample + # dropout = config.arch_config.dropout + # non_linearity = config.arch_config.non_linearity + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.nonlinearity = get_activation(non_linearity) + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv_in = QwenImageCausalConv3d(3, dims[0], 3, padding=1) + + # downsample blocks + self.down_blocks = nn.ModuleList([]) + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + for _ in range(num_res_blocks): + self.down_blocks.append( + QwenImageResidualBlock(in_dim, out_dim, dropout) + ) + if scale in attn_scales: + self.down_blocks.append(QwenImageAttentionBlock(out_dim)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + mode = "downsample3d" if temperal_downsample[i] else "downsample2d" + self.down_blocks.append(QwenImageResample(out_dim, mode=mode)) + scale /= 2.0 + + # middle blocks + self.mid_block = QwenImageMidBlock( + out_dim, dropout, non_linearity, num_layers=1 + ) + + # output blocks + self.norm_out = QwenImageRMS_norm(out_dim, images=False) + self.conv_out = QwenImageCausalConv3d(out_dim, z_dim, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + + ## downsamples + for layer in self.down_blocks: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + x = self.mid_block(x, feat_cache, feat_idx) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) + return x + + +class QwenImageUpBlock(nn.Module): + """ + A block that handles upsampling for the QwenImageVAE decoder. + + Args: + in_dim (int): Input dimension + out_dim (int): Output dimension + num_res_blocks (int): Number of residual blocks + dropout (float): Dropout rate + upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d') + non_linearity (str): Type of non-linearity to use + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + num_res_blocks: int, + dropout: float = 0.0, + upsample_mode: Optional[str] = None, + non_linearity: str = "silu", + ): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # Create layers list + resnets = [] + # Add residual blocks and attention if needed + current_dim = in_dim + for _ in range(num_res_blocks + 1): + resnets.append( + QwenImageResidualBlock(current_dim, out_dim, dropout, non_linearity) + ) + current_dim = out_dim + + self.resnets = nn.ModuleList(resnets) + + # Add upsampling layer if needed + self.upsamplers = None + if upsample_mode is not None: + self.upsamplers = nn.ModuleList( + [QwenImageResample(out_dim, mode=upsample_mode)] + ) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + """ + Forward pass through the upsampling block. + + Args: + x (torch.Tensor): Input tensor + feat_cache (list, optional): Feature cache for causal convolutions + feat_idx (list, optional): Feature index for cache management + + Returns: + torch.Tensor: Output tensor + """ + for resnet in self.resnets: + if feat_cache is not None: + x = resnet(x, feat_cache, feat_idx) + else: + x = resnet(x) + + if self.upsamplers is not None: + if feat_cache is not None: + x = self.upsamplers[0](x, feat_cache, feat_idx) + else: + x = self.upsamplers[0](x) + return x + + +class QwenImageDecoder3d(nn.Module): + r""" + A 3D decoder module. + + Args: + dim (int): The base number of channels in the first layer. + z_dim (int): The dimensionality of the latent space. + dim_mult (list of int): Multipliers for the number of channels in each block. + num_res_blocks (int): Number of residual blocks in each block. + attn_scales (list of float): Scales at which to apply attention mechanisms. + temperal_upsample (list of bool): Whether to upsample temporally in each block. + dropout (float): Dropout rate for the dropout layers. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0, + non_linearity: str = "silu", + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + self.nonlinearity = get_activation(non_linearity) + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2 ** (len(dim_mult) - 2) + + # init block + self.conv_in = QwenImageCausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.mid_block = QwenImageMidBlock( + dims[0], dropout, non_linearity, num_layers=1 + ) + + # upsample blocks + self.up_blocks = nn.ModuleList([]) + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + if i > 0: + in_dim = in_dim // 2 + + # Determine if we need upsampling + upsample_mode = None + if i != len(dim_mult) - 1: + upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d" + + # Create and add the upsampling block + up_block = QwenImageUpBlock( + in_dim=in_dim, + out_dim=out_dim, + num_res_blocks=num_res_blocks, + dropout=dropout, + upsample_mode=upsample_mode, + non_linearity=non_linearity, + ) + self.up_blocks.append(up_block) + + # Update scale for next iteration + if upsample_mode is not None: + scale *= 2.0 + + # output blocks + self.norm_out = QwenImageRMS_norm(out_dim, images=False) + self.conv_out = QwenImageCausalConv3d(out_dim, 3, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, x, feat_cache=None, feat_idx=[0]): + ## conv1 + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + + ## middle + x = self.mid_block(x, feat_cache, feat_idx) + + ## upsamples + for up_block in self.up_blocks: + x = up_block(x, feat_cache, feat_idx) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) + return x + + +class AutoencoderKLQwenImage(nn.Module): + r""" + A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + """ + + _supports_gradient_checkpointing = False + + # fmt: off + def __init__( + self, + config: QwenImageVAEConfig, + # base_dim: int = 96, + # z_dim: int = 16, + # dim_mult: Tuple[int] = [1, 2, 4, 4], + # num_res_blocks: int = 2, + # attn_scales: List[float] = [], + # temperal_downsample: List[bool] = [False, True, True], + # dropout: float = 0.0, + # latents_mean: List[float] = [-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, + # -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921], + # latents_std: List[float] = [2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, + # 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160], + ) -> None: + # fmt: on + super().__init__() + base_dim = config.arch_config.base_dim + z_dim = config.arch_config.z_dim + dim_mult = config.arch_config.dim_mult + num_res_blocks = config.arch_config.num_res_blocks + attn_scales = config.arch_config.attn_scales + temperal_downsample = config.arch_config.temperal_downsample + dropout = config.arch_config.dropout + # non_linearity = config.arch_config.non_linearity + self.z_dim = z_dim + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + + self.encoder = QwenImageEncoder3d( + base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout + ) + self.quant_conv = QwenImageCausalConv3d(z_dim * 2, z_dim * 2, 1) + self.post_quant_conv = QwenImageCausalConv3d(z_dim, z_dim, 1) + + self.decoder = QwenImageDecoder3d( + base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout + ) + + self.spatial_compression_ratio = 2 ** len(self.temperal_downsample) + + # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension + # to perform decoding of a single video latent at a time. + self.use_slicing = False + + # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent + # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the + # intermediate tiles together, the memory requirement can be lowered. + self.use_tiling = False + + # The minimal tile height and width for spatial tiling to be used + self.tile_sample_min_height = 256 + self.tile_sample_min_width = 256 + + # The minimal distance between two spatial tiles + self.tile_sample_stride_height = 192 + self.tile_sample_stride_width = 192 + + # Precompute and cache conv counts for encoder and decoder for clear_cache speedup + self._cached_conv_counts = { + "decoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.decoder.modules()) + if self.decoder is not None + else 0, + "encoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.encoder.modules()) + if self.encoder is not None + else 0, + } + cuda_device = get_local_torch_device() + # FIXME: hardcode + dtype = torch.bfloat16 + latent_channels = config.arch_config.z_dim + + self.shift_factor = ( + torch.tensor( + config.arch_config.latents_mean + ) + .view(1, latent_channels, 1, 1, 1) + .to(cuda_device, dtype) + ) + latents_std_tensor = torch.tensor(config.arch_config.latents_std, dtype=dtype, device=cuda_device) + self.scaling_factor = (1.0 / latents_std_tensor).view(1, latent_channels, 1, 1, 1) + + def enable_tiling( + self, + tile_sample_min_height: Optional[int] = None, + tile_sample_min_width: Optional[int] = None, + tile_sample_stride_height: Optional[float] = None, + tile_sample_stride_width: Optional[float] = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_sample_stride_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. + tile_sample_stride_width (`int`, *optional*): + The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling + artifacts produced across the width dimension. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height + self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width + + def disable_tiling(self) -> None: + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_tiling = False + + def enable_slicing(self) -> None: + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self) -> None: + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + def clear_cache(self): + def _count_conv3d(model): + count = 0 + for m in model.modules(): + if isinstance(m, QwenImageCausalConv3d): + count += 1 + return count + + self._conv_num = _count_conv3d(self.decoder) + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + # cache encode + self._enc_conv_num = _count_conv3d(self.encoder) + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + + def _encode(self, x: torch.Tensor): + _, _, num_frame, height, width = x.shape + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x) + + self.clear_cache() + iter_ = 1 + (num_frame - 1) // 4 + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + else: + out_ = self.encoder( + x[:, :, 1 + 4 * (i - 1): 1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx, + ) + out = torch.cat([out, out_], 2) + + enc = self.quant_conv(out) + self.clear_cache() + return enc + + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> DiagonalGaussianDistribution: + r""" + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return posterior + + def _decode(self, z: torch.Tensor, return_dict: bool = True): + _, _, num_frame, height, width = z.shape + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + + if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): + return self.tiled_decode(z, return_dict=return_dict) + + self.clear_cache() + x = self.post_quant_conv(z) + for i in range(num_frame): + self._conv_idx = [0] + if i == 0: + out = self.decoder(x[:, :, i: i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + else: + out_ = self.decoder(x[:, :, i: i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + out = torch.cat([out, out_], 2) + + out = torch.clamp(out, min=-1.0, max=1.0) + self.clear_cache() + if not return_dict: + return (out,) + + return DecoderOutput(sample=out) + + def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + return decoded + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + _, _, num_frames, height, width = x.shape + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, self.tile_sample_stride_height): + row = [] + for j in range(0, width, self.tile_sample_stride_width): + self.clear_cache() + time = [] + frame_range = 1 + (num_frames - 1) // 4 + for k in range(frame_range): + self._enc_conv_idx = [0] + if k == 0: + tile = x[:, :, :1, i: i + self.tile_sample_min_height, j: j + self.tile_sample_min_width] + else: + tile = x[ + :, + :, + 1 + 4 * (k - 1): 1 + 4 * k, + i: i + self.tile_sample_min_height, + j: j + self.tile_sample_min_width, + ] + tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + tile = self.quant_conv(tile) + time.append(tile) + row.append(torch.cat(time, dim=2)) + rows.append(row) + self.clear_cache() + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + + enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] + return enc + + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + _, _, num_frames, height, width = z.shape + sample_height = height * self.spatial_compression_ratio + sample_width = width * self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = self.tile_sample_min_height - self.tile_sample_stride_height + blend_width = self.tile_sample_min_width - self.tile_sample_stride_width + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, tile_latent_stride_height): + row = [] + for j in range(0, width, tile_latent_stride_width): + self.clear_cache() + time = [] + for k in range(num_frames): + self._conv_idx = [0] + tile = z[:, :, k: k + 1, i: i + tile_latent_min_height, j: j + tile_latent_min_width] + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx) + time.append(decoded) + row.append(torch.cat(time, dim=2)) + rows.append(row) + self.clear_cache() + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + + dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] + + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.Tensor]: + """ + Args: + sample (`torch.Tensor`): Input sample. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, return_dict=return_dict) + return dec + + +EntryClass = AutoencoderKLQwenImage diff --git a/python/sglang/multimodal_gen/runtime/models/vaes/common.py b/python/sglang/multimodal_gen/runtime/models/vaes/common.py new file mode 100644 index 000000000..af1189d4f --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/models/vaes/common.py @@ -0,0 +1,647 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod +from collections.abc import Iterator +from math import prod +from typing import Optional, cast + +import numpy as np +import torch +import torch.distributed as dist +from diffusers.utils.torch_utils import randn_tensor + +from sglang.multimodal_gen.configs.models import VAEConfig +from sglang.multimodal_gen.runtime.distributed import ( + get_sp_parallel_rank, + get_sp_world_size, +) + + +class ParallelTiledVAE(ABC): + tile_sample_min_height: int + tile_sample_min_width: int + tile_sample_min_num_frames: int + tile_sample_stride_height: int + tile_sample_stride_width: int + tile_sample_stride_num_frames: int + blend_num_frames: int + use_tiling: bool + use_temporal_tiling: bool + use_parallel_tiling: bool + + def __init__(self, config: VAEConfig, **kwargs) -> None: + self.config = config + self.tile_sample_min_height = config.tile_sample_min_height + self.tile_sample_min_width = config.tile_sample_min_width + self.tile_sample_min_num_frames = config.tile_sample_min_num_frames + self.tile_sample_stride_height = config.tile_sample_stride_height + self.tile_sample_stride_width = config.tile_sample_stride_width + self.tile_sample_stride_num_frames = config.tile_sample_stride_num_frames + self.blend_num_frames = config.blend_num_frames + self.use_tiling = config.use_tiling + self.use_temporal_tiling = config.use_temporal_tiling + self.use_parallel_tiling = config.use_parallel_tiling + + def to(self, device) -> "ParallelTiledVAE": + # TODO: implement this + return self + + @property + def device(self): + return next(self.parameters()).device + + @property + def temporal_compression_ratio(self) -> int: + return cast(int, self.config.temporal_compression_ratio) + + @property + def spatial_compression_ratio(self) -> int: + return cast(int, self.config.spatial_compression_ratio) + + @property + def scaling_factor(self) -> float | torch.Tensor: + return cast(float | torch.Tensor, self.config.scaling_factor) + + @abstractmethod + def _encode(self, *args, **kwargs) -> torch.Tensor: + pass + + @abstractmethod + def _decode(self, *args, **kwargs) -> torch.Tensor: + pass + + def encode(self, x: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = x.shape + latent_num_frames = (num_frames - 1) // self.temporal_compression_ratio + 1 + + if ( + self.use_tiling + and self.use_temporal_tiling + and num_frames > self.tile_sample_min_num_frames + ): + latents = self.tiled_encode(x)[:, :, :latent_num_frames] + elif self.use_tiling and ( + width > self.tile_sample_min_width or height > self.tile_sample_min_height + ): + latents = self.spatial_tiled_encode(x)[:, :, :latent_num_frames] + else: + latents = self._encode(x)[:, :, :latent_num_frames] + return DiagonalGaussianDistribution(latents) + + def decode(self, z: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = z.shape + tile_latent_min_height = ( + self.tile_sample_min_height // self.spatial_compression_ratio + ) + tile_latent_min_width = ( + self.tile_sample_stride_width // self.spatial_compression_ratio + ) + tile_latent_min_num_frames = ( + self.tile_sample_min_num_frames // self.temporal_compression_ratio + ) + num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1 + + if self.use_tiling and self.use_parallel_tiling and get_sp_world_size() > 1: + return self.parallel_tiled_decode(z)[:, :, :num_sample_frames] + if ( + self.use_tiling + and self.use_temporal_tiling + and num_frames > tile_latent_min_num_frames + ): + return self.tiled_decode(z)[:, :, :num_sample_frames] + + if self.use_tiling and ( + width > tile_latent_min_width or height > tile_latent_min_height + ): + return self.spatial_tiled_decode(z)[:, :, :num_sample_frames] + + return self._decode(z)[:, :, :num_sample_frames] + + def blend_v( + self, a: torch.Tensor, b: torch.Tensor, blend_extent: int + ) -> torch.Tensor: + blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * ( + 1 - y / blend_extent + ) + b[:, :, :, y, :] * (y / blend_extent) + return b + + def blend_h( + self, a: torch.Tensor, b: torch.Tensor, blend_extent: int + ) -> torch.Tensor: + blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * ( + 1 - x / blend_extent + ) + b[:, :, :, :, x] * (x / blend_extent) + return b + + def blend_t( + self, a: torch.Tensor, b: torch.Tensor, blend_extent: int + ) -> torch.Tensor: + blend_extent = min(a.shape[-3], b.shape[-3], blend_extent) + for x in range(blend_extent): + b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * ( + 1 - x / blend_extent + ) + b[:, :, x, :, :] * (x / blend_extent) + return b + + def spatial_tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + _, _, _, height, width = x.shape + # latent_height = height // self.spatial_compression_ratio + # latent_width = width // self.spatial_compression_ratio + + tile_latent_min_height = ( + self.tile_sample_min_height // self.spatial_compression_ratio + ) + tile_latent_min_width = ( + self.tile_sample_min_width // self.spatial_compression_ratio + ) + tile_latent_stride_height = ( + self.tile_sample_stride_height // self.spatial_compression_ratio + ) + tile_latent_stride_width = ( + self.tile_sample_stride_width // self.spatial_compression_ratio + ) + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, self.tile_sample_stride_height): + row = [] + for j in range(0, width, self.tile_sample_stride_width): + tile = x[ + :, + :, + :, + i : i + self.tile_sample_min_height, + j : j + self.tile_sample_min_width, + ] + tile = self._encode(tile) + row.append(tile) + rows.append(row) + + return self._merge_spatial_tiles( + rows, + blend_height, + blend_width, + tile_latent_stride_height, + tile_latent_stride_width, + ) + + def _parallel_data_generator( + self, gathered_results, gathered_dim_metadata + ) -> Iterator[tuple[torch.Tensor, int]]: + global_idx = 0 + for i, per_rank_metadata in enumerate(gathered_dim_metadata): + _start_shape = 0 + for shape in per_rank_metadata: + mul_shape = prod(shape) + yield ( + gathered_results[ + i, _start_shape : _start_shape + mul_shape + ].reshape(shape), + global_idx, + ) + _start_shape += mul_shape + global_idx += 1 + + def parallel_tiled_decode(self, z: torch.FloatTensor) -> torch.FloatTensor: + """ + Parallel version of tiled_decode that distributes both temporal and spatial computation across GPUs + """ + world_size, rank = get_sp_world_size(), get_sp_parallel_rank() + B, C, T, H, W = z.shape + + # Calculate parameters + tile_latent_min_height = ( + self.tile_sample_min_height // self.spatial_compression_ratio + ) + tile_latent_min_width = ( + self.tile_sample_min_width // self.spatial_compression_ratio + ) + tile_latent_min_num_frames = ( + self.tile_sample_min_num_frames // self.temporal_compression_ratio + ) + tile_latent_stride_height = ( + self.tile_sample_stride_height // self.spatial_compression_ratio + ) + tile_latent_stride_width = ( + self.tile_sample_stride_width // self.spatial_compression_ratio + ) + tile_latent_stride_num_frames = ( + self.tile_sample_stride_num_frames // self.temporal_compression_ratio + ) + + blend_height = self.tile_sample_min_height - self.tile_sample_stride_height + blend_width = self.tile_sample_min_width - self.tile_sample_stride_width + + # Calculate tile dimensions + num_t_tiles = ( + T + tile_latent_stride_num_frames - 1 + ) // tile_latent_stride_num_frames + num_h_tiles = (H + tile_latent_stride_height - 1) // tile_latent_stride_height + num_w_tiles = (W + tile_latent_stride_width - 1) // tile_latent_stride_width + total_spatial_tiles = num_h_tiles * num_w_tiles + total_tiles = num_t_tiles * total_spatial_tiles + + # Calculate tiles per rank and padding + tiles_per_rank = (total_tiles + world_size - 1) // world_size + start_tile_idx = rank * tiles_per_rank + end_tile_idx = min((rank + 1) * tiles_per_rank, total_tiles) + + local_results = [] + local_dim_metadata = [] + # Process assigned tiles + for local_idx, global_idx in enumerate(range(start_tile_idx, end_tile_idx)): + t_idx = global_idx // total_spatial_tiles + spatial_idx = global_idx % total_spatial_tiles + h_idx = spatial_idx // num_w_tiles + w_idx = spatial_idx % num_w_tiles + + # Calculate positions + t_start = t_idx * tile_latent_stride_num_frames + h_start = h_idx * tile_latent_stride_height + w_start = w_idx * tile_latent_stride_width + + # Extract and process tile + tile = z[ + :, + :, + t_start : t_start + tile_latent_min_num_frames + 1, + h_start : h_start + tile_latent_min_height, + w_start : w_start + tile_latent_min_width, + ] + + # Process tile + tile = self._decode(tile) + + if t_start > 0: + tile = tile[:, :, 1:, :, :] + + # Store metadata + shape = tile.shape + # Store decoded data (flattened) + decoded_flat = tile.reshape(-1) + local_results.append(decoded_flat) + local_dim_metadata.append(shape) + + results = torch.cat(local_results, dim=0).contiguous() + del local_results + # first gather size to pad the results + local_size = torch.tensor( + [results.size(0)], device=results.device, dtype=torch.int64 + ) + all_sizes = [ + torch.zeros(1, device=results.device, dtype=torch.int64) + for _ in range(world_size) + ] + dist.all_gather(all_sizes, local_size) + max_size = max(size.item() for size in all_sizes) + padded_results = torch.zeros(max_size, device=results.device) + padded_results[: results.size(0)] = results + del results + + # Gather all results + gathered_dim_metadata = [None] * world_size + gathered_results = ( + torch.zeros_like(padded_results) + .repeat(world_size, *[1] * len(padded_results.shape)) + .contiguous() + ) # use contiguous to make sure it won't copy data in the following operations + # TODO (PY): use sgl_diffusion distributed methods + dist.all_gather_into_tensor(gathered_results, padded_results) + dist.all_gather_object(gathered_dim_metadata, local_dim_metadata) + # Process gathered results + data: list = [ + [[[] for _ in range(num_w_tiles)] for _ in range(num_h_tiles)] + for _ in range(num_t_tiles) + ] + for current_data, global_idx in self._parallel_data_generator( + gathered_results, gathered_dim_metadata + ): + t_idx = global_idx // total_spatial_tiles + spatial_idx = global_idx % total_spatial_tiles + h_idx = spatial_idx // num_w_tiles + w_idx = spatial_idx % num_w_tiles + data[t_idx][h_idx][w_idx] = current_data + # Merge results + result_slices = [] + last_slice_data = None + for i, tem_data in enumerate(data): + slice_data = self._merge_spatial_tiles( + tem_data, + blend_height, + blend_width, + self.tile_sample_stride_height, + self.tile_sample_stride_width, + ) + if i > 0: + slice_data = self.blend_t( + last_slice_data, slice_data, self.blend_num_frames + ) + result_slices.append( + slice_data[:, :, : self.tile_sample_stride_num_frames, :, :] + ) + else: + result_slices.append( + slice_data[:, :, : self.tile_sample_stride_num_frames + 1, :, :] + ) + last_slice_data = slice_data + dec = torch.cat(result_slices, dim=2) + + return dec + + def _merge_spatial_tiles( + self, tiles, blend_height, blend_width, stride_height, stride_width + ) -> torch.Tensor: + """Helper function to merge spatial tiles with blending""" + result_rows = [] + for i, row in enumerate(tiles): + result_row = [] + for j, tile in enumerate(row): + if i > 0: + tile = self.blend_v(tiles[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :stride_height, :stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + return torch.cat(result_rows, dim=-2) + + def spatial_tiled_decode(self, z: torch.Tensor) -> torch.Tensor: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + + Returns: + `torch.Tensor`: + The decoded images. + """ + + _, _, _, height, width = z.shape + # sample_height = height * self.spatial_compression_ratio + # sample_width = width * self.spatial_compression_ratio + + tile_latent_min_height = ( + self.tile_sample_min_height // self.spatial_compression_ratio + ) + tile_latent_min_width = ( + self.tile_sample_min_width // self.spatial_compression_ratio + ) + tile_latent_stride_height = ( + self.tile_sample_stride_height // self.spatial_compression_ratio + ) + tile_latent_stride_width = ( + self.tile_sample_stride_width // self.spatial_compression_ratio + ) + + blend_height = self.tile_sample_min_height - self.tile_sample_stride_height + blend_width = self.tile_sample_min_width - self.tile_sample_stride_width + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, tile_latent_stride_height): + row = [] + for j in range(0, width, tile_latent_stride_width): + tile = z[ + :, + :, + :, + i : i + tile_latent_min_height, + j : j + tile_latent_min_width, + ] + decoded = self._decode(tile) + row.append(decoded) + rows.append(row) + return self._merge_spatial_tiles( + rows, + blend_height, + blend_width, + self.tile_sample_stride_height, + self.tile_sample_stride_width, + ) + + def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + _, _, num_frames, height, width = x.shape + + # tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio + tile_latent_stride_num_frames = ( + self.tile_sample_stride_num_frames // self.temporal_compression_ratio + ) + + row = [] + for i in range(0, num_frames, self.tile_sample_stride_num_frames): + tile = x[:, :, i : i + self.tile_sample_min_num_frames + 1, :, :] + if self.use_tiling and ( + height > self.tile_sample_min_height + or width > self.tile_sample_min_width + ): + tile = self.spatial_tiled_encode(tile) + else: + tile = self._encode(tile) + if i > 0: + tile = tile[:, :, 1:, :, :] + row.append(tile) + result_row = [] + for i, tile in enumerate(row): + if i > 0: + tile = self.blend_t(row[i - 1], tile, self.blend_num_frames) + result_row.append(tile[:, :, :tile_latent_stride_num_frames, :, :]) + else: + result_row.append(tile[:, :, : tile_latent_stride_num_frames + 1, :, :]) + enc = torch.cat(result_row, dim=2) + return enc + + def tiled_decode(self, z: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = z.shape + + tile_latent_min_height = ( + self.tile_sample_min_height // self.spatial_compression_ratio + ) + tile_latent_min_width = ( + self.tile_sample_min_width // self.spatial_compression_ratio + ) + tile_latent_min_num_frames = ( + self.tile_sample_min_num_frames // self.temporal_compression_ratio + ) + tile_latent_stride_num_frames = ( + self.tile_sample_stride_num_frames // self.temporal_compression_ratio + ) + + row = [] + for i in range(0, num_frames, tile_latent_stride_num_frames): + tile = z[:, :, i : i + tile_latent_min_num_frames + 1, :, :] + if self.use_tiling and ( + tile.shape[-1] > tile_latent_min_width + or tile.shape[-2] > tile_latent_min_height + ): + decoded = self.spatial_tiled_decode(tile) + else: + decoded = self._decode(tile) + if i > 0: + decoded = decoded[:, :, 1:, :, :] + row.append(decoded) + result_row = [] + for i, tile in enumerate(row): + if i > 0: + tile = self.blend_t(row[i - 1], tile, self.blend_num_frames) + result_row.append( + tile[:, :, : self.tile_sample_stride_num_frames, :, :] + ) + else: + result_row.append( + tile[:, :, : self.tile_sample_stride_num_frames + 1, :, :] + ) + + dec = torch.cat(result_row, dim=2) + return dec + + def enable_tiling( + self, + tile_sample_min_height: int | None = None, + tile_sample_min_width: int | None = None, + tile_sample_min_num_frames: int | None = None, + tile_sample_stride_height: int | None = None, + tile_sample_stride_width: int | None = None, + tile_sample_stride_num_frames: int | None = None, + blend_num_frames: int | None = None, + use_tiling: bool | None = None, + use_temporal_tiling: bool | None = None, + use_parallel_tiling: bool | None = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_sample_min_num_frames (`int`, *optional*): + The minimum number of frames required for a sample to be separated into tiles across the frame + dimension. + tile_sample_stride_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. + tile_sample_stride_width (`int`, *optional*): + The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling + artifacts produced across the width dimension. + tile_sample_stride_num_frames (`int`, *optional*): + The stride between two consecutive frame tiles. This is to ensure that there are no tiling artifacts + produced across the frame dimension. + """ + self.use_tiling = True + self.tile_sample_min_height = ( + tile_sample_min_height or self.tile_sample_min_height + ) + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_sample_min_num_frames = ( + tile_sample_min_num_frames or self.tile_sample_min_num_frames + ) + self.tile_sample_stride_height = ( + tile_sample_stride_height or self.tile_sample_stride_height + ) + self.tile_sample_stride_width = ( + tile_sample_stride_width or self.tile_sample_stride_width + ) + self.tile_sample_stride_num_frames = ( + tile_sample_stride_num_frames or self.tile_sample_stride_num_frames + ) + if blend_num_frames is not None: + self.blend_num_frames = blend_num_frames + else: + self.blend_num_frames = ( + self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames + ) + self.use_tiling = use_tiling or self.use_tiling + self.use_temporal_tiling = use_temporal_tiling or self.use_temporal_tiling + self.use_parallel_tiling = use_parallel_tiling or self.use_parallel_tiling + + def disable_tiling(self) -> None: + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_tiling = False + + +# adapted from https://github.com/huggingface/diffusers/blob/e7ffeae0a191f710881d1fbde00cd6ff025e81f2/src/diffusers/models/autoencoders/vae.py#L691 +class DiagonalGaussianDistribution: + + def __init__(self, parameters: torch.Tensor, deterministic: bool = False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like( + self.mean, device=self.parameters.device, dtype=self.parameters.dtype + ) + + def sample(self, generator: torch.Generator | None = None) -> torch.Tensor: + # make sure sample is on the same device as the parameters and has same dtype + sample = randn_tensor( + self.mean.shape, + generator=generator, + device=self.parameters.device, + dtype=self.parameters.dtype, + ) + x = self.mean + self.std * sample + return x + + def kl( + self, other: Optional["DiagonalGaussianDistribution"] = None + ) -> torch.Tensor: + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.sum( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2, 3], + ) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def nll( + self, sample: torch.Tensor, dims: tuple[int, ...] = (1, 2, 3) + ) -> torch.Tensor: + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims, + ) + + def mode(self) -> torch.Tensor: + return self.mean diff --git a/python/sglang/multimodal_gen/runtime/models/vaes/hunyuanvae.py b/python/sglang/multimodal_gen/runtime/models/vaes/hunyuanvae.py new file mode 100644 index 000000000..d0e611db2 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/models/vaes/hunyuanvae.py @@ -0,0 +1,852 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from diffusers + +# Copyright 2024 The Hunyuan Team, The HuggingFace Team and The sgl-diffusion Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from sglang.multimodal_gen.configs.models.vaes import HunyuanVAEConfig +from sglang.multimodal_gen.runtime.layers.activation import get_act_fn +from sglang.multimodal_gen.runtime.models.vaes.common import ParallelTiledVAE + + +def prepare_causal_attention_mask( + num_frames: int, + height_width: int, + dtype: torch.dtype, + device: torch.device, + batch_size: int | None = None, +) -> torch.Tensor: + indices = torch.arange(1, num_frames + 1, dtype=torch.int32, device=device) + indices_blocks = indices.repeat_interleave(height_width) + x, y = torch.meshgrid(indices_blocks, indices_blocks, indexing="xy") + mask = torch.where(x <= y, 0, -float("inf")).to(dtype=dtype) + + if batch_size is not None: + mask = mask.unsqueeze(0).expand(batch_size, -1, -1) + return mask + + +class HunyuanVAEAttention(nn.Module): + + def __init__( + self, in_channels, heads, dim_head, eps, norm_num_groups, bias + ) -> None: + super().__init__() + self.in_channels = in_channels + self.heads = heads + self.dim_head = dim_head + self.eps = eps + self.norm_num_groups = norm_num_groups + self.bias = bias + + inner_dim = heads * dim_head + + # Define the projection layers + self.to_q = nn.Linear(in_channels, inner_dim, bias=bias) + self.to_k = nn.Linear(in_channels, inner_dim, bias=bias) + self.to_v = nn.Linear(in_channels, inner_dim, bias=bias) + self.to_out = nn.Sequential(nn.Linear(inner_dim, in_channels, bias=bias)) + + # Optional normalization layers + self.group_norm = nn.GroupNorm( + norm_num_groups, in_channels, eps=eps, affine=True + ) + + def forward( + self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None + ) -> torch.Tensor: + residual = hidden_states + + batch_size, sequence_length, _ = hidden_states.shape + + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + # Project to query, key, value + query = self.to_q(hidden_states) + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + + # Reshape for multi-head attention + head_dim = self.dim_head + + query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + + # Perform scaled dot-product attention + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + # Reshape back + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, self.heads * head_dim + ) + hidden_states = hidden_states.to(query.dtype) + + # Linear projection + hidden_states = self.to_out(hidden_states) + + # Residual connection and rescale + hidden_states = hidden_states + residual + + return hidden_states + + +class HunyuanVideoCausalConv3d(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int | tuple[int, int, int] = 3, + stride: int | tuple[int, int, int] = 1, + padding: int | tuple[int, int, int] = 0, + dilation: int | tuple[int, int, int] = 1, + bias: bool = True, + pad_mode: str = "replicate", + ) -> None: + super().__init__() + + kernel_size = ( + (kernel_size, kernel_size, kernel_size) + if isinstance(kernel_size, int) + else kernel_size + ) + + self.pad_mode = pad_mode + self.time_causal_padding = ( + kernel_size[0] // 2, + kernel_size[0] // 2, + kernel_size[1] // 2, + kernel_size[1] // 2, + kernel_size[2] - 1, + 0, + ) + + self.conv = nn.Conv3d( + in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = F.pad( + hidden_states, self.time_causal_padding, mode=self.pad_mode + ) + return self.conv(hidden_states) + + +class HunyuanVideoUpsampleCausal3D(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int | None = None, + kernel_size: int = 3, + stride: int = 1, + bias: bool = True, + upsample_factor: tuple[int, ...] = (2, 2, 2), + ) -> None: + super().__init__() + + out_channels = out_channels or in_channels + self.upsample_factor = upsample_factor + + self.conv = HunyuanVideoCausalConv3d( + in_channels, out_channels, kernel_size, stride, bias=bias + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_frames = hidden_states.size(2) + + first_frame, other_frames = hidden_states.split((1, num_frames - 1), dim=2) + first_frame = F.interpolate( + first_frame.squeeze(2), + scale_factor=self.upsample_factor[1:], + mode="nearest", + ).unsqueeze(2) + + if num_frames > 1: + # See: https://github.com/pytorch/pytorch/issues/81665 + # Unless you have a version of pytorch where non-contiguous implementation of F.interpolate + # is fixed, this will raise either a runtime error, or fail silently with bad outputs. + # If you are encountering an error here, make sure to try running encoding/decoding with + # `vae.enable_tiling()` first. If that doesn't work, open an issue at: + # https://github.com/huggingface/diffusers/issues + other_frames = other_frames.contiguous() + other_frames = F.interpolate( + other_frames, scale_factor=self.upsample_factor, mode="nearest" + ) + hidden_states = torch.cat((first_frame, other_frames), dim=2) + else: + hidden_states = first_frame + + hidden_states = self.conv(hidden_states) + return hidden_states + + +class HunyuanVideoDownsampleCausal3D(nn.Module): + + def __init__( + self, + channels: int, + out_channels: int | None = None, + padding: int = 1, + kernel_size: int = 3, + bias: bool = True, + stride=2, + ) -> None: + super().__init__() + out_channels = out_channels or channels + + self.conv = HunyuanVideoCausalConv3d( + channels, out_channels, kernel_size, stride, padding, bias=bias + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.conv(hidden_states) + return hidden_states + + +class HunyuanVideoResnetBlockCausal3D(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int | None = None, + dropout: float = 0.0, + groups: int = 32, + eps: float = 1e-6, + non_linearity: str = "silu", + ) -> None: + super().__init__() + out_channels = out_channels or in_channels + + self.nonlinearity = get_act_fn(non_linearity) + + self.norm1 = nn.GroupNorm(groups, in_channels, eps=eps, affine=True) + self.conv1 = HunyuanVideoCausalConv3d(in_channels, out_channels, 3, 1, 0) + + self.norm2 = nn.GroupNorm(groups, out_channels, eps=eps, affine=True) + self.dropout = nn.Dropout(dropout) + self.conv2 = HunyuanVideoCausalConv3d(out_channels, out_channels, 3, 1, 0) + + self.conv_shortcut = None + if in_channels != out_channels: + self.conv_shortcut = HunyuanVideoCausalConv3d( + in_channels, out_channels, 1, 1, 0 + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = hidden_states.contiguous() + residual = hidden_states + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + residual = self.conv_shortcut(residual) + + hidden_states = hidden_states + residual + return hidden_states + + +class HunyuanVideoMidBlock3D(nn.Module): + + def __init__( + self, + in_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "silu", + resnet_groups: int = 32, + add_attention: bool = True, + attention_head_dim: int = 1, + ) -> None: + super().__init__() + resnet_groups = ( + resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + ) + self.add_attention = add_attention + + # There is always at least one resnet + resnets = [ + HunyuanVideoResnetBlockCausal3D( + in_channels=in_channels, + out_channels=in_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + non_linearity=resnet_act_fn, + ) + ] + attentions: list[HunyuanVAEAttention | None] = [] + + for _ in range(num_layers): + if self.add_attention: + attentions.append( + HunyuanVAEAttention( + in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + eps=resnet_eps, + norm_num_groups=resnet_groups, + bias=True, + ) + ) + else: + attentions.append(None) + + resnets.append( + HunyuanVideoResnetBlockCausal3D( + in_channels=in_channels, + out_channels=in_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + non_linearity=resnet_act_fn, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + self.resnets[0], hidden_states + ) + + for attn, resnet in zip(self.attentions, self.resnets[1:], strict=True): + if attn is not None: + batch_size, num_channels, num_frames, height, width = ( + hidden_states.shape + ) + hidden_states = hidden_states.permute(0, 2, 3, 4, 1).flatten(1, 3) + attention_mask = prepare_causal_attention_mask( + num_frames, + height * width, + hidden_states.dtype, + hidden_states.device, + batch_size=batch_size, + ) + hidden_states = attn(hidden_states, attention_mask=attention_mask) + hidden_states = hidden_states.unflatten( + 1, (num_frames, height, width) + ).permute(0, 4, 1, 2, 3) + + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states) + + else: + hidden_states = self.resnets[0](hidden_states) + + for attn, resnet in zip(self.attentions, self.resnets[1:], strict=True): + if attn is not None: + batch_size, num_channels, num_frames, height, width = ( + hidden_states.shape + ) + hidden_states = hidden_states.permute(0, 2, 3, 4, 1).flatten(1, 3) + attention_mask = prepare_causal_attention_mask( + num_frames, + height * width, + hidden_states.dtype, + hidden_states.device, + batch_size=batch_size, + ) + hidden_states = attn(hidden_states, attention_mask=attention_mask) + hidden_states = hidden_states.unflatten( + 1, (num_frames, height, width) + ).permute(0, 4, 1, 2, 3) + + hidden_states = resnet(hidden_states) + + return hidden_states + + +class HunyuanVideoDownBlock3D(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "silu", + resnet_groups: int = 32, + add_downsample: bool = True, + downsample_stride: tuple[int, ...] | int = 2, + downsample_padding: int = 1, + ) -> None: + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + HunyuanVideoResnetBlockCausal3D( + in_channels=in_channels, + out_channels=out_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + non_linearity=resnet_act_fn, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + HunyuanVideoDownsampleCausal3D( + out_channels, + out_channels=out_channels, + padding=downsample_padding, + stride=downsample_stride, + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if torch.is_grad_enabled() and self.gradient_checkpointing: + for resnet in self.resnets: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states) + else: + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + +class HunyuanVideoUpBlock3D(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "silu", + resnet_groups: int = 32, + add_upsample: bool = True, + upsample_scale_factor: tuple[int, ...] = (2, 2, 2), + ) -> None: + super().__init__() + resnets = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnets.append( + HunyuanVideoResnetBlockCausal3D( + in_channels=input_channels, + out_channels=out_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + non_linearity=resnet_act_fn, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [ + HunyuanVideoUpsampleCausal3D( + out_channels, + out_channels=out_channels, + upsample_factor=upsample_scale_factor, + ) + ] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if torch.is_grad_enabled() and self.gradient_checkpointing: + for resnet in self.resnets: + hidden_states = self._gradient_checkpointing_func(resnet, hidden_states) + + else: + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class HunyuanVideoEncoder3D(nn.Module): + r""" + Causal encoder for 3D video-like data introduced in [Hunyuan Video](https://huggingface.co/papers/2412.03603). + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: tuple[str, ...] = ( + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + "HunyuanVideoDownBlock3D", + ), + block_out_channels: tuple[int, ...] = (128, 256, 512, 512), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + double_z: bool = True, + mid_block_add_attention=True, + temporal_compression_ratio: int = 4, + spatial_compression_ratio: int = 8, + ) -> None: + super().__init__() + + self.conv_in = HunyuanVideoCausalConv3d( + in_channels, block_out_channels[0], kernel_size=3, stride=1 + ) + self.mid_block: HunyuanVideoMidBlock3D | None = None + self.down_blocks = nn.ModuleList([]) + + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + if down_block_type != "HunyuanVideoDownBlock3D": + raise ValueError(f"Unsupported down_block_type: {down_block_type}") + + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio)) + num_time_downsample_layers = int(np.log2(temporal_compression_ratio)) + + if temporal_compression_ratio == 4: + add_spatial_downsample = bool(i < num_spatial_downsample_layers) + add_time_downsample = bool( + i >= (len(block_out_channels) - 1 - num_time_downsample_layers) + and not is_final_block + ) + elif temporal_compression_ratio == 8: + add_spatial_downsample = bool(i < num_spatial_downsample_layers) + add_time_downsample = bool(i < num_time_downsample_layers) + else: + raise ValueError( + f"Unsupported time_compression_ratio: {temporal_compression_ratio}" + ) + + downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1) + downsample_stride_T = (2,) if add_time_downsample else (1,) + downsample_stride = tuple(downsample_stride_T + downsample_stride_HW) + + down_block = HunyuanVideoDownBlock3D( + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + add_downsample=bool(add_spatial_downsample or add_time_downsample), + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + downsample_stride=downsample_stride, + downsample_padding=0, + ) + + self.down_blocks.append(down_block) + + self.mid_block = HunyuanVideoMidBlock3D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + add_attention=mid_block_add_attention, + ) + + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6 + ) + self.conv_act = nn.SiLU() + + conv_out_channels = 2 * out_channels if double_z else out_channels + self.conv_out = HunyuanVideoCausalConv3d( + block_out_channels[-1], conv_out_channels, kernel_size=3 + ) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.conv_in(hidden_states) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for down_block in self.down_blocks: + hidden_states = self._gradient_checkpointing_func( + down_block, hidden_states + ) + + hidden_states = self._gradient_checkpointing_func( + self.mid_block, hidden_states + ) + else: + for down_block in self.down_blocks: + hidden_states = down_block(hidden_states) + assert self.mid_block is not None + hidden_states = self.mid_block(hidden_states) + + hidden_states = self.conv_norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + +class HunyuanVideoDecoder3D(nn.Module): + r""" + Causal decoder for 3D video-like data introduced in [Hunyuan Video](https://huggingface.co/papers/2412.03603). + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + up_block_types: tuple[str, ...] = ( + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + "HunyuanVideoUpBlock3D", + ), + block_out_channels: tuple[int, ...] = (128, 256, 512, 512), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + mid_block_add_attention=True, + time_compression_ratio: int = 4, + spatial_compression_ratio: int = 8, + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = HunyuanVideoCausalConv3d( + in_channels, block_out_channels[-1], kernel_size=3, stride=1 + ) + self.up_blocks = nn.ModuleList([]) + + # mid + self.mid_block = HunyuanVideoMidBlock3D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + add_attention=mid_block_add_attention, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + if up_block_type != "HunyuanVideoUpBlock3D": + raise ValueError(f"Unsupported up_block_type: {up_block_type}") + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio)) + num_time_upsample_layers = int(np.log2(time_compression_ratio)) + + if time_compression_ratio == 4: + add_spatial_upsample = bool(i < num_spatial_upsample_layers) + add_time_upsample = bool( + i >= len(block_out_channels) - 1 - num_time_upsample_layers + and not is_final_block + ) + else: + raise ValueError( + f"Unsupported time_compression_ratio: {time_compression_ratio}" + ) + + upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1) + upsample_scale_factor_T = (2,) if add_time_upsample else (1,) + upsample_scale_factor = tuple( + upsample_scale_factor_T + upsample_scale_factor_HW + ) + + up_block = HunyuanVideoUpBlock3D( + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + add_upsample=bool(add_spatial_upsample or add_time_upsample), + upsample_scale_factor=upsample_scale_factor, + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + ) + + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6 + ) + self.conv_act = nn.SiLU() + self.conv_out = HunyuanVideoCausalConv3d( + block_out_channels[0], out_channels, kernel_size=3 + ) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.conv_in(hidden_states) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + self.mid_block, hidden_states + ) + + for up_block in self.up_blocks: + hidden_states = self._gradient_checkpointing_func( + up_block, hidden_states + ) + else: + hidden_states = self.mid_block(hidden_states) + + for up_block in self.up_blocks: + hidden_states = up_block(hidden_states) + + # post-process + hidden_states = self.conv_norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + +class AutoencoderKLHunyuanVideo(nn.Module, ParallelTiledVAE): + r""" + A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. + Introduced in [HunyuanVideo](https://huggingface.co/papers/2412.03603). + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + config: HunyuanVAEConfig, + ) -> None: + nn.Module.__init__(self) + ParallelTiledVAE.__init__(self, config) + + # TODO(will): only pass in config. We do this by manually defining a + # config for hunyuan vae + self.block_out_channels = config.block_out_channels + + if config.load_encoder: + self.encoder = HunyuanVideoEncoder3D( + in_channels=config.in_channels, + out_channels=config.latent_channels, + down_block_types=config.down_block_types, + block_out_channels=config.block_out_channels, + layers_per_block=config.layers_per_block, + norm_num_groups=config.norm_num_groups, + act_fn=config.act_fn, + double_z=True, + mid_block_add_attention=config.mid_block_add_attention, + temporal_compression_ratio=config.temporal_compression_ratio, + spatial_compression_ratio=config.spatial_compression_ratio, + ) + self.quant_conv = nn.Conv3d( + 2 * config.latent_channels, 2 * config.latent_channels, kernel_size=1 + ) + + if config.load_decoder: + self.decoder = HunyuanVideoDecoder3D( + in_channels=config.latent_channels, + out_channels=config.out_channels, + up_block_types=config.up_block_types, + block_out_channels=config.block_out_channels, + layers_per_block=config.layers_per_block, + norm_num_groups=config.norm_num_groups, + act_fn=config.act_fn, + time_compression_ratio=config.temporal_compression_ratio, + spatial_compression_ratio=config.spatial_compression_ratio, + mid_block_add_attention=config.mid_block_add_attention, + ) + self.post_quant_conv = nn.Conv3d( + config.latent_channels, config.latent_channels, kernel_size=1 + ) + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + x = self.encoder(x) + enc = self.quant_conv(x) + return enc + + def _decode(self, z: torch.Tensor) -> torch.Tensor: + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + generator: torch.Generator | None = None, + ) -> torch.Tensor: + r""" + Args: + sample (`torch.Tensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z) + return dec + + +EntryClass = AutoencoderKLHunyuanVideo diff --git a/python/sglang/multimodal_gen/runtime/models/vaes/stepvideovae.py b/python/sglang/multimodal_gen/runtime/models/vaes/stepvideovae.py new file mode 100644 index 000000000..d202b7a61 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/models/vaes/stepvideovae.py @@ -0,0 +1,1184 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Copyright 2025 StepFun Inc. All Rights Reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# ============================================================================== +from typing import Any + +import torch +from einops import rearrange +from torch import nn +from torch.nn import functional as F + +from sglang.multimodal_gen.configs.models.vaes import StepVideoVAEConfig +from sglang.multimodal_gen.runtime.models.vaes.common import ParallelTiledVAE + + +def base_group_norm(x, norm_layer, act_silu=False, channel_last=False) -> torch.Tensor: + if hasattr(base_group_norm, "spatial") and base_group_norm.spatial: + assert channel_last + x_shape = x.shape + x = x.flatten(0, 1) + if channel_last: + # Permute to NCHW format + x = x.permute(0, 3, 1, 2) + + out = F.group_norm( + x.contiguous(), + norm_layer.num_groups, + norm_layer.weight, + norm_layer.bias, + norm_layer.eps, + ) + if act_silu: + out = F.silu(out) + + if channel_last: + # Permute back to NHWC format + out = out.permute(0, 2, 3, 1) + + out = out.view(x_shape) + else: + if channel_last: + # Permute to NCHW format + x = x.permute(0, 3, 1, 2) + out = F.group_norm( + x.contiguous(), + norm_layer.num_groups, + norm_layer.weight, + norm_layer.bias, + norm_layer.eps, + ) + if act_silu: + out = F.silu(out) + if channel_last: + # Permute back to NHWC format + out = out.permute(0, 2, 3, 1) + return out + + +def base_conv2d(x, conv_layer, channel_last=False, residual=None) -> torch.Tensor: + if channel_last: + x = x.permute(0, 3, 1, 2) # NHWC to NCHW + out = F.conv2d( + x, + conv_layer.weight, + conv_layer.bias, + stride=conv_layer.stride, + padding=conv_layer.padding, + ) + if residual is not None: + if channel_last: + residual = residual.permute(0, 3, 1, 2) # NHWC to NCHW + out += residual + if channel_last: + out = out.permute(0, 2, 3, 1) # NCHW to NHWC + return out + + +def base_conv3d( + x, conv_layer, channel_last=False, residual=None, only_return_output=False +) -> torch.Tensor: + if only_return_output: + size = cal_outsize( + x.shape, conv_layer.weight.shape, conv_layer.stride, conv_layer.padding + ) + return torch.empty(size, device=x.device, dtype=x.dtype) + if channel_last: + x = x.permute(0, 4, 1, 2, 3) # NDHWC to NCDHW + out = F.conv3d( + x, + conv_layer.weight, + conv_layer.bias, + stride=conv_layer.stride, + padding=conv_layer.padding, + ) + if residual is not None: + if channel_last: + residual = residual.permute(0, 4, 1, 2, 3) # NDHWC to NCDHW + out += residual + if channel_last: + out = out.permute(0, 2, 3, 4, 1) # NCDHW to NDHWC + return out + + +def cal_outsize(input_sizes, kernel_sizes, stride, padding) -> list: + stride_d, stride_h, stride_w = stride + padding_d, padding_h, padding_w = padding + dilation_d, dilation_h, dilation_w = 1, 1, 1 + + in_d = input_sizes[1] + in_h = input_sizes[2] + in_w = input_sizes[3] + + kernel_d = kernel_sizes[2] + kernel_h = kernel_sizes[3] + kernel_w = kernel_sizes[4] + out_channels = kernel_sizes[0] + + out_d = calc_out_(in_d, padding_d, dilation_d, kernel_d, stride_d) + out_h = calc_out_(in_h, padding_h, dilation_h, kernel_h, stride_h) + out_w = calc_out_(in_w, padding_w, dilation_w, kernel_w, stride_w) + size = [input_sizes[0], out_d, out_h, out_w, out_channels] + return size + + +def calc_out_( + in_size: int, padding: int, dilation: int, kernel: int, stride: int +) -> int: + return (in_size + 2 * padding - dilation * (kernel - 1) - 1) // stride + 1 + + +def base_conv3d_channel_last(x, conv_layer, residual=None) -> torch.Tensor: + in_numel = x.numel() + out_numel = int(x.numel() * conv_layer.out_channels / conv_layer.in_channels) + if (in_numel >= 2**30) or (out_numel >= 2**30): + assert conv_layer.stride[0] == 1, "time split asks time stride = 1" + + B, T, H, W, C = x.shape + K = conv_layer.kernel_size[0] + + chunks = 4 + chunk_size = T // chunks + + if residual is None: + out_nhwc = base_conv3d( + x, + conv_layer, + channel_last=True, + residual=residual, + only_return_output=True, + ) + else: + out_nhwc = residual + + assert B == 1 + for i in range(chunks): + if i == chunks - 1: + xi = x[:1, chunk_size * i :] + out_nhwci = out_nhwc[:1, chunk_size * i :] + else: + xi = x[:1, chunk_size * i : chunk_size * (i + 1) + K - 1] + out_nhwci = out_nhwc[:1, chunk_size * i : chunk_size * (i + 1)] + if residual is not None: + if i == chunks - 1: + ri = residual[:1, chunk_size * i :] + else: + ri = residual[:1, chunk_size * i : chunk_size * (i + 1)] + else: + ri = None + out_nhwci.copy_(base_conv3d(xi, conv_layer, channel_last=True, residual=ri)) + else: + out_nhwc = base_conv3d(x, conv_layer, channel_last=True, residual=residual) + return out_nhwc + + +class Upsample2D(nn.Module): + + def __init__( + self, channels, use_conv=False, use_conv_transpose=False, out_channels=None + ) -> None: + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + + if use_conv: + self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1) + else: + assert "Not Supported" + self.conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1) + + def forward(self, x, output_size=None) -> torch.Tensor: + assert x.shape[-1] == self.channels + + if self.use_conv_transpose: + return self.conv(x) + + if output_size is None: + x = ( + F.interpolate( + x.permute(0, 3, 1, 2).to(memory_format=torch.channels_last), + scale_factor=2.0, + mode="nearest", + ) + .permute(0, 2, 3, 1) + .contiguous() + ) + else: + x = ( + F.interpolate( + x.permute(0, 3, 1, 2).to(memory_format=torch.channels_last), + size=output_size, + mode="nearest", + ) + .permute(0, 2, 3, 1) + .contiguous() + ) + + # x = self.conv(x) + x = base_conv2d(x, self.conv, channel_last=True) + return x + + +class Downsample2D(nn.Module): + + def __init__(self, channels, use_conv=False, out_channels=None, padding=1) -> None: + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + + if use_conv: + self.conv = nn.Conv2d( + self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + assert self.channels == self.out_channels + self.conv = nn.AvgPool2d(kernel_size=stride, stride=stride) + + def forward(self, x) -> torch.Tensor: + assert x.shape[-1] == self.channels + if self.use_conv and self.padding == 0: + pad = (0, 0, 0, 1, 0, 1) + x = F.pad(x, pad, mode="constant", value=0) + + assert x.shape[-1] == self.channels + # x = self.conv(x) + x = base_conv2d(x, self.conv, channel_last=True) + return x + + +class CausalConv(nn.Module): + + def __init__(self, chan_in, chan_out, kernel_size, **kwargs) -> None: + super().__init__() + + if isinstance(kernel_size, int): + kernel_size = ( + kernel_size if isinstance(kernel_size, tuple) else ((kernel_size,) * 3) + ) + time_kernel_size, height_kernel_size, width_kernel_size = kernel_size + + self.dilation = kwargs.pop("dilation", 1) + self.stride = kwargs.pop("stride", 1) + if isinstance(self.stride, int): + self.stride = (self.stride, 1, 1) + time_pad = self.dilation * (time_kernel_size - 1) + max((1 - self.stride[0]), 0) + height_pad = height_kernel_size // 2 + width_pad = width_kernel_size // 2 + self.time_causal_padding = ( + width_pad, + width_pad, + height_pad, + height_pad, + time_pad, + 0, + ) + self.time_uncausal_padding = ( + width_pad, + width_pad, + height_pad, + height_pad, + 0, + 0, + ) + + self.conv = nn.Conv3d( + chan_in, + chan_out, + kernel_size, + stride=self.stride, + dilation=self.dilation, + **kwargs, + ) + self.chan_in = chan_in + self.chan_out = chan_out + self.is_first_run = True + + def forward(self, x, is_init=True, residual=None) -> torch.Tensor: + x = nn.functional.pad( + x, self.time_causal_padding if is_init else self.time_uncausal_padding + ) + x = self.conv(x) + if residual is not None: + x.add_(residual) + return x + + +class ChannelDuplicatingPixelUnshuffleUpSampleLayer3D(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + factor: int, + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.factor = factor + assert out_channels * factor**3 % in_channels == 0 + self.repeats = out_channels * factor**3 // in_channels + + def forward(self, x: torch.Tensor, is_init=True) -> torch.Tensor: + x = x.repeat_interleave(self.repeats, dim=1) + x = x.view( + x.size(0), + self.out_channels, + self.factor, + self.factor, + self.factor, + x.size(2), + x.size(3), + x.size(4), + ) + x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous() + x = x.view( + x.size(0), + self.out_channels, + x.size(2) * self.factor, + x.size(4) * self.factor, + x.size(6) * self.factor, + ) + x = x[:, :, self.factor - 1 :, :, :] + return x + + +class ConvPixelShuffleUpSampleLayer3D(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + factor: int, + ) -> None: + super().__init__() + self.factor = factor + out_ratio = factor**3 + self.conv = CausalConv( + in_channels, out_channels * out_ratio, kernel_size=kernel_size + ) + + def forward(self, x: torch.Tensor, is_init=True) -> torch.Tensor: + x = self.conv(x, is_init) + x = self.pixel_shuffle_3d(x, self.factor) + return x + + @staticmethod + def pixel_shuffle_3d(x: torch.Tensor, factor: int) -> torch.Tensor: + batch_size, channels, depth, height, width = x.size() + new_channels = channels // (factor**3) + new_depth = depth * factor + new_height = height * factor + new_width = width * factor + + x = x.view( + batch_size, new_channels, factor, factor, factor, depth, height, width + ) + x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous() + x = x.view(batch_size, new_channels, new_depth, new_height, new_width) + x = x[:, :, factor - 1 :, :, :] + return x + + +class ConvPixelUnshuffleDownSampleLayer3D(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + factor: int, + ) -> None: + super().__init__() + self.factor = factor + out_ratio = factor**3 + assert out_channels % out_ratio == 0 + self.conv = CausalConv( + in_channels, out_channels // out_ratio, kernel_size=kernel_size + ) + + def forward(self, x: torch.Tensor, is_init=True) -> torch.Tensor: + x = self.conv(x, is_init) + x = self.pixel_unshuffle_3d(x, self.factor) + return x + + @staticmethod + def pixel_unshuffle_3d(x: torch.Tensor, factor: int) -> torch.Tensor: + pad = (0, 0, 0, 0, factor - 1, 0) # (left, right, top, bottom, front, back) + x = F.pad(x, pad) + B, C, D, H, W = x.shape + x = x.view(B, C, D // factor, factor, H // factor, factor, W // factor, factor) + x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous() + x = x.view(B, C * factor**3, D // factor, H // factor, W // factor) + return x + + +class PixelUnshuffleChannelAveragingDownSampleLayer3D(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + factor: int, + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.factor = factor + assert in_channels * factor**3 % out_channels == 0 + self.group_size = in_channels * factor**3 // out_channels + + def forward(self, x: torch.Tensor, is_init=True) -> torch.Tensor: + pad = ( + 0, + 0, + 0, + 0, + self.factor - 1, + 0, + ) # (left, right, top, bottom, front, back) + x = F.pad(x, pad) + B, C, D, H, W = x.shape + x = x.view( + B, + C, + D // self.factor, + self.factor, + H // self.factor, + self.factor, + W // self.factor, + self.factor, + ) + x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous() + x = x.view( + B, C * self.factor**3, D // self.factor, H // self.factor, W // self.factor + ) + x = x.view( + B, + self.out_channels, + self.group_size, + D // self.factor, + H // self.factor, + W // self.factor, + ) + x = x.mean(dim=2) + return x + + +def base_group_norm_with_zero_pad( + x, norm_layer, act_silu=True, pad_size=2 +) -> torch.Tensor: + out_shape = list(x.shape) + out_shape[1] += pad_size + out = torch.empty(out_shape, dtype=x.dtype, device=x.device) + out[:, pad_size:] = base_group_norm( + x, norm_layer, act_silu=act_silu, channel_last=True + ) + out[:, :pad_size] = 0 + return out + + +class CausalConvChannelLast(CausalConv): + time_causal_padding: tuple[Any, ...] + time_uncausal_padding: tuple[Any, ...] + + def __init__(self, chan_in, chan_out, kernel_size, **kwargs) -> None: + super().__init__(chan_in, chan_out, kernel_size, **kwargs) + + self.time_causal_padding = (0, 0) + self.time_causal_padding + self.time_uncausal_padding = (0, 0) + self.time_uncausal_padding + + def forward(self, x, is_init=True, residual=None) -> torch.Tensor: + if self.is_first_run: + self.is_first_run = False + # self.conv.weight = nn.Parameter(self.conv.weight.permute(0,2,3,4,1).contiguous()) + + x = nn.functional.pad( + x, self.time_causal_padding if is_init else self.time_uncausal_padding + ) + + x = base_conv3d_channel_last(x, self.conv, residual=residual) + return x + + +class CausalConvAfterNorm(CausalConv): + + def __init__(self, chan_in, chan_out, kernel_size, **kwargs) -> None: + super().__init__(chan_in, chan_out, kernel_size, **kwargs) + + if self.time_causal_padding == (1, 1, 1, 1, 2, 0): + self.conv = nn.Conv3d( + chan_in, + chan_out, + kernel_size, + stride=self.stride, + dilation=self.dilation, + padding=(0, 1, 1), + **kwargs, + ) + else: + self.conv = nn.Conv3d( + chan_in, + chan_out, + kernel_size, + stride=self.stride, + dilation=self.dilation, + **kwargs, + ) + self.is_first_run = True + + def forward(self, x, is_init=True, residual=None) -> torch.Tensor: + if self.is_first_run: + self.is_first_run = False + + if self.time_causal_padding == (1, 1, 1, 1, 2, 0): + pass + else: + x = nn.functional.pad(x, self.time_causal_padding).contiguous() + + x = base_conv3d_channel_last(x, self.conv, residual=residual) + return x + + +class AttnBlock(nn.Module): + + def __init__(self, in_channels) -> None: + super().__init__() + + self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels) + self.q = CausalConvChannelLast(in_channels, in_channels, kernel_size=1) + self.k = CausalConvChannelLast(in_channels, in_channels, kernel_size=1) + self.v = CausalConvChannelLast(in_channels, in_channels, kernel_size=1) + self.proj_out = CausalConvChannelLast(in_channels, in_channels, kernel_size=1) + + def attention(self, x, is_init=True) -> torch.Tensor: + x = base_group_norm(x, self.norm, act_silu=False, channel_last=True) + q = self.q(x, is_init) + k = self.k(x, is_init) + v = self.v(x, is_init) + + b, t, h, w, c = q.shape + q, k, v = map(lambda x: rearrange(x, "b t h w c -> b 1 (t h w) c"), (q, k, v)) + x = nn.functional.scaled_dot_product_attention(q, k, v, is_causal=True) + x = rearrange(x, "b 1 (t h w) c -> b t h w c", t=t, h=h, w=w) + + return x + + def forward(self, x): + x = x.permute(0, 2, 3, 4, 1).contiguous() + h = self.attention(x) + x = self.proj_out(h, residual=x) + x = x.permute(0, 4, 1, 2, 3) + return x + + +class Resnet3DBlock(nn.Module): + + def __init__( + self, + in_channels, + out_channels=None, + temb_channels=512, + conv_shortcut=False, + ) -> None: + super().__init__() + + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels) + self.conv1 = CausalConvAfterNorm(in_channels, out_channels, kernel_size=3) + if temb_channels > 0: + self.temb_proj = nn.Linear(temb_channels, out_channels) + + self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels) + self.conv2 = CausalConvAfterNorm(out_channels, out_channels, kernel_size=3) + + assert conv_shortcut is False + self.use_conv_shortcut = conv_shortcut + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = CausalConvAfterNorm( + in_channels, out_channels, kernel_size=3 + ) + else: + self.nin_shortcut = CausalConvAfterNorm( + in_channels, out_channels, kernel_size=1 + ) + + def forward(self, x, temb=None, is_init=True) -> torch.Tensor: + x = x.permute(0, 2, 3, 4, 1).contiguous() + + h = base_group_norm_with_zero_pad(x, self.norm1, act_silu=True, pad_size=2) + h = self.conv1(h) + if temb is not None: + h = h + self.temb_proj(nn.functional.silu(temb))[:, :, None, None] + + x = self.nin_shortcut(x) if self.in_channels != self.out_channels else x + + h = base_group_norm_with_zero_pad(h, self.norm2, act_silu=True, pad_size=2) + x = self.conv2(h, residual=x) + + x = x.permute(0, 4, 1, 2, 3) + return x + + +class Downsample3D(nn.Module): + + def __init__(self, in_channels, with_conv, stride) -> None: + super().__init__() + + self.with_conv = with_conv + if with_conv: + self.conv = CausalConv( + in_channels, in_channels, kernel_size=3, stride=stride + ) + + def forward(self, x, is_init=True) -> torch.Tensor: + if self.with_conv: + x = self.conv(x, is_init) + else: + x = nn.functional.avg_pool3d(x, kernel_size=2, stride=2) + return x + + +class VideoEncoder(nn.Module): + + def __init__( + self, + ch=32, + ch_mult=(4, 8, 16, 16), + num_res_blocks=2, + in_channels=3, + z_channels=16, + double_z=True, + down_sampling_layer=(1, 2), + resamp_with_conv=True, + version=1, + ) -> None: + super().__init__() + + temb_ch = 0 + + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + + # downsampling + self.conv_in = CausalConv(in_channels, ch, kernel_size=3) + self.down_sampling_layer = down_sampling_layer + + in_ch_mult = (1,) + tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + Resnet3DBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=temb_ch, + ) + ) + block_in = block_out + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + if i_level in self.down_sampling_layer: + down.downsample = Downsample3D( + block_in, resamp_with_conv, stride=(2, 2, 2) + ) + else: + down.downsample = Downsample2D( + block_in, resamp_with_conv, padding=0 + ) # DIFF + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = Resnet3DBlock( + in_channels=block_in, out_channels=block_in, temb_channels=temb_ch + ) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = Resnet3DBlock( + in_channels=block_in, out_channels=block_in, temb_channels=temb_ch + ) + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in) + self.version = version + if version == 2: + channels = 4 * z_channels * 2**3 + self.conv_patchify = ConvPixelUnshuffleDownSampleLayer3D( + block_in, channels, kernel_size=3, factor=2 + ) + self.shortcut_pathify = PixelUnshuffleChannelAveragingDownSampleLayer3D( + block_in, channels, 2 + ) + self.shortcut_out = PixelUnshuffleChannelAveragingDownSampleLayer3D( + channels, 2 * z_channels if double_z else z_channels, 1 + ) + self.conv_out = CausalConvChannelLast( + channels, 2 * z_channels if double_z else z_channels, kernel_size=3 + ) + else: + self.conv_out = CausalConvAfterNorm( + block_in, 2 * z_channels if double_z else z_channels, kernel_size=3 + ) + + @torch.inference_mode() + def forward(self, x, video_frame_num, is_init=True) -> torch.Tensor: + # timestep embedding + temb = None + + t = video_frame_num + + # downsampling + h = self.conv_in(x, is_init) + + # make it real channel last, but behave like normal layout + h = h.permute(0, 2, 3, 4, 1).contiguous().permute(0, 4, 1, 2, 3) + + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](h, temb, is_init) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + + if i_level != self.num_resolutions - 1: + if isinstance(self.down[i_level].downsample, Downsample2D): + _, _, t, _, _ = h.shape + h = rearrange(h, "b c t h w -> (b t) h w c", t=t) + h = self.down[i_level].downsample(h) + h = rearrange(h, "(b t) h w c -> b c t h w", t=t) + else: + h = self.down[i_level].downsample(h, is_init) + + h = self.mid.block_1(h, temb, is_init) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb, is_init) + + h = h.permute(0, 2, 3, 4, 1).contiguous() # b c l h w -> b l h w c + if self.version == 2: + h = base_group_norm(h, self.norm_out, act_silu=True, channel_last=True) + h = h.permute(0, 4, 1, 2, 3).contiguous() + shortcut = self.shortcut_pathify(h, is_init) + h = self.conv_patchify(h, is_init) + h = h.add_(shortcut) + shortcut = self.shortcut_out(h, is_init).permute(0, 2, 3, 4, 1) + h = self.conv_out(h.permute(0, 2, 3, 4, 1).contiguous(), is_init) + h = h.add_(shortcut) + else: + h = base_group_norm_with_zero_pad( + h, self.norm_out, act_silu=True, pad_size=2 + ) + h = self.conv_out(h, is_init) + h = h.permute(0, 4, 1, 2, 3) # b l h w c -> b c l h w + + h = rearrange(h, "b c t h w -> b t c h w") + return h + + +class Res3DBlockUpsample(nn.Module): + + def __init__( + self, input_filters, num_filters, down_sampling_stride, down_sampling=False + ) -> None: + super().__init__() + + self.input_filters = input_filters + self.num_filters = num_filters + + self.act_ = nn.SiLU(inplace=True) + + self.conv1 = CausalConvChannelLast( + num_filters, num_filters, kernel_size=[3, 3, 3] + ) + self.norm1 = nn.GroupNorm(32, num_filters) + + self.conv2 = CausalConvChannelLast( + num_filters, num_filters, kernel_size=[3, 3, 3] + ) + self.norm2 = nn.GroupNorm(32, num_filters) + + self.down_sampling = down_sampling + if down_sampling: + self.down_sampling_stride = down_sampling_stride + else: + self.down_sampling_stride = [1, 1, 1] + + if num_filters != input_filters or down_sampling: + self.conv3 = CausalConvChannelLast( + input_filters, + num_filters, + kernel_size=[1, 1, 1], + stride=self.down_sampling_stride, + ) + self.norm3 = nn.GroupNorm(32, num_filters) + + def forward(self, x, is_init=False) -> torch.Tensor: + x = x.permute(0, 2, 3, 4, 1).contiguous() + + residual = x + + h = self.conv1(x, is_init) + h = base_group_norm(h, self.norm1, act_silu=True, channel_last=True) + + h = self.conv2(h, is_init) + h = base_group_norm(h, self.norm2, act_silu=False, channel_last=True) + + if self.down_sampling or self.num_filters != self.input_filters: + x = self.conv3(x, is_init) + x = base_group_norm(x, self.norm3, act_silu=False, channel_last=True) + + h.add_(x) + h = self.act_(h) + if residual is not None: + h.add_(residual) + + h = h.permute(0, 4, 1, 2, 3) + return h + + +class Upsample3D(nn.Module): + + def __init__(self, in_channels, scale_factor=2) -> None: + super().__init__() + + self.scale_factor = scale_factor + self.conv3d = Res3DBlockUpsample( + input_filters=in_channels, + num_filters=in_channels, + down_sampling_stride=(1, 1, 1), + down_sampling=False, + ) + + def forward(self, x, is_init=True, is_split=True) -> torch.Tensor: + b, c, t, h, w = x.shape + + # x = x.permute(0,2,3,4,1).contiguous().permute(0,4,1,2,3).to(memory_format=torch.channels_last_3d) + if is_split: + split_size = c // 8 + x_slices = torch.split(x, split_size, dim=1) + x = [ + nn.functional.interpolate(x, scale_factor=self.scale_factor) + for x in x_slices + ] + x = torch.cat(x, dim=1) + else: + x = nn.functional.interpolate(x, scale_factor=self.scale_factor) + + x = self.conv3d(x, is_init) + return x + + +class VideoDecoder(nn.Module): + + def __init__( + self, + ch=128, + z_channels=16, + out_channels=3, + ch_mult=(1, 2, 4, 4), + num_res_blocks=2, + temporal_up_layers=(2, 3), + temporal_downsample=4, + resamp_with_conv=True, + version=1, + ) -> None: + super().__init__() + + temb_ch = 0 + + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.temporal_downsample = temporal_downsample + + block_in = ch * ch_mult[self.num_resolutions - 1] + self.version = version + if version == 2: + channels = 4 * z_channels * 2**3 + self.conv_in = CausalConv(z_channels, channels, kernel_size=3) + self.shortcut_in = ChannelDuplicatingPixelUnshuffleUpSampleLayer3D( + z_channels, channels, 1 + ) + self.conv_unpatchify = ConvPixelShuffleUpSampleLayer3D( + channels, block_in, kernel_size=3, factor=2 + ) + self.shortcut_unpathify = ChannelDuplicatingPixelUnshuffleUpSampleLayer3D( + channels, block_in, 2 + ) + else: + self.conv_in = CausalConv(z_channels, block_in, kernel_size=3) + + # middle + self.mid = nn.Module() + self.mid.block_1 = Resnet3DBlock( + in_channels=block_in, out_channels=block_in, temb_channels=temb_ch + ) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = Resnet3DBlock( + in_channels=block_in, out_channels=block_in, temb_channels=temb_ch + ) + + # upsampling + self.up_id = len(temporal_up_layers) + self.video_frame_num = 1 + self.cur_video_frame_num = self.video_frame_num // 2**self.up_id + 1 + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + Resnet3DBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=temb_ch, + ) + ) + block_in = block_out + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + if i_level in temporal_up_layers: + up.upsample = Upsample3D(block_in) + self.cur_video_frame_num = self.cur_video_frame_num * 2 + else: + up.upsample = Upsample2D(block_in, resamp_with_conv) + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in) + self.conv_out = CausalConvAfterNorm(block_in, out_channels, kernel_size=3) + + @torch.inference_mode() + def forward(self, z, is_init=True) -> torch.Tensor: + z = rearrange(z, "b t c h w -> b c t h w") + h = self.conv_in(z, is_init=is_init) + if self.version == 2: + shortcut = self.shortcut_in(z, is_init=is_init) + h = h.add_(shortcut) + shortcut = self.shortcut_unpathify(h, is_init=is_init) + h = self.conv_unpatchify(h, is_init=is_init) + h = h.add_(shortcut) + + temb = None + + h = h.permute(0, 2, 3, 4, 1).contiguous().permute(0, 4, 1, 2, 3) + h = self.mid.block_1(h, temb, is_init=is_init) + h = self.mid.attn_1(h) + h = h.permute(0, 2, 3, 4, 1).contiguous().permute(0, 4, 1, 2, 3) + h = self.mid.block_2(h, temb, is_init=is_init) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = h.permute(0, 2, 3, 4, 1).contiguous().permute(0, 4, 1, 2, 3) + h = self.up[i_level].block[i_block](h, temb, is_init=is_init) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + if isinstance(self.up[i_level].upsample, Upsample2D): + B = h.size(0) + h = h.permute(0, 2, 3, 4, 1).flatten(0, 1) + h = self.up[i_level].upsample(h) + h = h.unflatten(0, (B, -1)).permute(0, 4, 1, 2, 3) + else: + h = self.up[i_level].upsample(h, is_init=is_init) + + # end + h = h.permute(0, 2, 3, 4, 1) # b c l h w -> b l h w c + h = base_group_norm_with_zero_pad(h, self.norm_out, act_silu=True, pad_size=2) + h = self.conv_out(h) + h = h.permute(0, 4, 1, 2, 3) + + if is_init: + h = h[:, :, (self.temporal_downsample - 1) :] + return h + + +def rms_norm(input, normalized_shape, eps=1e-6) -> torch.Tensor: + dtype = input.dtype + input = input.to(torch.float32) + variance = ( + input.pow(2) + .flatten(-len(normalized_shape)) + .mean(-1)[(...,) + (None,) * len(normalized_shape)] + ) + input = input * torch.rsqrt(variance + eps) + return input.to(dtype) + + +class DiagonalGaussianDistribution: + + def __init__( + self, + parameters, + deterministic=False, + rms_norm_mean=False, + only_return_mean=False, + ) -> None: + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=-3) # N,[X],C,H,W + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + self.deterministic = deterministic + if self.deterministic: + self.var = self.std = torch.zeros_like( + self.mean, device=self.parameters.device, dtype=self.parameters.dtype + ) + if rms_norm_mean: + self.mean = rms_norm(self.mean, self.mean.size()[1:]) + self.only_return_mean = only_return_mean + + def sample(self, generator=None) -> torch.Tensor: + # make sure sample is on the same device + # as the parameters and has same dtype + sample = torch.randn( + self.mean.shape, generator=generator, device=self.parameters.device + ) + sample = sample.to(dtype=self.parameters.dtype) + x = self.mean + self.std * sample + if self.only_return_mean: + return self.mean + else: + return x + + +class AutoencoderKLStepvideo(nn.Module, ParallelTiledVAE): + + def __init__( + self, + config: StepVideoVAEConfig, + ) -> None: + nn.Module.__init__(self) + ParallelTiledVAE.__init__(self, config) + + self.frame_len = config.frame_len + + if config.version == 2: + self.latent_len = 3 + base_group_norm.spatial = True # type: ignore[attr-defined] + else: + self.latent_len = 5 + base_group_norm.spatial = False # type: ignore[attr-defined] + + self.encoder = VideoEncoder( + in_channels=config.in_channels, + z_channels=config.z_channels, + num_res_blocks=config.num_res_blocks, + version=config.version, + ) + + self.decoder = VideoDecoder( + z_channels=config.z_channels, + out_channels=config.out_channels, + num_res_blocks=config.num_res_blocks, + version=config.version, + ) + + self.world_size = config.world_size + # self.is_init = True + + def load_state_dict(self, state_dict, strict=True): + remapped = {} + for key, value in state_dict.items(): + if key.startswith("decoder.conv_out."): + # move “decoder.conv_out.weight” → “decoder.conv_out.conv.weight” + suffix = key[len("decoder.conv_out.") :] + remapped[f"decoder.conv_out.conv.{suffix}"] = value + else: + remapped[key] = value + super().load_state_dict(remapped, strict=strict) + + def _encode(self, x, is_init_image=True) -> torch.Tensor: + # b, len, c, h, w = x.size() + b, c, len, h, w = x.size() + # x = rearrange(x, 'b l c h w -> b c l h w').contiguous() + z = self.encoder(x, len, True) # 下采样[1, 4, 8, 16, 16] + return z + + @torch.inference_mode() + def encode(self, x): + # b (nc cf) c h w -> (b nc) cf c h w -> encode -> (b nc) cf c h w -> b (nc cf) c h w + chunks = list(x.split(self.frame_len, dim=1)) + for i in range(len(chunks)): + chunks[i] = self._encode(chunks[i], True) + z = torch.cat(chunks, dim=1) + + posterior = DiagonalGaussianDistribution(z) + return posterior.sample() + + def _decode(self, z) -> torch.Tensor: + + chunks = list(z.split(self.latent_len, dim=2)) + for i in range(len(chunks)): + chunks[i] = chunks[i].permute(0, 2, 1, 3, 4) + chunks[i] = chunks[i].to(next(self.decoder.parameters()).dtype) + chunks[i] = self.decoder(chunks[i], is_init=True) + x = torch.cat(chunks, dim=2) + return x + + def decode(self, z) -> torch.Tensor: + num_frames = z.size(2) + dec = ParallelTiledVAE.decode(self, z).permute(0, 2, 1, 3, 4) + dec = self.mix(dec).permute(0, 2, 1, 3, 4) + num_sample_frames = num_frames // 3 * 17 + return dec[:, :, :num_sample_frames] + + def mix(self, x) -> torch.Tensor: + remain_scale = 0.6 + mix_scale = 1.0 - remain_scale + front = slice(self.frame_len - 1, x.size(1) - 1, self.frame_len) + back = slice(self.frame_len, x.size(1), self.frame_len) + x[:, back] = x[:, back] * remain_scale + x[:, front] * mix_scale + x[:, front] = x[:, front] * remain_scale + x[:, back] * mix_scale + return x + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + generator: torch.Generator | None = None, + ) -> torch.Tensor: + """ + Args: + sample (`torch.Tensor`): Input sample. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z) + return dec + + +EntryClass = AutoencoderKLStepvideo diff --git a/python/sglang/multimodal_gen/runtime/models/vaes/wanvae.py b/python/sglang/multimodal_gen/runtime/models/vaes/wanvae.py new file mode 100644 index 000000000..1018d43be --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/models/vaes/wanvae.py @@ -0,0 +1,1343 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextvars +from contextlib import contextmanager + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from sglang.multimodal_gen.configs.models.vaes import WanVAEConfig +from sglang.multimodal_gen.runtime.layers.activation import get_act_fn +from sglang.multimodal_gen.runtime.models.vaes.common import ( + DiagonalGaussianDistribution, + ParallelTiledVAE, +) +from sglang.multimodal_gen.runtime.platforms import current_platform + +CACHE_T = 2 + +is_first_frame = contextvars.ContextVar("is_first_frame", default=False) +feat_cache = contextvars.ContextVar("feat_cache", default=None) +feat_idx = contextvars.ContextVar("feat_idx", default=0) +first_chunk = contextvars.ContextVar("first_chunk", default=None) + + +@contextmanager +def forward_context( + first_frame_arg=False, feat_cache_arg=None, feat_idx_arg=None, first_chunk_arg=None +): + is_first_frame_token = is_first_frame.set(first_frame_arg) + feat_cache_token = feat_cache.set(feat_cache_arg) + feat_idx_token = feat_idx.set(feat_idx_arg) + first_chunk_token = first_chunk.set(first_chunk_arg) + try: + yield + finally: + is_first_frame.reset(is_first_frame_token) + feat_cache.reset(feat_cache_token) + feat_idx.reset(feat_idx_token) + first_chunk.reset(first_chunk_token) + + +class AvgDown3D(nn.Module): + + def __init__( + self, + in_channels, + out_channels, + factor_t, + factor_s=1, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.factor_t = factor_t + self.factor_s = factor_s + self.factor = self.factor_t * self.factor_s * self.factor_s + + assert in_channels * self.factor % out_channels == 0 + self.group_size = in_channels * self.factor // out_channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t + pad = (0, 0, 0, 0, pad_t, 0) + x = F.pad(x, pad) + B, C, T, H, W = x.shape + x = x.view( + B, + C, + T // self.factor_t, + self.factor_t, + H // self.factor_s, + self.factor_s, + W // self.factor_s, + self.factor_s, + ) + x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous() + x = x.view( + B, + C * self.factor, + T // self.factor_t, + H // self.factor_s, + W // self.factor_s, + ) + x = x.view( + B, + self.out_channels, + self.group_size, + T // self.factor_t, + H // self.factor_s, + W // self.factor_s, + ) + x = x.mean(dim=2) + return x + + +class DupUp3D(nn.Module): + + def __init__( + self, + in_channels: int, + out_channels: int, + factor_t, + factor_s=1, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + self.factor_t = factor_t + self.factor_s = factor_s + self.factor = self.factor_t * self.factor_s * self.factor_s + + assert out_channels * self.factor % in_channels == 0 + self.repeats = out_channels * self.factor // in_channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.repeat_interleave(self.repeats, dim=1) + x = x.view( + x.size(0), + self.out_channels, + self.factor_t, + self.factor_s, + self.factor_s, + x.size(2), + x.size(3), + x.size(4), + ) + x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous() + x = x.view( + x.size(0), + self.out_channels, + x.size(2) * self.factor_t, + x.size(4) * self.factor_s, + x.size(6) * self.factor_s, + ) + + _first_chunk = first_chunk.get() + if _first_chunk: + x = x[:, :, self.factor_t - 1 :, :, :] + return x + + +class WanCausalConv3d(nn.Conv3d): + r""" + A custom 3D causal convolution layer with feature caching support. + + This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature + caching for efficient inference. + + Args: + in_channels (int): Number of channels in the input image + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int or tuple, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0 + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int | tuple[int, int, int], + stride: int | tuple[int, int, int] = 1, + padding: int | tuple[int, int, int] = 0, + ) -> None: + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + self.padding: tuple[int, int, int] + # Set up causal padding + self._padding: tuple[int, ...] = ( + self.padding[2], + self.padding[2], + self.padding[1], + self.padding[1], + 2 * self.padding[0], + 0, + ) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = F.pad(x, padding) + x = ( + x.to(self.weight.dtype) if current_platform.is_mps() else x + ) # casting needed for mps since amp isn't supported + return super().forward(x) + + +class WanRMS_norm(nn.Module): + r""" + A custom RMS normalization layer. + + Args: + dim (int): The number of dimensions to normalize over. + channel_first (bool, optional): Whether the input tensor has channels as the first dimension. + Default is True. + images (bool, optional): Whether the input represents image data. Default is True. + bias (bool, optional): Whether to include a learnable bias term. Default is False. + """ + + def __init__( + self, + dim: int, + channel_first: bool = True, + images: bool = True, + bias: bool = False, + ) -> None: + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 + + def forward(self, x): + return ( + F.normalize(x, dim=(1 if self.channel_first else -1)) + * self.scale + * self.gamma + + self.bias + ) + + +class WanUpsample(nn.Upsample): + r""" + Perform upsampling while ensuring the output tensor has the same data type as the input. + + Args: + x (torch.Tensor): Input tensor to be upsampled. + + Returns: + torch.Tensor: Upsampled tensor with the same data type as the input. + """ + + def forward(self, x): + return super().forward(x.float()).type_as(x) + + +class WanResample(nn.Module): + r""" + A custom resampling module for 2D and 3D data. + + Args: + dim (int): The number of input/output channels. + mode (str): The resampling mode. Must be one of: + - 'none': No resampling (identity operation). + - 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution. + - 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution. + - 'downsample2d': 2D downsampling with zero-padding and convolution. + - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution. + """ + + def __init__(self, dim: int, mode: str, upsample_out_dim: int = None) -> None: + super().__init__() + self.dim = dim + self.mode = mode + + # default to dim //2 + if upsample_out_dim is None: + upsample_out_dim = dim // 2 + + # layers + if mode == "upsample2d": + self.resample = nn.Sequential( + WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, upsample_out_dim, 3, padding=1), + ) + elif mode == "upsample3d": + self.resample = nn.Sequential( + WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, upsample_out_dim, 3, padding=1), + ) + self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + + elif mode == "downsample2d": + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)) + ) + elif mode == "downsample3d": + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)) + ) + self.time_conv = WanCausalConv3d( + dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0) + ) + + else: + self.resample = nn.Identity() + + def forward(self, x): + b, c, t, h, w = x.size() + first_frame = is_first_frame.get() + if first_frame: + assert t == 1 + _feat_cache = feat_cache.get() + _feat_idx = feat_idx.get() + if self.mode == "upsample3d": + if _feat_cache is not None: + idx = _feat_idx + if _feat_cache[idx] is None: + _feat_cache[idx] = "Rep" + _feat_idx += 1 + else: + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if ( + cache_x.shape[2] < 2 + and _feat_cache[idx] is not None + and _feat_cache[idx] != "Rep" + ): + # cache last frame of last two chunk + cache_x = torch.cat( + [ + _feat_cache[idx][:, :, -1, :, :] + .unsqueeze(2) + .to(cache_x.device), + cache_x, + ], + dim=2, + ) + if ( + cache_x.shape[2] < 2 + and _feat_cache[idx] is not None + and _feat_cache[idx] == "Rep" + ): + cache_x = torch.cat( + [torch.zeros_like(cache_x).to(cache_x.device), cache_x], + dim=2, + ) + if _feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, _feat_cache[idx]) + _feat_cache[idx] = cache_x + _feat_idx += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) + x = x.reshape(b, c, t * 2, h, w) + feat_cache.set(_feat_cache) + feat_idx.set(_feat_idx) + elif not first_frame and hasattr(self, "time_conv"): + x = self.time_conv(x) + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.resample(x) + x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4) + + _feat_cache = feat_cache.get() + _feat_idx = feat_idx.get() + if self.mode == "downsample3d": + if _feat_cache is not None: + idx = _feat_idx + if _feat_cache[idx] is None: + _feat_cache[idx] = x.clone() + _feat_idx += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv( + torch.cat([_feat_cache[idx][:, :, -1:, :, :], x], 2) + ) + _feat_cache[idx] = cache_x + _feat_idx += 1 + feat_cache.set(_feat_cache) + feat_idx.set(_feat_idx) + elif not first_frame and hasattr(self, "time_conv"): + x = self.time_conv(x) + return x + + +class WanResidualBlock(nn.Module): + r""" + A custom residual block module. + + Args: + in_dim (int): Number of input channels. + out_dim (int): Number of output channels. + dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0. + non_linearity (str, optional): Type of non-linearity to use. Default is "silu". + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + dropout: float = 0.0, + non_linearity: str = "silu", + ) -> None: + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.nonlinearity = get_act_fn(non_linearity) + + # layers + self.norm1 = WanRMS_norm(in_dim, images=False) + self.conv1 = WanCausalConv3d(in_dim, out_dim, 3, padding=1) + self.norm2 = WanRMS_norm(out_dim, images=False) + self.dropout = nn.Dropout(dropout) + self.conv2 = WanCausalConv3d(out_dim, out_dim, 3, padding=1) + self.conv_shortcut = ( + WanCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity() + ) + + def forward(self, x): + # Apply shortcut connection + h = self.conv_shortcut(x) + + # First normalization and activation + x = self.norm1(x) + x = self.nonlinearity(x) + + _feat_cache = feat_cache.get() + _feat_idx = feat_idx.get() + if _feat_cache is not None: + idx = _feat_idx + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and _feat_cache[idx] is not None: + cache_x = torch.cat( + [ + _feat_cache[idx][:, :, -1, :, :] + .unsqueeze(2) + .to(cache_x.device), + cache_x, + ], + dim=2, + ) + + x = self.conv1(x, _feat_cache[idx]) + _feat_cache[idx] = cache_x + _feat_idx += 1 + feat_cache.set(_feat_cache) + feat_idx.set(_feat_idx) + else: + x = self.conv1(x) + + # Second normalization and activation + x = self.norm2(x) + x = self.nonlinearity(x) + + # Dropout + x = self.dropout(x) + + _feat_cache = feat_cache.get() + _feat_idx = feat_idx.get() + if _feat_cache is not None: + idx = _feat_idx + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and _feat_cache[idx] is not None: + cache_x = torch.cat( + [ + _feat_cache[idx][:, :, -1, :, :] + .unsqueeze(2) + .to(cache_x.device), + cache_x, + ], + dim=2, + ) + + x = self.conv2(x, _feat_cache[idx]) + _feat_cache[idx] = cache_x + _feat_idx += 1 + feat_cache.set(_feat_cache) + feat_idx.set(_feat_idx) + else: + x = self.conv2(x) + + # Add residual connection + return x + h + + +class WanAttentionBlock(nn.Module): + r""" + Causal self-attention with a single head. + + Args: + dim (int): The number of channels in the input tensor. + """ + + def __init__(self, dim) -> None: + super().__init__() + self.dim = dim + + # layers + self.norm = WanRMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + def forward(self, x): + identity = x + batch_size, channels, time, height, width = x.size() + + x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width) + x = self.norm(x) + + # compute query, key, value + qkv = self.to_qkv(x) + qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1) + qkv = qkv.permute(0, 1, 3, 2).contiguous() + q, k, v = qkv.chunk(3, dim=-1) + + # apply attention + x = F.scaled_dot_product_attention(q, k, v) + + x = ( + x.squeeze(1) + .permute(0, 2, 1) + .reshape(batch_size * time, channels, height, width) + ) + + # output projection + x = self.proj(x) + + # Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w] + x = x.view(batch_size, time, channels, height, width) + x = x.permute(0, 2, 1, 3, 4) + + return x + identity + + +class WanMidBlock(nn.Module): + """ + Middle block for WanVAE encoder and decoder. + + Args: + dim (int): Number of input/output channels. + dropout (float): Dropout rate. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__( + self, + dim: int, + dropout: float = 0.0, + non_linearity: str = "silu", + num_layers: int = 1, + ): + super().__init__() + self.dim = dim + + # Create the components + resnets = [WanResidualBlock(dim, dim, dropout, non_linearity)] + attentions = [] + for _ in range(num_layers): + attentions.append(WanAttentionBlock(dim)) + resnets.append(WanResidualBlock(dim, dim, dropout, non_linearity)) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward(self, x): + # First residual block + x = self.resnets[0](x) + + # Process through attention and residual blocks + for attn, resnet in zip(self.attentions, self.resnets[1:], strict=True): + if attn is not None: + x = attn(x) + + x = resnet(x) + + return x + + +class WanResidualDownBlock(nn.Module): + + def __init__( + self, + in_dim, + out_dim, + dropout, + num_res_blocks, + temperal_downsample=False, + down_flag=False, + ): + super().__init__() + + # Shortcut path with downsample + self.avg_shortcut = AvgDown3D( + in_dim, + out_dim, + factor_t=2 if temperal_downsample else 1, + factor_s=2 if down_flag else 1, + ) + + # Main path with residual blocks and downsample + resnets = [] + for _ in range(num_res_blocks): + resnets.append(WanResidualBlock(in_dim, out_dim, dropout)) + in_dim = out_dim + self.resnets = nn.ModuleList(resnets) + + # Add the final downsample block + if down_flag: + mode = "downsample3d" if temperal_downsample else "downsample2d" + self.downsampler = WanResample(out_dim, mode=mode) + else: + self.downsampler = None + + def forward(self, x): + x_copy = x.clone() + for resnet in self.resnets: + x = resnet(x) + if self.downsampler is not None: + x = self.downsampler(x) + + return x + self.avg_shortcut(x_copy) + + +class WanEncoder3d(nn.Module): + r""" + A 3D encoder module. + + Args: + dim (int): The base number of channels in the first layer. + z_dim (int): The dimensionality of the latent space. + dim_mult (list of int): Multipliers for the number of channels in each block. + num_res_blocks (int): Number of residual blocks in each block. + attn_scales (list of float): Scales at which to apply attention mechanisms. + temperal_downsample (list of bool): Whether to downsample temporally in each block. + dropout (float): Dropout rate for the dropout layers. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__( + self, + in_channels: int = 3, + dim=128, + z_dim=4, + dim_mult=(1, 2, 4, 4), + num_res_blocks=2, + attn_scales=(), + temperal_downsample=(True, True, False), + dropout=0.0, + non_linearity: str = "silu", + is_residual: bool = False, # wan 2.2 vae use a residual downblock + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + dim_mult = list(dim_mult) + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = list(attn_scales) + self.temperal_downsample = list(temperal_downsample) + self.nonlinearity = get_act_fn(non_linearity) + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv_in = WanCausalConv3d(in_channels, dims[0], 3, padding=1) + + # downsample blocks + self.down_blocks = nn.ModuleList([]) + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:], strict=True)): + # residual (+attention) blocks + if is_residual: + self.down_blocks.append( + WanResidualDownBlock( + in_dim, + out_dim, + dropout, + num_res_blocks, + temperal_downsample=( + temperal_downsample[i] if i != len(dim_mult) - 1 else False + ), + down_flag=i != len(dim_mult) - 1, + ) + ) + else: + for _ in range(num_res_blocks): + self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + self.down_blocks.append(WanAttentionBlock(out_dim)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + mode = "downsample3d" if temperal_downsample[i] else "downsample2d" + self.down_blocks.append(WanResample(out_dim, mode=mode)) + scale /= 2.0 + + # middle blocks + self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1) + + # output blocks + self.norm_out = WanRMS_norm(out_dim, images=False) + self.conv_out = WanCausalConv3d(out_dim, z_dim, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, x): + _feat_cache = feat_cache.get() + _feat_idx = feat_idx.get() + if _feat_cache is not None: + idx = _feat_idx + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and _feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat( + [ + _feat_cache[idx][:, :, -1, :, :] + .unsqueeze(2) + .to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = self.conv_in(x, _feat_cache[idx]) + _feat_cache[idx] = cache_x + _feat_idx += 1 + feat_cache.set(_feat_cache) + feat_idx.set(_feat_idx) + else: + x = self.conv_in(x) + + ## downsamples + for layer in self.down_blocks: + x = layer(x) + + ## middle + x = self.mid_block(x) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + + _feat_cache = feat_cache.get() + _feat_idx = feat_idx.get() + if _feat_cache is not None: + idx = _feat_idx + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and _feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat( + [ + _feat_cache[idx][:, :, -1, :, :] + .unsqueeze(2) + .to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = self.conv_out(x, _feat_cache[idx]) + _feat_cache[idx] = cache_x + _feat_idx += 1 + feat_cache.set(_feat_cache) + feat_idx.set(_feat_idx) + else: + x = self.conv_out(x) + return x + + +# adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +class WanResidualUpBlock(nn.Module): + """ + A block that handles upsampling for the WanVAE decoder. + Args: + in_dim (int): Input dimension + out_dim (int): Output dimension + num_res_blocks (int): Number of residual blocks + dropout (float): Dropout rate + temperal_upsample (bool): Whether to upsample on temporal dimension + up_flag (bool): Whether to upsample or not + non_linearity (str): Type of non-linearity to use + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + num_res_blocks: int, + dropout: float = 0.0, + temperal_upsample: bool = False, + up_flag: bool = False, + non_linearity: str = "silu", + ): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + if up_flag: + self.avg_shortcut = DupUp3D( + in_dim, + out_dim, + factor_t=2 if temperal_upsample else 1, + factor_s=2, + ) + else: + self.avg_shortcut = None + + # create residual blocks + resnets = [] + current_dim = in_dim + for _ in range(num_res_blocks + 1): + resnets.append( + WanResidualBlock(current_dim, out_dim, dropout, non_linearity) + ) + current_dim = out_dim + + self.resnets = nn.ModuleList(resnets) + + # Add upsampling layer if needed + if up_flag: + upsample_mode = "upsample3d" if temperal_upsample else "upsample2d" + self.upsampler = WanResample( + out_dim, mode=upsample_mode, upsample_out_dim=out_dim + ) + else: + self.upsampler = None + + self.gradient_checkpointing = False + + def forward(self, x): + """ + Forward pass through the upsampling block. + Args: + x (torch.Tensor): Input tensor + feat_cache (list, optional): Feature cache for causal convolutions + feat_idx (list, optional): Feature index for cache management + Returns: + torch.Tensor: Output tensor + """ + if self.avg_shortcut is not None: + x_copy = x.clone() + + for resnet in self.resnets: + x = resnet(x) + + if self.upsampler is not None: + x = self.upsampler(x) + + if self.avg_shortcut is not None: + x = x + self.avg_shortcut(x_copy) + + return x + + +class WanUpBlock(nn.Module): + """ + A block that handles upsampling for the WanVAE decoder. + + Args: + in_dim (int): Input dimension + out_dim (int): Output dimension + num_res_blocks (int): Number of residual blocks + dropout (float): Dropout rate + upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d') + non_linearity (str): Type of non-linearity to use + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + num_res_blocks: int, + dropout: float = 0.0, + upsample_mode: str | None = None, + non_linearity: str = "silu", + ): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # Create layers list + resnets = [] + # Add residual blocks and attention if needed + current_dim = in_dim + for _ in range(num_res_blocks + 1): + resnets.append( + WanResidualBlock(current_dim, out_dim, dropout, non_linearity) + ) + current_dim = out_dim + + self.resnets = nn.ModuleList(resnets) + + # Add upsampling layer if needed + self.upsamplers = None + if upsample_mode is not None: + self.upsamplers = nn.ModuleList([WanResample(out_dim, mode=upsample_mode)]) + + self.gradient_checkpointing = False + + def forward(self, x): + """ + Forward pass through the upsampling block. + + Args: + x (torch.Tensor): Input tensor + feat_cache (list, optional): Feature cache for causal convolutions + feat_idx (list, optional): Feature index for cache management + + Returns: + torch.Tensor: Output tensor + """ + for resnet in self.resnets: + x = resnet(x) + + if self.upsamplers is not None: + x = self.upsamplers[0](x) + return x + + +class WanDecoder3d(nn.Module): + r""" + A 3D decoder module. + + Args: + dim (int): The base number of channels in the first layer. + z_dim (int): The dimensionality of the latent space. + dim_mult (list of int): Multipliers for the number of channels in each block. + num_res_blocks (int): Number of residual blocks in each block. + attn_scales (list of float): Scales at which to apply attention mechanisms. + temperal_upsample (list of bool): Whether to upsample temporally in each block. + dropout (float): Dropout rate for the dropout layers. + non_linearity (str): Type of non-linearity to use. + """ + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=(1, 2, 4, 4), + num_res_blocks=2, + attn_scales=(), + temperal_upsample=(False, True, True), + dropout=0.0, + non_linearity: str = "silu", + out_channels: int = 3, + is_residual: bool = False, + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + dim_mult = list(dim_mult) + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = list(attn_scales) + self.temperal_upsample = list(temperal_upsample) + + self.nonlinearity = get_act_fn(non_linearity) + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + + # init block + self.conv_in = WanCausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.mid_block = WanMidBlock(dims[0], dropout, non_linearity, num_layers=1) + + # upsample blocks + self.up_blocks = nn.ModuleList([]) + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:], strict=True)): + # residual (+attention) blocks + if i > 0 and not is_residual: + # wan vae 2.1 + in_dim = in_dim // 2 + + # determine if we need upsampling + up_flag = i != len(dim_mult) - 1 + # determine upsampling mode, if not upsampling, set to None + upsample_mode = None + if up_flag and temperal_upsample[i]: + upsample_mode = "upsample3d" + elif up_flag: + upsample_mode = "upsample2d" + + # Create and add the upsampling block + if is_residual: + up_block = WanResidualUpBlock( + in_dim=in_dim, + out_dim=out_dim, + num_res_blocks=num_res_blocks, + dropout=dropout, + temperal_upsample=temperal_upsample[i] if up_flag else False, + up_flag=up_flag, + non_linearity=non_linearity, + ) + else: + up_block = WanUpBlock( + in_dim=in_dim, + out_dim=out_dim, + num_res_blocks=num_res_blocks, + dropout=dropout, + upsample_mode=upsample_mode, + non_linearity=non_linearity, + ) + self.up_blocks.append(up_block) + + # output blocks + self.norm_out = WanRMS_norm(out_dim, images=False) + self.conv_out = WanCausalConv3d(out_dim, out_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, x): + ## conv1 + _feat_cache = feat_cache.get() + _feat_idx = feat_idx.get() + if _feat_cache is not None: + idx = _feat_idx + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and _feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat( + [ + _feat_cache[idx][:, :, -1, :, :] + .unsqueeze(2) + .to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = self.conv_in(x, _feat_cache[idx]) + _feat_cache[idx] = cache_x + _feat_idx += 1 + feat_cache.set(_feat_cache) + feat_idx.set(_feat_idx) + else: + x = self.conv_in(x) + + ## middle + x = self.mid_block(x) + + ## upsamples + for up_block in self.up_blocks: + x = up_block(x) + + ## head + x = self.norm_out(x) + x = self.nonlinearity(x) + _feat_cache = feat_cache.get() + _feat_idx = feat_idx.get() + if _feat_cache is not None: + idx = _feat_idx + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and _feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat( + [ + _feat_cache[idx][:, :, -1, :, :] + .unsqueeze(2) + .to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = self.conv_out(x, _feat_cache[idx]) + _feat_cache[idx] = cache_x + _feat_idx += 1 + feat_cache.set(_feat_cache) + feat_idx.set(_feat_idx) + else: + x = self.conv_out(x) + return x + + +def patchify(x, patch_size): + if patch_size == 1: + return x + + if x.dim() == 4: + x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size) + elif x.dim() == 5: + x = rearrange( + x, + "b c f (h q) (w r) -> b (c r q) f h w", + q=patch_size, + r=patch_size, + ) + else: + raise ValueError(f"Invalid input shape: {x.shape}") + + return x + + +def unpatchify(x, patch_size): + if patch_size == 1: + return x + + if x.dim() == 4: + x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size) + elif x.dim() == 5: + x = rearrange( + x, + "b (c r q) f h w -> b c f (h q) (w r)", + q=patch_size, + r=patch_size, + ) + + return x + + +class AutoencoderKLWan(nn.Module, ParallelTiledVAE): + r""" + A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. + Introduced in [Wan 2.1]. + """ + + _supports_gradient_checkpointing = False + + def __init__( + self, + config: WanVAEConfig, + ) -> None: + nn.Module.__init__(self) + ParallelTiledVAE.__init__(self, config) + + self.z_dim = config.z_dim + self.temperal_downsample = list(config.temperal_downsample) + self.temperal_upsample = list(config.temperal_downsample)[::-1] + + if config.decoder_base_dim is None: + decoder_base_dim = config.base_dim + else: + decoder_base_dim = config.decoder_base_dim + + self.latents_mean = list(config.latents_mean) + self.latents_std = list(config.latents_std) + self.shift_factor = config.shift_factor + + if config.load_encoder: + self.encoder = WanEncoder3d( + in_channels=config.in_channels, + dim=config.base_dim, + z_dim=self.z_dim * 2, + dim_mult=config.dim_mult, + num_res_blocks=config.num_res_blocks, + attn_scales=config.attn_scales, + temperal_downsample=self.temperal_downsample, + dropout=config.dropout, + is_residual=config.is_residual, + ) + self.quant_conv = WanCausalConv3d(self.z_dim * 2, self.z_dim * 2, 1) + self.post_quant_conv = WanCausalConv3d(self.z_dim, self.z_dim, 1) + + if config.load_decoder: + self.decoder = WanDecoder3d( + dim=decoder_base_dim, + z_dim=self.z_dim, + dim_mult=config.dim_mult, + num_res_blocks=config.num_res_blocks, + attn_scales=config.attn_scales, + temperal_upsample=self.temperal_upsample, + dropout=config.dropout, + out_channels=config.out_channels, + is_residual=config.is_residual, + ) + + self.use_feature_cache = config.use_feature_cache + + def clear_cache(self) -> None: + + def _count_conv3d(model) -> int: + count = 0 + for m in model.modules(): + if isinstance(m, WanCausalConv3d): + count += 1 + return count + + if self.config.load_decoder: + self._conv_num = _count_conv3d(self.decoder) + self._conv_idx = 0 + self._feat_map = [None] * self._conv_num + # cache encode + if self.config.load_encoder: + self._enc_conv_num = _count_conv3d(self.encoder) + self._enc_conv_idx = 0 + self._enc_feat_map = [None] * self._enc_conv_num + + def encode(self, x: torch.Tensor) -> torch.Tensor: + if self.use_feature_cache: + self.clear_cache() + if self.config.patch_size is not None: + x = patchify(x, patch_size=self.config.patch_size) + with forward_context( + feat_cache_arg=self._enc_feat_map, feat_idx_arg=self._enc_conv_idx + ): + t = x.shape[2] + iter_ = 1 + (t - 1) // 4 + for i in range(iter_): + feat_idx.set(0) + if i == 0: + out = self.encoder(x[:, :, :1, :, :]) + else: + out_ = self.encoder(x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :]) + out = torch.cat([out, out_], 2) + enc = self.quant_conv(out) + mu, logvar = enc[:, : self.z_dim, :, :, :], enc[:, self.z_dim :, :, :, :] + enc = torch.cat([mu, logvar], dim=1) + enc = DiagonalGaussianDistribution(enc) + self.clear_cache() + else: + for block in self.encoder.down_blocks: + if isinstance(block, WanResample) and block.mode == "downsample3d": + _padding = list(block.time_conv._padding) + _padding[4] = 2 + block.time_conv._padding = tuple(_padding) + enc = ParallelTiledVAE.encode(self, x) + + return enc + + def _encode(self, x: torch.Tensor, first_frame=False) -> torch.Tensor: + with forward_context(first_frame_arg=first_frame): + out = self.encoder(x) + enc = self.quant_conv(out) + mu, logvar = enc[:, : self.z_dim, :, :, :], enc[:, self.z_dim :, :, :, :] + enc = torch.cat([mu, logvar], dim=1) + return enc + + def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + first_frame = x[:, :, 0, :, :].unsqueeze(2) + first_frame = self._encode(first_frame, first_frame=True) + + enc = ParallelTiledVAE.tiled_encode(self, x) + enc = enc[:, :, 1:] + enc = torch.cat([first_frame, enc], dim=2) + return enc + + def spatial_tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + first_frame = x[:, :, 0, :, :].unsqueeze(2) + first_frame = self._encode(first_frame, first_frame=True) + + enc = ParallelTiledVAE.spatial_tiled_encode(self, x) + enc = enc[:, :, 1:] + enc = torch.cat([first_frame, enc], dim=2) + return enc + + def decode(self, z: torch.Tensor) -> torch.Tensor: + if self.use_feature_cache: + self.clear_cache() + iter_ = z.shape[2] + x = self.post_quant_conv(z) + with forward_context( + feat_cache_arg=self._feat_map, feat_idx_arg=self._conv_idx + ): + for i in range(iter_): + feat_idx.set(0) + if i == 0: + first_chunk.set(True) + out = self.decoder(x[:, :, i : i + 1, :, :]) + else: + first_chunk.set(False) + out_ = self.decoder(x[:, :, i : i + 1, :, :]) + out = torch.cat([out, out_], 2) + + if self.config.patch_size is not None: + out = unpatchify(out, patch_size=self.config.patch_size) + + out = out.float() + out = torch.clamp(out, min=-1.0, max=1.0) + self.clear_cache() + else: + out = ParallelTiledVAE.decode(self, z) + + return out + + def _decode(self, z: torch.Tensor, first_frame=False) -> torch.Tensor: + x = self.post_quant_conv(z) + with forward_context(first_frame_arg=first_frame): + out = self.decoder(x) + + out = torch.clamp(out, min=-1.0, max=1.0) + + return out + + def tiled_decode(self, z: torch.Tensor) -> torch.Tensor: + self.blend_num_frames *= 2 + dec = ParallelTiledVAE.tiled_decode(self, z) + start_frame_idx = self.temporal_compression_ratio - 1 + dec = dec[:, :, start_frame_idx:] + return dec + + def spatial_tiled_decode(self, z: torch.Tensor) -> torch.Tensor: + dec = ParallelTiledVAE.spatial_tiled_decode(self, z) + start_frame_idx = self.temporal_compression_ratio - 1 + dec = dec[:, :, start_frame_idx:] + return dec + + def parallel_tiled_decode(self, z: torch.FloatTensor) -> torch.FloatTensor: + self.blend_num_frames *= 2 + dec = ParallelTiledVAE.parallel_tiled_decode(self, z) + start_frame_idx = self.temporal_compression_ratio - 1 + dec = dec[:, :, start_frame_idx:] + return dec + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + generator: torch.Generator | None = None, + ) -> torch.Tensor: + """ + Args: + sample (`torch.Tensor`): Input sample. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z) + return dec + + +EntryClass = AutoencoderKLWan diff --git a/python/sglang/multimodal_gen/runtime/models/vision_utils.py b/python/sglang/multimodal_gen/runtime/models/vision_utils.py new file mode 100644 index 000000000..ac22579a0 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/models/vision_utils.py @@ -0,0 +1,301 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +import os +import tempfile +from collections.abc import Callable +from urllib.parse import unquote, urlparse + +import imageio +import numpy as np +import PIL.Image +import PIL.ImageOps +import requests +import torch +from packaging import version + +if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): + PIL_INTERPOLATION = { + "linear": PIL.Image.Resampling.BILINEAR, + "bilinear": PIL.Image.Resampling.BILINEAR, + "bicubic": PIL.Image.Resampling.BICUBIC, + "lanczos": PIL.Image.Resampling.LANCZOS, + "nearest": PIL.Image.Resampling.NEAREST, + } +else: + PIL_INTERPOLATION = { + "linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + "nearest": PIL.Image.NEAREST, + } + + +def pil_to_numpy(images: list[PIL.Image.Image] | PIL.Image.Image) -> np.ndarray: + r""" + Convert a PIL image or a list of PIL images to NumPy arrays. + + Args: + images (`PIL.Image.Image` or `List[PIL.Image.Image]`): + The PIL image or list of images to convert to NumPy format. + + Returns: + `np.ndarray`: + A NumPy array representation of the images. + """ + if not isinstance(images, list): + images = [images] + images = [np.array(image).astype(np.float32) / 255.0 for image in images] + images_arr: np.ndarray = np.stack(images, axis=0) + + return images_arr + + +def numpy_to_pt(images: np.ndarray) -> torch.Tensor: + r""" + Convert a NumPy image to a PyTorch tensor. + + Args: + images (`np.ndarray`): + The NumPy image array to convert to PyTorch format. + + Returns: + `torch.Tensor`: + A PyTorch tensor representation of the images. + """ + if images.ndim == 3: + images = images[..., None] + + images = torch.from_numpy(images.transpose(0, 3, 1, 2)) + return images + + +def normalize(images: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor: + r""" + Normalize an image array to [-1,1]. + + Args: + images (`np.ndarray` or `torch.Tensor`): + The image array to normalize. + + Returns: + `np.ndarray` or `torch.Tensor`: + The normalized image array. + """ + return 2.0 * images - 1.0 + + +# adapted from diffusers.utils import load_image +def load_image( + image: str | PIL.Image.Image, + convert_method: Callable[[PIL.Image.Image], PIL.Image.Image] | None = None, +) -> PIL.Image.Image: + """ + Loads `image` to a PIL Image. + + Args: + image (`str` or `PIL.Image.Image`): + The image to convert to the PIL Image format. + convert_method (Callable[[PIL.Image.Image], PIL.Image.Image], *optional*): + A conversion method to apply to the image after loading it. When set to `None` the image will be converted + "RGB". + + Returns: + `PIL.Image.Image`: + A PIL Image. + """ + if isinstance(image, str): + if image.startswith("http://") or image.startswith("https://"): + image = PIL.Image.open(requests.get(image, stream=True).raw) + elif os.path.isfile(image): + image = PIL.Image.open(image) + else: + raise ValueError( + f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {image} is not a valid path." + ) + elif isinstance(image, PIL.Image.Image): + image = image + else: + raise ValueError( + "Incorrect format used for the image. Should be a URL linking to an image, a local path, or a PIL image." + ) + + image = PIL.ImageOps.exif_transpose(image) + + if convert_method is not None: + image = convert_method(image) + else: + image = image.convert("RGB") + + return image + + +# adapted from diffusers.utils import load_video +def load_video( + video: str, + convert_method: ( + Callable[[list[PIL.Image.Image]], list[PIL.Image.Image]] | None + ) = None, +) -> list[PIL.Image.Image]: + """ + Loads `video` to a list of PIL Image. + Args: + video (`str`): + A URL or Path to a video to convert to a list of PIL Image format. + convert_method (Callable[[List[PIL.Image.Image]], List[PIL.Image.Image]], *optional*): + A conversion method to apply to the video after loading it. When set to `None` the images will be converted + to "RGB". + Returns: + `List[PIL.Image.Image]`: + The video as a list of PIL images. + """ + is_url = video.startswith("http://") or video.startswith("https://") + is_file = os.path.isfile(video) + was_tempfile_created = False + + if not (is_url or is_file): + raise ValueError( + f"Incorrect path or URL. URLs must start with `http://` or `https://`, and {video} is not a valid path." + ) + + if is_url: + response = requests.get(video, stream=True) + if response.status_code != 200: + raise ValueError( + f"Failed to download video. Status code: {response.status_code}" + ) + + parsed_url = urlparse(video) + file_name = os.path.basename(unquote(parsed_url.path)) + + suffix = os.path.splitext(file_name)[1] or ".mp4" + with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as temp_file: + video_path = temp_file.name + video_data = response.iter_content(chunk_size=8192) + for chunk in video_data: + temp_file.write(chunk) + + video = video_path + + pil_images = [] + if video.endswith(".gif"): + gif = PIL.Image.open(video) + try: + while True: + pil_images.append(gif.copy()) + gif.seek(gif.tell() + 1) + except EOFError: + pass + + else: + try: + imageio.plugins.ffmpeg.get_exe() + except AttributeError: + raise AttributeError( + "`Unable to find an ffmpeg installation on your machine. Please install via `pip install imageio-ffmpeg" + ) from None + + with imageio.get_reader(video) as reader: + # Read all frames + for frame in reader: + pil_images.append(PIL.Image.fromarray(frame)) + + if was_tempfile_created: + os.remove(video_path) + + if convert_method is not None: + pil_images = convert_method(pil_images) + + return pil_images + + +def get_default_height_width( + image: PIL.Image.Image | np.ndarray | torch.Tensor, + vae_scale_factor: int, + height: int | None = None, + width: int | None = None, +) -> tuple[int, int]: + r""" + Returns the height and width of the image, downscaled to the next integer multiple of `vae_scale_factor`. + + Args: + image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`): + The image input, which can be a PIL image, NumPy array, or PyTorch tensor. If it is a NumPy array, it + should have shape `[batch, height, width]` or `[batch, height, width, channels]`. If it is a PyTorch + tensor, it should have shape `[batch, channels, height, width]`. + height (`Optional[int]`, *optional*, defaults to `None`): + The height of the preprocessed image. If `None`, the height of the `image` input will be used. + width (`Optional[int]`, *optional*, defaults to `None`): + The width of the preprocessed image. If `None`, the width of the `image` input will be used. + + Returns: + `Tuple[int, int]`: + A tuple containing the height and width, both resized to the nearest integer multiple of + `vae_scale_factor`. + """ + + if height is None: + if isinstance(image, PIL.Image.Image): + height = image.height + elif isinstance(image, torch.Tensor): + height = image.shape[2] + else: + height = image.shape[1] + + if width is None: + if isinstance(image, PIL.Image.Image): + width = image.width + elif isinstance(image, torch.Tensor): + width = image.shape[3] + else: + width = image.shape[2] + + width, height = ( + x - x % vae_scale_factor for x in (width, height) + ) # resize to integer multiple of vae_scale_factor + + return height, width + + +def resize( + image: PIL.Image.Image | np.ndarray | torch.Tensor, + height: int, + width: int, + resize_mode: str = "default", # "default", "fill", "crop" + resample: str = "lanczos", +) -> PIL.Image.Image | np.ndarray | torch.Tensor: + """ + Resize image. + + Args: + image (`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`): + The image input, can be a PIL image, numpy array or pytorch tensor. + height (`int`): + The height to resize to. + width (`int`): + The width to resize to. + resize_mode (`str`, *optional*, defaults to `default`): + The resize mode to use, can be one of `default` or `fill`. If `default`, will resize the image to fit + within the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, + will resize the image to fit within the specified width and height, maintaining the aspect ratio, and + then center the image within the dimensions, filling empty with data from image. If `crop`, will resize + the image to fit within the specified width and height, maintaining the aspect ratio, and then center + the image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only + supported for PIL image input. + + Returns: + `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`: + The resized image. + """ + if resize_mode != "default" and not isinstance(image, PIL.Image.Image): + raise ValueError( + f"Only PIL image input is supported for resize_mode {resize_mode}" + ) + assert isinstance(image, PIL.Image.Image) + if resize_mode == "default": + image = image.resize((width, height), resample=PIL_INTERPOLATION[resample]) + else: + raise ValueError(f"resize_mode {resize_mode} is not supported") + return image diff --git a/python/sglang/multimodal_gen/runtime/pipelines/README.md b/python/sglang/multimodal_gen/runtime/pipelines/README.md new file mode 100644 index 000000000..14a9531a4 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/pipelines/README.md @@ -0,0 +1,18 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# Adding a New Custom Pipeline + +Please see documentation [here](https://hao-ai-lab.github.io/sgl-diffusion/contributing/add_pipeline.html) + +# PipelineStages + +Basic components in a pipeline, which can be used by customed pipelines of different models. + +The stages form a partial order + + +# PipelineExecutors + +Runs the stages in a pipeline in various way. Supported ways: +1. sync +2. async diff --git a/python/sglang/multimodal_gen/runtime/pipelines/__init__.py b/python/sglang/multimodal_gen/runtime/pipelines/__init__.py new file mode 100644 index 000000000..8139975b8 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/pipelines/__init__.py @@ -0,0 +1,93 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +Diffusion pipelines for sglang.multimodal_gen. + +This package contains diffusion pipelines for generating videos and images. +""" + +from typing import cast + +from sglang.multimodal_gen.runtime.pipelines.composed_pipeline_base import ( + ComposedPipelineBase, +) +from sglang.multimodal_gen.runtime.pipelines.lora_pipeline import LoRAPipeline +from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import Req +from sglang.multimodal_gen.runtime.pipelines.pipeline_registry import ( + PipelineType, + get_pipeline_registry, +) +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import ( + maybe_download_model, + verify_model_config_and_directory, +) +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class PipelineWithLoRA(LoRAPipeline, ComposedPipelineBase): + """Type for a pipeline that has both ComposedPipelineBase and LoRAPipeline functionality.""" + + pass + + +def build_pipeline( + server_args: ServerArgs, + pipeline_type: PipelineType | str = PipelineType.BASIC, +) -> PipelineWithLoRA: + """ + Only works with valid hf diffusers configs. (model_index.json) + We want to build a pipeline based on the inference args mode_path: + 1. download the model from the hub if it's not already downloaded + 2. verify the model config and directory + 3. based on the config, determine the pipeline class + """ + # Get pipeline type + model_path = server_args.model_path + model_path = maybe_download_model(model_path) + # server_args.downloaded_model_path = model_path + logger.info("Model path: %s", model_path) + + config = verify_model_config_and_directory(model_path) + pipeline_name = config.get("_class_name") + if pipeline_name is None: + raise ValueError( + "Model config does not contain a _class_name attribute. " + "Only diffusers format is supported." + ) + + # Get the appropriate pipeline registry based on pipeline_type + logger.info( + "Building pipeline of type: %s", + ( + pipeline_type.value + if isinstance(pipeline_type, PipelineType) + else pipeline_type + ), + ) + pipeline_registry = get_pipeline_registry(pipeline_type) + + if isinstance(pipeline_type, str): + pipeline_type = PipelineType.from_string(pipeline_type) + + pipeline_cls = pipeline_registry.resolve_pipeline_cls( + pipeline_name, pipeline_type, server_args.workload_type + ) + + # instantiate the pipelines + pipeline = pipeline_cls(model_path, server_args) + + logger.info("Pipelines instantiated") + + return cast(PipelineWithLoRA, pipeline) + + +__all__ = [ + "build_pipeline", + "ComposedPipelineBase", + "Req", + "LoRAPipeline", +] diff --git a/python/sglang/multimodal_gen/runtime/pipelines/composed_pipeline_base.py b/python/sglang/multimodal_gen/runtime/pipelines/composed_pipeline_base.py new file mode 100644 index 000000000..d5fcf357e --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/pipelines/composed_pipeline_base.py @@ -0,0 +1,354 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +Base class for composed pipelines. + +This module defines the base class for pipelines that are composed of multiple stages. +""" + +import argparse +import os +from abc import ABC, abstractmethod +from typing import Any, cast + +import torch +from tqdm import tqdm + +from sglang.multimodal_gen.configs.pipelines import PipelineConfig +from sglang.multimodal_gen.runtime.loader.component_loader import ( + PipelineComponentLoader, +) +from sglang.multimodal_gen.runtime.pipelines.executors.pipeline_executor import ( + PipelineExecutor, +) +from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import Req +from sglang.multimodal_gen.runtime.pipelines.stages import PipelineStage +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import ( + maybe_download_model, + verify_model_config_and_directory, +) +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class ComposedPipelineBase(ABC): + """ + Base class for pipelines composed of multiple stages. + + This class provides the framework for creating pipelines by composing multiple + stages together. Each stage is responsible for a specific part of the diffusion + process, and the pipeline orchestrates the execution of these stages. + """ + + is_video_pipeline: bool = False # To be overridden by video pipelines + # should contains only the modules to be loaded + _required_config_modules: list[str] = [] + _extra_config_module_map: dict[str, str] = {} + server_args: ServerArgs | None = None + modules: dict[str, Any] = {} + post_init_called: bool = False + executor: PipelineExecutor | None = None + + # the name of the pipeline it associated with, in diffusers + pipeline_name: str + + def __init__( + self, + model_path: str, + server_args: ServerArgs, + required_config_modules: list[str] | None = None, + loaded_modules: dict[str, torch.nn.Module] | None = None, + executor: PipelineExecutor | None = None, + ): + """ + Initialize the pipeline. After __init__, the pipeline should be ready to + use. The pipeline should be stateless and not hold any batch state. + """ + self.server_args = server_args + + self.model_path: str = model_path + self._stages: list[PipelineStage] = [] + self._stage_name_mapping: dict[str, PipelineStage] = {} + self.executor = executor or self.build_executor(server_args=server_args) + + if required_config_modules is not None: + self._required_config_modules = required_config_modules + + if self._required_config_modules is None: + raise NotImplementedError("Subclass must set _required_config_modules") + # temp disable for duplicate initialing tp + # maybe_init_distributed_environment_and_model_parallel( + # server_args.tp_size, server_args.sp_size + # ) + + # Load modules directly in initialization + logger.info("Loading pipeline modules...") + self.modules = self.load_modules(server_args, loaded_modules) + + def build_executor(self, server_args: ServerArgs): + # TODO + from sglang.multimodal_gen.runtime.pipelines.executors.parallel_executor import ( + ParallelExecutor, + ) + + # return SyncExecutor(server_args=server_args) + return ParallelExecutor(server_args=server_args) + + def post_init(self) -> None: + assert self.server_args is not None, "server_args must be set" + if self.post_init_called: + return + self.post_init_called = True + + self.initialize_pipeline(self.server_args) + if self.server_args.enable_torch_compile: + self.modules["transformer"] = torch.compile(self.modules["transformer"]) + logger.info("Torch Compile enabled for DiT") + + logger.info("Creating pipeline stages...") + self.create_pipeline_stages(self.server_args) + + @classmethod + def from_pretrained( + cls, + model_path: str, + device: str | None = None, + torch_dtype: torch.dtype | None = None, + pipeline_config: str | PipelineConfig | None = None, + args: argparse.Namespace | None = None, + required_config_modules: list[str] | None = None, + loaded_modules: dict[str, torch.nn.Module] | None = None, + **kwargs, + ) -> "ComposedPipelineBase": + """ + Load a pipeline from a pretrained model. + loaded_modules: Optional[Dict[str, torch.nn.Module]] = None, + If provided, loaded_modules will be used instead of loading from config/pretrained weights. + """ + kwargs["model_path"] = model_path + server_args = ServerArgs.from_kwargs(**kwargs) + + logger.info("server_args in from_pretrained: %s", server_args) + + pipe = cls( + model_path, + server_args, + required_config_modules=required_config_modules, + loaded_modules=loaded_modules, + ) + pipe.post_init() + return pipe + + def get_module(self, module_name: str, default_value: Any = None) -> Any: + if module_name not in self.modules: + return default_value + return self.modules[module_name] + + def add_module(self, module_name: str, module: Any): + self.modules[module_name] = module + + def _load_config(self) -> dict[str, Any]: + model_path = maybe_download_model(self.model_path) + self.model_path = model_path + # server_args.downloaded_model_path = model_path + logger.info("Model path: %s", model_path) + config = verify_model_config_and_directory(model_path) + return cast(dict[str, Any], config) + + @property + def required_config_modules(self) -> list[str]: + """ + List of modules that are required by the pipeline. The names should match + the diffusers directory and model_index.json file. These modules will be + loaded using the PipelineComponentLoader and made available in the + modules dictionary. Access these modules using the get_module method. + + class ConcretePipeline(ComposedPipelineBase): + _required_config_modules = ["vae", "text_encoder", "transformer", "scheduler", "tokenizer"] + + + @property + def required_config_modules(self): + return self._required_config_modules + """ + return self._required_config_modules + + @property + def stages(self) -> list[PipelineStage]: + """ + List of stages in the pipeline. + """ + return self._stages + + @abstractmethod + def create_pipeline_stages(self, server_args: ServerArgs): + """ + Create the inference pipeline stages. + """ + raise NotImplementedError + + def initialize_pipeline(self, server_args: ServerArgs): + """ + Initialize the pipeline. + """ + return + + def load_modules( + self, + server_args: ServerArgs, + loaded_modules: dict[str, torch.nn.Module] | None = None, + ) -> dict[str, Any]: + """ + Load the modules from the config. + loaded_modules: Optional[Dict[str, torch.nn.Module]] = None, + If provided, loaded_modules will be used instead of loading from config/pretrained weights. + """ + + model_index = self._load_config() + logger.info("Loading pipeline modules from config: %s", model_index) + + # remove keys that are not pipeline modules + model_index.pop("_class_name") + model_index.pop("_diffusers_version") + if ( + "boundary_ratio" in model_index + and model_index["boundary_ratio"] is not None + ): + logger.info( + "MoE pipeline detected. Adding transformer_2 to self.required_config_modules..." + ) + self.required_config_modules.append("transformer_2") + logger.info( + "MoE pipeline detected. Setting boundary ratio to %s", + model_index["boundary_ratio"], + ) + server_args.pipeline_config.dit_config.boundary_ratio = model_index[ + "boundary_ratio" + ] + + model_index.pop("boundary_ratio", None) + # used by Wan2.2 ti2v + model_index.pop("expand_timesteps", None) + + # some sanity checks + assert ( + len(model_index) > 1 + ), "model_index.json must contain at least one pipeline module" + + model_index = { + required_module: model_index[required_module] + for required_module in self.required_config_modules + } + + for module_name in self.required_config_modules: + if ( + module_name not in model_index + and module_name in self._extra_config_module_map + ): + extra_module_value = self._extra_config_module_map[module_name] + logger.warning( + "model_index.json does not contain a %s module, but found {%s: %s} in _extra_config_module_map, adding to model_index.", + module_name, + module_name, + extra_module_value, + ) + if extra_module_value in model_index: + logger.info( + "Using module %s for %s", extra_module_value, module_name + ) + model_index[module_name] = model_index[extra_module_value] + continue + else: + raise ValueError( + f"Required module key: {module_name} value: {model_index.get(module_name)} was not found in loaded modules {model_index.keys()}" + ) + + # all the component models used by the pipeline + required_modules = self.required_config_modules + logger.info("Loading required components: %s", required_modules) + + components = {} + for module_name, ( + transformers_or_diffusers, + architecture, + ) in tqdm(iterable=model_index.items(), desc="Loading required modules"): + + if transformers_or_diffusers is None: + logger.warning( + "Module %s in model_index.json has null value, removing from required_config_modules", + module_name, + ) + if module_name in self.required_config_modules: + self.required_config_modules.remove(module_name) + continue + if module_name not in required_modules: + logger.info("Skipping module %s", module_name) + continue + if loaded_modules is not None and module_name in loaded_modules: + logger.info("Using module %s already provided", module_name) + components[module_name] = loaded_modules[module_name] + continue + + # we load the module from the extra config module map if it exists + if module_name in self._extra_config_module_map: + load_module_name = self._extra_config_module_map[module_name] + else: + load_module_name = module_name + + component_model_path = os.path.join(self.model_path, load_module_name) + module = PipelineComponentLoader.load_module( + module_name=load_module_name, + component_model_path=component_model_path, + transformers_or_diffusers=transformers_or_diffusers, + server_args=server_args, + ) + logger.info("Loaded module %s from %s", module_name, component_model_path) + + if module_name in components: + logger.warning("Overwriting module %s", module_name) + components[module_name] = module + + # Check if all required modules were loaded + for module_name in required_modules: + if module_name not in components or components[module_name] is None: + raise ValueError( + f"Required module key: {module_name} value: {components.get(module_name)} was not found in loaded modules {components.keys()}" + ) + + return components + + def add_stage(self, stage_name: str, stage: PipelineStage): + assert self.modules is not None, "No modules are registered" + self._stages.append(stage) + self._stage_name_mapping[stage_name] = stage + setattr(self, stage_name, stage) + + # TODO(will): don't hardcode no_grad + @torch.no_grad() + def forward( + self, + batch: Req, + server_args: ServerArgs, + ) -> Req: + """ + Generate a video or image using the pipeline. + + Args: + batch: The batch to generate from. + server_args: The inference arguments. + Returns: + Req: The batch with the generated video or image. + """ + if not self.post_init_called: + self.post_init() + + # Execute each stage + logger.info( + "Running pipeline stages: %s", + list(self._stage_name_mapping.keys()), + main_process_only=True, + ) + return self.executor.execute(self.stages, batch, server_args) diff --git a/python/sglang/multimodal_gen/runtime/pipelines/executors/parallel_executor.py b/python/sglang/multimodal_gen/runtime/pipelines/executors/parallel_executor.py new file mode 100644 index 000000000..a1a9d88fa --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/pipelines/executors/parallel_executor.py @@ -0,0 +1,92 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +from typing import List + +import torch + +from sglang.multimodal_gen.runtime.distributed import get_sp_group +from sglang.multimodal_gen.runtime.distributed.parallel_state import ( + get_cfg_group, + get_classifier_free_guidance_rank, +) +from sglang.multimodal_gen.runtime.pipelines import Req +from sglang.multimodal_gen.runtime.pipelines.executors.pipeline_executor import ( + PipelineExecutor, + Timer, +) +from sglang.multimodal_gen.runtime.pipelines.stages.base import ( + PipelineStage, + StageParallelismType, +) +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.distributed import broadcast_pyobj + + +class ParallelExecutor(PipelineExecutor): + """ + The correctness of the execution relies on the parallelism_type declared by stages + + """ + + def collect_from_main(self, batches: list[Req]): + + # TODO: fix this condition + if self.server_args.sp_degree != 1: + sp_group = get_sp_group() + batches = broadcast_pyobj( + batches, + sp_group.rank, + sp_group.cpu_group, + src=sp_group.ranks[0], + ) + + if self.server_args.enable_cfg_parallel: + batches = broadcast_pyobj( + batches, + self.worker.cfg_group.rank, + self.worker.cfg_cpu_group, + src=self.worker.cfg_group.ranks[0], + ) + + def execute( + self, + stages: List[PipelineStage], + batch: Req, + server_args: ServerArgs, + ) -> Req: + rank = get_classifier_free_guidance_rank() + cfg_rank = get_classifier_free_guidance_rank() + cfg_group = get_cfg_group() + + # TODO: decide when to gather on main when CFG_PARALLEL -> MAIN_RANK_ONLY + for stage in stages: + with Timer(stage.__class__.__name__): + paradigm = stage.parallelism_type + + if paradigm == StageParallelismType.MAIN_RANK_ONLY: + if rank == 0: + batch = stage(batch, server_args) + # obj_list = [batch] if rank == 0 else [] + # + # broadcasted_list = broadcast_pyobj( + # obj_list, rank=rank, dist_group=cfg_group.cpu_group, src=0 + # ) + # if rank != 0: + # batch = broadcasted_list[0] + torch.distributed.barrier() + + elif paradigm == StageParallelismType.CFG_PARALLEL: + obj_list = [batch] if rank == 0 else [] + broadcasted_list = broadcast_pyobj( + obj_list, rank=rank, dist_group=cfg_group.cpu_group, src=0 + ) + if rank != 0: + batch = broadcasted_list[0] + batch = stage(batch, server_args) + + torch.distributed.barrier() + + elif paradigm == StageParallelismType.REPLICATED: + batch = stage(batch, server_args) + + return batch diff --git a/python/sglang/multimodal_gen/runtime/pipelines/executors/pipeline_executor.py b/python/sglang/multimodal_gen/runtime/pipelines/executors/pipeline_executor.py new file mode 100644 index 000000000..08dc0ceb5 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/pipelines/executors/pipeline_executor.py @@ -0,0 +1,71 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +Base class for all pipeline executors. +""" +import time +from abc import ABC, abstractmethod +from typing import List + +from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import Req +from sglang.multimodal_gen.runtime.pipelines.stages import PipelineStage +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class Timer: + """ + A very simple timer that doesn't for cuda-stream to be synced + """ + + def __init__(self, name="Stage"): + self.name = name + self.start = None + self.end = None + self.elapsed = None + + def __enter__(self): + self.start = time.perf_counter() + logger.info(f"[{self.name}] started...") + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.end = time.perf_counter() + self.elapsed = self.end - self.start + logger.info(f"[{self.name}] finished in {self.elapsed:.4f} seconds") + return False + + +class PipelineExecutor(ABC): + """ + Abstract base class for all pipeline executors. + + Executors orchestrate the execution of pipeline, with managing the parallel and communications required by stages + + """ + + def __init__(self, server_args): + self.server_args = server_args + + @abstractmethod + def execute( + self, + stages: List[PipelineStage], + batch: Req, + server_args: ServerArgs, + ) -> Req: + """ + Execute the pipeline stages. + + Args: + stages: A list of pipeline stages to execute. + batch: The batch to process. + server_args: The server arguments. + + Returns: + The processed batch. + """ + raise NotImplementedError diff --git a/python/sglang/multimodal_gen/runtime/pipelines/executors/sync_executor.py b/python/sglang/multimodal_gen/runtime/pipelines/executors/sync_executor.py new file mode 100644 index 000000000..88528c51f --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/pipelines/executors/sync_executor.py @@ -0,0 +1,39 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +Synchronous pipeline executor implementation. +""" +from typing import List + +from sglang.multimodal_gen.runtime.pipelines.executors.pipeline_executor import ( + PipelineExecutor, + Timer, + logger, +) +from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import Req +from sglang.multimodal_gen.runtime.pipelines.stages import PipelineStage +from sglang.multimodal_gen.runtime.server_args import ServerArgs + + +class SyncExecutor(PipelineExecutor): + """ + A simple synchronous executor that runs stages sequentially. + """ + + def execute( + self, + stages: List[PipelineStage], + batch: Req, + server_args: ServerArgs, + ) -> Req: + """ + Execute the pipeline stages sequentially. + """ + logger.info("Running pipeline stages sequentially with SyncExecutor.") + + for stage in stages: + with Timer(stage.__class__.__name__): + batch = stage(batch, server_args) + + return batch diff --git a/python/sglang/multimodal_gen/runtime/pipelines/lora_pipeline.py b/python/sglang/multimodal_gen/runtime/pipelines/lora_pipeline.py new file mode 100644 index 000000000..4e7bc0901 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/pipelines/lora_pipeline.py @@ -0,0 +1,227 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +from collections import defaultdict +from collections.abc import Hashable +from typing import Any + +import torch +import torch.distributed as dist +from safetensors.torch import load_file + +from sglang.multimodal_gen.runtime.distributed import get_local_torch_device +from sglang.multimodal_gen.runtime.layers.lora.linear import ( + BaseLayerWithLoRA, + get_lora_layer, + replace_submodule, +) +from sglang.multimodal_gen.runtime.loader.utils import get_param_names_mapping +from sglang.multimodal_gen.runtime.pipelines.composed_pipeline_base import ( + ComposedPipelineBase, +) +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.hf_diffusers_utils import maybe_download_lora +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class LoRAPipeline(ComposedPipelineBase): + """ + Pipeline that supports injecting LoRA adapters into the diffusion transformer. + TODO: support training. + """ + + lora_adapters: dict[str, dict[str, torch.Tensor]] = defaultdict( + dict + ) # state dicts of loaded lora adapters + cur_adapter_name: str = "" + cur_adapter_path: str = "" + lora_layers: dict[str, BaseLayerWithLoRA] = {} + lora_layers_critic: dict[str, BaseLayerWithLoRA] = {} + server_args: ServerArgs + exclude_lora_layers: list[str] = [] + device: torch.device = get_local_torch_device() + lora_target_modules: list[str] | None = None + lora_path: str | None = None + lora_nickname: str = "default" + lora_rank: int | None = None + lora_alpha: int | None = None + lora_initialized: bool = False + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.device = get_local_torch_device() + self.exclude_lora_layers = self.modules[ + "transformer" + ].config.arch_config.exclude_lora_layers + self.lora_target_modules = self.server_args.lora_target_modules + self.lora_path = self.server_args.lora_path + self.lora_nickname = self.server_args.lora_nickname + if self.lora_path is not None: + self.convert_to_lora_layers() + self.set_lora_adapter( + self.lora_nickname, self.lora_path # type: ignore + ) # type: ignore + + def is_target_layer(self, module_name: str) -> bool: + if self.lora_target_modules is None: + return True + return any( + target_name in module_name for target_name in self.lora_target_modules + ) + + def convert_to_lora_layers(self) -> None: + """ + Unified method to convert the transformer to a LoRA transformer. + """ + if self.lora_initialized: + return + self.lora_initialized = True + converted_count = 0 + for name, layer in self.modules["transformer"].named_modules(): + if not self.is_target_layer(name): + continue + + excluded = False + for exclude_layer in self.exclude_lora_layers: + if exclude_layer in name: + excluded = True + break + if excluded: + continue + + layer = get_lora_layer( + layer, + lora_rank=self.lora_rank, + lora_alpha=self.lora_alpha, + ) + if layer is not None: + self.lora_layers[name] = layer + replace_submodule(self.modules["transformer"], name, layer) + converted_count += 1 + logger.info("Converted %d layers to LoRA layers", converted_count) + + if "fake_score_transformer" in self.modules: + for name, layer in self.modules["fake_score_transformer"].named_modules(): + if not self.is_target_layer(name): + continue + layer = get_lora_layer( + layer, + lora_rank=self.lora_rank, + lora_alpha=self.lora_alpha, + ) + if layer is not None: + self.lora_layers_critic[name] = layer + replace_submodule( + self.modules["fake_score_transformer"], name, layer + ) + converted_count += 1 + logger.info( + "Converted %d layers to LoRA layers in the critic model", + converted_count, + ) + + def set_lora_adapter( + self, lora_nickname: str, lora_path: str | None = None + ): # type: ignore + """ + Load a LoRA adapter into the pipeline and merge it into the transformer. + Args: + lora_nickname: The "nick name" of the adapter when referenced in the pipeline. + lora_path: The path to the adapter, either a local path or a Hugging Face repo id. + """ + + if lora_nickname not in self.lora_adapters and lora_path is None: + raise ValueError( + f"Adapter {lora_nickname} not found in the pipeline. Please provide lora_path to load it." + ) + if not self.lora_initialized: + self.convert_to_lora_layers() + adapter_updated = False + rank = dist.get_rank() + if lora_path is not None and lora_path != self.cur_adapter_path: + lora_local_path = maybe_download_lora(lora_path) + lora_state_dict = load_file(lora_local_path) + + # Map the hf layer names to our custom layer names + param_names_mapping_fn = get_param_names_mapping( + self.modules["transformer"].param_names_mapping + ) + lora_param_names_mapping_fn = get_param_names_mapping( + self.modules["transformer"].lora_param_names_mapping + ) + + to_merge_params: defaultdict[Hashable, dict[Any, Any]] = defaultdict(dict) + for name, weight in lora_state_dict.items(): + name = name.replace("diffusion_model.", "") + name = name.replace(".weight", "") + name, _, _ = lora_param_names_mapping_fn(name) + target_name, merge_index, num_params_to_merge = param_names_mapping_fn( + name + ) + # for (in_dim, r) @ (r, out_dim), we only merge (r, out_dim * n) where n is the number of linear layers to fuse + # see param mapping in HunyuanVideoArchConfig + if merge_index is not None and "lora_B" in name: + to_merge_params[target_name][merge_index] = weight + if len(to_merge_params[target_name]) == num_params_to_merge: + # cat at output dim according to the merge_index order + sorted_tensors = [ + to_merge_params[target_name][i] + for i in range(num_params_to_merge) + ] + weight = torch.cat(sorted_tensors, dim=1) + del to_merge_params[target_name] + else: + continue + + if target_name in self.lora_adapters[lora_nickname]: + raise ValueError( + f"Target name {target_name} already exists in lora_adapters[{lora_nickname}]" + ) + self.lora_adapters[lora_nickname][target_name] = weight.to(self.device) + adapter_updated = True + self.cur_adapter_path = lora_path + logger.info("Rank %d: loaded LoRA adapter %s", rank, lora_path) + + if not adapter_updated and self.cur_adapter_name == lora_nickname: + return + self.cur_adapter_name = lora_nickname + + # Merge the new adapter + adapted_count = 0 + for name, layer in self.lora_layers.items(): + lora_A_name = name + ".lora_A" + lora_B_name = name + ".lora_B" + if ( + lora_A_name in self.lora_adapters[lora_nickname] + and lora_B_name in self.lora_adapters[lora_nickname] + ): + layer.set_lora_weights( + self.lora_adapters[lora_nickname][lora_A_name], + self.lora_adapters[lora_nickname][lora_B_name], + lora_path=lora_path, + ) + adapted_count += 1 + else: + if rank == 0: + logger.warning( + "LoRA adapter %s does not contain the weights for layer %s. LoRA will not be applied to it.", + lora_path, + name, + ) + layer.disable_lora = True + logger.info( + "Rank %d: LoRA adapter %s applied to %d layers", + rank, + lora_path, + adapted_count, + ) + + def merge_lora_weights(self) -> None: + for name, layer in self.lora_layers.items(): + layer.merge_lora_weights() + + def unmerge_lora_weights(self) -> None: + for name, layer in self.lora_layers.items(): + layer.unmerge_lora_weights() diff --git a/python/sglang/multimodal_gen/runtime/pipelines/pipeline_batch_info.py b/python/sglang/multimodal_gen/runtime/pipelines/pipeline_batch_info.py new file mode 100644 index 000000000..8d6ba6dcc --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/pipelines/pipeline_batch_info.py @@ -0,0 +1,271 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Inspired by SGLang: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/model_executor/forward_batch_info.py +""" +Data structures for functional pipeline processing. + +This module defines the dataclasses used to pass state between pipeline components +in a functional manner, reducing the need for explicit parameter passing. +""" + +import pprint +from dataclasses import asdict, dataclass, field +from typing import TYPE_CHECKING, Any + +import PIL.Image +import torch + +from sglang.multimodal_gen.configs.sample.base import DataType +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.performance_logger import PerformanceLogger + +if TYPE_CHECKING: + from torchcodec.decoders import VideoDecoder + +import time +from collections import OrderedDict + +from sglang.multimodal_gen.configs.sample.teacache import ( + TeaCacheParams, + WanTeaCacheParams, +) + + +class PipelineLoggingInfo: + """Simple approach using OrderedDict to track stage metrics.""" + + def __init__(self): + # OrderedDict preserves insertion order and allows easy access + self.stages: OrderedDict[str, dict[str, Any]] = OrderedDict() + + def add_stage_execution_time(self, stage_name: str, execution_time: float): + """Add execution time for a stage.""" + if stage_name not in self.stages: + self.stages[stage_name] = {} + self.stages[stage_name]["execution_time"] = execution_time + self.stages[stage_name]["timestamp"] = time.time() + + def add_stage_metric(self, stage_name: str, metric_name: str, value: Any): + """Add any metric for a stage.""" + if stage_name not in self.stages: + self.stages[stage_name] = {} + self.stages[stage_name][metric_name] = value + + def get_stage_info(self, stage_name: str) -> dict[str, Any]: + """Get all info for a specific stage.""" + return self.stages.get(stage_name, {}) + + def get_execution_order(self) -> list[str]: + """Get stages in execution order.""" + return list(self.stages.keys()) + + def get_total_execution_time(self) -> float: + """Get total pipeline execution time.""" + return sum(stage.get("execution_time", 0) for stage in self.stages.values()) + + +@dataclass +class Req: + """ + Complete state passed through the pipeline execution. + + This dataclass contains all information needed during the diffusion pipeline + execution, allowing methods to update specific components without needing + to manage numerous individual parameters. + """ + + # TODO(will): double check that args are separate from server_args + # properly. Also maybe think about providing an abstraction for pipeline + # specific arguments. + data_type: DataType + + request_id: str | None = None + + generator: torch.Generator | list[torch.Generator] | None = None + + # Image inputs + image_path: str | None = None + # Image encoder hidden states + image_embeds: list[torch.Tensor] = field(default_factory=list) + pil_image: torch.Tensor | PIL.Image.Image | None = None + pixel_values: torch.Tensor | PIL.Image.Image | None = None + preprocessed_image: torch.Tensor | None = None + + # Text inputs + prompt: str | list[str] | None = None + negative_prompt: str | list[str] | None = None + prompt_path: str | None = None + output_path: str = "outputs/" + # without extension + output_file_name: str | None = None + output_file_ext: str | None = None + # Primary encoder embeddings + prompt_embeds: list[torch.Tensor] | torch.Tensor = field(default_factory=list) + negative_prompt_embeds: list[torch.Tensor] | None = None + prompt_attention_mask: list[torch.Tensor] | None = None + negative_attention_mask: list[torch.Tensor] | None = None + clip_embedding_pos: list[torch.Tensor] | None = None + clip_embedding_neg: list[torch.Tensor] | None = None + + pooled_embeds: list[torch.Tensor] = field(default_factory=list) + neg_pooled_embeds: list[torch.Tensor] = field(default_factory=list) + + # Additional text-related parameters + max_sequence_length: int | None = None + prompt_template: dict[str, Any] | None = None + do_classifier_free_guidance: bool = False + + # Batch info + num_outputs_per_prompt: int = 1 + seed: int | None = None + seeds: list[int] | None = None + + # Tracking if embeddings are already processed + is_prompt_processed: bool = False + + # Latent tensors + latents: torch.Tensor | None = None + raw_latent_shape: torch.Tensor | None = None + noise_pred: torch.Tensor | None = None + image_latent: torch.Tensor | None = None + + # Latent dimensions + height_latents: list[int] | int | None = None + width_latents: list[int] | int | None = None + num_frames: list[int] | int = 1 # Default for image models + num_frames_round_down: bool = ( + False # Whether to round down num_frames if it's not divisible by num_gpus + ) + + # Original dimensions (before VAE scaling) + height: list[int] | int | None = None + width: list[int] | int | None = None + fps: list[int] | int | None = None + height_not_provided: bool = False + width_not_provided: bool = False + + # Timesteps + timesteps: torch.Tensor | None = None + timestep: torch.Tensor | float | int | None = None + step_index: int | None = None + boundary_ratio: float | None = None + + # Scheduler parameters + num_inference_steps: int = 50 + guidance_scale: float = 1.0 + guidance_scale_2: float | None = None + guidance_rescale: float = 0.0 + eta: float = 0.0 + sigmas: list[float] | None = None + + n_tokens: int | None = None + + # Other parameters that may be needed by specific schedulers + extra_step_kwargs: dict[str, Any] = field(default_factory=dict) + + # Component modules (populated by the pipeline) + modules: dict[str, Any] = field(default_factory=dict) + + return_trajectory_latents: bool = False + return_trajectory_decoded: bool = False + trajectory_timesteps: list[torch.Tensor] | None = None + trajectory_latents: torch.Tensor | None = None + + # Extra parameters that might be needed by specific pipeline implementations + extra: dict[str, Any] = field(default_factory=dict) + + # Misc + save_output: bool = True + return_frames: bool = False + + # TeaCache parameters + enable_teacache: bool = False + teacache_params: TeaCacheParams | WanTeaCacheParams | None = None + + # STA parameters + STA_param: list | None = None + is_cfg_negative: bool = False + mask_search_final_result_pos: list[list] | None = None + mask_search_final_result_neg: list[list] | None = None + + # VSA parameters + VSA_sparsity: float = 0.0 + perf_logger: PerformanceLogger | None = None + + # profile + profile: bool = False + num_profiled_timesteps: int = 8 + + # debugging + debug: bool = False + + # results + output: torch.Tensor | None = None + + @property + def batch_size(self): + # Determine batch size + if isinstance(self.prompt, list): + batch_size = len(self.prompt) + elif self.prompt is not None: + batch_size = 1 + else: + batch_size = self.prompt_embeds[0].shape[0] + + # Adjust batch size for number of videos per prompt + batch_size *= self.num_outputs_per_prompt + return batch_size + + def __post_init__(self): + """Initialize dependent fields after dataclass initialization.""" + # Set do_classifier_free_guidance based on guidance scale and negative prompt + if self.guidance_scale > 1.0 and self.negative_prompt is not None: + self.do_classifier_free_guidance = True + if self.negative_prompt_embeds is None: + self.negative_prompt_embeds = [] + if self.guidance_scale_2 is None: + self.guidance_scale_2 = self.guidance_scale + + if self.perf_logger is None: + self.perf_logger = PerformanceLogger(self.request_id) + + def set_width_and_height(self, server_args: ServerArgs): + if self.height is None or self.width is None: + width, height = server_args.pipeline_config.set_width_and_height( + self.width, self.height, self.pil_image + ) + self.width = width + self.height = height + if self.height is None or self.width is None: + self.width = 1280 + self.height = 720 + + def __str__(self): + return pprint.pformat(asdict(self), indent=2, width=120) + + +@dataclass +class ForwardBatch: ... + + +@dataclass +class OutputBatch: + """ + Final output (after pipeline completion) + """ + + output: torch.Tensor | None = None + trajectory_timesteps: list[torch.Tensor] | None = None + trajectory_latents: torch.Tensor | None = None + trajectory_decoded: list[torch.Tensor] | None = None + error: str | None = None + + # Logging info + logging_info: PipelineLoggingInfo = field(default_factory=PipelineLoggingInfo) + + +@dataclass +class PreprocessBatch(Req): + video_loader: list["VideoDecoder"] | list[str] = field(default_factory=list) + video_file_name: list[str] = field(default_factory=list) diff --git a/python/sglang/multimodal_gen/runtime/pipelines/pipeline_registry.py b/python/sglang/multimodal_gen/runtime/pipelines/pipeline_registry.py new file mode 100644 index 000000000..a1605f5ca --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/pipelines/pipeline_registry.py @@ -0,0 +1,239 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/models/registry.py +# and https://github.com/sgl-project/sglang/blob/v0.4.3/python/sglang/srt/models/registry.py +import dataclasses +import importlib +import pkgutil +from collections.abc import Set +from dataclasses import dataclass +from enum import Enum +from functools import lru_cache + +from sglang.multimodal_gen.runtime.pipelines.composed_pipeline_base import ( + ComposedPipelineBase, +) +from sglang.multimodal_gen.runtime.pipelines.lora_pipeline import LoRAPipeline +from sglang.multimodal_gen.runtime.server_args import WorkloadType +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + +_PREPROCESS_WORKLOAD_TYPE_TO_PIPELINE_NAME: dict[WorkloadType, str] = { + WorkloadType.I2V: "PreprocessPipelineI2V", + WorkloadType.T2V: "PreprocessPipelineT2V", +} + + +class PipelineType(str, Enum): + """ + Enumeration for different pipeline types. + + Inherits from str to allow string comparison for backward compatibility. + """ + + BASIC = "basic" + PREPROCESS = "preprocess" + + @classmethod + def from_string(cls, value: str) -> "PipelineType": + """Convert string to PipelineType enum.""" + try: + return cls(value.lower()) + except ValueError: + raise ValueError( + f"Invalid pipeline type: {value}. Must be one of: {', '.join([t.value for t in cls])}" + ) from None + + @classmethod + def choices(cls) -> list[str]: + """Get all available choices as strings.""" + return [pipeline_type.value for pipeline_type in cls] + + +@dataclass +class _PipelineRegistry: + # Keyed by pipeline_type -> architecture -> pipeline_name + # pipelines[pipeline_type][architecture][pipeline_name] = pipeline_cls + pipelines: dict[str, dict[str, type[ComposedPipelineBase] | None]] = ( + dataclasses.field(default_factory=dict) + ) + + def get_supported_archs( + self, pipeline_name_in_config: str, pipeline_type: PipelineType + ) -> Set[str]: + """Get supported architectures, optionally filtered by pipeline type and workload type.""" + return set(self.pipelines[pipeline_type.value].keys()) + + def _load_preprocess_pipeline_cls( + self, workload_type: WorkloadType + ) -> type[ComposedPipelineBase] | None: + pipeline_name = _PREPROCESS_WORKLOAD_TYPE_TO_PIPELINE_NAME[workload_type] + + return self.pipelines[PipelineType.PREPROCESS.value][pipeline_name] + + def _try_load_pipeline_cls( + self, + pipeline_name_in_config: str, + pipeline_type: PipelineType, + workload_type: WorkloadType, + ) -> type[ComposedPipelineBase] | type[LoRAPipeline] | None: + """Try to load a pipeline class for the given architecture, pipeline type, and workload type.""" + + if pipeline_type.value not in self.pipelines: + return None + + try: + if pipeline_type == PipelineType.PREPROCESS: + return self._load_preprocess_pipeline_cls(workload_type) + elif pipeline_type == PipelineType.BASIC: + return self.pipelines[pipeline_type.value][pipeline_name_in_config] + else: + raise ValueError(f"Invalid pipeline type: {pipeline_type.value}") + except KeyError as e: + logger.error( + f"Please check if the ComposedPipeline class has been defined associated with {pipeline_type.value}.{pipeline_name_in_config}" + ) + raise e + return None + + def resolve_pipeline_cls( + self, + pipeline_name_in_config: str, + pipeline_type: PipelineType, + workload_type: WorkloadType, + ) -> type[ComposedPipelineBase] | type[LoRAPipeline]: + """Resolve pipeline class based on pipeline name in the config, pipeline type, and workload type.""" + if not pipeline_name_in_config: + logger.warning("No pipeline architecture is specified") + + pipeline_cls = self._try_load_pipeline_cls( + pipeline_name_in_config, pipeline_type, workload_type + ) + if pipeline_cls is not None: + return pipeline_cls + supported_archs = self.get_supported_archs( + pipeline_name_in_config, pipeline_type + ) + raise ValueError( + f"Pipeline architecture '{pipeline_name_in_config}' is not supported for pipeline type '{pipeline_type.value}' " + f"and workload type '{workload_type.value}'. " + f"Supported architectures: {supported_archs}" + ) + + +@lru_cache +def import_pipeline_classes( + pipeline_types: list[PipelineType] | PipelineType | None = None, +) -> dict[str, dict[str, type[ComposedPipelineBase] | None]]: + """ + Import pipeline classes based on the pipeline type and workload type. + + Args: + pipeline_types: The pipeline types to load (basic, preprocess). + If None, loads all types. + + Returns: + A three-level nested dictionary: + {pipeline_type: {architecture_name: {pipeline_name: pipeline_cls}}} + e.g., {"basic": {"wan": {"WanPipeline": WanPipeline}}} + """ + type_to_pipeline_dict: dict[str, dict[str, type[ComposedPipelineBase] | None]] = {} + package_name: str = "sglang.multimodal_gen.runtime.architectures" + + # Determine which pipeline types to scan + if isinstance(pipeline_types, list): + pipeline_types_to_scan = [ + pipeline_type.value for pipeline_type in pipeline_types + ] + elif isinstance(pipeline_types, PipelineType): + pipeline_types_to_scan = [pipeline_types.value] + else: + pipeline_types_to_scan = [pt.value for pt in PipelineType] + + logger.info("Loading pipelines for types: %s", pipeline_types_to_scan) + + for pipeline_type_str in pipeline_types_to_scan: + # Try to load from pipeline-type-specific directory first + pipeline_type_package_name = f"{package_name}.{pipeline_type_str}" + pipeline_dict: dict[str, type[ComposedPipelineBase] | None] = {} + + try: + pipeline_type_package = importlib.import_module(pipeline_type_package_name) + logger.debug("Successfully imported %s", pipeline_type_package_name) + + for _, arch, ispkg in pkgutil.iter_modules(pipeline_type_package.__path__): + + arch_package_name = f"{pipeline_type_package_name}.{arch}" + if ispkg: + arch_package = importlib.import_module(arch_package_name) + for _, module_name, ispkg in pkgutil.walk_packages( + arch_package.__path__, arch_package_name + "." + ): + if not ispkg: + pipeline_module = importlib.import_module(module_name) + if hasattr(pipeline_module, "EntryClass"): + entry_cls_list = pipeline_module.EntryClass + if not isinstance(entry_cls_list, list): + entry_cls_list = [entry_cls_list] + + if isinstance(pipeline_module.EntryClass, list): + pipeline_names = [ + pipeline.__name__ + for pipeline in pipeline_module.EntryClass + ] + else: + pipeline_names = [ + pipeline_module.EntryClass.__name__ + ] + + for entry_cls, pipeline_name in zip( + entry_cls_list, pipeline_names + ): + assert ( + pipeline_name not in pipeline_dict + ), f"Duplicated pipeline implementation for {pipeline_name} in {pipeline_type_str}.{arch_package_name}" + + assert hasattr( + entry_cls, "pipeline_name" + ), f"{entry_cls}" + pipeline_dict[pipeline_name] = entry_cls + + type_to_pipeline_dict[pipeline_type_str] = pipeline_dict + + except ImportError as e: + raise ImportError( + f"Could not import {pipeline_type_package_name} when importing pipeline classes: {e}" + ) from None + + # Log summary + total_pipelines = sum( + len(pipeline_dict) for pipeline_dict in type_to_pipeline_dict.values() + ) + logger.info( + "Loaded %d pipeline classes across %d types", + total_pipelines, + len(pipeline_types_to_scan), + ) + + return type_to_pipeline_dict + + +def get_pipeline_registry( + pipeline_type: PipelineType | str | None = None, +) -> _PipelineRegistry: + """ + Get a pipeline registry for the specified mode, pipeline type, and workload type. + + Args: + pipeline_type: Pipeline type to load. If None and mode is provided, will be derived from mode. + + Returns: + A pipeline registry instance. + """ + if isinstance(pipeline_type, str): + pipeline_type = PipelineType.from_string(pipeline_type) + + pipeline_classes = import_pipeline_classes(pipeline_type) + return _PipelineRegistry(pipeline_classes) diff --git a/python/sglang/multimodal_gen/runtime/pipelines/stages/__init__.py b/python/sglang/multimodal_gen/runtime/pipelines/stages/__init__.py new file mode 100644 index 000000000..062d4cd8e --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/pipelines/stages/__init__.py @@ -0,0 +1,59 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +Pipeline stages for diffusion models. + +This package contains the various stages that can be composed to create +complete diffusion pipelines. +""" + +from sglang.multimodal_gen.runtime.pipelines.stages.base import PipelineStage +from sglang.multimodal_gen.runtime.pipelines.stages.causal_denoising import ( + CausalDMDDenoisingStage, +) +from sglang.multimodal_gen.runtime.pipelines.stages.conditioning import ( + ConditioningStage, +) +from sglang.multimodal_gen.runtime.pipelines.stages.decoding import DecodingStage +from sglang.multimodal_gen.runtime.pipelines.stages.denoising import DenoisingStage +from sglang.multimodal_gen.runtime.pipelines.stages.denoising_dmd import ( + DmdDenoisingStage, +) +from sglang.multimodal_gen.runtime.pipelines.stages.encoding import EncodingStage +from sglang.multimodal_gen.runtime.pipelines.stages.image_encoding import ( + ImageEncodingStage, + ImageVAEEncodingStage, +) +from sglang.multimodal_gen.runtime.pipelines.stages.input_validation import ( + InputValidationStage, +) +from sglang.multimodal_gen.runtime.pipelines.stages.latent_preparation import ( + LatentPreparationStage, +) +from sglang.multimodal_gen.runtime.pipelines.stages.stepvideo_encoding import ( + StepvideoPromptEncodingStage, +) +from sglang.multimodal_gen.runtime.pipelines.stages.text_encoding import ( + TextEncodingStage, +) +from sglang.multimodal_gen.runtime.pipelines.stages.timestep_preparation import ( + TimestepPreparationStage, +) + +__all__ = [ + "PipelineStage", + "InputValidationStage", + "TimestepPreparationStage", + "LatentPreparationStage", + "ConditioningStage", + "DenoisingStage", + "DmdDenoisingStage", + "CausalDMDDenoisingStage", + "EncodingStage", + "DecodingStage", + "ImageEncodingStage", + "ImageVAEEncodingStage", + "TextEncodingStage", + "StepvideoPromptEncodingStage", +] diff --git a/python/sglang/multimodal_gen/runtime/pipelines/stages/base.py b/python/sglang/multimodal_gen/runtime/pipelines/stages/base.py new file mode 100644 index 000000000..eb89dbe7c --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/pipelines/stages/base.py @@ -0,0 +1,254 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +Base classes for pipeline stages. + +This module defines the abstract base classes for pipeline stages that can be +composed to create complete diffusion pipelines. +""" + +import time +import traceback +from abc import ABC, abstractmethod +from enum import Enum, auto + +import torch + +import sglang.multimodal_gen.envs as envs +from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import Req +from sglang.multimodal_gen.runtime.pipelines.stages.validators import VerificationResult +from sglang.multimodal_gen.runtime.server_args import ServerArgs, get_global_server_args +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class StageParallelismType(Enum): + # execute on all gpus + REPLICATED = auto() + # executed on main rank only + MAIN_RANK_ONLY = auto() + # this stage requires a cfg-parallel + CFG_PARALLEL = auto() + + +class StageVerificationError(Exception): + """Exception raised when stage verification fails.""" + + pass + + +class PipelineStage(ABC): + """ + Abstract base class for all pipeline stages. + + A pipeline stage represents a discrete step in the diffusion process that can be + composed with other stages to create a complete pipeline. Each stage is responsible + for a specific part of the process, such as prompt encoding, latent preparation, etc. + """ + + def __init__(self): + self.server_args = get_global_server_args() + + def log_info(self, msg, *args): + """Logs an informational message with the stage name as a prefix.""" + logger.info(f"[{self.__class__.__name__}] {msg}", *args) + + def log_warning(self, msg, *args): + """Logs a warning message with the stage name as a prefix.""" + logger.warning(f"[{self.__class__.__name__}] {msg}", *args) + + def log_error(self, msg, *args): + """Logs an error message with the stage name as a prefix.""" + logger.error(f"[{self.__class__.__name__}] {msg}", *args) + + def log_debug(self, msg, *args): + """Logs a debug message with the stage name as a prefix.""" + logger.debug(f"[{self.__class__.__name__}] {msg}", *args) + + def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """ + Verify the input for the stage. + + Example: + from sglang.multimodal_gen.runtime.pipelines.stages.validators import V, VerificationResult + + def verify_input(self, batch, server_args): + result = VerificationResult() + result.add_check("height", batch.height, V.positive_int_divisible(8)) + result.add_check("width", batch.width, V.positive_int_divisible(8)) + result.add_check("image_latent", batch.image_latent, V.is_tensor) + return result + + Args: + batch: The current batch information. + server_args: The inference arguments. + + Returns: + A VerificationResult containing the verification status. + + """ + # Default implementation - no verification + return VerificationResult() + + # execute on all ranks by default + @property + def parallelism_type(self) -> StageParallelismType: + # if get_global_server_args().enable_cfg_parallel: + # return StageParallelismType.MAIN_RANK_ONLY + return StageParallelismType.REPLICATED + + def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """ + Verify the output for the stage. + + Args: + batch: The current batch information. + server_args: The inference arguments. + + Returns: + A VerificationResult containing the verification status. + """ + # Default implementation - no verification + return VerificationResult() + + def _run_verification( + self, + verification_result: VerificationResult, + stage_name: str, + verification_type: str, + ) -> None: + """ + Run verification and raise errors if any checks fail. + + Args: + verification_result: Results from verify_input or verify_output + stage_name: Name of the current stage + verification_type: "input" or "output" + """ + if not verification_result.is_valid(): + failed_fields = verification_result.get_failed_fields() + if failed_fields: + # Get detailed failure information + detailed_summary = verification_result.get_failure_summary() + + failed_fields_str = ", ".join(failed_fields) + error_msg = ( + f"{verification_type.capitalize()} verification failed for {stage_name}: " + f"Failed fields: {failed_fields_str}\n" + f"Details: {detailed_summary}" + ) + raise StageVerificationError(error_msg) + + @property + def device(self) -> torch.device: + """Get the device for this stage.""" + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + + def set_logging(self, enable: bool): + """ + Enable or disable logging for this stage. + + Args: + enable: Whether to enable logging. + """ + self._enable_logging = enable + + def __call__( + self, + batch: Req, + server_args: ServerArgs, + ) -> Req: + """ + Execute the stage's processing on the batch with optional verification and logging. + Should not be overridden by subclasses. + + Args: + batch: The current batch information. + server_args: The inference arguments. + + Returns: + The updated batch information after this stage's processing. + """ + stage_name = self.__class__.__name__ + # Check if verification is enabled (simple approach for prototype) + enable_verification = getattr(server_args, "enable_stage_verification", False) + + if enable_verification: + # Pre-execution input verification + try: + input_result = self.verify_input(batch, server_args) + self._run_verification(input_result, stage_name, "input") + except Exception as e: + logger.error("Input verification failed for %s: %s", stage_name, str(e)) + raise + + # Execute the actual stage logic + if envs.SGL_DIFFUSION_STAGE_LOGGING: + logger.info("[%s] Starting execution", stage_name) + start_time = time.perf_counter() + + try: + result = self.forward(batch, server_args) + execution_time = time.perf_counter() - start_time + logger.info( + "[%s] Execution completed in %s ms", + stage_name, + execution_time * 1000, + ) + batch.logging_info.add_stage_execution_time(stage_name, execution_time) + except Exception as e: + execution_time = time.perf_counter() - start_time + logger.error( + "[%s] Error during execution after %s ms: %s", + stage_name, + execution_time * 1000, + e, + ) + logger.error("[%s] Traceback: %s", stage_name, traceback.format_exc()) + raise + else: + # Direct execution (current behavior) + result = self.forward(batch, server_args) + + if enable_verification: + # Post-execution output verification + try: + output_result = self.verify_output(result, server_args) + self._run_verification(output_result, stage_name, "output") + except Exception as e: + logger.error( + "Output verification failed for %s: %s", stage_name, str(e) + ) + raise + + return result + + @abstractmethod + def forward( + self, + batch: Req, + server_args: ServerArgs, + ) -> Req: + """ + Forward pass of the stage's processing. + + This method should be implemented by subclasses to provide the forward + processing logic for the stage. + + Args: + batch: The current batch information. + server_args: The inference arguments. + + Returns: + The updated batch information after this stage's processing. + """ + raise NotImplementedError + + def backward( + self, + batch: Req, + server_args: ServerArgs, + ) -> Req: + raise NotImplementedError diff --git a/python/sglang/multimodal_gen/runtime/pipelines/stages/causal_denoising.py b/python/sglang/multimodal_gen/runtime/pipelines/stages/causal_denoising.py new file mode 100644 index 000000000..689be4541 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/pipelines/stages/causal_denoising.py @@ -0,0 +1,506 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +import torch # type: ignore + +from sglang.multimodal_gen.runtime.distributed import get_local_torch_device +from sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context +from sglang.multimodal_gen.runtime.models.utils import pred_noise_to_pred_video +from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import Req +from sglang.multimodal_gen.runtime.pipelines.stages.denoising import DenoisingStage +from sglang.multimodal_gen.runtime.pipelines.stages.validators import ( + StageValidators as V, +) +from sglang.multimodal_gen.runtime.pipelines.stages.validators import VerificationResult +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +try: + from sglang.multimodal_gen.runtime.layers.attention.backends.sliding_tile_attn import ( + SlidingTileAttentionBackend, + ) + + st_attn_available = True +except ImportError: + st_attn_available = False + SlidingTileAttentionBackend = None # type: ignore + +try: + from sglang.multimodal_gen.runtime.layers.attention.backends.video_sparse_attn import ( + VideoSparseAttentionBackend, + ) + + vsa_available = True +except ImportError: + vsa_available = False + VideoSparseAttentionBackend = None # type: ignore + +logger = init_logger(__name__) + + +class CausalDMDDenoisingStage(DenoisingStage): + """ + Denoising stage for causal diffusion. + """ + + def __init__(self, transformer, scheduler) -> None: + super().__init__(transformer, scheduler) + # KV and cross-attention cache state (initialized on first forward) + self.kv_cache1: list | None = None + self.crossattn_cache: list | None = None + # Model-dependent constants (aligned with causal_inference.py assumptions) + self.num_transformer_blocks = self.transformer.config.arch_config.num_layers + self.num_frames_per_block = ( + self.transformer.config.arch_config.num_frames_per_block + ) + self.sliding_window_num_frames = ( + self.transformer.config.arch_config.sliding_window_num_frames + ) + + try: + self.local_attn_size = getattr( + self.transformer.model, "local_attn_size", -1 + ) # type: ignore + except Exception: + self.local_attn_size = -1 + + def forward( + self, + batch: Req, + server_args: ServerArgs, + ) -> Req: + target_dtype = torch.bfloat16 + autocast_enabled = ( + target_dtype != torch.float32 + ) and not server_args.disable_autocast + + latent_seq_length = batch.latents.shape[-1] * batch.latents.shape[-2] + patch_ratio = ( + self.transformer.config.arch_config.patch_size[-1] + * self.transformer.config.arch_config.patch_size[-2] + ) + self.frame_seq_length = latent_seq_length // patch_ratio + # TODO(will): make this a parameter once we add i2v support + independent_first_frame = self.transformer.independent_first_frame + + # Timesteps for DMD + timesteps = torch.tensor( + server_args.pipeline_config.dmd_denoising_steps, dtype=torch.long + ).cpu() + + if server_args.pipeline_config.warp_denoising_step: + logger.info("Warping timesteps...") + scheduler_timesteps = torch.cat( + (self.scheduler.timesteps.cpu(), torch.tensor([0], dtype=torch.float32)) + ) + timesteps = scheduler_timesteps[1000 - timesteps] + timesteps = timesteps.to(get_local_torch_device()) + logger.info("Using timesteps: %s", timesteps) + + # Image kwargs (kept empty unless caller provides compatible args) + image_kwargs: dict = {} + + pos_cond_kwargs = self.prepare_extra_func_kwargs( + self.transformer.forward, + { + # "encoder_hidden_states_2": batch.clip_embedding_pos, + "encoder_attention_mask": batch.prompt_attention_mask, + }, + ) + + # STA + if st_attn_available and self.attn_backend == SlidingTileAttentionBackend: + self.prepare_sta_param(batch, server_args) + + # Latents and prompts + assert batch.latents is not None, "latents must be provided" + latents = batch.latents # [B, C, T, H, W] + b, c, t, h, w = latents.shape + prompt_embeds = batch.prompt_embeds + assert torch.isnan(prompt_embeds[0]).sum() == 0 + + # Initialize or reset caches + if self.kv_cache1 is None: + self._initialize_kv_cache( + batch_size=latents.shape[0], dtype=target_dtype, device=latents.device + ) + self._initialize_crossattn_cache( + batch_size=latents.shape[0], + max_text_len=server_args.pipeline_config.text_encoder_configs[ + 0 + ].arch_config.text_len, + dtype=target_dtype, + device=latents.device, + ) + else: + assert self.crossattn_cache is not None + # reset cross-attention cache + for block_index in range(self.num_transformer_blocks): + self.crossattn_cache[block_index]["is_init"] = False # type: ignore + # reset kv cache pointers + for block_index in range(len(self.kv_cache1)): + self.kv_cache1[block_index]["global_end_index"] = ( + torch.tensor( # type: ignore + [0], dtype=torch.long, device=latents.device + ) + ) + self.kv_cache1[block_index]["local_end_index"] = ( + torch.tensor( # type: ignore + [0], dtype=torch.long, device=latents.device + ) + ) + + # Optional: cache context features from provided image latents prior to generation + current_start_frame = 0 + if getattr(batch, "image_latent", None) is not None: + image_latent = batch.image_latent + assert image_latent is not None + input_frames = image_latent.shape[2] + # timestep zero (or configured context noise) for cache warm-up + t_zero = torch.zeros( + [latents.shape[0]], device=latents.device, dtype=torch.long + ) + if independent_first_frame and input_frames >= 1: + # warm-up with the very first frame independently + image_first_btchw = ( + image_latent[:, :, :1, :, :].to(target_dtype).permute(0, 2, 1, 3, 4) + ) + with torch.autocast( + device_type="cuda", dtype=target_dtype, enabled=autocast_enabled + ): + _ = self.transformer( + image_first_btchw, + prompt_embeds, + t_zero, + kv_cache=self.kv_cache1, + crossattn_cache=self.crossattn_cache, + current_start=current_start_frame * self.frame_seq_length, + **image_kwargs, + **pos_cond_kwargs, + ) + current_start_frame += 1 + remaining_frames = input_frames - 1 + else: + remaining_frames = input_frames + + # process remaining input frames in blocks of num_frame_per_block + while remaining_frames > 0: + block = min(self.num_frames_per_block, remaining_frames) + ref_btchw = ( + image_latent[ + :, :, current_start_frame : current_start_frame + block, :, : + ] + .to(target_dtype) + .permute(0, 2, 1, 3, 4) + ) + with torch.autocast( + device_type="cuda", dtype=target_dtype, enabled=autocast_enabled + ): + _ = self.transformer( + ref_btchw, + prompt_embeds, + t_zero, + kv_cache=self.kv_cache1, + crossattn_cache=self.crossattn_cache, + current_start=current_start_frame * self.frame_seq_length, + **image_kwargs, + **pos_cond_kwargs, + ) + current_start_frame += block + remaining_frames -= block + + # Base position offset from any cache warm-up + pos_start_base = current_start_frame + + # Determine block sizes + if not independent_first_frame or ( + independent_first_frame and batch.image_latent is not None + ): + if t % self.num_frames_per_block != 0: + raise ValueError( + "num_frames must be divisible by num_frames_per_block for causal DMD denoising" + ) + num_blocks = t // self.num_frames_per_block + block_sizes = [self.num_frames_per_block] * num_blocks + start_index = 0 + else: + if (t - 1) % self.num_frames_per_block != 0: + raise ValueError( + "(num_frames - 1) must be divisible by num_frame_per_block when independent_first_frame=True" + ) + num_blocks = (t - 1) // self.num_frames_per_block + block_sizes = [1] + [self.num_frames_per_block] * num_blocks + start_index = 0 + + # DMD loop in causal blocks + with self.progress_bar(total=len(block_sizes) * len(timesteps)) as progress_bar: + for current_num_frames in block_sizes: + current_latents = latents[ + :, :, start_index : start_index + current_num_frames, :, : + ] + # use BTCHW for DMD conversion routines + noise_latents_btchw = current_latents.permute(0, 2, 1, 3, 4) + video_raw_latent_shape = noise_latents_btchw.shape + + for i, t_cur in enumerate(timesteps): + # Copy for pred conversion + noise_latents = noise_latents_btchw.clone() + latent_model_input = current_latents.to(target_dtype) + + if ( + batch.image_latent is not None + and independent_first_frame + and start_index == 0 + ): + latent_model_input = torch.cat( + [latent_model_input, batch.image_latent.to(target_dtype)], + dim=2, + ) + + # Prepare inputs + t_expand = t_cur.repeat(latent_model_input.shape[0]) + + # Attention metadata if needed + if ( + vsa_available + and self.attn_backend == VideoSparseAttentionBackend + ): + self.attn_metadata_builder_cls = ( + self.attn_backend.get_builder_cls() + ) + if self.attn_metadata_builder_cls is not None: + self.attn_metadata_builder = ( + self.attn_metadata_builder_cls() + ) + attn_metadata = self.attn_metadata_builder.build( # type: ignore + current_timestep=i, # type: ignore + raw_latent_shape=( + current_num_frames, + h, + w, + ), # type: ignore + patch_size=server_args.pipeline_config.dit_config.patch_size, # type: ignore + STA_param=batch.STA_param, # type: ignore + VSA_sparsity=server_args.VSA_sparsity, # type: ignore + device=get_local_torch_device(), # type: ignore + ) # type: ignore + assert ( + attn_metadata is not None + ), "attn_metadata cannot be None" + else: + attn_metadata = None + else: + attn_metadata = None + + with ( + torch.autocast( + device_type="cuda", + dtype=target_dtype, + enabled=autocast_enabled, + ), + set_forward_context( + current_timestep=i, + attn_metadata=attn_metadata, + forward_batch=batch, + ), + ): + # Run transformer; follow DMD stage pattern + t_expanded_noise = t_cur * torch.ones( + (latent_model_input.shape[0], 1), + device=latent_model_input.device, + dtype=torch.long, + ) + pred_noise_btchw = self.transformer( + latent_model_input, + prompt_embeds, + t_expanded_noise, + kv_cache=self.kv_cache1, + crossattn_cache=self.crossattn_cache, + current_start=(pos_start_base + start_index) + * self.frame_seq_length, + start_frame=start_index, + **image_kwargs, + **pos_cond_kwargs, + ).permute(0, 2, 1, 3, 4) + + # Convert pred noise to pred video with FM Euler scheduler utilities + pred_video_btchw = pred_noise_to_pred_video( + pred_noise=pred_noise_btchw.flatten(0, 1), + noise_input_latent=noise_latents.flatten(0, 1), + timestep=t_expand, + scheduler=self.scheduler, + ).unflatten(0, pred_noise_btchw.shape[:2]) + + if i < len(timesteps) - 1: + next_timestep = timesteps[i + 1] * torch.ones( + [1], dtype=torch.long, device=pred_video_btchw.device + ) + noise = torch.randn( + video_raw_latent_shape, + dtype=pred_video_btchw.dtype, + generator=( + batch.generator[0] + if isinstance(batch.generator, list) + else batch.generator + ), + ).to(self.device) + noise_btchw = noise + noise_latents_btchw = self.scheduler.add_noise( + pred_video_btchw.flatten(0, 1), + noise_btchw.flatten(0, 1), + next_timestep, + ).unflatten(0, pred_video_btchw.shape[:2]) + current_latents = noise_latents_btchw.permute(0, 2, 1, 3, 4) + else: + current_latents = pred_video_btchw.permute(0, 2, 1, 3, 4) + + if progress_bar is not None: + progress_bar.update() + + # Write back and advance + latents[:, :, start_index : start_index + current_num_frames, :, :] = ( + current_latents + ) + + # Re-run with context timestep to update KV cache using clean context + context_noise = getattr(server_args.pipeline_config, "context_noise", 0) + t_context = torch.ones( + [latents.shape[0]], device=latents.device, dtype=torch.long + ) * int(context_noise) + context_bcthw = current_latents.to(target_dtype) + with ( + torch.autocast( + device_type="cuda", dtype=target_dtype, enabled=autocast_enabled + ), + set_forward_context( + current_timestep=0, + attn_metadata=attn_metadata, + forward_batch=batch, + ), + ): + t_expanded_context = t_context.unsqueeze(1) + _ = self.transformer( + context_bcthw, + prompt_embeds, + t_expanded_context, + kv_cache=self.kv_cache1, + crossattn_cache=self.crossattn_cache, + current_start=(pos_start_base + start_index) + * self.frame_seq_length, + start_frame=start_index, + **image_kwargs, + **pos_cond_kwargs, + ) + start_index += current_num_frames + + batch.latents = latents + return batch + + def _initialize_kv_cache(self, batch_size, dtype, device) -> None: + """ + Initialize a Per-GPU KV cache aligned with the Wan model assumptions. + """ + kv_cache1 = [] + num_attention_heads = self.transformer.num_attention_heads + attention_head_dim = self.transformer.attention_head_dim + if self.local_attn_size != -1: + kv_cache_size = self.local_attn_size * self.frame_seq_length + else: + kv_cache_size = self.frame_seq_length * self.sliding_window_num_frames + + for _ in range(self.num_transformer_blocks): + kv_cache1.append( + { + "k": torch.zeros( + [ + batch_size, + kv_cache_size, + num_attention_heads, + attention_head_dim, + ], + dtype=dtype, + device=device, + ), + "v": torch.zeros( + [ + batch_size, + kv_cache_size, + num_attention_heads, + attention_head_dim, + ], + dtype=dtype, + device=device, + ), + "global_end_index": torch.tensor( + [0], dtype=torch.long, device=device + ), + "local_end_index": torch.tensor( + [0], dtype=torch.long, device=device + ), + } + ) + + self.kv_cache1 = kv_cache1 + + def _initialize_crossattn_cache( + self, batch_size, max_text_len, dtype, device + ) -> None: + """ + Initialize a Per-GPU cross-attention cache aligned with the Wan model assumptions. + """ + crossattn_cache = [] + num_attention_heads = self.transformer.num_attention_heads + attention_head_dim = self.transformer.attention_head_dim + for _ in range(self.num_transformer_blocks): + crossattn_cache.append( + { + "k": torch.zeros( + [ + batch_size, + max_text_len, + num_attention_heads, + attention_head_dim, + ], + dtype=dtype, + device=device, + ), + "v": torch.zeros( + [ + batch_size, + max_text_len, + num_attention_heads, + attention_head_dim, + ], + dtype=dtype, + device=device, + ), + "is_init": False, + } + ) + self.crossattn_cache = crossattn_cache + + def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """Verify denoising stage inputs.""" + result = VerificationResult() + result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)]) + result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty) + result.add_check("image_embeds", batch.image_embeds, V.is_list) + result.add_check( + "image_latent", batch.image_latent, V.none_or_tensor_with_dims(5) + ) + result.add_check( + "num_inference_steps", batch.num_inference_steps, V.positive_int + ) + result.add_check("guidance_scale", batch.guidance_scale, V.positive_float) + result.add_check("eta", batch.eta, V.non_negative_float) + result.add_check("generator", batch.generator, V.generator_or_list_generators) + result.add_check( + "do_classifier_free_guidance", + batch.do_classifier_free_guidance, + V.bool_value, + ) + result.add_check( + "negative_prompt_embeds", + batch.negative_prompt_embeds, + lambda x: not batch.do_classifier_free_guidance or V.list_not_empty(x), + ) + return result diff --git a/python/sglang/multimodal_gen/runtime/pipelines/stages/conditioning.py b/python/sglang/multimodal_gen/runtime/pipelines/stages/conditioning.py new file mode 100644 index 000000000..fb47b2948 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/pipelines/stages/conditioning.py @@ -0,0 +1,105 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +Conditioning stage for diffusion pipelines. +""" + +import torch + +from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import Req +from sglang.multimodal_gen.runtime.pipelines.stages.base import PipelineStage +from sglang.multimodal_gen.runtime.pipelines.stages.validators import ( + StageValidators as V, +) +from sglang.multimodal_gen.runtime.pipelines.stages.validators import VerificationResult +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class ConditioningStage(PipelineStage): + """ + Stage for applying conditioning to the diffusion process. + + This stage handles the application of conditioning, such as classifier-free guidance, + to the diffusion process. + """ + + @torch.no_grad() + def forward( + self, + batch: Req, + server_args: ServerArgs, + ) -> Req: + """ + Apply conditioning to the diffusion process. + + Args: + batch: The current batch information. + server_args: The inference arguments. + + Returns: + The batch with applied conditioning. + """ + # TODO!! + if not batch.do_classifier_free_guidance: + return batch + else: + return batch + + logger.info("batch.negative_prompt_embeds: %s", batch.negative_prompt_embeds) + logger.info( + "do_classifier_free_guidance: %s", batch.do_classifier_free_guidance + ) + logger.info("cfg_scale: %s", batch.guidance_scale) + + # Ensure negative prompt embeddings are available + assert ( + batch.negative_prompt_embeds is not None + ), "Negative prompt embeddings are required for classifier-free guidance" + + # Concatenate primary embeddings and masks + batch.prompt_embeds = torch.cat( + [batch.negative_prompt_embeds, batch.prompt_embeds] + ) + if batch.attention_mask is not None: + batch.attention_mask = torch.cat( + [batch.negative_attention_mask, batch.attention_mask] + ) + + # Concatenate secondary embeddings and masks if present + if batch.prompt_embeds_2 is not None: + batch.prompt_embeds_2 = torch.cat( + [batch.negative_prompt_embeds_2, batch.prompt_embeds_2] + ) + if batch.attention_mask_2 is not None: + batch.attention_mask_2 = torch.cat( + [batch.negative_attention_mask_2, batch.attention_mask_2] + ) + + return batch + + def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """Verify conditioning stage inputs.""" + result = VerificationResult() + result.add_check( + "do_classifier_free_guidance", + batch.do_classifier_free_guidance, + V.bool_value, + ) + result.add_check("guidance_scale", batch.guidance_scale, V.positive_float) + result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty) + result.add_check( + "negative_prompt_embeds", + batch.negative_prompt_embeds, + lambda x: not batch.do_classifier_free_guidance or V.list_not_empty(x), + ) + return result + + def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """Verify conditioning stage outputs.""" + result = VerificationResult() + result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty) + return result diff --git a/python/sglang/multimodal_gen/runtime/pipelines/stages/decoding.py b/python/sglang/multimodal_gen/runtime/pipelines/stages/decoding.py new file mode 100644 index 000000000..0728586f5 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/pipelines/stages/decoding.py @@ -0,0 +1,232 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +Decoding stage for diffusion pipelines. +""" + +import weakref + +import torch + +from sglang.multimodal_gen.configs.models.vaes.base import VAEArchConfig +from sglang.multimodal_gen.configs.pipelines.qwen_image import ( + QwenImageEditPipelineConfig, + QwenImagePipelineConfig, +) +from sglang.multimodal_gen.runtime.distributed import get_local_torch_device +from sglang.multimodal_gen.runtime.loader.component_loader import VAELoader +from sglang.multimodal_gen.runtime.models.vaes.common import ParallelTiledVAE +from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import OutputBatch, Req +from sglang.multimodal_gen.runtime.pipelines.stages.base import ( + PipelineStage, + StageParallelismType, +) +from sglang.multimodal_gen.runtime.pipelines.stages.validators import VerificationResult +from sglang.multimodal_gen.runtime.server_args import ServerArgs, get_global_server_args +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import PRECISION_TO_TYPE + +logger = init_logger(__name__) + + +class DecodingStage(PipelineStage): + """ + Stage for decoding latent representations into pixel space. + + This stage handles the decoding of latent representations into the final + output format (e.g., pixel values). + """ + + def __init__(self, vae, pipeline=None) -> None: + self.vae: ParallelTiledVAE = vae + self.pipeline = weakref.ref(pipeline) if pipeline else None + + @property + def parallelism_type(self) -> StageParallelismType: + if get_global_server_args().enable_cfg_parallel: + return StageParallelismType.MAIN_RANK_ONLY + return StageParallelismType.REPLICATED + + def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """Verify decoding stage inputs.""" + result = VerificationResult() + # Denoised latents for VAE decoding: [batch_size, channels, frames, height_latents, width_latents] + # result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)]) + return result + + def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """Verify decoding stage outputs.""" + result = VerificationResult() + # Decoded video/images: [batch_size, channels, frames, height, width] + # result.add_check("output", batch.output, [V.is_tensor, V.with_dims(5)]) + return result + + def scale_and_shift( + self, vae_arch_config: VAEArchConfig, latents: torch.Tensor, server_args + ): + # 1. scale + is_qwen_image = isinstance( + server_args.pipeline_config, QwenImagePipelineConfig + ) or isinstance(server_args.pipeline_config, QwenImageEditPipelineConfig) + if is_qwen_image: + scaling_factor = 1.0 / torch.tensor( + vae_arch_config.latents_std, device=latents.device + ).view(1, vae_arch_config.z_dim, 1, 1, 1).to(latents.device, latents.dtype) + else: + scaling_factor = vae_arch_config.scaling_factor + + if isinstance(scaling_factor, torch.Tensor): + latents = latents / scaling_factor.to(latents.device, latents.dtype) + else: + latents = latents / scaling_factor + + # 2. shift + if is_qwen_image: + shift_factor = ( + torch.tensor(vae_arch_config.latents_mean) + .view(1, vae_arch_config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + else: + shift_factor = getattr(vae_arch_config, "shift_factor", None) + + # Apply shifting if needed + if shift_factor is not None: + if isinstance(shift_factor, torch.Tensor): + latents += shift_factor.to(latents.device, latents.dtype) + else: + latents += shift_factor + return latents + + @torch.no_grad() + def decode(self, latents: torch.Tensor, server_args: ServerArgs) -> torch.Tensor: + """ + Decode latent representations into pixel space using VAE. + + Args: + latents: Input latent tensor with shape (batch, channels, frames, height_latents, width_latents) + server_args: Configuration containing: + - disable_autocast: Whether to disable automatic mixed precision (default: False) + - pipeline_config.vae_precision: VAE computation precision ("fp32", "fp16", "bf16") + - pipeline_config.vae_tiling: Whether to enable VAE tiling for memory efficiency + + Returns: + Decoded video tensor with shape (batch, channels, frames, height, width), + normalized to [0, 1] range and moved to CPU as float32 + """ + self.vae = self.vae.to(get_local_torch_device()) + latents = latents.to(get_local_torch_device()) + # Setup VAE precision + vae_dtype = PRECISION_TO_TYPE[server_args.pipeline_config.vae_precision] + vae_autocast_enabled = ( + vae_dtype != torch.float32 + ) and not server_args.disable_autocast + vae_arch_config = server_args.pipeline_config.vae_config.arch_config + + # scale and shift + latents = self.scale_and_shift(vae_arch_config, latents, server_args) + + # Decode latents + with torch.autocast( + device_type="cuda", dtype=vae_dtype, enabled=vae_autocast_enabled + ): + try: + # TODO: make it more specific + if server_args.pipeline_config.vae_tiling: + self.vae.enable_tiling() + except Exception: + pass + if not vae_autocast_enabled: + latents = latents.to(vae_dtype) + image = self.vae.decode(latents) + + # De-normalize image to [0, 1] range + image = (image / 2 + 0.5).clamp(0, 1) + return image + + @torch.no_grad() + def forward( + self, + batch: Req, + server_args: ServerArgs, + ) -> OutputBatch: + """ + Decode latent representations into pixel space. + + This method processes the batch through the VAE decoder, converting latent + representations to pixel-space video/images. It also optionally decodes + trajectory latents for visualization purposes. + + Args: + batch: The current batch containing: + - latents: Tensor to decode (batch, channels, frames, height_latents, width_latents) + - return_trajectory_decoded (optional): Flag to decode trajectory latents + - trajectory_latents (optional): Latents at different timesteps + - trajectory_timesteps (optional): Corresponding timesteps + server_args: Configuration containing: + - output_type: "latent" to skip decoding, otherwise decode to pixels + - vae_cpu_offload: Whether to offload VAE to CPU after decoding + - model_loaded: Track VAE loading state + - model_paths: Path to VAE model if loading needed + + Returns: + Modified batch with: + - output: Decoded frames (batch, channels, frames, height, width) as CPU float32 + - trajectory_decoded (if requested): List of decoded frames per timestep + """ + # load vae if not already loaded (used for memory constrained devices) + pipeline = self.pipeline() if self.pipeline else None + if not server_args.model_loaded["vae"]: + loader = VAELoader() + self.vae = loader.load(server_args.model_paths["vae"], server_args) + if pipeline: + pipeline.add_module("vae", self.vae) + server_args.model_loaded["vae"] = True + + if server_args.output_type == "latent": + frames = batch.latents + else: + frames = self.decode(batch.latents, server_args) + + # decode trajectory latents if needed + if batch.return_trajectory_decoded: + trajectory_decoded = [] + assert ( + batch.trajectory_latents is not None + ), "batch should have trajectory latents" + for idx in range(batch.trajectory_latents.shape[1]): + # batch.trajectory_latents is [batch_size, timesteps, channels, frames, height, width] + cur_latent = batch.trajectory_latents[:, idx, :, :, :, :] + cur_timestep = batch.trajectory_timesteps[idx] + logger.info("decoding trajectory latent for timestep: %s", cur_timestep) + decoded_frames = self.decode(cur_latent, server_args) + trajectory_decoded.append(decoded_frames.cpu().float()) + else: + trajectory_decoded = None + + # Convert to CPU float32 for compatibility + frames = frames.cpu().float() + + # Update batch with decoded image + output_batch = OutputBatch( + output=frames, + trajectory_timesteps=batch.trajectory_timesteps, + trajectory_latents=batch.trajectory_latents, + trajectory_decoded=trajectory_decoded, + ) + + # Offload models if needed + if hasattr(self, "maybe_free_model_hooks"): + self.maybe_free_model_hooks() + + if server_args.vae_cpu_offload: + self.vae.to("cpu") + + if torch.backends.mps.is_available(): + del self.vae + if pipeline is not None and "vae" in pipeline.modules: + del pipeline.modules["vae"] + server_args.model_loaded["vae"] = False + + return output_batch diff --git a/python/sglang/multimodal_gen/runtime/pipelines/stages/denoising.py b/python/sglang/multimodal_gen/runtime/pipelines/stages/denoising.py new file mode 100644 index 000000000..8ed4a0557 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/pipelines/stages/denoising.py @@ -0,0 +1,1217 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +Denoising stage for diffusion pipelines. +""" + +import inspect +import math +import os +import time +import weakref +from collections.abc import Iterable +from functools import lru_cache +from typing import Any + +import torch +import torch.profiler +from einops import rearrange +from tqdm.auto import tqdm + +from sglang.multimodal_gen.configs.pipelines.base import STA_Mode +from sglang.multimodal_gen.runtime.distributed import ( + cfg_model_parallel_all_reduce, + get_local_torch_device, + get_sp_parallel_rank, + get_sp_world_size, + get_world_group, +) +from sglang.multimodal_gen.runtime.distributed.communication_op import ( + sequence_model_parallel_all_gather, +) +from sglang.multimodal_gen.runtime.distributed.parallel_state import ( + get_cfg_group, + get_classifier_free_guidance_rank, +) +from sglang.multimodal_gen.runtime.layers.attention.backends.flash_attn import ( + FlashAttentionBackend, +) +from sglang.multimodal_gen.runtime.layers.attention.selector import get_attn_backend +from sglang.multimodal_gen.runtime.layers.attention.STA_configuration import ( + configure_sta, + save_mask_search_results, +) +from sglang.multimodal_gen.runtime.loader.component_loader import TransformerLoader +from sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context +from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import Req +from sglang.multimodal_gen.runtime.pipelines.stages.base import ( + PipelineStage, + StageParallelismType, +) +from sglang.multimodal_gen.runtime.pipelines.stages.validators import ( + StageValidators as V, +) +from sglang.multimodal_gen.runtime.pipelines.stages.validators import VerificationResult +from sglang.multimodal_gen.runtime.platforms.interface import AttentionBackendEnum +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import dict_to_3d_list, masks_like + +try: + from sglang.multimodal_gen.runtime.layers.attention.backends.sliding_tile_attn import ( + SlidingTileAttentionBackend, + ) + + st_attn_available = True +except ImportError: + st_attn_available = False + +try: + from sglang.multimodal_gen.runtime.layers.attention.backends.vmoba import ( + VMOBAAttentionBackend, + ) + from sglang.multimodal_gen.utils import is_vmoba_available + + vmoba_attn_available = is_vmoba_available() +except ImportError: + vmoba_attn_available = False + +try: + from sglang.multimodal_gen.runtime.layers.attention.backends.video_sparse_attn import ( + VideoSparseAttentionBackend, + ) + + vsa_available = True +except ImportError: + vsa_available = False + +logger = init_logger(__name__) + + +class DenoisingStage(PipelineStage): + """ + Stage for running the denoising loop in diffusion pipelines. + + This stage handles the iterative denoising process that transforms + the initial noise into the final output. + """ + + def __init__( + self, transformer, scheduler, pipeline=None, transformer_2=None, vae=None + ) -> None: + super().__init__() + self.transformer = transformer + self.transformer_2 = transformer_2 + + hidden_size = self.server_args.pipeline_config.dit_config.hidden_size + num_attention_heads = ( + self.server_args.pipeline_config.dit_config.num_attention_heads + ) + attn_head_size = hidden_size // num_attention_heads + + # torch compile + if self.server_args.enable_torch_compile: + full_graph = False + self.transformer = torch.compile( + self.transformer, mode="max-autotune", fullgraph=full_graph + ) + self.transformer_2 = ( + torch.compile( + self.transformer_2, mode="max-autotune", fullgraph=full_graph + ) + if transformer_2 is not None + else None + ) + + self.scheduler = scheduler + self.vae = vae + self.pipeline = weakref.ref(pipeline) if pipeline else None + + self.attn_backend = get_attn_backend( + head_size=attn_head_size, + dtype=torch.float16, # TODO(will): hack + supported_attention_backends={ + AttentionBackendEnum.SLIDING_TILE_ATTN, + AttentionBackendEnum.VIDEO_SPARSE_ATTN, + AttentionBackendEnum.VMOBA_ATTN, + AttentionBackendEnum.FA3, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.SAGE_ATTN_THREE, + }, # hack + ) + + # cfg + self.guidance = None + + # misc + self.profiler = None + + @lru_cache(maxsize=8) + def _build_guidance(self, batch_size, target_dtype, device, guidance_val): + """Builds a guidance tensor. This method is cached.""" + return ( + torch.full( + (batch_size,), + guidance_val, + dtype=torch.float32, + device=device, + ).to(target_dtype) + * 1000.0 + ) + + def get_or_build_guidance(self, bsz: int, dtype, device): + """ + Get the guidance tensor, using a cached version if available. + + This method retrieves a cached guidance tensor using `_build_guidance`. + The caching is based on batch size, dtype, device, and the guidance value, + preventing repeated tensor creation within the denoising loop. + """ + if self.server_args.pipeline_config.should_use_guidance: + # TODO: should the guidance_scale be picked-up from sampling_params? + guidance_val = self.server_args.pipeline_config.embedded_cfg_scale + return self._build_guidance(bsz, dtype, device, guidance_val) + else: + return None + + @property + def parallelism_type(self) -> StageParallelismType: + # return StageParallelismType.CFG_PARALLEL if get_global_server_args().enable_cfg_parallel else StageParallelismType.REPLICATED + return StageParallelismType.REPLICATED + + def _prepare_denoising_loop(self, batch: Req, server_args: ServerArgs): + """ + Prepare all necessary invariant variables for the denoising loop. + + Args: + batch: The current batch information. + server_args: The inference arguments. + + Returns: + A dictionary containing all the prepared variables for the denoising loop. + """ + pipeline = self.pipeline() if self.pipeline else None + if not server_args.model_loaded["transformer"]: + loader = TransformerLoader() + self.transformer = loader.load( + server_args.model_paths["transformer"], server_args + ) + if self.server_args.enable_torch_compile: + self.transformer = torch.compile( + self.transformer, mode="max-autotune", fullgraph=True + ) + if pipeline: + pipeline.add_module("transformer", self.transformer) + server_args.model_loaded["transformer"] = True + + # Prepare extra step kwargs for scheduler + extra_step_kwargs = self.prepare_extra_func_kwargs( + self.scheduler.step, + {"generator": batch.generator, "eta": batch.eta}, + ) + + # Setup precision and autocast settings + target_dtype = torch.bfloat16 + autocast_enabled = ( + target_dtype != torch.float32 + ) and not server_args.disable_autocast + + # Handle sequence parallelism if enabled + self._preprocess_sp_latents(batch) + + # Get timesteps and calculate warmup steps + timesteps = batch.timesteps + if timesteps is None: + raise ValueError("Timesteps must be provided") + num_inference_steps = batch.num_inference_steps + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + # Prepare image latents and embeddings for I2V generation + image_embeds = batch.image_embeds + if len(image_embeds) > 0: + image_embeds = [ + image_embed.to(target_dtype) for image_embed in image_embeds + ] + + # Prepare STA parameters + if st_attn_available and self.attn_backend == SlidingTileAttentionBackend: + self.prepare_sta_param(batch, server_args) + + # Get latents and embeddings + latents = batch.latents + prompt_embeds = batch.prompt_embeds + # Removed Tensor truthiness assert to avoid GPU sync + neg_prompt_embeds = None + if batch.do_classifier_free_guidance: + neg_prompt_embeds = batch.negative_prompt_embeds + assert neg_prompt_embeds is not None + # Removed Tensor truthiness assert to avoid GPU sync + + # (Wan2.2) Calculate timestep to switch from high noise expert to low noise expert + boundary_ratio = server_args.pipeline_config.dit_config.boundary_ratio + if batch.boundary_ratio is not None: + logger.info( + "Overriding boundary ratio from %s to %s", + boundary_ratio, + batch.boundary_ratio, + ) + boundary_ratio = batch.boundary_ratio + + if boundary_ratio is not None: + boundary_timestep = boundary_ratio * self.scheduler.num_train_timesteps + else: + boundary_timestep = None + + # TI2V specific preparations + z, mask2, seq_len = None, None, None + # FIXME: should probably move to latent preparation stage, to handle with offload + if server_args.pipeline_config.ti2v_task and batch.pil_image is not None: + # Wan2.2 TI2V directly replaces the first frame of the latent with + # the image latent instead of appending along the channel dim + assert batch.image_latent is None, "TI2V task should not have image latents" + assert self.vae is not None, "VAE is not provided for TI2V task" + self.vae = self.vae.to(batch.pil_image.device) + z = self.vae.encode(batch.pil_image).mean.float() + if self.vae.device != "cpu" and server_args.vae_cpu_offload: + self.vae = self.vae.to("cpu") + if hasattr(self.vae, "shift_factor") and self.vae.shift_factor is not None: + if isinstance(self.vae.shift_factor, torch.Tensor): + z -= self.vae.shift_factor.to(z.device, z.dtype) + else: + z -= self.vae.shift_factor + + if isinstance(self.vae.scaling_factor, torch.Tensor): + z = z * self.vae.scaling_factor.to(z.device, z.dtype) + else: + z = z * self.vae.scaling_factor + latent_model_input = latents.to(target_dtype).squeeze(0) + _, mask2 = masks_like([latent_model_input], zero=True) + + latents = (1.0 - mask2[0]) * z + mask2[0] * latent_model_input + latents = latents.to(get_local_torch_device()) + + F = batch.num_frames + temporal_scale = ( + server_args.pipeline_config.vae_config.arch_config.scale_factor_temporal + ) + spatial_scale = ( + server_args.pipeline_config.vae_config.arch_config.scale_factor_spatial + ) + patch_size = server_args.pipeline_config.dit_config.arch_config.patch_size + seq_len = ( + ((F - 1) // temporal_scale + 1) + * (batch.height // spatial_scale) + * (batch.width // spatial_scale) + // (patch_size[1] * patch_size[2]) + ) + seq_len = ( + int(math.ceil(seq_len / get_sp_world_size())) * get_sp_world_size() + ) + + guidance = self.get_or_build_guidance( + # TODO: replace with raw_latent_shape? + latents.shape[0], + latents.dtype, + latents.device, + ) + + image_kwargs = self.prepare_extra_func_kwargs( + self.transformer.forward, + { + # TODO: make sure on-device + "encoder_hidden_states_image": image_embeds, + "mask_strategy": dict_to_3d_list(None, t_max=50, l_max=60, h_max=24), + }, + ) + + pos_cond_kwargs = self.prepare_extra_func_kwargs( + self.transformer.forward, + { + "encoder_hidden_states_2": batch.clip_embedding_pos, + "encoder_attention_mask": batch.prompt_attention_mask, + } + | server_args.pipeline_config.prepare_pos_cond_kwargs( + batch, + self.device, + getattr(self.transformer, "rotary_emb", None), + dtype=target_dtype, + ), + ) + + if batch.do_classifier_free_guidance: + neg_cond_kwargs = self.prepare_extra_func_kwargs( + self.transformer.forward, + { + "encoder_hidden_states_2": batch.clip_embedding_neg, + "encoder_attention_mask": batch.negative_attention_mask, + } + | server_args.pipeline_config.prepare_neg_cond_kwargs( + batch, + self.device, + getattr(self.transformer, "rotary_emb", None), + dtype=target_dtype, + ), + ) + else: + neg_cond_kwargs = {} + + return { + "extra_step_kwargs": extra_step_kwargs, + "target_dtype": target_dtype, + "autocast_enabled": autocast_enabled, + "timesteps": timesteps, + "num_inference_steps": num_inference_steps, + "num_warmup_steps": num_warmup_steps, + "image_kwargs": image_kwargs, + "pos_cond_kwargs": pos_cond_kwargs, + "neg_cond_kwargs": neg_cond_kwargs, + "latents": latents, + "prompt_embeds": prompt_embeds, + "neg_prompt_embeds": neg_prompt_embeds, + "boundary_timestep": boundary_timestep, + "z": z, + "mask2": mask2, + "seq_len": seq_len, + "guidance": guidance, + } + + def _post_denoising_loop( + self, + batch: Req, + latents: torch.Tensor, + trajectory_latents: list, + trajectory_timesteps: list, + server_args: ServerArgs, + ): + # Gather results if using sequence parallelism + if trajectory_latents: + trajectory_tensor = torch.stack(trajectory_latents, dim=1) + trajectory_timesteps_tensor = torch.stack(trajectory_timesteps, dim=0) + else: + trajectory_tensor = None + trajectory_timesteps_tensor = None + + # Gather results if using sequence parallelism + latents, trajectory_tensor = self._postprocess_sp_latents( + batch, latents, trajectory_tensor + ) + + if trajectory_tensor is not None and trajectory_timesteps_tensor is not None: + batch.trajectory_timesteps = trajectory_timesteps_tensor.cpu() + batch.trajectory_latents = trajectory_tensor.cpu() + + # Update batch with final latents + batch.latents = self.server_args.pipeline_config.post_denoising_loop( + latents, batch + ) + + # Save STA mask search results if needed + if ( + st_attn_available + and self.attn_backend == SlidingTileAttentionBackend + and server_args.STA_mode == STA_Mode.STA_SEARCHING + ): + self.save_sta_search_results(batch) + + # deallocate transformer if on mps + pipeline = self.pipeline() if self.pipeline else None + if torch.backends.mps.is_available(): + logger.info( + "Memory before deallocating transformer: %s", + torch.mps.current_allocated_memory(), + ) + del self.transformer + if pipeline is not None and "transformer" in pipeline.modules: + del pipeline.modules["transformer"] + server_args.model_loaded["transformer"] = False + logger.info( + "Memory after deallocating transformer: %s", + torch.mps.current_allocated_memory(), + ) + + def _preprocess_sp_latents(self, batch: Req): + """Shard latents for Sequence Parallelism if applicable.""" + sp_world_size, rank_in_sp_group = get_sp_world_size(), get_sp_parallel_rank() + if get_sp_world_size() <= 1: + batch.did_sp_shard_latents = False + return + + def _shard_tensor( + tensor: torch.Tensor | None, + ) -> tuple[torch.Tensor | None, bool]: + if tensor is None: + return None, False + + if tensor.dim() == 5: + time_dim = tensor.shape[2] + if time_dim > 0 and time_dim % sp_world_size == 0: + sharded_tensor = rearrange( + tensor, "b c (n t) h w -> b c n t h w", n=sp_world_size + ).contiguous() + sharded_tensor = sharded_tensor[:, :, rank_in_sp_group, :, :, :] + return sharded_tensor, True + + # For 4D image tensors or unsharded 5D tensors, return as is. + return tensor, False + + batch.latents, did_shard = _shard_tensor(batch.latents) + batch.did_sp_shard_latents = did_shard + + # image_latent is sharded independently, but the decision to all-gather later + # is based on whether the main `latents` was sharded. + if batch.image_latent is not None: + batch.image_latent, _ = _shard_tensor(batch.image_latent) + + def _postprocess_sp_latents( + self, + batch: Req, + latents: torch.Tensor, + trajectory_tensor: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """Gather latents after Sequence Parallelism if they were sharded.""" + if get_sp_world_size() > 1 and getattr(batch, "did_sp_shard_latents", False): + latents = sequence_model_parallel_all_gather(latents, dim=2) + if trajectory_tensor is not None: + # trajectory_tensor shape: [b, num_steps, c, t_local, h, w] -> gather on dim 3 + trajectory_tensor = trajectory_tensor.to(get_local_torch_device()) + trajectory_tensor = sequence_model_parallel_all_gather( + trajectory_tensor, dim=3 + ) + return latents, trajectory_tensor + + def start_profile(self, batch: Req): + if not batch.profile: + return + + logger.info("Starting Profiler...") + # Build activities dynamically to avoid CUDA hangs when CUDA is unavailable + activities = [torch.profiler.ProfilerActivity.CPU] + if torch.cuda.is_available(): + activities.append(torch.profiler.ProfilerActivity.CUDA) + + prof = torch.profiler.profile( + activities=activities, + schedule=torch.profiler.schedule( + skip_first=0, + wait=0, + warmup=5, + active=batch.num_profiled_timesteps, + repeat=5, + ), + on_trace_ready=lambda _: torch.profiler.tensorboard_trace_handler( + f"./logs" + ), + record_shapes=True, + with_stack=True, + ) + prof.start() + self.profiler = prof + + def step_profile(self): + if self.profiler: + if torch.cuda.is_available(): + torch.cuda.synchronize() + self.profiler.step() + + def stop_profile(self, batch: Req): + try: + if self.profiler: + logger.info("Stopping Profiler...") + if torch.cuda.is_available(): + torch.cuda.synchronize() + self.profiler.stop() + request_id = batch.request_id if batch.request_id else "profile_trace" + log_dir = f"./logs" + os.makedirs(log_dir, exist_ok=True) + + trace_path = os.path.abspath( + os.path.join(log_dir, f"{request_id}.trace.json.gz") + ) + logger.info(f"Saving profiler traces to: {trace_path}") + self.profiler.export_chrome_trace(trace_path) + except Exception as e: + logger.error(f"{e}") + + def _manage_device_placement( + self, + model_to_use: torch.nn.Module, + model_to_offload: torch.nn.Module | None, + server_args: ServerArgs, + ): + """ + Manages the offload / load behavior of dit + """ + if not server_args.dit_cpu_offload: + return + + # Offload the unused model if it's on CUDA + if ( + model_to_offload is not None + and next(model_to_offload.parameters()).device.type == "cuda" + ): + model_to_offload.to("cpu") + + # Load the model to use if it's on CPU + if ( + model_to_use is not None + and next(model_to_use.parameters()).device.type == "cpu" + ): + model_to_use.to(get_local_torch_device()) + + def _select_and_manage_model( + self, + t_int: int, + boundary_timestep: float | None, + server_args: ServerArgs, + batch: Req, + ): + if boundary_timestep is None or t_int >= boundary_timestep: + # High-noise stage + current_model = self.transformer + model_to_offload = self.transformer_2 + current_guidance_scale = batch.guidance_scale + else: + # Low-noise stage + current_model = self.transformer_2 + model_to_offload = self.transformer + current_guidance_scale = batch.guidance_scale_2 + + self._manage_device_placement(current_model, model_to_offload, server_args) + + assert current_model is not None, "The model for the current step is not set." + return current_model, current_guidance_scale + + @torch.no_grad() + def forward( + self, + batch: Req, + server_args: ServerArgs, + ) -> Req: + """ + Run the denoising loop. + + Args: + batch: The current batch information. + server_args: The inference arguments. + + Returns: + The batch with denoised latents. + """ + # Prepare variables for the denoising loop + + prepared_vars = self._prepare_denoising_loop(batch, server_args) + extra_step_kwargs = prepared_vars["extra_step_kwargs"] + target_dtype = prepared_vars["target_dtype"] + autocast_enabled = prepared_vars["autocast_enabled"] + timesteps = prepared_vars["timesteps"] + num_inference_steps = prepared_vars["num_inference_steps"] + num_warmup_steps = prepared_vars["num_warmup_steps"] + image_kwargs = prepared_vars["image_kwargs"] + pos_cond_kwargs = prepared_vars["pos_cond_kwargs"] + neg_cond_kwargs = prepared_vars["neg_cond_kwargs"] + latents = prepared_vars["latents"] + boundary_timestep = prepared_vars["boundary_timestep"] + z = prepared_vars["z"] + mask2 = prepared_vars["mask2"] + seq_len = prepared_vars["seq_len"] + guidance = prepared_vars["guidance"] + + # Initialize lists for ODE trajectory + trajectory_timesteps: list[torch.Tensor] = [] + trajectory_latents: list[torch.Tensor] = [] + + # Run denoising loop + denoising_start_time = time.time() + + self.start_profile(batch=batch) + + # to avoid device-sync caused by timestep comparison + timesteps_cpu = timesteps.cpu() + num_timesteps = timesteps_cpu.shape[0] + with torch.autocast( + device_type=("cuda" if torch.cuda.is_available() else "cpu"), + dtype=target_dtype, + enabled=autocast_enabled, + ): + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t_host in enumerate(timesteps_cpu): + if batch.perf_logger: + batch.perf_logger.record_step_start() + # Skip if interrupted + if hasattr(self, "interrupt") and self.interrupt: + continue + + t_int = int(t_host.item()) + t_device = timesteps[i] + current_model, current_guidance_scale = ( + self._select_and_manage_model( + t_int=t_int, + boundary_timestep=boundary_timestep, + server_args=server_args, + batch=batch, + ) + ) + + # Expand latents for I2V + latent_model_input = latents.to(target_dtype) + if batch.image_latent is not None: + assert ( + not server_args.pipeline_config.ti2v_task + ), "image latents should not be provided for TI2V task" + latent_model_input = torch.cat( + [latent_model_input, batch.image_latent], dim=1 + ).to(target_dtype) + + # expand timestep + if ( + server_args.pipeline_config.ti2v_task + and batch.pil_image is not None + ): + timestep = torch.stack([t_device]).to(get_local_torch_device()) + temp_ts = (mask2[0][0][:, ::2, ::2] * timestep).flatten() + temp_ts = torch.cat( + [ + temp_ts, + temp_ts.new_ones(seq_len - temp_ts.size(0)) * timestep, + ] + ) + timestep = temp_ts.unsqueeze(0) + t_expand = timestep.repeat(latent_model_input.shape[0], 1) + else: + t_expand = t_device.repeat(latent_model_input.shape[0]) + + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t_device + ) + + # Predict noise residual + attn_metadata = self._build_attn_metadata(i, batch, server_args) + noise_pred = self._predict_noise_with_cfg( + current_model, + latent_model_input, + t_expand, + batch, + i, + attn_metadata, + target_dtype, + current_guidance_scale, + image_kwargs, + pos_cond_kwargs, + neg_cond_kwargs, + server_args, + guidance=guidance, + latents=latents, + ) + + if batch.perf_logger: + batch.perf_logger.record_step_end("denoising_step_guided", i) + # Compute the previous noisy sample + latents = self.scheduler.step( + model_output=noise_pred, + timestep=t_device, + sample=latents, + **extra_step_kwargs, + return_dict=False, + )[0] + if ( + server_args.pipeline_config.ti2v_task + and batch.pil_image is not None + ): + latents = latents.squeeze(0) + latents = (1.0 - mask2[0]) * z + mask2[0] * latents + + # save trajectory latents if needed + if batch.return_trajectory_latents: + trajectory_timesteps.append(t_host) + trajectory_latents.append(latents) + + # Update progress bar + if i == num_timesteps - 1 or ( + (i + 1) > num_warmup_steps + and (i + 1) % self.scheduler.order == 0 + and progress_bar is not None + ): + progress_bar.update() + + self.step_profile() + + self.stop_profile(batch) + + denoising_end_time = time.time() + + if num_timesteps > 0: + self.log_info( + "Average time per step: %.4f seconds", + (denoising_end_time - denoising_start_time) / len(timesteps), + ) + + self._post_denoising_loop( + batch=batch, + latents=latents, + trajectory_latents=trajectory_latents, + trajectory_timesteps=trajectory_timesteps, + server_args=server_args, + ) + return batch + + # TODO: this will extends the preparation stage, should let subclass/passed-in variables decide which to prepare + def prepare_extra_func_kwargs(self, func, kwargs) -> dict[str, Any]: + """ + Prepare extra kwargs for the scheduler step / denoise step. + + Args: + func: The function to prepare kwargs for. + kwargs: The kwargs to prepare. + + Returns: + The prepared kwargs. + """ + extra_step_kwargs = {} + for k, v in kwargs.items(): + accepts = k in set(inspect.signature(func).parameters.keys()) + if accepts: + extra_step_kwargs[k] = v + return extra_step_kwargs + + def progress_bar( + self, iterable: Iterable | None = None, total: int | None = None + ) -> tqdm: + """ + Create a progress bar for the denoising process. + + Args: + iterable: The iterable to iterate over. + total: The total number of items. + + Returns: + A tqdm progress bar. + """ + local_rank = get_world_group().local_rank + if local_rank == 0: + return tqdm(iterable=iterable, total=total) + else: + return tqdm(iterable=iterable, total=total, disable=True) + + def rescale_noise_cfg( + self, noise_cfg, noise_pred_text, guidance_rescale=0.0 + ) -> torch.Tensor: + """ + Rescale noise prediction according to guidance_rescale. + + Based on findings of "Common Diffusion Noise Schedules and Sample Steps are Flawed" + (https://arxiv.org/pdf/2305.08891.pdf), Section 3.4. + + Args: + noise_cfg: The noise prediction with guidance. + noise_pred_text: The text-conditioned noise prediction. + guidance_rescale: The guidance rescale factor. + + Returns: + The rescaled noise prediction. + """ + std_text = noise_pred_text.std( + dim=list(range(1, noise_pred_text.ndim)), keepdim=True + ) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # Rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # Mix with the original results from guidance by factor guidance_rescale + noise_cfg = ( + guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + ) + return noise_cfg + + def _build_attn_metadata( + self, i: int, batch: Req, server_args: ServerArgs + ) -> Any | None: + """ + Build attention metadata for custom attention backends. + + Args: + i: The current timestep index. + batch: The current batch information. + server_args: The inference arguments. + + Returns: + The attention metadata, or None if not applicable. + """ + attn_metadata = None + self.attn_metadata_builder_cls = self.attn_backend.get_builder_cls() + if self.attn_metadata_builder_cls: + self.attn_metadata_builder = self.attn_metadata_builder_cls() + if (st_attn_available and self.attn_backend == SlidingTileAttentionBackend) or ( + vsa_available and self.attn_backend == VideoSparseAttentionBackend + ): + attn_metadata = self.attn_metadata_builder.build( + current_timestep=i, + raw_latent_shape=batch.raw_latent_shape[2:5], + patch_size=server_args.pipeline_config.dit_config.patch_size, + STA_param=batch.STA_param, + VSA_sparsity=server_args.VSA_sparsity, + device=get_local_torch_device(), + ) + elif vmoba_attn_available and self.attn_backend == VMOBAAttentionBackend: + moba_params = server_args.moba_config.copy() + moba_params.update( + { + "current_timestep": i, + "raw_latent_shape": batch.raw_latent_shape[2:5], + "patch_size": server_args.pipeline_config.dit_config.patch_size, + "device": get_local_torch_device(), + } + ) + elif self.attn_backend == FlashAttentionBackend: + attn_metadata = self.attn_metadata_builder.build( + raw_latent_shape=batch.raw_latent_shape + ) + else: + return None + + assert attn_metadata is not None, "attn_metadata cannot be None" + + return attn_metadata + + def _predict_noise( + self, + current_model, + latent_model_input, + t_expand, + prompt_embeds, + target_dtype, + guidance: torch.Tensor, + **kwargs, + ): + return current_model( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=t_expand, + guidance=guidance, + **kwargs, + ) + + def _predict_noise_with_cfg( + self, + current_model: torch.nn.Module, + latent_model_input: torch.Tensor, + t_expand, + batch, + timestep_index: int, + attn_metadata, + target_dtype, + current_guidance_scale, + image_kwargs: dict[str, Any], + pos_cond_kwargs: dict[str, Any], + neg_cond_kwargs: dict[str, Any], + server_args, + guidance, + latents, + ): + """ + Predict the noise residual with classifier-free guidance. + + Args: + current_model: The transformer model to use for the current step. + latent_model_input: The input latents for the model. + t_expand: The expanded timestep tensor. + batch: The current batch information. + timestep_index: The current timestep index. + attn_metadata: Attention metadata for custom backends. + target_dtype: The target data type for autocasting. + current_guidance_scale: The guidance scale for the current step. + image_kwargs: Keyword arguments for image conditioning. + pos_cond_kwargs: Keyword arguments for positive prompt conditioning. + neg_cond_kwargs: Keyword arguments for negative prompt conditioning. + + Returns: + The predicted noise. + """ + noise_pred_cond: torch.Tensor | None = None + noise_pred_uncond: torch.Tensor | None = None + cfg_rank = get_classifier_free_guidance_rank() + # positive pass + if not (server_args.enable_cfg_parallel and cfg_rank != 0): + batch.is_cfg_negative = False + with set_forward_context( + current_timestep=timestep_index, + attn_metadata=attn_metadata, + forward_batch=batch, + ): + noise_pred_cond = self._predict_noise( + current_model=current_model, + latent_model_input=latent_model_input, + t_expand=t_expand, + prompt_embeds=server_args.pipeline_config.get_pos_prompt_embeds( + batch + ), + target_dtype=target_dtype, + guidance=guidance, + **image_kwargs, + **pos_cond_kwargs, + ) + # TODO: can it be moved to after _predict_noise_with_cfg? + noise_pred_cond = server_args.pipeline_config.slice_noise_pred( + noise_pred_cond, latents + ) + if not batch.do_classifier_free_guidance: + # If CFG is disabled, we are done. Return the conditional prediction. + return noise_pred_cond + + # negative pass + if not server_args.enable_cfg_parallel or cfg_rank != 0: + batch.is_cfg_negative = True + with set_forward_context( + current_timestep=timestep_index, + attn_metadata=attn_metadata, + forward_batch=batch, + ): + noise_pred_uncond = self._predict_noise( + current_model=current_model, + latent_model_input=latent_model_input, + t_expand=t_expand, + prompt_embeds=server_args.pipeline_config.get_neg_prompt_embeds( + batch + ), + target_dtype=target_dtype, + guidance=guidance, + **image_kwargs, + **neg_cond_kwargs, + ) + noise_pred_uncond = server_args.pipeline_config.slice_noise_pred( + noise_pred_uncond, latents + ) + + # Combine predictions + if server_args.enable_cfg_parallel: + # Each rank computes its partial contribution and we sum via all-reduce: + # final = s*cond + (1-s)*uncond + if cfg_rank == 0: + assert noise_pred_cond is not None + partial = current_guidance_scale * noise_pred_cond + else: + assert noise_pred_uncond is not None + partial = (1 - current_guidance_scale) * noise_pred_uncond + + noise_pred = cfg_model_parallel_all_reduce(partial) + + # Guidance rescale: broadcast std(cond) from rank 0, compute std(cfg) locally + if batch.guidance_rescale > 0.0: + std_cfg = noise_pred.std( + dim=list(range(1, noise_pred.ndim)), keepdim=True + ) + if cfg_rank == 0: + assert noise_pred_cond is not None + std_text = noise_pred_cond.std( + dim=list(range(1, noise_pred_cond.ndim)), keepdim=True + ) + else: + std_text = torch.empty_like(std_cfg) + # Broadcast std_text from local src=0 to all ranks in CFG group + std_text = get_cfg_group().broadcast(std_text, src=0) + noise_pred_rescaled = noise_pred * (std_text / std_cfg) + noise_pred = ( + batch.guidance_rescale * noise_pred_rescaled + + (1 - batch.guidance_rescale) * noise_pred + ) + return noise_pred + else: + # Serial CFG: both cond and uncond are available locally + assert noise_pred_cond is not None and noise_pred_uncond is not None + noise_pred = noise_pred_uncond + current_guidance_scale * ( + noise_pred_cond - noise_pred_uncond + ) + + if batch.guidance_rescale > 0.0: + noise_pred = self.rescale_noise_cfg( + noise_pred, + noise_pred_cond, + guidance_rescale=batch.guidance_rescale, + ) + return noise_pred + + def prepare_sta_param(self, batch: Req, server_args: ServerArgs): + """ + Prepare Sliding Tile Attention (STA) parameters and settings. + + Args: + batch: The current batch information. + server_args: The inference arguments. + """ + # TODO(kevin): STA mask search, currently only support Wan2.1 with 69x768x1280 + STA_mode = server_args.STA_mode + skip_time_steps = server_args.skip_time_steps + if batch.timesteps is None: + raise ValueError("Timesteps must be provided") + timesteps_num = batch.timesteps.shape[0] + + logger.info("STA_mode: %s", STA_mode) + if (batch.num_frames, batch.height, batch.width) != ( + 69, + 768, + 1280, + ) and STA_mode != "STA_inference": + raise NotImplementedError( + "STA mask search/tuning is not supported for this resolution" + ) + + if ( + STA_mode == STA_Mode.STA_SEARCHING + or STA_mode == STA_Mode.STA_TUNING + or STA_mode == STA_Mode.STA_TUNING_CFG + ): + size = (batch.width, batch.height) + if size == (1280, 768): + # TODO: make it configurable + sparse_mask_candidates_searching = [ + "3, 1, 10", + "1, 5, 7", + "3, 3, 3", + "1, 6, 5", + "1, 3, 10", + "3, 6, 1", + ] + sparse_mask_candidates_tuning = [ + "3, 1, 10", + "1, 5, 7", + "3, 3, 3", + "1, 6, 5", + "1, 3, 10", + "3, 6, 1", + ] + full_mask = ["3,6,10"] + else: + raise NotImplementedError( + "STA mask search is not supported for this resolution" + ) + layer_num = self.transformer.config.num_layers + # specific for HunyuanVideo + if hasattr(self.transformer.config, "num_single_layers"): + layer_num += self.transformer.config.num_single_layers + head_num = self.transformer.config.num_attention_heads + + if STA_mode == STA_Mode.STA_SEARCHING: + STA_param = configure_sta( + mode=STA_Mode.STA_SEARCHING, + layer_num=layer_num, + head_num=head_num, + time_step_num=timesteps_num, + mask_candidates=sparse_mask_candidates_searching + full_mask, + # last is full mask; Can add more sparse masks while keep last one as full mask + ) + elif STA_mode == STA_Mode.STA_TUNING: + STA_param = configure_sta( + mode=STA_Mode.STA_TUNING, + layer_num=layer_num, + head_num=head_num, + time_step_num=timesteps_num, + mask_search_files_path=f"output/mask_search_result_pos_{size[0]}x{size[1]}/", + mask_candidates=sparse_mask_candidates_tuning, + full_attention_mask=[int(x) for x in full_mask[0].split(",")], + skip_time_steps=skip_time_steps, # Use full attention for first 12 steps + save_dir=f"output/mask_search_strategy_{size[0]}x{size[1]}/", # Custom save directory + timesteps=timesteps_num, + ) + elif STA_mode == STA_Mode.STA_TUNING_CFG: + STA_param = configure_sta( + mode=STA_Mode.STA_TUNING_CFG, + layer_num=layer_num, + head_num=head_num, + time_step_num=timesteps_num, + mask_search_files_path_pos=f"output/mask_search_result_pos_{size[0]}x{size[1]}/", + mask_search_files_path_neg=f"output/mask_search_result_neg_{size[0]}x{size[1]}/", + mask_candidates=sparse_mask_candidates_tuning, + full_attention_mask=[int(x) for x in full_mask[0].split(",")], + skip_time_steps=skip_time_steps, + save_dir=f"output/mask_search_strategy_{size[0]}x{size[1]}/", + timesteps=timesteps_num, + ) + elif STA_mode == STA_Mode.STA_INFERENCE: + import sglang.multimodal_gen.envs as envs + + config_file = envs.SGL_DIFFUSION_ATTENTION_CONFIG + if config_file is None: + raise ValueError("SGL_DIFFUSION_ATTENTION_CONFIG is not set") + STA_param = configure_sta( + mode=STA_Mode.STA_INFERENCE, + layer_num=layer_num, + head_num=head_num, + time_step_num=timesteps_num, + load_path=config_file, + ) + + batch.STA_param = STA_param + batch.mask_search_final_result_pos = [[] for _ in range(timesteps_num)] + batch.mask_search_final_result_neg = [[] for _ in range(timesteps_num)] + + def save_sta_search_results(self, batch: Req): + """ + Save the STA mask search results. + + Args: + batch: The current batch information. + """ + size = (batch.width, batch.height) + if size == (1280, 768): + # TODO: make it configurable + sparse_mask_candidates_searching = [ + "3, 1, 10", + "1, 5, 7", + "3, 3, 3", + "1, 6, 5", + "1, 3, 10", + "3, 6, 1", + ] + else: + raise NotImplementedError( + "STA mask search is not supported for this resolution" + ) + + if batch.mask_search_final_result_pos is not None and batch.prompt is not None: + save_mask_search_results( + [dict(layer_data) for layer_data in batch.mask_search_final_result_pos], + prompt=str(batch.prompt), + mask_strategies=sparse_mask_candidates_searching, + output_dir=f"output/mask_search_result_pos_{size[0]}x{size[1]}/", + ) + if batch.mask_search_final_result_neg is not None and batch.prompt is not None: + save_mask_search_results( + [dict(layer_data) for layer_data in batch.mask_search_final_result_neg], + prompt=str(batch.prompt), + mask_strategies=sparse_mask_candidates_searching, + output_dir=f"output/mask_search_result_neg_{size[0]}x{size[1]}/", + ) + + def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """Verify denoising stage inputs.""" + result = VerificationResult() + result.add_check("timesteps", batch.timesteps, [V.is_tensor, V.min_dims(1)]) + # disable temporarily for image-generation models + # result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)]) + result.add_check("prompt_embeds", batch.prompt_embeds, V.list_not_empty) + result.add_check("image_embeds", batch.image_embeds, V.is_list) + # result.add_check( + # "image_latent", batch.image_latent, V.none_or_tensor_with_dims(5) + # ) + result.add_check( + "num_inference_steps", batch.num_inference_steps, V.positive_int + ) + result.add_check("guidance_scale", batch.guidance_scale, V.positive_float) + result.add_check("eta", batch.eta, V.non_negative_float) + result.add_check("generator", batch.generator, V.generator_or_list_generators) + result.add_check( + "do_classifier_free_guidance", + batch.do_classifier_free_guidance, + V.bool_value, + ) + result.add_check( + "negative_prompt_embeds", + batch.negative_prompt_embeds, + lambda x: not batch.do_classifier_free_guidance or V.list_not_empty(x), + ) + return result + + def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """Verify denoising stage outputs.""" + result = VerificationResult() + # result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)]) + return result diff --git a/python/sglang/multimodal_gen/runtime/pipelines/stages/denoising_dmd.py b/python/sglang/multimodal_gen/runtime/pipelines/stages/denoising_dmd.py new file mode 100644 index 000000000..1d39aaf8e --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/pipelines/stages/denoising_dmd.py @@ -0,0 +1,283 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +import time + +import torch +from einops import rearrange + +from sglang.multimodal_gen.runtime.distributed import ( + get_local_torch_device, + get_sp_parallel_rank, + get_sp_world_size, + logger, + sequence_model_parallel_all_gather, +) +from sglang.multimodal_gen.runtime.layers.attention.backends.sliding_tile_attn import ( + SlidingTileAttentionBackend, +) +from sglang.multimodal_gen.runtime.layers.attention.backends.video_sparse_attn import ( + VideoSparseAttentionBackend, +) +from sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context +from sglang.multimodal_gen.runtime.models.schedulers.scheduling_flow_match_euler_discrete import ( + FlowMatchEulerDiscreteScheduler, +) +from sglang.multimodal_gen.runtime.models.utils import pred_noise_to_pred_video +from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import Req +from sglang.multimodal_gen.runtime.pipelines.stages import DenoisingStage +from sglang.multimodal_gen.runtime.pipelines.stages.denoising import ( + st_attn_available, + vsa_available, +) +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.utils import dict_to_3d_list + + +# TODO: use base methods of DenoisingStage +class DmdDenoisingStage(DenoisingStage): + """ + Denoising stage for DMD. + """ + + def __init__(self, transformer, scheduler) -> None: + super().__init__(transformer, scheduler) + self.scheduler = FlowMatchEulerDiscreteScheduler(shift=8.0) + + def forward( + self, + batch: Req, + server_args: ServerArgs, + ) -> Req: + """ + Run the denoising loop. + + Args: + batch: The current batch information. + server_args: The inference arguments. + + Returns: + The batch with denoised latents. + """ + # Setup precision and autocast settings + # TODO(will): make the precision configurable for inference + # target_dtype = PRECISION_TO_TYPE[server_args.precision] + target_dtype = torch.bfloat16 + autocast_enabled = ( + target_dtype != torch.float32 + ) and not server_args.disable_autocast + + # Get timesteps and calculate warmup steps + timesteps = batch.timesteps + + # TODO(will): remove this once we add input/output validation for stages + if timesteps is None: + raise ValueError("Timesteps must be provided") + num_inference_steps = batch.num_inference_steps + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + # Prepare image latents and embeddings for I2V generation + image_embeds = batch.image_embeds + if len(image_embeds) > 0: + assert torch.isnan(image_embeds[0]).sum() == 0 + image_embeds = [ + image_embed.to(target_dtype) for image_embed in image_embeds + ] + + image_kwargs = self.prepare_extra_func_kwargs( + self.transformer.forward, + { + "encoder_hidden_states_image": image_embeds, + "mask_strategy": dict_to_3d_list(None, t_max=50, l_max=60, h_max=24), + }, + ) + + pos_cond_kwargs = self.prepare_extra_func_kwargs( + self.transformer.forward, + { + "encoder_hidden_states_2": batch.clip_embedding_pos, + "encoder_attention_mask": batch.prompt_attention_mask, + }, + ) + + # Prepare STA parameters + if st_attn_available and self.attn_backend == SlidingTileAttentionBackend: + self.prepare_sta_param(batch, server_args) + + # Get latents and embeddings + assert batch.latents is not None, "latents must be provided" + latents = batch.latents + latents = latents.permute(0, 2, 1, 3, 4) + + video_raw_latent_shape = latents.shape + prompt_embeds = batch.prompt_embeds + assert not torch.isnan(prompt_embeds[0]).any(), "prompt_embeds contains nan" + timesteps = torch.tensor( + server_args.pipeline_config.dmd_denoising_steps, + dtype=torch.long, + device=get_local_torch_device(), + ) + + # Handle sequence parallelism if enabled + sp_world_size, rank_in_sp_group = ( + get_sp_world_size(), + get_sp_parallel_rank(), + ) + sp_group = sp_world_size > 1 + if sp_group: + latents = rearrange( + latents, "b (n t) c h w -> b n t c h w", n=sp_world_size + ).contiguous() + latents = latents[:, rank_in_sp_group, :, :, :, :] + if batch.image_latent is not None: + image_latent = rearrange( + batch.image_latent, + "b c (n t) h w -> b c n t h w", + n=sp_world_size, + ).contiguous() + + image_latent = image_latent[:, :, rank_in_sp_group, :, :, :] + batch.image_latent = image_latent + + # Run denoising loop + denoising_loop_start_time = time.time() + with self.progress_bar(total=len(timesteps)) as progress_bar: + for i, t in enumerate(timesteps): + # Skip if interrupted + if hasattr(self, "interrupt") and self.interrupt: + continue + # Expand latents for I2V + noise_latents = latents.clone() + latent_model_input = latents.to(target_dtype) + + if batch.image_latent is not None: + latent_model_input = torch.cat( + [ + latent_model_input, + batch.image_latent.permute(0, 2, 1, 3, 4), + ], + dim=2, + ).to(target_dtype) + assert not torch.isnan( + latent_model_input + ).any(), "latent_model_input contains nan" + + # Prepare inputs for transformer + t_expand = t.repeat(latent_model_input.shape[0]) + guidance_expand = ( + torch.tensor( + [server_args.pipeline_config.embedded_cfg_scale] + * latent_model_input.shape[0], + dtype=torch.float32, + device=get_local_torch_device(), + ).to(target_dtype) + * 1000.0 + if server_args.pipeline_config.embedded_cfg_scale is not None + else None + ) + + # Predict noise residual + with torch.autocast( + device_type="cuda", + dtype=target_dtype, + enabled=autocast_enabled, + ): + if ( + vsa_available + and self.attn_backend == VideoSparseAttentionBackend + ): + self.attn_metadata_builder_cls = ( + self.attn_backend.get_builder_cls() + ) + + if self.attn_metadata_builder_cls is not None: + self.attn_metadata_builder = ( + self.attn_metadata_builder_cls() + ) + # TODO(will): clean this up + attn_metadata = self.attn_metadata_builder.build( # type: ignore + current_timestep=i, # type: ignore + raw_latent_shape=batch.raw_latent_shape[2:5], # type: ignore + patch_size=server_args.pipeline_config.dit_config.patch_size, # type: ignore + STA_param=batch.STA_param, # type: ignore + VSA_sparsity=server_args.VSA_sparsity, # type: ignore + device=get_local_torch_device(), # type: ignore + ) # type: ignore + assert ( + attn_metadata is not None + ), "attn_metadata cannot be None" + else: + attn_metadata = None + else: + attn_metadata = None + + batch.is_cfg_negative = False + with set_forward_context( + current_timestep=i, + attn_metadata=attn_metadata, + forward_batch=batch, + # server_args=server_args + ): + # Run transformer + pred_noise = self.transformer( + latent_model_input.permute(0, 2, 1, 3, 4), + prompt_embeds, + t_expand, + guidance=guidance_expand, + **image_kwargs, + **pos_cond_kwargs, + ).permute(0, 2, 1, 3, 4) + + pred_video = pred_noise_to_pred_video( + pred_noise=pred_noise.flatten(0, 1), + noise_input_latent=noise_latents.flatten(0, 1), + timestep=t_expand, + scheduler=self.scheduler, + ).unflatten(0, pred_noise.shape[:2]) + + if i < len(timesteps) - 1: + next_timestep = timesteps[i + 1] * torch.ones( + [1], dtype=torch.long, device=pred_video.device + ) + noise = torch.randn( + video_raw_latent_shape, + dtype=pred_video.dtype, + generator=batch.generator[0], + ).to(self.device) + if sp_group: + noise = rearrange( + noise, + "b (n t) c h w -> b n t c h w", + n=sp_world_size, + ).contiguous() + noise = noise[:, rank_in_sp_group, :, :, :, :] + latents = self.scheduler.add_noise( + pred_video.flatten(0, 1), + noise.flatten(0, 1), + next_timestep, + ).unflatten(0, pred_video.shape[:2]) + else: + latents = pred_video + + # Update progress bar + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps + and (i + 1) % self.scheduler.order == 0 + and progress_bar is not None + ): + progress_bar.update() + + denoising_loop_end_time = time.time() + if len(timesteps) > 0: + logger.info( + "Average time per step: %.4f seconds", + (denoising_loop_end_time - denoising_loop_start_time) / len(timesteps), + ) + + # Gather results if using sequence parallelism + if sp_group: + latents = sequence_model_parallel_all_gather(latents, dim=1) + latents = latents.permute(0, 2, 1, 3, 4) + # Update batch with final latents + batch.latents = latents + + return batch diff --git a/python/sglang/multimodal_gen/runtime/pipelines/stages/encoding.py b/python/sglang/multimodal_gen/runtime/pipelines/stages/encoding.py new file mode 100644 index 000000000..dbea07442 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/pipelines/stages/encoding.py @@ -0,0 +1,104 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +Encoding stage for diffusion pipelines. +""" + +import torch + +from sglang.multimodal_gen.runtime.distributed import get_local_torch_device +from sglang.multimodal_gen.runtime.models.vaes.common import ParallelTiledVAE +from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import Req +from sglang.multimodal_gen.runtime.pipelines.stages.base import PipelineStage +from sglang.multimodal_gen.runtime.pipelines.stages.validators import ( + V, # Import validators +) +from sglang.multimodal_gen.runtime.pipelines.stages.validators import VerificationResult +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import PRECISION_TO_TYPE + +logger = init_logger(__name__) + + +class EncodingStage(PipelineStage): + """ + Stage for encoding pixel space representations into latent space. + + This stage handles the encoding of pixel-space video/images into latent + representations for further processing in the diffusion pipeline. + """ + + def __init__(self, vae: ParallelTiledVAE) -> None: + self.vae: ParallelTiledVAE = vae + + @torch.no_grad() + def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """Verify encoding stage inputs.""" + result = VerificationResult() + # Input video/images for VAE encoding: [batch_size, channels, frames, height, width] + result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)]) + return result + + def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """Verify encoding stage outputs.""" + result = VerificationResult() + # Encoded latents: [batch_size, channels, frames, height_latents, width_latents] + result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)]) + return result + + def forward( + self, + batch: Req, + server_args: ServerArgs, + ) -> Req: + """ + Encode pixel space representations into latent space. + + Args: + batch: The current batch information. + server_args: The inference arguments. + + Returns: + The batch with encoded latents. + """ + assert batch.latents is not None and isinstance(batch.latents, torch.Tensor) + + self.vae = self.vae.to(get_local_torch_device()) + + # Setup VAE precision + vae_dtype = PRECISION_TO_TYPE[server_args.pipeline_config.vae_precision] + vae_autocast_enabled = ( + vae_dtype != torch.float32 + ) and not server_args.disable_autocast + + # Normalize input to [-1, 1] range (reverse of decoding normalization) + latents = (batch.latents * 2.0 - 1.0).clamp(-1, 1) + + # Move to appropriate device and dtype + latents = latents.to(get_local_torch_device()) + + # Encode image to latents + with torch.autocast( + device_type="cuda", dtype=vae_dtype, enabled=vae_autocast_enabled + ): + if server_args.pipeline_config.vae_tiling: + self.vae.enable_tiling() + # if server_args.vae_sp: + # self.vae.enable_parallel() + if not vae_autocast_enabled: + latents = latents.to(vae_dtype) + latents = self.vae.encode(latents).mean + + # Update batch with encoded latents + batch.latents = latents + + # Offload models if needed + if hasattr(self, "maybe_free_model_hooks"): + self.maybe_free_model_hooks() + + if server_args.vae_cpu_offload: + self.vae.to("cpu") + + return batch diff --git a/python/sglang/multimodal_gen/runtime/pipelines/stages/image_encoding.py b/python/sglang/multimodal_gen/runtime/pipelines/stages/image_encoding.py new file mode 100644 index 000000000..0f91451da --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/pipelines/stages/image_encoding.py @@ -0,0 +1,447 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +Image encoding stages for I2V diffusion pipelines. + +This module contains implementations of image encoding stages for diffusion pipelines. +""" + +import PIL +import torch + +from sglang.multimodal_gen.configs.pipelines.qwen_image import ( + QwenImageEditPipelineConfig, + QwenImagePipelineConfig, + _pack_latents, + qwen_image_postprocess_text, +) +from sglang.multimodal_gen.runtime.distributed import get_local_torch_device +from sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context +from sglang.multimodal_gen.runtime.models.vaes.common import ParallelTiledVAE +from sglang.multimodal_gen.runtime.models.vision_utils import ( + normalize, + numpy_to_pt, + pil_to_numpy, + resize, +) +from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import Req +from sglang.multimodal_gen.runtime.pipelines.stages.base import PipelineStage +from sglang.multimodal_gen.runtime.pipelines.stages.validators import ( + StageValidators as V, +) +from sglang.multimodal_gen.runtime.pipelines.stages.validators import VerificationResult +from sglang.multimodal_gen.runtime.server_args import ExecutionMode, ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import PRECISION_TO_TYPE + +logger = init_logger(__name__) + + +class ImageEncodingStage(PipelineStage): + """ + Stage for encoding image prompts into embeddings for diffusion models. + + This stage handles the encoding of image prompts into the embedding space + expected by the diffusion model. + """ + + def __init__( + self, + image_processor, + image_encoder=None, + text_encoder=None, + vae_image_processor=None, + ) -> None: + """ + Initialize the prompt encoding stage. + + Args: + text_encoder: An encoder to encode input_ids and pixel values + """ + super().__init__() + self.image_processor = image_processor + self.vae_image_processor = vae_image_processor + self.image_encoder = image_encoder + self.text_encoder = text_encoder + + def move_to_device(self, device): + fields = [ + "image_processor", + "image_encoder", + ] + for field in fields: + processor = getattr(self, field, None) + if processor and hasattr(processor, "to"): + setattr(self, field, processor.to(device)) + + def encoding_qwen_image_edit(self, outputs, image_inputs): + # encoder hidden state + prompt_embeds = qwen_image_postprocess_text(outputs, image_inputs, 64) + return prompt_embeds + + @torch.inference_mode() + def forward( + self, + batch: Req, + server_args: ServerArgs, + ) -> Req: + """ + Encode the prompt into image encoder hidden states. + + Args: + batch: The current batch information. + server_args: The inference arguments. + + Returns: + The batch with encoded prompt embeddings. + """ + + cuda_device = get_local_torch_device() + self.move_to_device(cuda_device) + + image = batch.pil_image + + # preprocess the imag_processor + prompt_image = server_args.pipeline_config.preprocess_image( + image, self.vae_image_processor + ) + + if batch.prompt and ( + isinstance(server_args.pipeline_config, QwenImageEditPipelineConfig) + or isinstance(server_args.pipeline_config, QwenImagePipelineConfig) + ): + prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" + txt = prompt_template_encode.format(batch.prompt) + image_processor_kwargs = dict(text=[txt], padding=True) + else: + image_processor_kwargs = {} + + image_inputs = self.image_processor( + images=prompt_image, return_tensors="pt", **image_processor_kwargs + ).to(cuda_device) + if self.image_encoder: + # if an image encoder is provided + with set_forward_context(current_timestep=0, attn_metadata=None): + outputs = self.image_encoder( + **image_inputs, + **server_args.pipeline_config.image_encoder_extra_args, + ) + image_embeds = server_args.pipeline_config.postprocess_image(outputs) + + batch.image_embeds.append(image_embeds) + elif self.text_encoder: + # if a text encoder is provided, e.g. Qwen-Image-Edit + # 1. neg prompt embeds + if batch.prompt: + prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" + txt = prompt_template_encode.format(batch.negative_prompt) + neg_image_processor_kwargs = dict(text=[txt], padding=True) + else: + neg_image_processor_kwargs = {} + + neg_image_inputs = self.image_processor( + images=prompt_image, return_tensors="pt", **neg_image_processor_kwargs + ).to(get_local_torch_device()) + + with set_forward_context(current_timestep=0, attn_metadata=None): + outputs = self.text_encoder( + input_ids=image_inputs.input_ids, + attention_mask=image_inputs.attention_mask, + pixel_values=image_inputs.pixel_values, + image_grid_thw=image_inputs.image_grid_thw, + output_hidden_states=True, + ) + neg_outputs = self.text_encoder( + input_ids=neg_image_inputs.input_ids, + attention_mask=neg_image_inputs.attention_mask, + pixel_values=neg_image_inputs.pixel_values, + image_grid_thw=neg_image_inputs.image_grid_thw, + output_hidden_states=True, + ) + batch.prompt_embeds.append( + self.encoding_qwen_image_edit(outputs, image_inputs) + ) + + batch.negative_prompt_embeds.append( + self.encoding_qwen_image_edit(neg_outputs, neg_image_inputs) + ) + + self.move_to_device("cpu") + + return batch + + def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """Verify image encoding stage inputs.""" + result = VerificationResult() + if batch.debug: + logger.debug(f"{batch.pil_image=}") + logger.debug(f"{batch.image_embeds=}") + result.add_check("pil_image", batch.pil_image, V.not_none) + result.add_check("image_embeds", batch.image_embeds, V.is_list) + return result + + def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """Verify image encoding stage outputs.""" + result = VerificationResult() + # result.add_check("image_embeds", batch.image_embeds, V.list_of_tensors_dims(3)) + return result + + +class ImageVAEEncodingStage(PipelineStage): + """ + Stage for encoding pixel representations into latent space. + + This stage handles the encoding of pixel representations into the final + input format (e.g., latents). + """ + + def __init__(self, vae: ParallelTiledVAE, **kwargs) -> None: + super().__init__() + self.vae: ParallelTiledVAE = vae + + def forward( + self, + batch: Req, + server_args: ServerArgs, + ) -> Req: + """ + Encode pixel representations into latent space. + + Args: + batch: The current batch information. + server_args: The inference arguments. + + Returns: + The batch with encoded outputs. + """ + assert batch.pil_image is not None + if server_args.mode == ExecutionMode.INFERENCE: + assert batch.pil_image is not None and isinstance( + batch.pil_image, PIL.Image.Image + ) + assert batch.height is not None and isinstance(batch.height, int) + assert batch.width is not None and isinstance(batch.width, int) + assert batch.num_frames is not None and isinstance(batch.num_frames, int) + height = batch.height + width = batch.width + num_frames = batch.num_frames + elif server_args.mode == ExecutionMode.PREPROCESS: + assert batch.pil_image is not None and isinstance( + batch.pil_image, torch.Tensor + ) + assert batch.height is not None and isinstance(batch.height, list) + assert batch.width is not None and isinstance(batch.width, list) + assert batch.num_frames is not None and isinstance(batch.num_frames, list) + num_frames = batch.num_frames[0] + height = batch.height[0] + width = batch.width[0] + + self.vae = self.vae.to(get_local_torch_device()) + + latent_height = height // self.vae.spatial_compression_ratio + latent_width = width // self.vae.spatial_compression_ratio + + image = batch.pil_image + image = self.preprocess( + image, + vae_scale_factor=self.vae.spatial_compression_ratio, + height=height, + width=width, + ).to(get_local_torch_device(), dtype=torch.float32) + + # (B, C, H, W) -> (B, C, 1, H, W) + image = image.unsqueeze(2) + + video_condition = torch.cat( + [ + image, + image.new_zeros( + image.shape[0], + image.shape[1], + num_frames - 1, + image.shape[3], + image.shape[4], + ), + ], + dim=2, + ) + video_condition = video_condition.to( + device=get_local_torch_device(), dtype=torch.float32 + ) + + # Setup VAE precision + vae_dtype = PRECISION_TO_TYPE[server_args.pipeline_config.vae_precision] + vae_autocast_enabled = ( + vae_dtype != torch.float32 + ) and not server_args.disable_autocast + + # Encode Image + with torch.autocast( + device_type="cuda", dtype=vae_dtype, enabled=vae_autocast_enabled + ): + if server_args.pipeline_config.vae_tiling: + self.vae.enable_tiling() + # if server_args.vae_sp: + # self.vae.enable_parallel() + if not vae_autocast_enabled: + video_condition = video_condition.to(vae_dtype) + encoder_output = self.vae.encode(video_condition) + + if server_args.mode == ExecutionMode.PREPROCESS: + latent_condition = encoder_output.mean + else: + generator = batch.generator + if generator is None: + raise ValueError("Generator must be provided") + latent_condition = self.retrieve_latents(encoder_output, generator) + + # Apply shifting if needed + if hasattr(self.vae, "shift_factor") and self.vae.shift_factor is not None: + if isinstance(self.vae.shift_factor, torch.Tensor): + latent_condition -= self.vae.shift_factor.to( + latent_condition.device, latent_condition.dtype + ) + else: + latent_condition -= self.vae.shift_factor + + if isinstance(self.vae.scaling_factor, torch.Tensor): + latent_condition = latent_condition * self.vae.scaling_factor.to( + latent_condition.device, latent_condition.dtype + ) + else: + latent_condition = latent_condition * self.vae.scaling_factor + + if server_args.mode == ExecutionMode.PREPROCESS: + batch.image_latent = latent_condition + else: + if isinstance(server_args.pipeline_config, QwenImageEditPipelineConfig): + batch_size = batch.batch_size + if ( + batch_size > latent_condition.shape[0] + and batch_size % latent_condition.shape[0] == 0 + ): + # expand init_latents for batch_size + additional_image_per_prompt = ( + batch_size // latent_condition.shape[0] + ) + image_latents = torch.cat( + [latent_condition] * additional_image_per_prompt, dim=0 + ) + elif ( + batch_size > latent_condition.shape[0] + and batch_size % latent_condition.shape[0] != 0 + ): + raise ValueError( + f"Cannot duplicate `image` of batch size {latent_condition.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([latent_condition], dim=0) + image_latent_height, image_latent_width = image_latents.shape[3:] + num_channels_latents = ( + self.server_args.pipeline_config.dit_config.arch_config.in_channels + // 4 + ) + image_latents = _pack_latents( + image_latents, + batch_size, + num_channels_latents, + image_latent_height, + image_latent_width, + ) + else: + mask_lat_size = torch.ones( + 1, 1, num_frames, latent_height, latent_width + ) + mask_lat_size[:, :, list(range(1, num_frames))] = 0 + first_frame_mask = mask_lat_size[:, :, 0:1] + first_frame_mask = torch.repeat_interleave( + first_frame_mask, + repeats=self.vae.temporal_compression_ratio, + dim=2, + ) + mask_lat_size = torch.concat( + [first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2 + ) + mask_lat_size = mask_lat_size.view( + 1, + -1, + self.vae.temporal_compression_ratio, + latent_height, + latent_width, + ) + mask_lat_size = mask_lat_size.transpose(1, 2) + mask_lat_size = mask_lat_size.to(latent_condition.device) + image_latents = torch.concat([mask_lat_size, latent_condition], dim=1) + + batch.image_latent = image_latents + + # Offload models if needed + if hasattr(self, "maybe_free_model_hooks"): + self.maybe_free_model_hooks() + + self.vae.to("cpu") + + return batch + + def retrieve_latents( + self, + encoder_output: torch.Tensor, + generator: torch.Generator | None = None, + sample_mode: str = "sample", + ): + if sample_mode == "sample": + return encoder_output.sample(generator) + elif sample_mode == "argmax": + return encoder_output.mode() + else: + raise AttributeError("Could not access latents of provided encoder_output") + + def preprocess( + self, + image: torch.Tensor | PIL.Image.Image, + vae_scale_factor: int, + height: int | None = None, + width: int | None = None, + resize_mode: str = "default", # "default", "fill", "crop" + ) -> torch.Tensor: + + if isinstance(image, PIL.Image.Image): + width, height = ( + self.server_args.pipeline_config.vae_config.calculate_dimensions( + image, vae_scale_factor, width, height + ) + ) + image = resize(image, height, width, resize_mode=resize_mode) + image = pil_to_numpy(image) # to np + image = numpy_to_pt(image) # to pt + + do_normalize = True + if image.min() < 0: + do_normalize = False + if do_normalize: + image = normalize(image) + + return image + + def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """Verify encoding stage inputs.""" + result = VerificationResult() + result.add_check("generator", batch.generator, V.generator_or_list_generators) + if server_args.mode == ExecutionMode.PREPROCESS: + result.add_check("height", batch.height, V.list_not_empty) + result.add_check("width", batch.width, V.list_not_empty) + result.add_check("num_frames", batch.num_frames, V.list_not_empty) + else: + result.add_check("height", batch.height, V.positive_int) + result.add_check("width", batch.width, V.positive_int) + result.add_check("num_frames", batch.num_frames, V.positive_int) + return result + + def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """Verify encoding stage outputs.""" + result = VerificationResult() + # result.add_check( + # "image_latent", batch.image_latent, [V.is_tensor, V.with_dims(5)] + # ) + return result diff --git a/python/sglang/multimodal_gen/runtime/pipelines/stages/input_validation.py b/python/sglang/multimodal_gen/runtime/pipelines/stages/input_validation.py new file mode 100644 index 000000000..9c3fd2fc2 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/pipelines/stages/input_validation.py @@ -0,0 +1,211 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +Input validation stage for diffusion pipelines. +""" +import numpy as np +import torch +import torchvision.transforms.functional as TF +from PIL import Image + +from sglang.multimodal_gen.configs.pipelines import WanI2V480PConfig +from sglang.multimodal_gen.configs.pipelines.qwen_image import ( + QwenImageEditPipelineConfig, +) +from sglang.multimodal_gen.runtime.models.vision_utils import load_image, load_video +from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import Req +from sglang.multimodal_gen.runtime.pipelines.stages.base import PipelineStage +from sglang.multimodal_gen.runtime.pipelines.stages.validators import ( + StageValidators, + VerificationResult, +) +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import best_output_size + +logger = init_logger(__name__) + +# Alias for convenience +V = StageValidators + +# TODO: since this might change sampling params after logging, should be do this beforehand? + + +class InputValidationStage(PipelineStage): + """ + Stage for validating and preparing inputs for diffusion pipelines. + + This stage validates that all required inputs are present and properly formatted + before proceeding with the diffusion process. + + In this stage, input image and output image may be resized + """ + + def _generate_seeds(self, batch: Req, server_args: ServerArgs): + """Generate seeds for the inference""" + seed = batch.seed + num_videos_per_prompt = batch.num_outputs_per_prompt + + assert seed is not None + seeds = [seed + i for i in range(num_videos_per_prompt)] + batch.seeds = seeds + # Peiyuan: using GPU seed will cause A100 and H100 to generate different results... + # FIXME: the generator's in latent preparation stage seems to be different from seeds + batch.generator = [torch.Generator("cpu").manual_seed(seed) for seed in seeds] + + def forward( + self, + batch: Req, + server_args: ServerArgs, + ) -> Req: + """ + Validate and prepare inputs. + + Args: + batch: The current batch information. + server_args: The inference arguments. + + Returns: + The validated batch information. + """ + + self._generate_seeds(batch, server_args) + + # Ensure prompt is properly formatted + if batch.prompt is None and batch.prompt_embeds is None: + raise ValueError("Either `prompt` or `prompt_embeds` must be provided") + + # Ensure negative prompt is properly formatted if using classifier-free guidance + if ( + batch.do_classifier_free_guidance + and batch.negative_prompt is None + and batch.negative_prompt_embeds is None + ): + raise ValueError( + "For classifier-free guidance, either `negative_prompt` or " + "`negative_prompt_embeds` must be provided" + ) + + # Validate height and width + if batch.height is None or batch.width is None: + raise ValueError( + "Height and width must be provided. Please set `height` and `width`." + ) + if batch.height % 8 != 0 or batch.width % 8 != 0: + raise ValueError( + f"Height and width must be divisible by 8 but are {batch.height} and {batch.width}." + ) + + # Validate number of inference steps + if batch.num_inference_steps <= 0: + raise ValueError( + f"Number of inference steps must be positive, but got {batch.num_inference_steps}" + ) + + # Validate guidance scale if using classifier-free guidance + if batch.do_classifier_free_guidance and batch.guidance_scale <= 0: + raise ValueError( + f"Guidance scale must be positive, but got {batch.guidance_scale}" + ) + + # for i2v, get image from image_path + # @TODO(Wei) hard-coded for wan2.2 5b ti2v for now. Should put this in image_encoding stage + if batch.image_path is not None: + if batch.image_path.endswith(".mp4"): + image = load_video(batch.image_path)[0] + else: + image = load_image(batch.image_path) + batch.pil_image = image + + # NOTE: resizing needs to be bring in advance + if isinstance(server_args.pipeline_config, QwenImageEditPipelineConfig): + height = None if batch.height_not_provided else batch.height + width = None if batch.width_not_provided else batch.width + width, height = server_args.pipeline_config.set_width_and_height( + height, width, batch.pil_image + ) + batch.width = width + batch.height = height + elif ( + server_args.pipeline_config.ti2v_task + or server_args.pipeline_config.ti2i_task + ) and batch.pil_image is not None: + # further processing for ti2v task + img = batch.pil_image + ih, iw = img.height, img.width + patch_size = server_args.pipeline_config.dit_config.arch_config.patch_size + vae_stride = ( + server_args.pipeline_config.vae_config.arch_config.scale_factor_spatial + ) + dh, dw = patch_size[1] * vae_stride, patch_size[2] * vae_stride + max_area = 704 * 1280 + ow, oh = best_output_size(iw, ih, dw, dh, max_area) + + scale = max(ow / iw, oh / ih) + img = img.resize((round(iw * scale), round(ih * scale)), Image.LANCZOS) + logger.info("resized img height: %s, img width: %s", img.height, img.width) + + # center-crop + x1 = (img.width - ow) // 2 + y1 = (img.height - oh) // 2 + img = img.crop((x1, y1, x1 + ow, y1 + oh)) + assert img.width == ow and img.height == oh + + # to tensor + img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device).unsqueeze(1) + img = img.unsqueeze(0) + batch.height = oh + batch.width = ow + # TODO: should we store in a new field: pixel values? + batch.pil_image = img + + if isinstance(server_args.pipeline_config, WanI2V480PConfig): + # TODO: could we merge with above? + # resize image only, Wan2.1 I2V + max_area = 720 * 1280 + aspect_ratio = image.height / image.width + mod_value = ( + server_args.pipeline_config.vae_config.arch_config.scale_factor_spatial + * server_args.pipeline_config.dit_config.arch_config.patch_size[1] + ) + height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value + width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value + + batch.pil_image = batch.pil_image.resize((width, height)) + batch.height = height + batch.width = width + + return batch + + def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """Verify input validation stage inputs.""" + result = VerificationResult() + result.add_check("seed", batch.seed, [V.not_none, V.non_negative_int]) + result.add_check( + "num_videos_per_prompt", batch.num_outputs_per_prompt, V.positive_int + ) + result.add_check( + "prompt_or_embeds", + None, + lambda _: V.string_or_list_strings(batch.prompt) + or V.list_not_empty(batch.prompt_embeds), + ) + result.add_check("height", batch.height, V.positive_int) + result.add_check("width", batch.width, V.positive_int) + result.add_check( + "num_inference_steps", batch.num_inference_steps, V.positive_int + ) + result.add_check( + "guidance_scale", + batch.guidance_scale, + lambda x: not batch.do_classifier_free_guidance or V.positive_float(x), + ) + return result + + def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """Verify input validation stage outputs.""" + result = VerificationResult() + result.add_check("seeds", batch.seeds, V.list_not_empty) + result.add_check("generator", batch.generator, V.generator_or_list_generators) + return result diff --git a/python/sglang/multimodal_gen/runtime/pipelines/stages/latent_preparation.py b/python/sglang/multimodal_gen/runtime/pipelines/stages/latent_preparation.py new file mode 100644 index 000000000..55f4fc86e --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/pipelines/stages/latent_preparation.py @@ -0,0 +1,155 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +Latent preparation stage for diffusion pipelines. +""" +from diffusers.utils.torch_utils import randn_tensor + +from sglang.multimodal_gen.runtime.distributed import get_local_torch_device +from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import Req +from sglang.multimodal_gen.runtime.pipelines.stages.base import PipelineStage +from sglang.multimodal_gen.runtime.pipelines.stages.validators import ( + StageValidators as V, +) +from sglang.multimodal_gen.runtime.pipelines.stages.validators import VerificationResult +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class LatentPreparationStage(PipelineStage): + """ + Stage for preparing initial latent variables for the diffusion process. + + This stage handles the preparation of the initial latent variables that will be + denoised during the diffusion process. + """ + + def __init__(self, scheduler, transformer) -> None: + super().__init__() + self.scheduler = scheduler + self.transformer = transformer + + def forward( + self, + batch: Req, + server_args: ServerArgs, + ) -> Req: + """ + Prepare initial latent variables for the diffusion process. + + Args: + batch: The current batch information. + server_args: The inference arguments. + + Returns: + The batch with prepared latent variables. + """ + + latent_num_frames = None + # Adjust video length based on VAE version if needed + if hasattr(self, "adjust_video_length"): + latent_num_frames = self.adjust_video_length(batch, server_args) + + batch_size = batch.batch_size + + # Get required parameters + dtype = batch.prompt_embeds[0].dtype + device = get_local_torch_device() + generator = batch.generator + latents = batch.latents + num_frames = ( + latent_num_frames if latent_num_frames is not None else batch.num_frames + ) + height = batch.height + width = batch.width + + # TODO(will): remove this once we add input/output validation for stages + if height is None or width is None: + raise ValueError("Height and width must be provided") + + # Validate generator if it's a list + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + # Generate or use provided latents + if latents is None: + shape = server_args.pipeline_config.prepare_latent_shape( + batch, batch_size, num_frames + ) + latents = randn_tensor( + shape, generator=generator, device=device, dtype=dtype + ) + latents = server_args.pipeline_config.pack_latents( + latents, batch_size, batch + ) + else: + latents = latents.to(device) + + # Scale the initial noise if needed + if hasattr(self.scheduler, "init_noise_sigma"): + latents = latents * self.scheduler.init_noise_sigma + # Update batch with prepared latents + batch.latents = latents + batch.raw_latent_shape = latents.shape + return batch + + def adjust_video_length(self, batch: Req, server_args: ServerArgs) -> int: + """ + Adjust video length based on VAE version. + + Args: + batch: The current batch information. + server_args: The inference arguments. + + Returns: + The batch with adjusted video length. + """ + + video_length = batch.num_frames + use_temporal_scaling_frames = ( + server_args.pipeline_config.vae_config.use_temporal_scaling_frames + ) + if use_temporal_scaling_frames: + temporal_scale_factor = ( + server_args.pipeline_config.vae_config.arch_config.temporal_compression_ratio + ) + latent_num_frames = (video_length - 1) // temporal_scale_factor + 1 + else: # stepvideo only + latent_num_frames = video_length // 17 * 3 + return int(latent_num_frames) + + def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """Verify latent preparation stage inputs.""" + result = VerificationResult() + result.add_check( + "prompt_or_embeds", + None, + lambda _: V.string_or_list_strings(batch.prompt) + or V.list_not_empty(batch.prompt_embeds), + ) + result.add_check("prompt_embeds", batch.prompt_embeds, V.list_of_tensors) + result.add_check( + "num_videos_per_prompt", batch.num_outputs_per_prompt, V.positive_int + ) + result.add_check("generator", batch.generator, V.generator_or_list_generators) + result.add_check("num_frames", batch.num_frames, V.positive_int) + result.add_check("height", batch.height, V.positive_int) + result.add_check("width", batch.width, V.positive_int) + result.add_check("latents", batch.latents, V.none_or_tensor) + return result + + def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """Verify latent preparation stage outputs.""" + result = VerificationResult() + if batch.debug: + logger.debug(f"{batch.raw_latent_shape=}") + # disable temporarily for image-generation models + # result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)]) + result.add_check("raw_latent_shape", batch.raw_latent_shape, V.is_tuple) + return result diff --git a/python/sglang/multimodal_gen/runtime/pipelines/stages/stepvideo_encoding.py b/python/sglang/multimodal_gen/runtime/pipelines/stages/stepvideo_encoding.py new file mode 100644 index 000000000..54aa6b45c --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/pipelines/stages/stepvideo_encoding.py @@ -0,0 +1,97 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context +from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import Req +from sglang.multimodal_gen.runtime.pipelines.stages.base import PipelineStage +from sglang.multimodal_gen.runtime.pipelines.stages.validators import ( + StageValidators as V, +) +from sglang.multimodal_gen.runtime.pipelines.stages.validators import VerificationResult +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +# The dedicated stepvideo prompt encoding stage. +class StepvideoPromptEncodingStage(PipelineStage): + """ + Stage for encoding prompts using the remote caption API. + + This stage applies the magic string transformations and calls + the remote caption service asynchronously to get: + - primary prompt embeddings, + - an attention mask, + - and a clip embedding. + """ + + def __init__(self, stepllm, clip) -> None: + super().__init__() + # self.caption_client = caption_client # This should have a call_caption(prompts: List[str]) method. + self.stepllm = stepllm + self.clip = clip + + @torch.no_grad() + def forward(self, batch: Req, server_args) -> Req: + + prompts = [batch.prompt + server_args.pipeline_config.pos_magic] + bs = len(prompts) + prompts += [server_args.pipeline_config.neg_magic] * bs + with set_forward_context(current_timestep=0, attn_metadata=None): + y, y_mask = self.stepllm(prompts) + clip_emb, _ = self.clip(prompts) + len_clip = clip_emb.shape[1] + y_mask = torch.nn.functional.pad(y_mask, (len_clip, 0), value=1) + pos_clip, neg_clip = clip_emb[:bs], clip_emb[bs:] + + # split positive vs negative text + batch.prompt_embeds = y[:bs] # [bs, seq_len, dim] + batch.negative_prompt_embeds = y[bs : 2 * bs] # [bs, seq_len, dim] + batch.prompt_attention_mask = y_mask[:bs] # [bs, seq_len] + batch.negative_attention_mask = y_mask[bs : 2 * bs] # [bs, seq_len] + batch.clip_embedding_pos = pos_clip + batch.clip_embedding_neg = neg_clip + return batch + + def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """Verify stepvideo encoding stage inputs.""" + result = VerificationResult() + result.add_check("prompt", batch.prompt, V.string_not_empty) + return result + + def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """Verify stepvideo encoding stage outputs.""" + result = VerificationResult() + result.add_check( + "prompt_embeds", batch.prompt_embeds, [V.is_tensor, V.with_dims(3)] + ) + result.add_check( + "negative_prompt_embeds", + batch.negative_prompt_embeds, + [V.is_tensor, V.with_dims(3)], + ) + result.add_check( + "prompt_attention_mask", + batch.prompt_attention_mask, + [V.is_tensor, V.with_dims(2)], + ) + result.add_check( + "negative_attention_mask", + batch.negative_attention_mask, + [V.is_tensor, V.with_dims(2)], + ) + result.add_check( + "clip_embedding_pos", + batch.clip_embedding_pos, + [V.is_tensor, V.with_dims(2)], + ) + result.add_check( + "clip_embedding_neg", + batch.clip_embedding_neg, + [V.is_tensor, V.with_dims(2)], + ) + return result diff --git a/python/sglang/multimodal_gen/runtime/pipelines/stages/text_encoding.py b/python/sglang/multimodal_gen/runtime/pipelines/stages/text_encoding.py new file mode 100644 index 000000000..eff5ee1c9 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/pipelines/stages/text_encoding.py @@ -0,0 +1,326 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +Prompt encoding stages for diffusion pipelines. + +This module contains implementations of prompt encoding stages for diffusion pipelines. +""" + +import torch + +from sglang.multimodal_gen.configs.models.encoders import BaseEncoderOutput +from sglang.multimodal_gen.configs.pipelines import FluxPipelineConfig +from sglang.multimodal_gen.runtime.distributed import get_local_torch_device +from sglang.multimodal_gen.runtime.managers.forward_context import set_forward_context +from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import Req +from sglang.multimodal_gen.runtime.pipelines.stages.base import PipelineStage +from sglang.multimodal_gen.runtime.pipelines.stages.validators import ( + StageValidators as V, +) +from sglang.multimodal_gen.runtime.pipelines.stages.validators import VerificationResult +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class TextEncodingStage(PipelineStage): + """ + Stage for encoding text prompts into embeddings for diffusion models. + + This stage handles the encoding of text prompts into the embedding space + expected by the diffusion model. + """ + + def __init__(self, text_encoders, tokenizers) -> None: + """ + Initialize the prompt encoding stage. + + """ + super().__init__() + self.tokenizers = tokenizers + self.text_encoders = text_encoders + + @torch.no_grad() + def forward( + self, + batch: Req, + server_args: ServerArgs, + ) -> Req: + """ + Encode the prompt into text encoder hidden states. + + Args: + batch: The current batch information. + server_args: The inference arguments. + + Returns: + The batch with encoded prompt embeddings. + """ + assert len(self.tokenizers) == len(self.text_encoders) + assert len(self.text_encoders) == len( + server_args.pipeline_config.text_encoder_configs + ) + + # Encode positive prompt with all available encoders + assert batch.prompt is not None + prompt_text: str | list[str] = batch.prompt + + all_indices: list[int] = list(range(len(self.text_encoders))) + + prompt_embeds_list, prompt_masks_list, pooler_embeds_list = self.encode_text( + prompt_text, + server_args, + encoder_index=all_indices, + return_attention_mask=True, + ) + + for pe in prompt_embeds_list: + batch.prompt_embeds.append(pe) + + for pe in pooler_embeds_list: + batch.pooled_embeds.append(pe) + if batch.prompt_attention_mask is not None: + for am in prompt_masks_list: + batch.prompt_attention_mask.append(am) + + # Encode negative prompt if CFG is enabled + if batch.do_classifier_free_guidance: + assert isinstance(batch.negative_prompt, str) + neg_embeds_list, neg_masks_list, neg_pooler_embeds_list = self.encode_text( + batch.negative_prompt, + server_args, + encoder_index=all_indices, + return_attention_mask=True, + ) + + assert batch.negative_prompt_embeds is not None + + for ne in neg_embeds_list: + batch.negative_prompt_embeds.append(ne) + + for pe in neg_pooler_embeds_list: + batch.neg_pooled_embeds.append(pe) + if batch.negative_attention_mask is not None: + for nm in neg_masks_list: + batch.negative_attention_mask.append(nm) + + return batch + + def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """Verify text encoding stage inputs.""" + result = VerificationResult() + result.add_check("prompt", batch.prompt, V.string_or_list_strings) + result.add_check( + "negative_prompt", + batch.negative_prompt, + lambda x: not batch.do_classifier_free_guidance or V.string_not_none(x), + ) + result.add_check( + "do_classifier_free_guidance", + batch.do_classifier_free_guidance, + V.bool_value, + ) + result.add_check("prompt_embeds", batch.prompt_embeds, V.is_list) + result.add_check( + "negative_prompt_embeds", batch.negative_prompt_embeds, V.none_or_list + ) + return result + + def prepare_tokenizer_kwargs(self, tokenizer_kwargs, **kwargs): + tok_kwargs = tokenizer_kwargs | kwargs + + return tok_kwargs + + @torch.no_grad() + def encode_text( + self, + text: str | list[str], + server_args: ServerArgs, + encoder_index: int | list[int] | None = None, + return_attention_mask: bool = False, + return_type: str = "list", # one of: "list", "dict", "stack" + device: torch.device | str | None = None, + dtype: torch.dtype | None = None, + max_length: int | None = None, + truncation: bool | None = None, + padding: bool | str | None = None, + return_overflowing_tokens=None, + return_length=None, + ): + """ + Encode plain text using selected text encoder(s) and return embeddings. + + Args: + text: A single string or a list of strings to encode. + server_args: The inference arguments providing pipeline config, + including tokenizer and encoder settings, preprocess and postprocess + functions. + encoder_index: Encoder selector by index. Accepts an int or list of ints. + return_attention_mask: If True, also return attention masks for each + selected encoder. + return_type: "list" (default) returns a list aligned with selection; + "dict" returns a dict keyed by encoder index as a string; "stack" stacks along a + new first dimension (requires matching shapes). + device: Optional device override for inputs; defaults to local torch device. + dtype: Optional dtype to cast returned embeddings to. + max_length: Optional per-call tokenizer override. + truncation: Optional per-call tokenizer override. + padding: Optional per-call tokenizer override. + + Returns: + Depending on return_type and return_attention_mask: + - list: List[Tensor] or (List[Tensor], List[Tensor]) + - dict: Dict[str, Tensor] or (Dict[str, Tensor], Dict[str, Tensor]) + - stack: Tensor of shape [num_encoders, ...] or a tuple with stacked + attention masks + """ + + assert len(self.tokenizers) == len(self.text_encoders) + assert len(self.text_encoders) == len( + server_args.pipeline_config.text_encoder_configs + ) + + # Resolve selection into indices + encoder_cfgs = server_args.pipeline_config.text_encoder_configs + if encoder_index is None: + indices: list[int] = [0] + elif isinstance(encoder_index, int): + indices = [encoder_index] + else: + indices = list(encoder_index) + # validate range + num_encoders = len(self.text_encoders) + for idx in indices: + if idx < 0 or idx >= num_encoders: + raise IndexError( + f"encoder index {idx} out of range [0, {num_encoders - 1}]" + ) + + # Validate indices are within range + num_encoders = len(self.text_encoders) + + # Normalize input to list[str] + assert isinstance(text, str | list) + if isinstance(text, str): + texts: list[str] = [text] + else: + texts = text + + embeds_list: list[torch.Tensor] = [] + pooled_embeds_list: list[torch.Tensor] = [] + + attn_masks_list: list[torch.Tensor] = [] + + preprocess_funcs = server_args.pipeline_config.preprocess_text_funcs + postprocess_funcs = server_args.pipeline_config.postprocess_text_funcs + text_encoder_extra_args = server_args.pipeline_config.text_encoder_extra_args + encoder_cfgs = server_args.pipeline_config.text_encoder_configs + + if return_type not in ("list", "dict", "stack"): + raise ValueError( + f"Invalid return_type '{return_type}'. Expected one of: 'list', 'dict', 'stack'" + ) + + target_device = device if device is not None else get_local_torch_device() + + for i in indices: + tokenizer = self.tokenizers[i] + text_encoder = self.text_encoders[i] + encoder_config = encoder_cfgs[i] + preprocess_func = preprocess_funcs[i] + postprocess_func = postprocess_funcs[i] + text_encoder_extra_arg = ( + text_encoder_extra_args[i] + if i < len(text_encoder_extra_args) and text_encoder_extra_args[i] + else {} + ) + + processed_texts: list[str] = [] + for prompt_str in texts: + processed_texts.append(preprocess_func(prompt_str)) + + # Prepare tokenizer args + tok_kwargs = self.prepare_tokenizer_kwargs( + encoder_config.tokenizer_kwargs, + **text_encoder_extra_arg, + ) + + text_inputs = tokenizer(processed_texts, **tok_kwargs).to(target_device) + input_ids = text_inputs["input_ids"] + is_flux = isinstance(server_args.pipeline_config, FluxPipelineConfig) + is_flux_t5 = is_flux and i == 1 + + if is_flux_t5: + attention_mask = torch.ones(input_ids.shape[:2], device=target_device) + else: + attention_mask = text_inputs["attention_mask"] + with set_forward_context(current_timestep=0, attn_metadata=None): + outputs: BaseEncoderOutput = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + ) + prompt_embeds = postprocess_func(outputs, text_inputs) + if dtype is not None: + prompt_embeds = prompt_embeds.to(dtype=dtype) + + embeds_list.append(prompt_embeds) + if is_flux: + pooled_embeds_list.append(outputs.pooler_output) + if return_attention_mask: + attn_masks_list.append(attention_mask) + + # Shape results according to return_type + if return_type == "list": + if return_attention_mask: + return embeds_list, attn_masks_list, pooled_embeds_list + return embeds_list, pooled_embeds_list + + if return_type == "dict": + key_strs = [str(i) for i in indices] + embeds_dict = {k: v for k, v in zip(key_strs, embeds_list, strict=False)} + if return_attention_mask: + attn_dict = { + k: v for k, v in zip(key_strs, attn_masks_list, strict=False) + } + return embeds_dict, attn_dict + return embeds_dict + + # return_type == "stack" + # Validate shapes are compatible + base_shape = list(embeds_list[0].shape) + for t in embeds_list[1:]: + if list(t.shape) != base_shape: + raise ValueError( + f"Cannot stack embeddings with differing shapes: {[list(t.shape) for t in embeds_list]}" + ) + stacked_embeds = torch.stack(embeds_list, dim=0) + if return_attention_mask: + base_mask_shape = list(attn_masks_list[0].shape) + for m in attn_masks_list[1:]: + if list(m.shape) != base_mask_shape: + raise ValueError( + f"Cannot stack attention masks with differing shapes: {[list(m.shape) for m in attn_masks_list]}" + ) + stacked_masks = torch.stack(attn_masks_list, dim=0) + return stacked_embeds, stacked_masks + return stacked_embeds + + def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """Verify text encoding stage outputs.""" + result = VerificationResult() + result.add_check( + "prompt_embeds", batch.prompt_embeds, V.list_of_tensors_min_dims(2) + ) + result.add_check( + "negative_prompt_embeds", + batch.negative_prompt_embeds, + lambda x: not batch.do_classifier_free_guidance + or V.list_of_tensors_with_min_dims(x, 2), + ) + if batch.debug: + logger.debug(f"{batch.prompt_embeds=}") + logger.debug(f"{batch.negative_prompt_embeds=}") + return result diff --git a/python/sglang/multimodal_gen/runtime/pipelines/stages/timestep_preparation.py b/python/sglang/multimodal_gen/runtime/pipelines/stages/timestep_preparation.py new file mode 100644 index 000000000..09c5d22ee --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/pipelines/stages/timestep_preparation.py @@ -0,0 +1,163 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +Timestep preparation stages for diffusion pipelines. + +This module contains implementations of timestep preparation stages for diffusion pipelines. +""" + +import inspect +from typing import Any, Callable, Tuple + +import numpy as np + +from sglang.multimodal_gen.configs.pipelines import FluxPipelineConfig +from sglang.multimodal_gen.configs.pipelines.qwen_image import ( + QwenImageEditPipelineConfig, + QwenImagePipelineConfig, +) +from sglang.multimodal_gen.runtime.distributed import get_local_torch_device +from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import Req +from sglang.multimodal_gen.runtime.pipelines.stages.base import ( + PipelineStage, + StageParallelismType, +) +from sglang.multimodal_gen.runtime.pipelines.stages.validators import ( + StageValidators as V, +) +from sglang.multimodal_gen.runtime.pipelines.stages.validators import VerificationResult +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class TimestepPreparationStage(PipelineStage): + """ + Stage for preparing timesteps for the diffusion process. + + This stage handles the preparation of the timestep sequence that will be used + during the diffusion process. + """ + + def __init__( + self, + scheduler, + prepare_extra_set_timesteps_kwargs: list[ + Callable[[Req, ServerArgs], Tuple[str, Any]] + ] = [], + ) -> None: + self.scheduler = scheduler + self.prepare_extra_set_timesteps_kwargs = prepare_extra_set_timesteps_kwargs + + @property + def parallelism_type(self) -> StageParallelismType: + return StageParallelismType.REPLICATED + + def forward( + self, + batch: Req, + server_args: ServerArgs, + ) -> Req: + """ + Prepare timesteps for the diffusion process. + + Args: + batch: The current batch information. + server_args: The inference arguments. + + Returns: + The batch with prepared timesteps. + """ + scheduler = self.scheduler + device = get_local_torch_device() + num_inference_steps = batch.num_inference_steps + timesteps = batch.timesteps + sigmas = batch.sigmas + n_tokens = batch.n_tokens + + is_flux = ( + isinstance(server_args.pipeline_config, FluxPipelineConfig) + or isinstance(server_args.pipeline_config, QwenImagePipelineConfig) + or isinstance(server_args.pipeline_config, QwenImageEditPipelineConfig) + ) + if is_flux: + sigmas = ( + np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + if sigmas is None + else sigmas + ) + + # Prepare extra kwargs for set_timesteps + extra_set_timesteps_kwargs = {} + if ( + n_tokens is not None + and "n_tokens" in inspect.signature(scheduler.set_timesteps).parameters + ): + extra_set_timesteps_kwargs["n_tokens"] = n_tokens + + for callee in self.prepare_extra_set_timesteps_kwargs: + key, value = callee(batch, server_args) + assert isinstance(key, str) + extra_set_timesteps_kwargs[key] = value + + # Handle custom timesteps or sigmas + if timesteps is not None and sigmas is not None: + raise ValueError( + "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" + ) + + if timesteps is not None: + accepts_timesteps = ( + "timesteps" in inspect.signature(scheduler.set_timesteps).parameters + ) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps( + timesteps=timesteps, device=device, **extra_set_timesteps_kwargs + ) + timesteps = scheduler.timesteps + elif sigmas is not None: + accept_sigmas = ( + "sigmas" in inspect.signature(scheduler.set_timesteps).parameters + ) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps( + sigmas=sigmas, device=device, **extra_set_timesteps_kwargs + ) + timesteps = scheduler.timesteps + else: + scheduler.set_timesteps( + num_inference_steps, device=device, **extra_set_timesteps_kwargs + ) + timesteps = scheduler.timesteps + + # Update batch with prepared timesteps + batch.timesteps = timesteps + self.log_debug(f"timesteps: {timesteps}") + return batch + + def verify_input(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """Verify timestep preparation stage inputs.""" + result = VerificationResult() + result.add_check( + "num_inference_steps", batch.num_inference_steps, V.positive_int + ) + result.add_check("timesteps", batch.timesteps, V.none_or_tensor) + result.add_check("sigmas", batch.sigmas, V.none_or_list) + result.add_check("n_tokens", batch.n_tokens, V.none_or_positive_int) + return result + + def verify_output(self, batch: Req, server_args: ServerArgs) -> VerificationResult: + """Verify timestep preparation stage outputs.""" + result = VerificationResult() + result.add_check("timesteps", batch.timesteps, [V.is_tensor, V.with_dims(1)]) + return result diff --git a/python/sglang/multimodal_gen/runtime/pipelines/stages/validators.py b/python/sglang/multimodal_gen/runtime/pipelines/stages/validators.py new file mode 100644 index 000000000..1ca9e992d --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/pipelines/stages/validators.py @@ -0,0 +1,522 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +""" +Common validators for pipeline stage verification. + +This module provides reusable validation functions that can be used across +all pipeline stages for input/output verification. +""" + +from collections.abc import Callable +from typing import Any + +import torch + + +class StageValidators: + """Common validators for pipeline stages.""" + + @staticmethod + def not_none(value: Any) -> bool: + """Check if value is not None.""" + return value is not None + + @staticmethod + def positive_int(value: Any) -> bool: + """Check if value is a positive integer.""" + return isinstance(value, int) and value > 0 + + @staticmethod + def non_negative_int(value: Any) -> bool: + """Check if value is a non-negative float.""" + return isinstance(value, int | float) and value >= 0 + + @staticmethod + def positive_float(value: Any) -> bool: + """Check if value is a positive float.""" + return isinstance(value, int | float) and value > 0 + + @staticmethod + def non_negative_float(value: Any) -> bool: + """Check if value is a non-negative float.""" + return isinstance(value, int | float) and value >= 0 + + @staticmethod + def divisible_by(value: Any, divisor: int) -> bool: + """Check if value is divisible by divisor.""" + return value is not None and isinstance(value, int) and value % divisor == 0 + + @staticmethod + def is_tensor(value: Any) -> bool: + """Check if value is a torch tensor and doesn't contain NaN values.""" + if not isinstance(value, torch.Tensor): + return False + return not torch.isnan(value).any().item() + + @staticmethod + def tensor_with_dims(value: Any, dims: int) -> bool: + """Check if value is a tensor with specific dimensions and no NaN values.""" + if not isinstance(value, torch.Tensor): + return False + if value.dim() != dims: + return False + return not torch.isnan(value).any().item() + + @staticmethod + def tensor_min_dims(value: Any, min_dims: int) -> bool: + """Check if value is a tensor with at least min_dims dimensions and no NaN values.""" + if not isinstance(value, torch.Tensor): + return False + if value.dim() < min_dims: + return False + return not torch.isnan(value).any().item() + + @staticmethod + def tensor_shape_matches(value: Any, expected_shape: tuple) -> bool: + """Check if tensor shape matches expected shape (None for any size) and no NaN values.""" + if not isinstance(value, torch.Tensor): + return False + if len(value.shape) != len(expected_shape): + return False + for actual, expected in zip(value.shape, expected_shape, strict=True): + if expected is not None and actual != expected: + return False + return not torch.isnan(value).any().item() + + @staticmethod + def list_not_empty(value: Any) -> bool: + """Check if value is a non-empty list.""" + return isinstance(value, list) and len(value) > 0 + + @staticmethod + def list_length(value: Any, length: int) -> bool: + """Check if list has specific length.""" + return isinstance(value, list) and len(value) == length + + @staticmethod + def list_min_length(value: Any, min_length: int) -> bool: + """Check if list has at least min_length items.""" + return isinstance(value, list) and len(value) >= min_length + + @staticmethod + def string_not_empty(value: Any) -> bool: + """Check if value is a non-empty string.""" + return isinstance(value, str) and len(value.strip()) > 0 + + @staticmethod + def string_not_none(value: Any) -> bool: + """Check if value is a non-empty string.""" + return isinstance(value, str) and len(value) > 0 + + @staticmethod + def string_or_list_strings(value: Any) -> bool: + """Check if value is a string or list of strings.""" + if isinstance(value, str): + return True + if isinstance(value, list): + return all(isinstance(item, str) for item in value) + return False + + @staticmethod + def bool_value(value: Any) -> bool: + """Check if value is a boolean.""" + return isinstance(value, bool) + + @staticmethod + def generator_or_list_generators(value: Any) -> bool: + """Check if value is a Generator or list of Generators.""" + if isinstance(value, torch.Generator): + return True + if isinstance(value, list): + return all(isinstance(item, torch.Generator) for item in value) + return False + + @staticmethod + def is_list(value: Any) -> bool: + """Check if value is a list (can be empty).""" + return isinstance(value, list) + + @staticmethod + def is_tuple(value: Any) -> bool: + """Check if value is a tuple.""" + return isinstance(value, tuple) + + @staticmethod + def none_or_tensor(value: Any) -> bool: + """Check if value is None or a tensor without NaN values.""" + if value is None: + return True + if not isinstance(value, torch.Tensor): + return False + return not torch.isnan(value).any().item() + + @staticmethod + def list_of_tensors_with_dims(value: Any, dims: int) -> bool: + """Check if value is a non-empty list where all items are tensors with specific dimensions and no NaN values.""" + if not isinstance(value, list) or len(value) == 0: + return False + for item in value: + if not isinstance(item, torch.Tensor): + return False + if item.dim() != dims: + return False + if torch.isnan(item).any().item(): + return False + return True + + @staticmethod + def list_of_tensors(value: Any) -> bool: + """Check if value is a non-empty list where all items are tensors without NaN values.""" + if not isinstance(value, list) or len(value) == 0: + return False + for item in value: + if not isinstance(item, torch.Tensor): + return False + if torch.isnan(item).any().item(): + return False + return True + + @staticmethod + def list_of_tensors_with_min_dims(value: Any, min_dims: int) -> bool: + """Check if value is a non-empty list where all items are tensors with at least min_dims dimensions and no NaN values.""" + if not isinstance(value, list) or len(value) == 0: + return False + for item in value: + if not isinstance(item, torch.Tensor): + return False + if item.dim() < min_dims: + return False + if torch.isnan(item).any().item(): + return False + return True + + @staticmethod + def none_or_tensor_with_dims(dims: int) -> Callable[[Any], bool]: + """Return a validator that checks if value is None or a tensor with specific dimensions and no NaN values.""" + + def validator(value: Any) -> bool: + if value is None: + return True + if not isinstance(value, torch.Tensor): + return False + if value.dim() != dims: + return False + return not torch.isnan(value).any().item() + + return validator + + @staticmethod + def none_or_list(value: Any) -> bool: + """Check if value is None or a list.""" + return value is None or isinstance(value, list) + + @staticmethod + def none_or_positive_int(value: Any) -> bool: + """Check if value is None or a positive integer.""" + return value is None or (isinstance(value, int) and value > 0) + + # Helper methods that return functions for common patterns + @staticmethod + def with_dims(dims: int) -> Callable[[Any], bool]: + """Return a validator that checks if tensor has specific dimensions and no NaN values.""" + + def validator(value: Any) -> bool: + return StageValidators.tensor_with_dims(value, dims) + + return validator + + @staticmethod + def min_dims(min_dims: int) -> Callable[[Any], bool]: + """Return a validator that checks if tensor has at least min_dims dimensions and no NaN values.""" + + def validator(value: Any) -> bool: + return StageValidators.tensor_min_dims(value, min_dims) + + return validator + + @staticmethod + def divisible(divisor: int) -> Callable[[Any], bool]: + """Return a validator that checks if value is divisible by divisor.""" + + def validator(value: Any) -> bool: + return StageValidators.divisible_by(value, divisor) + + return validator + + @staticmethod + def positive_int_divisible(divisor: int) -> Callable[[Any], bool]: + """Return a validator that checks if value is a positive integer divisible by divisor.""" + + def validator(value: Any) -> bool: + return ( + isinstance(value, int) + and value > 0 + and StageValidators.divisible_by(value, divisor) + ) + + return validator + + @staticmethod + def list_of_tensors_dims(dims: int) -> Callable[[Any], bool]: + """Return a validator that checks if value is a list of tensors with specific dimensions and no NaN values.""" + + def validator(value: Any) -> bool: + return StageValidators.list_of_tensors_with_dims(value, dims) + + return validator + + @staticmethod + def list_of_tensors_min_dims(min_dims: int) -> Callable[[Any], bool]: + """Return a validator that checks if value is a list of tensors with at least min_dims dimensions and no NaN values.""" + + def validator(value: Any) -> bool: + return StageValidators.list_of_tensors_with_min_dims(value, min_dims) + + return validator + + +class ValidationFailure: + """Details about a specific validation failure.""" + + def __init__( + self, + validator_name: str, + actual_value: Any, + expected: str | None = None, + error_msg: str | None = None, + ): + self.validator_name = validator_name + self.actual_value = actual_value + self.expected = expected + self.error_msg = error_msg + + def __str__(self) -> str: + parts = [f"Validator '{self.validator_name}' failed"] + + if self.error_msg: + parts.append(f"Error: {self.error_msg}") + + # Add actual value info (but limit very long representations) + actual_str = self._format_value(self.actual_value) + parts.append(f"Actual: {actual_str}") + + if self.expected: + parts.append(f"Expected: {self.expected}") + + return ". ".join(parts) + + def _format_value(self, value: Any) -> str: + """Format a value for display in error messages.""" + if value is None: + return "None" + elif isinstance(value, torch.Tensor): + return f"tensor(shape={list(value.shape)}, dtype={value.dtype})" + elif isinstance(value, list): + if len(value) == 0: + return "[]" + elif len(value) <= 3: + item_strs = [self._format_value(item) for item in value] + return f"[{', '.join(item_strs)}]" + else: + return f"list(length={len(value)}, first_item={self._format_value(value[0])})" + elif isinstance(value, str): + if len(value) > 50: + return f"'{value[:47]}...'" + else: + return f"'{value}'" + else: + return f"{type(value).__name__}({value})" + + +class VerificationResult: + """Wrapper class for stage verification results.""" + + def __init__(self) -> None: + self._checks: dict[str, bool] = {} + self._failures: dict[str, list[ValidationFailure]] = {} + + def add_check( + self, + field_name: str, + value: Any, + validators: Callable[[Any], bool] | list[Callable[[Any], bool]], + ) -> "VerificationResult": + """ + Add a validation check for a field. + + Args: + field_name: Name of the field being checked + value: The actual value to validate + validators: Single validation function or list of validation functions. + Each function will be called with the value as its first argument. + + Returns: + Self for method chaining + + Examples: + # Single validator + result.add_check("tensor", my_tensor, V.is_tensor) + + # Multiple validators (all must pass) + result.add_check("latents", batch.latents, [V.is_tensor, V.with_dims(5)]) + + # Using partial functions for parameters + result.add_check("height", batch.height, [V.not_none, V.divisible(8)]) + """ + if not isinstance(validators, list): + validators = [validators] + + failures = [] + all_passed = True + + # Apply all validators and collect detailed failure info + for validator in validators: + try: + passed = validator(value) + if not passed: + all_passed = False + failure = self._create_validation_failure(validator, value) + failures.append(failure) + except Exception as e: + # If any validator raises an exception, consider the check failed + all_passed = False + validator_name = getattr(validator, "__name__", str(validator)) + failure = ValidationFailure( + validator_name=validator_name, + actual_value=value, + error_msg=f"Exception during validation: {str(e)}", + ) + failures.append(failure) + + self._checks[field_name] = all_passed + if not all_passed: + self._failures[field_name] = failures + + return self + + def _create_validation_failure( + self, validator: Callable, value: Any + ) -> ValidationFailure: + """Create a ValidationFailure with detailed information.""" + validator_name = getattr(validator, "__name__", str(validator)) + + # Try to extract meaningful expected value info based on validator type + expected = None + error_msg = None + + # Handle common validator patterns + if hasattr(validator, "__closure__") and validator.__closure__: + # This is likely a closure (like our helper functions) + if "dims" in validator_name or "with_dims" in str(validator): + if isinstance(value, torch.Tensor): + expected = f"tensor with {validator.__closure__[0].cell_contents} dimensions" + else: + expected = "tensor with specific dimensions" + elif "divisible" in str(validator): + expected = ( + f"integer divisible by {validator.__closure__[0].cell_contents}" + ) + + # Handle specific validator types and check for NaN values + if validator_name == "is_tensor": + expected = "torch.Tensor without NaN values" + if isinstance(value, torch.Tensor) and torch.isnan(value).any().item(): + error_msg = ( + f"tensor contains {torch.isnan(value).sum().item()} NaN values" + ) + elif validator_name == "positive_int": + expected = "positive integer" + elif validator_name == "not_none": + expected = "non-None value" + elif validator_name == "list_not_empty": + expected = "non-empty list" + elif validator_name == "bool_value": + expected = "boolean value" + elif ( + "tensor_with_dims" in validator_name or "tensor_min_dims" in validator_name + ): + if isinstance(value, torch.Tensor): + if torch.isnan(value).any().item(): + error_msg = f"tensor has {value.dim()} dimensions but contains {torch.isnan(value).sum().item()} NaN values" + else: + error_msg = f"tensor has {value.dim()} dimensions" + elif validator_name == "is_list": + expected = "list" + elif validator_name == "none_or_tensor": + expected = "None or tensor without NaN values" + if isinstance(value, torch.Tensor) and torch.isnan(value).any().item(): + error_msg = ( + f"tensor contains {torch.isnan(value).sum().item()} NaN values" + ) + elif validator_name == "list_of_tensors": + expected = "non-empty list of tensors without NaN values" + if isinstance(value, list) and len(value) > 0: + nan_count = 0 + for item in value: + if ( + isinstance(item, torch.Tensor) + and torch.isnan(item).any().item() + ): + nan_count += torch.isnan(item).sum().item() + if nan_count > 0: + error_msg = ( + f"list contains tensors with total {nan_count} NaN values" + ) + elif "list_of_tensors_with_dims" in validator_name: + expected = ( + "non-empty list of tensors with specific dimensions and no NaN values" + ) + if isinstance(value, list) and len(value) > 0: + nan_count = 0 + for item in value: + if ( + isinstance(item, torch.Tensor) + and torch.isnan(item).any().item() + ): + nan_count += torch.isnan(item).sum().item() + if nan_count > 0: + error_msg = ( + f"list contains tensors with total {nan_count} NaN values" + ) + + return ValidationFailure( + validator_name=validator_name, + actual_value=value, + expected=expected, + error_msg=error_msg, + ) + + def is_valid(self) -> bool: + """Check if all validations passed.""" + return all(self._checks.values()) + + def get_failed_fields(self) -> list[str]: + """Get list of fields that failed validation.""" + return [field for field, passed in self._checks.items() if not passed] + + def get_detailed_failures(self) -> dict[str, list[ValidationFailure]]: + """Get detailed failure information for each failed field.""" + return self._failures.copy() + + def get_failure_summary(self) -> str: + """Get a comprehensive summary of all validation failures.""" + if self.is_valid(): + return "All validations passed" + + summary_parts = [] + for field_name, failures in self._failures.items(): + field_summary = f"\n Field '{field_name}':" + for i, failure in enumerate(failures, 1): + field_summary += f"\n {i}. {failure}" + summary_parts.append(field_summary) + + return "Validation failures:" + "".join(summary_parts) + + def to_dict(self) -> dict: + """Convert to dictionary for backward compatibility.""" + return self._checks.copy() + + +# Alias for convenience +V = StageValidators diff --git a/python/sglang/multimodal_gen/runtime/platforms/__init__.py b/python/sglang/multimodal_gen/runtime/platforms/__init__.py new file mode 100644 index 000000000..c87fb3aa9 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/platforms/__init__.py @@ -0,0 +1,172 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/platforms/__init__.py + +import traceback +from typing import TYPE_CHECKING + +# imported by other files, do not remove +from sglang.multimodal_gen.runtime.platforms.interface import ( # noqa: F401 + AttentionBackendEnum, + Platform, + PlatformEnum, +) +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import resolve_obj_by_qualname + +logger = init_logger(__name__) + + +def cuda_platform_plugin() -> str | None: + is_cuda = False + + try: + from sglang.multimodal_gen.utils import import_pynvml + + pynvml = import_pynvml() # type: ignore[no-untyped-call] + pynvml.nvmlInit() + try: + # NOTE: Edge case: sgl_diffusion cpu build on a GPU machine. + # Third-party pynvml can be imported in cpu build, + # we need to check if sgl_diffusion is built with cpu too. + # Otherwise, sgl_diffusion will always activate cuda plugin + # on a GPU machine, even if in a cpu build. + is_cuda = pynvml.nvmlDeviceGetCount() > 0 + finally: + pynvml.nvmlShutdown() + except Exception as e: + if "nvml" not in e.__class__.__name__.lower(): + # If the error is not related to NVML, re-raise it. + raise e + + # CUDA is supported on Jetson, but NVML may not be. + import os + + def cuda_is_jetson() -> bool: + return os.path.isfile("/etc/nv_tegra_release") or os.path.exists( + "/sys/class/tegra-firmware" + ) + + if cuda_is_jetson(): + is_cuda = True + if is_cuda: + logger.info("CUDA is available") + + return ( + "sglang.multimodal_gen.runtime.platforms.cuda.CudaPlatform" if is_cuda else None + ) + + +def mps_platform_plugin() -> str | None: + """Detect if MPS (Metal Performance Shaders) is available on macOS.""" + is_mps = False + + try: + import torch + + if torch.backends.mps.is_available(): + is_mps = True + logger.info("MPS (Metal Performance Shaders) is available") + except Exception as e: + logger.info("MPS detection failed: %s", e) + + return "sglang.multimodal_gen.runtime.platforms.mps.MpsPlatform" if is_mps else None + + +def cpu_platform_plugin() -> str | None: + """Detect if CPU platform should be used.""" + # CPU is always available as a fallback + return "sglang.multimodal_gen.runtime.platforms.cpu.CpuPlatform" + + +def rocm_platform_plugin() -> str | None: + is_rocm = False + + try: + import amdsmi + + amdsmi.amdsmi_init() + try: + if len(amdsmi.amdsmi_get_processor_handles()) > 0: + is_rocm = True + logger.info("ROCm platform is available") + finally: + amdsmi.amdsmi_shut_down() + except Exception as e: + logger.info("ROCm platform is unavailable: %s", e) + + return ( + "sglang.multimodal_gen.runtime.platforms.rocm.RocmPlatform" if is_rocm else None + ) + + +builtin_platform_plugins = { + "cuda": cuda_platform_plugin, + "rocm": rocm_platform_plugin, + "mps": mps_platform_plugin, + "cpu": cpu_platform_plugin, +} + + +def resolve_current_platform_cls_qualname() -> str: + # TODO(will): if we need to support other platforms, we should consider if + # vLLM's plugin architecture is suitable for our needs. + + # Try MPS first on macOS + platform_cls_qualname = mps_platform_plugin() + if platform_cls_qualname is not None: + return platform_cls_qualname + + # Fall back to ROCm + platform_cls_qualname = rocm_platform_plugin() + if platform_cls_qualname is not None: + return platform_cls_qualname + + # Fall back to CUDA + platform_cls_qualname = cuda_platform_plugin() + if platform_cls_qualname is not None: + return platform_cls_qualname + + # Fall back to CPU as last resort + platform_cls_qualname = cpu_platform_plugin() + if platform_cls_qualname is not None: + return platform_cls_qualname + + raise RuntimeError("No platform plugin found. Please check your " "installation.") + + +_current_platform: Platform | None = None +_init_trace: str = "" + +if TYPE_CHECKING: + current_platform: Platform + + +def __getattr__(name: str): + if name == "current_platform": + # lazy init current_platform. + # 1. out-of-tree platform plugins need `from sglang.multimodal_gen.runtime.platforms import + # Platform` so that they can inherit `Platform` class. Therefore, + # we cannot resolve `current_platform` during the import of + # `sglang.multimodal_gen.runtime.platforms`. + # 2. when users use out-of-tree platform plugins, they might run + # `import sgl_diffusion`, some sgl_diffusion internal code might access + # `current_platform` during the import, and we need to make sure + # `current_platform` is only resolved after the plugins are loaded + # (we have tests for this, if any developer violate this, they will + # see the test failures). + global _current_platform + if _current_platform is None: + platform_cls_qualname = resolve_current_platform_cls_qualname() + _current_platform = resolve_obj_by_qualname(platform_cls_qualname)() + global _init_trace + _init_trace = "".join(traceback.format_stack()) + return _current_platform + elif name in globals(): + return globals()[name] + else: + raise AttributeError(f"No attribute named '{name}' exists in {__name__}.") + + +__all__ = ["Platform", "PlatformEnum", "current_platform", "_init_trace"] diff --git a/python/sglang/multimodal_gen/runtime/platforms/cpu.py b/python/sglang/multimodal_gen/runtime/platforms/cpu.py new file mode 100644 index 000000000..5186d2489 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/platforms/cpu.py @@ -0,0 +1,61 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/platforms/cpu.py + +import platform + +import torch + +from sglang.multimodal_gen.runtime.platforms.interface import ( + CpuArchEnum, + Platform, + PlatformEnum, +) + + +class CpuPlatform(Platform): + _enum = PlatformEnum.CPU + device_name = "CPU" + device_type = "cpu" + dispatch_key = "CPU" + + @classmethod + def get_cpu_architecture(cls) -> CpuArchEnum: + """Get the CPU architecture.""" + machine = platform.machine().lower() + if machine in ("x86_64", "amd64", "i386", "i686"): + return CpuArchEnum.X86 + elif machine in ("arm64", "aarch64"): + return CpuArchEnum.ARM + else: + return CpuArchEnum.UNSPECIFIED + + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: + return platform.processor() + + @classmethod + def get_device_uuid(cls, device_id: int = 0) -> str: + return platform.machine() + + @classmethod + def get_device_total_memory(cls, device_id: int = 0) -> int: + # This is a rough estimate for CPU memory + # In practice, you might want to use psutil or similar + return 0 + + @classmethod + def is_async_output_supported(cls, enforce_eager: bool | None) -> bool: + return True + + @classmethod + def get_current_memory_usage( + cls, device: torch.types.Device | None = None + ) -> float: + # For CPU, we can't easily get memory usage without additional libraries + return 0.0 + + @classmethod + def get_device_communicator_cls(cls) -> str: + return "sglang.multimodal_gen.runtime.distributed.device_communicators.cpu_communicator.CpuCommunicator" diff --git a/python/sglang/multimodal_gen/runtime/platforms/cuda.py b/python/sglang/multimodal_gen/runtime/platforms/cuda.py new file mode 100644 index 000000000..e3324b14a --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/platforms/cuda.py @@ -0,0 +1,430 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/platforms/cuda.py +"""Code inside this file can safely assume cuda platform, e.g. importing +pynvml. However, it should not initialize cuda context. +""" + +import os +from collections.abc import Callable +from functools import lru_cache, wraps +from typing import TypeVar + +import torch +from typing_extensions import ParamSpec + +from sglang.multimodal_gen.runtime.platforms.interface import ( + AttentionBackendEnum, + DeviceCapability, + Platform, + PlatformEnum, +) +from sglang.multimodal_gen.runtime.utils.common import is_blackwell +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import import_pynvml + +logger = init_logger(__name__) + +_P = ParamSpec("_P") +_R = TypeVar("_R") + +pynvml = import_pynvml() # type: ignore[no-untyped-call] + +# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models +# see https://github.com/huggingface/diffusers/issues/9704 for details +torch.backends.cuda.enable_cudnn_sdp(False) + + +def device_id_to_physical_device_id(device_id: int) -> int: + if "CUDA_VISIBLE_DEVICES" in os.environ: + device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",") + if device_ids == [""]: + msg = ( + "CUDA_VISIBLE_DEVICES is set to empty string, which means" + " GPU support is disabled. If you are using ray, please unset" + " the environment variable `CUDA_VISIBLE_DEVICES` inside the" + " worker/actor. " + "Check https://github.com/vllm-project/vllm/issues/8402 for" + " more information." + ) + raise RuntimeError(msg) + physical_device_id = device_ids[device_id] + return int(physical_device_id) + else: + return device_id + + +def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]: + @wraps(fn) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: + pynvml.nvmlInit() + try: + return fn(*args, **kwargs) + finally: + pynvml.nvmlShutdown() + + return wrapper + + +class CudaPlatformBase(Platform): + _enum = PlatformEnum.CUDA + device_name: str = "cuda" + device_type: str = "cuda" + dispatch_key: str = "CUDA" + device_control_env_var: str = "CUDA_VISIBLE_DEVICES" + + @classmethod + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None: + raise NotImplementedError + + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: + raise NotImplementedError + + @classmethod + def get_device_total_memory(cls, device_id: int = 0) -> int: + raise NotImplementedError + + @classmethod + def is_async_output_supported(cls, enforce_eager: bool | None) -> bool: + if enforce_eager: + logger.warning( + "To see benefits of async output processing, enable CUDA " + "graph. Since, enforce-eager is enabled, async output " + "processor cannot be used" + ) + return False + return True + + @classmethod + def is_full_nvlink(cls, device_ids: list[int]) -> bool: + raise NotImplementedError + + @classmethod + def log_warnings(cls) -> None: + pass + + @classmethod + def get_current_memory_usage( + cls, device: torch.types.Device | None = None + ) -> float: + torch.cuda.reset_peak_memory_stats(device) + return float(torch.cuda.max_memory_allocated(device)) + + @classmethod + def get_attn_backend_cls_str( + cls, + selected_backend: AttentionBackendEnum | None, + head_size: int, + dtype: torch.dtype, + ) -> str: + # TODO(will): maybe come up with a more general interface for local attention + # if distributed is False, we always try to use Flash attn + if selected_backend == AttentionBackendEnum.SLIDING_TILE_ATTN: + try: + from st_attn import sliding_tile_attention # noqa: F401 + + from sglang.multimodal_gen.runtime.layers.attention.backends.sliding_tile_attn import ( # noqa: F401 + SlidingTileAttentionBackend, + ) + + logger.info("Using Sliding Tile Attention backend.") + + return "sglang.multimodal_gen.runtime.layers.attention.backends.sliding_tile_attn.SlidingTileAttentionBackend" + except ImportError as e: + logger.error( + "Failed to import Sliding Tile Attention backend: %s", str(e) + ) + raise ImportError( + "Sliding Tile Attention backend is not installed. " + ) from e + elif selected_backend == AttentionBackendEnum.SAGE_ATTN: + try: + from sageattention import sageattn # noqa: F401 + + from sglang.multimodal_gen.runtime.layers.attention.backends.sage_attn import ( # noqa: F401 + SageAttentionBackend, + ) + + logger.info("Using Sage Attention backend.") + + return "sglang.multimodal_gen.runtime.layers.attention.backends.sage_attn.SageAttentionBackend" + except ImportError as e: + logger.info(e) + logger.info( + "Sage Attention backend is not installed. Fall back to Flash Attention." + ) + elif selected_backend == AttentionBackendEnum.SAGE_ATTN_THREE: + try: + from sglang.multimodal_gen.runtime.layers.attention.backends.sage_attn3 import ( # noqa: F401 + SageAttention3Backend, + ) + from sglang.multimodal_gen.runtime.layers.attention.backends.sageattn.api import ( # noqa: F401 + sageattn_blackwell, + ) + + logger.info("Using Sage Attention 3 backend.") + + return "sglang.multimodal_gen.runtime.layers.attention.backends.sage_attn3.SageAttention3Backend" + except ImportError as e: + logger.info(e) + logger.info( + "Sage Attention 3 backend is not installed. Fall back to Flash Attention." + ) + elif selected_backend == AttentionBackendEnum.VIDEO_SPARSE_ATTN: + try: + from vsa import block_sparse_attn # noqa: F401 + + from sglang.multimodal_gen.runtime.layers.attention.backends.video_sparse_attn import ( # noqa: F401 + VideoSparseAttentionBackend, + ) + + logger.info("Using Video Sparse Attention backend.") + + return "sglang.multimodal_gen.runtime.layers.attention.backends.video_sparse_attn.VideoSparseAttentionBackend" + except ImportError as e: + logger.error( + "Failed to import Video Sparse Attention backend: %s", str(e) + ) + raise ImportError( + "Video Sparse Attention backend is not installed. " + ) from e + elif selected_backend == AttentionBackendEnum.VMOBA_ATTN: + try: + from kernel.attn.vmoba_attn.vmoba import moba_attn_varlen # noqa: F401 + + from sglang.multimodal_gen.runtime.layers.attention.backends.vmoba import ( # noqa: F401 + VMOBAAttentionBackend, + ) + + logger.info("Using Video MOBA Attention backend.") + + return "sglang.multimodal_gen.runtime.layers.attention.backends.vmoba.VMOBAAttentionBackend" + except ImportError as e: + logger.error( + "Failed to import Video MoBA Attention backend: %s", str(e) + ) + raise ImportError( + "Video MoBA Attention backend is not installed. " + ) from e + elif selected_backend == AttentionBackendEnum.AITER: + logger.info("Using AITer backend.") + return "sglang.multimodal_gen.runtime.layers.attention.backends.aiter.AITerBackend" + elif selected_backend == AttentionBackendEnum.TORCH_SDPA: + logger.info("Using Torch SDPA backend.") + return "sglang.multimodal_gen.runtime.layers.attention.backends.sdpa.SDPABackend" + elif selected_backend == AttentionBackendEnum.FA3: + if is_blackwell(): + raise ValueError("The 'fa3' backend is not supported on Blackwell GPUs") + elif selected_backend: + raise ValueError(f"Invalid attention backend for {cls.device_name}") + else: + if is_blackwell(): + target_backend = AttentionBackendEnum.TORCH_SDPA + logger.debug(f"Use torch_sdpa as default backend") + else: + target_backend = AttentionBackendEnum.FA3 + logger.debug(f"Use fa3 as default backend") + + if not cls.has_device_capability(80): + logger.info( + "Cannot use FlashAttention-2 backend for Volta and Turing " "GPUs." + ) + target_backend = AttentionBackendEnum.TORCH_SDPA + elif dtype not in (torch.float16, torch.bfloat16): + logger.info( + "Cannot use FlashAttention-2 backend for dtype other than " + "torch.float16 or torch.bfloat16." + ) + target_backend = AttentionBackendEnum.TORCH_SDPA + + # FlashAttn is valid for the model, checking if the package is + # installed. + if target_backend == AttentionBackendEnum.FA3: + try: + from sglang.multimodal_gen.runtime.layers.attention.backends.flash_attn import ( # noqa: F401 + FlashAttentionBackend, + ) + + supported_sizes = FlashAttentionBackend.get_supported_head_sizes() + if head_size not in supported_sizes: + logger.info( + "Cannot use FlashAttention-2 backend for head size %d.", + head_size, + ) + target_backend = AttentionBackendEnum.TORCH_SDPA + except ImportError: + logger.info( + "Cannot use FlashAttention-2 backend because the " + "flash_attn package is not found. " + "Make sure that flash_attn was built and installed " + "(on by default)." + ) + target_backend = AttentionBackendEnum.TORCH_SDPA + + if target_backend == AttentionBackendEnum.TORCH_SDPA: + logger.info("Using Torch SDPA backend.") + + return "sglang.multimodal_gen.runtime.layers.attention.backends.sdpa.SDPABackend" + + logger.info("Using fa3 backend.") + + return "sglang.multimodal_gen.runtime.layers.attention.backends.flash_attn.FlashAttentionBackend" + + @classmethod + def get_device_communicator_cls(cls) -> str: + return "sglang.multimodal_gen.runtime.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa + + +# NVML utils +# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, +# all the related functions work on real physical device ids. +# the major benefit of using NVML is that it will not initialize CUDA +class NvmlCudaPlatform(CudaPlatformBase): + + @classmethod + @lru_cache(maxsize=8) + @with_nvml_context + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None: + try: + physical_device_id = device_id_to_physical_device_id(device_id) + handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id) + major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle) + return DeviceCapability(major=major, minor=minor) + except RuntimeError: + return None + + @classmethod + @lru_cache(maxsize=8) + @with_nvml_context + def has_device_capability( + cls, + capability: tuple[int, int] | int, + device_id: int = 0, + ) -> bool: + try: + return bool(super().has_device_capability(capability, device_id)) + except RuntimeError: + return False + + @classmethod + @lru_cache(maxsize=8) + @with_nvml_context + def get_device_name(cls, device_id: int = 0) -> str: + physical_device_id = device_id_to_physical_device_id(device_id) + return cls._get_physical_device_name(physical_device_id) + + @classmethod + @lru_cache(maxsize=8) + @with_nvml_context + def get_device_uuid(cls, device_id: int = 0) -> str: + physical_device_id = device_id_to_physical_device_id(device_id) + handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id) + return str(pynvml.nvmlDeviceGetUUID(handle)) + + @classmethod + @lru_cache(maxsize=8) + @with_nvml_context + def get_device_total_memory(cls, device_id: int = 0) -> int: + physical_device_id = device_id_to_physical_device_id(device_id) + handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id) + return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total) + + @classmethod + @with_nvml_context + def is_full_nvlink(cls, physical_device_ids: list[int]) -> bool: + """ + query if the set of gpus are fully connected by nvlink (1 hop) + """ + handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids] + for i, handle in enumerate(handles): + for j, peer_handle in enumerate(handles): + if i < j: + try: + p2p_status = pynvml.nvmlDeviceGetP2PStatus( + handle, + peer_handle, + pynvml.NVML_P2P_CAPS_INDEX_NVLINK, + ) + if p2p_status != pynvml.NVML_P2P_STATUS_OK: + return False + except pynvml.NVMLError: + logger.exception( + "NVLink detection failed. This is normal if" + " your machine has no NVLink equipped." + ) + return False + return True + + @classmethod + def _get_physical_device_name(cls, device_id: int = 0) -> str: + handle = pynvml.nvmlDeviceGetHandleByIndex(device_id) + return str(pynvml.nvmlDeviceGetName(handle)) + + @classmethod + @with_nvml_context + def log_warnings(cls) -> None: + device_ids: int = pynvml.nvmlDeviceGetCount() + if device_ids > 1: + device_names = [cls._get_physical_device_name(i) for i in range(device_ids)] + if ( + len(set(device_names)) > 1 + and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID" + ): + logger.warning( + "Detected different devices in the system: %s. Please" + " make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to " + "avoid unexpected behavior.", + ", ".join(device_names), + ) + + +class NonNvmlCudaPlatform(CudaPlatformBase): + + @classmethod + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: + major, minor = torch.cuda.get_device_capability(device_id) + return DeviceCapability(major=major, minor=minor) + + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: + return str(torch.cuda.get_device_name(device_id)) + + @classmethod + def get_device_total_memory(cls, device_id: int = 0) -> int: + device_props = torch.cuda.get_device_properties(device_id) + return int(device_props.total_memory) + + @classmethod + def is_full_nvlink(cls, physical_device_ids: list[int]) -> bool: + logger.exception( + "NVLink detection not possible, as context support was" + " not found. Assuming no NVLink available." + ) + return False + + +# Autodetect either NVML-enabled or non-NVML platform +# based on whether NVML is available. +nvml_available = False +try: + try: + pynvml.nvmlInit() + nvml_available = True + except Exception: + # On Jetson, NVML is not supported. + nvml_available = False +finally: + if nvml_available: + pynvml.nvmlShutdown() + +CudaPlatform = NvmlCudaPlatform if nvml_available else NonNvmlCudaPlatform + +try: + from sphinx.ext.autodoc.mock import _MockModule + + if not isinstance(pynvml, _MockModule): + CudaPlatform.log_warnings() +except ModuleNotFoundError: + CudaPlatform.log_warnings() diff --git a/python/sglang/multimodal_gen/runtime/platforms/interface.py b/python/sglang/multimodal_gen/runtime/platforms/interface.py new file mode 100644 index 000000000..68073b7fb --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/platforms/interface.py @@ -0,0 +1,252 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/platforms/interface.py +from __future__ import annotations + +import enum +import random +from typing import TYPE_CHECKING, NamedTuple + +import numpy as np +import torch + +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.utils import resolve_obj_by_qualname + +if TYPE_CHECKING: + from sglang.multimodal_gen.runtime.layers.attention.backends.attention_backend import ( + AttentionImpl, + ) + +logger = init_logger(__name__) + + +class AttentionBackendEnum(enum.Enum): + FA3 = enum.auto() + SLIDING_TILE_ATTN = enum.auto() + TORCH_SDPA = enum.auto() + SAGE_ATTN = enum.auto() + SAGE_ATTN_THREE = enum.auto() + VIDEO_SPARSE_ATTN = enum.auto() + VMOBA_ATTN = enum.auto() + AITER = enum.auto() + NO_ATTENTION = enum.auto() + + def __str__(self): + return self.name.lower() + + +class PlatformEnum(enum.Enum): + CUDA = enum.auto() + ROCM = enum.auto() + TPU = enum.auto() + CPU = enum.auto() + MPS = enum.auto() + OOT = enum.auto() + UNSPECIFIED = enum.auto() + + +class CpuArchEnum(enum.Enum): + X86 = enum.auto() + ARM = enum.auto() + UNSPECIFIED = enum.auto() + + +class DeviceCapability(NamedTuple): + major: int + minor: int + + def as_version_str(self) -> str: + return f"{self.major}.{self.minor}" + + def to_int(self) -> int: + """ + Express device capability as an integer ````. + + It is assumed that the minor version is always a single digit. + """ + assert 0 <= self.minor < 10 + return self.major * 10 + self.minor + + +class Platform: + _enum: PlatformEnum + device_name: str + device_type: str + + # available dispatch keys: + # check https://github.com/pytorch/pytorch/blob/313dac6c1ca0fa0cde32477509cce32089f8532a/torchgen/model.py#L134 # noqa + # use "CPU" as a fallback for platforms not registered in PyTorch + dispatch_key: str = "CPU" + + # The torch.compile backend for compiling simple and + # standalone functions. The default value is "inductor" to keep + # the same behavior as PyTorch. + # NOTE: for the forward part of the model, vLLM has another separate + # compilation strategy. + simple_compile_backend: str = "inductor" + + supported_quantization: list[str] = [] + + def is_cuda(self) -> bool: + return self._enum == PlatformEnum.CUDA + + def is_rocm(self) -> bool: + return self._enum == PlatformEnum.ROCM + + def is_tpu(self) -> bool: + return self._enum == PlatformEnum.TPU + + def is_cpu(self) -> bool: + return self._enum == PlatformEnum.CPU + + def is_out_of_tree(self) -> bool: + return self._enum == PlatformEnum.OOT + + def is_cuda_alike(self) -> bool: + """Stateless version of :func:`torch.cuda.is_available`.""" + return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM) + + def is_mps(self) -> bool: + return self._enum == PlatformEnum.MPS + + @classmethod + def get_attn_backend_cls_str( + cls, + selected_backend: AttentionBackendEnum | None, + head_size: int, + dtype: torch.dtype, + ) -> str: + """Get the attention backend class of a device.""" + return "" + + @classmethod + def get_device_capability( + cls, + device_id: int = 0, + ) -> DeviceCapability | None: + """Stateless version of :func:`torch.cuda.get_device_capability`.""" + return None + + @classmethod + def has_device_capability( + cls, + capability: tuple[int, int] | int, + device_id: int = 0, + ) -> bool: + """ + Test whether this platform is compatible with a device capability. + + The ``capability`` argument can either be: + + - A tuple ``(major, minor)``. + - An integer ````. (See :meth:`DeviceCapability.to_int`) + """ + current_capability = cls.get_device_capability(device_id=device_id) + if current_capability is None: + return False + + if isinstance(capability, tuple): + return current_capability >= capability + + return current_capability.to_int() >= capability + + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: + """Get the name of a device.""" + raise NotImplementedError + + @classmethod + def get_device_uuid(cls, device_id: int = 0) -> str: + """Get the uuid of a device, e.g. the PCI bus ID.""" + raise NotImplementedError + + @classmethod + def get_device_total_memory(cls, device_id: int = 0) -> int: + """Get the total memory of a device in bytes.""" + raise NotImplementedError + + @classmethod + def is_async_output_supported(cls, enforce_eager: bool | None) -> bool: + """ + Check if the current platform supports async output. + """ + raise NotImplementedError + + @classmethod + def inference_mode(cls): + """A device-specific wrapper of `torch.inference_mode`. + + This wrapper is recommended because some hardware backends such as TPU + do not support `torch.inference_mode`. In such a case, they will fall + back to `torch.no_grad` by overriding this method. + """ + return torch.inference_mode(mode=True) + + @classmethod + def seed_everything(cls, seed: int | None = None) -> None: + """ + Set the seed of each random module. + `torch.manual_seed` will set seed on all devices. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + if seed is not None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + @classmethod + def verify_model_arch(cls, model_arch: str) -> None: + """ + Verify whether the current platform supports the specified model + architecture. + + - This will raise an Error or Warning based on the model support on + the current platform. + - By default all models are considered supported. + """ + pass + + @classmethod + def verify_quantization(cls, quant: str) -> None: + """ + Verify whether the quantization is supported by the current platform. + """ + if cls.supported_quantization and quant not in cls.supported_quantization: + raise ValueError( + f"{quant} quantization is currently not supported in " + f"{cls.device_name}." + ) + + @classmethod + def get_current_memory_usage( + cls, device: torch.types.Device | None = None + ) -> float: + """ + Return the memory usage in bytes. + """ + raise NotImplementedError + + @classmethod + def get_device_communicator_cls(cls) -> str: + """ + Get device specific communicator class for distributed communication. + """ + return "sglang.multimodal_gen.runtime.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase" # noqa + + @classmethod + def get_cpu_architecture(cls) -> CpuArchEnum: + """Get the CPU architecture of the current platform.""" + return CpuArchEnum.UNSPECIFIED + + def get_attn_backend(self, *args, **kwargs) -> AttentionImpl: + attention_cls_str = self.get_attn_backend_cls_str(*args, **kwargs) + return resolve_obj_by_qualname(attention_cls_str) + + +class UnspecifiedPlatform(Platform): + _enum = PlatformEnum.UNSPECIFIED + device_type = "" diff --git a/python/sglang/multimodal_gen/runtime/platforms/mps.py b/python/sglang/multimodal_gen/runtime/platforms/mps.py new file mode 100644 index 000000000..2312ec059 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/platforms/mps.py @@ -0,0 +1,88 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from sglang.multimodal_gen.runtime.platforms import AttentionBackendEnum +from sglang.multimodal_gen.runtime.platforms.interface import ( + DeviceCapability, + Platform, + PlatformEnum, +) +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class MpsPlatform(Platform): + _enum = PlatformEnum.MPS + device_name: str = "mps" + device_type: str = "mps" + dispatch_key: str = "MPS" + device_control_env_var: str = "MPS_VISIBLE_DEVICES" + + @classmethod + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None: + raise NotImplementedError + + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: + raise NotImplementedError + + @classmethod + def get_device_uuid(cls, device_id: int = 0) -> str: + raise NotImplementedError + + @classmethod + def get_device_total_memory(cls, device_id: int = 0) -> int: + raise NotImplementedError + + @classmethod + def is_async_output_supported(cls, enforce_eager: bool | None) -> bool: + if enforce_eager: + logger.warning( + "To see benefits of async output processing, enable MPS " + "graph. Since, enforce-eager is enabled, async output " + "processor cannot be used" + ) + return False + return True + + @classmethod + def get_current_memory_usage( + cls, device: torch.types.Device | None = None + ) -> float: + return 0.0 + + @classmethod + def get_attn_backend_cls_str( + cls, + selected_backend: AttentionBackendEnum | None, + head_size: int, + dtype: torch.dtype, + ) -> str: + # MPS supports SDPA (Scaled Dot-Product Attention) which is the most compatible + logger.info("Using Torch SDPA backend for MPS.") + return ( + "sglang.multimodal_gen.runtime.layers.attention.backends.sdpa.SDPABackend" + ) + + @classmethod + def get_device_communicator_cls(cls) -> str: + # Use base communicator for MPS + return "sglang.multimodal_gen.runtime.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase" + + @classmethod + def seed_everything(cls, seed: int | None = None) -> None: + """Set the seed for MPS device.""" + if seed is not None: + import random + + import numpy as np + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + # MPS doesn't have manual_seed_all like CUDA + # The manual_seed above should be sufficient diff --git a/python/sglang/multimodal_gen/runtime/platforms/rocm.py b/python/sglang/multimodal_gen/runtime/platforms/rocm.py new file mode 100644 index 000000000..9eca14ac7 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/platforms/rocm.py @@ -0,0 +1,138 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from rocm/vllm: https://github.com/ROCm/vllm/blob/v0.7.3%2Brocm/vllm/platforms/rocm.py +""" +This file is a platform abstraction for ROCm GPUs, +adjusted to match the structure and interface of `cuda.py`. +""" + +import torch + +import sglang.multimodal_gen.envs as envs +from sglang.multimodal_gen.runtime.platforms.interface import ( + AttentionBackendEnum, + DeviceCapability, + Platform, + PlatformEnum, +) +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +# ROCm uses the same torch.cuda interface +class RocmPlatform(Platform): + _enum = PlatformEnum.ROCM + device_name: str = "rocm" + device_type: str = "cuda" # torch uses 'cuda' backend string + dispatch_key: str = "CUDA" + device_control_env_var: str = "CUDA_VISIBLE_DEVICES" + + @classmethod + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: + major, minor = torch.cuda.get_device_capability(device_id) + return DeviceCapability(major=major, minor=minor) + + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: + return str(torch.cuda.get_device_name(device_id)) + + @classmethod + def get_device_total_memory(cls, device_id: int = 0) -> int: + return torch.cuda.get_device_properties(device_id).total_memory + + @classmethod + def is_async_output_supported(cls, enforce_eager: bool | None) -> bool: + if enforce_eager: + logger.warning( + "To see benefits of async output processing, enable CUDA graph. " + "Since enforce-eager is enabled, async output processor cannot be used" + ) + return False + return True + + @classmethod + def log_warnings(cls) -> None: + pass # ROCm-specific warnings can be added here + + @classmethod + def get_current_memory_usage(cls, device: torch.device | None = None) -> float: + torch.cuda.reset_peak_memory_stats(device) + return float(torch.cuda.max_memory_allocated(device)) + + @classmethod + def get_attn_backend_cls_str( + cls, + selected_backend: AttentionBackendEnum | None, + head_size: int, + dtype: torch.dtype, + ) -> str: + logger.info( + "Trying SGL_DIFFUSION_ATTENTION_BACKEND=%s", + envs.SGL_DIFFUSION_ATTENTION_BACKEND, + ) + + if selected_backend == AttentionBackendEnum.TORCH_SDPA: + logger.info("Using Torch SDPA backend.") + return "sglang.multimodal_gen.runtime.layers.attention.backends.sdpa.SDPABackend" + + elif selected_backend in (AttentionBackendEnum.FA3, None): + pass + + elif selected_backend in ( + AttentionBackendEnum.SLIDING_TILE_ATTN, + AttentionBackendEnum.SAGE_ATTN, + ): + raise ValueError( + f"{selected_backend.name} is not supported on {cls.device_name}." + ) + elif selected_backend: + raise ValueError( + f"Invalid attention backend for {cls.device_name}: {selected_backend}" + ) + + target_backend = AttentionBackendEnum.FA3 + if dtype not in (torch.float16, torch.bfloat16): + logger.info( + "Cannot use FlashAttention backend for dtype other than " + "torch.float16 or torch.bfloat16." + ) + target_backend = AttentionBackendEnum.TORCH_SDPA + + if target_backend == AttentionBackendEnum.FA3: + try: + import flash_attn # noqa: F401 + + from sglang.multimodal_gen.runtime.layers.attention.backends.flash_attn import ( # noqa: F401 + FlashAttentionBackend, + ) + + supported_sizes = FlashAttentionBackend.get_supported_head_sizes() + if head_size not in supported_sizes: + logger.info( + "Cannot use FlashAttention-2 backend for head size %d.", + head_size, + ) + target_backend = AttentionBackendEnum.TORCH_SDPA + except ImportError: + logger.info( + "Cannot use FlashAttention backend because the " + "flash_attn package is not found. " + "Make sure that flash_attn was built and installed " + "(on by default)." + ) + target_backend = AttentionBackendEnum.TORCH_SDPA + + if target_backend == AttentionBackendEnum.TORCH_SDPA: + logger.info("Using Torch SDPA backend.") + + return "sglang.multimodal_gen.runtime.layers.attention.backends.sdpa.SDPABackend" + + logger.info("Using Flash Attention backend.") + + return "sglang.multimodal_gen.runtime.layers.attention.backends.flash_attn.FlashAttentionBackend" + + @classmethod + def get_device_communicator_cls(cls) -> str: + return "sglang.multimodal_gen.runtime.distributed.device_communicators.cuda_communicator.CudaCommunicator" # works for ROCm too diff --git a/python/sglang/multimodal_gen/runtime/scheduler_client.py b/python/sglang/multimodal_gen/runtime/scheduler_client.py new file mode 100644 index 000000000..97cc1165e --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/scheduler_client.py @@ -0,0 +1,149 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +import asyncio + +import zmq +import zmq.asyncio + +from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import Req +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +# Using a singleton pattern to hold the ZMQ context and the socket connected to the scheduler +class SchedulerClient: + """ + A gateway for Scheduler, forwarding the ForwardBatch from http endpoints (or somewhere else) to background scheduler, with TCP socket + """ + + _instance = None + + def __new__(cls, *args, **kwargs): + if not cls._instance: + cls._instance = super(SchedulerClient, cls).__new__(cls) + return cls._instance + + def __init__(self, *args, **kwargs): + # Ensure the initialization runs only once for the singleton instance + if getattr(self, "_init_done", False): + return + # Queue + worker to strictly serialize ZeroMQ REQ/REP interactions + self._request_queue = asyncio.Queue() + self._worker_task = None + self._closing = False + self._init_done = True + + def initialize(self, server_args: ServerArgs): + self.server_args = server_args + self.context = zmq.asyncio.Context() + # This is the REQ socket used to connect to the backend Scheduler + self.scheduler_socket = self.context.socket(zmq.REQ) + scheduler_endpoint = server_args.scheduler_endpoint() + self.scheduler_socket.connect(scheduler_endpoint) + logger.info( + f"Scheduler client connected to backend scheduler at {scheduler_endpoint}" + ) + # Worker will be lazily started on the first forward call to ensure a running loop exists + + async def forward(self, batch: Req) -> Req: + """Enqueue a request to the backend Scheduler and await the reply.""" + if self._closing: + raise RuntimeError( + "SchedulerClient is closing; cannot forward new requests" + ) + + await self._ensure_worker_started() + + loop = asyncio.get_running_loop() + future = loop.create_future() + await self._request_queue.put((batch, future)) + return await future + + async def _ensure_worker_started(self): + # Start the worker only once and only when an event loop is running + if self._worker_task is None or self._worker_task.done(): + self._worker_task = asyncio.create_task(self._worker_loop()) + + async def _worker_loop(self): + while True: + try: + item = await self._request_queue.get() + try: + batch, future = item + except Exception: + # Malformed queue item; skip + self._request_queue.task_done() + continue + + try: + await self.scheduler_socket.send_pyobj(batch) + response = await self.scheduler_socket.recv_pyobj() + if not future.done(): + future.set_result(response) + except Exception as e: + if not future.done(): + future.set_exception(e) + finally: + self._request_queue.task_done() + except asyncio.CancelledError: + # Drain remaining items with cancellation error to avoid hanging waiters + while True: + try: + batch, future = self._request_queue.get_nowait() + except asyncio.QueueEmpty: + break + try: + if not future.done(): + future.set_exception(asyncio.CancelledError()) + finally: + self._request_queue.task_done() + raise + + def close(self): + self._closing = True + # Cancel worker if running + if self._worker_task is not None: + self._worker_task.cancel() + try: + self.scheduler_socket.close() + finally: + try: + self.context.term() + except Exception: + pass + + +# Singleton instance +scheduler_client = SchedulerClient() + + +async def run_zeromq_broker(server_args: ServerArgs): + """ + This function runs as a background task in the FastAPI process. + It listens for TCP requests from offline clients (e.g., DiffGenerator). + """ + ctx = zmq.asyncio.Context() + # This is the REP socket that listens for requests from DiffGenerator + socket = ctx.socket(zmq.REP) + broker_endpoint = f"tcp://*:{server_args.broker_port}" + socket.bind(broker_endpoint) + logger.info(f"ZMQ Broker is listening for offline jobs on {broker_endpoint}") + + while True: + try: + # 1. Receive a request from an offline client + request_batch = await socket.recv_pyobj() + logger.info("Broker received an offline job from a client.") + + # 2. Forward the request to the main Scheduler via the shared client + response_batch = await scheduler_client.forward(request_batch) + + # 3. Send the Scheduler's reply back to the offline client + await socket.send_pyobj(response_batch) + + except Exception as e: + logger.error(f"Error in ZMQ Broker: {e}", exc_info=True) + # A reply must be sent to prevent the client from hanging + await socket.send_pyobj({"status": "error", "message": str(e)}) diff --git a/python/sglang/multimodal_gen/runtime/server_args.py b/python/sglang/multimodal_gen/runtime/server_args.py new file mode 100644 index 000000000..88f81936c --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/server_args.py @@ -0,0 +1,1025 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Inspired by SGLang: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/server_args.py +"""The arguments of sgl-diffusion Inference.""" +import argparse +import dataclasses +import inspect +import json +import random +import sys +import tempfile +from contextlib import contextmanager +from dataclasses import field +from enum import Enum +from typing import Any, Optional + +from sglang.multimodal_gen.configs.configs import PreprocessConfig +from sglang.multimodal_gen.configs.pipelines import FluxPipelineConfig +from sglang.multimodal_gen.configs.pipelines.base import PipelineConfig, STA_Mode +from sglang.multimodal_gen.configs.pipelines.qwen_image import ( + QwenImageEditPipelineConfig, + QwenImagePipelineConfig, +) +from sglang.multimodal_gen.runtime.platforms import ( + AttentionBackendEnum, + current_platform, +) +from sglang.multimodal_gen.runtime.utils.common import ( + is_port_available, + is_valid_ipv6_address, +) +from sglang.multimodal_gen.runtime.utils.logging_utils import ( + configure_logger, + init_logger, +) +from sglang.multimodal_gen.utils import FlexibleArgumentParser, StoreBoolean + +logger = init_logger(__name__) + +ZMQ_TCP_PORT_DELTA = 233 + + +def _is_torch_tensor(obj: Any) -> tuple[bool, Any]: + """Return (is_tensor, torch_module_or_None) without importing torch at module import time.""" + try: + import torch # type: ignore + + return isinstance(obj, torch.Tensor), torch + except Exception: + return False, None + + +def _sanitize_for_logging(obj: Any, key_hint: str | None = None) -> Any: + """Recursively convert objects to JSON-serializable forms for concise logging. + + Rules: + - Drop any field/dict key named 'param_names_mapping'. + - Render Enums using their value. + - Render torch.Tensor as a compact summary; if key name is 'scaling_factor', include stats. + - Dataclasses are expanded to dicts and sanitized recursively. + - Callables/functions are rendered as their qualified name. + - Fallback to str(...) for unknown types. + """ + # Handle simple types quickly + if obj is None or isinstance(obj, (str, int, float, bool)): + return obj + + # Enum -> value for readability + if isinstance(obj, Enum): + return obj.value + + # torch.Tensor handling (lazy import) + is_tensor, torch_mod = _is_torch_tensor(obj) + if is_tensor: + try: + ten = obj.detach().cpu() + if key_hint == "scaling_factor": + # Provide a compact, single-line summary for scaling_factor + stats = { + "shape": list(ten.shape), + "dtype": str(ten.dtype), + } + # Stats might fail for some dtypes; guard individually + try: + stats["min"] = float(ten.min().item()) + except Exception: + pass + try: + stats["max"] = float(ten.max().item()) + except Exception: + pass + try: + stats["mean"] = float(ten.float().mean().item()) + except Exception: + pass + return {"tensor": "scaling_factor", **stats} + # Generic tensor summary + return {"tensor": True, "shape": list(ten.shape), "dtype": str(ten.dtype)} + except Exception: + return "" + + # Dataclasses -> dict + if dataclasses.is_dataclass(obj): + result: dict[str, Any] = {} + for f in dataclasses.fields(obj): + if not f.repr: + continue + name = f.name + if "names_mapping" in name: # drop noisy mappings + continue + try: + value = getattr(obj, name) + except Exception: + continue + result[name] = _sanitize_for_logging(value, key_hint=name) + return result + + # Dicts -> sanitize keys/values; drop 'param_names_mapping' + if isinstance(obj, dict): + result_dict: dict[str, Any] = {} + for k, v in obj.items(): + try: + key_str = str(k) + except Exception: + key_str = "" + if key_str == "param_names_mapping": + continue + result_dict[key_str] = _sanitize_for_logging(v, key_hint=key_str) + return result_dict + + # Sequences/Sets -> list + if isinstance(obj, (list, tuple, set)): + return [_sanitize_for_logging(x) for x in obj] + + # Functions / Callables -> qualified name + try: + if inspect.isroutine(obj) or inspect.isclass(obj): + module = getattr(obj, "__module__", "") + qn = getattr(obj, "__qualname__", getattr(obj, "__name__", "")) + return f"{module}.{qn}" if module else qn + except Exception: + pass + + # Fallback: string representation + try: + return str(obj) + except Exception: + return "" + + +class ExecutionMode(str, Enum): + """ + Enumeration for different pipeline modes. + + Inherits from str to allow string comparison for backward compatibility. + """ + + INFERENCE = "inference" + PREPROCESS = "preprocess" + FINETUNING = "finetuning" + DISTILLATION = "distillation" + + @classmethod + def from_string(cls, value: str) -> "ExecutionMode": + """Convert string to ExecutionMode enum.""" + try: + return cls(value.lower()) + except ValueError: + raise ValueError( + f"Invalid mode: {value}. Must be one of: {', '.join([m.value for m in cls])}" + ) from None + + @classmethod + def choices(cls) -> list[str]: + """Get all available choices as strings for argparse.""" + return [mode.value for mode in cls] + + +class WorkloadType(str, Enum): + """ + Enumeration for different workload types. + + Inherits from str to allow string comparison for backward compatibility. + """ + + I2V = "i2v" # Image to Video + T2V = "t2v" # Text to Video + T2I = "t2i" # Text to Image + I2I = "i2i" # Image to Image + + @classmethod + def from_string(cls, value: str) -> "WorkloadType": + """Convert string to WorkloadType enum.""" + try: + return cls(value.lower()) + except ValueError: + raise ValueError( + f"Invalid workload type: {value}. Must be one of: {', '.join([m.value for m in cls])}" + ) from None + + @classmethod + def choices(cls) -> list[str]: + """Get all available choices as strings for argparse.""" + return [workload.value for workload in cls] + + +# args for sgl_diffusion framework +@dataclasses.dataclass +class ServerArgs: + # Model and path configuration (for convenience) + model_path: str + + # Attention + attention_backend: str = None + + # Running mode + mode: ExecutionMode = ExecutionMode.INFERENCE + + # Workload type + workload_type: WorkloadType = WorkloadType.T2V + + # Cache strategy + cache_strategy: str = "none" + + # Distributed executor backend + distributed_executor_backend: str = "mp" + nccl_port: Optional[int] = None + + # HuggingFace specific parameters + trust_remote_code: bool = False + revision: str | None = None + + # Parallelism + num_gpus: int = 1 + tp_size: int = -1 + sp_degree: int = -1 + # sequence parallelism + ulysses_degree: Optional[int] = None + ring_degree: Optional[int] = None + # data parallelism + # number of data parallelism groups + dp_size: int = 1 + # number of gpu in a dp group + dp_degree: int = 1 + # cfg parallel + enable_cfg_parallel: bool = False + + hsdp_replicate_dim: int = 1 + hsdp_shard_dim: int = -1 + dist_timeout: int | None = None # timeout for torch.distributed + + pipeline_config: PipelineConfig = field(default_factory=PipelineConfig, repr=False) + preprocess_config: PreprocessConfig | None = None + + # LoRA parameters + # (Wenxuan) prefer to keep it here instead of in pipeline config to not make it complicated. + lora_path: str | None = None + lora_nickname: str = "default" # for swapping adapters in the pipeline + # can restrict layers to adapt, e.g. ["q_proj"] + # Will adapt only q, k, v, o by default. + lora_target_modules: list[str] | None = None + + output_type: str = "pil" + + # CPU offload parameters + dit_cpu_offload: bool = True + use_fsdp_inference: bool = False + text_encoder_cpu_offload: bool = True + image_encoder_cpu_offload: bool = True + vae_cpu_offload: bool = True + pin_cpu_memory: bool = True + + # STA (Sliding Tile Attention) parameters + mask_strategy_file_path: str | None = None + STA_mode: STA_Mode = STA_Mode.STA_INFERENCE + skip_time_steps: int = 15 + + # Compilation + enable_torch_compile: bool = False + + disable_autocast: bool = False + + # VSA parameters + VSA_sparsity: float = 0.0 # inference/validation sparsity + + # V-MoBA parameters + moba_config_path: str | None = None + moba_config: dict[str, Any] = field(default_factory=dict) + + # Master port for distributed inference + # TODO: do not hard code + master_port: int | None = None + + # http server endpoint config, would be ignored in local mode + host: str | None = None + port: int | None = None + + scheduler_port: int = 5555 + + # Stage verification + enable_stage_verification: bool = True + + # Prompt text file for batch processing + prompt_file_path: str | None = None + + # model paths for correct deallocation + model_paths: dict[str, str] = field(default_factory=dict) + model_loaded: dict[str, bool] = field( + default_factory=lambda: { + "transformer": True, + "vae": True, + } + ) + override_transformer_cls_name: str | None = None + + # # DMD parameters + # dmd_denoising_steps: List[int] | None = field(default=None) + + # MoE parameters used by Wan2.2 + boundary_ratio: float | None = None + + # Logging + log_level: str = "info" + + @property + def broker_port(self) -> int: + return self.port + 1 + + @property + def is_local_mode(self) -> bool: + """ + If no server is running when a generation task begins, 'local_mode' will be enabled: a dedicated server will be launched + """ + return self.host is None or self.port is None + + def __post_init__(self): + self.scheduler_port = self.settle_port(self.scheduler_port) + # TODO: remove hard code + self.master_port = self.settle_port(self.master_port or 30005, 37) + if self.moba_config_path: + try: + with open(self.moba_config_path) as f: + self.moba_config = json.load(f) + logger.info("Loaded V-MoBA config from %s", self.moba_config_path) + except (FileNotFoundError, json.JSONDecodeError) as e: + logger.error( + "Failed to load V-MoBA config from %s: %s", self.moba_config_path, e + ) + raise + self.check_server_args() + + configure_logger(server_args=self) + + # log clean server_args + try: + safe_args = _sanitize_for_logging(self, key_hint="server_args") + logger.info("server_args: %s", json.dumps(safe_args, ensure_ascii=False)) + except Exception: + # Fallback to default repr if sanitization fails + logger.info(f"server_args: {self}") + + @staticmethod + def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: + # Model and path configuration + parser.add_argument( + "--model-path", + type=str, + help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.", + ) + parser.add_argument( + "--model-dir", + type=str, + help="Directory containing StepVideo model", + ) + + # attention + parser.add_argument( + "--attention-backend", + type=str, + default=None, + choices=[e.name.lower() for e in AttentionBackendEnum], + help="The attention backend to use. If not specified, the backend is automatically selected based on hardware and installed packages.", + ) + + # Running mode + parser.add_argument( + "--mode", + type=str, + choices=ExecutionMode.choices(), + default=ServerArgs.mode.value, + help="The mode to run sgl-diffusion", + ) + + # Workload type + parser.add_argument( + "--workload-type", + type=str, + choices=WorkloadType.choices(), + default=ServerArgs.workload_type.value, + help="The workload type", + ) + + # distributed_executor_backend + parser.add_argument( + "--distributed-executor-backend", + type=str, + choices=["mp"], + default=ServerArgs.distributed_executor_backend, + help="The distributed executor backend to use", + ) + + # HuggingFace specific parameters + parser.add_argument( + "--trust-remote-code", + action=StoreBoolean, + default=ServerArgs.trust_remote_code, + help="Trust remote code when loading HuggingFace models", + ) + parser.add_argument( + "--revision", + type=str, + default=ServerArgs.revision, + help="The specific model version to use (can be a branch name, tag name, or commit id)", + ) + + # Parallelism + parser.add_argument( + "--num-gpus", + type=int, + default=ServerArgs.num_gpus, + help="The number of GPUs to use.", + ) + parser.add_argument( + "--tp-size", + type=int, + default=ServerArgs.tp_size, + help="The tensor parallelism size.", + ) + parser.add_argument( + "--sp-degree", + type=int, + default=ServerArgs.sp_degree, + help="The sequence parallelism size.", + ) + parser.add_argument( + "--ulysses-degree", + type=int, + default=ServerArgs.ulysses_degree, + help="Ulysses sequence parallel degree. Used in attention layer.", + ) + parser.add_argument( + "--ring-degree", + type=int, + default=ServerArgs.ring_degree, + help="Ring sequence parallel degree. Used in attention layer.", + ) + parser.add_argument( + "--enable-cfg-parallel", + action="store_true", + default=ServerArgs.enable_cfg_parallel, + help="Enable cfg parallel.", + ) + parser.add_argument( + "--data-parallel-size", + "--dp-size", + "--dp", + type=int, + default=ServerArgs.dp_size, + help="The data parallelism size.", + ) + + parser.add_argument( + "--hsdp-replicate-dim", + type=int, + default=ServerArgs.hsdp_replicate_dim, + help="The data parallelism size.", + ) + parser.add_argument( + "--hsdp-shard-dim", + type=int, + default=ServerArgs.hsdp_shard_dim, + help="The data parallelism shards.", + ) + parser.add_argument( + "--dist-timeout", + type=int, + default=ServerArgs.dist_timeout, + help="Set timeout for torch.distributed initialization.", + ) + + # Output type + parser.add_argument( + "--output-type", + type=str, + default=ServerArgs.output_type, + choices=["pil"], + help="Output type for the generated video", + ) + + # Prompt text file for batch processing + parser.add_argument( + "--prompt-file-path", + type=str, + default=ServerArgs.prompt_file_path, + help="Path to a text file containing prompts (one per line) for batch processing", + ) + + # STA (Sliding Tile Attention) parameters + parser.add_argument( + "--STA-mode", + type=str, + default=ServerArgs.STA_mode.value, + choices=[mode.value for mode in STA_Mode], + help="STA mode contains STA_inference, STA_searching, STA_tuning, STA_tuning_cfg, None", + ) + parser.add_argument( + "--skip-time-steps", + type=int, + default=ServerArgs.skip_time_steps, + help="Number of time steps to warmup (full attention) for STA", + ) + parser.add_argument( + "--mask-strategy-file-path", + type=str, + help="Path to mask strategy JSON file for STA", + ) + parser.add_argument( + "--enable-torch-compile", + action=StoreBoolean, + default=ServerArgs.enable_torch_compile, + help="Use torch.compile to speed up DiT inference." + + "However, will likely cause precision drifts. See (https://github.com/pytorch/pytorch/issues/145213)", + ) + + parser.add_argument( + "--dit-cpu-offload", + action=StoreBoolean, + help="Use CPU offload for DiT inference. Enable if run out of memory with FSDP.", + ) + parser.add_argument( + "--use-fsdp-inference", + action=StoreBoolean, + help="Use FSDP for inference by sharding the model weights. Latency is very low due to prefetch--enable if run out of memory.", + ) + parser.add_argument( + "--text-encoder-cpu-offload", + action=StoreBoolean, + help="Use CPU offload for text encoder. Enable if run out of memory.", + ) + parser.add_argument( + "--image-encoder-cpu-offload", + action=StoreBoolean, + help="Use CPU offload for image encoder. Enable if run out of memory.", + ) + parser.add_argument( + "--vae-cpu-offload", + action=StoreBoolean, + help="Use CPU offload for VAE. Enable if run out of memory.", + ) + parser.add_argument( + "--pin-cpu-memory", + action=StoreBoolean, + help='Pin memory for CPU offload. Only added as a temp workaround if it throws "CUDA error: invalid argument". ' + "Should be enabled in almost all cases", + ) + parser.add_argument( + "--disable-autocast", + action=StoreBoolean, + help="Disable autocast for denoising loop and vae decoding in pipeline sampling", + ) + + # VSA parameters + parser.add_argument( + "--VSA-sparsity", + type=float, + default=ServerArgs.VSA_sparsity, + help="Validation sparsity for VSA", + ) + + # Master port for distributed inference + parser.add_argument( + "--master-port", + type=int, + default=ServerArgs.master_port, + help="Master port for distributed inference. If not set, a random free port will be used.", + ) + parser.add_argument( + "--scheduler-port", + type=int, + default=ServerArgs.scheduler_port, + help="Port for the scheduler server.", + ) + parser.add_argument( + "--host", + type=str, + default=ServerArgs.host, + help="Host for the HTTP API server.", + ) + parser.add_argument( + "--port", + type=int, + default=ServerArgs.port, + help="Port for the HTTP API server.", + ) + + # Stage verification + parser.add_argument( + "--enable-stage-verification", + action=StoreBoolean, + default=ServerArgs.enable_stage_verification, + help="Enable input/output verification for pipeline stages", + ) + parser.add_argument( + "--override-transformer-cls-name", + type=str, + default=ServerArgs.override_transformer_cls_name, + help="Override transformer cls name", + ) + # Add pipeline configuration arguments + PipelineConfig.add_cli_args(parser) + + # Add preprocessing configuration arguments + PreprocessConfig.add_cli_args(parser) + + # Logging + parser.add_argument( + "--log-level", + type=str, + default=ServerArgs.log_level, + help="The logging level of all loggers.", + ) + return parser + + def url(self): + if is_valid_ipv6_address(self.host): + return f"http://[{self.host}]:{self.port}" + else: + return f"http://{self.host}:{self.port}" + + def scheduler_endpoint(self): + """ + Internal endpoint for scheduler + + """ + scheduler_host = self.host or "localhost" + return f"tcp://{scheduler_host}:{self.scheduler_port}" + + def settle_port(self, port: int, port_inc: int = 42) -> int: + while True: + if is_port_available(port): + return port + if port < 60000: + port += port_inc + else: + port -= port_inc + 1 + + def post_init_serve(self): + """ + Post init when in serve mode + """ + if self.host is None: + self.host = "localhost" + if self.port is None: + self.port = 3000 + self.port = self.settle_port(self.port) + + @classmethod + def from_cli_args( + cls, args: argparse.Namespace, unknown_args: list[str] | None = None + ) -> "ServerArgs": + if unknown_args is None: + unknown_args = [] + provided_args = cls.get_provided_args(args, unknown_args) + + # Handle config file + config_file = provided_args.get("config") + if config_file: + config_args = cls.load_config_file(config_file) + # Provided args override config file args + provided_args = {**config_args, **provided_args} + + # Handle special cases + # if "tp_size" in provided_args: + # provided_args["tp"] = provided_args.pop("tp_size") + + return cls.from_dict(provided_args) + + @classmethod + def from_dict(cls, kwargs: dict[str, Any]) -> "ServerArgs": + """Create a ServerArgs object from a dictionary.""" + attrs = [attr.name for attr in dataclasses.fields(cls)] + server_args_kwargs: dict[str, Any] = {} + + for attr in attrs: + if attr == "pipeline_config": + pipeline_config = PipelineConfig.from_kwargs(kwargs) + logger.debug(f"Using PipelineConfig: {type(pipeline_config)}") + server_args_kwargs["pipeline_config"] = pipeline_config + elif attr == "preprocess_config": + preprocess_config = PreprocessConfig.from_kwargs(kwargs) + server_args_kwargs["preprocess_config"] = preprocess_config + elif attr in kwargs: + server_args_kwargs[attr] = kwargs[attr] + + return cls(**server_args_kwargs) + + @staticmethod + def load_config_file(config_file: str) -> dict[str, Any]: + """Load a config file.""" + if config_file.endswith(".json"): + with open(config_file, "r") as f: + return json.load(f) + elif config_file.endswith((".yaml", ".yml")): + try: + import yaml + except ImportError: + raise ImportError( + "Please install PyYAML to use YAML config files. " + "`pip install pyyaml`" + ) + with open(config_file, "r") as f: + return yaml.safe_load(f) + else: + raise ValueError(f"Unsupported config file format: {config_file}") + + @classmethod + def from_kwargs(cls, **kwargs: Any) -> "ServerArgs": + # Convert mode string to enum if necessary + if "mode" in kwargs and isinstance(kwargs["mode"], str): + kwargs["mode"] = ExecutionMode.from_string(kwargs["mode"]) + + # Convert workload_type string to enum if necessary + if "workload_type" in kwargs and isinstance(kwargs["workload_type"], str): + kwargs["workload_type"] = WorkloadType.from_string(kwargs["workload_type"]) + + kwargs["pipeline_config"] = PipelineConfig.from_kwargs(kwargs) + kwargs["preprocess_config"] = PreprocessConfig.from_kwargs(kwargs) + return cls(**kwargs) + + @staticmethod + def get_provided_args( + args: argparse.Namespace, unknown_args: list[str] + ) -> dict[str, Any]: + """Get the arguments provided by the user.""" + provided_args = {} + # We need to check against the raw command-line arguments to see what was + # explicitly provided by the user, vs. what's a default value from argparse. + raw_argv = sys.argv + unknown_args + + # Create a set of argument names that were present on the command line. + # This handles both styles: '--arg=value' and '--arg value'. + provided_arg_names = set() + for arg in raw_argv: + if arg.startswith("--"): + # For '--arg=value', this gets 'arg'; for '--arg', this also gets 'arg'. + arg_name = arg.split("=", 1)[0].replace("-", "_").lstrip("_") + provided_arg_names.add(arg_name) + + # Populate provided_args if the argument from the namespace was on the command line. + for k, v in vars(args).items(): + if k in provided_arg_names: + provided_args[k] = v + + return provided_args + + def check_server_sp_args(self): + + if self.pipeline_config.is_image_gen: + if ( + (self.sp_degree and self.sp_degree > 1) + or (self.ulysses_degree and self.ulysses_degree > 1) + or (self.ring_degree and self.ring_degree > 1) + ): + raise ValueError( + "SP is not supported for image generation models for now" + ) + self.sp_degree = self.ulysses_degree = self.ring_degree = 1 + + if self.sp_degree == -1: + # assume we leave all remaining gpus to sp + num_gpus_per_group = self.dp_size * self.tp_size + if self.enable_cfg_parallel: + num_gpus_per_group *= 2 + if self.num_gpus % num_gpus_per_group != 0: + raise ValueError(f"{self.num_gpus=} % {num_gpus_per_group} != 0") + self.sp_degree = self.num_gpus // num_gpus_per_group + + if ( + self.ulysses_degree is None + and self.ring_degree is None + and self.sp_degree != 1 + ): + self.ulysses_degree = self.sp_degree + logger.info( + f"Automatically set ulysses_degree=sp_degree={self.ulysses_degree} for best performance" + ) + + if self.ulysses_degree is None: + self.ulysses_degree = 1 + logger.info( + f"Ulysses degree not set, " f"using default value {self.ulysses_degree}" + ) + + if self.ring_degree is None: + self.ring_degree = 1 + logger.info( + f"Ring degree not set, " f"using default value {self.ring_degree}" + ) + + if self.sp_degree == -1: + self.sp_degree = self.ring_degree * self.ulysses_degree + logger.info( + f"sequence_parallel_degree is not provided, using ring_degree * ulysses_degree = {self.sp_degree}" + ) + + if self.sp_degree != self.ring_degree * self.ulysses_degree: + raise ValueError( + f"sequence_parallel_degree is not equal to ring_degree * ulysses_degree, {self.sp_degree} != {self.ring_degree} * {self.ulysses_degree}" + ) + + def check_server_dp_args(self): + assert self.num_gpus % self.dp_size == 0, f"{self.num_gpus=}, {self.dp_size=}" + assert self.dp_size >= 1, "--dp-size must be natural number" + self.dp_degree = self.num_gpus // self.dp_size + logger.info(f"Setting dp_degree to: {self.dp_degree}") + + def check_server_args(self) -> None: + """Validate inference arguments for consistency""" + if current_platform.is_mps(): + self.use_fsdp_inference = False + + # autocast + is_flux = ( + isinstance(self.pipeline_config, FluxPipelineConfig) + or isinstance(self.pipeline_config, QwenImagePipelineConfig) + or isinstance(self.pipeline_config, QwenImageEditPipelineConfig) + ) + if is_flux: + self.disable_autocast = True + + # Validate mode consistency + assert isinstance( + self.mode, ExecutionMode + ), f"Mode must be an ExecutionMode enum, got {type(self.mode)}" + assert ( + self.mode in ExecutionMode.choices() + ), f"Invalid execution mode: {self.mode}" + + # Validate workload type + assert isinstance( + self.workload_type, WorkloadType + ), f"Workload type must be a WorkloadType enum, got {type(self.workload_type)}" + assert ( + self.workload_type in WorkloadType.choices() + ), f"Invalid workload type: {self.workload_type}" + + if self.tp_size == -1: + self.tp_size = 1 + + if self.hsdp_shard_dim == -1: + self.hsdp_shard_dim = self.num_gpus + + assert ( + self.sp_degree <= self.num_gpus and self.num_gpus % self.sp_degree == 0 + ), "num_gpus must >= and be divisible by sp_size" + assert ( + self.hsdp_replicate_dim <= self.num_gpus + and self.num_gpus % self.hsdp_replicate_dim == 0 + ), "num_gpus must >= and be divisible by hsdp_replicate_dim" + assert ( + self.hsdp_shard_dim <= self.num_gpus + and self.num_gpus % self.hsdp_shard_dim == 0 + ), "num_gpus must >= and be divisible by hsdp_shard_dim" + + if self.num_gpus < max(self.tp_size, self.sp_degree): + self.num_gpus = max(self.tp_size, self.sp_degree) + + if self.pipeline_config is None: + raise ValueError("pipeline_config is not set in ServerArgs") + + self.pipeline_config.check_pipeline_config() + + # Add preprocessing config validation if needed + if self.mode == ExecutionMode.PREPROCESS: + if self.preprocess_config is None: + raise ValueError( + "preprocess_config is not set in ServerArgs when mode is PREPROCESS" + ) + if self.preprocess_config.model_path == "": + self.preprocess_config.model_path = self.model_path + if not self.pipeline_config.vae_config.load_encoder: + self.pipeline_config.vae_config.load_encoder = True + self.preprocess_config.check_preprocess_config() + + # parallelism + self.check_server_dp_args() + # allocate all remaining gpus for sp-size + self.check_server_sp_args() + + if self.enable_cfg_parallel: + if self.num_gpus == 1: + raise ValueError( + "CFG Parallelism is enabled via `--enable-cfg-parallel`, while -num-gpus==1" + ) + + +@dataclasses.dataclass +class PortArgs: + # The ipc filename for scheduler (rank 0) to receive inputs from tokenizer (zmq) + scheduler_input_ipc_name: str + + # The port for nccl initialization (torch.dist) + nccl_port: int + + # The ipc filename for rpc call between Engine and Scheduler + rpc_ipc_name: str + + # The ipc filename for Scheduler to send metrics + metrics_ipc_name: str + + # Master port for distributed inference + master_port: int | None = None + + @staticmethod + def from_server_args( + server_args: ServerArgs, dp_rank: Optional[int] = None + ) -> "PortArgs": + if server_args.nccl_port is None: + nccl_port = server_args.scheduler_port + random.randint(100, 1000) + while True: + if is_port_available(nccl_port): + break + if nccl_port < 60000: + nccl_port += 42 + else: + nccl_port -= 43 + else: + nccl_port = server_args.nccl_port + + # Normal case, use IPC within a single node + return PortArgs( + scheduler_input_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}", + nccl_port=nccl_port, + rpc_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}", + metrics_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}", + master_port=server_args.master_port, + ) + + +# TODO: not sure what _current_server_args is for, using a _global_server_args instead +_current_server_args = None +_global_server_args = None + + +def prepare_server_args(argv: list[str]) -> ServerArgs: + """ + Prepare the inference arguments from the command line arguments. + + Args: + argv: The command line arguments. Typically, it should be `sys.argv[1:]` + to ensure compatibility with `parse_args` when no arguments are passed. + + Returns: + The inference arguments. + """ + parser = FlexibleArgumentParser() + ServerArgs.add_cli_args(parser) + raw_args = parser.parse_args(argv) + server_args = ServerArgs.from_cli_args(raw_args) + global _current_server_args + _current_server_args = server_args + return server_args + + +@contextmanager +def set_current_server_args(server_args: ServerArgs): + """ + Temporarily set the current sgl_diffusion config. + Used during model initialization. + We save the current sgl_diffusion config in a global variable, + so that all modules can access it, e.g. custom ops + can access the sgl_diffusion config to determine how to dispatch. + """ + global _current_server_args + old_server_args = _current_server_args + try: + _current_server_args = server_args + yield + finally: + _current_server_args = old_server_args + + +def set_global_server_args(server_args: ServerArgs): + """ + Set the global sgl_diffusion config for each process + """ + global _global_server_args + _global_server_args = server_args + + +def get_current_server_args() -> ServerArgs: + if _current_server_args is None: + # in ci, usually when we test custom ops/modules directly, + # we don't set the sgl_diffusion config. In that case, we set a default + # config. + # TODO(will): may need to handle this for CI. + raise ValueError("Current sgl_diffusion args is not set.") + return _current_server_args + + +def get_global_server_args() -> ServerArgs: + if _global_server_args is None: + # in ci, usually when we test custom ops/modules directly, + # we don't set the sgl_diffusion config. In that case, we set a default + # config. + # TODO(will): may need to handle this for CI. + raise ValueError("Global sgl_diffusion args is not set.") + return _global_server_args + + +def parse_int_list(value: str) -> list[int]: + if not value: + return [] + return [int(x.strip()) for x in value.split(",")] diff --git a/python/sglang/multimodal_gen/runtime/sync_scheduler_client.py b/python/sglang/multimodal_gen/runtime/sync_scheduler_client.py new file mode 100644 index 000000000..93359f34d --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/sync_scheduler_client.py @@ -0,0 +1,92 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +import zmq + +from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import Req +from sglang.multimodal_gen.runtime.server_args import ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class SyncSchedulerClient: + """ + A synchronous, singleton client for communicating with the Scheduler service. + Designed for use in synchronous environments like the DiffGenerator or standalone scripts. + """ + + _instance = None + + def __new__(cls, *args, **kwargs): + if not cls._instance: + cls._instance = super(SyncSchedulerClient, cls).__new__(cls) + return cls._instance + + def initialize(self, server_args: ServerArgs): + if hasattr(self, "context") and not self.context.closed: + logger.warning( + "SyncSchedulerClient is already initialized. Re-initializing." + ) + self.close() + + self.server_args = server_args + self.context = zmq.Context() # Standard synchronous context + self.scheduler_socket = self.context.socket(zmq.REQ) + + # Set socket options for the main communication socket + self.scheduler_socket.setsockopt(zmq.LINGER, 0) + self.scheduler_socket.setsockopt( + zmq.RCVTIMEO, 6000000 + ) # 10 minute timeout for generation + + scheduler_endpoint = self.server_args.scheduler_endpoint() + self.scheduler_socket.connect(scheduler_endpoint) + logger.debug( + f"SyncSchedulerClient connected to backend scheduler at {scheduler_endpoint}" + ) + + def forward(self, batch: Req) -> Req: + """Sends a batch to the scheduler and waits for the response.""" + try: + self.scheduler_socket.send_pyobj(batch) + output_batch = self.scheduler_socket.recv_pyobj() + return output_batch + except zmq.error.Again: + logger.error("Timeout waiting for response from scheduler.") + raise TimeoutError("Scheduler did not respond in time.") + + def ping(self) -> bool: + """ + Checks if the scheduler server is alive using a temporary socket. + This avoids interfering with the state of the main REQ/REP socket. + """ + if not hasattr(self, "context") or self.context.closed: + logger.error("Cannot ping: client is not initialized.") + return False + + ping_socket = self.context.socket(zmq.REQ) + ping_socket.setsockopt(zmq.LINGER, 0) + ping_socket.setsockopt(zmq.RCVTIMEO, 2000) # 2-second timeout for pings + + endpoint = self.server_args.scheduler_endpoint() + + try: + ping_socket.connect(endpoint) + ping_socket.send_pyobj({"method": "ping"}) + ping_socket.recv_pyobj() + return True + except zmq.error.Again: + return False + finally: + ping_socket.close() + + def close(self): + """Closes the socket and terminates the context.""" + if hasattr(self, "scheduler_socket"): + self.scheduler_socket.close() + if hasattr(self, "context"): + self.context.term() + + +# Singleton instance for easy access +sync_scheduler_client = SyncSchedulerClient() diff --git a/python/sglang/multimodal_gen/runtime/utils/common.py b/python/sglang/multimodal_gen/runtime/utils/common.py new file mode 100644 index 000000000..c39769ae8 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/utils/common.py @@ -0,0 +1,291 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +import importlib +import ipaddress +import os +import platform +import signal +import socket +import sys +import threading +from functools import lru_cache + +import psutil +import torch +import zmq + +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = None): + """Kill the process and all its child processes.""" + # Remove sigchld handler to avoid spammy logs. + if threading.current_thread() is threading.main_thread(): + signal.signal(signal.SIGCHLD, signal.SIG_DFL) + + if parent_pid is None: + parent_pid = os.getpid() + include_parent = False + + try: + itself = psutil.Process(parent_pid) + except psutil.NoSuchProcess: + return + + children = itself.children(recursive=True) + for child in children: + if child.pid == skip_pid: + continue + try: + child.kill() + except psutil.NoSuchProcess: + pass + + if include_parent: + try: + if parent_pid == os.getpid(): + itself.kill() + sys.exit(0) + + itself.kill() + + # Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes), + # so we send an additional signal to kill them. + itself.send_signal(signal.SIGQUIT) + except psutil.NoSuchProcess: + pass + + +def add_prefix(name: str, prefix: str) -> str: + """Add a weight path prefix to a module name. + + Args: + name: base module name. + prefix: weight prefix str to added to the front of `name` concatenated with `.`. + + Returns: + The string `prefix.name` if prefix is non-empty, otherwise just `name`. + """ + return name if not prefix else f"{prefix}.{name}" + + +def is_valid_ipv6_address(address: str) -> bool: + try: + ipaddress.IPv6Address(address) + return True + except ValueError: + return False + + +def configure_ipv6(dist_init_addr): + addr = dist_init_addr + end = addr.find("]") + if end == -1: + raise ValueError("invalid IPv6 address format: missing ']'") + + host = addr[: end + 1] + + # this only validates the address without brackets: we still need the below checks. + # if it's invalid, immediately raise an error so we know it's not formatting issues. + if not is_valid_ipv6_address(host[1:end]): + raise ValueError(f"invalid IPv6 address: {host}") + + port_str = None + if len(addr) > end + 1: + if addr[end + 1] == ":": + port_str = addr[end + 2 :] + else: + raise ValueError("received IPv6 address format: expected ':' after ']'") + + if not port_str: + raise ValueError( + "a port must be specified in IPv6 address (format: [ipv6]:port)" + ) + + try: + port = int(port_str) + except ValueError: + raise ValueError(f"invalid port in IPv6 address: '{port_str}'") + return port, host + + +def is_port_available(port): + """Return whether a port is available.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind(("", port)) + s.listen(1) + return True + except socket.error: + return False + except OverflowError: + return False + + +def get_zmq_socket( + context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool +) -> zmq.Socket: + mem = psutil.virtual_memory() + total_mem = mem.total / 1024**3 + available_mem = mem.available / 1024**3 + if total_mem > 32 and available_mem > 16: + buf_size = int(0.5 * 1024**3) + else: + buf_size = -1 + + socket = context.socket(socket_type) + if endpoint.find("[") != -1: + socket.setsockopt(zmq.IPV6, 1) + + def set_send_opt(): + socket.setsockopt(zmq.SNDHWM, 0) + socket.setsockopt(zmq.SNDBUF, buf_size) + + def set_recv_opt(): + socket.setsockopt(zmq.RCVHWM, 0) + socket.setsockopt(zmq.RCVBUF, buf_size) + + if socket_type == zmq.PUSH: + set_send_opt() + elif socket_type == zmq.PULL: + set_recv_opt() + elif socket_type == zmq.DEALER: + set_send_opt() + set_recv_opt() + elif socket_type == zmq.REQ: + set_send_opt() + set_recv_opt() + elif socket_type == zmq.REP: + set_send_opt() + set_recv_opt() + else: + raise ValueError(f"Unsupported socket type: {socket_type}") + + if bind: + socket.bind(endpoint) + else: + socket.connect(endpoint) + + return socket + + +# https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip +@lru_cache(maxsize=1) +def is_hip() -> bool: + return torch.version.hip is not None + + +@lru_cache(maxsize=1) +def is_cuda(): + return torch.cuda.is_available() and torch.version.cuda + + +@lru_cache(maxsize=1) +def is_cuda_alike(): + return is_cuda() or is_hip() + + +@lru_cache(maxsize=1) +def is_blackwell(): + if not is_cuda(): + return False + return torch.cuda.get_device_capability()[0] == 10 + + +@lru_cache(maxsize=1) +def is_hpu() -> bool: + return hasattr(torch, "hpu") and torch.hpu.is_available() + + +@lru_cache(maxsize=1) +def is_xpu() -> bool: + return hasattr(torch, "xpu") and torch.xpu.is_available() + + +@lru_cache(maxsize=1) +def is_npu() -> bool: + return hasattr(torch, "npu") and torch.npu.is_available() + + +@lru_cache(maxsize=1) +def is_host_cpu_x86() -> bool: + machine = platform.machine().lower() + return ( + machine in ("x86_64", "amd64", "i386", "i686") + and hasattr(torch, "cpu") + and torch.cpu.is_available() + ) + + +@lru_cache(maxsize=1) +def is_cpu() -> bool: + return os.getenv("SGLANG_USE_CPU_ENGINE", "0") == "1" and is_host_cpu_x86() + + +# cuda + + +def set_cuda_arch(): + capability = torch.cuda.get_device_capability() + arch = f"{capability[0]}.{capability[1]}" + os.environ["TORCH_CUDA_ARCH_LIST"] = f"{arch}{'+PTX' if arch == '9.0' else ''}" + + +def get_bool_env_var(env_var_name: str, default: str | bool = "false") -> bool: + raw_value = os.getenv(env_var_name, None) + if raw_value is None: + raw_value = str(default) + + value_str = str(raw_value).strip().lower() + truthy = {"1", "true", "yes", "y", "t", "on"} + falsy = {"0", "false", "no", "n", "f", "off", ""} + + if value_str in truthy: + return True + if value_str in falsy: + return False + + default_bool = str(default).strip().lower() in truthy + logger.warning( + "Unrecognized boolean for %s=%r; falling back to default=%r", + env_var_name, + raw_value, + default_bool, + ) + return default_bool + + +def is_flashinfer_available(): + """ + Check whether flashinfer is available. + As of Oct. 6, 2024, it is only available on NVIDIA GPUs. + """ + # if not get_bool_env_var("SGLANG_IS_FLASHINFER_AVAILABLE", default="true"): + # return False + return importlib.util.find_spec("flashinfer") is not None and is_cuda() + + +# env var managements + +_warned_bool_env_var_keys = set() + + +def get_bool_env_var(name: str, default: str = "false") -> bool: + # FIXME: move your environment variable to sglang.srt.environ + value = os.getenv(name, default) + value = value.lower() + + truthy_values = ("true", "1") + falsy_values = ("false", "0") + + if (value not in truthy_values) and (value not in falsy_values): + if value not in _warned_bool_env_var_keys: + logger.warning( + f"get_bool_env_var({name}) see non-understandable value={value} and treat as false" + ) + _warned_bool_env_var_keys.add(value) + + return value in truthy_values diff --git a/python/sglang/multimodal_gen/runtime/utils/distributed.py b/python/sglang/multimodal_gen/runtime/utils/distributed.py new file mode 100644 index 000000000..c89a31dcc --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/utils/distributed.py @@ -0,0 +1,231 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +import pickle +from typing import Any, List, Optional + +import numpy as np +import torch +import torch.distributed as dist + + +def broadcast_pyobj( + data: List[Any], + rank: int, + dist_group: Optional[torch.distributed.ProcessGroup] = None, + src: int = 0, + force_cpu_device: bool = True, +): + """Broadcast inputs from src rank to all other ranks with torch.dist backend. + The `rank` here refer to the source rank on global process group (regardless + of dist_group argument). + """ + device = torch.device( + "cuda" if torch.cuda.is_available() and not force_cpu_device else "cpu" + ) + + if rank == src: + if data is None or len(data) == 0: + tensor_size = torch.tensor([0], dtype=torch.long, device=device) + dist.broadcast(tensor_size, src=src, group=dist_group) + else: + serialized_data = pickle.dumps(data) + size = len(serialized_data) + + tensor_data = torch.ByteTensor( + np.frombuffer(serialized_data, dtype=np.uint8) + ).to(device) + tensor_size = torch.tensor([size], dtype=torch.long, device=device) + + dist.broadcast(tensor_size, src=src, group=dist_group) + dist.broadcast(tensor_data, src=src, group=dist_group) + return data + else: + tensor_size = torch.tensor([0], dtype=torch.long, device=device) + dist.broadcast(tensor_size, src=src, group=dist_group) + size = tensor_size.item() + + if size == 0: + return [] + + tensor_data = torch.empty(size, dtype=torch.uint8, device=device) + dist.broadcast(tensor_data, src=src, group=dist_group) + + serialized_data = bytes(tensor_data.cpu().numpy()) + data = pickle.loads(serialized_data) + return data + + +def generate_masked_orthogonal_rank_groups( + world_size: int, parallel_size: list[int], mask: list[bool] +) -> list[list[int]]: + """Generate orthogonal parallel groups based on the parallel size and mask. + + Arguments: + world_size (int): world size + + parallel_size (List[int]): + The parallel size of each orthogonal parallel type. For example, if + tensor_parallel_size = 2, pipeline_model_parallel_group = 3, data_parallel_size = 4, + and the parallel mapping order is tp-pp-dp, then the parallel_size = [2, 3, 4]. + + mask (List[bool]): + The mask controls which parallel methods the generated groups represent. If mask[i] is + True, it means the generated group contains the i-th parallelism method. For example, + if parallel_size = [tp_size, pp_size, dp_size], and mask = [True, False , True], then + the generated group is the `tp-dp` group, if the mask = [False, True, False], then the + generated group is the `pp` group. + + Algorithm: + For orthogonal parallelism, such as tp/dp/pp/cp, the global_rank and + + If we want to get the `dp_group` (tp_size * pp_size groups of dp_size ranks each. + For example, if the gpu size is 8 and order is 'tp-pp-dp', size is '2-2-2', and the + dp_group here is [[0, 4], [1, 5], [2, 6], [3, 7]].) + The tp_rank and pp_rank will be combined to form the `dp_group_index`. + dp_group_index = tp_rank + pp_rank * tp_size (2) + + So, Given that tp_rank and pp_rank satisfy equation (2), and dp_rank in + range(0, dp_size), the ranks in dp_group[dp_group_index] satisfies the + equation (1). + + This function solve this math problem. + + For example, if the parallel_size = [tp_size, dp_size, pp_size] = [2, 3, 4], + and the mask = [False, True, False]. Then, + dp_group_index(0) = tp_rank(0) + pp_rank(0) * 2 + dp_group_index(1) = tp_rank(1) + pp_rank(0) * 2 + ... + dp_group_index(7) = tp_rank(1) + pp_rank(3) * 2 + + dp_group[0] = 0 + range(0, 3) * 2 + 0 = [0, 2, 4] + dp_group[1] = 1 + range(0, 3) * 2 + 0 = [1, 3, 5] + ... + dp_group[7] = 1 + range(0, 3) * 2 + 3 * 2 * 3 = [19, 21, 23] + """ + + def prefix_product(a: List[int], init=1) -> List[int]: + r = [init] + for v in a: + init = init * v + r.append(init) + return r + + def inner_product(a: List[int], b: List[int]) -> int: + return sum([x * y for x, y in zip(a, b)]) + + def decompose(index, shape, stride=None): + """ + This function solve the math problem below: + There is an equation: + index = sum(idx[i] * stride[i]) + And given the value of index, stride. + Return the idx. + This function will used to get the pp/dp/pp_rank + from group_index and rank_in_group. + """ + if stride is None: + stride = prefix_product(shape) + idx = [(index // d) % s for s, d in zip(shape, stride)] + # stride is a prefix_product result. And the value of stride[-1] + # is not used. + assert ( + sum([x * y for x, y in zip(idx, stride[:-1])]) == index + ), "idx {} with shape {} mismatch the return idx {}".format(index, shape, idx) + return idx + + masked_shape = [s for s, m in zip(parallel_size, mask) if m] + unmasked_shape = [s for s, m in zip(parallel_size, mask) if not m] + + global_stride = prefix_product(parallel_size) + masked_stride = [d for d, m in zip(global_stride, mask) if m] + unmasked_stride = [d for d, m in zip(global_stride, mask) if not m] + + group_size = prefix_product(masked_shape)[-1] + num_of_group = world_size // group_size + + ranks = [] + for group_index in range(num_of_group): + # get indices from unmaksed for group_index. + decomposed_group_idx = decompose(group_index, unmasked_shape) + rank = [] + for rank_in_group in range(group_size): + # get indices from masked for rank_in_group. + decomposed_rank_idx = decompose(rank_in_group, masked_shape) + rank.append( + inner_product(decomposed_rank_idx, masked_stride) + + inner_product(decomposed_group_idx, unmasked_stride) + ) + ranks.append(rank) + return ranks + + +class RankGenerator(object): + def __init__( + self, + tp: int, + sp: int, + pp: int, + cfg: int, + dp: int, + order: str, + rank_offset: int = 0, + ) -> None: + self.tp = tp + self.sp = sp + self.pp = pp + self.cfg = cfg + self.dp = dp + self.rank_offset = rank_offset + self.world_size = tp * sp * pp * cfg * dp + + self.name_to_size = { + "tp": self.tp, + "sp": self.sp, + "pp": self.pp, + "cfg": self.cfg, + "dp": self.dp, + } + order = order.lower() + + for name in self.name_to_size.keys(): + if name not in order and self.name_to_size[name] != 1: + raise RuntimeError( + f"The size of ({name}) is ({self.name_to_size[name]}), but you haven't specified the order ({self.order})." + ) + elif name not in order: + order = order + "-" + name + + self.order = order + self.ordered_size = [] + + for token in order.split("-"): + self.ordered_size.append(self.name_to_size[token]) + + def get_mask(self, order: str, token: str): + ordered_token = order.split("-") + token = token.split("-") + mask = [False] * len(ordered_token) + for t in token: + mask[ordered_token.index(t)] = True + return mask + + def get_ranks(self, token): + """Get rank group by input token. + + Arguments: + token (str): + Specify the ranks type that want to get. If we want + to obtain multiple parallel types, we can use a hyphen + '-' to separate them. For example, if we want to obtain + the TP_DP group, the token should be 'tp-dp'. + + """ + mask = self.get_mask(self.order, token) + ranks = generate_masked_orthogonal_rank_groups( + self.world_size, self.ordered_size, mask + ) + if self.rank_offset > 0: + for rank_group in ranks: + for i in range(len(rank_group)): + rank_group[i] += self.rank_offset + return ranks diff --git a/python/sglang/multimodal_gen/runtime/utils/hf_diffusers_utils.py b/python/sglang/multimodal_gen/runtime/utils/hf_diffusers_utils.py new file mode 100644 index 000000000..bb357b91f --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/utils/hf_diffusers_utils.py @@ -0,0 +1,384 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from SGLang: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/hf_transformers_utils.py + +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for Huggingface Transformers.""" + +import contextlib +import hashlib +import json +import os +import tempfile +from pathlib import Path +from typing import Any, Optional, cast + +import filelock +from diffusers.loaders.lora_base import ( + _best_guess_weight_name, # watch out for potetential removal from diffusers +) +from huggingface_hub import snapshot_download +from transformers import AutoConfig, PretrainedConfig +from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES + +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) +_CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = { + # ChatGLMConfig.model_type: ChatGLMConfig, + # DbrxConfig.model_type: DbrxConfig, + # ExaoneConfig.model_type: ExaoneConfig, + # Qwen2_5_VLConfig.model_type: Qwen2_5_VLConfig, +} + +for name, cls in _CONFIG_REGISTRY.items(): + with contextlib.suppress(ValueError): + AutoConfig.register(name, cls) + + +def download_from_hf(model_path: str): + if os.path.exists(model_path): + return model_path + + return snapshot_download(model_path, allow_patterns=["*.json", "*.bin", "*.model"]) + + +def get_hf_config( + model: str, + trust_remote_code: bool, + revision: str | None = None, + model_override_args: dict | None = None, + **kwargs, +): + is_gguf = check_gguf_file(model) + if is_gguf: + raise NotImplementedError("GGUF models are not supported.") + + config = AutoConfig.from_pretrained( + model, trust_remote_code=trust_remote_code, revision=revision, **kwargs + ) + if config.model_type in _CONFIG_REGISTRY: + config_class = _CONFIG_REGISTRY[config.model_type] + config = config_class.from_pretrained(model, revision=revision) + # NOTE(HandH1998): Qwen2VL requires `_name_or_path` attribute in `config`. + config._name_or_path = model + if model_override_args: + config.update(model_override_args) + + # Special architecture mapping check for GGUF models + if is_gguf: + if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: + raise RuntimeError(f"Can't get gguf config for {config.model_type}.") + model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type] + config.update({"architectures": [model_type]}) + + return config + + +def get_config( + model: str, + trust_remote_code: bool, + revision: Optional[str] = None, + model_override_args: Optional[dict] = None, + **kwargs, +): + try: + config = AutoConfig.from_pretrained( + model, trust_remote_code=trust_remote_code, revision=revision, **kwargs + ) + except ValueError as e: + raise e + + return config + + +def load_dict(file_path): + if not os.path.exists(file_path): + return {} + try: + # Load the config directly from the file + with open(file_path) as f: + config_dict: dict[str, Any] = json.load(f) + if "_diffusers_version" in config_dict: + config_dict.pop("_diffusers_version") + # TODO(will): apply any overrides from inference args + return config_dict + except Exception as e: + raise RuntimeError( + f"Failed to load diffusers config from {file_path}: {e}" + ) from e + + +def get_diffusers_config( + model: str, +) -> dict[str, Any]: + """Gets a configuration for the given diffusers model. + + Args: + model: The model name or path. + + Returns: + The loaded configuration. + """ + + config_name = "config.json" + if "scheduler" in model: + config_name = "scheduler_config.json" + # Check if the model path exists + if os.path.exists(model): + config_file = os.path.join(model, config_name) + config_dict = load_dict(config_file) + generation_config_file = os.path.join(model, "generation_config.json") + generation_config_dict = load_dict(generation_config_file) + return config_dict | generation_config_dict + else: + raise RuntimeError(f"Diffusers config file not found at {model}") + + +# Models don't use the same configuration key for determining the maximum +# context length. Store them here so we can sanely check them. +# NOTE: The ordering here is important. Some models have two of these and we +# have a preference for which value gets used. +CONTEXT_LENGTH_KEYS = [ + "max_sequence_length", + "seq_length", + "max_seq_len", + "model_max_length", + "max_position_embeddings", +] + + +def attach_additional_stop_token_ids(tokenizer): + # Special handling for stop token <|eom_id|> generated by llama 3 tool use. + if "<|eom_id|>" in tokenizer.get_added_vocab(): + tokenizer.additional_stop_token_ids = set( + [tokenizer.get_added_vocab()["<|eom_id|>"]] + ) + else: + tokenizer.additional_stop_token_ids = None + + +def check_gguf_file(model: str | os.PathLike) -> bool: + """Check if the file is a GGUF model.""" + model = Path(model) + if not model.is_file(): + return False + elif model.suffix == ".gguf": + return True + + with open(model, "rb") as f: + header = f.read(4) + return header == b"GGUF" + + +def get_lock(model_name_or_path: str): + lock_dir = tempfile.gettempdir() + os.makedirs(os.path.dirname(lock_dir), exist_ok=True) + model_name = model_name_or_path.replace("/", "-") + hash_name = hashlib.sha256(model_name.encode()).hexdigest() + # add hash to avoid conflict with old users' lock files + lock_file_name = hash_name + model_name + ".lock" + # mode 0o666 is required for the filelock to be shared across users + lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name), mode=0o666) + return lock + + +def maybe_download_lora( + model_name_or_path: str, local_dir: str | None = None, download: bool = True +) -> str: + """ + Check if the model path is a Hugging Face Hub model ID and download it if needed. + Args: + model_name_or_path: Local path or Hugging Face Hub model ID + local_dir: Local directory to save the model + download: Whether to download the model from Hugging Face Hub + + Returns: + Local path to the model + """ + + local_path = maybe_download_model(model_name_or_path, local_dir, download) + weight_name = _best_guess_weight_name( + model_name_or_path, file_extension=".safetensors" + ) + return os.path.join(local_path, weight_name) + + +def verify_model_config_and_directory(model_path: str) -> dict[str, Any]: + """ + Verify that the model directory contains a valid diffusers configuration. + + Args: + model_path: Path to the model directory + + Returns: + The loaded model configuration as a dictionary + """ + + # Check for model_index.json which is required for diffusers models + config_path = os.path.join(model_path, "model_index.json") + if not os.path.exists(config_path): + raise ValueError( + f"Model directory {model_path} does not contain model_index.json. " + "Only HuggingFace diffusers format is supported." + ) + + # Check for transformer and vae directories + transformer_dir = os.path.join(model_path, "transformer") + vae_dir = os.path.join(model_path, "vae") + + if not os.path.exists(transformer_dir): + raise ValueError( + f"Model directory {model_path} does not contain a transformer/ directory." + ) + + if not os.path.exists(vae_dir): + raise ValueError( + f"Model directory {model_path} does not contain a vae/ directory." + ) + + # Load the config + with open(config_path) as f: + config = json.load(f) + + # Verify diffusers version exists + if "_diffusers_version" not in config: + raise ValueError("model_index.json does not contain _diffusers_version") + + logger.info("Diffusers version: %s", config["_diffusers_version"]) + return cast(dict[str, Any], config) + + +def maybe_download_model_index(model_name_or_path: str) -> dict[str, Any]: + """ + Download and extract just the model_index.json for a Hugging Face model. + + Args: + model_name_or_path: Path or HF Hub model ID + + Returns: + The parsed model_index.json as a dictionary + """ + import tempfile + + from huggingface_hub import hf_hub_download + from huggingface_hub.errors import EntryNotFoundError + + # If it's a local path, verify it directly + if os.path.exists(model_name_or_path): + try: + return verify_model_config_and_directory(model_name_or_path) + except ValueError: + # Not a pipeline, maybe a single model. + config_path = os.path.join(model_name_or_path, "config.json") + if os.path.exists(config_path): + with open(config_path) as f: + config = json.load(f) + return config + raise + + # For remote models, download just the model_index.json + try: + with tempfile.TemporaryDirectory() as tmp_dir: + # Download just the model_index.json file + model_index_path = hf_hub_download( + repo_id=model_name_or_path, + filename="model_index.json", + local_dir=tmp_dir, + ) + + # Load the model_index.json + with open(model_index_path) as f: + config: dict[str, Any] = json.load(f) + + # Verify it has the required fields + if "_class_name" not in config: + raise ValueError( + f"model_index.json for {model_name_or_path} does not contain _class_name field" + ) + + if "_diffusers_version" not in config: + raise ValueError( + f"model_index.json for {model_name_or_path} does not contain _diffusers_version field" + ) + + # Add the pipeline name for downstream use + config["pipeline_name"] = config["_class_name"] + + logger.info( + "Downloaded model_index.json for %s, pipeline: %s", + model_name_or_path, + config["_class_name"], + ) + return config + except EntryNotFoundError: + logger.warning( + "model_index.json not found for %s. Assuming it is a single model and downloading it.", + model_name_or_path, + ) + local_path = maybe_download_model(model_name_or_path) + config_path = os.path.join(local_path, "config.json") + if not os.path.exists(config_path): + raise ValueError( + f"Failed to find config.json for {model_name_or_path} after failing to find model_index.json" + f"You might be looking for models ending with '-Diffusers'" + ) + with open(config_path) as f: + config = json.load(f) + return config + except Exception as e: + raise ValueError( + f"Failed to download or parse model_index.json for {model_name_or_path}: {e}" + ) from e + + +def maybe_download_model( + model_name_or_path: str, local_dir: str | None = None, download: bool = True +) -> str: + """ + Check if the model path is a Hugging Face Hub model ID and download it if needed. + + Args: + model_name_or_path: Local path or Hugging Face Hub model ID + local_dir: Local directory to save the model + download: Whether to download the model from Hugging Face Hub + + Returns: + Local path to the model + """ + + # If the path exists locally, return it + if os.path.exists(model_name_or_path): + logger.info("Model already exists locally") + return model_name_or_path + + # Otherwise, assume it's a HF Hub model ID and try to download it + try: + logger.info( + "Downloading model snapshot from HF Hub for %s...", model_name_or_path + ) + with get_lock(model_name_or_path): + local_path = snapshot_download( + repo_id=model_name_or_path, + ignore_patterns=["*.onnx", "*.msgpack"], + local_dir=local_dir, + ) + logger.info("Downloaded model to %s", local_path) + return str(local_path) + except Exception as e: + raise ValueError( + f"Could not find model at {model_name_or_path} and failed to download from HF Hub: {e}" + ) from e diff --git a/python/sglang/multimodal_gen/runtime/utils/logging_utils.py b/python/sglang/multimodal_gen/runtime/utils/logging_utils.py new file mode 100644 index 000000000..64ece7951 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/utils/logging_utils.py @@ -0,0 +1,401 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/logger.py +"""Logging configuration for sglang.multimodal_gen.""" +import argparse +import datetime +import json +import logging +import os +import sys +import warnings +from functools import lru_cache, partial +from logging import Logger +from logging.config import dictConfig +from os import path +from types import MethodType +from typing import Any, cast + +import sglang.multimodal_gen.envs as envs + +SGL_DIFFUSION_CONFIGURE_LOGGING = envs.SGL_DIFFUSION_CONFIGURE_LOGGING +SGL_DIFFUSION_LOGGING_CONFIG_PATH = envs.SGL_DIFFUSION_LOGGING_CONFIG_PATH +SGL_DIFFUSION_LOGGING_LEVEL = envs.SGL_DIFFUSION_LOGGING_LEVEL +SGL_DIFFUSION_LOGGING_PREFIX = envs.SGL_DIFFUSION_LOGGING_PREFIX + +RED = "\033[91m" +GREEN = "\033[92m" +YELLOW = "\033[93m" +RESET = "\033[0;0m" + +_warned_local_main_process = False +_warned_main_process = False + +_FORMAT = ( + f"{SGL_DIFFUSION_LOGGING_PREFIX}%(levelname)s %(asctime)s " + "[%(filename)s:%(lineno)d] %(message)s" +) + +# _FORMAT = "[%(asctime)s] %(message)s" +_DATE_FORMAT = "%m-%d %H:%M:%S" + +DEFAULT_LOGGING_CONFIG = { + "formatters": { + "sgl_diffusion": { + "class": "sglang.multimodal_gen.runtime.utils.logging_utils.ColoredFormatter", + "datefmt": _DATE_FORMAT, + "format": _FORMAT, + }, + }, + "handlers": { + "sgl_diffusion": { + "class": "logging.StreamHandler", + "formatter": "sgl_diffusion", + "level": SGL_DIFFUSION_LOGGING_LEVEL, + "stream": "ext://sys.stdout", + }, + }, + "loggers": { + "sgl_diffusion": { + "handlers": ["sgl_diffusion"], + "level": "WARNING", + "propagate": False, + }, + }, + "root": { + "handlers": ["sgl_diffusion"], + "level": "DEBUG", + }, + "version": 1, + "disable_existing_loggers": False, +} + + +class NewLineFormatter(logging.Formatter): + """Adds logging prefix to newlines to align multi-line messages.""" + + def __init__(self, fmt, datefmt=None, style="%"): + logging.Formatter.__init__(self, fmt, datefmt, style) + + def format(self, record): + msg = logging.Formatter.format(self, record) + if record.message != "": + parts = msg.split(record.message) + msg = msg.replace("\n", "\r\n" + parts[0]) + return msg + + +class ColoredFormatter(NewLineFormatter): + """A logging formatter that adds color to log levels.""" + + LEVEL_COLORS = { + logging.ERROR: RED, + logging.WARNING: YELLOW, + } + + def format(self, record: logging.LogRecord) -> str: + """Adds color to the log level name.""" + original_levelname = record.levelname + color = self.LEVEL_COLORS.get(record.levelno) + if color: + record.levelname = f"{color}{original_levelname}{RESET}" + + formatted_message = super().format(record) + + if color: + record.levelname = original_levelname + + return formatted_message + + +class SortedHelpFormatter(argparse.HelpFormatter): + """SortedHelpFormatter that sorts arguments by their option strings.""" + + def add_arguments(self, actions): + actions = sorted(actions, key=lambda x: x.option_strings) + super().add_arguments(actions) + + +@lru_cache +def _print_info_once(logger: Logger, msg: str) -> None: + # Set the stacklevel to 2 to print the original caller's line info + logger.info(msg, stacklevel=2) + + +@lru_cache +def _print_warning_once(logger: Logger, msg: str) -> None: + # Set the stacklevel to 2 to print the original caller's line info + logger.warning(msg, stacklevel=2) + + +# TODO(will): add env variable to control this process-aware logging behavior +def _info( + logger: Logger, + msg: object, + *args: Any, + main_process_only: bool = True, + local_main_process_only: bool = True, + **kwargs: Any, +) -> None: + """Process-aware INFO level logging function. + + This function controls logging behavior based on the process rank, allowing for + selective logging from specific processes in a distributed environment. + + Args: + logger: The logger instance to use for logging + msg: The message format string to log + *args: Format string arguments + main_process_only: If True, only log if this is the global main process (RANK=0) + local_main_process_only: If True, only log if this is the local main process (LOCAL_RANK=0) + **kwargs: Additional keyword arguments to pass to the logger.log method + - stacklevel: Defaults to 2 to show the original caller's location + + Note: + - When both main_process_only and local_main_process_only are True, + the message will be logged only if both conditions are met + - When both are False, the message will be logged from all processes + - By default, only logs from processes with LOCAL_RANK=0 + """ + try: + local_rank = int(os.environ["LOCAL_RANK"]) + rank = int(os.environ["RANK"]) + except Exception: + local_rank = 0 + rank = 0 + + is_main_process = rank == 0 + is_local_main_process = local_rank == 0 + + if (main_process_only and is_main_process) or ( + local_main_process_only and is_local_main_process + ): + logger.log(logging.INFO, msg, *args, stacklevel=2, **kwargs) + + global _warned_local_main_process, _warned_main_process + + if not _warned_local_main_process and local_main_process_only: + # logger.warning( + # "%sBy default, logger.info(..) will only log from the local main process. Set logger.info(..., is_local_main_process=False) to log from all processes.%s", + # GREEN, + # RESET, + # ) + _warned_local_main_process = True + if not _warned_main_process and main_process_only and is_main_process: + # logger.warning( + # "%sis_main_process_only is set to True, logging only from the main (RANK==0) process.%s", + # GREEN, + # RESET, + # ) + _warned_main_process = True + + if not main_process_only and not local_main_process_only: + logger.log(logging.INFO, msg, *args, stacklevel=2, **kwargs) + + +class _SGLDiffusionLogger(Logger): + """ + Note: + This class is just to provide type information. + We actually patch the methods directly on the :class:`logging.Logger` + instance to avoid conflicting with other libraries such as + `intel_extension_for_pytorch.utils._logger`. + """ + + def info_once(self, msg: str) -> None: + """ + As :meth:`info`, but subsequent calls with the same message + are silently dropped. + """ + _print_info_once(self, msg) + + def warning_once(self, msg: str) -> None: + """ + As :meth:`warning`, but subsequent calls with the same message + are silently dropped. + """ + _print_warning_once(self, msg) + + def info( # type: ignore[override] + self, + msg: object, + *args: Any, + main_process_only: bool = True, + local_main_process_only: bool = True, + **kwargs: Any, + ) -> None: + _info( + self, + msg, + *args, + main_process_only=main_process_only, + local_main_process_only=local_main_process_only, + **kwargs, + ) + + +def _configure_sgl_diffusion_root_logger() -> None: + logging_config = dict[str, Any]() + + if not SGL_DIFFUSION_CONFIGURE_LOGGING and SGL_DIFFUSION_LOGGING_CONFIG_PATH: + raise RuntimeError( + "SGL_DIFFUSION_CONFIGURE_LOGGING evaluated to false, but " + "SGL_DIFFUSION_LOGGING_CONFIG_PATH was given. SGL_DIFFUSION_LOGGING_CONFIG_PATH " + "implies SGL_DIFFUSION_CONFIGURE_LOGGING. Please enable " + "SGL_DIFFUSION_CONFIGURE_LOGGING or unset SGL_DIFFUSION_LOGGING_CONFIG_PATH." + ) + + if SGL_DIFFUSION_CONFIGURE_LOGGING: + logging_config = DEFAULT_LOGGING_CONFIG + + if SGL_DIFFUSION_LOGGING_CONFIG_PATH: + if not path.exists(SGL_DIFFUSION_LOGGING_CONFIG_PATH): + raise RuntimeError( + "Could not load logging config. File does not exist: %s", + SGL_DIFFUSION_LOGGING_CONFIG_PATH, + ) + with open(SGL_DIFFUSION_LOGGING_CONFIG_PATH, encoding="utf-8") as file: + custom_config = json.loads(file.read()) + + if not isinstance(custom_config, dict): + raise ValueError( + "Invalid logging config. Expected Dict, got %s.", + type(custom_config).__name__, + ) + logging_config = custom_config + + for formatter in logging_config.get("formatters", {}).values(): + # This provides backwards compatibility after #10134. + if formatter.get("class") == "sglang.multimodal_gen.logging.NewLineFormatter": + formatter["class"] = "sglang.multimodal_gen.logging_utils.NewLineFormatter" + + if logging_config: + dictConfig(logging_config) + + +def init_logger(name: str) -> _SGLDiffusionLogger: + """The main purpose of this function is to ensure that loggers are + retrieved in such a way that we can be sure the root sgl_diffusion logger has + already been configured.""" + + logger = logging.getLogger(name) + + methods_to_patch = { + "info_once": _print_info_once, + "warning_once": _print_warning_once, + "info": _info, + } + + for method_name, method in methods_to_patch.items(): + setattr( + logger, method_name, MethodType(method, logger) + ) # type: ignore[arg-type] + + return cast(_SGLDiffusionLogger, logger) + + +# The root logger is initialized when the module is imported. +# This is thread-safe as the module is only imported once, +# guaranteed by the Python GIL. +# _configure_sgl_diffusion_root_logger() + +logger = init_logger(__name__) + + +def _trace_calls(log_path, root_dir, frame, event, arg=None): + if event in ["call", "return"]: + # Extract the filename, line number, function name, and the code object + filename = frame.f_code.co_filename + lineno = frame.f_lineno + func_name = frame.f_code.co_name + if not filename.startswith(root_dir): + # only log the functions in the sgl_diffusion root_dir + return + # Log every function call or return + try: + last_frame = frame.f_back + if last_frame is not None: + last_filename = last_frame.f_code.co_filename + last_lineno = last_frame.f_lineno + last_func_name = last_frame.f_code.co_name + else: + # initial frame + last_filename = "" + last_lineno = 0 + last_func_name = "" + with open(log_path, "a") as f: + ts = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f") + if event == "call": + f.write( + f"{ts} Call to" + f" {func_name} in {filename}:{lineno}" + f" from {last_func_name} in {last_filename}:" + f"{last_lineno}\n" + ) + else: + f.write( + f"{ts} Return from" + f" {func_name} in {filename}:{lineno}" + f" to {last_func_name} in {last_filename}:" + f"{last_lineno}\n" + ) + except NameError: + # modules are deleted during shutdown + pass + return partial(_trace_calls, log_path, root_dir) + + +def enable_trace_function_call(log_file_path: str, root_dir: str | None = None): + """ + Enable tracing of every function call in code under `root_dir`. + This is useful for debugging hangs or crashes. + `log_file_path` is the path to the log file. + `root_dir` is the root directory of the code to trace. If None, it is the + sgl_diffusion root directory. + + Note that this call is thread-level, any threads calling this function + will have the trace enabled. Other threads will not be affected. + """ + logger.warning( + "SGL_DIFFUSION_TRACE_FUNCTION is enabled. It will record every" + " function executed by Python. This will slow down the code. It " + "is suggested to be used for debugging hang or crashes only." + ) + logger.info("Trace frame log is saved to %s", log_file_path) + if root_dir is None: + # by default, this is the sgl_diffusion root directory + root_dir = os.path.dirname(os.path.dirname(__file__)) + sys.settrace(partial(_trace_calls, log_file_path, root_dir)) + + +def set_uvicorn_logging_configs(): + from uvicorn.config import LOGGING_CONFIG + + LOGGING_CONFIG["formatters"]["default"][ + "fmt" + ] = "[%(asctime)s] %(levelprefix)s %(message)s" + LOGGING_CONFIG["formatters"]["default"]["datefmt"] = "%Y-%m-%d %H:%M:%S" + LOGGING_CONFIG["formatters"]["access"][ + "fmt" + ] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s' + LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S" + + +def configure_logger(server_args, prefix: str = ""): + log_format = f"[%(asctime)s{prefix}] %(message)s" + datefmt = "%m-%d %H:%M:%S" + logging.basicConfig( + level=getattr(logging, server_args.log_level.upper()), + format=log_format, + datefmt=datefmt, + force=True, + ) + + set_uvicorn_logging_configs() + + +def suppress_other_loggers(): + warnings.filterwarnings( + "ignore", category=UserWarning, message="The given NumPy array is not writable" + ) diff --git a/python/sglang/multimodal_gen/runtime/utils/performance_logger.py b/python/sglang/multimodal_gen/runtime/utils/performance_logger.py new file mode 100644 index 000000000..fb4f4e399 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/utils/performance_logger.py @@ -0,0 +1,76 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +import json +import logging +import os +import subprocess +import time +from datetime import datetime + +from dateutil.tz import UTC + +project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../")) +LOG_DIR = os.path.join(project_root, "logs") + +# Configure a specific logger for performance metrics +perf_logger = logging.getLogger("performance") +perf_logger.setLevel(logging.INFO) +perf_logger.propagate = False # Prevent perf logs from going to the main logger + +# Ensure the logs directory exists +if not os.path.exists(LOG_DIR): + os.makedirs(LOG_DIR) + +# Set up a file handler for the performance logger +handler = logging.FileHandler(os.path.join(LOG_DIR, "performance.log")) +handler.setFormatter(logging.Formatter("%(message)s")) +perf_logger.addHandler(handler) + + +def get_git_commit_hash() -> str: + """Get the current git commit hash.""" + try: + commit_hash = ( + subprocess.check_output(["git", "rev-parse", "HEAD"]) + .strip() + .decode("utf-8") + ) + return commit_hash + except (subprocess.CalledProcessError, FileNotFoundError): + return "N/A" + + +class PerformanceLogger: + """ + A utility class for logging performance metrics. + """ + + def __init__(self, request_id: str): + self.request_id = request_id + self.start_time = time.monotonic() + self.step_timings = [] + self.commit_hash = get_git_commit_hash() + + def record_step_start(self): + """Records the start time of a step.""" + self.step_start_time = time.monotonic() + + def record_step_end(self, step_name: str, step_index: int | None = None): + """Records the end time of a step and calculates the duration.""" + duration = time.monotonic() - self.step_start_time + self.step_timings.append( + {"name": step_name, "index": step_index, "duration_ms": duration * 1000} + ) + + def log_total_duration(self, tag: str): + """Logs the total duration of the operation and all recorded steps.""" + total_duration = time.monotonic() - self.start_time + log_entry = { + "timestamp": datetime.now(UTC).isoformat(), + "request_id": self.request_id, + "commit_hash": self.commit_hash, + "tag": tag, + "total_duration_ms": total_duration * 1000, + "steps": self.step_timings, + } + perf_logger.info(json.dumps(log_entry)) diff --git a/python/sglang/multimodal_gen/runtime/workflow/__init__.py b/python/sglang/multimodal_gen/runtime/workflow/__init__.py new file mode 100644 index 000000000..af2eb7d10 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/workflow/__init__.py @@ -0,0 +1 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo diff --git a/python/sglang/multimodal_gen/runtime/workflow/preprocess/__init__.py b/python/sglang/multimodal_gen/runtime/workflow/preprocess/__init__.py new file mode 100644 index 000000000..af2eb7d10 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/workflow/preprocess/__init__.py @@ -0,0 +1 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo diff --git a/python/sglang/multimodal_gen/runtime/workflow/preprocess/components.py b/python/sglang/multimodal_gen/runtime/workflow/preprocess/components.py new file mode 100644 index 000000000..1f7890aec --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/workflow/preprocess/components.py @@ -0,0 +1,341 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +import dataclasses +import gc +import os +import random +from collections.abc import Callable +from typing import Any + +import numpy as np +import pyarrow as pa +import torch +from datasets import Dataset, Video, load_dataset + +from sglang.multimodal_gen.configs.configs import ( + DatasetType, + PreprocessConfig, + VideoLoaderType, +) +from sglang.multimodal_gen.configs.sample.base import DataType +from sglang.multimodal_gen.dataset.dataloader.parquet_io import ( + ParquetDatasetWriter, + records_to_table, +) +from sglang.multimodal_gen.runtime.distributed.parallel_state import ( + get_world_rank, + get_world_size, +) +from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import PreprocessBatch +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class PreprocessingDataValidator: + + def __init__( + self, + max_height: int = 1024, + max_width: int = 1024, + max_h_div_w_ratio: float = 17 / 16, + min_h_div_w_ratio: float = 8 / 16, + num_frames: int = 16, + train_fps: int = 24, + speed_factor: float = 1.0, + video_length_tolerance_range: float = 5.0, + drop_short_ratio: float = 0.0, + hw_aspect_threshold: float = 1.5, + ): + self.max_height = max_height + self.max_width = max_width + self.max_h_div_w_ratio = max_h_div_w_ratio + self.min_h_div_w_ratio = min_h_div_w_ratio + self.num_frames = num_frames + self.train_fps = train_fps + self.speed_factor = speed_factor + self.video_length_tolerance_range = video_length_tolerance_range + self.drop_short_ratio = drop_short_ratio + self.hw_aspect_threshold = hw_aspect_threshold + self.validators: dict[str, Callable[[dict[str, Any]], bool]] = {} + self.filter_counts: dict[str, int] = {} + + self.num_items_before_filtering = 0 + self.num_items_after_filtering = 0 + + self.register_validators() + + def register_validators(self) -> None: + self.add_validator("data_type_validator", self._validate_data_type) + self.add_validator("resolution_validator", self._validate_resolution) + self.add_validator("frame_sampling_validator", self._validate_frame_sampling) + + def add_validator( + self, name: str, validator: Callable[[dict[str, Any]], bool] + ) -> None: + self.validators[name] = validator + self.filter_counts[name] = 0 + + def __call__(self, batch: dict[str, Any]) -> bool: + """ + Validate whether the preprocessing data batch is valid. + """ + self.num_items_before_filtering += 1 + + for name, validator in self.validators.items(): + if not validator(batch): + self.filter_counts[name] += 1 + return False + + self.num_items_after_filtering += 1 + return True + + def _validate_data_type(self, batch: dict[str, Any]) -> bool: + """Validate basic validity of data items""" + return not ( + batch["caption"] is None + or batch["caption"] == "" + or batch["fps"] is None + or batch["fps"] <= 0 + or batch["num_frames"] is None + or batch["num_frames"] <= 0 + ) + + def _validate_resolution(self, batch: dict[str, Any]) -> bool: + """Validate resolution constraints""" + + aspect = self.max_height / self.max_width + if batch["resolution"] is not None: + height = batch["resolution"].get("height", None) + width = batch["resolution"].get("width", None) + + if height is None or width is None: + return False + + return self._filter_resolution( + height, + width, + max_h_div_w_ratio=self.hw_aspect_threshold * aspect, + min_h_div_w_ratio=1 / self.hw_aspect_threshold * aspect, + ) + + def _filter_resolution( + self, h: int, w: int, max_h_div_w_ratio: float, min_h_div_w_ratio: float + ) -> bool: + """Filter based on aspect ratio""" + return (min_h_div_w_ratio <= h / w <= max_h_div_w_ratio) and ( + self.min_h_div_w_ratio <= h / w <= self.max_h_div_w_ratio + ) + + def _validate_frame_sampling(self, batch: dict[str, Any]) -> bool: + """Validate frame sampling constraints""" + + if batch["num_frames"] / batch["fps"] > self.video_length_tolerance_range * ( + self.num_frames / self.train_fps * self.speed_factor + ): + return False + + frame_interval = batch["fps"] / self.train_fps + start_frame_idx = 0 + frame_indices = np.arange( + start_frame_idx, batch["num_frames"], frame_interval + ).astype(int) + return not ( + len(frame_indices) < self.num_frames + and random.random() < self.drop_short_ratio + ) + + def log_validation_stats(self): + info = "" + for name, count in self.filter_counts.items(): + info += f"failed in {name}: {count}, " + info += f"number of items before filtering: {self.num_items_before_filtering}, " + info += f"number of items after filtering: {self.num_items_after_filtering}" + + logger.info(info) + + +class VideoForwardBatchBuilder: + + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, batch: list) -> PreprocessBatch: + forward_batch = PreprocessBatch( + video_loader=[item["video"] for item in batch], + video_file_name=[item["name"] for item in batch], + height=[item["resolution"]["height"] for item in batch], + width=[item["resolution"]["width"] for item in batch], + fps=[item["fps"] for item in batch], + num_frames=[item["num_frames"] for item in batch], + prompt=[item["caption"] for item in batch], + prompt_attention_mask=[], + data_type=DataType.VIDEO, + generator=torch.Generator("cpu").manual_seed(self.seed), + ) + return forward_batch + + +class ParquetDatasetSaver: + """Component for saving and writing Parquet datasets using shared parquet_io.""" + + def __init__( + self, + flush_frequency: int, + samples_per_file: int, + schema: pa.Schema, + record_creator: Callable[..., list[dict[str, Any]]], + ): + self.flush_frequency = flush_frequency + self.samples_per_file = samples_per_file + self.schema = schema + self.create_records_from_batch = record_creator + self.num_processed_samples: int = 0 + self._writer: ParquetDatasetWriter | None = None + + def save_and_write_parquet_batch( + self, + batch: PreprocessBatch, + output_dir: str, + extra_features: dict[str, Any] | None = None, + ) -> None: + """ + Save and write Parquet dataset batch + + Args: + batch: PreprocessBatch containing video and metadata information + output_dir: Output directory + extra_features: Extra features + + Returns: + Number of processed samples + """ + assert isinstance(batch.latents, torch.Tensor) + assert isinstance(batch.prompt_embeds, list) + assert isinstance(batch.prompt_attention_mask, list) + + # Process non-padded embeddings (if needed) + if batch.prompt_attention_mask is not None: + batch.prompt_embeds = self._process_non_padded_embeddings( + batch.prompt_embeds[0], batch.prompt_attention_mask[0] + ) + else: + raise ValueError("prompt_attention_mask is None") + + # Prepare batch data for Parquet dataset + batch_data: list[dict[str, Any]] = [] + + for key in dataclasses.fields(batch): + value = getattr(batch, key.name) + if isinstance(value, list): + for idx in range(len(value)): + if isinstance(value[idx], torch.Tensor): + value[idx] = value[idx].cpu().numpy() + elif isinstance(value, torch.Tensor): + value = value.cpu().numpy() + setattr(batch, key.name, value) + + # Create record for Parquet dataset + records = self.create_records_from_batch(batch) + batch_data.extend(records) + + if batch_data: + self.num_processed_samples += len(batch_data) + table = records_to_table(batch_data, self.schema) + if self._writer is None: + os.makedirs(output_dir, exist_ok=True) + self._writer = ParquetDatasetWriter( + out_dir=output_dir, samples_per_file=self.samples_per_file + ) + self._writer.append_table(table) + logger.debug("Collected batch with %s samples", len(table)) + + # If flush is needed + if self.num_processed_samples >= self.flush_frequency: + self.flush_tables() + + def _process_non_padded_embeddings( + self, prompt_embeds: torch.Tensor, prompt_attention_mask: torch.Tensor + ) -> list[torch.Tensor]: + """Process non-padded embeddings""" + assert isinstance(prompt_embeds, torch.Tensor) + assert isinstance(prompt_attention_mask, torch.Tensor) + assert prompt_embeds.shape[0] == prompt_attention_mask.shape[0] + + # Get sequence lengths from attention masks (number of 1s) + seq_lens = prompt_attention_mask.sum(dim=1) + + non_padded_embeds = [] + + # Process each item in the batch + for i in range(prompt_embeds.size(0)): + seq_len = seq_lens[i].item() + # Slice the embeddings and masks to keep only non-padding parts + non_padded_embeds.append(prompt_embeds[i, :seq_len]) + + return non_padded_embeds + + def flush_tables(self, write_remainder: bool = False): + """Flush buffered records to disk. + + Args: + output_dir: Directory where parquet files are written. Kept for API + symmetry (writer already configured with this path). + write_remainder: If True, also write any leftover rows smaller than + ``samples_per_file`` as a final small file. Useful for the last flush. + """ + if self._writer is None: + return + _ = self._writer.flush(write_remainder=write_remainder) + # Reset processed sample count modulo samples_per_file + remainder = self.num_processed_samples % self.samples_per_file + self.num_processed_samples = 0 if write_remainder else remainder + + def clean_up(self) -> None: + """Clean up all tables""" + self.flush_tables(write_remainder=True) + self._writer = None + self.num_processed_samples = 0 + gc.collect() + + def __del__(self): + self.clean_up() + + +def build_dataset( + preprocess_config: PreprocessConfig, + split: str, + validator: Callable[[dict[str, Any]], bool], +) -> Dataset: + if preprocess_config.dataset_type == DatasetType.HF: + dataset = load_dataset(preprocess_config.dataset_path, split=split) + dataset = dataset.filter(validator) + dataset = dataset.shard(num_shards=get_world_size(), index=get_world_rank()) + elif preprocess_config.dataset_type == DatasetType.MERGED: + metadata_json_path = os.path.join( + preprocess_config.dataset_path, "videos2caption.json" + ) + video_folder = os.path.join(preprocess_config.dataset_path, "videos") + dataset = load_dataset("json", data_files=metadata_json_path, split=split) + column_names = dataset.column_names + # rename columns to match the schema + if "cap" in column_names: + dataset = dataset.rename_column("cap", "caption") + if "path" in column_names: + dataset = dataset.rename_column("path", "name") + + dataset = dataset.filter(validator) + dataset = dataset.shard(num_shards=get_world_size(), index=get_world_rank()) + + # add video column + def add_video_column(item: dict[str, Any]) -> dict[str, Any]: + item["video"] = os.path.join(video_folder, item["name"]) + return item + + dataset = dataset.map(add_video_column) + if preprocess_config.video_loader_type == VideoLoaderType.TORCHCODEC: + dataset = dataset.cast_column("video", Video()) + else: + raise ValueError(f"Invalid dataset type: {preprocess_config.dataset_type}") + + return dataset diff --git a/python/sglang/multimodal_gen/runtime/workflow/preprocess/preprocess_workflow.py b/python/sglang/multimodal_gen/runtime/workflow/preprocess/preprocess_workflow.py new file mode 100644 index 000000000..3d3a831ae --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/workflow/preprocess/preprocess_workflow.py @@ -0,0 +1,143 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +import os +from typing import cast + +from torch.utils.data import DataLoader + +from sglang.multimodal_gen.configs.configs import PreprocessConfig +from sglang.multimodal_gen.dataset.dataloader.record_schema import ( + basic_t2v_record_creator, + i2v_record_creator, +) +from sglang.multimodal_gen.dataset.dataloader.schema import ( + pyarrow_schema_i2v, + pyarrow_schema_t2v, +) +from sglang.multimodal_gen.runtime.distributed.parallel_state import get_world_rank +from sglang.multimodal_gen.runtime.pipelines.pipeline_registry import PipelineType +from sglang.multimodal_gen.runtime.server_args import ServerArgs, WorkloadType +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.runtime.workflow.preprocess.components import ( + ParquetDatasetSaver, + PreprocessingDataValidator, + VideoForwardBatchBuilder, + build_dataset, +) +from sglang.multimodal_gen.runtime.workflow.workflow_base import WorkflowBase + +logger = init_logger(__name__) + + +class PreprocessWorkflow(WorkflowBase): + + def register_pipelines(self) -> None: + self.add_pipeline_config( + "preprocess_pipeline", (PipelineType.PREPROCESS, self.server_args) + ) + + def register_components(self) -> None: + assert self.server_args.preprocess_config is not None + preprocess_config: PreprocessConfig = self.server_args.preprocess_config + + # raw data validator + raw_data_validator = PreprocessingDataValidator( + max_height=preprocess_config.max_height, + max_width=preprocess_config.max_width, + num_frames=preprocess_config.num_frames, + train_fps=preprocess_config.train_fps, + speed_factor=preprocess_config.speed_factor, + video_length_tolerance_range=preprocess_config.video_length_tolerance_range, + drop_short_ratio=preprocess_config.drop_short_ratio, + ) + self.add_component("raw_data_validator", raw_data_validator) + + # training dataset + training_dataset = build_dataset( + preprocess_config, split="train", validator=raw_data_validator + ) + # we do not use collate_fn here because we use iterable-style Dataset + # and want to keep the original type of the dataset + training_dataloader = DataLoader( + training_dataset, + batch_size=preprocess_config.preprocess_video_batch_size, + num_workers=preprocess_config.dataloader_num_workers, + collate_fn=lambda x: x, + ) + self.add_component("training_dataloader", training_dataloader) + + # try to load validation dataset if it exists + try: + validation_dataset = build_dataset( + preprocess_config, split="validation", validator=raw_data_validator + ) + validation_dataloader = DataLoader( + validation_dataset, + batch_size=preprocess_config.preprocess_video_batch_size, + num_workers=preprocess_config.dataloader_num_workers, + collate_fn=lambda x: x, + ) + except ValueError: + logger.warning( + "Validation dataset not found, skipping validation dataset preprocessing." + ) + validation_dataloader = None + + self.add_component("validation_dataloader", validation_dataloader) + + # forward batch builder + video_forward_batch_builder = VideoForwardBatchBuilder( + seed=self.server_args.preprocess_config.seed + ) + self.add_component("video_forward_batch_builder", video_forward_batch_builder) + + # record creator + if self.server_args.workload_type == WorkloadType.I2V: + record_creator = i2v_record_creator + schema = pyarrow_schema_i2v + else: + record_creator = basic_t2v_record_creator + schema = pyarrow_schema_t2v + processed_dataset_saver = ParquetDatasetSaver( + flush_frequency=self.server_args.preprocess_config.flush_frequency, + samples_per_file=self.server_args.preprocess_config.samples_per_file, + schema=schema, + record_creator=record_creator, + ) + self.add_component("processed_dataset_saver", processed_dataset_saver) + + def prepare_system_environment(self) -> None: + assert self.server_args.preprocess_config is not None + dataset_output_dir = self.server_args.preprocess_config.dataset_output_dir + os.makedirs(dataset_output_dir, exist_ok=True) + + validation_dataset_output_dir = os.path.join( + dataset_output_dir, "validation_dataset", f"worker_{get_world_rank()}" + ) + os.makedirs(validation_dataset_output_dir, exist_ok=True) + self.validation_dataset_output_dir = validation_dataset_output_dir + + training_dataset_output_dir = os.path.join( + dataset_output_dir, "training_dataset", f"worker_{get_world_rank()}" + ) + os.makedirs(training_dataset_output_dir, exist_ok=True) + self.training_dataset_output_dir = training_dataset_output_dir + + @classmethod + def get_workflow_cls(cls, server_args: ServerArgs) -> "PreprocessWorkflow": + if server_args.workload_type == WorkloadType.T2V: + from sglang.multimodal_gen.runtime.workflow.preprocess.preprocess_workflow_t2v import ( + PreprocessWorkflowT2V, + ) + + return cast(PreprocessWorkflow, PreprocessWorkflowT2V) + elif server_args.workload_type == WorkloadType.I2V: + from sglang.multimodal_gen.runtime.workflow.preprocess.preprocess_workflow_i2v import ( + PreprocessWorkflowI2V, + ) + + return cast(PreprocessWorkflow, PreprocessWorkflowI2V) + else: + raise ValueError( + f"Workload type: {server_args.workload_type} is not supported in preprocessing workflow." + ) diff --git a/python/sglang/multimodal_gen/runtime/workflow/preprocess/preprocess_workflow_i2v.py b/python/sglang/multimodal_gen/runtime/workflow/preprocess/preprocess_workflow_i2v.py new file mode 100644 index 000000000..876d7bd41 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/workflow/preprocess/preprocess_workflow_i2v.py @@ -0,0 +1,70 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +from typing import TYPE_CHECKING + +from tqdm import tqdm + +from sglang.multimodal_gen.dataset.preprocessing_datasets import PreprocessBatch +from sglang.multimodal_gen.runtime.workflow.preprocess.components import ( + ParquetDatasetSaver, +) +from sglang.multimodal_gen.runtime.workflow.preprocess.preprocess_workflow import ( + PreprocessWorkflow, +) + +if TYPE_CHECKING: + from torch.utils.data import DataLoader + + from sglang.multimodal_gen.runtime.pipelines.composed_pipeline_base import ( + ComposedPipelineBase, + ) + from sglang.multimodal_gen.runtime.workflow.preprocess.components import ( + VideoForwardBatchBuilder, + ) + + +class PreprocessWorkflowI2V(PreprocessWorkflow): + training_dataloader: "DataLoader" + validation_dataloader: "DataLoader" + preprocess_pipeline: "ComposedPipelineBase" + processed_dataset_saver: "ParquetDatasetSaver" + video_forward_batch_builder: "VideoForwardBatchBuilder" + + def run(self) -> None: + # Training dataset preprocessing + for batch in tqdm( + self.training_dataloader, + desc="Preprocessing training dataset", + unit="batch", + ): + forward_batch: PreprocessBatch = self.video_forward_batch_builder(batch) + + forward_batch = self.preprocess_pipeline.forward( + forward_batch, self.server_args + ) + + self.processed_dataset_saver.save_and_write_parquet_batch( + forward_batch, self.training_dataset_output_dir + ) + + self.processed_dataset_saver.flush_tables() + self.processed_dataset_saver.clean_up() + + # Validation dataset preprocessing + if self.validation_dataloader is not None: + for batch in tqdm( + self.validation_dataloader, + desc="Preprocessing validation dataset", + unit="batch", + ): + forward_batch = self.video_forward_batch_builder(batch) + + forward_batch = self.preprocess_pipeline.forward( + forward_batch, self.server_args + ) + + self.processed_dataset_saver.save_and_write_parquet_batch( + forward_batch, self.validation_dataset_output_dir + ) + self.processed_dataset_saver.flush_tables() + self.processed_dataset_saver.clean_up() diff --git a/python/sglang/multimodal_gen/runtime/workflow/preprocess/preprocess_workflow_t2v.py b/python/sglang/multimodal_gen/runtime/workflow/preprocess/preprocess_workflow_t2v.py new file mode 100644 index 000000000..b8f1df011 --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/workflow/preprocess/preprocess_workflow_t2v.py @@ -0,0 +1,70 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +from typing import TYPE_CHECKING, Optional + +from tqdm import tqdm + +from sglang.multimodal_gen.runtime.pipelines.pipeline_batch_info import PreprocessBatch +from sglang.multimodal_gen.runtime.workflow.preprocess.components import ( + ParquetDatasetSaver, +) +from sglang.multimodal_gen.runtime.workflow.preprocess.preprocess_workflow import ( + PreprocessWorkflow, +) + +if TYPE_CHECKING: + from torch.utils.data import DataLoader + + from sglang.multimodal_gen.runtime.pipelines.composed_pipeline_base import ( + ComposedPipelineBase, + ) + from sglang.multimodal_gen.runtime.workflow.preprocess.components import ( + VideoForwardBatchBuilder, + ) + + +class PreprocessWorkflowT2V(PreprocessWorkflow): + training_dataloader: "DataLoader" + validation_dataloader: Optional["DataLoader"] + preprocess_pipeline: "ComposedPipelineBase" + processed_dataset_saver: "ParquetDatasetSaver" + video_forward_batch_builder: "VideoForwardBatchBuilder" + + def run(self) -> None: + # Training dataset preprocessing + for batch in tqdm( + self.training_dataloader, + desc="Preprocessing training dataset", + unit="batch", + ): + forward_batch: PreprocessBatch = self.video_forward_batch_builder(batch) + + forward_batch = self.preprocess_pipeline.forward( + forward_batch, self.server_args + ) + + self.processed_dataset_saver.save_and_write_parquet_batch( + forward_batch, self.training_dataset_output_dir + ) + + self.processed_dataset_saver.flush_tables() + self.processed_dataset_saver.clean_up() + + # Validation dataset preprocessing + if self.validation_dataloader is not None: + for batch in tqdm( + self.validation_dataloader, + desc="Preprocessing validation dataset", + unit="batch", + ): + forward_batch = self.video_forward_batch_builder(batch) + + forward_batch = self.preprocess_pipeline.forward( + forward_batch, self.server_args + ) + + self.processed_dataset_saver.save_and_write_parquet_batch( + forward_batch, self.validation_dataset_output_dir + ) + self.processed_dataset_saver.flush_tables() + self.processed_dataset_saver.clean_up() diff --git a/python/sglang/multimodal_gen/runtime/workflow/workflow_base.py b/python/sglang/multimodal_gen/runtime/workflow/workflow_base.py new file mode 100644 index 000000000..88522f69e --- /dev/null +++ b/python/sglang/multimodal_gen/runtime/workflow/workflow_base.py @@ -0,0 +1,188 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +from abc import ABC, abstractmethod +from typing import Any, Optional + +from sglang.multimodal_gen.runtime.pipelines import ComposedPipelineBase, build_pipeline +from sglang.multimodal_gen.runtime.pipelines.pipeline_registry import PipelineType +from sglang.multimodal_gen.runtime.server_args import ExecutionMode, ServerArgs +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class WorkflowBase(ABC): + """ + Abstract base class for defining video processing workflows. + + A workflow serves as the top-level orchestrator that coordinates multiple pipelines + and components to accomplish a specific video processing task. The workflow pattern + provides several key benefits: + + 1. **Separation of Concerns**: Workflows separate high-level orchestration logic + from low-level processing implementations in pipelines. + + 2. **Modularity**: Different workflows can be created for different execution modes + (preprocess, inference, etc.) while sharing common pipeline components. + + 3. **Configuration Management**: Workflows manage the configuration and initialization + of multiple related pipelines and components in a centralized manner. + + 4. **Environment Setup**: Workflows handle system-level setup and resource + allocation before pipeline execution begins. + + 5. **Lifecycle Management**: Workflows control the complete lifecycle from + initialization through execution to cleanup. + + The workflow acts as a factory and coordinator, creating the appropriate pipelines + based on configuration, setting up the execution environment, and orchestrating + the overall processing flow. + """ + + def __init__(self, server_args: ServerArgs): + """ + Initialize the workflow with configuration arguments. + + Args: + server_args: Configuration object containing all parameters + needed for workflow and pipeline setup. + """ + self.server_args = server_args + + # TODO: pipeline_config should be: dict[str, PipelineConfig] + # pipeline_type should be included in the PipelineConfig + # pipeline_config[pipeline_name] = (pipeline_type, server_args) + self._pipeline_configs: dict[str, tuple[PipelineType, ServerArgs]] = {} + self._pipelines: dict[str, ComposedPipelineBase] = {} + self._components: dict[str, Any] = {} + self.register_pipelines() + self.register_components() + + self.prepare_system_environment() + self.load_pipelines() + + def load_pipelines(self) -> None: + """ + Create and initialize all registered pipelines. + + This method instantiates pipeline objects from their configurations + and makes them available as both dictionary entries and instance + attributes for convenient access. + """ + for pipeline_name, pipeline_config in self._pipeline_configs.items(): + pipeline_type, server_args = pipeline_config + pipeline = build_pipeline(server_args, pipeline_type) + self._pipelines[pipeline_name] = pipeline + setattr(self, pipeline_name, pipeline) + + def add_pipeline_config( + self, pipeline_name: str, pipeline_config: tuple[PipelineType, ServerArgs] + ) -> None: + """ + Register a pipeline configuration for later instantiation. + + Args: + pipeline_name: Unique identifier for the pipeline. + pipeline_config: Tuple containing the pipeline type and + configuration arguments. + """ + self._pipeline_configs[pipeline_name] = pipeline_config + + def add_component(self, component_name: str, component: Any) -> None: + """ + Register a component instance with the workflow. + + Components are auxiliary objects that may be shared across pipelines + or used for workflow-level functionality (e.g., databases, caches, + external services). + + Args: + component_name: Unique identifier for the component. + component: The component instance to register. + """ + self._components[component_name] = component + setattr(self, component_name, component) + + def get_component(self, component_name: str) -> Any: + """ + Retrieve a registered component by name. + + Args: + component_name: The name of the component to retrieve. + + Returns: + The component instance. + """ + return self._components[component_name] + + @abstractmethod + def register_components(self) -> None: + """ + Register workflow-specific components. + + Subclasses must implement this method to register any components + needed for their specific workflow (e.g., databases, external APIs, + shared resources). + """ + pass + + @abstractmethod + def register_pipelines(self) -> None: + """ + Register workflow-specific pipelines. + + Subclasses must implement this method to define which pipelines + are needed for their specific workflow and how they should be + configured. + """ + pass + + @abstractmethod + def prepare_system_environment(self) -> None: + """ + Prepare the system environment for workflow execution. + + Subclasses must implement this method to handle any system-level + setup required before pipeline execution (e.g., GPU initialization, + temporary directories, resource allocation). + """ + pass + + @abstractmethod + def run(self): + """ + Execute the main workflow logic. + + Subclasses must implement this method to define the specific + execution flow for their workflow, coordinating the registered + pipelines and components to accomplish the desired task. + """ + pass + + @classmethod + def get_workflow_cls(cls, server_args: ServerArgs) -> Optional["WorkflowBase"]: + """ + Factory method to get the appropriate workflow class based on execution mode. + + This method acts as a workflow factory, returning the appropriate + workflow class implementation based on the specified execution mode + in the configuration arguments. + + Args: + server_args: Configuration object containing the execution mode + and other parameters. + + Returns: + The appropriate workflow class for the specified execution mode, + or None if no workflow is available for the given mode. + """ + if server_args.mode == ExecutionMode.PREPROCESS: + from sglang.multimodal_gen.runtime.workflow.preprocess.preprocess_workflow import ( + PreprocessWorkflow, + ) + + return PreprocessWorkflow.get_workflow_cls(server_args) + else: + raise ValueError( + f"Execution mode: {server_args.mode} is not supported in workflow." + ) diff --git a/python/sglang/multimodal_gen/test/__init__.py b/python/sglang/multimodal_gen/test/__init__.py new file mode 100644 index 000000000..af2eb7d10 --- /dev/null +++ b/python/sglang/multimodal_gen/test/__init__.py @@ -0,0 +1 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo diff --git a/python/sglang/multimodal_gen/test/cli/test_generate_common.py b/python/sglang/multimodal_gen/test/cli/test_generate_common.py new file mode 100644 index 000000000..aff1c98aa --- /dev/null +++ b/python/sglang/multimodal_gen/test/cli/test_generate_common.py @@ -0,0 +1,105 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +""" + Common generate cli test, one test for image and video each +""" + +import os +import unittest +from pathlib import Path + +from PIL import Image + +from sglang.multimodal_gen.test.test_utils import ( + TestCLIBase, + check_image_size, + is_mp4, + run_command, +) + + +class TestGenerate(TestCLIBase): + model_path = "black-forest-labs/FLUX.1-dev" + launch_file_name = "launch_flux.json" + output_name = "FLUX.1-dev, single gpu" + ext = "jpg" + + def test_generate_with_config(self): + test_dir = Path(__file__).parent + config_path = ( + (test_dir / ".." / "test_files" / self.launch_file_name) + .resolve() + .as_posix() + ) + command = [ + "sgl_diffusion", + "generate", + f"--config={config_path}", + ] + duration = run_command(command) + + self.assertIsNotNone(duration, f"Run command failed: {command}") + + # verify + self.verify_image(self.output_name) + + def test_generate_multiple_outputs(self): + command = [ + "sglang", + "generate", + "--prompt='A curious raccoon'", + "--output-path=outputs", + f"--model-path={self.model_path}", + "--save-output", + f"--output-file-name={self.output_name}", + "--num-outputs-per-prompt=2", + "--width=720", + "--height=720", + ] + duration = run_command(command) + self.assertIsNotNone(duration, f"Run command failed: {command}") + + self.verify_image(f"{self.output_name}_0.{self.ext}") + self.verify_image(f"{self.output_name}_1.{self.ext}") + + def verify_image(self, output_name): + path = os.path.join("outputs", output_name) + with Image.open(path) as image: + check_image_size(self, image, 720, 720) + + def verify_video(self, output_name): + path = os.path.join("outputs", output_name) + with open(path, "rb") as f: + header = f.read(12) + assert is_mp4(header) + + +class TestWanGenerate(TestGenerate): + model_path = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" + launch_file_name = "launch_wan.json" + output_name = "Wan2.1-T2V-1.3B-Diffusers, single gpu" + ext = "mp4" + + def test_generate_multiple_outputs(self): + command = [ + "sglang", + "generate", + "--prompt='A curious raccoon'", + "--output-path=outputs", + f"--model-path={self.model_path}", + "--save-output", + f"--output-file-name={self.output_name}", + "--num-outputs-per-prompt=2", + "--width=720", + "--height=720", + ] + duration = run_command(command) + self.assertIsNotNone(duration, f"Run command failed: {command}") + + self.verify_video(f"{self.output_name}_0.{self.ext}") + # FIXME: second video is a meaningless output + self.verify_video(f"{self.output_name}_1.{self.ext}") + + +if __name__ == "__main__": + unittest.main() diff --git a/python/sglang/multimodal_gen/test/cli/test_generate_t2i_perf.py b/python/sglang/multimodal_gen/test/cli/test_generate_t2i_perf.py new file mode 100644 index 000000000..bbfde89a4 --- /dev/null +++ b/python/sglang/multimodal_gen/test/cli/test_generate_t2i_perf.py @@ -0,0 +1,70 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +import unittest +from pathlib import Path + +from sglang.multimodal_gen.configs.sample.base import DataType +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.test.test_utils import TestGenerateBase + +logger = init_logger(__name__) + + +class TestFlux_T2V(TestGenerateBase): + model_path = "black-forest-labs/FLUX.1-dev" + extra_args = [] + data_type: DataType = DataType.IMAGE + thresholds = { + "test_single_gpu": 6.90 * 1.05, + } + + +class TestQwenImage(TestGenerateBase): + model_path = "Qwen/Qwen-Image" + extra_args = [] + data_type: DataType = DataType.IMAGE + thresholds = { + "test_single_gpu": 11.7 * 1.05, + } + + +class TestQwenImageEdit(TestGenerateBase): + model_path = "Qwen/Qwen-Image-Edit" + extra_args = [] + data_type: DataType = DataType.IMAGE + thresholds = { + "test_single_gpu": 43.5 * 1.05, + } + + prompt: str | None = ( + "Change the rabbit's color to purple, with a flash light background." + ) + + def setUp(self): + test_dir = Path(__file__).parent + img_path = (test_dir / ".." / "test_files" / "rabbit.jpg").resolve().as_posix() + self.base_command = [ + "sglang", + "generate", + "--text-encoder-cpu-offload", + "--pin-cpu-memory", + f"--prompt='{self.prompt}'", + "--save-output", + "--log-level=debug", + f"--width={self.width}", + f"--height={self.height}", + f"--output-path={self.output_path}", + ] + [f"--image-path={img_path}"] + + def test_single_gpu(self): + self._run_test( + name=f"{self.model_name()}, single gpu", + args=None, + model_path=self.model_path, + test_key="test_single_gpu", + ) + + +if __name__ == "__main__": + del TestGenerateBase + unittest.main() diff --git a/python/sglang/multimodal_gen/test/cli/test_generate_t2v_perf.py b/python/sglang/multimodal_gen/test/cli/test_generate_t2v_perf.py new file mode 100644 index 000000000..ea2d35304 --- /dev/null +++ b/python/sglang/multimodal_gen/test/cli/test_generate_t2v_perf.py @@ -0,0 +1,68 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +import unittest + +from sglang.multimodal_gen.configs.sample.base import DataType +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.test.test_utils import TestGenerateBase + +logger = init_logger(__name__) + + +class TestFastWan2_1_T2V(TestGenerateBase): + model_path = "FastVideo/FastWan2.1-T2V-1.3B-Diffusers" + extra_args = ["--attention-backend=video_sparse_attn"] + data_type: DataType = DataType.VIDEO + thresholds = { + "test_single_gpu": 13.0, + "test_cfg_parallel": 15.0, + "test_usp": 15.0, + "test_mixed": 15.0, + } + + +class TestFastWan2_2_T2V(TestGenerateBase): + model_path = "FastVideo/FastWan2.2-TI2V-5B-FullAttn-Diffusers" + extra_args = [] + data_type: DataType = DataType.VIDEO + thresholds = { + "test_single_gpu": 25.0, + "test_cfg_parallel": 30.0, + "test_usp": 30.0, + "test_mixed": 30.0, + } + + +class TestWan2_1_T2V(TestGenerateBase): + model_path = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" + extra_args = [] + data_type: DataType = DataType.VIDEO + thresholds = { + "test_single_gpu": 76.0, + "test_cfg_parallel": 46.5 * 1.05, + "test_usp": 22.5, + "test_mixed": 26.5, + } + + +class TestWan2_2_T2V(TestGenerateBase): + model_path = "Wan-AI/Wan2.2-T2V-A14B-Diffusers" + extra_args = [] + data_type: DataType = DataType.VIDEO + thresholds = { + "test_single_gpu": 865, + "test_cfg_parallel": 446, + "test_usp": 124, + "test_mixed": 159, + } + + def test_mixed(self): + pass + + def test_cfg_parallel(self): + pass + + +if __name__ == "__main__": + del TestGenerateBase + unittest.main() diff --git a/python/sglang/multimodal_gen/test/cli/test_generate_ti2v_perf.py b/python/sglang/multimodal_gen/test/cli/test_generate_ti2v_perf.py new file mode 100644 index 000000000..79d043f52 --- /dev/null +++ b/python/sglang/multimodal_gen/test/cli/test_generate_ti2v_perf.py @@ -0,0 +1,62 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +import unittest + +from sglang.multimodal_gen.configs.sample.base import DataType +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger +from sglang.multimodal_gen.test.test_utils import TestGenerateBase + +logger = init_logger(__name__) + + +class TestGenerateTI2VBase(TestGenerateBase): + data_type: DataType = DataType.VIDEO + + @classmethod + def setUpClass(cls): + cls.base_command = [ + "sglang", + "generate", + f'--prompt="Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline\'s intricate details and the refreshing atmosphere of the seaside."', + "--image-path=https://github.com/Wan-Video/Wan2.2/blob/990af50de458c19590c245151197326e208d7191/examples/i2v_input.JPG?raw=true", + "--save-output", + "--log-level=debug", + f"--output-path={cls.output_path}", + ] + cls.extra_args + + def test_single_gpu(self): + pass + + def test_cfg_parallel(self): + pass + + def test_mixed(self): + pass + + +class TestWan2_1_I2V_14B_480P(TestGenerateTI2VBase): + model_path = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers" + extra_args = ["--attention-backend=video_sparse_attn"] + thresholds = { + "test_single_gpu": 13.0, + "test_cfg_parallel": 191.7 * 1.05, + "test_usp": 15.0, + "test_mixed": 15.0, + } + + +class TestWan2_2_TI2V_5B(TestGenerateTI2VBase): + model_path = "Wan-AI/Wan2.2-TI2V-5B-Diffusers" + # FIXME: doesn't work with vsa at the moment + # extra_args = ["--attention-backend=video_sparse_attn"] + thresholds = { + "test_single_gpu": 13.0, + "test_cfg_parallel": 191.7 * 1.05, + "test_usp": 387.6 * 1.05, + "test_mixed": 15.0, + } + + +if __name__ == "__main__": + del TestGenerateTI2VBase, TestGenerateBase + unittest.main() diff --git a/python/sglang/multimodal_gen/test/cli/test_serve.py b/python/sglang/multimodal_gen/test/cli/test_serve.py new file mode 100644 index 000000000..fe39d7fff --- /dev/null +++ b/python/sglang/multimodal_gen/test/cli/test_serve.py @@ -0,0 +1,287 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +import asyncio +import base64 +import subprocess +import time +import unittest +from pathlib import Path + +from openai import OpenAI + +from sglang.multimodal_gen.runtime.utils.common import kill_process_tree +from sglang.multimodal_gen.test.test_utils import is_mp4, is_png, wait_for_port + + +def wait_for_video_completion(client, video_id, timeout=300, check_interval=3): + start = time.time() + video = client.videos.retrieve(video_id) + + while video.status not in ("completed", "failed"): + time.sleep(check_interval) + video = client.videos.retrieve(video_id) + assert time.time() - start < timeout, "video generate timeout" + + return video + + +class TestVideoHttpServer(unittest.TestCase): + model_name = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" + timeout = 120 + extra_args = [] + + def _create_wait_and_download( + self, client: OpenAI, prompt: str, size: str + ) -> bytes: + + video = client.videos.create(prompt=prompt, size=size) + video_id = video.id + self.assertEqual(video.status, "queued") + + video = wait_for_video_completion(client, video_id, timeout=self.timeout) + self.assertEqual(video.status, "completed", "video generate failed") + + response = client.videos.download_content( + video_id=video_id, + ) + content = response.read() + return content + + @classmethod + def setUpClass(cls): + cls.base_command = [ + "sglang", + "serve", + "--model-path", + f"{cls.model_name}", + "--port", + "30010", + ] + + process = subprocess.Popen( + cls.base_command + cls.extra_args, + # stdout=subprocess.PIPE, + # stderr=subprocess.PIPE, + text=True, + bufsize=1, + ) + cls.pid = process.pid + wait_for_port(host="127.0.0.1", port=30010) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.pid) + + def test_http_server_basic(self): + client = OpenAI( + api_key="sk-proj-1234567890", base_url="http://localhost:30010/v1" + ) + content = self._create_wait_and_download( + client, "A calico cat playing a piano on stage", "832x480" + ) + self.assertTrue(is_mp4(content)) + + def test_concurrent_requests(self): + client = OpenAI( + api_key="sk-proj-1234567890", base_url="http://localhost:30010/v1" + ) + + num_requests = 2 + + async def generate_and_check_video(prompt, size): + content = await asyncio.to_thread( + self._create_wait_and_download, client, prompt, size + ) + self.assertTrue(is_mp4(content)) + + async def send_concurrent_requests(): + tasks = [ + generate_and_check_video( + "A dog playing a piano on stage", + "832x480", + ) + for _ in range(num_requests) + ] + await asyncio.gather(*tasks) + + asyncio.run(send_concurrent_requests()) + + +class TestFastWan2_1HttpServer(TestVideoHttpServer): + model_name = "FastVideo/FastWan2.1-T2V-1.3B-Diffusers" + + +class TestFastWan2_2HttpServer(TestVideoHttpServer): + model_name = "FastVideo/FastWan2.2-TI2V-5B-FullAttn-Diffusers" + + +class TestImage2VideoHttpServer(unittest.TestCase): + model_name = "Wan-AI/Wan2.2-I2V-A14B-Diffusers" + timeout = 1200 + extra_args = [] + + def _create_wait_and_download( + self, client: OpenAI, prompt: str, size: str + ) -> bytes: + + image_path = "https://github.com/Wan-Video/Wan2.2/blob/990af50de458c19590c245151197326e208d7191/examples/i2v_input.JPG?raw=true" + image_path = Path(image_path) + video = client.videos.create( + prompt=prompt, + input_reference=image_path, + size=size, + seconds=10, + extra_body={"fps": 16, "num_frames": 125}, + ) + # TODO: Some combinations of num_frames and fps may cause errors and need further investigation. + video_id = video.id + self.assertEqual(video.status, "queued") + + video = wait_for_video_completion(client, video_id, timeout=self.timeout) + self.assertEqual(video.status, "completed", "video generate failed") + + response = client.videos.download_content( + video_id=video_id, + ) + content = response.read() + return content + + @classmethod + def setUpClass(cls): + cls.base_command = [ + "sgl-diffusion", + "serve", + "--model-path", + f"{cls.model_name}", + "--port", + "30010", + ] + + process = subprocess.Popen( + cls.base_command + cls.extra_args, + # stdout=subprocess.PIPE, + # stderr=subprocess.PIPE, + text=True, + bufsize=1, + ) + cls.pid = process.pid + wait_for_port(host="127.0.0.1", port=30010) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.pid) + + def test_http_server_basic(self): + client = OpenAI( + api_key="sk-proj-1234567890", base_url="http://localhost:30010/v1" + ) + content = self._create_wait_and_download( + client, "A girl is fighting a monster.", "832x480" + ) + self.assertTrue(is_mp4(content)) + + def test_concurrent_requests(self): + client = OpenAI( + api_key="sk-proj-1234567890", base_url="http://localhost:30010/v1" + ) + + num_requests = 2 + + async def generate_and_check_video(prompt, size): + content = await asyncio.to_thread( + self._create_wait_and_download, client, prompt, size + ) + self.assertTrue(is_mp4(content)) + + async def send_concurrent_requests(): + tasks = [ + generate_and_check_video( + "A dog playing a piano on stage", + "832x480", + ) + for _ in range(num_requests) + ] + await asyncio.gather(*tasks) + + asyncio.run(send_concurrent_requests()) + + +class TestImageHttpServer(unittest.TestCase): + + def _create_wait_and_download( + self, client: OpenAI, prompt: str, size: str + ) -> bytes: + img = client.images.generate( + model="gpt-image-1", + prompt=prompt, + n=1, + size=size, + response_format="b64_json", + output_format="png", + ) + image_bytes = base64.b64decode(img.data[0].b64_json) + return image_bytes + + @classmethod + def setUpClass(cls): + cls.base_command = [ + "sglang", + "serve", + "--model-path", + "Qwen/Qwen-Image", + "--port", + "30020", + ] + + process = subprocess.Popen( + cls.base_command, + # stdout=subprocess.PIPE, + # stderr=subprocess.PIPE, + text=True, + bufsize=1, + ) + cls.pid = process.pid + wait_for_port(host="127.0.0.1", port=30020) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.pid) + + def test_http_server_basic(self): + client = OpenAI( + api_key="sk-proj-1234567890", base_url="http://localhost:30020/v1" + ) + content = self._create_wait_and_download( + client, "A calico cat playing a piano on stage", "832x480" + ) + self.assertTrue(is_png(content)) + + def test_concurrent_requests(self): + client = OpenAI( + api_key="sk-proj-1234567890", base_url="http://localhost:30020/v1" + ) + + num_requests = 2 + + async def generate_and_check_image(prompt, size): + content = await asyncio.to_thread( + self._create_wait_and_download, client, prompt, size + ) + self.assertTrue(is_png(content)) + + async def send_concurrent_requests(): + tasks = [ + generate_and_check_image( + "A dog playing a piano on stage", + "832x480", + ) + for _ in range(num_requests) + ] + await asyncio.gather(*tasks) + + asyncio.run(send_concurrent_requests()) + + +if __name__ == "__main__": + # del TestPerform·anceBase + unittest.main() diff --git a/python/sglang/multimodal_gen/test/test_files/launch_flux.json b/python/sglang/multimodal_gen/test/test_files/launch_flux.json new file mode 100644 index 000000000..6a9d83820 --- /dev/null +++ b/python/sglang/multimodal_gen/test/test_files/launch_flux.json @@ -0,0 +1,11 @@ +{ + "model_path": "black-forest-labs/FLUX.1-dev", + "prompt": "A beautiful woman in a red dress walking down a street", + "text_encoder_cpu_offload": true, + "pin_cpu_memory": true, + "save_output": true, + "width": 720, + "height": 720, + "output_path": "outputs", + "output_file_name": "FLUX.1-dev, single gpu" +} diff --git a/python/sglang/multimodal_gen/test/test_files/launch_wan.json b/python/sglang/multimodal_gen/test/test_files/launch_wan.json new file mode 100644 index 000000000..eeb9ddf9d --- /dev/null +++ b/python/sglang/multimodal_gen/test/test_files/launch_wan.json @@ -0,0 +1,11 @@ +{ + "model_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", + "prompt": "A beautiful woman in a red dress walking down a street", + "text_encoder_cpu_offload": true, + "pin_cpu_memory": true, + "save_output": true, + "width": 720, + "height": 720, + "output_path": "outputs", + "output_file_name": "Wan2.1-T2V-1.3B-Diffusers, single gpu" +} diff --git a/python/sglang/multimodal_gen/test/test_files/rabbit.jpg b/python/sglang/multimodal_gen/test/test_files/rabbit.jpg new file mode 100644 index 0000000000000000000000000000000000000000..56747c94ac7fc89f77091603488296e082fb82aa GIT binary patch literal 268656 zcmeFYcRZVK`#v69DLO>$+17|rt=NjT_Ff4=C^bWCZ%P#vtJK~rc9I&gDcYd+-jvuQ zilV5tU(fS=p3nF9=lAvd`+X*N?lafx%9Z;#uH0wt>p0K*@66v7hCeEPPWAwRjt)Qw z002+}s7Y=C$gVt+s|SFD6+rQC9{_ks!uG$u4~fA4^dY_K^8`TtpFU)7NuK`Kr%)xm z{eQ=aa{6yOAmpm=zfApVyZVOz*;Bm-0I06skw19&P&|?3KRyX586f|^?cxs}ARLS0wk?;aYZ<3JRB>DS=r0_~5p!m-b z{&&i*j+2~%l8X8o4ebp8DG3=FDLEMh1^JZ-K@$8g$$u~iAjZMv6 zpTFR{dwRe2jg3!CPEF6u&aDtu*VZ>Sx3+fZ%6COHL{qo_J|x0j@7f3L`me~av_Ho`9{x9CCIc`QzI^EFYdRoxns`nQ_oun;Qb7&T}oVfKtap+E%PMU5`yH3h6BaB=isI5+G zMO|%(sphm-IH;i9u{JC&rS~hUH|`DYL>0Z#)^zR;x2kZAl#~WmRK+oHrH)&70#mmt z0-)ZnZN5%;VGK7G%F%%0w}O?<9aCm<(xt_d`K{zb*5#`vZgo>R@3CofTO$Wjd*YKj*6Xhy zS5}ABquUY`6omp{DH4GC_c@E8eA9N#$fTC!(k7NC&(h*3X5}(+j}>ja+2R zrBpz3zaqb?js_LuAAoJG4Sli-{#_((?nK`RrJ^RX__bzyC|i^B!`TZXr)wSg6h){f zrkSn)>Ff#7K;smWU9&LSld>2R!^Vw5P$uBzUx4PjkHZpE?B0aJ3rMJL)g!#A2B3m$ zsYGQs05BDALRpV~X$+No7qEf;+~`|1FmyUbl-_7~-tq*1Rc!F#*tHi-r>`Sf{6a7| z%ABpMbJ(^}&GItqMvT^-8w)I&&Wzrn$hbCEre*IrRv$7!#kig%Rg^r(|45<00NL5p z(8B?0Zln%?%6T()R@_~AS;#%90tfjJ9y>SB-#!~1pkTIJWv)yew^7EM#ptY}v^XNW zx3$vgH`UXbJMaO}s1JdwzztdX7XrZDC6>1@pcK!G<}T$>?pCUPy5oDf>eyQcGrw@h z*oeI;os2)MM<5M8Oxn{eMVh(3Ufk~KMNn4qFKE`8i^eTMtY(Zsl9U17&wD2aVOuti z*=S(Xb8fu&@x>nl7ctW!;_neBRQ!N=_bk@P_8a0(;t_*>C7)I|>fiTpz^A?+Y<|S% z${V2hbNj!WFfV+asIK)t}~O{mKHo1G>>!0ehcl!ryPkz$g9O8wy;!Z^-OU?kBLG1VvAXF zPR27o{j9@*AIhInjw1&eriwt#wPGjf(eoCX1+zpi>0k#f^Hhb(jn{Yl3@0peaL1>W z{w`1uoSzV}5HyYtnDzQ58I=RmoXZn9i_)#r25!$aiWmy$lIBkC=86P9eU;t~!iy3z z5uvwH!Ei3yi0sLEVUrK-QX9<*dfGPG(+tuG-}%gaHJR5$vi%Xa&)~G)jfA>HyjO*i z>s&T+wzgqm0V-wxMRdh{F8F(zz@RPk8I^B|7k7iZ7BEc5DE38x9= zZL*S&9U2eTYC3=62x%Gz-h&O?G#4AgolFkdP}QXy)ySJ2<)%ZH!rm2B1DPCgr#T+? zKpj~Bksuv+%7w|<)<F+?U+7R#&HlA8s<7|5TC#%kW;dl9-uo377zpRcU%M_Y6 zUnn-Ll3B|rvzX=Z^3;{P8$|C)sU}ZzLBHFP#xXBx0zY?(j77 zpCQB>IcDk!$g`HZWx(R@F8EbCXGJ`#ySsU>eql0T?lT|C`5xw!88~7#TXfH(&)PG_ zkkGY1EeJZ8C6=xl;KERJpV|?rhLb|K@RC1G66{pkMxQHhVpnH|yA8Lj_YTqv2cJcG7r%)3U&34#uVW$Ma)2Rm*DM zZFob0nf&~V=4G4qptILyJpp~wWSJNDX=~n9j*wW1#VD|U>M3%uM*m^rUQieN8=E=D98532m}R5i@wl z`c7NU|sIiOXxWcyau4d8FN5}Xe%ySiv2_pqUKkU;Clc?ZiXiE^_KEozYclCD&w~^_S%bA_uulh; z-~-KqLF|)y%xvQ{x7B=mJx!#gj7gnWz~>_ELfVhYSW>|EiGLi)W(i278$@2WT&T27R3vdeM} zrh2`+fru?czWLTkByUH=gUoE7-?iRuk5r4QV;6b@nwA)3>9aPLRn1~F?Z7IH^XdD9 z!Cjq;L>x2kBExyx9_fykP&2w4d-V(Y#6pW4=Uj%7%ii=~^lVGh3kKDs($RG^`zoa= zObVckrd)xAH=sveo_Ab6hvV%nh>rDBkp;4;ebb8GZlot@gL3O_^{E9D^I*P!J-p*~ z^P7QoxGYqwp!${2e`Lq2ufS;Bkscb#Sz$ zDMC4IpDLBC>ZLOG6L}jeO!l5fB6mE}xa>1n)mDLD7W!Quwg!t^PL*0$SGo{dOO3w> z%{z_RnW!HrvcOedZY&x6fc9`oznPM``<|i2O6#ot5Oboc5D~Pw%H}T8kCC0G7cD6V zK9N#OpPQ8#a)Nj1pA%u`sJJeHvVBy#=pyd3XVs6zfO#)`(MwQQkeXdazZ`V*=SX^y zL*gU2GC}KJqv0qDq2>8#!_v(<5$<0e0Ifu%hVvtT4605zpaPDyhZUQ+z{+(LgD;n6 zQ;24U&z?+$Y9L~#E$u?3F{-2VJG(AD%=z^3vp3VNr+tgfihvj^-H-i(SoLuca?O&` zWcHzHt?DSpZJ6Qq$IZu+39eo|%s|jiWp176*MW8K;76V}DExCy(+`tW0w(c0a?8^< ztF$<~R9r0dgj!#Tt%Fn9Z@C-pjNj(Lzfj@Xlwd{2z3(t;?YhHJDXdEYE+HyYG_7@X zU_v%@ofm4Rv=bUfC}hleE2LM!b{EFvFKp^aX);tPxRnHt0R}3-bP7fTB_<*F5t9dN z?vUwyWmNaJq-Z!L*dOl{4Z~vnS8?l!0iv_rsbqksY@CBUB0;SN>E(?J@f8$IfL)EK^{h zzWobYo?7NI&Vl%p=@@G&`4FI?>75aM36FjfBfPoFoY3kQ3Z?$|WVV7oFGsbdfMiHa zU%&|EN%WEuri7MvIRx0dJtp)Kv){Mu%6y9PFM6^bffQuHDpe=mB;9CK-%$Jv`=}>? zVF_z1e}W0o{2F-QgdH-N`zwv4vQk#}ER46NsbE9#`D4EhOoLXfp-&=&75}V>425ud zaNCVi$a!)^*hltpY7DcsJA5*Kj#{LHGo@=gl}ZRSRaYte47Wa2T5P$J`dEfN4QmuQZMSX|(;F~K5D_(%lbfw9bosM&W zT7;K0k4i7rbJTT=oXrm_H-d_O5(_OcU3gNSd0quC#i|O-$(Ys*S`02ujK-Fu4@L)X z;yoK`o)p=G;w22~3gbhLE{<#?s&<2U&A&?NBFgKkHFQsbpD6PM-(;Z_2@1$dbTa6 zvaFkiM*qB~SxVx$70L)ztY@(C`|^h?Lm2?QQ+ej$@b(gS{OmM-%MC>OQ`vM7)^2x@I!RjrCuje1aQz5nrmE%MIWMew*F-*wskc`nI@!Fg^vvF%U|Rcx_+o zq?23c-B`a0?slHbevwqt6tigD$pzfZOpCS&+(p_H>64497GAI{i-ercwPzUTR~a;a5}P;ADB)S&LI zs(c>E^6M#)gV*oiVIb7*dZKHu{B)S3?4AcnBzQ;i&eqlgL|>(9kB(53>5ob}v_%QZ zf#b7HeX@V}dil6Ul2!VYZ?yyPcRU-?YL7b*TBK6gpbaSw;8Y{6t~3?)m%8LrZtB@q zyW}GqKgL@A>C5qhC+E$kfqRJ2YYDc0$9{+?8sF?T$<3ZC=K0BlrO>xc!yj3)lbm{9D-4lC3$BEZ*dWEa zBhV~Ny|vZ71EXh<;*Z`ovHNw`uLM`(UuH~mMCYJ;^pF}GH2HE=issPj=922ZND}Nm6l?k&V=~Z21aYK=bF8de&p1QBFVE?y_M13?2H=JD)3e5a+i2 z@rHr1xEH9gOw}ff(Y_bUI(Lf|JT9_lHCsf!4O?Ieyt@Tw%Ch4wpXb(Un|ng}^F&NC ze$mjS;vxENLZpikx|{MQl6^>O?v!6N#e=y*_Px6%)!RxpWi}c6LOyGw0s0T!jQ{}r zL)oNI(5Rk!xOP>CKp_$!v}WZC)J3cO1}TYfe>X(N$r>MyP^2f{hQ4`W;+R!XxwFX* zv3B^%Tqci2IVUC}Ba*|AfL=ko%1?g|p!?zFt5CmjQT5{8-#NcXN zo34yX#I1HHgGs)l04J5M0eW-NT)SAxl18L@XcCb)UFD%@P zufaLAk*LeZ$4K3#0?={(@&Gfh?K*P_ijR{#-nLZD{6$@s2^^iD?P3E?!OR4;2~EGn zrEJ>7bZU$^{RpiW$WAX6l+!n9X{-ird^n|!rekcCR7o73;=q~=T$?Hzom5|S4FdBG z^3_ay+%kO1{Q`C}_?Q}#*cZzVHAnfqyc1?Q)wg3y9ANw&UIBz;kh)L8rU{xliiocvETD zM3Z*yVdjaYnr5=bdtA8xbMlyIY3!4@XZD`Wv?{M2P`1Y$jC`YJ#!^>*jC%!I@GGl)j-XLa?S>X7N$HXx6wy;4DV&vuSP5730`8|E`CCE94Tc*p>KANMZS z-L!$G#rNi*Zv!bG+ci4>Pw&J;gjek$-9lT=#uay5?tAzg1IN_LmF+Wp0MkRBekZQk zJuIiC$^?f`zXseb!Gew+eFeuMc^LV9$-KTc#FKAAg7@3)Mo!eK1@ZnV z<*bGkfrGJ10Kw_#87N!vuDv!?^QZZ=w|fr78SsrwX^h?@cGjVWPi$R>P|Qi96Q`<> zu%p3*Fw<9hMBk614~P&hCZPaVt?Tw98qP_*CRt{aa`yVo^F`}PO}hNccG29 zHSkXrCW_X<5%S-EX0`2 zH)Zbv8u$EtXhJw#@)6DW$k^@rsKq(ank44Q7Uc1~Tyn=2X!ZNGGxS}Pc-2*0AU@$; z$iI|6Zo-~V=OX7f8(!}YOyt;*q>KXStJk9CsEUuI?71#(WAf8e&#EN3doI8dm+C)6 z&xeSZmi&^}FU9gPoEj9vf5$J|TP8nkqup4i-`B*ql|9|5 zMkN`LRF$OA?dlXwFw&r}m>o_>bm&E|+zmOZBL=CFaP0~?FVE)T6uK>KUVpc*m^&-y zyMA}m3GX-f7l6F%ZLJGQ>Hkd7Fjdz|NS^Zg=rO}#9IOdWaw%>i&Nb0#!9VoVGXPhh zaG4ZW@zxOUP1(}e{TDVB3sCe^vD`?oP;r9AbgOP<%Hpu*p@=8U2Cu(1%~kGKQ1et0 zLAnWfsJ7vSg{LyjvA@**0BuW*-vEDE_!cw+7P*qVmYU4ggclKLLxI#C6X%}yFX2K? zQZ!NQFCf|Z#z$@C&$L^^AQSYHqLebFeUw(;jY4q&Q%9z{>c!7Fg<^*)f9hN#nd!ki z6qTIxo#GhQ1tG5wEwD~CgvEo5L%#*qis|AK-QxJcD2k@u%90V+H4|JuyF`75r!nV) zQr;G}N+nsE&zTc%bnH92y7tT1gK#10r-#N1A;9~QUPz%vfLG;E>6yx-e(8-+!97?c zx@I*|zvd?+OQL}$6-xbvZ$2fkgK@5!sDCpYNk`3wmMnH(D#|AZ@N&RS6CPTwy5&u*3dr zn`$0rGjs8j3zsY5LN1Jvy9~GxhgT~(i;>CZgTqMy_5(Tpb-J(fbz9H}`ji6_84S@zSs7pA>f?uHii>CBzFxR>0dHtoJQHF7iy z9jc$ySJM&d#t2f&(A@hd+AzE}c;4JS7Mi`Cb>9i$Qom9;OtISfR)A3F?dtxLaf%{h z&-|l`&)68barx+-m+@OaA7FfTZ%I;k1!N!exLTODJiTY=(zjI<{DS?q1#v3ARyh>( zu`zg+6m9A>HpuA?6V?;3oDc0XPfzfb&mDmIR}k zzF$bTvOseux29mV>iG_6zo9t0j{?B$W-lr`e@x;+{L)?5{`<}Y%NitKw=HJde#wUF zJQPtA(2$2XCytJ|eu?DgJ_hz%EmW=;+v70bpJ7Z8vb1JWRCeg#=nv{8?*_U0#LpnX z#LRT9uTyQL)Q?9wHC=%z<=X@0$!ieA(_EiwLi3vz9uq>N-0}_Aw9FGCWHN7^3^+i? zQ~FvMgD3$eNtYyW=V|$)gc*QOR#p8iZ$MkcBG`4I?2yI-T@Ehb=*gG2iAT82@b%>E z*Fwu+E1D&kWFT~l?2Ltfxh=aVN-O&3__?&gTcAIoSdID37E_bf*?eoAKAXpNu$46D>x zbJ-?Z!LI(B(q^I3m>_Rc564cG-1X%+mbKYqZ0V&!k=hzPXyB3s#cTrw?SZym)5FZq zX;@T(AfiPr)mtO+k;al90v%RkDTEa9uvlm*Ey%*{P3=pfYoxM)g>eO&iF`dGTU}L1 z3>OGh_NU7nQWe6c*byyd**E`A{JE0JR%BkC?;#Q#=h(c6i6Jc7FfTxqo!f?sb>X z^6M(_>e15TC+wyMbyJA{$ji%y5cQ54(Rkdj!BexEV@8j#K@GB1(pvZwMEXz4G`RA1sGyOaM}VO^eaNu3 zhuRs`(y6#!90p;vec^7|7N4w3c@?+`-RawNS1JvZ8EcfKssg8YZm}IkrODah?>#{^ z_R(HGtjemm^ZIp+ID@u(Ok4V0t!?CG_jg?wyB2MzJ zZ(5kxTbT6x!jZ5r=MTJzH4MuY3bGAT0&d4k4?cb$8>{aK<%>1GYfUIE`-TIksXXz_ zHlAQfqB^yY|z&SnRThbSX%2}5FyD0le2|#La0>g zvsK0ted#Lv`{Vqk?kmN<-{Q24u9IkTSB9kwSo@Wz#7b%ZJU+$XEBK)RZ>ecG4IPaj zXMo8>0y4w8@}R@iT?v?YXAZrhP%$?nvmk6agdrsS%Z}fIQ2#FSFW?KPV|miegB7iT zN#O{9vUPF5XHNZJ4J{5_8mF@UeyheX)hhSPKPTk3ae{@=Z>QhJ3gyJi`ie1ZuV2+W zex7bbOq8kA&2T8iuNO3$q5A=Gb%`wi#8llo5y%%Q7HoEvH>DW7pcuh$>Xw0FSD|%b z?+Ai9(qB(=@3|ZEQ z^~1lfyr$phwk2(r6Cni?BnAUsMz3rUlXaK5oJ~CMxN4QE6@q5`U;-bNfFHfSG{939 zO=pXa8J%Ty^;7U3e*umhG${!J?dBe;Djs)y3vXA;+PCPy`m8-DL5W#9PJF^DzAL*6 zX@>ZlZm6H(mL_2gw^H%$?xu=!;&aWFrjp%^ti%)(TYmjLbf%2WoT^aUpS;^@@8TZP zjm#4&q42^lVoFi>5iF(y9B2ZgS(i6VwP%=Mwlx<}G z`SfVGZWZ|Tq@QCa@!)6lCFA%4m#FCDluchNy8fe)>8+e^Z-+dYnK5o>^nx4qK~P-e zpk~|%MLyUH;jZ%SZ z4#zj5ItsOxUvO%lcy0d+TK*ScouH}sQRRVrol|Q(2tKc~jrXn0_qfVH!seCP4L`CA z2Z-UERSK8gKw~zecgE@nqjxw2!!|~?!hhA(I+p&z?S9-U8jG$o*jD1?x*s>bUSa5D z`iJ%Os-*#Cz2GEGR5Cqs7^e)-&vc}c#;UjqZ=$%zqPGOdqI47h+j_rltpcDy5~&JI z344)$hyee}83z=sVO*hdH1eHB_uNA5!CwHEPc2#O__1dL%6l&4{iC0V>hWVp&*C4G zsTs#AhO2$P=_>D7PudTf?n#M>qecUG#~l1iHz;kgD&_gr0i0hI`OIerLf`L#aB+c~ z?^g76L7%}VG(@vZ&S0NaZq`1?Ch+CqM>kx2jVE$oWJFjtSTLenx?;&e`oi#)1fc3$ zmyoPL{(u0AOyVyL_y8SAjT`W3JOMBP)>T z@(d|?f~gq)oF@E`!k3jHV%f2?V=y2$ab+<7x=riFOz)Qr&7K(ZGt=fjN6jEThG%^| ze5gQTUId@V!;DK%z7=Vy0f&dYmo96Hi(h=+N}~al=C*5)$nE&_OVA8ap0gTz(Y$l# zvwIJ9FRevVk7yJh#;>h3lKb5OZraSzse`nQl!Ji^<=lySOSJG7ml8kix%N9W)Pb?Q zIBsOD zXTCRk$e_6{3X!l%ZD}*ob;j?M07F>{PnVoQ%fCMOU=>0_4MM|l2AzJ~iX zBhzz9^HxJID9n+^pBM%P2a;@>Xss9)%)2;}>WO=pDfaG0Ou!Uv?r4c1HHkd5(8=Wn zS#CP^&e(~JWJLpIo0h7WEIpzJ6KAAQ`xzR)M~T?7Vt~DPF&*t^^(4jB!2JikPVk+T zxW}*fJz3B9jFY5FPO5zfb7~1F&$-+Wk3PQ#jTekBcA#bI6s8uYBX@TKgi@mh0@^a( zu-j6WIiBturq4iY5RW;`*h@g^wWAccP7A0BYp2anO#YWZnqv_~<(Lth4=UrjkMIHa zvr1B*$3jeYW_Al>8zso~WsOjy%w zCkzr9eGhr&jbF6HE_fi2Z4*Hm2cYz}LoyXnPk2v`W_ zK)3j_v-ig(%Bv*Yzh_I0uI*<$X=n|M=*H4&WlwO;eY}{Cl|r)l8k~>6UJ-a@a6Zjn z4LWW{`9H)$h0ePL|4+O}D zjh+p^8pGUwr=|>llWD);4yDYf^8hK&{Mb6KrwW4-#?8lC?$zqcnj|+E%+*Sqs)3nu zmr}0JxerleW_h)*rm1R_N16;pncGX?((c_hOLAf%!r!=hE+spyM{x+;?@7;L%gyV{ z&Qw%TObM_$b#O77*1Wq4EQm{xwFIhfUsfpRzma_{bRzBba0cxdB(1(}Hlr4Ae&)Ne z{90(6D8X)vJdqB16hDc+6XhmrVEZ(D{FwR(RGy$5r$lJsP*nMB)9iQrDz#3uCuw2Z zz3^roH0dl$53UN@*3r|QQODjn1m{zfQzU!{Qn^j(d@y{Ir=R9R@yme#Kk&!5=o_z` zI~?M6xB?HW05j6yPi3)>6?SurvWg0;7=%goy)5G|}0wD%Xz<@x2O~!hl$yjo(rK7#kl`xd0M1;Z~ zOBZX|-wy@FEhQg6(bH^Ay)@}{l?B^V+SD!f)*h4CD%Hz6H!4?s`F`#S-UWEs5vJG(-4eawYyu0tygr|3n|TIyopcEz23I`Gmj-+S!ktG2uH&Q za$^|9TMt|C&YVfIc$k$qyy#xKPdir}b4PC&7Muku&l|I$t*DG{p|`UnAo{uQ0_p@Y}VYkNrnmc{~OR)gu+e;r#`d z;O6Fc4=I$b)X(%f1lgZxC3ed~Th5+#Fr6>hK;V-rtFR%mN87g%n!`;-LA?FZI*%-@|ouN3Rl zU>YaAM`(0aTVWsF{LOy21v=r~?3nVKW$t+i0{sD^T0A{(;w|Ng-`!iuQH5&juGr2+ z1B?{)ylY_In?JW=Ah#AnOn%DxarRBG$8)BHDAFr4kv(y)bdW!+jyeq*ua)2J`a>BH zx-o{Kl!ca>p22x-;}?An;c3N=Q>*My<;Mx(DcIaEt1|U+PS#1%a=}zr5O_F**5+*X z=8?qsSl$Y$8@!?v{?4^cTY+qRcwDf^ttW^fZ)$>uJ3avW&NPa*6>4gKhkd&Ct^7qx zj!487&VYFfF)J@yzp8PtBK1{eseI9F*Zegd7CdejJTe~0+bwXy1vuQ4* zOg_%hT$P4=fk`{oPy{eKqndgk3ASs00e9|>W+Tt%zsNSoLnm{k4^F<%dY(bMLxge7 zIPmcdJuE84P(Je!RP&Hw7BNw0#aN(kQlIUEO(aZ-G7fUKGXm6lvB3Jlf_%>s;qat+d?z zWlhHqJMuY97VjVde?_!Q*3S+hYT~F0z(S`|@_GbsGiu1mHe9CStR*Sd9WJ-}E!vz` z&GB-27|Qv$Elsf%d|c<|tM#<^G^xFGlr<`Wba*|QG;RhR!@&NeMLF?xN)SuM*~0JI z@0r`JBCE7nvp%WY=98VGmN(tXZI4OhReg}Ymv>Em3?@Vu|9bdcescIsBhKH-SYlK0 z@AMyE$nYy=Lv#tqW&Y+>^*88gAfJqgj=2^uLX-IsEYbWHB?ZX~TU74oSc$nEJ7nJ6 zl|!iZnHslA8xkqbNmf}#W74{NOIc1$({}8mD=o_B#piq6WwdRtxcu*t#0f$q=mW; zM>|$$J>Eft;aMa+#kbQfh(svyc#kt>O67$d%~<_-hk~GKZJedz^#>pWk;kwd`rK>1 zj2pEytw>?2bOh6DLDE*xm(xa1D^-N_8uaa1fnIvg*n2MCvyyKVALXV8L+bF!OF1BiVD;`uhsW=^C62*IHmXV_N^tcqOtUCu-ol6Q8 zHq6CYBn3hCRE zzz-~<$dYnyNoV8@cgMRZ+o*SgAmiFLYQXWi?YHzQcgRv=w7+%$%PHSeNn>y>CgOOQ zmTjytNgo@9@;J)SuO8|EZr}hh`MFPp%1keY_rS2>+prIpxQx5$%Z8-ciKI@q=G~bq zN0sblFrPFJ&eJLJ{2=;n(VXJ*nI%wt{U{+K#K7`p+zOY zqp`e@LU~tIk|)<9#@BEnN38@DXC&i5NAt7c9#a$TZ_{Ter-tVGy{q&We!gA_^R8nh zhe=BRP%!E)>-zoa_8E86=E4E}wLd9%+9ot=^Big^SN5{yvrhv#;PUnb8eu!%fX-)QA+ee%Ai8+)B*_4sR-hWWU2mO5#pg{dc#PW3~Q`yfnpUDSY1d&=gAom*dpR#5@; z;l51AKNLVoHL`-n3gcY?0TvIUdzi6>m*S;!^H1{2uz&;COOQw7K4|MZ!$Q}#6@Q%W zMV?MrX7z4>T@qeaN}Gmwubv^Y$N5L4r7>OmwpKgP|O z{)}Z&gS)>i*xIeAYmw?@qPA?WT$3jQE z@jw}?){T#*;9Av#im9G4-z9lPd%V9?bgQesunOZclyXU=cHa*xa;O$Vek6f7!s5JF z+{zVVa(t#%Qttr@n4&afmox9%qXEMMNXxd)cIMp{H+w)-uL*LRkl zWxFGPa6O{R93T7sOm7;9QU{rlKCT~sWuG~Cm)t|a`RAPd30IDPeGUO+;MiFC?I6({ z`hNSIt+*xc@5rTKQV&N(A`TzGI7N%>P5qn;dWz|hqxb}FqWg$9jqIhwzf`uD9JQ_< z#$kcFT8}IP=T3pNFo4ifSM(;-2c_Zxbx)DT2jmvGNFIQOPndmfo7Nh~i>h)HiX&=N z?v5Tg9bg1mb5!!b0vEPFt^@Gy5C9W!D-%YVPY+U(U3@$tDUBKH?sdo8ZdYF?u|&TO z20eQ*Y6H62us_*YMTK&CdV>~9cncYQ=-3zeh5ilpv-4GjdUaP3(lQOwj*NnauGE_Qn4$S-wy8W;Kt|%1s9B=V5 zF9M1OJTXST$2d(88s+uLiYzN3H+;)hEs%9Z_e1Yp72ABUlc5So=HU&0Qq8

G8lF z%IX_~HP|BiYHG;?l=OmnvJ6K}>H4B*Lg&WU3+`^SU%VMYU?%P(#9A5eItLDq6X_3fbrc+Wk`?|DdT zcc`Np03xT-6pAS_OyU{)m>AwTCsZi$Y8w`^Ob?oA6-R*t$f0xv=KhJX(B?BWO@|JQ z#j`4}E~+X&htD*L&unAaV~?>f#x8m8mR-~T6k&3(l;d1L+CT#U?w#6(ok9DjD?vHO zPDOp*ugK?!!)~C@L2AXc^2lGTN0TpH65-;jbhnn&iv;Zy-4p((*eQ4#{-q&$Slz93^W;&(#!(7>r zNy4C!HsMwv#I%Y2_}s-|&p>nkUOKVL>=@+B;=H5%HKE0rqd29$VydB9?|Qth4vhp$ ze0dM!6mhDJ>6>n*1u3p@iCMIzel##cUcB>Zg@p95^Nw(kHX}d>(tfa+`UttcFH*w@dIS^bpS?hb=q1-#Md9+JPv*c6jsaRtc6-&6Wp%6J%Y2f57ILWR47GM( z$G>g%6y&t$da}*s$-ta?ZX*|=_~(}V536gipl81f7A9wAgS1y)`eC9{>grj$rJ7Ff zygetR-2~(3p9sJX^6ABux(-|cq>3O?xuw;+WYoibSvl; ze_E&>#C~)GPTdDMdg@ygX`Ha~03m{-%0MHh5fs3OtH0zagHvf?n%dGe&zr z8Wl1Yi8rS$ov*D=U#qYve~L^pA2`slv3&S(6Xms%|16{1Qroz?#Hps_XmAc~)rInf zN;@!Zj|^yIAKs-* z$NVqgEvgOimgy)AFO~>?-glM%9=&{vcob`Lr)HnRd;d+j1o!W`IsWHE_cAytHAH?v z-YAiLEq`-;D%z3qIqvIenSK2-p+1XcBR`F~hpTm2p)spRrVFiJkVo8 z#N75gu422b7VbCnl)VJ}-AeeZ&}ZQ?7xz8)!NkwLJw4zLt7uWFN1hh7+&|{y;pI{= zoZxvy?MW2?)90C~6Ip`lK>>^(G&raAe}|B}ZX3SqB`V!e#Bn5uvJWJEok&q%@= zyi)_1Hzemwk-Wo#r)Z1q;R7o)_8Y2G2AEWO5s}x;T1vU4G2gY{?F0L^-CrP&IYdbs zjj~4co>qR03d8#{-H6Kp<<`8jhT7WXi2*;$nD!Zp#8|7ko89qQfie>@i4cEJ_~4{PwykXnm39_j1!awJw29Y@xhD`P@*&8y(yxVSnmP>|P)-?k^za z>W&Qp*NuS~ydOHJ*CJ>ZQ?Hc!C@Dk))n`+a}Iv)#}AT-R|Pr-i33bh`|k zZE)uDy%J#|{kMPcCoput#Lr;oTkvcSJNY7VDITT8-&~*Elr|#E%^qvf-UV-1mh6Q5_kkfT(gEO4p z=Qdv^Cd+-)dL%c!JMU5Df@rDdsjaf5WT6i+UqorQR|hEj{C&h)9nuCV+PG4oE@XJwVb-Zr@SE|A`B zD(m$ez+%wR>I$wK=BXV1^^?V*M`_;Ci8M8paK7Ux@NeaTrz9bq;cU+XweABaJJPu$ zduKz#^rJI*X@Q>1ciLf#^7XR%$jz+@p|*>}XGC&%JNjMj2WYn2o2-Mo=gm)zJ9mv! z2-!(jvtJ*gi0>(;|FbHDC}D*YI_ILQR>kwVIEHfu+ojuSdlEyGHoU3reRG*xuqL80 zAWg|g-%`@85w@qbfO*fwl3KwM1AN5aBKLa_OEC5;s-u%t+jcsPgHEB(hJfnR2G)(5*mUX*9CdY2(ZW(Zm}6J`RF7l?DOFo{Uqp?> zPX4Mt3zJJ}`ZgDdklF{=6o-rh*7Le{Q`6zqDG=!)rP|Y=XJa1^Ms*eu&Kv^g5b2g; zU(|D-3;8)jTAC!J7t0G#fe1`>zb<@cH-8i7b3+#&GF=lm>&}dv&kXcZQZ-^UAcgGz zqY%6YOyv0b@P)Qmhhd2iv`uY9Rp%vQ1FvznT=&oS)ksLD!}-gh1)@K+6d}Fw3@r6Y zxp=H;GJJgS2IN;7^4y(7SbLvN(wgRiV9!y5^v|V|OLC7{Gw{w^D$g{Y=bYebMjxdY z85hev2-;eSRBML3wP(3^fYX`VSdj;2MrlYJ%`cnDh}g1Fp*vOJVBl?rH)lVfg0(H8<3TPl>q40#&$YTZp({jDB0(j3ix7%DJZ~Rd-*rVwh*n( z8d!W$m^X}xXEsHxo7Br>$GqGF>-)m5ag zvpi4fqX3wKF30l`#bvb5!lG@7gcs{jp}pOGnLz1 zlb>7+vtJ#X&<96`AbuLVKyBmtkh>LGP6ju6*7UhZNb< zuQr`YGZoIQ{721>)0mNr5VWkG4Xgiv^(9?N5bZn``a$wi4;R7FbNCabOq3lYa zw)ti;#MKiwOa}bm-%Jv%FL6e*FFwYtz`G~fO_o|5#68{#KE%twzLbT(>`&m{*KiDu zdKwi|?3IKyz7v30`*zwE2yANZl%X&V09C;fZxQL(=Mkib(cJ95v>J401^4)S=Iqa3 zGmZjpHFE<5w&^Vgeu0&aS^0Y^AyXD4BOe?pgblijAm5ROn!Y;7y>2JvepauG{#fU6 z3_ZR<420-faM6%@&yEMe=DtU-YCDfqM7;8(%5T+!4Ii7*E!wZN%bha5?76AhFZ{f# zzx6lcnu^@htW4?|x=UNlpQHhf3wR$lYVhOwX{&xB5Bl)RkGdm&H0hD%8RpH;=n1T9 zHLlA0B;yTj?MV4k+f3kKiMsAHk%f!iX!*&;3kr;4@Z(<8(sxxADy}-TO3}5KKgH^u zJ14uUT@-Zn*T{j3-hUJWZ9#_K*-VQ)rB9n}*hFy^(NZ2I2D<69#AJ6L?#WKLlYs4P z^0#L<^Z~OA;>I_rstcYr-A&Kky3&a>$a_Vj)(C?azPmb|+hY<*Z8tbpGc>2u=SH>D z*N<2}KFs&znt$RFOBIhivcyQm;HNxUGi2yadJC(C%lGr_&#o9$sG{{h?878G(h9>M>2R22Z<--3b{#f>8<0(e z51cN1KWlYWhEyk>xg&#=|52n$Zm>zfbW!lB;VTt%bIn6SU6iP?RzuIqoao<(%P0Mb&dmKW?xSJY=?VJ`JOEjZ8v= z7wLn3T?Z`K+a3~v6y!guf@?mxAPX;cL58W0h}n_}WR(wNBnLZjCCR0)X2~Xh8}R;5 zX~(gxd{<<{Rv^KsBB@MQi93PJT~7a__=0_*5)Hz~F^Fgd z6&uo*IfB#hU!fj$C0t`2K@rM@bxQH_zUah)cK74SmBXp;7jDpAebGT zxRgQVHUCkLo>x{gVNzRYudbXnh+IF+{$4oxRu|v9j#cBp=}c2VxlA5EeXDhBZ%z^Z{$)N+O?^0#@2KnI=Q?SV22#NWAWU>@qid)Kx@q1 z8de~C;nyr>=g~|1X5~8@(XGN*F%KuH?+awHk@}uJD83;c0FBm?_E# z6FW$Cio@t{O4#a7fcJE6kQ%fZZeR)o#=U4pH6{Mk!JivUKix)lZ(tWw=m)o1sA^iB z`|A82=%HwhbsJ^TZj*CrdpbSN&9$?3zfpEyJ0ciF_pAbEi;~gaPl7*x@!FN}Ti5 z>SRMv`EHSSl+PyjUZZ}%S}V<#Bj<|-VirJz87YF>P61DHyGjy`S*Q130C8#02BXm` zo?>&wcmM6cZMy!_U(R?zO>*An?TjQK7Fwq{Me8$GE9ek}h*ve1wd1H1PhNpDeC>;+ z3H<=%HdoUyKD3qMECTG4SC(@qTEM$!xLoF}7r+TsxtIV3)q&i)VnA9aI%ec6?jC-F z|M3Q--2=LN#LQFzT=%=}C6y|m{$C~+q-96Bb?#8SSsvWe)y4VPeX<^>7pqgHT&eMW zf9z!CnL~%%PJr#xpRb@tUNeZSSr!Wq-sgojVJ0^RAbx<9k?T}VA@mGjyLiO*TaV>E zgzDB`kqtHAn;1U*XnDlRK9ReeqOk!yM}3gM--NP#!m9VJ=Rw7YFN=ed!()#LUcpDC z4qm#&s>$TOU09LVuJ@`Yu4pKN@e8->8#Xi1qs2tNKVmN-)!bn8Dl9RT6;ws^c~kWAo)=;6S|rEdB;_>J-%mt z5=GB3F%p=^IZMs)Q$4J5A|H_oej&Z!D1(Tc5`4jdLEuO!z8!1KbvFFHAY zFxWm3nt={n!7CKSTx*9G+Xoh!!=5ruxLiOI!t>WoF&p2{cANer#-=*#Xgs`?wF?U( zF^`HB6F!8(mP+@?bDM0EO!=6TU^jFgOO`0{xz5+s(*N;8dUdL*F zMB|Ps=WUf*t6{ZpT}V8niFk>evvcKJwF}jW%8}PxR8PbZCKP(>`EB zEngML)Rmt{P~q7j#Zsbec$jYM^vY9rYW&k@!kfQ=vfr*3oL5|Xea!21lxECd`dS6j zf+QNF*qrrTArr@VBPku zD4W0``V6bU#1b7w8^!pk=bjkeMYvtYqF&{JP=2GV#=n{0UyfqLLQ*aKM(XFw$Y;IL zQgRc)k?4~{2st8j&|%Ir;g+sV1i7T=e?T^9secrEIISzHb-X*#`MMU=+MV%C?fu_| z_S=dS&kc6yktd3E??wf_0p;T=B(!)+6J;0ceq4gKd=lcVHcGi^f(n+eiD~=+@oDq^ zZ77>pT%+?_OuM0!CTM2t`SH#kod3)5Tzk1AYjrUHcg=mo_V=wtcxC6%G-5sEy4V7y zf>nvz_aHX_d8ifzx%)!K{#WvLo@^*F`j~WUOKGdAkbYrZ)fnr&oYiJ~0g2ncZ29IH zG&665+ly97+DxYjuc?pzA*G{4aGn%%1+Hk77SC(QeUQC+4U{O(L2S8nI<^@@d|SKI zgS3>YSj8qyoDz&N*C1pq5EVf$q4)9oP;}XYaP;Sq%m8HO?UO2yV*QMxbKG4WTG56G zg7r;;oQo2{<67exCdF6mrD#0S;bmeOeR47&Rwc6%Fn;rbV51|{A|ubUFF@50fhobW zu`>@j7?fDKJ0C4KGS`QUiM`ekk>h^+8qM!}t1{KCz!^HWGpty)r558m(8;)r-MzJlh(U>ibC8dkr3qf!WzjsmOZus^ku|y{_3pl|*3R z9>d_IQ>(`w^mV8r`b|N*+}i+2uZHE^UF}q)vRnsMwqbcH-B#La)Qp7x$9Utx9M2l1 zeXvWH*dzGc^vC3fGmI|Q+?B8+;my;Wp>fRM9)oZ|7ClEJ6I4tZsJ4Ty zBWd?69v0(LrxkD~r`Jy_WFtemXC4euGk`yVShpa8`w?49hEiBfRD1iuz3)M$8KBo? zx8bHIe^1eEIY2GebsZr5x3s_$EXnW z?Lcg^#*SleYEHO)BNRtJp{gxOyLWQsaF=ivEf;+N_2D_-PR2_-=qe+1EFA#cNm{UE zIggEuEx>EBKlQH-1{P)Qvo`WqxswEr;?1oW1tn4;PN53-fU=r0C1NEH7CtWQMY6mC zbd?c;CA>2xAlIvx3x*iI_MB!x;a*3;*}@f*oA;K5bBBO%!(gd(z?j{hl%AI&KKUy- zn7FRQ(OKE?D5E7C$Y%pgkD9KW(Q2z91a&xfE=t!6>I$U`!e|06(-TCzA$8ECr98Af z^hhyXbXumq=|~~j_@riC-v+MR1;nLvJd`LPfRKDXZI;p}48WkmsQh>chiV zcQd~e4B8&c?bz<1^e~w8F7hhA2^tuQN-- zy9r|~WOe$y1qj*BEmDCsP;}796ZUfvq25We`Gh8KcOoX+5zp$9SX1qoKejS zs}?9?gDBB!9Gbcgk*N0)U>Ej}b8+%M8*G^F2x&#E@3K5Fc9sFC$H71I&0h;E0Ew0& zgdX3=vf9`&L0hcCw-LgABJ)HZ@lJv27yv0x-Ing6mXM)ztZv$EE>KdT9(yQoDlpfR zcV@cE^aEB((iw}zZ?u1V!ZA4fxdYikhMbvAOnyMq0sKvF$rH>8idzMo6m*%6b4IiD zK%`$j7@&pnq|4ING0jxrFy*_Pn`{R#X0ujPSQtiws*G+2*N3h-T3q-nm{G+k(-q&o zZOI}qg??!!rFr+jQFLdt*K^E@0jRdTO%*4nx>;MvHD_cL1Px%2(JJ(m|AqQ_LU4b0 z4gBW$YW*hFk>hO+cq5EP!GFlBrD<1J1bus9SC5yfe3^!_WePjU!n{?RcC$`^orD`K zo^+JOxCka7_Qd@ZzqoSc`$*wBAPc z-VsKay2!+?48bEAUUXhP70$Bf?m8qq>s!QM;oBKk0k3b-cLFMzev=yAe7L#mIRg+g z8SPU|;!6U6ZYugpp4`+_q^!?j8_vH$%Bswc5`30}_~7eknUAbZuW9AhgbwGKbpooi z0FthBz>*Tjj{vtl=M|uYc_GRr_?7NWH*BTpOQD3TX%z<)UBfC zthr>1{P|X3@ZRuxlv^*S$*j*O`cu+u+5D3Ukavd8`;hK9F}&3yH26x!ZgU7*&u6u3 z{CCFFngX6T^ls)yaHkD(Tokz9ZdlttRJoCz=z+y{2iPn?R9Fj*qJpdySm~KN#sTS< z>87Jz@}IvZ4vuQ|L8Pez(nrI%17-~`Wm~7gI^4(4V%;ty*31II{u8v$+b#6a6 zbq%>dz4{^)B0;|+`J+O~omBh$ZiGhoJTRzX@tO+TR?jcc>;~lGCN{D|Jti~CmZvrb zOW^Cj4Gvb$30Z+DDJ1u?h4D9&8R?!FS0}`JIN#;u38?LT0s?7zC!`NoBL3TjVZx{e zy&B%Ny82LoxDSgl;ws6Lv#yE?es^DWPWe1JC8J1g&GP(9l~OuPwCTKB(5|NjpH2l{ z*vbQ1?(RRohyDCru*a+J6>%sdJ z{`4=9Ft4FmRb9>+Qjmj+b$Dv-KMJTv=e%H6V`2559Kqp;7iz4K**R@Kj*F|IquxFN zcI;5p>$r`1R@88F%-)%QxcZaP!MW|uBaQUW{M16gfxq*u0c7LgC*v1ElA3GtpLqoE z*;CFc$cn9b)il5|YepiARy*Qv$1A6g)ge&g5@8 zC_htmF^qjoGJ%)P3jNvrfHiUZ|d%5ApI ztDkoIg=%rl2)hJF4zhZEg64g%>G?kOlhhTZ(njzotY75g_RPisis)A(K`oMpR#i0Y z9e`JX8yq`azG%Zix=suelalR4EOf*Dc@fWPZ|ra9&Jjxbag2KqrNGsHX`YXJV&xv+ zC&C}6DqB*6S9yTt_*V#G8%!B|TDOjzqD@AS9#P`d=Uxf*a_ppX9Qy>kb3x1^1wKTZ z2QUnf5zhC7z(WPO6JCS>6*x5*SJVSNU%q~5rMFe29(H(hL<4$+fxHOVhoMNkJ+b&r zWt_uM6W_Lt=@V1ECW2B}h@5gT+NOg42Ux&DTeeOvaz!_C1X*rQtL^ZbM(R#1z>_7* z`676l3fap+jtxleiBJoCiwiQiuLaGfOfNuR*y3r)Y2t7wdYV#S!qPEFlhIRu&GIo5 zQIYS5WyLy@!n%^+A#8xo6(I{BTHnGSORIK`CA%Psxj;wpD!p2B$d6o8lcQ9Q94+?( zc0rd`qZ@%HA<@ z5o0az@hj+lN-?-tfVpJI4c|yj+wQ`58?o&aEdr=yuH}Ra(Jov3#>cY923X(K7=ctq z%Q;TShJx>mdcPgPHN#FhuHfu#<{|Gn z?*jxaH2a6U5}K0;DR!%e*{jPV`h7kJddN04cf}KJj5XJ}Q3ZAhbLg7A`7az*k=*I% zubRY2ovaY?!-NBveU&vkz&LX2wHKdeP0zA@8Y#p1=UL|?I*wlTX?u{F^?cgS&$I;U zc>B~l>Z`!5y{Xq-Dz+I+qM2OC9dA`*P7!vTEXyjsvqq;O0?X0rC^J(r`{@oC&+`4r zO@jY~eVHRGa_@eZDG<39On)o%R}O#d&w1l&n83&}$MIG77y_RXfE{t-145tN{46qj zNjr4LhHbI;sXj0M3f$tBA(#plwZMpLM3~ajHr89`Zr^^b*)My#krZmgFZz3w%#O%o z4MzSB!7(~J^kY&n_!QJ>k|=!P&Lr@{D4#d7nxm|v-&W7e+S-`YCT3~JU~PKjIB zZsB84E;$Js*^W3Qrzsk6skI#->HYu*dAK@XNW`@xz!Geej%$a60rFt#&-XYT125E1 z`y8TmQ=XeUsKa583Z6ZE&aqZR$fub1JF~zqs_vRv?ia-J_nWQ~?7I$o$8R|>TZbb^ zwz;2+V!clP2wfeJ<2zUgV1fwViD81{Rz@foS8+8h$^A`RJ$WBNS$ePvtbALF&@?Q> z=)P{%9!%&;`5^V|wC=Y8i5B@o@;US<1!JVzZ3iyNk8T<%7!p&jyFSy#^B`Te&u^~l z6TQ>6+1mV&)On39e2B-%4iZv+PC9I?Fq-*9H#O0JSqjoeeZMZ$D91Sq2P@&ETIasI zjR;}KqXZV6(m-N@M?9vaw~sy6-JU}a(=5JAJl108nn#_JzA}D)6M|!hgKc{dza9l= z^WQi!1=tvS++xKtdsP4~N~p&#_}7J-)_9C>j}$i4BQR5|?B4_-_JxX%B- zG;%&A_V98(*DFCHBQi@s!rv62NZ~K0IPE8U=1sThwiBvJ!f5jWkP1vtvnL;Alf1PP z#_={FqNOGiMMO|Geh z;y}^}#sZV~NYYXCwG^5G21N8)e_(upf*YTzz zELn{ds2x%ZQ?*Gt1firkn{F2XaVc=Ac0v`OqNpsCIfxbF!I{f+^N1ZvI}5uzk#jG@ z>|Eq>lF^~&UV@GsMF)wEGIs8ZdYfJ=m`OK^ITKn*VFa*H`;z5R_f22Xv@xxj!yvLT$%Cx95p{O$+!m=eM8jSH0@osR(g z%a(G?xQE(xpBiKla*hyR+8h5vBrOV*&;+;Drv~WCzp^r4**T+Hlmq;?J6AL;H?sT9 zO`8z97{!)EAKuX^W6OP(>kD{1Z{x?Z>|Pyk>H47$`*J0h3zBH@i@jUWD*(}0=^dq8 z?PR+mi@1ro6(-G%&wg!-j0}G9uooKSrKLlOjUL>esH(rn+CGgt!M;?$e;UL;*aNP=zj>a_seC3OJgB$lqLl=qeJ9z=dVO^K@!Kam z5vPiqSNL&zWo74%6k`~~4w-ap8_WhsWA;2S2_wwyo`wE~ua9#<0|Oka;fn=5Z2}#* zJ9DiCeu57YqdE+FYUzzllU1L1Mx0@R*-_;C#XHPwkDw&>oD0HP7p`^Wa`&eiY9&ee z_NIE&H1V?q{r<#!%X#e8=0)25pT*{>-a}$J1821ThXxpJOBd=H@en5tQf0omkp1aa zJ*ps?9V0uN?`s^OOnJaCylN;~e;`MFyBHrZ0FA${3Q`IQ58L1vancQ{o$=%FeEp=f z+F|lErRa%WV)Y4=x;_>;vozWPeedY*vT%}dugv_4IMnrQCO)k+6WSv^y0|B`HMj(C z%>Z>Itvvlk{~)PJU*?J3&2Uirk%|>~Q8mb*&@LOu5_P59K~B72>77?+r}X}M(=9&Q zl+`u-#N9qwvXc@X)*A3up5BHeSCE z8}dyxAx_Fc+l#}QD*@(%+Vyxz>g@vf27;a0y5VyqC?1 zJR&k0w#=lTZuL40K-=cl=Ew5F#L+$Tu*Rv{&HO)lTeIQSQ3yBF)CcTk2*t4}Gx3A@>bvm6Xa6Xq z9JFBVq?GOPbr=l^^?7*t3cUO$b)0u@Sw0rOrEo3{i9(HxT%YSuDZ5pQ(PTe4WZAPa zi9m6_h_^w~TZ6ou__PR?+1~zQd`q2Wx6%U4WC_JN)2d&?ewEK-Zhh&J{r6g z6iCpt=ZN*|ReTo-qVW$$-H0+X9{KdlxlmhJxSBHXv()!pdxj#X&JPWdZY5I><#st3 zG7TxAH|#ZxvxeEoA@ZN|1%-SN!|L!@uDBFfVQqM{Hyble2GEpVc$6K^Xg-;@{EmQo z9>aW3J{58c>xsIhjCcGJs_r{W@*kuj_~-Afo#4FZgYy^(5v$eJ9aRm%kKA*{sXx5FoYB>E+|-kdO4w z>U68^Q+=THH`fh1>UYg`)Koy{)2e3I3|`5(I`O#xycA^!wnv(%O==~V`yPuiK(j$A z#LL0!UP_v{IAu%gHBgTWqV9m;HgV^o;|JIxvdmP~YVw8lvk4A!!;ZP$PS93kM zBN1$NOm2@&aWJXtJNO&{klyNi9topPIbkqHJ;_B2Jgjx)+P>Pkic zBt1$@fv;K3;;{;_B@WntFBPC|2~IDl@-b!DBnCng+o3xp`y2i9qID|@{B2dqO&?6% zn#Xgo*#n38X+bVYMa)h28%m$ZWeW z^rj+?iS^~Stoqy5LwJSbU5l0T?Bx$zM&~$cW1k8boY!m?D`c5Kh)AKtCcHHin32OV zEj2sKE1{paE!!(D_ya%rtnO`;1;B7sOIjTDL5pqmx?d`b7V~^gcDa>Mn6ji-IrjAV z`TP{Qt>T~hDXEpf<sJA3g>32#M?su$w&Et!>HR%zXYQKQjL=`6lFn*bv(V-h1Aul2yj-y15fJ%v zU6g{6r?oHl{-S%j zN;(@vA*}k#LvehA$l-y7tY5bs&wC%Xwhc&|*X3>DKiUzd2w4ep$3Mel3rW9BNrNi2 zIj1cfR{mFx{c7#qT|4cm=ySO}p@4;rsth2A1`F*1v@cik%(`)Ebys^82ozi~JpW@A zlkl6=?TWa@X)xVJX2ioz>dz7bOaD#&eTF^dQ$*K3m-R)*9QX&)%N*Q(sY4$@C;aqV zu>7&&g?1>1X#p)fq3yr4(}^!$_sSQ699jG#XGCiAE7fpez?Vh1qXzWQFoxW&K@)5S zzjJMjob}^*Wuca5A0tnW1vYLSQf=dQVeSga{M3L%jbwo_AfvNP_bh(xV}ji=oafaE zmgk4nJ&yXbF-1V^XGYtW*}q2aC)lI`t8Gj;zWRb1FuOZH^cG1y?-N2BD`Puh7lxB> z=~$q|yqpEA-^6x_gO$6}o{IQYN2tHA`wsqyv7ZT9>HeF$vDns~>y11W$qZUN_$7!u zWW2kK8V;~`Ye%t;%WWdx4j63Y*t@v{8l=Q^djoSCzRNW={C8jQAtun$=-XDDkV|_t zBtpefA=OfO!h_p$N0Y5?r|c=#U&`D{aH|N)#v-~~d<8OSOqnpFX$4;wuC_~EqpN#? z)_IwQpm3(u!Mwt( zWq51Fh=_@NYc;;BWr;X3|rK@;!Y8>ED#)$S4!Y1@?CrV_Tv3s66X>%xd%bthTv zif9Q#6ElD9K&$1knvcWC@mC=G4jgbUm}YzhhsF~n!V~}-kz>g2aX_}}CEh!FB&JGU z4}?~jX1CsiFZ-T^iicytR|X-fbUUC(fIf<;d#Y ze|ZU>0ryy9Fn_fDKEOMio80Qz4{d$a#Qa*2i!4RZB0)BQ=j%$E%D%HYItrl9vn89c zJ_exGiHryC;W9d+^g1opSQ03DEZ_*%6md4i+w0INRVDqmNz|2x#B1IXTnlb^s7M4} zmov~@$5T2xv`)uP{;xbkuJOznOZ< z-S9c!UU_YSQ#mFKZSjW%Lsi9eRr`xZp+qQY9;;lcm zm}Ymu!u0eWS=b%ljMS?gMJjpI>-D7Di&cAzH)pcD0%!&TG20?y-Q& zR1edK#VU)ZFRvhzT`|iVPC7Ui>~`S8-q!p6z=HSqIJ_SjX0fB&dgxc~zNkErIc5o+ z%ClwdJ8)3gOOwx$p@E&AKG+<+LARa&b4PpR%^&6i?SmV*j#X)|fa*!cfO)zEL;Juf-3!?2u!t!kJO5oqrUl1nSI-j;cAZu$yA2`a2K&(cD1=*R$ zr=b1*v=65~(D!-1Ja=ugU}M@SNs8cVLEE|&I#Nf%U~A(M$M_z#G#lHwkF=e=2$j3b zBlWHUNMhc;gte!?L4v{2ZOf&2(rv{E~>Owra ziM!uEvkslNcM3g3e3+tM$hS+=k+GRqzD!T*>Db$gK&SND>&i_Vu$wjRSG8umL~5I?Gt(MJq9OHG`$D_ zgQZGD^e!-L68b82REn6T8ldmucq$UPx}pofGix%4P^(G2nA^;z&}{IDM|65DA8PXC ze|=Ygd;W8RCMgEhr>ZdZKS0*~S}Ze%!18a6JGb%sk&hqpZsQ|(UMau8?<$9^8Sx%Z zB?O1EnmL6*c%#dZok3_a+7~pvn&gY;8Q#J)jqI~DjD_RmXIgi@bW-&STXv2Jb=X)ArpA}H~wy-g) zU%p-)>~jb3Hn5Vn+Q4y=up-2u3})`!wU_uw26zO?vvW&n528o=A=tTvI>(ZKJ;zjf z9dUa!JNmsyO*P7p2-ou5Rv@oXcHA^ zb@#j$m-xZ~IK6a}%Nh5?SChQdwc*Iis^BApK+ScW>;lp8KnEjt2vJLv=TF6##e&I? zm?(w}I-`BIT^Q3>c;9I9><eA} zp(W;3*(B{vbamPY^vMm`E{*_ff*u+jTJ2Uo_J1CQTz;V;6<|GK+o^<|Hfnl^7#aHD z+$t}vPw=zxXleh1*a;?Tgp3YpN50KVN-l3y2K%}-)r9x&$kbi~vy|6dcv~SZZBDF2 ze?>t6_2-VC?)T4SW*QV8f4)>jgv0QiRh~Jv(Wk5b?W7j3S6jBn4=+v7jbUrR%717* zNoGOzr|&S*l*f=ai*{A|sVAt1DE?Gx#Ie}d-*Km+c>QO*QaH}T+;fV$2C{0WmNtm`Zv8%^H3U~MC z%D0J3hw3Om+VrxH{G6O(p}J@+i@VMRW!xw>nS)x7fM#G5IEnSe3?t>;hL$BSna@f`<^PpIP$4c#RpTzo6P?x@W@Spu3sqU zSfGND5;3VuaTmn1a|sSBS}Yh)6Y~arBAwThsyXt2#cu$+h_URz1U%O}tsrcU?FTfl zH>i$W4SN!A_lK}6ylX247TW5a!+ivduIxE2%vdM|AVvzDP*!MuaA`op_W|Y!nZ7F5 zZ=R`f#mL;)1Kq~i4iYJji-KEhN)mzy)H;B$Fes7YL<^3j%vx1LP(OGk3Fgt{?mcN; zBfX^%55WiH{H@L{uo?DTjL+c+ro)A#lPeVJ(GJX45R3%V6V{z`hwDO{;Y%HW{D8U$ z^5HFRFTd9zInPDP9hXMyow5fYzv6Ed-%|rRSo91lk}>t=%`y1f%G{@8-9-U%36H=B zkY{h0%##z4V{^*<-(1E$nokJM#^tQ4j8nbR%EHNNs_ ztYFwwHC~eYSXECsR1KKuK$ltp;yW~K{jU|Q`z1>%U=iP>X5GtZ&;^FPT1+2H32foe@P2mr4 zm38I>_mS!bGL$&>G1;^%?!REL)uOmp>_L<^q00Lz05KLrK_qw3s6y9`1=8?Arqi|> zp}iK*Z`g6PyMDbDKYXB_Se{=tc)yUA+~UB0095eWE}P&0+;gc84(DH{0Lmv7uIwuK zv?}5(Ch>XC7U;w6OUn4P*sDcU-_|D)V8MP3qD$*s|vTz{0_rYJwCReQ&z6$`B zY<4j%@pishk*@%|mocMn0OKMp6(W1T6{>SY6mg&olm3U?z>=TEV#}mnoCvIC%f588_Ei1fgwis$p-zoe^rzDYpPje6Jfdf1>8yFXy^(j~8AMC~qtdKsTzI`{ILv}ioUMEJdO=U)ke<P0NuL_Q z9q7#N+9)p{)vbtNxK*qIyQ{e9M(1&)P9p{>kd7apRk#0a-4P8a3_`>`h3 zlBiGjFSH|gU&~Fh`3WOVMNgU%Wbmw=`a!9;A#z~>XUT09fhu!y^%bqrm{D;&seNrfX9tr+qR5jlVFiD+!^!cjTHp9;vre9*HI z@&Ux4?JefZeAXgP*69>(Xcg1X))AeF-ju{hnz`Z_AVZ7>HF7sa<*Pt)_>EoAXftwm zquo&nmaBCzUhPrkGv+Fe?iCvyJ4YG1nR+#f>+yzSFBuC^T(}armLRE2elyC-xc7Nd zOOxGTP9uNg*0qw^GD3#Ad4%}N`&{64DdMtL7LtIRC;!pqak&^jnVznJ_e?vD&fj<9^pjnXsYMR>}X|2Of z5&zmjgx%vb$r;+>;If%jy?B%%*-8uI&qwrfz1^Wgw|PZ@TpVAz@K%SMg_e>_k~`nK z{i9e&`B=(X!=%l+AmNak_tkcb_ZxYd^zwnQM+LdO7G`wLtNBwNg|)-HdrK`96vrIo zQ2vkNmd09v10a^Kx26(HS-pEiKV+lBdwo*+=I;Opa@QvX*+Bk}BJsZ)nDTcs2XmJf znx8uU4?;n`zTh8WUv0%5mYo2$v5jMmMgUQcmETVb%DwZ*t&O|)Ha{i46@S4h{tS3n zM|I)|XVzw1WXycH{{Zim$LdXeyQhB4e+lfYW7Ahhv5E-3>}d!JK7l|L^b0p);#J_ZZv_FQwso=vIJXbN|?+xhQEzzda^$3DL7EYl+c25z}4khHjP5UCz?;R$|ki zfYqZOnh3$HJ%OR+cM)9%Ca|Shvy4}<-Dpl*fttvF4LfiJD*@z_ETC4dp6-j*yF0%L z$QL-OvG{r7Pf8^OG;Zj(TT1g@v8P`sgjbYl`eVkzyXz}vSgV>;&WN7(;eAp_?Z(qy zp<}7a$(+{%@Q2}_jJ_jZCYRy3B$e^e7kBAjsDBQBU=NI%<;*r$8gQ^t7=JqBYJ;C{ zwUsOlO1}aot|}DwXT+A5@)aD_r@Maj_vecL0N|Y;8T8aIVFj(v`;r9D=U*25QTQwI zKjAPkc%MwVwv2JOp=i(0jMi1Ms_<89Ln?UsQg%4cwU7ukR-njjfI&NcXQh^y2-1O5icd=QO({2!K=iL4(wqpv742Rf zP1eE9b5%8@ciLsNnFAHFGlqK~UIi|w#4YVMa- zgxjw?fnF!3YqsXd8niX^WWt*9T_d?3*9VAYF>q^@)OJ>(dyn=|Zq<)>qYU(} zn^_d9Sn$iOF9W4vHlMWeI#;pk8h8NlTxHdu1o2oM5w3FaT4S|2t-$8B{Gci}kG3gu z8mh-RG#T`&l0nBoT_eD~G5S?$Er((HR-DF;XFiq%`TpJZnSl0(ImJ?%$dO%PfDzJ|*1(RGqRCYcKDe+s&MP^ry{pjYv=x;lum-wf^fgru z9NA->RXJ^O+Pi1B7_7Oivq>{$GmkdB^s6!1oNz0mOA^PWX~U-QIIUF86pm0>`%?Xv zYU35zJ&LHT+2K?bT6N>1r*BP#ISlG2bV7fD8Y>ZVxxiz)+RwFdU zwK?LXs~KKZLhz2&(EXbNs>v28B5i52PUg~_jU z@ZPE;QC~Lcmjyrqy%S2*+BMF4Rr^F1KCIJq^}wxpd_t?x*Uh#bB9VB_Y0cs}2;kNG zR2Dvglf%}*gU(FvVuM0@*Y zp0W;^suD3BI@Ad%&2`0^CUv%Lmg2iDEpo!RO*QgGdIpp}U}=pWfny&}rFM2lY2vvX zEMQktXxTIxqqMbBFi*8^!6v{(U|LJN1L<4#@#Ge*3pHYx;<`&4OS_22l)0!hT@~(} zJcH%*ui1~-v-Ur?@m$uvFY$G%y^hey6A$&j;00V^bDq6;(~a3b>{t5>%i^@R@u!bq z+iJLFg@+Oke{bhszdjo9-izUD2%*!WRAo>&9Whk29SctI{+}I>HLjvrqc}nQrs?z8=73lC-f@^XeHAt^K^-<(Y;pO(LY2olXSFJ9P zbIn^#9OAOD&msQ9a4HwkWzS0Y2(&32@_Z*Xb9+pV+_SH^g_eeG^2LE$28l7X!TY-X(0ph*`)jUUarazuo;MGk(;!XVWc~Kq(aDQ!1GhEvp)5t9y51oH(zu5D_ z9uv`SynpbROsg9yp2dkFbNotr{siK`6#PLX_OJWn-oK=;j{2^vs{9A>0JZI5FaCmS z`F+*53$uaRx~sySM5#Rn)1$s;mz|#Ul$FmdAG{ zt+p|W#u%>Z}$T!(TB_niv-Luje(Up(oL zwbRnQMhpJ{F4zU_SXHES(mi?&N+83U?6mz-Ae>jvuxi62HQ8y}sM%h1S{FMhpG4_? zB!R8|%;%*t^TZH3;T!JPiZ-MogY7BCY72dBQ86LwT)$K{_NTNbitsj^@l@`7P9M3i zB7J_{YXqz%#UmthPNO_jWh&DFxKq_KFs7-TCj?OIv~Vi0K(op=t+Q916+I)-pJYyK)aqS1)%9`d6pjTBBAJ+G@8o zahbm{<uLWzSEIE%p76c4poBf>C=^xpOKZH`Qn5UwO zS4B#uhnY#Iuoam#lh?1wq%JYER!rJTkKq(u9L}-hv)G>EvfjqMX75N)GhD^aowlc# z)0rFt+6Aejj{}O_n$HzoZ8+kNVci$>>3%_8iKI=nI@g)NzEB?3w{5Ake=PJ}roMBx*Ihr0 z6?6U)C;s&pF^wy9J*MYd^Eu6Pw;G5f^NRB^;wE9sRarH*CvIyORIg&r=6lesx3Q`) zN-eGuRqWjTE0A4G&A~r+iqK1}$GHCh7ax^#^D#z&r>)ufdWHbvUz}DHns3@A+y@ok zyep|`5$e6+{oPcG*%8((cB$t&V@Ag-9+1}HY!!AdGC#e=b}&YYFyl3we*|sH4h2uP zF38SHoxas%Uf^?CY=jR=sOM-t)gIZ8+Cbf(T08hX0RI3Qz2W^$t@Pbi^z>MSWBvqd zf|gJhem}_n0QFZ^>bAcQP0l&Gyz>752-p7rvK51>#Zsc^a?O9{X0z_4+@(cRjfYCJ zB4ip9CJg5#kDOLa$T8BoNmB-}Z$4btZ7H46XFG6qoK-FV05wf?-K#q+tSi){&67JV zE%_L)MADxGSDEQ#WMaGhHLyi&Y>!Q}mw{bfn`XExC+YOA-od!$gH(1_Qjnn6wRlfJ z8r7_!2h4z1m-s(TeOppXRT;?u3jF~6o4ycg9w715T2!$C-=DdUrEeJA&b0YYbp4b+ zWi3a?z9_W5*5W5hK(6nM0zbMd`ZMshfu(pmNs7Zyjbl_$zbL=}b`|Bn2EHHIcqhWv z`c1rHN+{3Py_ZnI#i;hK*g+(D6&swc#@%X-a&|3h=@x~+Vk$j5L$!BVPh@AirI)y;LqqyrQXW8K}x@eoZ^($+HMn&+jsW;+FSdL^M{ z4aEgSi(9oDH>YanW1Y5aHQ0ECgb?h-d3L3!Ev!U}M$tzh@h4Q0KM87a1OEUn@BEE_ zCHk~q>0K-OQT?L6BEN%v7i(JP1H3k^uzHspN9t?&x4j-^*EP*t^srG#r*9AdS4m|6 zWCpER>FmuKEMVkUxO_eMH{%~0-$`lVn(|ala0 zrim<|7}md4J`?`{!8klg;;lRTN5%Sr*aMBRz@kDvq>w)#YwuryU$CFT?}V3!#Jc;P zTK+SbT4npA=rAgIPI+SC6Y=-qckDg!qv9pyym9H0SV1ri=iIK;Rr+Tq@D=+10PrvN z0QhI{(%L)U8LWEk^kgy_V*)t-i_U(h@~^cm^-Ui}i^#FJNPyq~Hva%B`_~Uo;MQy&~ zQ_!E4MSRn&YyKhE<}+VRN}h^33i4eh?(*xKscWYfks|!kSD8Fi>N~UDqn6X6trT0i zg5bZ-3lFVx+Jw+sCLK$AQr+s|56ZxMX1R?^QMj3YSk-;=SX8Un>h@MVhsL^<;;>{W zfd2p&t#EoR=AgbJw);s4r1W0(*XuBSzU0_ChQRCCRgVg2F~_Pvwyum+{{R(vl(3a2 zK69$AdT^$C_r+fr*!V}_ZlkN}549$@XjCZ4XOs}253lK8&86j)%UtxY+CSTm#J8GP zgEXH4S^ogb%e~Y42(XGjQ~6irFZKp&?W1VIFzkFAN_8b9c-)#%Y>la+LLc&e5tK?XJ;T+52^TaAfd|F*tb3uSr6Q&LWQz73XTuUg3BaxnT`l50D?=*~ z+Rh`#wRBMGKmZA@VWBZ~6;RB{Kqq6+7f}R^S8b;1ijnoN2wgyeE2g#87{(MD$Jm>D zW7*qPicz$+RFd*2$KJ0cZB}@>2CB)fTs3Y`uzjm-?2Q{D&(M1#*AOF|>U2N`f zKpxfU8ZFTnuP3sNvMbSakX1!;^H}qydcCda3Mz)7cBhKQigKeBpLud?CnJWTgRDvv zn)A!4w>B%R)u-G|aJMJt&2vWQ%2i(tnoL@D(nha=DE?;2o{lYYJVd=Mk;%F z@mgqtv?@kGwjgjSw2B1c%sH!=Px?o@mcH(aa}yNWSXkhdwNww+AP8Q zD8)q35O@{4!!lA}1c(WHXoHihEMMiX?Cu-`R z-{qqlFzRx7YwZI#;B%kSo%&9>boL?qutnGuo>5tc$o8pdQuV&1Hkd zVqNLu(v{5pq38D&3yzh`U)ZSaUWauN=ZfWT;Xaj>&bS_XEtn0`slB)ZrF9ZoGai*( z*thp+I%1A8SjgZCiGE(+rFKnVxa3tiED#Q#rEAJ+;&W?oRQ~|=t7g|x+PL(p(%LBP z`cx{LI0KqFOq@=p9b=E5@T))CEIwuYtCmDR>lA$b>qSFHADDxV{?#oKppA2=9Pb>D zPfAtPo?n%@^f>ETN_Tv!Flm!arF^Ic2<=%kYNC#k8`IAs!D%YOd%OD>yt?O%o z6l1kTh_UV75!DhqmeM3_MR|yc^e670z~I&R+Gw={4bRB%J{TU`c*S`Bn|#p9xjhcy z{(nmL4-i~^y4K6Z5%Kob+b%!nqHupPgHPE3p!X;kMHg4#AFWrO@kAl!A0($3`tewn zQk~mh?yn8BIRkINe8 z>sa^f9y0K7zjqajMNI|~!6Z`yr zcopbB6zxQQE|*CkcEUK|1OB-x=k?7^gpVzVeA%scQHqDk=B!RID>HOs+~6&9T8B#L zZZm!S5KAaR#x8LWq!99I4GU2$Bk%8}B&iZjsaiNB{_GUK&+2AQej12yE< zt+hd}lGf$p{-SiqFmrg&QFOR4x@3BF9qE(`%OpaQ+ZKi|Jp} z&w@0@()>9D55uY;KT(SBp$#s1_=>LlsO@c>MF@>YO0yEnadW_~eml?fzbd=$4x-O< zk~prEWzgKT)2?@;$S|g<>X$#i`z^LIIIhn4$rK?69ffh%b_nC;IF~iLIgYc#dShB6 zPpEz799J{qd#I`$_SW;g5(uAZb$R8doyk{Ga#7&0dZB zG=9K78SqVL+Ki+MIH5?-GUo?2j_Jz=NtEPQI zaFEBq=Dv9Nq47rNPrG{yi@&x`cncTZTKxY2@s;0=yg{ft{jpXv{0qsio6WPiQdWWW z87B-w5%1N}A9Q$^_Px2kx_G=NV<1TQ^G8GZ*U$Rj#jEX3V>Ybwxl_2KKHxL$#})J5 zli)8C2l?(F5_8uT+go^d`z3$VCG#ugDrQxxJI$YOM&Xkht}BkzHE1-Eh^^J5_Tx3?y7$CEw8w2A82%7`?Pjs(XVmGQ`*-8; zrvvw2_0G9x_=R_KIf^XY4wzoG^7XHZuQgk}r=@)b;GHAwe6X&(HF}k0WO3pt(xkOV zf2Ug7c%s>F7D)*7&MTs~uvL^>iQY8>7-h|N`VWQlTU1%L?P1$+#Y^HDG~MDjmn3n7 zT-Vc4!XF`u#rygmb91Fy+C?SZ#G+sZ1_nA;Q}ENm)<$pbJEi^J;!m>H2(q$Z(EP)D za0sox8+f8mgg*)|wf(+ivLeG6_m`!2(W7aeE;6;?es})Sem1lCoAD0f>rejxj~h(e z+Y$U~I-kVyJ*(%l;rMg+Ypb*ajMk38++w=f#YtO3ff^;b-~&tl00{lLA9wVxM~d7U z(vHXxNa|Z2HF4p=vxLXtTrR2MoUuDN`d8F1r?Fb)@AQJrKp!u^hM>+iew9_cGj31Z z{RMpyCV)?IRNuqL)K+0+e9w8|hTZ$L_WoZ%Tt(C3y;H<6 z0H+n?T7lTyR`EF*o4n)jt!rzIqw=m)&Hx`uxfRC(wM#=5btdB?ng~5BA}L8Vs|d)a zVu>(EeX7;zd77>~!`}1#p5!2D%MC;kX98wEaMU@yP~e0EUtWz1$r2O9-M&=}tsu9hoYNlu>7Zmgj_O;nC+?8U1}*2q=C2wJ~v z0ph#h%qH3@8yw&Y(zVmLt;>5=VVqU}04ZY}S5#zj%51@F6<(E(acl=l>f~YCux`$H z=}zE|I_v~B&)qn#iu`=l%w4Y^N-k*=8s;>~u5$7S>rQ!&>dBTu#*x{lYH0QNTG zz6M-EKXvl$_YLy{TR&>p*e56NeJS4%v@zii7-||MKk@F2WVAlEvz^fY0OPgAHSC+v z6U$>iO3rXcB+jrK9$!8$d_k z3iK(&>V;Fnn$qbTj+LWqw+buPt~@P)khPir00{9;Ft!Wcb#kiK{;%&tIgQZoxnkWcqPc?S+;&yoStyPB9TwAg0Sk!t9<0iW)CmpNWbWeuus!R5Z zp|wY9z2RSmuxj^F-AxL3ody2@wr(QxCW|VA z>MKHhnwu11nB(VVS-!lwy_hO62rHbq)#0**rIEHWADOF`tzK=8npC+8uc2wN&S(21 zasL3cYtFtR>(Ikw(M0O5{{VMj#dLlk@lDYfAG^hT=e}CB%gbM!fV|^Q&zQEJvD)&#iU3&x-BDj;m?`^(MTkUl!fFvX7Vp z&{ya9x}5Xa{Syh#PMwSTj;fk`EW3i5@*PXWGn;111I=vfx0g54m!ArtSCMP8U(E@F zO=PZmbUBXdRnIZ;w~p>4EoMj!+ZDs<)>rz|DqKueO>}oZ4ezcmP37M>uD3SK9V^gv!N%1+l*m-K)^m2ngF(W8iUFnwH(NwU;z|s^KK{pbYRR zG%(uDj{s9X!rqnAY+5!g0CKk$N+LW8^FI*TFPDn;OQ1Znit{~qF3QXgo;*ivkRB`N ze-+1X7?X@`ucy3MEIyU;_lvIKD<>UJDxzUmW9JVMugFb!X0qW}*SdJ9g^1VBTGYQj zE21Zw&RwNHrD`unYNCt79f^Tffa+6}t>qBX1`)zYLfZ$~puf;;*6t zD+U>BcHOw33sFV$bgdsK?Bf*-8*^HSfx1=(h8*y=IIL-vmb+;nUA50%Mxn3;AUV>l zNwKR-%33_6tpNYg{#!iNQ&+9E1WZ;G=M_O^ZQNIr%5cdZlOCoD@M^5ro6T}iPSv7} z%s8p!Ib}?;{{T7St28w#6IRtXoC?mxCXrG&H!!Ujtc!74mRftc>0M2pl6Q=X%-7p1 z9JRKeE_znvc1;-RU8-rZs*S5emZ+nVLPwucp~?tHdWugiYVGuk9|>ISvXXfh5a>a1s6*Pm3!*}I(tXXni<$0rp4ly#uEJl8Z{r&@-EYk`4Gf?bBE znoYwMNWFSdmWDOD9bJ{*y=loV;nujjd%cFV#rP}DZ5f>ID$JVZMhE5XR2d9|=BtS` zE}Po3XS$CBDXxT!c}z`J6&H%-ZebF~&2)FNZ{o5qP%J!lt!x%&A9rr6(z*F!!)Cgv zBi~e7$Qp!=BF=SA)~!oYMl4LgUxJTI2>ZK zZV2n^QoYq>$uF7miUXLDWdonSdK$D6CqKUVeUJ7305}y=3pnHb{H=rQ{(s1;u?We+ z{N9T}sybaeU$?VSr)yC)-GltTSmEP7vHBm-)n5^KPWwg?YEj59EH1y}-gx>S_FMa| zJE{CD7v@q?`0bBxPo;If7V!)Cb)<=h+N|fmp4GWoanOIVM`j$Y28vcVD@$3Ko;`zL zALsn^rlq5#fG0ug$3I@)#=UpMehjwLqfZogPH(knS3l}!Z}qUh{P3B3NIM>mITgao zVn`!BJGUQ@px4@**OiWd1E>t;V`uxb1~s0Qq9 z$d=TPtV~z_04wMJ0JZnUd#}S6ZxsAF@#lcg%`|pU>o$Y_nrW~5?0c9bfBWLTZ`^2- zxybIFnb4cOVzjMqfEC9-YFK5(W9)=|k>JTLG=G4;GS^ANYB!p8mmj+BY2^JPU(UR4 zA-Q<3R`DPFQ-8A;f^@;d>zbyisWA0hVGXK(C&hTJh10w3?N{tciQCwr+A&(K9#@*j zvY0B?ui6%`*py{mD=F#4XA2gX=M~bHj$?|^w}v(26#bEb$w#4*fBMy@4ug)P^sdqv zVL16~RyZQZ0=g+8&I(-}c5TID%i-TKzES#Dtu~sC&{pQ1;rQhoit9+z+MExAeiTjO z1b4?u{*}MroxclZ@YlmjJ2{gLGKomXt$u!f&EEp2@kXJhL1_kJxe4ZR*mX7i7JLf$ zac|(?3tDL@7*8bvsP+P@lsT??$A>i3ffglR1#W5KPFLA+9UFC$4$@8zJaT=fQ?Ubq7T;jDYwGXga`AyckW3il-*~Iwc zQz;@xZg88q70BLdQGE8$a2+qUSvK00NwPcA-b0RBv_&< zk$)R?HHo5`Zw!~$vEnB@4*vj;rEOh*Xk3Ph&*VfOd-z2~G;K4^d_^>^Cet?_j6#a} zcg7KYs_JHGfs6gKw!JUMUNP|3h3+L9gpH_NZs8;OVNb7eYvK=y9}zW=5nLFRyr@5g z7mh2;%$yu~S?*<&XBW*KjuYZd&CJn@OF4rn?kmdl%WIuK=AM1Rx*4TEX5V)F`d2Dc zY7_Z8c7w%y|J8yBziL60yw9q1`3JUfniH?OkQ2g&gY;U%W*`-)1Zk z5S)6~qWD_QE6q|7Aljf;rAmqAR;+m!?Pud%8{v<^C-7QE7V}&Q97sBt)PpDKUyl~I zAlKMG@K7%t>AoZQxoM?o7eZ+Cdl3w%ym@n$IUilSYvN5lbF-7pbkvo!IA|wk+1y*X zIIXB(rE_pf036oL5^yV!>_<;;CmhzaPI^~6Zy>FBi?1d)spDZ;+rt{Y4l6K; zY;~r|#Y-3&rq6&VHd`1M@&!{~UE(W{BQ@RXau5x8?}_eNV>qD9@vjzNkjI{xub2E= zszrzT*RlAsQ+>`e*w@PbEx&!Ped^$GRgvdjE7ZfKcy6OOcjmioe)*Uh^9!kqF|Nh2 z=Q04^m32-9Wg~ppHFrNPF&Z}VfO^)%cM1m;%SYEDt_g^(0P7&R1lL1vaq`!m+uY#S zOK)}ZPzdYcxc9E2%0T$9HHP8$a(J$?!r_1!HOtC$CUz2IZF4ti^W;_8E`DCMk2Kk% zly@V@qavx6IpULZRS5sm{#9q)Rii_i>?hMe?NvpHTJibO<@-YnA1qbqCa2GA0l=q0 zY-r~PBJ6LL>07oi?s3It!E2Mvb=qyGAd1fC9Z6X0^otZKGApZ&%?CkQnrZ}4E1?Uz zvp8p2kqD6oHHWEv-;%X1H*{XrhjQ2)){#W>vxC1f?lp$SLVjB8FP1*0xw)AlZYyMt zs#h@NwQ#1gXI4(WwbIKvlr@KNK|FfXaYubLbGI@T%~posim`YH%~*Sio+~F4R*s_D z`(YWbyHC_Jit?{Hfv%rSosKJlts-T)>QP>9#a$Bdn&qrmk~3Rxl>)g{5~V#04=#IF zWVbjt#UnLPkOFI_Fga>HQD^NnD&8D)tho>>{IyffC~IiIX=u&Qw#tcv#XxRYuHjYP zOxD}MtYo-l1qtb0kw!eG^2!+rV~kZPrt=qS)<4=LX4orQ+R>FmHENK#N@51>OCm5{ zmDO8VtMwIew2)SzBN(&JB^c-{66D}lp3=;l^@7C{{UvU^!;My#Ns|;ejX+tDIe`nXx)m_GBNUx%onij_5T10c@kTo zRUDpxcCM%U8p~Rb?Oq%G+{@P4zxJ8$^y^r1+uU1|Z0oVVj32zeUrG#|j-P34@t$23Q%Be5M-hbllKx_=*dI?G70gT=lVk*@qbesr11)bj z94v|3_+2o|!Q;!qfH<#9_=ME6<-u7V5r6C1*gQfRUW62jjy@q#j> zV9;7;kk1l5xOs7aHv+vYoVjUt2;j)h(0#rC01lPv9}v6^s(d}yzp?M+Te~}mR_e<7 zFZEZJa(&4Vd-Nmt92KqvS%|Ub#=z$rzMs?3{U|g{^k3QQU$pSI#@`LaO zTZjTi=WAVD-byU35;guXuexXS4i8GJ0r{dFsHFR2^smq-{hJcP+f`e0_M30*`<)WX zbopbrk4X7vk$B0C=XlT`t#G%$vp%PxcwM|zd_c8GBb=2P{q5l<+T3&Z+79Td?D~F zU)Ha!R!A9Rjy9e~Jqm%Jan<8Twa0KZu0?0BDpRtU#|YlU#UVHst>R%SviiHm5@8N#i@jmr#L0$&6%=Z~nD& z(a-j>_wt{XK1uXmpPA@A>yf?IWxv%vT<&p`{sa8 zdt+Z@beOu;jtzY`@XO)6k`%hT0H_tp{2=&seXcFOp{On`gn~!o{`s%8v7eN6tS4Ks zn=`F#BGR_o+b`X$U0&(Tn+K(IjDprSl1_aqht)4tB5SfbR~1L8X&CpD+?4&(LGZnQM) z^!+}@01hSF_uc;JKEl3b_@(hvOo}GaboRt?{I1cQAEkO9j(kfzw&vExK#iV)zD?D& zsJt~6{VLiyE*qb;OuxEcWkPvBr7D^&4Jk)c$TfDol0!d;TkQ5}_eZF&1@UFwt=?|e z{a5$j@7GPN_`gS=Trzl;DX*>fkkO$L{!Bpw@~;c=2EKKBy|#9rWsCj<+x^?#yjrxm zne`ZF%L|w{`>*9#?U<44U0#=`$7re~8``O9np`lEC?I37tGBwnm9)S!6#7vb3MJIU zxho^Ev=TAFk1Oe0BT}CYXG|Yz z>TCzvLWoabeJhNbD>$N6FiKlL!|z>&o#Fd-^5&n;PC6RsqMu`)yyIh_i%qkQw}nS< zLTgvW-aOF$5_~b#{vm3%`m+)&! z0E_s=#ojvXjlNa(%PxNbT?<3ct1S{o<}ZqTSuVfgTkUVYqj4p}ePubVdn?!yPnhw$ zABBAX0B5hHj=<*!?)@=cZJxHfmLhwF{6MZ^;!HTk>Cu>fpWk25T^h(MN@>>)H0z zZ?|n7YnIgYG7l6REsXtU{Zt%Rfq2JR5yaSUYVSNltg0(>Upn~9TaPDoG-W4bc{hsn zJS(|1^1q9g_HHis0?$+Ig;y+UEdPRx~;puNVr?H+ii2Zih;!E;n&UC>lR&qwQw0_Zm!^ zT&9Qr)BbI*be4;?VzsLPR~71aaIAx+bC(v1y{qKibLNw0o=K;j*`w_KSR8b&o_NBV z%#g>PYg%;;6vEmK)zI6_w$;ux*OI((NfpyePE5{mJv&Xi6A|fMA-M9CSC;Abjc&sg z)Ghd#Fx90qhIHM{+4TtUb6nldpq;%ctnz0()&!DZ^IZuYbl|y2bxYq$=HBR%D@OiE zz~;Hh?ju&MQQuArl+loYgw{>Nj()W#+aqNl)hnx*aC>t~#i-cGx?z)861swZwQln9 zGTij8UgGs~X;{@7K7Q$NE3ncIUcAmbq2}hg-7itt{c9$SW_vw|2LiXHk17Rl+HLOa zV!H7ubsKA)=w&zw_D$ldOJq}}bQZ`+spYjV&{l4;Ce}EKV{BrwC2ZmBOV0a?7=O@}D z$)mB6T|=&M#%ry#hcJ{scz%E9-n~Ci(DaWISjn&WcGz1i3%>5jIS`JXWL?ZT>b+t; zc)_k(+f7xt@}>U(fFAz<=lt}eoXyOSPlw8lEv`re@Wa;}f&O~_gLSP&es8;gay>n% zu|lypE>|uCH-GYKa6!8gg(QsdKIr;;`t}snLy}T6CA3%z#WI91z}@}RTrK6DnU@o8 zFw4Y@I)xl$dyh`_-`y-{2J@8n?d|^n)~-ua*sX({1D|pSVeML~hP3J)E85_K%LvMnt<|h~(hxzI4N_7%7bxJWL(=bq~o`$!rY^Ig`xyIj8 zYf8sTAL}BI>+Sv@g?a^sksOxe?L%q~j0f}dtQBN()T(gtX_Cbv$r(8xbKBql0M*TU zPL6a>6Gk+T8s^$SK21(OSwQ~)p6mYrEmPcn@OG_>FAT>bbpd-XxB2Z>Q$n_lj4_-s zKY6)1?ax8&iq%ZYHIF#)SA#Xb1z#mS)Rhr(Vhw{1g0nTxf$ov#b9$D3^b0KntA{l_0G3|El+M)+-i;Kj7pEan=HhpGl^ zEkZRb9J@aJWFAsj_!35W9Y$9MsNNM3=${N{gW~6fB=`%jlxz5wYrhIk#3@VH=4&E< z=bhLeCqT0SkG+cU9@wnwEkgH6)F9P#$qY9Z&a%Yo)g1_BBis-Oucv-Gcyb%BiDSiD zOiLAyi|lQ@L2Rq{SM8wr$@PxmWkNe`rj26C-WK>DeLjyDimy3^t(sO}M~-O5KNKW? zz-#VJN8p~JbNg3mdRamCZ7)lj(HF6cO|WQW{JNU@hxShJCX?}>_Gqxu?`@)nYngRz zHRLIliRHFkl6~&i7UPU?C2wl`ci>0t<#+Kj;gNJ3--}4KA#<>ah4r#NH#i({MiUiSN9HP!5Hfc$LueftA{Yg%|`{t_6TD@{Ioy=Uw%_nP)t z0{;M7);vGqYpCtKts7A~`GE!$MHC4IqA)oHhTI$; zip)kxuiHW4`%`x=Dl1P3+TC4U-QE8Ht())v0I;e2Kb2pxw?ID*G|$*#(^S;#WVEnY zwCyU-`##dz;QiIP>p`pRnJEejxaRPw@4mU-3`rK5g~$#+xP7 zH3m|v9ADj4?}uP6Kw7hB;n*?ve@fNRKz?6% zBla=y_LP6oEkAV>pCW&?Uqab@C%4m~j>6Uef)TeB_MCb}v{+IpF6}LNw6>$zf$@yq z36{|LewCRnjWxZrk$)C=#eH#k;h0`R%ms5=r-rRD!0AUJqG!h+*%x+nabkk?X9xT7$V_6q(@c_UgxMkam=b@_p!EL$ z>sPvKo&?o(?S6l8P#F~Y(S3PKg0M(ZoZWVY-%eWA*+uo zHt}-l_t<~v1OEVK{Jm@2{tA2{9xIiuu4CJ5hdns1drJ6Ssm$6ETWECt6VU8*jU9B$Z~#y`RjagEWe*=}boWc9TQsNXSTbrK z=qBO+0E)3@Ng0~1s+Q`@8*X*v`d5>9i|yLY{k@Ypjsg3^7o|-F{QAzT{f2;%{<1^Z z0ot@}tYy2J49w4_SYNjSzn*LdJI&z{veU!D-mf~I=LMOwNR4U z+IzLrS17T3*il@Ui9AJhZz!|Tqmp?C-Ec6Y*w;i+68E+^j}jy{U-2`%JCJ|NCFPIQ z_7(Z{@q^-Cqpj)}7VQ2)!^X(|yDpk2WZ0)ozDm9#kXn`~Lu-u5(iH2BC6xLK58v{IQ5*ACF(uR}Zm0 zwNs)nVY`;fFW(hb^sKq=t~Fb)^mx~5_Z5YDmov)Gu;R8fc#G-L%#3W|wuv!)N`GRwUQ8 zIOW+M=@z-0WJ%D>Zq2wK!bNnx6Sue2Z;TOxAs@wEi5`7cdLwv`P4HiWF8p7txQ;t$ zjD(+IU&{~e74f%P_^I(TQt@u3BYyf;SXV!C9*Vz}ew+UQ!BoF$3r`8?mj3_%{4C^N zT_~3JNRB+Hm!RXjAJV^{TmJwPO>rT4C3ZfityN}{vNkS8ZY9O~b8xw`r^$gPQF$y=AubtjoFe zEsnT$kD9vMJ#VI3zGu_*=h`xBthLwNDCV;?x%7Rmu(|3hSNmrJt$g!q;wP;`;y1+x zr?B35i2>oOo4xT85%;Ug?|e*~v5Mxed`%uYQ$qHqOX6L5V^)=q1EqZB;{AE$K43Mi z;x7{`xkf(K^Y4m$M8Z}ZxUFFm8k0HS6l+FZ*)`zW?a7t6?Oh(Zs=MUoyz=7U%G}pV z*z=j1+l8kPAK4Qggsrn2XyI-_Di;(#TLV~Szxk7|5>(w`dcYRn6E@$f3~j|5bf zuyI-*K|Lq}j7yHyeOqXtTe9_~c=nG<4GON@)1zvyAt3Qd89C{S0Mg&>M*je6U{Xac z#-Pyu)cgTh#GfE{paLf$k{OAUth!L1QsRA5%j*ONi8?aggQjf@JDO3~WlOk%wr!@~mr5nEC+ zbdNlW#(6r`ON}BxGwoj7-XB;KrE}BiyPIw+zm({P8QAd4{VH76B-%0RYu>MP6=ltE zdaZ$hh!qEQJUya#wagC^wkwvsxsh0|d)KJlX{JPK;dNadk%{-Ku$(M%;Y_(Tjd^dh zlj&QA7l@Hqx4ZeL%4A_1jjhFMT3sFhXhE004RhrNYgW!aK&Y`^$n~8L?oFv)tEJpC z0h;p75kN)ay;n&9;=JrhpRzl9R940-qJqde)^&sffOxK%N0z)-mu6#8SrtW+t}rTf zwp?;4>o6x2_H`dEc2d~$XDv)$vMf_GslY#{`TaUncXkjodu|gy?*9N^YGkF+j?#X< zpGuu>O`@VN&z|G_^#1@Nj&kCd%^LFW!}qDG>64hQAUmUBjyruZju;*Y`W<4y zyo-oV=-d0J`u_mC)3M^3k*6Sv!cJpQRXRBI1NnROH2Io1;V-|Z;0L$0cG`}YXYmV9 zmtF7*d0O4D^EC}3F`qRTlRcOnK>q-imi=S<$&}!DtXiPCR(O~(CmeDy-yN%}D8h8j zSti`g<#2dX5A^mPoohQ%jImK7mdL;#=i9exu@vbDTYe4(6b`jT?-Fg^UP(+GboTtK zWQ^kzg4CjtISZo!AUWfVH*9))dRG4ci7i*hS}dMFx7+=tZLu|olmH)h%N^hzj(>MA zVEBHIP~Suh#@SUtWZbRk>-i39r-wDU{4cDbXo^X~%BjzW}-k;~cr?qO^TBEaMeB-$2MRtBX@ZGnCHMRcCk*+km8Ij^4{J50D|T};BNl_b-r4YQMn#o1T}}?p$a(n;355UtO1I2|vT;p7qCkF`DP%4}>+J zfm1h`CD{Ji(1AGFsah#A#ybzSTjoBX7f8omYxJ}BF!(E{{7Cpw;jap6uz8lAE%3*P zEarW=DfZ1?IH3uSKI(?`8REBr-voZPf5AWh0B21zPw<|j;Y3D#Dp|Dq_z}hdj^+_L zbse_I(l7i69jo?l!XFECpA2|{>r=B(8zs&?T_X+u0FNpB>&`!BZvy`S!YASJV#-j) z0cCHzpswwoLB9jvHT7+b3SqdMkG{ccx{^9=rpHHWgX{>MGyJ=LLbf9~i84CKsrV){ z`c`ZJ0Ev&f6+g%G6|ZuFHH>t_{;kvU71a>23~}6fqwK>he(|n>;I@iCE-tJ-V^!@Z zTR6{fd({agc?np4;jKNWSg{5KMIad+D>_?gEn)rf!yn)kp$xLhxV2rQ9Q##8n%Xuo zt9jn9=}as{V~-%e$~}E+L+tzG1|j8+neJ$tP}9Ta74r}B_8!$%=J*udka+4o@B5;K zF&5Shzj?SF`=4sC*CIcV@O=kr<|EXJ;Z8lzLHJd9?#Wyam4CbI_*I$M>VXZ;b4^=x z5^t8ZvIi+FTVo_*8o%MNsi;p9tGrM5%0|D{t2y zO3?oRgilOZ)&hfOKJW2Z#>=)w_M4o_8>=3s}@v6#E!IbXTSLft73}kn!c5iuv~D z;qN7BoA|m`4Or=hN-gv;SgTz{$m833W;waPRW$I6*aE9(fa2801 zl#R6q@n)T_N|zCY59Ng*bL&bdbJXUoA6D@Mo_njqXq*gi%~D&Y_Y|j zeZ5cPUnzL6O4Wz@Ux+{LpPtMA0HCjbZfofO0E+(r6Eyz-4Okc;pUEI#sOrD&*XCc3 zek5J%nv{W}dxg&fc^Mx2YB^5&6#0&i#vU`Z(;e+G8%Rb57izEP$LU@TKZtb;#%R_A zK!1#LpZn(n^r+>!p8jt#@fILEd3>M5)MX}FH^!lme;>-bYYNsqnm*}0qe&ZiSo^(A zTaE8fvlIKa>48@?{W{#@MQ!E!;=8L|A+Lg4Td$WKeMzhoZfz>6JDk3k=duD9r>%CD z{{U-@0~jVo?TYBEv{_W`u#I2dwqwyR$n)iuQIoK(78v2FT|JB>ywYdRaPPVP|*pLl-_bw-Vj7bBsP!u~@X8=>HO zS3CPa{AbYr0A_E5*IqZ(kDWRwTWDh#c^~_ty-UM3ajF^alWyMa(!ZMT_$g2AQKnoL5n6tdw88wc(NKJ9{l^MXY18{&f}|&!H@QK*}qlvhgVfzHznJz)qsH zt$aoEU>eTGmp+-b@l$5B+r>;*74x00xVho0Q~PWYjlz(ss zp5osNtto+|W+7NP)ltqjP*am_K89ct)M1w=*;C;~KG3edNQBpR(6 z{KGZ7X&BBcFg2|0zFN{2n5#B0;dvFPBN?C!V{8Dd+lYW2E2)w^3|BvRHqdBu6976? zC|OND>9E{YAlW#e4W*qx#V^_tv&Tv(G(Z2<{(OALSo&54?2gKEE2we&HJLS`Wjt5J z^Bhu&XyDFE^sCmkjU#X>^p;y+rF0r)kRvpETO6>Vc4tGPoM(#M)3hgZz^<0s!+8R| zN5dW(Nyj3q+Z~@&F|=<9d2p<|a!9Vy(?i-u;fzh7pmP>TLWrrfkNdyRh@Q z(6GuV@OgBoakMWLgFcA~Y=y6*^?weCbHJ}QzpxR>^sh+=EwnsOQqZxG1$maSqrx8y zE9wiKJuZn{R|~87J*92yS}JD9`SNcM#pN6*t}|Bf(z1Qp`Y!KA5T?R9S0AX@CeB4W zbRwlA<=uC~Wg^N7@=Z$0b!JY4SJ^k(Br_eP3i7WKcwRINLQ`?FIk{|n#IU-bnzi)k zkC-c>y3nMI=M?KLG^1w~&gWvRb^1`ksIO7c6zu}M*G-J80Bh1T)8(Kw!&%%`Jqig; zx~la)+}4!W$6#1r>s(Lvua==>`F@otzmn?_LbyH8dgrmpC3JSFEKDv0dlY(po$<=zS@dY)(!74MtB;o3Tp*q7KH}a5`Iuw?CjiKtZrlb} zqMn8lk7I**CWhimOQ@O|;AW0U-AX)ehYYNw0<5H9sQ`j7dQ~mkeqmor>TsvT%~^aA z@rqQ?ejutR_I|f^;cfM3N`VYl(tXm!KwEp@NRe=p6SZU?JnPy3FTK`P2YQdEOl&qZ`|`7;Jv;AI)?tUpT0xS2>l?txmfSnX`} z;<@|8+mYxi8=4-Q;9oDt`ZS*f@7-4R;zqdetq72pb^oGj{%gOgq< z3r`EQb4JlTEYZAjDMv>`hE`L8tfzvc0tmtBOgYar^_Gp{XuoP1AH}d-rMHRfY~}v| zgoR`x6}ouhZzEG3u_K7#ya)y*WN@k^Bqohf^F_X}`!@C!yQk{vyfNau-qQIbUMgX6 zBAl&ZV*QqfxY~G$L^~AICDwzl2E29)>6Jh=s}19~TIr;6dmT>s=AP>tO>1f24t^kb zfzKh`mq5`y7`$czyL05%SGqfZbh$Z`y2Kt@W!35e$pNh)2>*Hm(?wn1>3n@ zWN#`(#t+OD_b;y8cdoN~biL&hxA;x)Ctr);mHz;bB!cHl@dlr(6|s4tDE|N?wh12G zfJomX^cDS2{>9(3_J;?DJUil9Em|v!O*70@VYV?P&BIDlZ*9Le0pMdiSIXbA-|Xe9 zX;zVZGxKEfwEqAVxBRvL0G3T@V>|vwBEO(d1^8CON7JqBH1rCn5O6=aTfp^b(zb_9 zpy?1<*vPT%O~(Tm`G^2>`F$(55Zn27wrm$`<(;wFdsb{Vu$F|bNRyvWx&Xlc01n(% zv@L7qOLufz?B#$1u|@0oo|V}KX~>Oxcv|8>mgnUQnIzBM1Hbeil}_GViRC#Ni9<7v zsz1YlpGwA->R+)HwE-Eh#12kJL-`B=_|vU3Z)^5P{{UFqjE6iGAJBegwoH76n|b<- z-)6OM+%5)tzB>&5L$~#O^4UVM+1bgEqb^sHJs0YIDwL#M5gnr>W-Mnv!av$M>5hA1 zvKIQ#-J}|4?_PhIrD4Mk=RE-9ILAzVD2X;SZtkbLfX{U(g47&FH)Nm3p(h=B4&AGp zxYS~mc33XkWx(CE{LSs_{_}PGsoKt^6wVtY+;uD5W2Wvq_7xM#_coBrGxi}7 zcmp1t_y;+xh_6=R_us~?8Xw)mxo`*fpZ$^i3H0yNp;T$$^KXK2f@1E9KfXI)^j!DF zGc!`cIA*dEus%{I?f(FEeqZkT0noR<&{c~|dvhoDh2wG8cS7KHQR&VA>x@=bqkHGZ zEW<2NMgzL^9CXM%^T1lhmr_|#;>OsJ2X!QHhlOAEPkenTp;+wOQibLH`)pO?gyhvb zjbF`wvxjwoKZJas{cDF!L8m`xxO24f7XeS>SyJlKO?Bq(_88;}*%%(39>#`H*@X(0M@O=)os;rc+rF0ws5*{kPqk539?z>kzCHRQ#lA)e&v@gq8 zShJk*Q=5$?*Spt)++2N@;Su=RUi*J<_Mx}3*U!E-K^?(Jd+}IRq;^rLWL#a^EzEcG z^8Kn9?XPdiyKW;rDJ0XQi7l<4mo?XnoLHH(Ny80y%`{ErpwI*zt%ZJOQ?N&&*D&hOH^E5-Wrt*cEc!wYsvR1x_OYh*$` z8u*D0p{`py>b&It0Lw%3{{X;$g=^(&&$Z+JS=vMX$KL+6^uNW8Y4x@+5KKOye3btH z8u`ll%v%f&C_T(HbWGu-vVkR8p`D*Udq4iF>U8@QmIizS81=0CdugSc%G?NVO7x8< z!-nEDhBZD~^{joDIo(mJXrrU)tfgN&R9C0ykS0pYl9e52~iU=gFJ#8&~tKTfYz8>X(0Nx{J-Ode`lG?cdogxA{Z<0aU&!{7vwWzz>J_ zelqb4k`#e{SjjHZHvlS+$MUVTvBf!`EdKz(LH_`0i(i2M01BJKS{YyMO<@cm4yhUH z)%>n&xA%ID+s`8wW;|Eaf3!!%J!|7HkFI=S;ybA&dD!faA1V&O4%PE5<)9pj;ICOX zXqmzIHaUiUsn{-ckX~`cYTI5N;<*^*z{P39FFaAo7d-;f^8i+hUtw`vO|`+|wG5?~ ztz*hn%lDdr^{#G9y~eCgto}E>@72s%G4t)zK}0+f6n~Dv3F1Z5tJnHKDG` zYf;*=WwgxzaJLqC&lOfXToGM$#g`RF?3fjynT+t@RaT6XTav?yrR0Lx#Q;GP9OD(! zSU^h(r)g%Tv0dhwrz>sZfD3q-ftu7vKpfX&Y2kZj8I1Eub>WM|{{U6evi6A5#}PVw zr!~xKQTkVRe`4^V$mv`zqWCzil*r{ctCiYnsv*d$H_Soes=&mWLNu4|)2d|qwCW8H z|JM8tdt>x1S2mF#+5p9J%~

(YZc{q7$_2uVm3QQi?XNK9%PheyJ2d6l&_N{8!`*2dJid zlyIlMXwcik0Ny%QpW5vZ_OA=K@uAp0YQogKdSfT%>0a#`HY#*tv`4t>J|nk+GLfKa z^83rtcP8R7UPFEH4?o3Pnm>wu(-3E_^z3oYpDF5^ZKQ5%dGx56dR7jFbrUendOg3B zAQ-LbFxeh8uXsfmmqCoznd&|jl05vm74{gk1Y}nyd*SP{Rq)6%kb{cyy??^c`J@&7D=EiAHxuIhUqT8qwz(5zu^eK)p795V<3u~YO=AJ$y+kU_10Tm#uILFTxG!6+*4z_Aa$%$3m%UYPe03OM6!A*0u9wWo5MvdasOZdz=P9n3n!x35WKcM)(Qq?XZll1( zW(2|Fq%3;xg?xFX$E5!N!W;1t(oJW>{UMDc*>_-A^Cg^(m0wnc!G=s>lzr;@ChJPk z^?wLyx);Rz{{Zdp4XTS@4fx8!GbPrUx0TcHVQ-zIPEXk+K5S@nXzrgN_=>x~0(k4h zz68`Iy3m>_VY^Qz?#k{lJ(c7?^UC5A^16OZy-=y`y@P@77uN!!3e+_tdOq0Z3 z9yq$xJP~s-eL~7bLK^o(xyJdnsli-%p?+Cok~6t}dHA95nD}c=xA=kJ+y4OgNjwv( zHqAd+jW!gJ;G0{x`oVOW`={iUet4ONLhPe|hhv#5XwRJ9&AiZ)54oO4PJf+xSHf=? zYCi|OSe`7lgUyo>M-{>U08*%Djipy#<4GY$g`+*5xjVRki*|nQS3S0_FaZ8vN~sR0 z_R7)1{$wmlNerKIKn4#TzN5V(Q>FOJ@cTyazNZI_z8^;@@V>Jo%WAS6?`3ld#P%{V zT=s#Yw%0V+6~wGy*pQIkRAak zHun-P+(g7aa+UL^h&(T>_$S2I-VpIcvsvnTT<>dj3x5QWoE#ht%f?lQ47ekLDkToa zidr1}Ht65PD_Sch#|={3s4bnwx@%|@t~XOxKyqL53cCAf{{WY7-dFS=`#Ao~78>7- z{2}9QIniYBr;KHcT}C!Ekw}s1wrbsXW6WsxX9oav1XtxZ!|&PV@5H`2)AWr`6X{ph zseksg+!1Wd$!yXRfr|N4H^`6Z-Dlm9`N>`cXt>3v$B8q6fhe_-Twe= ztFnw)*M_I>CVgG-*Wsp-;aPPH`xv~Ts9s#&72_MOb_|XNFp;8B+x-(Zzw*~-{yoq1HQu6n30U7r+3qm%XUoE(_8Y&) zzo*u%x0fZ)*wMj_IR~H_7(bqW0azB$2x1#o5#$_>xqiQx{VIR6#Q2UbMbB=ZFBl%c zr|Lo8d^oc)NY#MH9k~)4%>m`eFi#u0qId{7wvI8R%;nABV%`ERq6QPboAu$ zQ^#V`D_h4el&ISr{{VLYcK-3lBc>>VIMhj~Kkj22P$=h)$MHXbJoFf?Qwt&$v4R1Y zkhCZ^m0nr->9})_*aMMGw`P@B&4*_05+oxfdlQZVub})*TfEe*EG|;PV`FgaidA_# z&rI>?GyTz6KWVl%Q(S5-<|3+Z^n9Eu@z>c0AIi0ffw2^6Yz6hJD@Lt`?u^(u&I+WCtJJ8QLcaw!6&S72Jg7{_c$k@0;=5I$9Cy?GcrV|L}M6SFBl3t4!~pF{{Rhh zxBeoy)F68eGjw*4WIUjq?!kY9w&T-|-<@o>1EiMj9WV$qyiGi>@)3-=90Gs1tJ}XG z)pV_vm!|o5Wq#2fT2eX}AHo-=0UJltzc}VBZtam8=&zEv^5mHGQ`cbcmi{24ILnWw0Df_t0AhG4PAgUxo%LiQhy zGmpm{Rwd@Q71i82WuEUPQ{as`&Fm?b^Y9b`!juZ@H+v|UTcZe zym@D$&bBKPGkWhI<6MjCZ!MH}DSJBst!UBCnbKWq z(>d~w2s-h%{3-*kLd)bVi^ou>9DQq;nkzY1&A9VsZ&AqpwTUFqC&&;FMI}KWfY0)* znl()B*H()-B$|A>ylad(KGncMs;#i{ww`^z`qbCIj@a* zZMR?IuRd$oE|*gj+is1y-G@Ln=O2|D8C0vLhi~yS z;=hHy4R~IE5qPw%q;RhRIB6TQt$#J2wNJ<26nuN}gtwP+G+KNli)&>+C>@Cv(foS& z<>POR-XfDx*DeWK)+;pd4 z<}8`&<+B26oOYy&(Y3U7H0&J4<+N%*t*9;6y=Gg&SLLm{akOC3${fYGq6^ctS~Bez zs`&z_%`0@Qc|zupoD+)Gwc5lAKG&1FkupjMYgXPn$y z5!$IOt7oNm7dAt$rB_P@Jt(>Z6Hm2m{{Tw#?H=Cl+-4~aeJ`v)9u@|?C&bo4Sem%VQjL!# zxjU-5CM!~Ff0~|_fQssY%_6#e+I3m`7^CcBtbE7+*ZeHCy&JPvWz^C>wd7XbBEaO< zR;l9RrvTT^{ZQtmK=td5YCQF>TK@pVwsAks)#G}nj2K1Myr)?4(T-9xQ!EA1=u(PZ z52UX=bdhlptjmvy9#|hWe0!?+?nfjC0=TQ+8%w*1D`;S06|vtMy$`!T;cHt*#I0F6 zpT($*c!Y}hK5rRJ$M2fRpTx6BcxJAcQcIvpXV-UM5wmxl)#^SA_@^hDsM2J58vN=1 z0EC9%e-W;)!#*mxj_5S9V0JZJhdhz{Ti|~gLw~mqk^HOa4F^+>SV3c&{FL}<@eHHJ z3*m-1^{=r!FYz08+&(?KRJrn_o}{@^_K4Kr3xaDd;`%29Deqo42Z_YW2yV5l2CN^P z)_+{F$Mv*!@(A6sD}vXo)#N7?trn+~io?}!xvf66&!L;Iq2rz-)0gJJubzBIrZ1Ry zucP&gpEZaz7>qXXRY_;#w&UCa-bKU1I61NqiX9MXzs(&NVd z3GoN|FN*Z}BA5PuwANUk{snH|xqtWpi2Ky*;2XVXeaFN3gxmNlQn0ErY{{R?x1I6DC z^#P^$^GCSR-MEL(k8E;r{{Sqd<~c{aM6bP7()I06!=a-JlG#avHI>Twmh)h7xp*J& zCBu%*TFqF~L(U?CXF$P8k9P-riWVDgxZU@#*Ck4~8WRq6V7fV?r|bVxoP zP4-awCq;i~(Pv z1L<5BUHmi?A*z8BnIc&ES%adqLSSsL!n@0~+M^SDL1>E_!re75(BHnu#uLP|3K01tPs zul@yo)RK4#3m*u4O?!)-BEcTtPt;N`w#%l$Fxd9Y?Y5LEAZGI*AE^CcW+eoprmd5&08cW{|%{SQKNy6th`ZdsnvZwtMse(DDf_h<2%h$_I&t+qUTB1?5|eZGEU@qZOnJkART=BMcx;op4tL1 z!!?{Ea9SUf5D6FW3|Hzm?Ctw7>E9DR9%-8Y0EKScY7X}VL*=SoLSU5?v`fm&fDm!X z1gLcib6+rcd-i(qZRd|QFNVj;I;N2|%rhb3rCVky%?kej%Rzy?RD8ub0Q9fg-}ooT z!1=Uq4e45mmvX_W8~3%kK=Z?Ec^X?^K)YAv@@IwuNW&wk2d_5@UPiEi_h-U?fI2MR zD4W7_*jn4icWm<^wTWVjH*y!4!N4f5BxB8t6B}W7AYs10WIPe4MW@;eXrl6!O8kfT zFfbSBxxo5Ymw11}mim3--_ITm}-j`;Lon--g}=Kwxoqr>Xw{mUxu?Tb_OR6$B~d+RGT>rBbWxgBZuswPTQ6 zPN$(%-SpkS&*Xpan&>tkWn{!Gvau<&h`GUG!6)0N9=Se(q!M1aBTq=ofO&g}`4D~6 z_4ELB$E7kLim3&}g_Tei>;MP9`y>1ZHH$W<0z!{$cgZW5qfl^icqeG?MnD6maX~z% zOLq_0@1oP<+QKio56U|7%sz*p>VAOMX1l6eSTajvBVgr3r@sT|KTg1M1{c0-D(g|R z(_TwxBvyU7mg76NoD7x6QaC;Os*Zp+8R9$G=f3i;W7sXxRbWGgCyWe^6m;B2KH`B` z=;HfqQAuy7KQ2Fgbt%~7HzWW!;~#V70Db)Z+RM62n_J!g0A!kZ(;iwrBMp(x6{_cgRkm|n}Q-oxZT@y#X(k>tX}n|cLMa#RnL00+6O%j?hV>vI`73aW(Q z1RjggN6deLdXv*23cjy(IgjmXAw#u3&V$Q2Ior7W(i7>s0Cx7`wL2AESZe^wCG0Ll zKr*_(Ix_*g%2_Zvx0u++PJ^l@Ou@U<@Aa$L?V@7IyJrXYjy)1Fj3CJ%-~xW{88o`9 zrh{)5m8Bv4+Kb{SNLb^x&Q(EUby$Mq|a2_Sk=VVvft_MIVK!pa?uet6hDs{crA~PxuDNRaPy( z?u7XVpcUZ0DEP18p9sQre;V8bhCh-kZQmk+^yQiV07~=E+7sf|m*9;Db?*C_RzEH% zqw68(YVBUx&QJGLlhoJYKgTbRzAn1ATie%{RJxGA*)1bhVI*n%Osan1{$%=AHEVL| z9nV`Y%e}9?JrVmOum0bQsx*2ZhGB}{)o@vHwF0U6a@=Pprf_{1y?bZE{{Y&z$NK*O z&3&Y4dWEjWK4dbqZ^6g z_^VFTV%)^!bST6B0IiD7mt34VTX~Fw(nbyZ#b|JHv#UC#)cp@1T8?Y$k1OwS^{!4m zQs8`&#;84j{{RYtTlklZ=TnL7GP>z_$W}l)(3Veg~r%9t+NdEvRIQmy-Wo#8Ht;Ot?uZ1O*2K-X}tGpBN{uwN1xw*Q#QN&0}{$$tY zhsBTDKg3@VZXVjk?npHJ{I>bPKaYCjo-#2w@ffH_M1Gk3dHuR{JqE^YGr*_Ky_6MM zWBY=?BK%k59b?8?#Ja|{aPqvQoU#61mBRh3Qw{uaj5T>RagwU_IF(D58Gc?XE=ac; ze43IK;;1d~cEGArM$o8ZYL>lg&2?K*&3e_^`J;;B^&K&vrD&@n-9&g+p{MUH#c(0ZuEGzm4<7Xru%$ zFav{Ci|;t9l5QBOEtfB}RHin~Ygk#BS9zw|91bfZNwM9GiuBz(P2dbxRT-0=iZ-@A zFe_p!UAW@9J53pnO0zDK(bo*6VpbsMnhdTnOpi}2dRFJy$l#jRw6a>#oUUrCqkzBA z+?<81X>^xP3fHuqKg~Eb1>S_Eb6XGEo*Q$bG3{9`qOu;9?NVrSZX&SnG-;Idpwrr% z1)iTeHbU2NrPxSEBNWSPNn(ZoKb{$1^cJ`ZB=A@F@J61<@ zB&=|HtREu^@sAN%?k9@(-BJ;Qit^oV(0RZ;k3B0kOF{<% zyP+)%oK92h1JaMM4@&8N%+!6Esa(&L|Iq$f^&ML3*=x@A4;frUaC=tYh&3W(JAEs} zwOg4CxvxCaof=x7j5P`J2kBmWsB0I}zUsR73`RJt>Be}gvJ0GT<=o|t>ZFqs*0ygg zvUnAkmgjaciqk=RQp9i4t>BT;mfqWOt;p8dV4mzL-@wOO5*zEAU+jDME?e98Tas_5 z?yG_EKg0h36nHLBo-fkw^oBh?+iv#R`gwtCQqQe=kB2@f{7vwX3E_QzP_l?|qyGNOLSK^Pvp8)C7=-LJ7mr1$*0FR9}@+7>|t_P%+(*}tD0K5Qk z*@blW_ul|KYP)rt3(o;+Pu6`ieyOBL$CB@7D}AaAV}Wlps6XeIYU?llCEIH|o5!9O z@cz7IKfRdR9X9Fl)A!nYTXXKn)y+wze$JjEj}89-2^WKJC+X$98upzJ)m!XWKz$lY znKX4$t8|WI#6JW4X{YIv_@lzY+r@evxXikrhh_V0Q2zke#2^Xn@%&zNNUPk}8)fls z$HclVwukW|#CVrYi9~Q-M=s4$);@#Dxgod4qyAXTuNmWR%x1kWK=>u%n|*fEUicy7 zjSp1RVCK!N^(N4@WgqOtd@UJ$)h2A$W8*DvSkhziziqFEGtK3UpA|=KCW(2Die%Jn zVm6mSeI<_LJ;!?IbfVw8q172Dx$}0r@YeP%5`P%{Jh~nx&~3T7k_6d%V5g_sEqRhL z*Ak}q2weGtAXkr;E!MusxA;Bbc(rXm#UHb$#7OORU__dAhlyT27f_U+y0&_pGB%fZ z@?#!RmeIDHJY%J2c+10De10KM34YUF9=rHy;=6;GwOI7i9gXAt^CiWcSV?oAn(Iv^ zcQ#r%F+^y!)BKI0W91uLY?}Q8{gM7RT>LejPAN5+2om2en*!{Wb&9x3>FJ)`)R(M7JJrcFNn)y2l2e-i%y%S&5?{oIS+Zw{mH zqMj>%!_ccjp+`(s^yc91hgRISKCk$}`#0PCFw`H#-V?TxPY%f^lO?o|=iF(MNs_kn zNq}XJKnq)3;CZMDuoEPUUUlJ(H(SuWSEg%T9FWI#X*^d}@-G?k*f9JKFnz1*UxB|9 zH6IXoaCk#h`z-o~pKS2O_Tqled!{xF)^|seBH%ET4bU)ImzdvoXn$un@dm5E!%oX< zceq#!uo7sd2MuX(SL9ivI8wt1<7|oYwilsAH+Y^T;_iJ-;17aYca6R>>ApLIO@=G` z;T*nXtt3p7TdF)mk z_RQRd``?H86}0-&EW#b6mcuVSk0bOVk?5#d9+cP8h(hC|1J~1!{;;ZNQITR1mFP!* z?%TONvoKi>ZLPM)~*&V=%wfRFZOWV)0z&GY$@@xaGV?tzRCZhKS$`pxJ2 z9-ODlUX8bQeM0{LwZY`{>??`$XS$spB_VoAAvkAVjgRiB?%Z%OU514Ax_zrfZi_6N zhzc8Ol~K`G`?&kWaX@mNOKK6>+sSUz8)%(=L`7_t`@5Z&{sK0FI(0S6%X4jWCde{; zgz}LLgn*!M06mmdbGmIVwr53k3ZnX@C&efzy zfIH<<>bcGe##hP&4(E_-XdJB>vB`AOxAw9}Yiqral+praJ+iSrJr68F9AdN1rO)os zj0mCe3aKreWtf}{k%RK%j@7>)(kr^wV~*JET)ob{rmVrGSsGQ# z10GYDn_bcFC0PcO zbgHlS=XX?dj>(556m;9LU~mm@Vh(4-zAm=1KV!6ztc!#g<+&UVer`u69ff?eUL`rxgYOwULRPj%RBtJNBsf!v+5UjkO=OULUaN@p&xja6G7!WRJ()%a+IYMSLIdf8q`2#jhD^nvRlf zBvmCakjidS-4h?Wb@?`v{0F}^)BJe&Yw_db7m2St1@KeDm-@}a!6e$f#1V*W?ID$O zi;KyOcF}HDClW}!bg!PgDe%koXZSgF_OFXS5ohs^x&Ht_Z-(xzX|uZA>vth-${vy> zq*A!--Rle`Uu9RFpHET|_DT>wnDGaP9k16;u(O)(>Q^Ewcp_O#uMsVUYY%ls`onIL8h^PbE$zG0wmtar zv3OogLUi8c*Q%$3{o>lzKSFFgU#9Ad=4r1CQg>F52mb&Z09Qp0tz|r6p^*%uswQ)v z#8-!2S`9mIpJ?z7%titIKc#ir70a@TZPyG3Y!UprRoR{vhqJ+`go}LPjt5RZ!m~!N zcX2)-hY3`y4yS#DbX8xVtq0KWSEPYQwQJDkTbR1OssnjMezh4hQ90I(QE*`;Cw2GhA5Kp}CYSncLeV0q_} zjvpH|Uxxx&^#X1t&&vv)KOai`qxi-9T6{zCV?LYV`OcYX(4uj=w0=GOOg=N$=Xl#F zr2Q-6-FE2t<&AhaithZeII80uz1Ws-{6(m07p*h_9eo0VkzlTwRcI1wk9L zcCS6tF*^A5JBDVEso7!@2bKK?6m*G}Fmqp{G{m{x9f zI_XE4Ul-BnG;8X1z=xzdNTX5Ud}u} zwZ2hQwJj5BWD3@mVk%Lw^X2B59DHK2THY1o73mtqiz^%o;I5)Ex6->6V=azTQPWg| zisR<6D8zwX{msB{ma{yBldW$Vm&}eLNXiTv=dTI_&3awto`N&dy!TOnqvEzx3!{{m z1*+wge=w^*W!=)X?QIkhRHPG;)99)LE^E~E&^CjL^Svq%xHahdY8|5$l}2#Xndtg& zl2lc@xS}J0SXz9WD(%J6^shF&tYD;7YgH$KS8eTfj+I_3m91-gi27o*!%WsPCyLx2 z)uxtPm!(T(x86W^IM$}Q<}{L z*|!}l)2{5qoL4V%ro;|SLliL%FS)KIjvTOMx za&UOAA5garYpaqT)k0fL)h!OCRP(vC&`(Oqn@s1WdV5$eaa5(TPfFQH)lD8?+F{+& zoBf!5E7E_n;{NR)+3|mNpEDmaKmXDGLv?B8eCE94Qb`SW*U?ULO>y_o%guS_S4S9% zRq$!-g}MroQMI#HV2rS>2Xi;hz<GQvJWY9pC8Ds|rpDElkS>I^q0;L7mS2WSmPY5b&ZUDy>hZ?$?+qIuy=CV?;tPYCa z4h3A3Y;oXmSh3&ssP6BM&MRnCK_pPVnMP}a)b12DZC+h5uRqo)OvwlI#dmfdD7<6|s6_?R_w&~c{{Vd^#{9{zoZnKKIrim2^xKN-bl(t5 zfR6n?hwp!eKPstJNa>XZeKn@scoO<8woz=VvUEQlE1EhrE<>_mWK@J zv(WWVhQ2?uw)=dVEcX_T{{SuRlkG7t{02j^Z@cM8iZTicz_A5|Qm@gJpnSBLd& zCrr8fZj(Lkk8<56x04P30DYlG{HxWM;y1)iS|zvf7Oi9AJ2?LBwzmuul0WOE+TJbCdi;f>UHKMg!XsrYGig_FkOC7n)&<1S{A$ET~q!M&l+jhnjWAt{)?zU_LJau z5z468_r@z~SemilW1`WvnjfLDc*Egu!`)Wr;lJ$x;4MqxJL?WU-={6Ky0?aJr&Y+D zY4;8;VPyXRmUC&>aCeHnGyD_>=oK>JfPV0Kzdki-*zf1?AqB@(tcq z#354SV)l70;&mK2$PC^u@UQlw_=WJ_!xDTn@dS4MDDZ62h?_*z;g0IjSJ7deB#Rj1 z?rsN=Ohjblwdb6@Bt=Tfz^4YZcIu~{2c)6t(HWVM?9 z(G_xbkCb|_oN6!W6D4$WKLg$Mj~!|D`ZSUYnH(6Tg?AX+rgm|?PeOo#LF3ZDpnrhh z3|cbA^;tQ5UQAaM>LB70l7J%K z;3V$##5{7Lm-z=c&^Oo<%-0rgPW`9mVV@oHr`haF%k}!&zW=W9Exzb?fB1SKwqZsJ zG(?x0@hgNORpDx=!0w2NOS4!NlO+P4eki0M!K&*O=nWce(Va!DwJ}t#dfbBe2>mAL zSCr~X%SJ63@3Vf5dhwO~vxh@#Xb3CG1=By0aXZ72{8UmI;j5~FKe1kH$z%4LH{cYF zvW|-jH7}vN;Lh*TMnfEqYoPBm-Jeh6XAl}tRwId@JzH;^9;LA|Zf`mOy5)4nE|pZ5 zc1o28>)XO7ET=0035GA50IjklWK~=zyOL#QixMD1&q=d1ov@;>Hi~HR=(N zlf6pQvtDV}%h5l78ZGoZxf*cdnSWef=bJL7o|w~_egUsD?7EqY7QEVN^HNYj+*xG%qaK1OlSA>nmkU6J)^D&AmxVH_+APi*w6f z*~qeO>Lb9S%idJmw3eeoYZ}M)A*FCI%XQ#Wo97seotc2jbFOc6l(DNzY#R-HuRLK) z4e+*^{3xTK6jqYP(dD6o15~hyyNF?N>iTdp@2}+3C!7})qkhhkP@Z2WO3v2>uL=<+ z2g&K=x2nogLZ^+3KUSm^LweNvo)K%g+9PB&LJ#Z)3%(yGkN*rkL!wQd@*Z&t~}T z;E#k0RIL4PFL^rk_n?I#Kkx>;SHp5WqN~kKZ(0tZVpV}H!v@tfJcK#9 zdB6HK;KL%kVLd>9b914ld9{!IN8%UY&wQg5*pJZc9%^irFX%p5K3*XhDbe@g?wn-K z5t=!KWcw>HbZogtS@UeVlA5W$FM^121oQ`4Go)8^od)e+cCl%(#d75ehaJMmsRQ)v zooB3CJ5c(qxisbLgt&Qt_c?4#kP2&yIe#F{XwrN1hK((_<@SHq8HAth&-}^STG8Fv z^q7FSkr~MR3EZX*@%SnKO@T&b?!4t~2R+hb4_M?U{4Y6xH*0So6E;1+TpH`}tgWu5 zHEMPiiJX*qJsjHzWQ%?tkbL}=X)zu+dhwoXa-y>7qkW(OVFLc(=x%Ar$QQ?|OxB@D zwujF3hnD|H7VNg*_iTOV!s_0X?OvpU=UG2Dz#ELpeHcr?ag7{LgxG)kTEIQ!0D+^X zPlzjCUhSh*@1uQ_NfX&T|D+|=Rl&@yW7o|KcGY;J0B>*{{0ktO0GMeXDEEGGqnXNO zy|HSu#f7_*#`}*XJ-72h%qkd1#{P81-(ej@WsCTp)I+}{LO8|Jyjb~1a#T43+gBpo z-{NZyk3tqS$6QjBRP2og`f#xJ5kk$zH2;Fa*q;z3S0QF~Hpg^YpfuS<{1}haN&fju z_iOMu)eWh?K`l%ez^eYt#=c$OEVaz_e&-w_gsJdU7*0F0?ysexwE&<#FTTJAOMy*uPbF*q0E z;KQCI?(;mOwzuRo?i3Xju&=D1>_1rIg7|d-x4ClSm#4BF?2-gLQn97`nP;F0GdwX- z`jC@it!$>i1aq5dmA0nK^svsvXQyl!qS|4$Tc14=14y>3E8;Vh$h_!IakKqzE;}mv z59^8@YQ!h7EMHL2Wn2r?gaeSJc#$zv;QcyjV;}hKKyajBu+ZZ4<1(K?j#+Rq{xVF4UxaESWTB$z<)B8gW}t{b_w?;pEu7I-dl49p4e>wdY3Jc;^i+? z3yYk~UIiONh*v$4h~oJ30RdRepKQiIFh0pcV1M_<5`O;Bo2f!4b1S}&vC;tAC4D{s zdp6rD$9X&0bkA3w)HpnsM%25(|) ziwswu`RF77>7hmPadoH~`G;@;-aq3{G86826#2}QO0sib#QV$6Tuha3FLpa~p2B2^ML2Tl z#2Y4;BH(GT-;j|7dlNk4`Y72CHF>X-3KjK}Eh^n|$ygxF&=ma#dl7y9%f)C`r)u9WwVsD#G>aoI>{N!)0||aP8^im&D8Gz21|{0-4r@;zddCjXjnec zsQ<}Tfx6#X^v;V3RAST4**_A%mM(9DEPb4&{FuGN(I>MX?ia}Ba@K`M>&-VGr{;TG zV{P`>HZre7UbL7O?H}WBhCEKG@O2S`X~YQuCP5fa ze=?{z^z(FFn9Mwu>q(F<5sG*~&Or@oyFD}}@$gR07wiZb0EwlmF4_(E!ohdArC|Lkom zc~4w$zXEs2{f8y)kn7@=&1s!ZkIb-!+Y&F2rjAUPACLKP=xhy~B{KPe5C4%IN^Ceq zuhO48)nUt~y1I1Jb^`SJ`lkerBpimT`)*}QM@asCF)s}0Gu`5bin1vE|1Trx{2MO9 ziw4)_p$mA{Mm7^_qCO~4a_onR-#X*=K<_=Sv$s=tT0HF_vAM(7jDku~BR*DL@cT>N zej4>5R%Yur!5{qaJz7_S^e9G~CyCHHH7~uGZAwLIr4P_*i2>Qn;yTZntb8{Z)%$Hd zXn6ocbxaeV(aLgTVpPoN(zzczB;~aJMJAAGY!f(+(+qlC8Tn^`IsX3V7cWP)UkrZ| z`1s@?^cR&nc#xr;uVu>P=h+Q+;m?3vt*%Es{%cy8L&{0G=FmWCRE%8cg9sw^2iN*U1o$E2fJ8 zN1%1WW?T)>wVR;wuNYaOTebM%mKnKyxz0CdA6TIfO}uQp&(lox#mSk*s4Ehg zIGsWc__iR`?cX4?mQ}YqcRe$*ME&paU&-@earwk5wi+2D!82{sdq57VnZ?Wz2t4R&H5z% zn8Zf)7Vks2fHXLeQ7BA(m4$|2X=?d_C%W5p4t zMzv%-xB2tH5BHtg_WcqMrf$`)sy0kIj>a?xA@Z_JDG^B`egmk(4CRWB)>fCCVy8FLi51p{Kjz?hI3BiB8yu z*PSq_jbt55j`CCkV|?*POBR~j+y#omj7fuKa$52ETVv|ca5gXnb@R6e`odXo zwj^K|bGFoLri$WRSg+8KH6{l@E$L5xLcWDt`H9DrlP;jwX!QM+eS((;1|F!{Hd{0qzp&#a*6x*E}_v`W>dJK zE(bGEf4>W9`({}L9omy~i@>wjaO|^g0@7$U$h2FqW$}P6eC0@{8>6iWC5%TAOBb?r ztcxIyO(#mk9;vP0lq*eonEWG=@naE$Up6SQlnJaXwl?F(rlA$M*$IEJ!{ctQ;>^0l zzFaccWPJsWw$NnbQfr72?WFOKMDWr}T=T%x{*M;C@z8gC)KKq-fWH(vxvO0qs%~*FLyc8p zAidsWJEQ8Hk)~({&FinsW~X~*+9n1@#t*5Zw*_T*ift*j^aax-JS&VT@`3EKFKXwa zs+ll3I>w;Lvx8L*|4&wn85gX}ch%SmQ7xn>TK~!*tSAGxVNu5K$Pb;w|0J`E1j z^ER_1B_k=i`@+#UD|~Js*9x|!ds2a-@A|e9HDw?Dwb#wZ75M_JR^Cf)07@K)4(;_V zaDBhp3gTb`C#Sb7&-dIM4mCA8^xgdfuQk0}kuo!p>F{+wt)#rsqFo-5r6w2_C!L8u zCpyx+2{k6hvnPo0n%+G&EbBGXR>`2E$xit?J5KwNVcn(24q?nKR$&jGZYZ^+%1?e`oo>x@c|EGEI2U3*Ngsk zeH8PW-j6X5z3L9O@j^@xjaT*;eX;8Np~`Hz$-{kF5g`N#$tmLHnZNe^R@W%wWL zAkW^XG4W`NySt-_AwH%XRN0*FUVO=U2BR0>nQ^ccJ z-~zd{25gY(aMM)dp8iKK8mkPSJ7$IU2)T~`=<{WRAD7*!_1GXgsk}cWs=gBtY%POxE-ZINU5K{oFRs?WB+!FV>0DlLAdu>azF zrLBBJ<>Io-_UUWVu=jPYz#NhjWJ@@y!^Lcv9^ zVxodOA?VAui1?Il+f-Klx!NCL$v|V%T>EfU4599rtV-tD^+NEXaokP>;-D~qf-7Pz zC#$Zo5KAt--HSMPu4#H%t94?I&RX%_15c?n*8j&Ka}Nzt@mQ0FWS8Rre6CziBIVya-h9{`#mmtCHy1hOsh^? zCMG|D28OmzjeWuIO6Wd=!y=QvdT<*P{fD00$!pf~ZI|@b4n?ERet!6cjFk``(Y+DX zpK{+htK3naFn#Zd0k_H)7i>DxANV>Z(>-}a^pTE$=v8$htZ#z82kMp}EXvOC6u#;V z5@wyaK)Pzg3YQW=molthcZ|HX>ck8r2r&G+Ql&PwXSz^oV#`T1dvr6tTUcRNe}Y8p zyUR^d%fTjWjKx5BIQyfV4O1a=wLQXRBvMGkre4L~<&R|jOI zj3yU=3TMvS|B>7vIXECU+kbRNn?;!?bI1--LmhDK%l?GT~wwFxZM7#*d>{+U|kGyv+pdNCl*n^?(c{K<+;m|KD9jjmml0m5A4 zw3yj|;Wl?4{-y^^?62Q37r*|IU=$ayUY(oHpsH3-aaSrF0Tsn)w+@v5d%AQvAS zevZ?$plqW1j30sJD7sFb>swR-^q0C)z)Ak#xW7jdnz6V9VeP)HA{2@Xm<4(k_PS_^QUyr z64mmno-js&429>lS1*yA8+x&d)t8eR|4bmgSVh^n6Vp&yXGi{#{$eA;&y(Heh;g4|Z6!wrpVoJvxUl+pI`GDjx7Fv<_sS3cp~h za**5i>^l7x?2=uFU0(+wjF3flGrwSB-p%HjGRtV(Mi{~BZ(5|wgK+-Q_&@k4oYlI9 z_R7l1QP+F;kAdyldYXCT-ae-1oKqZO`j$-}=20f(d4~>+Zs)#C<{F~!a({*W75MM( zp|cdzWo4IogV8pHo9j3Pn2;_>T2K0$wT(PcUqdcUtSNlaob~k)T*4(gsB0>ErxIF; z%JYl4ZFL3zZn?>hMVawpwEy!Vb6h^!x@A>ISa^N-|q^rSll z0JXxAtrnpvT*+A-bJ{6PJuW%N|h zt?jQN{s8;g7d(pSbxJ>I?I96)}xNmZNMZR~8i~ z4+&FSj6Ydhc=FPxW)Mt;SKynQkefLu8>dC$sebH{2O}SSB1R%YRFKa5m1tz8S?COh5$(oeB8VT}e z8Ajfe1GoN>j7f)R8F9r8N})`j2Vp$91s88Ph*Wfa;(!)5g(CQfY{O*QY8=Bwi)sV* z&Ut~8Vj8n~f*^jaFk*92VQ+FXdSshX1rh3Y$ZW1A(VPm#?PqQ0`x1Ey z<-~2%Oat9Pg7te4%Q6$Kw_h&{ICBdg0i7{ubdnv+!QOKw!#&7A*?WvV*ek5$jGj*N z348Ma^k@$^;$sMS)Z?@>pJ=Opdeb<=ik{Eu&wyJZ;dC}ysAg+uc;K}P;Xbq3&3?V2 zo~(E#g4L`bd)*H}poV6uA^TO`_QU@7aF z%Ei8AlG4_!yXYt4zBMeO90aG6CkePOc?q4!ipL1Mu~d9Ow9tOzN>Y;XO&tKKpMEgj zomy&&Y|ArdPfYicsHc{ok$q)3S#!u|4r-eg6nJ*OYL@2O81skRV6%fNG`Kto>VpNRb>YbEsSY9r{r{|RHWF%A~y-4RB`%X()J0WZqVo4k<|45u2@lu?gFETwyzrG&H-aNZ@=6+Ezp_3qDyHNEa zWnk@VCuRrXvLK}u(kO;o^KT%IGc$781L9ljeamt)G6~PSgSfjxk9}DLq_g?Uu(E8d zKc~UK5~Ibv**nKVBrsRbUKCxKu&Dw@LOEG`N~T;hwaPtCHWzmlxMapEsv_niDs7b3 z=uJ42-y)gY(krYS3*z~Ugnt0pYKE~?rj!kNcs&f%yxhDZc6Gwcl{u1C$IerEVJm1% zdDV5U_v`cyk3o&S%;C{soDcVgho$MtK)n8p7H?C>i){P$iAxn&R?1yp2uqSs_Sw-* zFSN^fGho9=$lJo^{p-rSxz4G)@F!a2{A^-w&)w4No3g5Y%~=&C^=0#V*M2K@oG?ml zR&RBn-y5$M%C<;%r5lXadmEQg%Xjq5a4f3;61|@iyAj>!v-L`w6`Ah{WH(BaKP4|i z6xa%N)>M_b?Ye6j%eV0BF?Ha9kvU-H{n|jzNavs|tB-s6qwxske|W(Lk%C~RP2l;|lEhIPy&O82R>y~MO-ukPTicxX+dmpk3Pu$8 zedZ8L+pN8tc&yA+>hVK?H!81f(^isxOp_!TYTx2h8$;4pNZvLkTe~o*X4A_Xib(JU zJ70W@)n&H-2IbIVtJAIuv%620vHR&wbcpnaIv4P5U$FY;3LkpQd7~Jow^m z!FaoQM7lRbBzErm(kiD~e2U(X79X}^f88-g1z(I%JV{v$VBc?)dIW75yxOG>}obren{<(kK`p*2n$sj7@9k z{@{BBUx82_&mN!QKgcy?n`Y-lSe*N0L6VHUyMP}g;>n^TC2+Q4eX!&LCX2?sHYz!T z=~u^W8BMDrhhUA=O-4QY{AAO9U$4adB<@D z8>IW?WcU;9{1CBXikHpiP~nVshU-Gu0LLaqu+s!xt)e4$?Lyk5w9EV;{)U^Op0AqH z3y*WG!eY(_t*)cbsjzxM)N(gAvn5IHFKs>VfiMTb8i4zCJFFf7W?l7^ZMUztv#AKu z!Z3?4T0f3HLp;^>BCA#3%|rP01M;cKbEi9`F=OP0!$Ai}=b$sz=<_#d*B+2xssCHF z`vvzE=}nQ#G<@SRBF` zQ04u0M5%3ys9!>y?+Rfd^56cmSzD$e+!cIITn~@pe8F8gp@AgWsFw);mvVv_?})9Y z0Eu@Y+`rB7SB6&`v3^5a(HH4*)PbYy%(VusZb+kk0jj&bo!fId&GPuU8@88*xkMy& z7wo`m^>ZvuGyZt1;YJ$E(fJZ`tX9YuK_szK%5^%c{NG`=^YMvmnqj8wGds7*2!rUt zU%j%+4=bM~C>E1#CzN4kTNZN|aktRfWnacyNnMcAU)DJ>?B(Cw<~*PU&tAUzBpMH# zAQ~MkUVO*q_>pyizOmDUFYs@4>%DsCw^pY>cv|L@73W&@7%MJ8c6J*_@eLJyGee zN^p9fn(Q6wS_7-Xe{^sb5t0Cth~0H0ltx$@TCapJg|RjetQ*S9%pce}qFsQw>h z(>6Q8#gR`dE$d)xEb-c+@i@{qblwKy!e%w1dnNa|k05K9y(OE7_Ua38HOk#wS&(<1 zFtr?2imgyHkQ!c#*Ac$z{DkbA%j-FfdgqANT)fizwkukkt zATKuRcV*?YJ5QA#dhd7yu?QMg;~Wn<%e ze5XL88QAFmvw=9+KR0_}bMN$#^e3jph_YM*=tGP<$Me!_oif72DwBhl16S8|WmLJL zn%C{BWqc*K08KBF@9M*=e=ejMxxs+`c*-KH-~m@{A@RKB?^`w z{n%pUyX|^jyo4*7PJ7--QhcM1uxlnb&29?c}nn2Agil{aFLPl#tYh zrFV&n8DE<84**2EZ#GU1dmcTIRL)y=bW2VcoZ!CXg--jLpqXx|?!d>H`p3;#z_G_z}lhM4Rz z2B8xSP5zb`1~s;Tj;7KI$0X2#3=4XdLgfo=QStWIm3_H#=F@4&@iWO&OTbb3y+h#4A;10e zI-PQZE-#08tFSTl?q2$-#StCI(OKWS_dXosf;bzbbh!Q(YS91=ZmhANTh0y^%;SRW zt4YDNM9zG4wfyQV*K`&1aKb#_W`7lSzV~r7Gue~ch^28E#o;CbL`0a!f6laW6R`j z7WEIvijeGA2`%hzE zmZ?dFnMUEgxsU@V{Yz0-!J2}f=EP;!thY6-vt0~5y#G`e>=?}H(0H-f584SfO|-i> z=%v~7;=Q-$#QRMy(#wHs{1ZTZ?wy_er;YKm@6$V3o(szZEJD5W^$dLtur=?q!rJLd zhizp?uFOE8@OOG3>=ChVUW2JDE8A+LE*e)DN%6)LxDLo`A5vUAaEecCwE-d^?V0Ai zxu~p0?04#8*f~R~g1Fnw@RVTTfPCjGQ=#Y~YXNg~#`05&Az!;{K5os(G*x zIWhK1g4WHvLnBVX`l|s$DeYD2{K*XjU0OsTYGi4gclP*z)1;atZqSqN9G-Rr8x5T* zDQRR-IPj6C8i`--c_(%!s)Hl+qa({}f^FQ+xaak(%h&qtpbYWNMgoh#-Id&XZ5lUH zl@E;nScO^&(j!L1=UpKkl}aDsw0(cXDGW(Z>@FJ71)ofoO>LeLgrEdSBT-f3sof}=U=%29 z3%8%!YHn`1`3S!p@xO2yOn0$L#V1X>H{Y!ny;0yJN%L@at}SeMRJXL-@PL(~13XfY zo?;*}<(?(!z|36P#5ew8l~3<(O3Ayj|LBQX1`dcHd5Im!gMb)U~LHrN(iXUu&x8L`PvoT*_@hDpqvlR+e5 zZ)Rb{`<)YhPj@LOm1*$Mz8pukkPqH2HHc2#h=foV23u18y*78R!Hw}ibAdfWE_6~x zd9UvNfhokiY~TrHG1R}W$haj%3g(8c4PSpN=JrZl=1|A_QdS%Zj)~qy?Ib76#N{ri zP9U2>Rm!bkF%2r|15aKPhcA^D-qlM)lL?1sLs9%%1g;Z`L|{du;+vmgAK+6 zhj|B_mN)hISojp1Y6q~^c)YU9JV>lDX@)e_&NNfA?X z-W$DfH_w21%2Unodr_oyt`>a0 z#9dJH{&c8Om0Ze#H>r`D&7a;zVw#3B!`pJ>i>Q)-RO966%`nlimwL&d*X|BZcqacD z%9hfrKPPq_?l}zb1trv$Q=o34I_r0dQ>Al{I+RGdd&tTbtm6a@Llq5_GV+y71 z@1BdWWhq)1BAfG>HjohCTHV2O?3*iB;J|@mOm_sh7g~D zp{HooqEhkI90&aYfniGh{eP7Yx>ENzhnm~cH19ihP0^>#^5=%xKl??cIf=@T3<1`} zXsNkBm~16dYe?gajv3rXhjzDP$>{~>qzYinYgDm|d} z6~7qV>UnU*o|C29JoP=Q`xoTM0~d5XvtWn=VTgATQ4!I(8h5A90u0|(;gZx5RuiJ?=~QyJ=ovbQ&P4@DRg;ZBl=hdgjPd(Mb=sp z?Gm1;@l(ER{|YT_wrTd|2x3GC#JV(I#5P^P+^5=KfWo&72<+Cro}n^UJU*K-2jc<7 zEvMqd+XKlB7&k_gz=O#y#y25Rst0=Qy%TPJBF3s@d9|@UO4VzAdgJ*P|43?|B-{ku z_9q_?_qW=3fO54w$pk$x_>ZNb6b?U0fZ^ThlILd&bGkth*^hGw{BvD+ab{Xb?LoCl z<**CxEsq{h!Ktvua<&y!t6Q%C`BMq|8${{vfg#m(VC-N~n1#8QH|5tsoZkU@IDT;| zOSi_BNGF<%DVex|o{+V>7MMIKrjf&SFm7W18 zyGm@!i)`vICFUB6FZJpZ15e(*3j4skNVpZ=X4wh2BdQitH)Sy`fb%Kq?0G)vfB(N} zWLV>Wjvm@Yqqho#T$mbrUq#F4Qbpltoq|4J7sa5rC+6&%6^e=k#iGEIR&}SvE+~#J zKv{IHT0Bt~1}AT4HXB@Jd{Q?pfP#$V8z~EHQ7t$+nu({bS#HN!q?&ye^XqNC$SubP zR1kRVE=}*D3v$|P(7iE^^&_qk2}AZ3hVS!fSL%mjf1^y@)}VS@KX(DQ+#fluY0R;p zj#C{S^4WZ%M-}R`817%4e3nGM4m6xY3x?E!XuRV!xYjQc0=co-+NMZ9h_I+Mc-K^s zpRGY2(k2`(0r2=zVLSblJq-Piq$kR4ZCObk%eWTf!1%Yji?S59kNAafsOnobk=Jez z8hXOt!jRBk))h<|5C6UT9|2$kBA3Z|giq?_aoo^cvT0WL71esT=-=Scvg+peYmIe5 z&!)Msgl+C``%`RF0q(_I3V&S~He!PUcxwd@_gY7>*ncX@v|>_TPE;|geG>UFt7)M3 z@~;_If0YgA2Z$kX*t+dI+cY+QPVe>{2mWF}{QM|5&SyE2R%B})Y54YbfB(Q>ntDZxcCF~eq{kQRp`b}L|HSXv`9QZZ>n?(fPNBEl>Rn&+5P@Y(b>o< zMH80^{$qV3;n3C$zJuX1t514LF`V83XqLfs{l1V(nL%lm=Z@p@@r7ryuO_dLO;v7B zJv~)UB#%gyOy8iSi!DpcY84nKxt(N(zCpG@mkU>M0r2D#B>*O9NX$uad$UpC&DgqG z!z!zooo3yTM|t4R$c@4h1mB-xIV}Hbk7|R2%vbaUk{>55NFU{rFUH z&Jjeu{bNcNC#yt@qFR9KMDzgeI@CV6hx?)P{I^Pqgu`=#t*0OV3u$B;TNXQ43eM4q zP_05=6c-w6ns20KDC$Q>1QpxpY8%~Ymt`Ut0!7aCObUG1?oV!(!cnU~{=D2Mk~qb);u zCpZ{iMOBMSZX6MSSdPCenxUPr z2A5J@6>o&WN783bWDixgiqL0)B$%+Hob{Hg3KS|<sj9yHTT=lOMi^2dDEfkL%DAi#pUi>zX5r^o_y;q{o@GckLGWP5YrZ=AS$Pb zQYPtqJLc6dZBxZkwP{7?oEFT^MDChz9*?-1>qHRIuSgH2Ml|C7xt}#+;}E?Tkf(fi zNBLXt)nlskEi{vr)9^+D22DJSF z6}qrfU`E>Ws<3uQRv(ry$kDUbraKGNk|&N4E8Yqsl zc#CXbabOF+4-kPtv#V4~QRYzb~Yg>!;i_2|K>!3e0PYpkSR^R@DOCql@Y< zRIW&IQN&aRS+^bYBZK`RW-a22`x0jmw8l!D9owD#TxQ#VnlrTl2JiH&X~&gfiPH+1 z%(cT4zqvv1^*djOKeM7Sg+zIL7_G*m)7*r!_>Wwdu0dYRN_~Tx4k#|Y=33BMAK`Y+ zl!Nb&{k^ZQsyUYX)~yObHc}B>?f3zi{y5(}pv$1EXtC_}!xDhdl|`4S;M{&_lhwpMJEv zWHaQQcYmRtnzr03KGm4fTkRHa;By$hW8(dnsjuO>Drd6sfkAt;ox!$F5l~1H6&-F~YVzM6bhM zs-I0{#SN3c)82sq~bh zDaHu8?AownSOmrSUdZ7r=MXg>_qN}i+8Yfml zL!Ri@K;(^`q6K6!eY{Gb#nT2>?8;`pE7iJz7bTn`JRsE?hQKF*bAdo?7Ls4@ABjec zR-KXla>||99xJWv*v7Ffv$%Fi^!f74tcQyKR6OvK>mmiKQ)Ft|UhKp))C6k10uZjv zFt8C2Jx*?BVWF@_FdpQ^JAf+O@)Y*gPNi?)tIRYY`3GT%N>>nqH_C#QX=qRM${F0X3MEB09|v|K#bmy= zZj1g@38_2JhONV8l^3H=)Q%~xXxYfmxMyUWDMVcnubNU60`9NfcrJm+Pc=l8-le>k z1~+027dV_c^dy_>C9AA20Tf*)b9S3^^y_=|7?aK_vCWSa{)L84zy411E3U`_#{{P= zST!Y}>!r^q#uu}E=`o5&eQ-0zkDiTuRr8+s1=eHeZ`EM6_o4`F zT$?ttA5&ddIA9ksOr1IcsqU4FS#WpuDRFK_|7}e;hl*_hUm*qRwoYyR=(MZ4RtG=k z6n?_jJTRdwwZ%%(caE_|0^36bK8We5=avYU(=;m)L44iPrc z6xisI=~eELa)>Io-JTZZ5$Y#ehHW=6%>F>Q=#pEh7ko8*KPa9tuJV2#`s`kugy)Ka zJM$?^oTrv803|n*&X=w_Ofhn+`IWhMakAlu<9TX0pGh ze8ei=4t!;|$N#lQrY6`_GoXk=S%t2NO#@wbH0n5K))b}7C|h2rp?b1S{mrIX4{~RC zT^Z&#AwzFHJE$6;@;AEYMr5}{OGl|+?g>Qu)Le4cLbR~% z6L6x(PnvCC`2zG%;+RqFKropJxMymX*Ugcr40|VniX`1L)ud~ePdm@>8q*`Gaf8BtsttkT5Jc`vdKT~_U5FDf}w{t|Hf-k_50rZ(lI zZ4By&1^i`~#-^2IiLAHB1i^;ms^1hFskl}FuB|6Z_4`^<=qFS|45K=CueFDIbq-$2 zqBiPgx-5s&ms%Q%w(V2pHs`9noeaV8;x4fyzyz$s%`JaFY*?vk!)3EvF#WG=-4lGN z2Wz=FIotiov61kl1Vy+D7^p{J|RRJBwH=xU0yRvF&+Wx%Z?r34vKr$pQpd(*S?qFkm7lW-KlGp0=S7n`J z9iXSug#QY#f&-RXov=ObDF)AAx86;>)VK?1O!em7Qg0ESYi6b7b*+W2c1#P(A2%0W zP8fbSeU?;9cOh{6tdhWOV0eKaIHo38JGSP3_kucGX^^eS=DFR@%)v`gSDSB->T$PjIz5hfcZebIN;GY5hO^KeSk>Z>$F2yf_!mB3JOh z<-@hBYwQ- z^ce-}ElE?k7<=ONo%#o!wTV7PMt0uz$`d&C3Zc*E`RS6V>UYlbvz=DYX3WUiX^2nL z5`OeXZ2nSb3BJ){uLy&yj!`-0Yrua?pEC>z(G54iUv98vwR$Sij^CK{uFxR+g8Zbu z$bvBCnp46mTj%zEjB9&^zn9G8c#=+!3UqrPG$ng+pNOwR2V~jsGaj@7_mbC@JNy}( z$oHN;vKkp|p5mdiuncF{q3(!_9AdxoPduu!uBS1nsQA!lq%Tb5wy1{UwszIo-IZNK zxcuNug0F8W^(Ji{)~*iyG3BRUK2;|4qPjdD{Zb_Yjgpz;* zsK&V8ZPwNMUL8WPxX7cOZY=_*iAh|)^$q7*lD#5B_F=$7o<4*}lTqVxN51CWmCRv5 zHAoCzO|g344(@DAS(mJmryv|^6U0vW`R}K$=&l`-lfzjQ{m3%wn0|GL~;w>xX^^VQG{l3R)jRb?k9bXZ%ByPdT#WU3Sn4`=mvER zaoBJoH?dr%sEJ>P0wM+*ep9K zhOF9?={e*1lp^Z`chv6Fb%A$qSKq6tsTFK8Wtu!DFVx-Ttp5#OC386W4pz=aVyj%%T>M}~U z5V1<+rRTre}Kzd$MfDF!=lQB;U6|Zm{(8 z=t6h#YL3~o zeK~|xZ}6Q3Ja_w46RSR6&M(WYh>P{-UXcE5Q`~6oT!~}9%P3~1v;f||l@fKu|4o0^ zmy3DRo7y?QxF?p-2lcjxFiM={2r_ZD>hqVGi3>ts3=Ys==QyQ&RRXsA>EzlXG!9I> z0=Q#3uRom(d~EfSlQQ7{k7y-Chp>F3!`R|WV)qHFFz#4`P9B*#HrICtd^}%LctpTu z{chjq@FF!irB1^b@@6Py1i$2pb*VxDQl>0Xp~3uHw}9698V5G;xz!Zc+Sg^s0s-*w z@%)`LQ+lxAuveO@KEO|IZLjCJ`CHb@J8eH*>RG!bILRnwEyM?7c_y13$}vtKg8Cq= z7KZh8$}YtuU3}JC{5AD9`AQEs`1_nrM0`padQ)r*S;IgNbIDY*9vw*9M;Xm83HALN zx%%M{Jv;Ta9b5QBfI3&LXyC}+yJa;}8dPHY8nOHi55}I~E`ktOCbJE(HTxz|{9se^ zRg-i*_w8!q?O+pZJ5RYxvgkHV3ouJag8yu(sw3F51n++GOMUck$)gTjuk_L1+w%Kg zEW`YAUUS7>z4!3D^|VDdXss9N|n99zW>H3ZIu$lt&A zxIieg%xvPxQNpw|@bxyj9*$bh$Zri#y8G}3^4)@te}dw=z|$WLvsiG+?@-LsPUU^% zWznSxl`^01Uh5(}q#yKAFbXYCisDJhX?(J^btjViI8l6FE?Tk+grS}TV%+r*f(SihTC?9NCh{`$cTYqCf4 zVFkGnDY_Sw=#d+jJ9Sh@O^z_>J_sTLF^*-+c-=(oZEA^KxeONXy=qQm4P)g&_qd=5 z<`rG?fcboSA|3iJDH$%|S~s}1Bu`zSKW*@BamJ8jP8nfCLQ{mCA1!jdN5ucG4h!_~Om^5?VVb@V+LEf;VHMQO0GdWh@@}no210ez4nQ z>5*p+?(dV0y6009|7_}}noG^<92^XzbO9w9+R<>@_wOSvDJX~^-2-&leoXuVoN@b8 z`iB8^FgJa%KdAv&Rkvn!({;rJC`A%3$?{i&A&H7IGg&~YRG zyEMPsoM@Y%TJM88pfO~AJCmW?Ccs`dUFCj{ZU|)!RQVBc<4c_;35P-D3dQHz+{8(W z?Zxu%87?-ue!Gw4Tozj`%g+{UKF>Fl-v|;g&rIU%1+uBsG8`>xNtIB4s1ah28#ZH$ z+9+3ev_kbU;qgq9w>O|mm%U-mwy05~=4gd(ismhqS6Q6+d5`Z^u+^QxW#42c#-Bonf^x+>l|DL zb3MF-;w>p~?C3J*)Y!71Bru|PRst6{F;>7=!Egrm)zq`FyLeDDCcqBSFNX$Ll=>f< zET1t8e^zXJ*ykvqZ9dr=Ke2M(ekQCuY7<4Qzx==qBzRU1nP`SOeleaz`hv&Izm{of z#MTY*_9eQqp#DM-MDh%%EPyS#ZaXq{xDn;{vOB2^a@8H_Hm{(!{_<8v!$aR1H zLM|9!wbE@+OXV~_$ZgtQc;I}`cnDEV07j^4gZcGWIMc|uYDa6B+po3k&ciKiTjg{e zt_)*SYdc%bOQ^uhQYiv{d;hf5)@w+fRP2=3?67yvvd-_c{QM_!#7`mNJ0#>e*&sT( zd8;Z9Tt1d&wWem+?*B|^%DOI|#mo!1+>#4_C}43}XvXoAci0gsz}@ck9d4PGeThR2kV( z2n)qMn!VN2#Ou8Ohyt5*AuBxd#XJM)-I+v4@OM+&8u*(o2+ePYU8B_ve^b@la^@eL z4jbn4gX}#{tkDQO+?t>=Ty?Ed+O>+{R+D!zjV5C6-@7YL2(MPz7E2Y?$Gb>H-GMN` zCP15$FZ*uy;Dk=0vv#?`e?+?NGbfNO!FOi*%2V7VNWc_%p~97x0?H$raM5kaV=(E9 zW!*>F(yS?^Dj4VeG$_VRuq5X9aaIBdI*85Y^fZGqVG??^^EJF@Z4T@1`e^-G-};EC z#&S3z5Y}=Hn?G|s*1zEj=9u8KX*F;oqTT{7VgUrmL90xwF=Ho#|7f9W?6?;gwfCoi|KcMCoWFpjMw$d6SD#!~xW8%_ zWnnohw>CoihHy>379XAZs(j_6d+voe?Lv>_VR@%Ekomb((ZZBg)A=UoQ>F*ntS9X+ z-g*(_RG~TNno&-D+2CZPWmYJ_0+U}v^*wZeHlYzaZ`E>$vqSf zG;bV`Lt4f!5(=_`6wxZ^{$vWKyHVChC_;eSKH|Kj$%uR>tnt=l{y(Da(zd2(P8Y3SsIp$n!Ohv`$JUe+?yS=RAgznm-6S(r4kj1qsGwU`wGa^!c+82* zk~Yp;T(rx2R#=PF$>@1y!j%Gy ztv`f_OUJx|OfKsWYsJv0(q||y-}f#Ha}|tL$a$i2zAU(27L7DJBg2H7JHfps)T$Zo z^_}uv%(2T_6~(Z8kTtbve!1f!yX`N6Y+Kig#tSTsY>Yiq70u4Kt73;hRF0W0pJ=YJCC>9p4hf(W3)OMJe$H)e*f4&-v-oy**uuPCH6Y;Zoz|$? z`WAtZ7$SZ}b1@cSLLz7@9cRuig zxya#%_$b57$cjSAJy%do=rL34^BC9<)`wit7r}9wCBBXY;!cH8_a4iqNI9?}0JPiz z6X}8HequRemZo3-jYd|Z+b;mi%(R(`E|O(Y1IapTJ|4w7#;V3Y^8#|QM7V6M%(~NE z{h{1NgTLSm7Vq@Jo0c?>;%r1oW*F&LiSh{$HRd1m=Djk(k1*te-O!@^mdeM^FUofp z*Y#TJCivT`z$3=8xbFpwiK#;8qIbAS1)fo&_cQ{^{2tGMs`!V+P<_wNuve6)I|I+R z?8U_ucMG|>3xKM56Q^IGKn~B$H^{`0+`p&#iuuK%@rEIcZ{NNNT-v4U{N5){lnLB$ z0_=Hoq7VS!eHZP2M4S8b8?DzUD-l!L1&}Wm)rEz=VeBu`$Rk;%n*R|s zx%Dm|eXza_jRI$Vl8N<+L#{ND#@9Dav)Ml2`zmD-GZdB8O>IXw(bbCNYk@TDh72oS z_0z(Vv+HyVqZPzgsa#c6CXL;YOPNo~wBwT3c~2&{zpBXS|0CNq47iKY$XDHr=GA((>=5F; z{#_*6&^3iahv`oJUvWOH0NF~RP$O-`jYH|bt+*CYKzq3~@hgJUXB~HFESmPLQM#x| zo)YmxBhmY}#3xuJUGIjBVU$}LLcA*z{-m}BmC!?j=u0R#iq)!GNqW`0m>caOtDD_A>4BDQ1;x4KLbIzt2TpUEN8IR&mgVldX{1(V0Npyh+@nr&B~x zw@Ho!UUp_pwjIm!G_5H|8bjcdS!d12YW1yz+tCc>0m}>dTnuz6H{!>{?kzXT0@k}g z-^fp7Jim_IlG;|`Z`0zb5B*O1aD)G;mE ztEUv)NSZ56>zr~HCSVE?6GT?SmYq&AGi!<`_N+(T&G!{>DeLS^vsxa0j2pT{s`OO; z`73f)YIV!E#wiXH31>_4J#^201`>L$t;_{+?!F%HxQ}Fw>f-d$<>Q7y&i+hDg_^l_Qi<&i|H0EB4gsqh!)taLvKS* zDg*nMQOdz?PaQ~$*8djNPL-W-JT@HGQ@3^b=3V>g1VRYS4ZR6$ly&Y3++l#EBmCzL z{d%azen*TXYRA7Bb{HEZki*0+w_Npa1r$u6qyBNOwqmgx`u7dtUzI4u;sW*#;*|Tu zuXPWE`n1xiVjfUr$xcd}mgJc&8rK^pzBLco(9z0p{T^RpZT|AKSNxzkf`z29zBatN zj;`via5m{a|0r=OD!u5Q>YedNbdRUDfeT*AT~9C}m{p|;uKKW;R+Y2$>0Eloz~^S~ zi`IYnEx&c>tW%;yAm^gw$C%J|IVyUR(yyy6A4;0<(w|1xm)m4I2=@4xW8x|F%^G6U zN^}qhgy%hJVv|k*Ym>Vz9T;(~erqswWnvdOn2dwq(m-qc^{h?%vB{GyK^>Ci3w|Xf zswVga>>fcDTE0B(aZ^&3&hhkcWnX5OqfLqZeQm)mb}2ivpClO!AIYURo_+hKRxX}! zdimmf1RXA`_O()RsvhGqGz@Bd4G%Xrop}grT9n?Ltbs4;R+>K>{qmE!eb|FXrWM)8 zz7(++DY2vluAU99#KmLjf}*yf%&~Z#DgEKW(xF#2=vcw;`C8kH*a(!KJ9qf^)wigs zo0n1n!K!7uDhb0wZ+G7HijTsCt^$6L&%EJzcM-OG%lEBUEPLOWmkNo?p>02@>97?G z19`qU?3DqQ*1-5ITxu$o=b-7S@<(y#ECh9HU#Cp{^+jIFYN9;h8Wj5uKCL zDl|7_r>EAg`^n1x6_`4^x%iadeJ|4lpoH!>LK*7yWTgyx26Zr~QUG&oG9KEq`rUJdJ7s4oPbf)ZqwQsYD1K0uupTf8mmYv_S4A;{J?S#t z-L6!y9jJRtaU-?v{Fu2!=p5ywYm|>wYxG-|Lp*&`?Qe>oA8r8U zfhB4=m`yXAnTqvE+acjw!Zs=56FW)7N%nY~_t;s&n4L_#p%d@U=o{s&9l~F zA!B}c*$S2HuEt7JMUOMM)MLUbAT?V_!v_!){fGQ~i{Y^EB_dygJa}_Nh;x+1Hxa8~rFH1^u;XFnd-~2k19A zalZ>zux;>}4d~V`0%k0ic4+G!H-$@rGt*|HoDF6f%1EgU{`g$jmN2J$8&u*k7`{uK zJyjmpIHy1!j*3|LTrrJM*(@uipd|0r;M-aQm(-@xedXEDM*ZY1R#q5>e>6JhoJ3P@ zUKCF%{zC$%^CnE_dQ6vX@|ZRnil(b$Y&&n6Cpa=)vDf_BPq6nn7;x>JKKxt=&KQ&;>H{HY06^{5$~#V01TF{8=1V_E|g&`|Y=>ax>4H!$@iAI+Sp zNHCbpAE*#^HY)BRadNO(6x^`}Vjrbb_uK{ma+$YhE+A@>iw1%p>AGP^fm6(^&;Cz` z^IgVT+yow${~=*4$>H41q%kNE52&k7*_-Trds-?WeEf>~pF9I0hsB@uv+bl%>-JD& zT(wBUYabo{9=?!Oz6Qd{*O!RTx+#va7v72>7WEc;yKI*FijJ>!0~S(!4O}EKwOUV# ziE?qdvD-#5Rk`bWn@)4F`^FB5CBk8!@i!0k1NOYF^VWvd!WCp{UFyE?``Ae@)%P1jWAlO&yb8X;*b;CeVh(F41G zaW6*e;ert+ewP)6Dxs;Y1|NMwki8}pa>^3!9AuO%X!U>?F?H&(U5$zRp{jCc`{;%O6GtmM^86&uQAVj7AZ@K|&JGj}VQ$?$Fi*(K|x4xi3EpECRH9r)hs>KaB+ z0|UrrFQG751S!Gl7}FPe2D(}_g&FKtJ4NW~^3itH&Z46v)eh0&p%YtosWFsL^-{0Y zbN#eqMD$i&<;@V+nQRI-Yp}OY12sz-P0lS>Pj58xN)CkjZ1?)`NjbK0Uv;aSI(K*lV23`6A>Cbwu0DHGe}t8PNxd$`@)X!oxr4z}15m%59`c+KJBj6~4+bYs6I3cbLc zsi!fg>9*PYZbO{s?#!Co%_ZtDK`k7i+XGu1=DRMha*5)4yiJl)>$20o3}=YRO-n13 z3S{LO4tV;z^~}en#7J1jE&d@%@)ReDhpto`zv*gyP**d1bTU&`CVF2){PS3*b+4$c zF1&hh;aYJk?GbE9z+1{d6{%4kEIYTBR)3L2=8ZT0k0|0Y&1IHw&d0to6g(}DaF1CL z+E!ZXr*HFx*b*A4sj!f;vSV3tRc~vVgWo68s(=6Cmzv~76IGujR{V&OuBZO3O*Hva zbhbD?q|)kuMrKMIL+Hca=MJqgrtL!5Oq|4XL|o8Q&N(lsxjuU{ye3@bwCen-lg<8d z4zUII#?eNdh=aGiG)4A&55DSS(hd$Ev1+X~G?a|LEeGLO$LPu^bv#9zzl$l|F1@1b zF;kV`)8v&aQf8ZROEfO}*rR1yJ;7NQhEz&w8WPG?k^iDxZ9rpaxizU|4rq-jfYW#h_E?{}|^EVGKU z*Vg53=MTi6jRmLH+>oFLvaAQEgv<=T?2;u6TmTNi476AdJA+v!?-b66plbKsxS*KO zIy)U6R?DP~>$L!wR~>=&rMl7lS+A<2y?mUN5i4Cgsl+wAJI$E0W5ZAp)ZKe9{&KUe2Q z`p;bV16GxC4$RH7$4|E~4(p0@{e-eiSsAA>zka2|F_ZVF<-mn}#4MrqP_-+*_iKct zn?~J%DZc2w9$zF@MUv;nJm0%7x(26Ugc450|Lw?Y(g7=K{EZ0Ic}#}Pd8wnw>P933 zWucP4nMF(F{BjNkPQU&VglFj{PyJo^;5LnBZ!aKL5A#zM=_{YQ5Z)B3*LwC~&qIH{ zf>r(@V%I(TolWM879_5V7C+LEy1KvJAP=JY6QD!=ZKp;Eb%x|-k3oXZ zVv`Jt?9di!AGC1T{~YCqz-=C+rXDw^r>x>r-<T@m8 ze__0mIEG<0NYKPcY{Grj#}F?X$A=yk%f)XC#VnBBz3~U8&0I>jjks#fTzISaLVfl5 z&ty;7A@@(@2F3T@0>@GN@{QZ?*U=iPxzo~s@+`pK+{)8wA{Gt!$c8mW7xCnZ;-rUt zHTLv5A)GoRLebFt#${w0^(NDvYCe87LIZn0d8U2 zva6Ghw41PRSmli_ItRN=t;dHbK}tkHVp@seG|4wXwip}EJ{G^Mi^jxgNlif+U?I$BS`wB#qnL4rfNi@c4w6`LMSM9&m8u}-mbG$C66OUMX>vF#@ zeG5E%rJL-dSwaK1ZPm+PJ?}eHKNm!KN@%r3?WUt52;=3D=T(!Eud+XyD6t}3n!ky+vuYCZjpi~vT|h<2#V;=hvl2j>Sp?K)V}kKg7iDqhP0u;wg~mS%23VS>btEpH@pd%@mN!r&kHk`n!K~Np~Pc*rz6A{I31A|A@PX zLDZQ_$&}L;bl8C{!{X^Xzy8};wF_p+ljY4voBPLx9ltfXU8pljx?K-sfJ$`C+V68; zO3iX*xV`v~=v~TJIM#jqg@wsW(XC_Sab?PpmtTicEAINHZ3C1R#d{p|dtLVpb2AG% z_gnGPxYNLaF zp{BrmmvF51dMDuGO1Ft85T3|!F8E&Ov1k5KF~IMwq|wIQceU7#VBA^HK;}+#Uc}x4 z@ymHX+@5=sr??9e+10++^`ZxnWqYl<;l%GS zYk$aLovrGuov#wNHkaK*2*4-j#a z3VkEHJ~rVt!pfoq(pOQhYP~8B3fe@6&xRH0>(!)~e@jlp#7mDJvOU_K2gk zsr=NpdD=c>!SCaUY5$gI=l*!#weVTFw^g*w!-}xRkSML-?&0(hdYbhNBVqEqjC!)w z9Yq~Py!1*NFv0XX45)0rGaz6A9C`hKE-T+B`t(PtEa7)Mg*KGDj;+jqs(;sE`wQ?- zwf$wue?;@X-o%9w?Y-WBA&L2+-v`3K?|ySOHuCZABd)xV`cltaGFe*Mb*IPGSFUrH z;YPjkqIjcIrIA1zQc6#gU@m#MGxBVWoM0yVF{*5pZpG6A_HXX2M$J*9-KnPl2%FU= z%N}c9M1pSBM;Wf!VHG7sH;Mu-gjFb=SBJ=~vLiXRX71i_Y}U(PZp^K`9K)+>tdh)lzY8nEf_v?^R! zusZ_wCe6svI}p3#csmamU)qaT^+Ore$p$tGsIKkh^)ak%wYc=#;4tV^3PZ6>+t@8^9Oiwit zRV%1^^2$z+$X3XN@i;aEtpv}#wqUr>V0GyH2K=E%6v|fTM$tw2TjrjD#modtkaM2M z;1_B_7SnJNGIIlN-RF%*MRcJA!kHGH2P957x_RRfePd9@nf&l?Wq)k{QWKX_NxUtC znfXlHo==VMtSRx2xarpJPp2mFi@*5_gz(eolhMQZO?jWb*?M$mHb?#U=B7xt^B*mS z6B%IWbdsSi`()gFk&O>mvU3QT$lli#o+x7IuJIeAIKs9H!VnYuJqG$F5}b6T>FhKlO5{F1C-Lt+_?3o#VKL78cQ|>$TkLSt)I6 zL~~RcqX_jS$vu&0Hnkyg__1Zy0$0q;t)c2=G7W=&bMn*7IKwhD=q)fK6oP~V3pfkZ zJ>Q#r>R<<2W4{`^;D3ZJmaOZMG5Hv_UoohM{=%`>9zi(Kb3Cr+wi-ebNGdEOYYSCUZoFE?)f^4FyDoR2^gZ}%%_ z;87)cQ!%3VJ~u`179LwAD-SuayORKZL$K;@=w9pUG&q+ofn9vbdn*l z!LZjSUW(=>P-9!!2bjmuSiq2H;lX4dK4;3#>$^7o&~ey6u_?x-=QKkVhcI{Wt}$Oa zM*)23Qw6`(ND;~kUF(a>qPyJ;*Sf#FoHTxN_D7Qmdb`NQ3`)B%c?u}~M|h}N?Y7~~jsdl^SZOECPp*ocAx`B~Fh8f8~?82wj$Q z%!qs^0Nki0%3MH?>yyodj6|V1;`{=ZykOZ5|a2LfTxl*BoTN*}LMP|4j{4@&p=zhj5`Z@|5+Y{(x!M?1jVgdc!QaJLVCRM=Ew2+-uES6}Bp z)t#kxLL+{c2BJQ!`ttCJ>zNH!ip2-0D!|qpgthLqx0<+M#4Tg@4OiOJf!&jzL04}V zZ_cqL-6Bpk3#hyvv=(~m6t?OQ;{!j-HZR-#Wt(K66DAoO@oc!bYiKI2_H;}GY!k-F5aPdk5~H_ z%w!jr=+dw<_J~`-b0wnf`a8ZbO_=6AEBm-|zi^p!@KM@h6H}Hb9ZitX^-j4LP(HGk zp(QEyZ$DctXXZh`xPGsgx*(1qUAMOJUX`iV0Dpbo)%qRE_0EksGJ17<69vcFUHdWT zXsi$U;ffU_jFsnxSjFqGujz3{FZZe`Z=61%%F%s4{A^LF+5|xK;K6EFss-!$ap7XA z>k=E`o(E{w)$K#bOPV!<(|xs9X+~B&L9RIXQI zo-hKE>F^5FSRJXWuCDwak)z(0XEbbH>!nzQpJ$ng!c4+@Tyl*uWN{n6fC>)&6(xI` zSg1GlhV>70zUl%V8)3$J-^;a&l`?N^@t9Aks>C|F4l$TFM9)&O0RJl+Wf*Z943+a{ zkLz27Heussqj4E0U<1_qR7Lo|rnrk^J45F-tHs!(v~H5hxaU$T(iSi9k4+?EYY0?} z!kd%;n1I$s_+%~2`mjMn0&DufCCdyM&>{}(xCb}5N0_Yh63zuEKxR_WD|5Y9_t|4V zThqEk-PO=s5=6*-AnkZ*`I{0(Uxsj(*ymv?%i>pl3U5MD-)i|WW67E9>Tj^ppG`=7%NvWz`Bz3QhUKJY2v1lFl;pLf13K?{&L^8C%Wz?ji6HU zCG0L>Rk57@xWhw#vRuI2)rtxBMJpVgBi}7$e<- zUiUrw37MpdyhDWW>f*x{3n_EK?PI>B-;w8|qM2GUu^$Z|*}gKeWwr414XGQJ{}}xG zotNHfR$_8oW@aLHBwsVgX+z!Vi0sUvS;G{K6RzXkn`*poGO?NZL^cJ|S>N*XLj>Yj zkl$#rIXKDCbAX;%{J8G#X&@xN*ySus&QIr7DSIQ$B+LW4qqspON~Qb$KcahMH<8^p zX<=6OZXNWc#x$4gRp&=qDRx7*{jrds;F|#0V516`EGI^7?rTO}E8TReDDyJ|yFCc# zkLqzAn$!qPk)!eKLuGfQT`;10wf+QSLJVQcmT zV~6N|w|GXj2Sxx!54Ajx_VtQmBe~kp4^c*h%b{~nSjWAz+rxQ0=oOq_;5Z{n%#k@^ zTaZpF5R$@*^Vts{V{fdR6Er90e{c1)vnjSYpv@Dx(TL;qN~ed5F*k2cQ%^-~+B>Dd z+-oeKP%Tb74*3fWTy8Wj^SlqoN4!_jS~9J3Mz{BZlJo6rh2oa?5WhU>AD4~83@(@@ zn7_SvMZFKvw}<+2pfU#A$;=qvG*4*qepKVXCaKB_AV%lX$Y@~H0f`0T8Z#B?13 zs(Y7Q$bi*+Q{wON9WTgD@ZHF&8#<3D2J%_sSw)FN)I=)moI?l0^ z)#fu`8Iamd6tjgw{`5KM5yp}uryj!IQkm7N5CLOeVl~QRG0J z%UMfdo7t)B=L>OzHO|S+nD-9lo^#)mc;A*w)YlHT08`=YWZFAVRl~}cOAj*uQv5#8 z9uW&g=sPHzLS)>{4GhxpafPM6N|Ips9zyq`AJDbfw_U%{;?(t?^2@8TH;upa?;=zm zoJFjQ7%GLA?A2dj@?*&WUCBxU}*&I+hGKXid9e6APQa9$q$ycYP6ZDk{&C*-60%*tV!8EHM4_f zG7+1UUx5TgY8iFR)g0Q5w(_e47aSCEI4NmRIFyBRF>3gbCYB^Rr(U0oY*)1qYbG8_ zT26a1dZhj}d-iY!n0cOWS5(=U)TUdzhAyY6%2b()*e`WjP9Cy5yjij|TjwyIF8}?a zRLaU0QRm37>s8DzrX_{GV^oi*a2{_Et`T0J(CEw!G<1#FzGe*Nz`TjP)e zCCB@2aR2OXNMC3;+a=-vrE9Wp0~&?gdV{F$0*}O^y0O-6|dLh+NeDz)C#( zS@Md$$Kz8_MPJz8T7_zVX49ZTyU8pKFjK_IQ+0K(ue1PFc%Ry#T3@Vwrc^1$PRzeu z*AK!J>j<0|yx8V7oUE@SmGIC_lEEg>Qh&F)*BbiB9;S9G-1d+o>=j6@HypVsawJO= z%^7Y;=vv9bHPbI7!mv!0QO7von6wtMQB|+`KI~q3?L{=#fmQQ=%t8LNOBVB;MSvDqPe=Ma?cBJld(}wB)=G>LSL4@lR~bfy!Lu_f{; zUP&BUh7*dwyR+O8FBl*GT7*5O78^F7Y>Iv2^;&N}zrH(&004fS&{_B5{m-oBqU(~@ zxZRnKYv%m6Fsr+2OiC?H36N-u(|e7p&?_EuirC~~mqwAZLn9(31;naNvmx=2EKQKSoMV+MN4>NYw=DUD4clLD& z2a0dyH`UnRb26Yv0e2pRZ_$c3ta;a&HTGf+oM33D?rbN<&0a+rKrmI*mTKZu_1Yya zPK{2RIB3MnNsQjhtQ0AMuj%&&jD;u#?C$HZPe>;*fyXOtRFv5o{1sH*^%FN9Ok7dfY+Q;j zzJo%Yc;x5_65we9o;0nj8QX zu29QrGGpA9ZO44RmB93Av%#MA*HZjPWDF*}xBBn|$Aef}$lICf`t=f{_1R$6$)1$M zT7bxP;H|LV(@9VbT*QRYuKafO(jyK#xddx!!uXE?aI$_}Y2)Fp}txPYV%R(lH-$6nMpn? z-6Luua3-|&iuP{YlOik%B4gq6s}2P6bI9=)i#m|~Y$A?{)imW%Tm$cEMPpwlbPt!z zU5K1fM^Dx zze$eV`h4~Kf_2?U=14SY?@O(5Zc`>)?~IwkeCQPIWw3-xNitTz~2Wqbtf z(U+iwU9DK9>6+LW#CEA?s`WQG4W7}vf5996h9rNH3UUBsB^wTRvBBh#PW|9*qeJ`-%rLc_H8Mu&)|el_@&tsgM)*;<||LZE_J>XE7y}F zPlgZ#Fey{7&U%YO+y1nu`BqYsyiaQZ(cuA%w|eGvWx#zb=If@YP^^yNDi9 zaLtE6_-=|+|&Q~L3*}XzHSfkyYSTKLM7cbLN1a{Bx2hyc_z0K6ieLv%3`iaU3 zM6*K%EnmhGw>N`;1ewNM5JE z0sV>l1F~(aP-JndU`;98Xm&#%5vG1%6c7DM_UqOS4`%edYYSH z#~CQ>J-=kJAIB#a{MKKUGjmjwD{%)mg6Pka?$)!|k849aI?3HEkLIfP#;tLFGzFKw z9#G5NhYeQJ3Xq|+ic?wsDfVWJ?OMwl93IIJ^29PfQiWcWLEmwtb6 z$>OV|kFy>Po`Z8#szL#Ea1dznqQ zv7oNh+3Yo(_184T#>`H6@KrE8&n9N(ba2j}-acjQC0dG^U?z-2+WAfN%TZ)BvF>AK zuls5|zK@UYI&P;qqA9(_FWNg;KW6!yYeGHo=QY~kabF~yWS*BNEUGAGXmQe z_9mRzf%88uS#oPU{|7)pzrI6bjz+eX9Mi9xV96|U5JIURHr_ZLGAnUBL#)``0MH4P zgS4J9KiXcP4@1bR(dZUh^he4=X?o`@@eQMoZUYad4Q|+I`mL?OvAmPc5@QE%Ja-&c z$!K9Ce@XEkm#JbIBG|{K2Nl%XEIZ73DhEpAZM3_)7sF&{u>%IXttutB#CwRxT#jp| z2;-UF=qxsQDtJGgcbavyj6x>>ebyZau3t;BwvH?#5hwDmOwsK_z!L;8@9SOk&RK3v zZ*ElcU4Ea008~-wEUS-du$EGx09R5+JdTzX55;fVTkRZIHD?r=HNK4Tj<~H7Ihm{j zbgPj@Yd##eITfU!sIAP4J1V@9(YdpVX)Vyw3l+%0#WakJRJ$0{2;0R#5TM|Z#bxRi z-h5ov$i@hz5$=wVAC$kff9#U~028h5C5d(y1dX-%DdJz)E5v#cX%v=G+ZFv>PiGvf zv4A7c*UWzv{uF=0L2ZFvHer*{#J*`QM+eI8kK~uc{{Rnl9}3ScjEcddI6PM+X(_Zk zn*B}vralR3-XHUR$rXNHc!R>at)16B1%7jxEW#kBA3CYskE1;&g^`%Ut%Qt6xHx72+N$@ip6*j!rh$ z(NU|h$6C$kXzLy-Qy~m%ntggXPFB22Rn_ifxpSIb9{Td;tEpSwSP{{VMTY zc&&Zgx>q0JKMu73jVhjBXX>yd^`@ zQ%9=lnrx1sn$gtsh}zqcU1x{w7I+2+L0fk|A9U?An~uDQU7Z@YU5?f~Sg09ULhY(%b2P8svs9NtHO+=v&k6 zA?;${j8`FPum1pL%-gHhJVmKn z%MnqE`KQHqVkrqv(!AG;{pTU}IFE9lJvkp!_+R4OFzp5M1$`l;_}0?wV!t9ATx%od z3iMwK{7%&L=0R^U1N*JWu1r*Z(lx+l+`jEo^lx9(ZKOXbuO-*_NVy#=gwg&XCZ!(9 z56pU3XQW%*>T++Bm8FE7CGj6lt5c~@+&G#14{dx+)PwI=mw3D3&CTkV<^KRyIx6&0 z>?`lR7s1xE*(*AX;;-I(H@BV)p!}=3D;yQ6pNXC;@DJG}lgW=e52h>1ZnQmG3Ae_W z#eUQHqu`~r`X4e7+3p2=-{OCOrrfc$81^E!r%`o0%ta|E=zJfewEkK1M=ER3EcgMd z*6w~3UR(3DaqV3F*Mx5D_{U1zaJi zNZ<0#1o@=xftDJbl-)VN@4g&(f5jd<@hmmph-KxLMM-t6jq31*CQ3g=oVmS z_r6VioA9T_ejo5n-KE}>s86k3{{YW9?iSwP{`use#~(_%QA$TdD^B{KbbiiSR+Fo0 z-|?RqI}LW80yz7V#82CVlKUA6m<|Sz_OHI|h7VSFo*jmfmRI<-c~P<{tHlaRt$d zRfm?3^M&q3d5r8wp7?^+)Y|R;0KVJ*0Ew@c{vh~tf5+GS^{=mQG}!Jeu3vBZzD0Q_ zjaJi0vkE_ZbtfKQ_dU<|YmTH*)kydQ#2Q(WBJp4di3-D@=cYfEd4{EA-+m0^aRfI% zU-Ri-OnAn{-cYu&e82)pJTsBk9fmtsn`xdOw$vq>=JsK2XQv>N?s3UP{_1+kC*6#C z{{XTLaviHu&-E zmw*_4x%J8CHQo63`1q$^^F&Jpyl9rz@S*u_ZD9v&OY|As-^9e8^~vk9$EQSX0g^Ge zVX=Z&JnjDgWD&vsm9j%T(_K<^S#i?6%i!OJ^*ih5yzuEnxzu*)8b1C4-QSIqN_`f} z5c4_y9y(W*>DnFNh;o1yFNK_uAkUt^R`T<`-YBtgMC*e=_^mAfueP+_MVT^3j zi(^o}BbueARFXxmHw01Ifo4)_TvM6>X=g$%dwDaroS)x@{5H)`-bD$^tVEBCQQ{{W6xUn1p7#(dX2_rx9_H%n>uxqPp^{{Y#4v^c_7Hr(At2JHz^% zS{IBhJWFJrT-G<%cO6%Ujznc2bKBO1LzSjsTJn!6>E9i6sI+g0SAH3}-x<<$?GnI+ z`9Nz;D$Nc7c|Kook7k+;flamxBvmgzR1WV`yM9_4f4OHUT)GFw|u z&&*(ybqBq4UlwLrzAoxI{AuN;u(o=z1E>5mh=lRk9 z0E1`uzuD%!GspTBwy|m9tNV^(pG@*mj@+wT+)ozMk~r8SuSru*)<-9_X`W4|Xwl!q z?-ASw%)VX=gYG!uy#vGgRmH6G*x&x`9DJ_WQW*F46pP{DAb8UZW8BC=`G!3Yf2DQ0 zT-N}7n^!({x!5Ep|HW_<^>byhwfR-0-K^rstTyI&-6 z#Vc-szjw--xBIkE(H?75WXxN;S@h7qHx&olW@ySe`Au{eHu2qmtZvoCYPR!PNB}=6 zsAGxIX_sLJ**XrDqLc0|AslwD1+I^m<|*ZSS5u^2{jw!<+qGoPQqbxx-3ZNBkYw>m zZ5e8p$RVqtG7QH%4Oxcj6i^AM?=9OLS1We%#az_@?ghn}9c!A@)+rn1u6_%L-}4&i z?fjV+nxAXXJ(2S##V>?fXN+uR^Wp~}_OHu-jK8wihqUR8k-($qEB9*Z*b$M^yaVD_ zfvoilIT}<19-h_DjKISY9af@B$MWOx6XDL6;p^9n%N!p{`I}m>mN6l>L4`b5?w7_- zg1Q{n)7*Ih2yU&?zbbw*XjYm9^r;IfggD~97a5m2bFmS+J{{NY-pqdTzIgF1*P85e z*1dDadh$HXzGC24h27oS-L-Hg;Ou(XZDNtj$7SZRIO|<*fuksW*sHpIoDySbuGdS` z!ZrY?o-?6UsOWUMO|n4`6{S7K7YkObwD$6ACsALsn$nF$EsUv4l+@E0y}-{(_OA$9 zB&}XO;m|It=zj#+4=@UKbUiu?nD@U9>2kpm?G@gecQ)(1Q@l80Scc(URm^eC0N62J zWlE`Rbz*4G*`8&oX%S3^itv9Gczy&Al@;{G^hPvbYsx%Br+Mm0t~%9^6=7Ttm8Z}l zR=`pDRJNKTA;@3Sy-w#zMSNqWEE-b!R}Zy2snXENveA6lmk2nfYC0i?P>AFhC z8u3?e^i-Nb#cxiohdo+HipOR;ZK{_Va&O66_9pP+MotBCTCShvJ^-$IwF;$9#}tr8 zpDrtU+SWao6+BvRn0FCSO&pNQ6+2flmCof!w00UDtai=3n)UqxYhXcY`TOnDLKIhL zrTB^%AsgEUwv6m~n5;aR^`4V=X|!Ux(LKC_ius#Q@iGMjN%&V=Y2vKo51Qq@c9N{fNgv{^Z(`4t3rIa+_(D~Iag zbN|%-Mzq`X-<;Q`>3$-L-9l!*Ux!*TbZXd#TAE?OuN?Dg;SaOCIpQGqIqP3{d^6Rp z4Wu>rdnAr;8)xq3_dK_g^oM6*f{ChwS0OSAjapvJ z`UAmM6TFDW`B&C{5dza2tK`oM>T+A#uHHDWbI^6X!VAfRPKFsbaZ=2^k@cO1t!Mxh z=^9PVqplmhd~v6El5+c1iN$)Jk>ZP`jYjJ1l`S+zt35uWw($0wbT~b04_EkiYjAS4 z&Ukyph#7Yp^iTLqY{CH5=T)URGu*-E6zk}Gox4n6rq3e=+J$SNU|BUwN~AN!v}_@DgolTgM~YV zTY~JT0wi= z1J=H9)^F}2xc%AbUZdj8VGQgVydP1!GR|>cg$uLXp-+-FJxjvgCuJ(6;fHGWy(?Tx zkT)9m{{T+ZUfzDVuV3)jhzI~ZEL9yYhgu42Q|Y@avgkI~aireJ6O&#`px%k12D{rC ztIS8vV(D1)PwhE9e@Z{JuUhZ>C!sv=^sGy`5)>vdL9Z&g?a&dfVps|jwBw=g^c6Bcy(E%82|Y8|}bUR6>UfN-9JuWl>3xnp?E7M7!N%8ao2SNq?EdEU3D zHNCWOq^e%yJBcm)J#ois;iz=dJXcfEMxp!Glh3$z`CdXoDh}*@2mb)7yrW(MX#W7S z^q|8`)5(_uk2Yq=^BjJFXXR0W-#*)|Ut3$kx^1EJP!dW7IcDsBo`W1>xZfUWGWc`C zR@QftEO*y3tD9)o<{pKQulqQ~WOW}j>Q>i~ERsPIszwAVPy0%LTJq_w<6k;Sf|o3; zSdQ!0^!`=uek0T53h_z{MyCXi@BaXZ!ThU`gTj}Zt^KvFv~4TNg-C$=#hZ`BoKvxK z*y=nF;A?LZ&;6Nqa@u8yv^RRLpoizP(vacVePorGDr1(Fkp?S+@z0CT#Qy*rq|Io$ zc9cAu9WupvOG_vKiWK@s7#^7*o}`-epNX3FzlC(kejM3KlgkY8Tx%0Va+g=|qbx8X z9c`?#SwP{ZCCBh%!lItu!qn>53h~IwQCWU(^}jDnt$IB6ih?;6J$a6re-K<+Sjly6 zfX#m|?$@%%+wjGGMerZOnyuG}bpHSq{{YNFS?%;K8UA1S2urdzx%t0_&L{g%PxYoA z*{`4c2jJ7=ZxZRc2AeMX+r`f55TZpKYFbHUBPxj$YnI93SB`Ki>G}LsqS{;PULx?k zlS85SOG|ApN65sYYx35nS0D4*W|~`>eGFlK>}g&2Rn6}`j#~PAsPfMq_{T)?=fu5l z$4@TnEmrzT;=vy+sKjsrbJ{S2+XIU7pBUciT3zYWNK{y)^4U)cS%F`a4^RzqjQ3t5 zSuAw$A&%61s;4>ZssQcBm4f@4>rkzaRW;@YjWJ;b>#M*KD5PurbK6q^T@` zj+hZ6=i0Kaw4dy&Sw*M%sH|=6Ke;Q)@glPy_fvNRiu4bFnpx1U^{)_H$VC1h@V2$6 zTOs}lr-0Z*+3krglNiNRV%(84n&oynuNk^e@z=z9v_)c;{{T~-_C%AEkY2I4K3V5E z6L6&cIIf39(yvXVdJd=><)eZp7d>{d7{a;FBmw+}1$mE%*jo6OJx^;nxzlv%eTVuy zXxYaDmFbVj*R*(FSiP3c-%wUqn5jinA1DKW2d6(T^EKT?E1We>!$24MUYBin0So3I z-QaT^jM&dT2H)jg&EmU@ThE4GAGr~*RF?V=E2kvGC(F-FKj~afmLZ{Z39ws6 zmrPb@3&!0JHuP`GyA3+_2o-Ick=F!&o`$^k-Q(0WFSYF|5-_s!^JNdpN8!@CJvQRr zYjyL(vN^}g_pp2YGh3qudoGaiOM@$5{{TOwbkKnmt!H%ml1_{}A8PZRHfSyF%9)^{c8fF>pN+*HDKUKb1}Aj0qdPa!yLjC0?G@ z(8kL>sBcg3t&=fm>+Ej_mQl~rr;1Io75uA(ww_j-%IWD{43bD8+Mjp|)iWlJra55* zhUz+1YuMwuK_mYF4PjbcJdlm!4l4Dopp&yi{YMlL6BcbHCfzJ8k7~$iZY*R`6JQF{ z)K+U7rJWDRdsQpTt2>yoyfCbps@U$deNTFf*7w>nc&{kY?+Wo<<)zSZT0xUzW!x~q ztP9&;Gw#uQedIShQo(w%4z#CWIjQWVkEK;hnIJoBx-vx~;{vd+Z6oRhXY7qVpxCDp zLs{4MW_a^eW0;~~R+M)EqUNl~_&ek0!|xDFX(akQIR~c|_zU|+{2ln2cXch+t^2C} zwe?*x-&C|{CQww5%D+2*X`h0T>bh)(-Yx*(;=Igu2ih)VG=Cl65PlzO+DDuAs_vJsch2O$>*keeiiUP{2@0&>mF5nJ}QHYwT|i)EOXjKD1}9OW|eM) zipr0}^4m*-y(^~E(V%=|rF{M-n>%^bk+4O4exlKg0JB zM-a)b)(e|<7zVtC%IY}c%GSk}q&Iy$*Rqi_LO(@dnQo z#a64i$y*3{akgJSX{~KLOkY276;kARcX$=AZ96FiYg#AH;%fUgbe3=p*z~QzXSs1$ zntY>oHQ8ITkTF}GRCy+h88qB=6@_=aojVG)1sj@W+^*sk?OnZv{Lo4wJu8aQVpWx2+Pzat)1tU> zE-RX)YaVa*ay0Gk^B-d`0j zp7qVN8eCmgy^wwHutC#Txp) zuaZ?wd0w%sOBB3U(6Rh1)OE5w3<%Gm+4xh zml6O?eG%Z{=E5^xNu&5Zt)`I3patkF-TWzNuv;m`bK>ITPeWSpyiZxu{{VLo0=pY4 zmsJBb$5`-!x|>io=DZ46_fDmc0~Lhqdk2NCU?2uH>^fv{GKW^}Uo6?{3t~n%uEGz9 zBHA1S+P<#|Ra8fqsLb{aXT@Vj`AOoqeN*E^b`k9pBIDM+dGTkDo@Zr8+{V6K@dt`^ z3!|27R+EaBhp&L8IQz%nmp>em;1r(ZtVgK>E&l)k^{+4S55ym}f5)V?@Vs~0rOn6M zCM2S9UWsX|f1&eTG^s`_W7nRkT~2Q}nk zDA}HNHiG53?a^L&iRQ=C(x-}8Jtu(8-HDCGZKF^c>`s3yLabHPUp*Et|1Iw>9 z@b0f@Epz0@X%s&v$os$a8uWs&)3jaHnCttZUcap=Wc%K`qc6yi-gZ2{>HJ=u2$xB`$V%A+N;8% zuO*p>IRhOr*8uu?vUqCuP}XGDwdV7D;kW{KZ5=|B*C!YqaG-u4LeD;#r(;;uPS8h| z0}zHs=l~zx9l8!N(!BcXQe9SfEn-s?Y)Yu(IVYZd&nLgVVs3ej*MoffYs>vbm+bHy zdJmLzUfV#<@1{L#&2BC3ych<*0RvjQRa<1|HgLbo=CJA|kP;Po#+?1;Tzw;~YY7#U zdWM(;jEBm!vv3Nk$M=ND*kjXT&HppY# zsujowqU4X6fX^EV8*9~X{4JrruNZdy+mGBJ78nVrBc)|rlQRb+vsZ$k7oeq1&H(Gyy z{0;vA2o!QPwxQySXr$Gp^8BN1rZ88xvqbyKXCWq?TL;SWk-ECqG2(4{O*6pyPKBpO znRSN#+quX5^MpOL1L$NEzvEm=>3aVF#4i<*sa-wJqj`Dz$>qUjl5#h@0f2JZZU6(I zuT=51rj@ICw_YBL%8DJ9mXJsDx7i|Hu`2rKW_kW1Ytp99Nh?_L4K~TFVv2i0RvC`q znA+QyA1(%d==_)KUYnrl=Te5!!%~?9GhRk=$L@%^&wj%=9S5~!L1PY&DQ#evz}#(P z!yhs}rHJ7G&-%Jj8&`^>;( zlg4&>SEKk-NwogSmeSH!{mY}q?*9Pu)s=Sr&+p?k@%5kD^t)XrRkWAN@~69h_ybBd z{vCfDSJmGKZolH7i%Glx097pi0N(x&^Nf4uqYqZ{vqqKaa>D5P147fF{?MKqhfGIz zKEaRq`e2N|@DxRFqMG!5Ss}UbmX8dpxZ6uS>SY-Xor$||w~z&SW!2=meeJTQVwtYK z)MR1RVlBo!)G)7N)-LQ}i%PWIRz$SgU3uE=BTKXX1PY4w5#_U!wv$q`393H*?f@Wt z-!t5OFaYc;x4DivcO%;iYr@KnSS3T zrWH#ec>(_RI9@$Jt!|Nko)}BZIApL)nBgTPOz`n-!RON$$EF2qub-z#rUnlKV8maB zE4TW0_pI9=I9tnb(X&l23b{DE9WRN7D63%OQ=Ct);;@ zZ@cyMtzu%%GeihN!-^mn# z^TFfrty3~+?bh}r;E!q-{`NXnRhFkL7amjm-KmjdteyjVW*NC} zO5&$c9EJ;lz^<|>q)AJ`$n>o1*sX2`_Ud}jJ05j&e|e(Z3yA*!c?tBc_e9oXofhRz z0qtCc<)d3hV^g_(4Oh{mds)2ce&Fp{+`d!OufpnapC_u!0&V92)deBawE!b6SRytl5f=Q0Fri`{eTL2U*gh)M3L_+2JXW-Vi;g~IN+9Qq+YT|D+(Qhu*UX}JAkNh7FpLad8 zpb{akpEYX*9Al+?CO(y+qAnGzd9&LLjw-RUQr#=Axr75$Jgl1eI+`BE8E8+ZBA{Fh zU{|4O7sd+YEUkb8HM0r??^{t^=Y1;EN32=vQL=Tet4-8k-N~;Mg5o&Tb6QaAYZpq| ztqUX2qmSr)8xElXLGx0{CA{7EuM~^KSpWlFCZDYa88zfprK#w9Zgh1QRxx!Im9{8q zQLaGaZDif*cH~vW@iCcyIja_<0yZ;S5bC5a%GKfe=CZawGMrW}pW@r8juf0%LOl9( zCr?A_tv6Ixj45jAmg+MG+}F;!Ccbw7)!HYCQDh70T$NGk;cj}QQ><~|*D-xQBsi{7 z{9Oobw6bfq^DaecMx~FSq;s0#GRe(zZ)_x9MRXVP$0KA_IMH_VTh*$$4XkEq@CcX! zE7EjFRA2>hGsIy$RjV84go^PeE1Zv(EcX|=Q3%JSP!~WN<0IAHKn7~>+8C~DB?GBc zg|s#&yk9W$Q2nueYl^zmN=#O({{U(ljY#H+|JMFQG|dVqqvp27k}d+(+k07jtn{Q@ z$0iA{ioDMo+cY#QMmQZSX2RX)IK?6&I)jcYs?plsA~y=&oZ8$zw>$U=c~UAPYZdQde^gjGSJDGB+-I6s~1+5mt1tN z%fd6kq}+nTirG|5-W4u=S>c}#+ka*GYYP5UYn};;?#pP2Y99}H*3;~pPzuKJ55#G0 zkgQ{h?xSlPqvwwi_*~0EDh88^nzx!0lZa(pn%zU}{!w0j`90DrX6(-@Y(=SDWg- zWb8l{*y>PRD-l>LV>7oRgF@GZ^mdqrVP8&oPr~ra0+I!A9uU(TYGyHCN%(dcZNkO3 z(ygK;KE_A8x5y%;V{_T2Rjo}Na{{VfdNVaz8qdCQR^=Vk~D@T)|`d#}wd_wVO_Iz4z zhP)wfbdjhYb=~U3PCvXp_TQs^75c02`@^0q)h(_3WAPnyogKpBICRUIZ})uw3%B`K z<6ppk4*UmOR5DAdT--y|<@x`KdWgsR+oMhioz>TzrJ*KmuLJ4 zo_`=cPfGN#lCrZoqJ4wmrP6Gz*>syj6}UTE;(V(~?~3)UCgLubRvWMg91bhNyg7L$ zzj8G-N#xo#skz*v8S9Go%UkylDvpDKPe1NOkmO)hWse^*^r?3Mpw=zb_mX~< z=hi)1L~gsKa_p!NLU^j8$IhKz{e$iU0eBd#OGf)d2Ik`!UQ`}){{Yq%1)Q>KHw_HD zZ~*eTAXlF%TAk3EvCV2yAx6`5;#pe^?e0IqxgB08?pR5Gx=$s!Kym*70sKXJMyC>4 zTkMHdQvq@ZBc*cJx-=ot#+1LO@*v&3^vaR36$T&tgrW+k@rVnI-b6jW*s)vTQjkaNUfQ6u5lS* z_i{%74y1I?9cz~tSC1^ZhOXt2i!5+Q@}!vT&D4E+55llW)SBB(7aFkEby7rjhkRjx z!78~QKspdQbR1U+tu()4yws8c)67cw<#NlF`GRqU`E$=z?T(e{aA#8ABy+_SNT|ne zD;e}f_UI4e?^xQr(B0eG=-Ne-T{M$KEiJ-qGD!abCro_Y?@olEJeWl&$7INvn9nvpNo}Xm+ z1Y4btGd|yyrI_xLSMKr^2WqGvQuxBlPSmu?d^T>HOH&}5Si84;we{>D>aC(;2kzqJ zh9hDF=Ho3IubTK@K!;J%EcI3c`%&&bM*jd+X<;1fW~5(ywyoUJ}o z8YOej^!+|>2aA~IM>>`KM&2-=Ge-Agy8)vVS< zY{>rrI-SR(0lmoqLarM;fr{+>O<`xGX`{v$j@o>0=6`3jzdW)(+$JUC=8ir~zul38 z9&jsxmrr>-CkKt;g;LHd1+tD5ZQfjwUB-YBz$C-5i6g%Py(%c@IZqEG+T5kx#lPB` z+Ba>IbO@vF?DQvqr*~@VHC-=Pn@_ZR>Hg3qZ#@=IkgJvQ)e=A9`M2Y0j% zZ9q2~rNmaqfaH?IFv0v(pQU!z5Kj%o#R*TgA~)>k(zZMT@S zlYx%AV;eZ^qqhUTpH{W--h!HS;U@1}k~oCR^Od!NRgzM1)Jre}=&jh-2dn7%nDBAS>L!Qf&8Zk5dVF7gZC14H5);(pT-HIsmT^PW+^-YQuB z2D`5fMtos1!W0Q?i*1-6m7{(({v9#uqt?A7a5L33eOf&i!8Tgu#vb*h+5sQdpR9yx4uKxf_jcwx*lbnnx=nvE%$X2m3X!NUFRXxX8!<4hX8;$Bm6~gGz%2JXx;w2@J2Fzn69O< z%_8ls_maXsM3GBx|D8H=-!ou8>2rQJRN=OqKTlM!Js_}-5;^4PQqwW zByuSC{&hSusxXY6-!Im)A=wEz@5=o}YuVewA{Rf!^Dnh(pF<{%NGFve{n`1wDNV*^ zSvoN5Q9@Qm%6;Xg$f8e{P9L>VnbpB~#vH@LdQ)YO&4kE4_oY#?EwpdpyH%FD2`Gk6 z7*e<#<@$Z1Keony3c-f>+dkIhp4HRq8fDy8{(hUt10;J_0{8YefTnibCu+`a#V;8%%h8td6fvR-S^?X@4Y!M!_*(v5D0b*Oed7Cl8~7!@hL zRp8f}E~M72$>O)IH8`#Y4R6eIdo!y0L~Y2ZoHU#*W)$;2i|2D! zH+1F};=8HedS3OZyQhOvkW0@s z@i~?qDYT4h`>#XJuWjQlLW-vytZH&=YT&D}BC#jA@{TLyahOFR(Z#e#&an!kE~1OT zHJg8^q9u9mp|5OW2NmYhk0Y_}P^-?4=Hp)xbf!h(jxn0>y+gz@lYnT`^-GCIUux~C zvFOsLZBL+WyhXq{uByYta_|AK4Y2dDyN!3cK|4Td@++Nmt624GUlF$Bn#;ZMBZWMg z!M3%KvBhOyT1l|h6whWWo(5lub=J5YYPW|qk1XKUP34PBFnO!mUXoX*71gnwS~?!D zYpt%#jw?#vTLrq;0dE9>ahlMACqtGdxGL0FvByhBbXR)mfZ~?V#O4!{Ymw9&%p0iX9q>a+N!H$lyB>H4bOK&q#uctk5S&J@aFEKqMX`c&L zq0=Pe0=D!mEKMK{UZ-i{$jYfAlEa@riq4XH9&2Z&O~J-1c1u|zkC9%tW#O2J98%u+ zdNf=S(z2a99e7;TNgp)b2Wg4LSJ>W_>Y9#@g6NDl6-WCaJsP;~c71Mr&;Qr{U~e>) zj^Cy$&h9PBOMzaIs%uE`Z34X7?&$fM0={<(MISZJ&YIH8eYR^l@mc{>>DR==wACF) zGH?xDN-Lnbi!b5}y|tMqiLN6!YM?BpVa98h)vazL&svIf>~zuOJ0fsKwU+BA<3T=GK7A`c))&(!|t@KQlaA;JI{!I6=jE58AorzeN5kcuLah9oDOYYxCE|`g7P_%N$$< zY;juhBi+E_6yvGVX&2}*-n~0d)T0OoO7d?9>5^XKHS5jc=YkxsrDal|W3LtL8R>dO z-Hg_|GBZo7Xz)#ImBn}ro+OhMxvb)NbHegZLy++o+3Y^(>0V)~ z_=d$tYVJH&c()ml3i(gO7lyiG^gU{7=NYX6sSAmN8B0XwIXCN^j zrF%DoylW(H97OdLq>G#^kEC>64k)=!)##dX&jgrJTt|boXm8jk?_Tev=uL7%DD7D) zodt%a+@o!AGu4Tsk=Lbl_Wl-~YT5Ahz~Z_RJFRUGpS6t~T|RNmd9B}ukJ_7#_4S?q z0EEun;8ziQ;a3mH{VMqiQam~j35gVB3eVNN6bw8je+v2z!$OT;V+!86&>%4M{#C7! zJm=@;wc$sPIA6lKYkfIjpJv{b_cpoU%X@_^tAk!IuJ|pjvdW)ADkTe`z0NnolU&=j z`$A%De|LBP0IIv&3x5;pmpin@mP!8r*j?xH70zk;e~?C0?j5=c^$i~9$Jg-vhr`G0 zAb%80BL4u#tIMlI^D#*J%i;I@FO4WgX=|;&g#Q5PyX}_a{`uz%_2$1%z6xl1BD!CA zbHx7u@RI6*Q|&Qbw%ZGLA9$7li@!f}c0QH(o8fIc;ZKLI;~x@ri>+?vUS8izK5HNM z`xC>^{{YK@(uv5Rew+9YM6}X$ z<+;^oj^gbAZRogyX^` z9dTamXpJ4OoqZ#Mow3vKJ*%dB9)%8upW33C)nVN$^PK&_DuiVtoto*Gp1X071xRk? znS8j=s}4g+!sFKhvoCIbZLl9UVb#9s;=Ib6o~0ti^zj)=-Gg*u`9l8y3VIEobMzHr zYq*%SGqRM0H!?-?;ro2J9X}evw7r<4Pj1CbL*{-${{Z!>){~@-k8^b@z8iM##Cj9` z(0dBzlR9E(HSBU)tk-#Luge(;T((5o`Ko*M&dw`_ZRR2Ic`HEX-M^Vt>f0}|5^=&bv==sQ-ggZw?M zTv(>3eJq!^c9!zI5lPC4E>DvXjd6&_&HOlL$?sUrsU%w7p{Hr~Nb@{$$Z|0X5~pGT z-LQI*#t6qlUdd%})>>-GEV+W!moNVSJ#OQXn~ZYXU&~nY#{~DTygIrvsLvbttKsy$ zHn#R9u3^*TYf~hFn{=^A<=bv?^D$z*xEZY(?ChfOS@6Z=ZEqf;)^JDWkj5*LlC~SC z%3Q^}8N#XbuUppiD;rC9*7Zq(+FB1H&fo7N&a8JhQMFi>$V`kKyN}CQ68KuiO&ZDk zK_euw+@uI@qlPt))R#<>c^mVd)!PTk&RZUJ;mrbTizn3v-?sVk+Cl=K-ON>? zC#e}&w~z?o@H^LYq3C9R4++r{&9wzYEgGNPT0*SMz+?=NI}S%ryem|iU zWB7%p!42ekdoRr#5u@Xy7c9=-_wcy-3UwsZ9gC!sYUK?k?^f{liskYg9db$bv4Hps{ycgu*~e23^+e>_*sULElHy3uw_ zcSRoK9G~0+nIeAK2NEvsnPNCK?fO=)ZKN)RrpYoDk4}|jjOUb=?qMn8rtAQxw#qQl z*HW<&(?ir`)R#rmZM6I7u^r1Bskp;LaJUcrB%+4@0D%5g*XZK+O}VyRF+^7P2lt3R zY)p-Yaz|a`1Y_>t0qI-=X|Li78*Oh(D{%UJtsKz(!!!)I56{Y4Fk9-sm2}<)y`Dd_ zX&28NiGOUu86uPSeq;oU41L|o6X-@d;8&>w&S$x4vp^ONcHCiIOPKSWvavZ|qUR&n zfGby15NaB6irKbFVpXC{{JRb_zwzYuCb@|wMY575c8UvxTd0dw8;eNgIvlOa2y0Gz1<81@FdGC1Y?NWPS`0$<*R>NB64IQ=WtwAud8GTkW% z*aLF2oX$a{L?xtqzZttS)P z@dd)Hh(Eh2;en$~)ul~UL_B1E@;>25b*)n&>~zrFqgyEFEX9DwE7*@}zhQb{z1aBM zk1dn7PeWB3%iN#y&B@6faagavcao>(gzgRMD^$#x+}PcDlQ1d=;ac~0UTO*8>MPHz zwJBM=!JJp1Y4;J`#1>5o3Nnj<_js*qvlYxnSibdQT*oD}fVjx3_BxHV zth?tQEhim7NUm-`@}(JC;XGU5xu%HgdO!N%*1bMCYH_v2kpahQ@{{(x8TM)smj3`~ zcCKG$;LfQ@G<#L1ryb0Ax*0USb>t7a&|zza(yuL9)I%U@^y^2qk`NiWt)o6;oy5Ii zt=umY4r_ny>u@s16@eTvugt$NrbwP?`8dUDm5kathPnB8HIaDINdEvYwMQ2GeAJUp zk+>Q3s|y^J(Uv2Y9jk)YE~D}}IQd8NuTQ#It`qRD3h{k~(!@*$jidQcD@`hA!q<^& znuG2nS7Uc@(u(2rTW$QB@tn0iTr9axatT^z8R=eyqgw);R~2b!QwF;|GG>TXs;8y( zoexFQk#>M<(zL5&Wn5Q~+uivpE4k4%uevi_SfHg_Q`Kysh@4f+jUi?k#}!vemKSke ztESl7%@D}#UU~HyeO(VHgTo6WaJ8=&g%LYfwdr0Tx7v3P(x$l4Vpqb7^YM7OT;@ZSlsoUAX7y}i`*jz&p0<~n-;)=DaqCBNb+FzICb&VQGR`C2NXN6!6PM|mBzO1~+mM+9#BS*wFh zanNHm(dY&BDgIMhTBXD4@<&x?ZytM8*FGMG{RjY8R|kh~k`2{D7P_^YY3Wbd)I{ug z73YEOK{yr5-S}?m)Z-QP02Y^hYNf`51k3W^{Y5y|k~D@YaXw7Gw4T^!cS_BJS-7`2 z-O|0&#hw&vcuCkR=Y3Kaww!{UE77Cxo~;_Hd!F;*ABwm7Y5=d$SJPe*`0w{pcjmu2 z?CuY0+=BU#jjf94RLiS_rzLA2r&_1QIQ2VTxoY{#SNk5Vc13*aXX4w~$bn@z9lKX+ zr0Vx(>*lPgClj7HMNebt9|>sbXR%kMdcLJ?1QrFCt_68NhIJM;$j#UP0IgkX-_1O1 zkJF0qD&y)xc0AlyKR>v5Ka6y1tNWt1XCF$J!QLI$nT@NfzB$PM03yA*)5F%6cNsPG z7lHg2rmO*q4wdTRaU9XPGsW^nDIY9P@bg!>hbA-l*MoT9;f3<=n6B3J>}&Qe4-RQD z+XdJA?#?0yjVr$N22CnIYL#nLsaizf*;81G-AUlaZi zWnEy8a{CUI@yCe19A9bfcBcc{uIk6VM-vK>J;UK9t$A}Q<*W49;r^bNVRsJI`J3=- zN|ViAYxR%eW|C&Nj#c#jm7O?Uow1x!vO7VbMwl3mH1W^m)jd?uH=Pt z3gg4^j)J&NXF`vDabB*{88eFHb$IgGtt06HWeC)ywV)#|>vCq2nE&lO{hs~@HQ0@A+6 zW#mO7U0lWh^GDqWJ%xJIRx`+|FPJ2Qo<{@giox)nmhf3CS)hqu=G%-B&30mAZv!k_ zk^8q%_|?y2s*ghoXswSoO}Hj^KllS&uA=tIwaaWl(>x!lR{LC@!US&hnq=~JPF={RNFvs&qTaCo_KTknZ58fn+u=!wZ*w_Odus_F& z!ec6RFMlf2NTP_(2?qU7W>DxWATr)Z%VCvFpKfjI$K!Gr+|tb}8R0?%afZ%+0l_}!py(@!)n8JY zBWo0Sp5^hqe&R*_)m(q$$DDK(%O}|BTOF2@_E>FZw}BNihB3-Zi+L(`9S{w z>sBjkl6Z_ss7g}8)aGw4M+Fs6@jG>5fBjrw4dZTGbtJhZ^28<7Jc8F3ZG~1l$X3Y7 zBn`y%IKb)YT;{87s_FWF{3Q5SG8KZ-BpZlMt#JE6$M;oFQ`-YI>Q>rjn(@avz*;sN z!GDy?v}cbsPnG5Ijt_cJb}BNgif(U4+D=YM@oXT;SvTy!PU=lq{}l znalWIJyS@t6KU41cYmo}N<{95Dw}ZoNyr%8%W`-)1B&;(FT(Q4X$H08gfMB=sObXZ z24+?VcwSG;!HV`H8NkMJF14#~2K}Mpl`l0UwpIJYT3JM7Y!t>k$6z?lM$j-hu6k`t zP`%ME^-XH-YpG6tY4SG6>#;!Y$7m!BF5G9Wbx~ZX%51~bZ}lG%`8vJG7V%u(VIYuy zskVn|h`?NonF;~l<>!oIo|h)2eXL7=;;hAN9cH>%QL>3>8H9X)un15(FkgD^BhY`c zbQ^6p%zc&NvyL?p{`G{ChLAQs@%fnha50<;;XFNUuW7z68tjt)0A}g%r`sk}ZPF#Y zIFeO&1D2LmR6Vj(bgdJ+6jq0+S)D@T!*O`O#Mbg#4Pgz{=ow_URFn58o3J-YApsXT zCvO-Ax!(qOc0Uzc-pzS+xZcfaX)7#D5;eoL@sjxC7#Yagdls%g;^&HW-A?ky!qZ7C zvsqh-F76~aXVfmUEO7;6ljKJts;YCnx_qL&f5A4NWbj6ziDUiL2|d*R0Kk_TfpwL? z+bM365&h-?7pN86MXqahIj0RSc+JMCqu$M`c-KgVR82|RcucL9YiQ2I`B#CvqJhgO z^flqWDAMi)gfiSPlI3lM#E1S`eb7$j{vFxmU{|8w>rrdBOW|D=Bu2I|!!(M65Mirum``o5?bOWzS>iCffSoxR0 z7dP5Bh_dn+i#&Uj6M_p6z~|*7VLTCzE7~=G?3=kFit5$3rK&^bTd+deW^%q*`A*%$ z9k}MahVBMS`7Py-%S*XTaqSop;}{1$c5n#)02=k3RqbvWCNToHDkcR#fa7bH?hvvb z-NEfnlCgC*FL)CZCSOKml!$(fc2 z$O~{r*7-i`E3eE5BMiRv;9eh?HOyAB+%yeuAQ{QXM+A9I$mAbRc;da2OTX9ROKZq( z43}_3Jly4hVfQkuN%@pu^!vQ^uKFhhRC->OHJ-EJsI@zn-FaXX-I$jLYp<6ifOb*x z7WD%)oq2h2V8>3DQD^5rnj~I|z=995=Zty^=k)8zG;MYt5oxd=vs}vGZh(#0ON{SO z2UQ2B_eT|L#yTuO?tx74ZQB^`;ygCqI2<3wySeOha-NBKsbX0^(-B*Uz;P6RcWbcE z(_@}7+3VCaokb^9zP4qZS=#X>as%$;bBeQ!R z)um@+8e0W+$3WOU3XFd`q+wAV^vL>^jO%RRKLS5(z5f6aTpX6SmgZZ_CfxD1;yp)0 zT`rpI8hI~;XB;Trn2(x z2qa)0G3lD*Y@lbeIr(Hq-@6srULlOvH|OrOdNn9rqj zx;^wlZ;)c6xwUJk_NiQs^vy05hUd$h0~Mmtn4LA<>q{OzI#x7y=p^!PM-{0kf_TV) z@+*>=%vXCR?N>Qi4!Xxxn@zd$vAPEDUAzorK`brun({@Od=De$PUF(Lts4Bm!)tC+ z+tQpTVzx%^qodnS-hH|9N>H{aLMu975yF=cR6p%$9}I zziRW}5K5^BJp7~iSE)|%%m7OB?P|s$2haV{{HcNY7yDM-53);&{22JJsM%lKOB(?j z?ma8|E&ZrGJ!^fZTte_85D!}ZaXvcucTm#xSw71ej!#PYeBw4PF}d?TsSMFRL8j<; z(?xTGP-;`&bY%jww3%f$Rq<7phf=~L*~xhqHII1%$=bE-H#|c%p&hpN!wT|c14oAC zI3!{ES3#)Rs-8Pn6{$qgjw>UYQ7c^dl0=gL^scMJz95=e2Rv7p{heo& zSD^TA#bb~FDf=F9_MXjL`oF?gV&XL>zKGFKNe~(0zG(22M&2c_aMLwbf-%~zd0d=t z6VuU{1DULd;{copcjek?Ac7wukrsd*l(DcJflZxpYY!_;5!l#lNk z`T4w08lAbO{{X^KAE>TRV$Jza|Iz*|ue?fR=;ZxyD-u5uNV0snKgzMRp9*-FTbszy zH2Lowe+($i{VP^K0sKR2-|YCbYu2Csc$gR;}e zjxsCCbRXIG<2~iF%?^(u?#Ii$6pz5Re@gd{hu#_SPla`*)4W$yNtYQOquRLK z;eU0WelD~WA1gld(0mPHsOV9=o|Wd_Blt0@&nro%Cgtsl`d`8qx>l9pG2)$A1hj)< zB7e_U$K4*);(k2wM~FOcbZ$K7g3xrgkAzX}QHre?`-u67;%1Kum*G!}Z6otO+hUmi z0G4KNoALs)FSK6}X%L3g;UX^1qbh!cbgxv0NVBuISs(=N`^22E{KaT#I%cP+>O)HK z^ZAiK-w*RAyEn*fI zf5#&7jDDZtU!A`fdZ*T0JUB zPT;&$9%@4{Kr7WW7$BYiNUuVrUuJvz98Q&^{{W@n*S=^T6vz_FzLmyF6}FfDks|PG z+4K((T5a1=%^Ak%=fr$BM%|6FSw;emo|Q521)`0$$@qs>A7#S#?Op+_c&_dXrilpv zab7iC5}fZ7hKh4`KE%;HOrB=ZUi;y15>NJJ^4qn3etbXi7H*==edFLCvKux0D||K% zYeK7Kv75R&ZCA#ZwyW~Ft~!5-b0OO-ja|=(v|C%NPu&5#74v6`vt2I z*d6V~^cvRYVoqzyyh-6$Zdi@ALrm8%ttSyTO7C9wNv-_ccC5XTD+xGka$XAX-KLkR zBZb274Sn(O^HbBUtW_gDv0t87ej)oE-_pI`;r{@Q)~YHKE%RkGHV61-Rn5U1A51c5B<^wT?Z?qP(}o zI&|cgLHSdXHAkG;z9qq~{{T-|-9YxQHITFpJ*#g{)k>~Css?xWR+gJ9=8uX00O+=x z*NXZvFI3wHVcNcH_@g0~4=cC-0IHFk&(1#`&fnTi)<*QLj{#4n+Fc~$PKrst+pO;y zHtk@{gdb(D_uR-<-;FthPGy$T;!u}3&z}ZtSU(F!!#x3)tqDPEt8IjJJAKxN}Hi4 zb94u=tdTv9s4_}-A5Txht!Z|b`plcHh9Um|o|wiz`qjZX9)-%Y#8JbS;g>tl0Oa%s zpdZemywev>v;NW5+$!$ek*LOT+>h&B75o}?j8RD&`K3Vn{{ZUAUYvVZoNH6hsor@L zIR&u94s)J^(ztOx;`etnfnQD$-`o=j+CKJ2qkDI&Q`z*h#|hIM z4r^=1+NGwosi)hFvc+ri#}%LmHV?ZWnA|C1a(+|E?Zt9m@RTNjW&N9=G&bxe#NTEo za*{zRWH~HHL(V>erOTC@DCl?IE!HgbO+M!9#g_&^e{dC8GrGt3z&!&yV~^oq!nn;2 z`tMGH+f~)AW3{%BhIrsRWt9sLE)sh4i~ukZS0}lq9RvGG;DgHZBCv&G9l<<2hoAPJ z8;|s@Sv))?y8h9<9%w}}M;PT|90=F8dI7ibHM7+cj%xc;z0__kOr>NxihaJ~HU9uu z%(;yKQP_OJM<06~FWpDvQKI=4Qb=iN-)b zGcoy40DB6k$(bGL)F<(OiwB3ad1FlzOSfy=nIer)7{T(8FLDM5MmglTsqOaA~omB|@u#PDZ??sQgf6kUc{?O_&ntjCDKAM(g|ao3D(X36Im)i)a? z2NUCq?MB1KS2~oNt~^0-o^`~NO5RqT9Hs}y-z$>KK z#SWMxon^F*AX{&iIHQRsR8UUHm(4M@3wfJ%@G9qtVbvklVDSF{fo*(=u0(v>C;CZM z2^#K@{M*R{fHs5}1B|t3+<0O=HpWdpNo2T|Cv=u=*)!gS`PY9^sCFPd8p}*6GY(oXXM<{W9;!29PM_-?xO_p#Z%WV zEG)bodu8Vxn!W2;+Dxm6;#(ceX6GO4-banF2H1-6z|U(>>c_yy3pr$ctX+1icp z0n_&r@CHKir;nyT3idq;_T9C?YOFk|x7wZUhT%B*vB2PS?Oc6^hHfI%G}}}uB->0` zJP*9fkWU=~=cRMn9G7<2g=fau5Wh@qJHGMH``?vz(f1L9N4)<4XIfiXN2l6G$1;tK zG5fvoj4%V`+Ir-1n!Be7rJg9Gf!69k>PaLsJa`2E0Cxmq1a|3Ocj2!T>T6<=TS$D{ zQtT0vb|6(>-Q%W4GCPddtax|D(QEd$Dw2?QC_;jO*Mk+CF+`?9zI zjpc|r@7BE|PQL!wX1WQm$JcLNqXRe{$NA}BG-&#ywr_1=G%Owd=vBzVxfxN#dyj|O z=ShEWFz-?-k2=@do}CUf$T}>xyrN{8JUSsIc4;@!cJz`i02v{Jyo> z39B4Zjn6>SRUuPyxboF{XVg}+;jr-i(g%M^#gRV7lE$nGmd-fo=~q#*QbUR!Wn$*|;^!SW3!*I_NEU`TL618Jo#yv2z_(n#LQ0mizUsrp*N3e}X$! zl(0bOa2PeJMrQ*xtd>^-G&~wrm271U-RrZG#vE>OSl7B$z1)8%QCT`*jjIdfWfKa* z*3v7cEsEtQy}Qxm@|*6~g6an9Ex?LST@7Qo$!a=$@W1ZObFFE5?eh*RvX*CQI?y3` zD;`rX->Yjx)wb+5sXj@m7a#;tJJ&K%eD zFY(htOUoG_AALya>0gU~7ycdH-|DwtXF-j|yiCrQHifBKAC$f$@cb(|0XVL3=<$_8b&6$ORG;cuxp|-$u?DcQ}ml}w%qwlI#-_H(bVv; zalB4S%fs&Powey&Mv9PtSYo{sP4L~h#upW#ccFPJ`@Sn`6&37pV)FTWGOm?)@Ja=B za_jO&LW9M4&EBg$oxo97Z8cf$TzxCrPU!is;hx85u9Z2?D`Yhl)JXS5= zh?zPX`OL#FK9?%Y7R>f5&0(i*wdOkZzad!`GlHiTz~1=qTCqc17NO$VXJ&2PYmX0s zaCedPe$iV)tG?Ff3%hZwsQgtdJQg+QS6cU&I-2G#b&?zJ5nhU-an-7QZ{cqp%BmEe zmFzl4jU<-gGTaegp4xwlTGhc;>0Z0xPZ8Z*-F>Sy4vUH8R>S*7eO=;T5Q!r4gICW! zB4@a|-<~U^mV29@nX85I1;m!)AXLHzw6N(AK_P+6mAj_h44Im_NNxjLjnz(lR9pVx zcr|WB^0`ptuj_AMZPu!!>l;mYkZ@3Yu-;ibVOEDO(%Yqj0RMmvu{K;pV> zF4WmX9!QJI!+0EiL$!P-x2SvvH$T2+YI-%soc{pXclPZO0l`H&h_?vNb@kMtEMt^~=E!=)j8L4f4AlqH5#p4Yc+G8RA0FSLS zeUfztzQsOE9@z50$E9Tq_ATCVI$lv%mg$+s+6WNAAo-i6dS`<4C^UNywz7Qh%}L>} z0bO{HQigvIpY4k&V_9Q(RPO-wStI+%!_%jsBE59@C&v5PA4j=VPOivR2jW1*dVgJ^ z%AYbk$>Hij*`tV&Zf_6nh6k@o@c#ghe-1n~;*AGNvhhBp9-Vb*D{dh-%+WWdI(ufl zSH*u2f8ihVd_CZg^|61izJI;{0BqH|e>(7+t3h}Em38~sf7_)TM4tFP%t4&K#{{XJ?tV#a>j^zHe+*$Z9;q8*_wD8lS{{Yw9e}#9k+uN9eIw<;9 zy`&b3od*@P@pUJ=JExddnzLt}YCjEp0jXL3_fELB{{XIY2mWhbajX0YxwT_;;2Xhs z2R=lU@WHq}>~@ zv1;GCmM+!jz98_7PUb}KU!Z>(J``(~Uk@kvkK#+HKHa2+^9Fg6CLuuGz{>qOA4>R* zHO(^njloZ$uV)WYNnH0Z6+NtZk;&+qNz$Z^$7=WeC&#ksma2gC73BAi0)ev?i+QLj zM}b}Rk7aCqBNxUAB*ngKVZ21wtGM;AleF8%l1lcE4C+ciNjb%Iu_)NwmeCqb;9{V< z)LPj#Ct-^08EvlM{m%xx=T+10M5HfDixhdrzkM^m%U-eYvhwgM%!PQ|yIgevkY+bXdk~+Lq=OQa!8ZKLzS33?mikt*Mt*+38xsb~JH1tE=c$ zoaeo8J|ooB@U_os+Lg?E_MY{{-s;mk0Op*gX4zjII?Esx)9KO2 zv|x6oGR3XU&)G0*!2Tu3-a|ieUqTy0jw{RjZ{gpy+MXyel#k6n5ZM8G)$QL4An~>R z#lDr``zY2~*->Ov^BeN6@qQg!-Cr|EE9jpDd?>o`^{zZU`>VNsjz7Gg?$$h}643Xb z1nPF40G`9a9vmlH@m#CE=2<_~K>hHeT8Dpd4PNA&cp zSuata%L~k%NEpR%PA8!TL34LLp&_+|G2`Yv0+ad*Rj^w))mVUI!5g^8sUPR1T1{FV zI@V+aPa)-xtDn1EciS$mOm`aOnDc*n&nNJrbinEKtTgI$#J1XkK&B`l4dqLOOLj^ik?h@Z-I5N<61bPH#<*Qoh!%Z)DvhutGLMhk(+sz)(2plm7tM zt9BYKwXMCpT8xFGw`O-$<(^%~A$wqL$3KO379&cq@a^=z93EtoO3{6?HT~HUCgQLU z;^dGR44$K=YF!V)TC^9>5KrG+kn$HiDlr+s>PQ3S{*{V{GdAYX^{p>Xd30kSk*(h9 zSleQ@=2!sTw2(iEg8l=cHKjd>@_yOmhJz+<2G%2hgl8DsbI^ZE-SF+5h|^87v0F<= z3eO+T&Iw>oOcLL90O`+aui~qzB$_QEXBK{Qd5-ZYc_SFx3Md)dv|)ib;~j=8y;Djj zC$!YyG|Njn>Fy!4DLfSB(zTC!)lz``@N^A`Z3Bc$Z()+gKWAjPgv%zC@D9 zp@)_TZJ~MvbCBvkhvhtARSC7-E*DE(4>4&rQ27=pqvyGjQ-_&=>E;$a&&%sqZQ;|u zw{CRZq{RfSER4YISYtmbuWiI`Tn_980V1-xdM<{BuC1t8X|pNP;*Rd>J9V;& z8&$z0paq|M?kB5$HP>m0Z*6T3sdv%i%sbSaPdF?W1Fqyy-2n)^43S*kfoC4M;>hOl z8y~Pi6xa93_fu*WR9`{4_OCsc1KrbI$Kp>P8y!vFJts}b9C>HWk&z~l zR2>x7u##G^Ak(ky^sg4($po&l&kUCEIRF_g&=|bM1EO=&y9Q?C0<&7iOJ4)o%#k8u z=I&_WWBtNlRp4?KMUb;|yBu?h*VlCIe&@n7>TyCM)HT@RvIp)gR&kOfkmv415P2E& z8@iBcYn>lZ)qF(0AJ8H&+4y=Jds|kU<%)31uP9(fb_}vtLB=~ zWnM&K!g=m7pWZ(2EjsK@bGsPnQTG}!i}#LJ)8Y)eM~&g}Eutt8>JUkCv3$y{Z@3~M z^o=nY9S$4UHPQH<)57-A-|E@~WvwNS_^nHk(Y4h?e}=8p!Mt;wNWMvG@Q$A})W zE2xo9J%fGUZ^E@a0pcr0(_#M74Jv3?jIM*6>6(J9GHe~nNzc*q{35*<#GW6z)*m;E z7}P8yJ16S0v>zx|_9Fow(*gS$_GtU?IAyuuO<}4f#IsvGc>B=>1u{-IWMB>e`qy*e z+qBbm!X*Yvo*9!H8@A7>zvzpu%(q;16 zAOTl7EP(JbIV0=uB@j$3k^OX&BSHnKG>tt2VTdqDseZls*;0RFYr=~tv&MS9zf zvf(~lkC#0F`rr!a{{Xbcsc4YhMgTbhx~U_88BbjPm5pnrxLMbE^4V7(-cpU%jxbkf z98*vfrE}4wlJe%?%QE3Zc{_mlRF9e_9-s`VAKk7x*)HOh<+s3OW9CKtBOM6o$G4?) zT0P5K{in}6SCW1#-`Z>Qy2rDGQCb@A(MeWhA6 zU_76@bB^m>HkEzlroX4VF zOtRdfw*;!>`wo?KZE(umNhi(uS3hz^{h~=~KrxS)`&Nq+Es9Qu=Zu0q*0oz1ns)KL z5$ygd+S8q7J5Nzuv>K6U`06UvuA-U!$@!~PW_OW8vHaEQJ67CxvB*6Lu2REThdU?x z)qZQJau{$cXqlZPGx^`UY*aE>%M6}i>0IN^$~@Scf$3Kx)h3Y!PfEEB6}4L!FMe4@ z4RVv|EoI6dmb%?WV=cb;dWz*&S7~GNKycf@>DLs#-! z4N#ceO?7>;;i}=AJPPL1Sht!=k?&PoBDGA(#M+s*4AeH)WVQ`ut@_LI)X~a|^HC94 z>tLKzH`6ZK$hW>+;;&62$h=m|V>WoNkMso8>}5^4<8ZIaua6%MJWU9e<4D4@bBuPc z*j;8JBNXHp+vge7+|X_BDnYqE|*7T11h8>N?j| z;cXpKTr&SJZCclF04!ub#Y3bt<#%tg$?u3c^yc!|M9u`G<92 ze`++n>W$j8#DYauh60>)t{Wrhl~-d;wF!k)b)Mszw+^2yq#E3VPm(YODh+9?oZW_; z%{S7$yTWl8m<6tXP0&dBjd#8tvhwbB(t8=yquBLbD&9%HOm?GC@ZZ`|iw>3DXf`p3 z1emKDun{2$I*Q`v)a#)@`R?k-Tay^=USq3F(OwqpE7-M|icZXH$@SZ9z_oT#)fHBn z>~m6FNbS2c&o)n5yA_~h2NfsShoxSfj(@7dvH#QjG0`+@!k=ikjg>nJ>m#zeLM46& zKr2neUT)ZSsvtGxl}(qL>~#QvGDgmw6(pI4sx<@ZlJw2~f~W;rMG zuRjwA)U;%M9!H0|l;nHv+ebfYweM7J^K8DxytCmRqjhyQuMP8MBAI-@Wv`_C7vY(_ zMXX6QyjjA(G4!to5S1I>WP6I9wmaX3o)f&#G?{g4xnIt^{oj}U$dmkQ(JaT=<@2I< zhu?P3)}!$ciuA1~{#yA}^S6lpEb14Ny^eyI#~x#KuMV9$xS6@NW3QGGRz8>UuCd_% z01#*&+g>EHT|-!fl~>D~$a&&#mogO}H;dHyS>Q ztGBP&SfAEGP#+!XO%MElG|5KZsT}eD0O&J;UU{p|h8!!5?erM^MO4#tuMu5GyHB%{ zT`*YfewqM4}ZxX?(QQ$B(kUQFZ(I$>tCJLy0zLLVP8Z1aQKC-{9ExP-Y3^m-6g<^HMNOhjblvh z8&d%Ndj1%%ny!bH?O%VEx9t7nw+Tj_I4HpMSas>%FdCIFC7jm_XLvS&UB--3Ue)TI z=TD_bV*uA-WvR3*ImSIJpSQPFN@Sat?$)4tHkaa61OjWnvc0psA2SO1P8;ip+tRvC zJ6*a*3L2VbXRY{_+T0z{8l&NjOxaw34gjnHc{SMm;i>Fx*=Gb+se$%yggjT6Q43e9 zS@@7-ZIZrC@WapKDU9Qa?M3D7wJuXb=?}L@e)|h`sc!TL=0M}pzJ9RrX}6VgUe}@c zvKy!sTBOX%yh!Z2_N^T!Q?@@Yc@-|aZ7yTnmBGZmRGCgF1KDA?j^JQcCZ`_GK+ZZ> zKc;Fg9u(H>ll`Djc{r|gd9&$P5lQl4N3JW?d?jynumV(L_?oicOm|^B4xhta1&Diu z@_x!rJ1-TP%k5e1{uaLQ_J)S@Q@u0HPC~?RtMxVS9}K)hY5xEpS8;jL%l@x_+}Zxt zdj2)@ri~}VpAU$i#GY*IM=(Xo{wBR2Oz|gz?yV)(=9!_5c=?d$9+>s53#TLQ&kOjM zRPiKj9+Xl$C^96t*|$7&TJ^gL?d@4&=Xg#_@}rvh*TI+n0NZ{avx?41t*rKg8Zg?= z^%d>bP3Ir5*r8|71;4Fv)jcXb4O_eEW{?;HW?W=HJ3#cTiLWJwECJ?j97IRTea&o3 zZ6r8=LA(3xopV$*ST5j=VmxjbW9YTzRcEP4+XOp5Z(d056a=Z-P$+PSGN z?O}JlyS9>ZZbAqmB$*`s5OK-JdY@8^$#7V|AdGAn!2XrYYZF_{six^Mn|pJ^lE*LZ z%sX+@jw_B;8QT(;x3B4PN+;Fh4(dS=!c@th_K)0Op!vN;aULo04zH=&{{Ugz+sOs2 zpLWtA^FRA9>z=}8sPh6W*o51=uou|uj1-;0yg5u&STW68GbFv86<}a|}n{PPAMh$u| z!i^_M(!4#Wc!JVd?xmIdjqaN~*4w*mYn=VxQY+={YWg`f=AJ86y1jj+-SXICxW>_i z_vXHov)0C!;q4nv)h5GB_M3@V03;h!Z{dDXv~I~90rG*9R|&3bvB_CzdS;z1{1&Tw zdzh8RVT~|Xe?kDr`@KCl_o{c6H+K`?L*6&#dU%pWdyUr@YPtjRQY7dFU|I;Z-GC6wSC`XE*agOjvw$K9-dHeEYUYnfdAt64v43~l`&ePcWHHL<@vI@J@oOSzsdYi#>_GnZJ69TuQi#Z_=fNMChpSj zf7@bvHu_(TQ%1TtiP0nok(?qx90QI%Vy;KveJhH674I|+7|Mzxw1Ubz zh@rtbxRyp^`K*iS_eFOtjx)vo0I)B$ds#GJEmbbym10*o4{^al2N@f}DEfoj6+=Vt z{{Wq$Si$A9ZELe3m#9dEw;1C95`Kdf*6LSR{v)~5ue6JE4yZ2VYXitlkTg#eaSV0f z5!{@L_Du&!@cz2`+i8C*+J9+Ws=Kb{N$mu&?e#!aGwu&zO)kcV58)|ev{tjWi*%eM zM|R!7f*L0Sru98X8R^tlv|3&N0A$))U3srBlVNWM&-b@&iB*|<{HJ3Mx#zujuY<3w zw7ZtM@ej?X%eotd97v79jg$aEA%=OtAQ6$8`en46DWkTsg5ov0b7dsB8)K6ZGApU&K^Q6-hXfy#^T^yOnr^Fa zVD|bI*$~Hao@$9Ru~0EpiLx>{1cn%8JmYD=?Ta|tGxXQHw7EJYkfRp)e7UklJ9Yed zQZRk1(BZp=`YU}yNOICDWflJRBj6Bw^v}4(KVH(**`?Y7l5Fy~fy*Q05=J;-jk)8T zselTNIn*7{0n5ws^q_jV9OakMt zLc@yL3G8%s_wrupmYSxXo8uW9LgAA+UVX99)|H;C9-Vxbf&NlVZJyrV*vHbjtq%I` zG<$tOKHa-W9UXEyWBa^UPj;JiDN}m5RZ)_;U&HH~)of)WvC~=@XLdN5U*<#F?PV5MM%q+RhY}$kjlmV!*g9NUD%=tZ6Bi$)ZxaSCl9JlLmc@*6 zgOGb|t$UfbTeZ3XSPhB@?b@;}=bp;rYZB0Bn8;s|vZzz52ie_Gs+Rd)XXMbuLYe(MV9jZ-HRx{C5R{z22$ zvF>lhr2EmZrmNd(A7qDOubC0=Si0o$M{aHvLnozEs7is?+G|#>aQp+_sOlD$*E1W4 zn`-*klO14re>cnMD`L;YcPSRayjD`EYM?=XZ?b){&5mk0W&1o~e*I9?j8eAcz&_PZ z3CgdQ#<5$P7uG}r%o$p}X+6Y^`?LAh8M&6!?qC^eShXmL$QvtI>|}JwJEvOKw7Z61 zk&SUfRYAj3qP&d0P~w`9n$YxyjyRWmW~^J@DfwHqa`xJWUl325f*0@HsizHV;rjZPI zaaDDKuL;^eD)H-JA;@(sPcDT!Nn>-iy&p!i;4rVB?KM)jK|$iZyGPY_+kz|RYfx)r zSn7H*ylDmx2C$|vvJ)>gkv_X5xK$Vx3u_4sS(69ao*s;m>0$7fX`J=!Yc0M&noBNVs7!}WWmgVi!ZqZ%OnI64bP*Abw7miSsuPpIBu$fLPyT7|;lLEO- zOc@s@x~S`+^SGL{)6np(Ldl}Im=nzt#*A@`x?-zbJui> z7rBf&*QMz?3|m7<&3K=Ld_+acTDMGB*PalywZ4pR0=&xW*}F5@!$C`94lOn{1Lf&i z0%+{@P2_MZ=-VF+Sojby=w{FCLq!q#dKf0%rpHc@(@xK+_{{X^S=KZc;Hd^?qcq!Z8 z(D+#4WUb7sd%gFFwM}*8)9<1XZT|pJk@2_j?_MFSXa*fiH`Bvvj^TXC+(x4(9sB!L z9}il5LGYc{pZ@>}Tu*ll(j<^T{{Zpv_H_H-G5fjgnx60WW7obf_@sE#;}3{!bo&IB zNg&bfB@Xbye&k5ORv?@ zTX+0*@4g?q{{V!0!hdA&{NHQ-{{YOg>~HS;y#96byJ^1#BGW&xynC+QL8ICgNNnJ` zSmu%#zk%QH9Q%HC!ua3f$HPAl+Dqa80D^jadVZhx_CW2kch$cAa8$rc5okp}yF;QD zjoUrfR{qu1??2Ne{pbAis$$yL{ynXHzhAvy1yAvF#By90zPyps{&|MLAC-Lt@T11o zcUMtny?h@GIXhIHmyzeU-A{~ zejCy>%h+)mZv3jhiTn)%!@g3Z-{n2d!7SILsF~SjB{sa%@iu%&?;pzKbT9@IK zw6G5{yXj@vdOyR0K(Uj1v-4U#E^7$^af<5Ty-^Sq>faFjF&3Jo&{xk-a?#10U9`-?Al~)mdhO9j&TC917XzybQU25sOaL|A=$<54 z_^%K%GtMir(&A&ydPk;vjjxHLVary%w!H&!n(+vw5s#%!2C5OUKpycG%u~48T{Vr4 zTX3%x(KR;nR~7A=B+W9dH0)+{u_Hy&wXAM}N%K`Sgy#mfbp0>Se2UdDJukz$h=w2q zdKIyr>#Av=BWk5xDb`ZyDjpLF8N;!L6HnfgT5? zd2W^Bi+hv?x>2l0#?wGlqt1I?m=^6rO!yv^ercXRlF_m6UL*ei2?E83cPoyLy<-HZe%xtAjmQY+Ia2 zGq`rIus#`$4%!H;ZEf9K8J2hGK9xMjGN)tej{w6A>E+8Dj(_E!pdQuq4v#p53qcw{ zFk5$4uO#??E}t#3>AHIv0}Q;~0{a18wwJKYFWBwXU*+4mf30%WJ81ML{@7{2TSxN1 zU93-2^sajHPbp^+klXm@k}A}5!)-K-LP;X3I9Qu3LJqk1uP&+T zQrN<|Nn~8aQ3DbOnUR6q=lKsx$)9|k_)pd)X5o#V*mmiY_zsb%NDIH0FPs@&#$txYPDeGE{n(gDMmvyuwfJQ?T z*Z?^L>Azpv9{#sZSG2G#AsQw~$^{+X7*?0L8g=85tY~>Fr(z8a2kIzwnW<`TqcW zu~?ZMoM&^tPyhqINa=xBt$y2i3f}53<=Wf+>KK3<2H*K-5b5q%H zX!738a^E$>4X)282(hpHvyb)LpSl4Rzop&2t1Z%aYDpU2)g@sxe0kGj%iIX;Gla)Q z!Q(Z|_(bZT*|+A!V@5LF@(@q3bBk*RK9-W}gbueGt+)ANGCukXW+uYzCp#D|S+Gx;dS`3!A zjAOZ(rD-Kb+9E=~za;P%lE|mCD&yNH!`CrsI&PzS%w~s6h)b~ac^`mS{{Xbxh|g}B z735c1)SfHXmcD zz#PZ`Ju9B@o}H-K>bi%CFPDD1c+WKL{{TL$Vgp_ znWrkkl8=x}D=y%P=sC;a{G*dsB;y9Vfye5$3*c`EoBc*-8iuUNGp2A_1sj)-_L<9L zwidk#Eh=C5O=o>s_OyX+E1v5$ zf8lpqTqMaPu#*VD&Ldc49CC6CI}CG;)gOi~FD(8dcq%ERF{>Te&oDk_%# zb@nfSd`{ZWgZ0lCm00eyts@RmN=EN5?qXq%jr-OgE^+B!GkE?B3(Z3NT7^QzCaHTA zIR5~iZev9sKqFNiz*nn$J=L!KJ@6k}`(KwVvg)?}U-5qF6tjKra%!nd-)o?yZbZ(j zN&d<4UEH^Ka`{loiM{^->687|bzknDHx4t@3iY4(n_YN&?RR#sX?LW$N5AF$?Yc2# z^4Y#uRNK6I5g-^laq@+y{7QVQtt0nP`3WR({M$ZnnWOo;?e*go?LHmxwaB$et_xX= z0`Ux+d7<13EU|S`2pbtVRttp~u9!&1Q9U0})jS*G8+%)gT2!}^V#6_~0q{3toNXjz z0h|NS;L|ld8r#fEeM;JLw^A6^<=H}Q7}?G?kQe10JA2oR-a~ALeGFl(W+m;s!X6?C z!WAdxT=Il=ZVBSO3&Fbe&4!sRmE6-=wV^<)rz(Jw2q1#2yslJ&2V8cp+HB>TJx9Ve za%ynP;9H%(UCOn+t2g|yp95;1aul!}4tgHd!F`rXX(oGk^EN+tgyiQS5uai^;Mb>U zQrc=7bo#})k{kOnl9oWeXbOHt+meJR&PY6raoV8q{*N8*yElk*Ab{O%xEn?TB-oH0 zxZ~3U(z~FvIHkGgmzGvZsmG{kk+Jf_M)3p2+~9HTjt`|#)!{x(oPn8y(kgEq1B1DF z{xx(KH=$+fdWud*I0Hfsi=>fr0o}J#}R1BzBRq&k6ukCk%v< zjlYH~R?Z>zW|grbXN4at0(m(frEIn`k=}VeWxmV)nL8POyIcAK#PG?r%L9hz*0|j! zR|-(BRaeZ(JONv_7ZSrGCCJzl<^_EYFxZBokex<1iFV0!NqG6GY(s5rd$ew-$WfE2>8o)W6vk`OS6WMw0$}OU7N>xXpdd`&;Q+mxAu4 z*v0Gd+r^$RQ>e5b^T^}!uawMV`@>kq==vK}y|uL9cdsY$CyV~Zg;R?2h4H>dJj&%Y zT}J0qecQ8NJz7s>X-6f!@A_@Ws?N;}VY5 z<;sDxmnrDht#z~KJXblXYA*!hC2hi~ORc@RG49~kJ9BWdOdT7VBSoXt!eSdeQ%;^3 zt^EAp;Ww)eD?UwmEo0d;;MY5Kb@s2`2I9E!3QJ(cRg=*f&iw9b%G9Jo0^+Lb z7i}L(=C5@x+^8nAqwP7*E~XmiA>t`~vPrKZzi%DG6+A_HW}7n%R|^iOZEtt-)^CaYVPkOE zh+@33Yw7l^=Dj*}bU3P2?0svZ-WkBh71HXLcByIUUp45yEV+gewW~i~_=L*HdzS=fD0{;M2j9s|OekQ5^0BcA6=>GsZ`QGow>-L7^)o=Js*VNXZR*}v1h#&ve z{CM#f?5Xf~Sy?Xh%R7x4>Df;0I!r_W?c__jFnx-evup6eUx!v}H-Pj{?OWhKDsyLuMSRIGBzvYu`qR`%H)23ht2U6n!CF5KPUbpd`Iy| zigk=@={>#F92n<}#LB0j+CcR5ub(_*+Wpiyd+9vB$hG@}b>Qnu$Mbb4BD`&zGj$qV z9l!R`TD;>~_%ZPR09HqfQqgqh4cBeBi2epJnEWfF14^AA`F48`h*Mg`{09EY#oyAj zNGdQpSL$DlKVnaZI=`85;mtz-07{ts$j;l#VgCRDA~^@)A+L_Sb^9^=O7KOk&a0>S zTJ?pZd6K@KOp-+q{y8CFva0_8w4y{G-mgv8r!6)+u=$-Dx?arr4ts0&k&61)r+g;4 z)b#CID&!MXu}R2pTJ)a+YT8)TVLWEK6jIpqs{a5C{ayH9t-RJKThLdrT-{`6wR}#-(-gE#(O$WvLC7p`K>)N+6o+4MYPlxKNF z!^(H=iTc$zwOgAB_oLQshkDl*SJA#6@w|7@ ziB*BmYW(xluUcjhuM8{Ld^4)vz@j+Yr3J;GYh3D)%@C2~#d-DDimmkfvL=uf9;X%J z+P94^65U~xl2*Qd@fXBRMQ)l1P6c5o+|jKM(|-tjW3q`wy_&S_KZX8Czxwt1Iq;(I z?O$ZoyhAVA?&bdgmJ$8Lujg;!r;9aXtHgY%<*(lF+0Myq^y_~Tc!k<;E+FY^=ns07 zta_#jc{c@e>lJ&L+M(2ZjOj;p}Bvv!!vG7 z{o~F96I|z6M$<-+UB-}x4JJn7Imbdk9{&J^cUsuAhTmRc!|Rt@J?6zqa*?RL{~TBc`y7+HNBOH5ZT&!kolO}j>0#8`D)FzUC-Sw+_CA-uc>&>Sfti9Sm2&`*_9GF7?Bhw?(GF|Hva&30=A;{GMl<3@mAa44BY728Z%9E zbqwz&?<`_8$5YWkY$+rT{V*#}!=4#U8s>YQKIxWQ&|Yh{Pc0(1gFHu`QAr%iKxqI9 zfL9@?YF1hvnWy-V#BJtWUN$1O2t1}LM%+h$$2bbRh6x87xUQSSl4@QslUSPmL#1h! z@*8H7Q|2?a51A`}en5WW@T<51(zmBzdmGmATI;?U8V-wcx@L;yVB zyIJX)d{C{rI98C)IU8A2;2ArYJSy~2fmU?g3fsa~x^A_rEHi4?Fh(q`m~MvL#BgSQ zpzY4oUhE0a1RbQVjG5ZlXr39<^=&Tt1-MJkBHB5mxsPr(k^%%?cexy=Bn)~Qzv9mo z*;}R0hN6&M&$jB~<&FwLAh%^mWXBHa9A_#%Ro#wpim#+x-$$p~YFb33t)}a!uN^lm z(@w=QKiB|oJAP-#QbLnoI$mm4-|&(_rD+ixcp|sElHkg7FPf0bSd;zifTWH)4h0gO zlCUGO_&ay~tKt}SUp4%#K0A$9QlI_1by)2e>$S{{VyT^Y?lA@@wjgPZipDhr)V?_D}DBu_I{G z2TzeAw2m*Acc+w+ANPiO(xW8qDwLYC=uNIJpQm_Y#&%JAXrRDkT&JtU{_9|o&jbNm z8U^azcv+x|b8w-Q{`CW8as~)IcES1zySE#ggI$q1+Byv*SJdRwt|FE= z<(4UTaG(K^wD#z6$ozW>^=&I2B(?(Cqnh#K+cakidVcN49YX*c(*XB1!+3hjPu1-1 zp|g}}n#IgN=*=cPiB578Ir&)U<$CdMJXvpgmU}SdScCM=0!b^Q#8LbY+*knL@1x4IJAJV9JyGnz^)~j}| zVAUc(Re<^2V+0ReZO^E!;IuJsq0w1O_S=9}wq09E@{XA*$AUXzy6Z9aSiH=uFPJ^{ z_5OJ2UQKDKT_&iP4iHas_fxK6tX_}C{mnrgKbgqRXCle}tKkT!y<}amN zX>TF7S^ogng|5lvwwf^viZ>qSv%bq6q!NNR1Jf0%j)pDRx1{P)PJ&B?+ReZf)EI28 z?AZ=7IAt})THG{vD4U^PGu&51MZECgn~*pnnFo_=`a|jRIEWV=RQpzTnSHhqvoyW&SF_l0PamhO0E6+E{bYP-S7anrxG% z^dBX1^XW2Lx1DX%y?PC`)1dMQ`D#h+kq5{!twoJ4k2TDL44o^g(eD|~-qoF{TD80F zfN*L0c=EnXb*7MA9ZliHSsg1L3z$jUo@-}jA;g^5B|W+%+mzz8lQT%_t?tBp9+lB) z*Acs98scogljUxe(c0TO48u+=#-cme9vqua~buBhXZ|Jvz!@2+em^TACt`abGxp!b>vpt%`U2CKpICy1bbq zkyN{H3uQj{Z{?hr8tBAVPu4J8pHU$*nsp(<8$fSErUL|{+ zZar(d(=|QuftvaEL!3EX#=WmZL}a(*sg!6+hdKEqG2lmQcG^Q=AzNYVm(x%jz>xR^Zdg8qtBBpzkv18=% zER*q1gIn_cZYu<$1HcudYp1oxL0#U-8hc}?kHy!1P(^H4H>hh@vpH?mxoxDj+yu>D z{>9z@04&8TBx!TY=DfRU3ElcsUuwJl_vv1PHirVpMg>Lvi4UNrPbxqE*8IECe`?!n zZ}}MPE!Y0J z@gIvm6L`PD)4Y%=mI3HHeJjgY3SAMyS1!ZNYX$O0+;+dTkHOtUg`Vo-#l0uDe136< zU)G|N;}5}k$zkGq2(#LEFdzJg7QZrX^>6H-`1`Lf@yCtytE)fsS!%wa8(m5M2bVmK zmqvcK$?*r_ZNbbQFww{TzQz7E#_GQpJ`w0{WS_=5I=KEQZp5> zJPI9}y)RDqr~RLW%~Do2sY|i+IPiV7tP6qASD5&V!y5BWPMgB^tPl8I2EEJSzsEld zw1ORKIEeoMYL$*ptI9f`D{8t|k4FM2fA5;&ZF~pgO-k{eNTuifn*E^g*0=D}#dfOj zEGkItUGA0OzXn==rCM7<{{Vc~L{*WC#kgt8enax_*29h8|}D_fM06; zzPPuylOjmfzMxkZt^6|Zw}|0BbOI+Ms*mYGtt`=}?;ne|cV4xtrCsgh_OG`-A^yy* zm$=gqNC!Di59MDE_{&4oyeBhS-6>tBnmuY5%5`)_9~F3+QZ0bjjU>|F>U+>rbvu{5 zlqogoeg*KvcQ-213T{awbKO1)cpg1L#M**}jE_qDr&;(T;jM4ODz@=1{{Sp@ubA{p znEWZC#$w>$SETq?<3zTD%8`#uR|~4;Ef1Q;R-35o&yBoy@aFTu`m-`A5re@M()e+r zOEh^0TKC@@&uindCWvxA^?3#N!yget7A%p9zum#Eg*&6&z~U38x!H@Y>N<6ryT=;I zj?62~AIB*rz=UZwvMBD$S$-y;QPX8>8?-t90D8EsH&)Yr+%07fyV|rn54XM=d{a7P z3YJ=Rq-+OBV#xe!_DlA0@P~-JUTr)}<4H!Fq^u#Lw^e^OV~@S_5ECPS#0+!-zcGJh z`(FxMm@g8|5bXu#m~UWz*+XBozYcE*5zjP+IFYc3-*}(!pu(y5Cx@=1k55QpghLEV zh=;emb^2+Zd5T<(ykz}7tIo9JGwIPwa;6|T!S7up`kd1$t*O}v7%kuFT-7!(rzu{< zn$_9>+hq|%au*fnGF}^XCfQ6-#!mNdPg>L0uAL((8irr<(!9UKx`olxg2N}97Ea;s zUS&H_G21?*9Pqoc{oK*EO7a~_r%uxi#fGTKH52! z(lZ5`0g}_mah#lCan637D{|JsygLLx@Q_I)_eFOY`hM{J&`0o&`+YqtpU`|tnkJJL zt#1f|d!k!TiSUHPMm)#PHYZV^UIjy>>R<4YYG}7D=Eh}XG^d3N2>G#|nQZ6OXEmXj z=o+rM6xVjPn%gYnP5%H(xl{7Q?}4!i?2&_vV4V8WY2GWBRnz7E%y|vYtj{F3(|+=D zsCveSjJG(*9lBz^dHY_Y;tfV?*`)~#viWGIfD;+xKQDi3_c&scOtNiH!?j=QyiD}*0)eJIv0j*#munaERsrth1s2hEXcVamCsU7K67DrtH#kvC6-vcux;g+ z{bbVa+vhT#aedD;S&axlba+Pa?&-}u|aR}*Vm#PM#sSrP6w zqKM&88QpmKR{%zsv)qd5g{{Y0MFZHken`s~3+|tDRkK%Fon$-B0cmAvGaXepZ zjU%*!X2zD?A8R_n_thQp1IfXRWGT)~%TM^CZ9me#X`U(3we4s3<`D$OCK+ii zW|QPyzuk>CyPlX(z3am+w230rw6=)IV~iIFQ0Zf%}Ge-^qar3c~+Nyx`)U~ z_x}LB_kUl+SE2sQ8uS)kHowp>CHp|vCR+wp=_GTOg^2V!WG9bG@!u9&U0Q#lf1i{8 z0Ex-0Z-iRrso{@{I!>#nkkUb@Tgw`rgsh8!`s3+b&~7tK-*{O1X429-8RgMqVQ_Cl zJaVrhK!A`vj!4|Sv5#u>e-8MI`%2ZkKd2N*0;*bEH_DEVLz{WK82(-lSla}Aro8vY zzE6j~BkS5-qp$X5zN0PGhR>HUULEKc+wAU)dklJ->3l_{+5LhkEwEnge%J?Zm=rOO z_ym4%Udl-AUX2!USoH4$OE>%^X`o$OG?o*~?FEE7hxSIc#v8ppNg)y^ z%dyq4%1>-JLyjx09Ou$nb^WBT3nw}9`2(D}pML@=Gg9QOVFdx~y2a*2(DbXY+)HV38}bi& zi$}N(8pizvbhnlSMfQj=!n)%d9F9j8!N~uUpBLMt$pLDvEFgM-!H?g>nmr_6$Hy&rWo#Yd_B^r>&14mSH&qH7jyyAeMpoQm^J zZq7L_Mdi5i+;PY8S1m)Lmd8kwg+X?D*0fi9h;{E=m7DH|$(Zxnuj%uBn3(75+NLra zM{g3m$YZ+bdRA(A`gDXh`l^E8#1RqpDgN>3YSFsD`yJ5Ao|V%R7LBMbk%*KJlDCN4 z51O2zvU#l>LDL4k3P&wN z@cZ`3@lLg*X*Y8@=ReNBo-c^s7x?v5*k{?)s*IPKcL zXA>PyZV}M;;j2b^P<0h*@W{MPbDUPJ9wE3tcyV5Lu`k#o*1X>^;;b#J$<8Z>)3vb7 zH&>!**U>tIS^EMNBV$F;-6d1+UgzQXyu2x{FHepG^&py)pJN8KSAg!tVIU#^H+*hS(w)T;(`9*e7 zuAOyn<$+oI69eTfTSQpU{I%v+8gN-i3twDox;430^cCe-9vX^Cjs2ne&bA znIw#GqPaWmCO~;w`g-TWL@~gwCsy#=A?aLx%h{bi$ng6g51FOuUb~=pDddC_zol5T z(h?KXABB6rgjUpn_K{sQ>hv{(IK=QVD<8|-(z*Rh!SX{XSYrns)%C@#rJ1ziu{CQy z^h^O=5?4mo(D>d@3(LIji=U-4{t-{cQmr(xvkGVhZs^WYv`$Ad&0B%2rJV0mD?-D|v zylksuxIUHVD(9Q^`%9OU*Zrn^2c_xyaJsem9Qs$|_P*OML9g4t7I;$A;trMeOQ~z~ zm*QXSqvLN6-OHRi|Bv}Ej$oV1fZ1gdZaZQWkW5^qtuv2P~5@brIW z)MMFPl)fO1oQR~WG)WrQ{%ynVJ?jcf2W~|JSKfa0Tua!<(KZ0C z)+u3};}tvUu`u4EgF`OQ&t&F_G+FMLu%&mB-0c6zYZ(4@W;bTjYtQLc(1U&AAB~8Pqsri2ac3R7b1FWe4VCTc+bMq zhlWQ@zlD8o;Qs*HUG$q(w$)g%1HE^?An*o-b7k^G70;l>e4*p7hAOVjY`>7I_I5et zN;2quRsR4DYvK5~lvSwz0Bf%d@xZUj-w}9H?#m>0t|dQqd83re>J-`gD8IFbhou=n zDqCOLt4q&H{Ckea^b-J@@(uZbub`N!5^D1#rGAgsL^daampl@hQ;tRJeMp^Fa2#!S?>Uq+g||l5T-Td7>a|fS~lR zqBJiOXud1ZzkB`BUU8}TYSnHG)~AZ|@9fW(lKteaZulDS?6)Gm?eHpSI<}UUx3O*@ z3jFoaWovt_weQ+rkEYNx)M(oW1EqA*taHno?4LgV)!qlP({)mGqzFmvUmN&;OcGD# z?#eEs45*H>TYv3auBbYy)l7Bf`WG5TNoKHcitP26Ab zk7@5Tvi|_fdH(+Z{0Dd4`q%A$!VOcx{ui-}QoYUQqW=IlL**&-0QdbX@U!60#7!^2 z9#y`lIJ&xRp&*a$u6r8&U-)^n@-YnNPLG6e{y%X`zu{MWV|mt{P~?LnzqDJGxaYvdTCr_^Wd@wFDB3O~NaHKn zq4JTK`dl1oi5@OE*3>a8Q`!VHUK)W|SNkVQnI+c$0Ojs? zP0?GU{qHrKmn+#19f&82^ows7_$qA%>da~rUD`o3a$H>Q>6wtJP?aSM{PqW#f^4()7Byu+q8Y;Jz2K7qVKYYV`M~?5o7N2Xoi%Ii;<(qkK{{Vfy&c8cu>Qzru-12#^nlx!5(lwa;QR0ZF z)irsh^WlIt(DOGa&UT-dH~|6ZIKTqFj`%QqQQ~ii*1Ck25PgM`MlCF-1{G3@23Z2c zj&j5RKe`y!7<#gAm7SFmO6eWUo-UH^Tm4JJ-eGHbVYJg!FER+m*OkAzsuMWwNFPfU zI%G5JlWO-7d7&ec>DUpttAEQi7+h|1f$8mBt&->(mZIJiu-4$T(X@~Bjbl&@(<(*? zK-h>CJEc1zR%IlS-vD(w<;B;6{u}DrrJbyrb?&n)dVQU}>jQIXZGkI2!VVHg9#+`F zV!y@x+VpI3A2<9_yoPJ3rwkdHV}uUH@=WdY4S*|q;eDXI@n?tpK`_jzd1G!<`@g@8 z1Nhe$;>#3*@_Pk4+6FFL&@coK(zkpWeRrf?L3L~M6;zZTS4LJ(`{dUrFTQkt%d^z% zG~GkT8s3Xw7yc$JuxTynxm#o#zu*N_`z3k}?}_cB*KX`k=SjbK!VaZU+!$eT{@6t1 z_Y6gLx^|xihp%7!PSY_Jlor<-v?Kkg#v)Iu%OM`y39kZJG4nXsk^F-_5FUVgcCSGr znDi|o#dLoX+eJ7sTSigcmxO53Cj;281K;UgP0p)7h%K8^Sp;{n&hH8@7(i8&DgOYb zd;$K>Ju8LLHE;B2?X@x&%8bStzUqUX5Bmfkz?|1@4X4>O%WJgZBAP>zGqgBdXY%@D zyP%FnPzt}>3#nVLl?B5vkr$+G&gJ|MBmJMHdzXviw!iUA8uDT*i|cz^v~S^&$|!K& zhs;L9e}#Cqtz$K_okLAi`RqRS59P&n{v7cnY4$xEP>_@ga9i&U^065I06ihKm2n5}kCx-^iuDZ&;R-AgjHj)6R)_X`x!27&$oXr! z)2^N$^fv0J<*i^~&X(%KYY`tYw^3N4$U$^d8(4I%f*o2dh}$}xFst$nHc{7scy6+CU6%_@!NCv z(Bl=6d1{xD30H>hE4J1kjqX}_#|$|Y%ZM7>5WKHrSxp+jm{I9qDQP-(tm)zwZY8+y z(z;}4Rt)3i^r$r`t=w#ro!vdEc|zwMcjqLk@N>GgVb-FB1QLDhS2=MV_K_ithUGmC zW!+B<2ZLP&{b`u+am8+n&CRg0NR*1~?n`@QX?OHE5F_6{5T3kAe zX0_Jri0ihA$5im2jGi5|)HL~^MC!Ht`2N&?0IqJ{Yl}#E=QtJp417k_(%!}7lz$C= zKKyjmuFb^Zdz@0lTbEN@DwjI3{M>v`_%p7@btrWyf^lCj>iz)K;gH+N$BO;H_?6(B ze-s%u9g=3ga@D>RT1Pu?QC}-w+{e>Ur6#P8%N6i#)v5WIn$@@P)yB@hrG4e9d^fdI zk1T#w2ASY1NS_{9{Hu~x7S=kIHn%=T(>x<_EMVvKuTIfC8E~qvvvKK~`sYOWYo#dj znlR+{Bc*oMo(Z(Hp67C(V()6DXHUV1n+g&2g%<)||mj;^BZE;l+ zWk%;CYoj_K;*xk;4%O;tmS>2@E1A{chj+>=Uutu_A?DYbpoKA-sNO0_pz_J2)-7`` z7Pu?hcS&(i+L8U9bnUFmcHj!_^eru}S94dH=?@>87QO4kdjvxQi=wGvc?jh2@|5IP zo$G!YUHvQQJ&_c_t~*w@gp+|$R+Tn%*!ina(&V>%fk~t5GwJi~VP1=?=~A?vE6jCk zH@256Tz-VYhkR*Nb0JQ%Aggk8;SZd5IZivBG z?_BSIJ|y@D<35sfi{(a?j67|O=i0daZ}wofTjxte{lv#|5ymP~TRj=$UOe2DRiwKg zMR@1oAH#o&7HnjVBbkR1$B)*(H2y69%f2l5e{}bD(HZQ|-F4(*zOsi;_>u6AP5M-( z^4UM-oWC%~w|egUG5bVlIh#~T;NYqU9XxEI^ zhs~*_>Y=ze71Y|vc?tz|&3%`o_yMJi(_1C_%+npIJO*62>Wy4-mjzE+b@~0kEMH6--sQiqm&5NL`a~S&yiUW!7d9JN2?D*3Pw{%+SOlA%D+-kP26|tDd`EevT8M!o0~Pew zfwgb#f0E;#E8=ZF@>i1)918k};ij_!NgcT#F|KC_%X7xY(&qP(^ws|Wff01XzqJ4o zJ6D5vx9zr{=ymkeemm1Ny$O=&ukO@!uZw(X@h;BqQD~PL2C$qIb&1V_a_89buMy~X zvrc7~=N^^Gi<|2Re47RGn!W1RP|~#k8e@WL(D3eys$DE@2HIcRmeO%Pi`tlf1i!L$ zomcI5QGKrO{#nOgz8?eZEHvE|O{&NR#K1Fe`;mXM?tQEBL*e$7Y2btQttu}zU+>rG z55k`iUuyQEIh?P2SEWk4XB(ntsru8P>-t6IltR>J{krtMD;HMS>x%f(!MAqyeq7i0 z1)X|ApK*SszP|AN)u^@mPfu#jw1`T^#I{P&lu?TEPZLPC;KD1{WxEM%(|Xs3c#cy% z#m#v&PfC&Dy5vQsGC$n|Dl5u$zYxBUr^+q4yP2@Q!o7FI&$zn5)mIp=o%NX`j`T|> zCKHcW~$BX~_AHPPp{r2D7HtEi_LGT-_34 zj$5PT@$+s902u9x;5;#7b*lJ|D`1~wvW)G)>UQ-f)SM1#UlnT!s%pa0K;k<|KXnpx z48t4S-`m=QFil$S9Tj{trKkLSNE&^o{bt@kf52lMF~xTu6*YzM&Y7q96GM}2yDMK_ zk;6QUFe5&|bB>(X0i|l1MxB1T&Y-Sfw{zt}e)O_ug4q5MpU$}}OO0R0-ZZChj_xcZ-gK#0tl^^cO1FK6t7LIjm_%YfcbHX!PWHJD=l+m zcPg@^K4hqVNBO`%LDw1e^`qGghr)Kcoz|f)p`^oXB!C9H)b$nJX>1Pw9OSzn;fYkd zr+_x+J$Hz_Y2vRKi2M=o?mO9`mQ)g4N%P-oS1I>okOQ_TV>mIugd|{P*q^w_t+h>0 z%YAiYwzJdY{{Tj^i)zOMN!$VVq+$1+i!<~m90&MXe;GmJ7!G8^+j`LU6Zr1+d=|6lwwIBC;ZO7bpw2w`S!`>OT zpT!zXetM)Wr|8k`C<=gl+43;KKyN5xoR=dT0USP+X+F8*54Bvr;b{=MX&Zn~mJdz< zQP-vb00E3wdN~eD$2Jlto)F{Am&-yrvvmAMJu9NqH22axH*bCjisIr$i)iT-L=P`_ zrv0Y_uL7Oon`w1B$PK^C5=Di7?;Xc~-~w~o-m-3WdkrT^xq|Wh$*&sQ%8&ke7|8=a zn7|&CtR6?xpAWA7)!>-CDgOXHZrVTi1;kQ$QU3tBy<}tl{%eDmL6+8NtnRq`)~?5< zW5GlCp{#F&+WvKILrU||8x2cMW!m82Gn;&Y(2de4^uZN(N!70I{3mzf5B0y<{#Ez; zNcQX_bzR+8WSJwT)z50~q^}fn$=RI!xVhA=5#RT6$2ZH^HWd7a++&W|;=R+ty4AaQ zb_-kdCe5Bzc`<<&(g*31E6=sPG8HOG8j$`{k4PqabsR#AhSZrF4D~b+yp0Y)SJF z7^F}?j6URS{{Xr+A7U$#wA7B7uc{yi%y0tdqCP?VIKdvA*8P-!*msj@P@T{^g0gg0 zj$^?dn8w^!R8f?+Iw>M;J?-vPe9=1Y^kpl9_*EHgVzSg;3x@ei`Ej1uUO^wywDjj& z?I+5(JEe@{j`6VRil=dJ1`@(!_f#}^=PIKOpQT-_WS#Uqd%}Jpk!Fic)EJMH?31ek zozJo2x(Q~5F9qODvZ>%NMkjzrwmmDxbf&so-2=WvPn&P|Pfw-=citb>W4TzZ?s&mq znEn&ck^af9s#DO$a6KbTltnRcxG?k;>CxRqJ+xOSxIVS!iyPWPFQWy{aQs(GZE%YV zjN&&RFR`we9E|TGkIs#Zd{wJgxDqP^jl4Z!>9R>Q3lKYYimYxWgb5;ioZ+ihWY12v zF?q`qbHAsct!1}@W@~A>de@z4`hlC|&MUCDyK9vZjvac6(kR88{;g?oX*u@q*ox)t z?&EiOoOAp&>K9KP&A-euw~SYrYdTTWrxTC5{DJ9NP9~6KPn3M)JHJY&b7~+Q{K2|= zRoLXrCT_I0vPX9tCt6lPbDFvfDtRPst|wg3{GXoNsjlP6dpoAz%AS>4aG;o<@1B&b z7d%=iFD)e4-D|AUb;+JV4aHp5EJ98kt#Z1AZ)rH4Mg?WDsz~)nB@=M9KuIK*CJA^^ z(zA3q-eeMHwe?0;P0;Ydw7BenK@tIs*1T-ICMy$5yhHx9bgi#8Ib!LQ4&j9OQR{B3D3?|*y$04(bqV{Ln< zP6}MJyGO}BH}SreXLZK{yi-^BpQLxJFN)s{JVT^f%Xeun?yKSbQ&rWiF0plAHHd@e zKAweDZ4ak4{{V>*$uavkO32gxAupNx%K1I~NpLO?6+`|KY5e-<_I^grXWh2{02Y$d z<&5JM1=q%h{{Wtbz6!nZEW4{7UlU8s3MPKg`|$i^&jsoM@pA`u>{r5;{w14bHeG(@ zx^YjGu5<5!@vCUZn^4K|uO}XLes*2>mgYCc>YCfd7c8J|qm-X&eg5ATEnH_clK%h@ zET(<(wkzUstX(k3t1;`hk@q}$QOZAW~=J|rXLqPbE1Wb=eU&TCcO{h&b_T`8tkciDRte?^cCjTmdW20 zj8Ydh%?}PksA%_5YMNAr8KNI4TKy*YMe%aS#MZFs9w1>I(kRcPdyi4-J6FUyKA&y& z6(xDaYH4&s@a~D?OJB5G1@keGMmemhViLR$5{4((OOmF? z=hlsB<8KGJpGcj;uWsL!ceg(jB-Cw1kB+p!@t!=?T(|!KTD?oe8m^P1T?;#}-D0PM z!Wu7#ZG8KEq%i8A;aYO3JEPv^r8VxO=i9A!zzJ_G`fjS~%l`l@*;_x#zI@dFCf@z8 z%SS9L=+7Pe9Epp`ZyzI}TJVqfQ`c&9=~_ite6}jZ;@Pib*xl)o>5^(Pd6gyCgXhwt z-RW6Yy43a$x3H=-I?HQIjW#Tk(z)$-P?B46ty{xdTgWhLn$;B=EOwv|8q}_1S#uDs zm&5i4QvgURzxEW`h0u=Ode@-acxzM9EyS!S0J`wQTE~^HHuLQktakRUn(M?@GA7op zLi+AmGm+MSC4F4ueEhqX$mb2{&_;n#l)aTsL9aZ>nd^GvxY z56kIaV_jMcWyfmsU0%}BrX|#7l-0?NXVmi@5j?p<$-u8o8ttZ!2w+8V7q;rZG(2Lw z>t4INZ{6haT;9mj+4G*$edEdXXxbQ#LE^mkP=e%5_J7`cFrjq z%7|kXnQAabUxX#Ny|`Gb->h^Ss9^JlrFdV5gXvcuVI3>nH18HY#jlh2Rv%R7nbk{> z^iK-OcX52S5*?}2iu)9ucT|$^8^$dwOEb%Ir&g|t)XbSPD|hPGOmUNY;GQ^;(%dUG zwVavaMqG&7+@?8lky~&N6dZ_(^y7EvNf7Kl!I}?nt2^hWsrX zrb+d>Aq?0v^TELuap_Y(0*l^Q$dj8cs%H51_58t5qghV#*YAEn_Xva7LjMi)Oze`s z*=gH|785uzcXB=NS_M3;11YVaT`z-lIb1z6>gi|td8f$chwWmcoynb6 zD19?+4lR>@Zst8+W{_}L7T{!WKO}EqWRRj1DTlb8RN#xeNfV@lb#P4WVxY$|kfqe! z2d~R!XV2%qvmkUJuQiuS; zlZ$b9|8eRqHKIkV+fw@UgZgmA!sqwm|ItCpmfo+_YwdzZ7Y$GZ%t!7;0~3T}tHAI~ z^J~wgrrv@*505WPVfu8}xruqvk{g+@#6$k8i-7v9eDfJCKYSZ!Hcg%v81}h+PXZwx zX9dVIBDzuqKb*Whao8G`_;aZ&x}tHTT3z5B^}uJQ^{PU>7pgjr$k0oRSUy$Ga}~wA zA(kEXYHiZ#;zNG3EJ^fu@!kKjvFGO(y1FpTFD5vnCHw4iql&uC0@lk+R#wKX!hAd9 z)$h2+vo}-%SHwjY(|Q3zdw%!*0(rKg0uv1ayuv~q^j8m9QNnSyn9w_xFgtSC?$|cf zh9`%12rCW_0Gd#X9W;vWVE!1(7*pAZm>7ysdHuW?xV${zADz9dqQdxjN>tuUrSi3r z2==gTTw7{DIqdgKUz$s-A;xPz?`*9#<1DaJ(TOxuD*Q;yF)iD(56{ZQC7DRC;3)7} zOKH3IMtu3qDc|gkj9^D>j*3>3wXIS_*?A$_D3B%lqXv;*v=4_WC^8i!n33x*RQq-= zHp6foOx$N~izl(`b)pW(-I?&+;*zS&)m)jBSH~gvtkIGkx#>f;QFALq#;3QuGD>4S zd-WDH^uLdj{=-A;b46J2O!JRK|6x6?J=^j85B6eDEOs2W>t%Xb^dvSHT3r;vrQ8ai zVd2a|_43&2eg5~e+RM4Roto!<*;}b@4acb0w#07U&wM(+^d3N*iH^8b(CL%D= zSAldtPHk1FEk@--g7Z7IDAzj{`N}yj&jHBP`RA41)qS#2mAt!G{>yUH;=$-KzraX10 z%<|cpR0USVaryX;R4^CreAdJC3K=ts`Mk-RU_7*j?$lJ*+|l))0a%GP)JN~wokX2= zqG|G|{|1%Q0@rm?_&=yLK>;Y%4ylV9@r7N^B$v*dUO4z}#zS^>PUYW^b=iFcRqVcc zHG3>?o4;Nao$0P0T9MVqN>7+-2!fkcR1Pb>@yvyJj;ZK*X*hiAivNC*f{9W04fL>< z%{30Iqi^LVapfVhE-A|rHs)SOa9O%;CaMuF@9uzeCHHTZR+E3F&DJVHTy{FA=0>~r z;kj8SS5h9o3lYQ?!>+G<^Buy`)J}#XRsVX1( zpji%yhkGfCxJID>317Vy2fYL{jrInJ-if_MZKiRBAKj=Kf-YWvs@b zls76fd^fWZ?EA_SoDomVnJfT5i|nzun%)<1dHdFoymQ>)@}}v3O`j!zwich6eY;na zZDtH_0<40USy*L~-P&n1*W1v~yKKu=vW+@tce!V6C0F+1lppJC2Uf0ZJd*%@wUTK< zH8yALq#e?hLT(loUFnEH_t)E?OJHTGnx=LS(FW<|Wr@cm^+PrkVEYs?T6e~aIW40s zlN0NOcfQ`aF?lQ-E!~hS2F9sG>o5-ua&j~xpZOUoY2U6s=q)5vAQ0`oUQ8T2Jx#I<)&9xTwrfr9M8**w7FMzfeF*j9@$uD8r&Hj{nd!=JWB zvuMSxPH_JbFsjZgBwy0Gm~Q5dT;j6%(4rF`1f0(0BcQ@(KlIBR;XxfF^v#q1;%$(h zCn&uWzTgs}`;ijwj6w<`*+{@xDt}ER&(mP9sTNwD{WBHd=>E`vE%CPMMEU|&0}d5f z8ex6>SrVi)Pxd=JkzIznZg4l7=~rdjG!WnUj)w0cgf&to7&wbu+IVA0qU3E0{ZAo= zyp+ckURJZ0>W5rdP$5Ha-jrznEDXVxJ74LwMWeh^wh*qK^1udi?%1Oh1vSW8@sDLu zkZM@>naSLrH{Bk0p7GW#lC?7%na-Z4Y^DV<7L#ZawEoKBcR-ZNzcH$P8-Su*OJs>^ zM;{`OPz2`r?P*JBdRrMen7F9bOek#9L(no{w^u;77J8tSzod&_LMoh_0e9u(h5_js z#SmD}kb9haH#h#vf`tuk?xU3f&De8^CQW~^t{*2?IW;xgZT1S9NqQjhF6bd?j}cOK z^rE&I4|)NS#Hr_h-!kZMXUE;1r`@eb^CfN97;gZICiWnmk!bw#X9IdrO-mX+wOvsh5fB3Fi^G zmp4{_il72TdOV^Q7;8~yps1ovOES>OG`m4_Ad}?Ao7hf_9eLL`ErhFP>29{kf|jow zanRENb(hX!-2B*DoAjUGBG#j`5j+LsI_O%>VSyjb~t6^EZME%byzGos(rwjHi5 zi^#Imiu2e(qE7s#D0ikP5=&4M82HXri4L|?HZT#1jVv$Fa6UtRf%zK$dGKze1L}xf zt1Wa9cYA!R>J-RZHuKC=9zS4qx)Ar($PwZT9@hA5>xrz$ww@J_OH#Z11VV{g9#@Kl70V8NT z^t&-wfOn{>WqK;HY+y;WS@;Y52yFJ07F*sgHvu>{& zy6#rwX9|-y3J|)9zB?hhfi6MDAd0{sBoJAAq#q@hjZ@A|8&Ljc4WHu_R^6XK+m^+j za>*<7IXjaAJ&W1amG4i%>zeK2y(~#(0=ZC9*TrV>8aAu)E$Pj5pvcz|Dn2igW4F&A(E6z%Ne5JsSV}bceko`xU{HFZ`$fSI z0+_V8B>T}_8YTd2Ov?&dQ4hK-XiJa2bD|ep*|*qsRdSnsAX9T6QLwenFNZYC{PhI) z?t>i~Mk}NvrJQ|dFMsSykD<_`To!)uR=9pv#vlEP9UVwt&jT-*OKON)Tgf7R)v>33 zU|`!mMO9}!UZer%MsCk)H3vgS0e!)I+4`9+Mvy6j)PawH>^vSM0vXo77ec{Iv?-+n z6hL)D{qsE`!MNEUQz>hWcm3qCqAdIv=|LUEW|=WP@cO)m>4b+v^2|$by+ybixswuC z;Gh7^)~R@j0Xs2FHI*7RtoicYku-fXHdk*SK*#TLl!Vub;oJT%t?FZ%z-jEsu~sV- z+#}&!64H?G@fA}F=I|#%b9seY1|fEHG>C;+okI{Tj>}CjvPWgLEXw13Fwh*!E^s>!H+tZ1i3cDB)r*}=2z`wu<6LQM(K#KRzoC<)x`UI@0hC~`tk>qeGjMy}t1`{JJ~GVqXXLM9-=Y!k(E(0J=s= zGoO^4WEw=}BvnqHD8*)OsXI1Ml!O%DWrp8F$1YrR=jRW8m?faU0%P3q)4WmHas!K718x@{dj+Nt+;k6_P#j<5l1dyT(Z&iSYBu8PhLk_v|U#8=@SBIkRrW;mceS2R4V@Dp{nTOY`V={TUDq8(Vc2P-;o7D_T5TS)OXf zB{~e6P5zh#zuo9SC*WXe@0B<8Uu*!ARE;YfvUFlcXvl-FOLJWEZ2lXCzU!HNMBctx zimK>9lFYzmnUY`QcG?%{=R6ru?f`vnjyH3>ORo>AX^phi_s^mxzN>shyp?c0PPpN8 zWXku~b8PlWnaV8JJUPgC|K%@c)UdK0C~K>MhEEe92bSn%&b$arVURf2uqOK=eF72g zqJq&mxnAP@XI&+y+&hJPx4%`Bj76_hh~MXLZh<{ko%MRFv1w&3mcU~<$XvC*AY&6s zV>4dIG&vb4lKg9%DDhbS^9+N&O)D7EERI(nvXU-%>y*m(Wy^3})#px#944KE@ps>G z2#v)a5j0eIoIEdhhSr#nMFryX!|a$x*BNG7FZOT?niR=7QSA<;4IY$kR1MUq|BiYT zpR7ARU_RIE_?-VKsRHoQJ#j@DPJK4>%D4Dc_^);l2LXj{ehIC z#g5(o#S0o|>f-n!?~jN{Q78#1A*WZ~Q~j~^OA+LdSAxq4-DY@%sj4*gOVvxwtxF1W ze{iB_RZj~b(+{RKH6H6-_R1ATC$2sCLuYa$tZsoQHrsC-Q=T~~4A!bw_g?INn8GVL zQ1G?G*T$>v9nL&1FyPS$T*Wf>ngho}&(>|}f{a1d*WLwi?_W7G3Yvbj#_1a$&%SUg zXoXphqJ;i8W?jb;rr^8kB;$ha@m#C8@e zB}`w6f6UC5*$%~$u{sKR>_`{mdIQy!?x)~%KW;e}8+i4BF7i{pPVc+BvAviUA*poN z;o^&JVNvRl3!r-0S0-(bC#3$c{2gW7C`>H7PC02KM097osDr_SkYRQFRZksgKIyPM z{f{-VlBV>Qi3;zDL`!3ZKf9Ts6`woV^j$WdNIzJAz73L2y3_Eu5vewM6qWMS=gsKO z5rl&IWrp*!q3P1(i2>JD%7N_NZ$4gL7KdO|eI;J<-u`}FE=wl6Zh4#^4j_rAcC(-uXOzDKp)w>)8V zcrG1M&mRA`ne3D)T3eRT$F_E=Tkv=8m&X%xb?^oRmEK; zkWQ9f?w8`{XpcfJ-n7UT|L$`Bim*^5&7=CY@3)EpC_4!GQA)5ia&4T?1~aL0Jwz9F zp2e03-`C~FHba?xv>A^TuVYwHnz_f9sWFw7zp&+J*Dl3A+xGguGn;X$j^G~S>@pRk zl+9j!j_?1#IB4mz{T|$?3~OqRIw%1FTO6M1h9NpSIEzB?+O15PsNB@^R8jb12Iw1z zOHl&L%+bxhSM>30xvu;9`-OJxvs4^1{s~1K%aq_x{uE2+4c^x=CjQQqnzX|>oQ=JI zoQE|S4ZZVyes7fY%*E%++};o6R9c$Zwb{*5nMqTL#n^*9b~1V=(%en_9Hb8qlwM4z z!~=KG78s#G96!n|+c;Ud$d9y;HZcGrFtA)9$MfIL(pO_L3Y$pt*m|D!8Q zZjYMXq?EgC?cr@Xo>JLXhYL3c#;+N9TJ%#oWHO8=H$oYScQY-=)T3_YUOhxMd9+-^ zGJ9j5{d|zhAE<12>IIS~moFivVia7sofdCy49HP5JXS5=fjc$wqx#0ncU!r(q&#UlRrnW zpB`b{S2BEDW{x&)cF#^LX~07lLs;w1fT_-fwC|4m#Xp^gRc&sHGbUQJlZM-U+PIx1 zrk#niX0Pvq+_{rcrhJ=r^j##eEfu)WOx=arn_H;K>Fx(;F_kIUQF-{iqPX$l4nD=DXxx^zjt&o+&zn_BF_Sh|X3P=WHEdz;tP^l0E5%Bh4x0`YC2`qRU z1fXaUubUZXVqAiJ`HI@oh-b6i2PuGqbsS%pMCA7jt zVO}rrd#naOgb{hiyg=Rdqp9F!u);U`uTgioccmV!eAI(7dkRP_WcwnFTQe6}FI4O} z1+PZMSdWK@+g{9*Q5BF&I*W@pYUas#y!blP(l_rvH6wh`K4Jq@s=TfVld}rRshj1gJI6@v#MC zMvR-9dI~FK=JGGddOww`{2^*5mTMPNe;*`%zbQ|PHP|5K1R zYl-JSKjNn49Y=o%l|G8h2Owh*iqh^|DyTY5F@)E4OFzS8nn`FHp}jLfodn_#D_Lis zdHrnuiDll-pIv%BGO}j(uw18|{DJ}m|D$`F!)+q=kM1v2g!bI_QHBBaH~N&+Y`6HY zk)zJcwm?yf+wQB03%e|P)T1f6<TXfs1S}JvID+8j}+w|avM1K9sE7HeXmZq z?LTrb?Govl8Q+mGKaJSc)NM|^2-pB|g3UnDemY=h#b%RAsb(*? zM=FnX9_4&9IhAk?eh>7M9bI+WQl#f?kv*=j)$CPLgEv~hn9#rJ4oeTYLc2=Ob62MV z!*?2N|Ag$dA4!jF*aQ$-ml9>d_SBMjQ^_q~sc(5^*}(d?C4*N`?i}_588a#^@>I5o zcoy?6%sh8=T@LNZZMlnAOJW*q2rFEXu-E?7iExHAzBE8h;p6_$+K~;kj%O zJMEv21WdhF-_soTGa(pt06`>mSZC7%A0rt>W*G837w=dgG+`1WvB^Om`tC_leo2er$kGzEqx3+Q84iu{J*$|9V0Kn zV-FG~9OAn?lNxpakr-Ru8#dv`_4xh~8(5lHPB6^&Z2RqS;`4|`24ZQ9E6wKqM=>Aa z$>;}i*+!)Eq+gM!d7&b|T$eQqo@9Ns|EB9n?gyZuWBE1Fv3blFxs{N~n!jrfZ8`q6 znyzqc2HZDUTbqMF5PStJXK&!=^c0u=+)?1#ag^Ca_hhC{aP+mTQnpo0t^SWikGtJq zZ|c5|gZTXu@Na3n4Yu}9djQIbK??>+9pS~7TTdL8ou2n6FSo~2x~;c=EfZDJ9z_bI z>rDn(YVH2G?a_%gFda}IxSa4p-9h5YfR?G}mN=I@6mwPH*m{y@NM?n(AiwPRjr~dL zn#*c(nXyn?^S&5cfi*47Y*BJL@H?rJ{$1Od6nG5*8uJ4y6tNdK{8;U|E%+l~&ID*c z=WJk<#MY;HOEpNtW!D#i-@U&cZ{V3CX8Kd0`8utY{R(|4bN%gAo6p7+Pj?Hr^X+S_ zCTR!Y|K?mzpnz{;Q#Ckb1vXOOY|_={Hw`$ZiR-|FzP0msJjgO`s&816Q2_vF`EUC< z*O~i&%>^6?EcZ|r?kB8`+du3Z%W-}S)@+&_^j|EQaN5i@c zi6CU?BSLzHpTp(QK#C#ZT9QGsTLI^lYEp(w8WW#kJ~-t>dk|zoQ6kcz*bbs1TX=je zO*82H50Uv&2$3`SX|JBw$vuWQf*Xm2?8$&si$Z1Px=Fvck`a100PA@?Tc=?xkB6fB z;DN)bWeZG|V3FJWB|W>)E*uNH88nIAZ$B!I;^&*;abTA@+BoT9HNoa`z|?wgejeB4 zd+m~75_1DA7T0_2FjIBTe@-+cPHcgQ( zn<+{Wd%&YEVMun=e+qi&qH4Qw6&t(N-`Dq`3e!PMu@gW9aPuVz%o~9o-Nmd{Y$sh?Fn<-pqJ648sC{LF4^Wra$?s9OaF0LDcEA>I!F94(HqZ3( zjuAk?yuA5y*I26~&aq3^c=25TI%#FTC!n20nP1buuHnc z82@VShG8S1Ze?Wim-Q$nuUm#YfFzoQAwUyPDbMsZ{F@dR_xEIH=Jv!}Xwpnu3?KIx z?esrID|bGDwg1sQDE^xn3={Y_LDhHPrDFcki6x1j`)IF})-c%*PCg>Q?z=B{Wt<9U`WKKPs67w79h@@fdR?kyUH7g;3WE4$lH136Zq#hEk~SehN@(E z5lGJ$5q#{P?Xhf)P`WYj)w#>|x8EN9pk1D84Q=_CNHW$UZUA;>n{{+58MIH5aD5+( z|FDpTKR3kCei)td=Furo?UAp&^;X@9ozJ>StlU33|G=Udawk}Ql>ithz~r(sJ)6h? z)?*%}En!Tjg45c&v;IatNNE~U!G)Hh&)*}zNv)({3AEzDlBI)BY$97Z$4xL6`j9*M z@G?3H<&5$C3@o7S1V@9{aA{w0iovnu%01(jG565kg|sP%n66Kc{EGaI*f!ULLyILk zh1Agt1`p2=)p}F4Di%_pL>3(M+V1x^7Dv|SGruO=q?}3ert0x^%ZQN1rL1OTWa#}k zt)^I|S7)z|dB)G*!WNBYDdyx?I3z$-$C? z`5mvk+kU#{I7Hnmqvm~DNtMgrVWsLrChu<+}O0Ipz=p8le3oXofAOW6~uj>6L;hd$*<0jd$XC{bBnm?CA(poVw zh{knzDAg)N`VZgix$@Yf{M6*u3sj^1_HOa%!Ojd|*!r6YftNuKdjbs0!kM!S%azqD zLP4APBB{staH~`wl7h|?Bs{1+UipfN6DH50rqyKrgt3c`R@l+J`60HBX&>KICla%h z2Ab|9Z)aa#t0qNFn@0S0YrQN=N-mjhTQ}=XKYe#%$ENOZj*biM5EH~(PGu|)I#;2V zfK{mR6)&A)HPO24+=ukr^CCN$RsC^8$nPPSnd_5jD(vYV&pi5F%JTv;&wm&{lkOUN z)H8XT{G`4=fj9mzH2H# zZK*uN?Ix4bKlRhOZ1(;!{FqeZp~J7Oj3$}0)D_smPSI4t=14~YkVbM;YWCo0^Zg8UH;#d zHdzR(Og*0vVMjGzfb}7jp2~(*<}zniUeST}Km&Z1tqU#A6W1QDMr#}toR4HSjL=aY z5oF_)1;R3;(VrgckU&ioQj*FezH{(om%}yqHNIhAyvKB<0>b5LWSpA>=wxd- z_B@*)ljKwQw>QJ_8RJRql57_zMX9LgdNsV_>i0r}q-ae|6E(W6a*Z!k|BO*%zZsRevR`zgTJ+vD=t7SPa?#4vTc@1 zTM?j`e)*G&dpm<5J#sj|OZwIxJR}M?J91ds_Qfz}oQqEhmxWGQ=fbrRe>q2LXm1>8 z8SOILcgK!<0!jgN^-E9tm)4rLg&NaW)s-!bjo>lQPP8SgqKci(-CBmD;E%rLFq2t| zI-g!(i-N3XDmb-sFSJx=K4kl~%%-DP<%GdMyS`Q%eP?&sgz@pX$BRVifWvZG_L0wf zGd0rn&S?fL+T|BM$`l5Py9}hPX0PQvi{2J9jgl*`;vG))9^E+LiVOKIxeKYpyh;gl zD;RH&NW^CDTZ*mxR;#oBVBx4lhieVcrj5p(nXvH?)x>PS(kfo5#OjgN?vB`HD!mB@ zcg*Yk3=!tzjJw97!zK4?&K_B}?CU-6u`{=L`AxIm!J2$Q?ZY9h6JatxRASDda4+$l zzsTPtOcjm8Ig1-?@SWa8zS!@cjW^vgl>2p>8$J&ARA#9+&l7gj?P-x|NM9l)3Q@%? znx3X2C0^x#KtwzU4-vqbtn!`t*p(k;iRGTK~gNGRwlfKh!jyb$Ptoj}go;tEEYA!n>C}PBj^NkNK!J!j7 z@8e*6Aqp@4{-eX++7I9C@o>5=`+?$l%7&|T-tT9uCGJ`HeyF160S98QIKZ+$|L{Cg#w?5SdQlGJ5Z1GDS-u@!FTrDq=oI$7a~4_;Z1 zJ&n`MGzWhoGT!T1Z6AO|zHE_WDDvtjRSt;f4@td!#-F+M%U{O@=NRpGiP9WZeTN9; z9sVMi(4}@Kt}bbbA*r9c1LKgT+8Wi({>C99I`TjHN@AVkd9o>cKn&c%LXPK0j=wxh z^V&%EWQQE) z?zI8dKj&Uuj>NvSE!}}W>VB(%&;UCN{mkW0c|!;vYnB4|dDPo5i3oUD0fu}(HN>k6 z3&w{;ZCc(11{k2T?lCUMvuIDd`e@=29uBC^0qy&HrkT0rLggn@32g;BbwGAw+A{ZXNX%$d9-y`&^A*Ri^QXXS4 zQ6UW%TWl7IUH4W9^i`^~2ZIY-LHGC`JLRfn&@C$g2%mW^p-aUlARt8ME$~FhG-T;3lissHH{C;}b)*Ba>K2vMe5P(#x3S_-WOa%1r z)U7kE_d<}`grV8t(B28*M+NA(p>Oo3Kv2Ojfh}&VxHsrhZ6HhX?y0c2TKqB(e)y1o=X*EQUb{lI50xi zFy-O77YImkCzlu+W$2~ozrTPS@`$4fWZE20jaZmW&ztfUU6ct(c|wi?8XGlz@v%r~ zPq0OODamjP$qOXP<9NQg!9NqgNqIN{9Z&tQ;PoxF+66d(KcqMwA!hTBPAgS-vZ?JcELyBMhB;UxEL2Fx@8WA7LB;`zyG2Pl4TQoVzv0#Y=t?|X6*qob$suOX zO)Kx8Bx+VPAXNg=%Oz^$W|OX*w^;XAx-4TmKBCn%uvWz~2oJ^;?t#>9Qu%RoK4bsE zb43v^ZhSenN4SXN=^4xthPPeIKa+f9q!rh3kU2Yst%Vb9~Pra&%9dRZl5k&Nc`yc*Ig0;b`3n}_AI6rlGkPlh0SX99sU-RTR?*@`SWN#OPrjFC_g} zk;G$6%p5rv>$r~?rvYJ1i?)1D4zzXOZ#7ZBxlmnS8}}8zx?o#CJ1zVmTz^_XWi4-a zIT*(&Ji0;fqG{HyUCu+8P9GV={!LQ-ee)7e(Ud2Hx+2|t_a;&1x(U3;c|n`h&x`Z> zr1OQm1SkW1fg(o~omGi;uUxLq75Tdk?GLjE$@)|Q@~C9GH9OY+$TzNKe$ZDQ(=Z+K za?(RoWz2rR1=&uW54jxlEPi4$+Y{a&nAe-2-=}=JHR~1T$;u}gV?Yj)6xg{e%whN6 z*0blla!bJ8)5>;Ir39#wkB_L{3ZtnSOHBL9HjCH5!GAwhJw4XoQDLJy!z4o-NiJlL zGV-60SIhp5 zI(C=OLRAsc*Ajcr3fv*ca^Yqgnp>*q$A^x4+nXRGvc^x98uz?>NRxQj_0z~SqYs8LpGw@VwfK&(hNw^g`Zg#mbOVz_kl*&q4Kq!v8W0k zeP+N$NC8p2BL>^5n;!x55*gxQS+z3qD75?~zIzaR+N;FW8I)dor5tKE6Oq{{Y0{E( z(@x;K@BK*S*7B{S%OUsAbNrmg&N@%+N(JZdH|E+qQ9rb?oF9TwHV;^lI|X|JI3|`b z>iwp$nf*Wgrti4V;yU=RQ&h3~j5;_LJc|rSd>2^L{!uFU`RMe_>#BP{Aw!N^R;+Pv zQ8H$_l==YB{;<|nC!Y=5&b;osXKzWE=gVKqr+fQWkloCHZhGVNtt}3=Y3j%#YaL5e z=;MY~qFh#gQ>_!dde*c_O>(csU7ti~`2tk{))|QHPfT4-5PO8TOiB*EI?q2YccQ02 zlWYt$^@ZEDHw5Dab`#4tfr7;x4ApYgGJ zYjr|6>)W6^g>N2HE(zmAXeKTiBXrq30WjnQL*H*lRV`^>_ zYZ;k`!N>p$zAI9#`}hMz!R(b2Trj9qCRV9TLwbGlRM#a9&Zm0j`x9~C(vDTJVHX;<2 zm(7v&JXn(UFihq1E~mN+%+Q!ZiZbDNrVT?A9iHJD#ber;77ldJ^m3csUtk{v+w=ZX z#9+mW1E7AwB)lai0V#$EQ6iPC&{T8*{}B#rVNU7;>$4xt@@}Kd&7n7#Bq%1hme^d< zBkye1kjZ!YqE*CC_4foaeN&l%tQV}T1Lf$ntU3s|WxW=V*!vrRS@|!tD6PB;Q4+E|0z(mG&m(*`3*9nl3nC-CtHrXOcz}N@5<$k)-x+Sp01ROgGt~5$ zAQ^NkUa9a6jU!R$O3tuMv=oM_5}O3+sB8!G#nJAx<0K5#f;l^*k(`s7t@4+*pC59f zbeUlj8ejX)`UB5>tgd0Vb#>lO2-JVFH(4eiAH*F8CYGr%Cryn^?~&HXY5C^U5QDYE zJyv0x^m}9d-)L&TV*+kF2QaDxyZLY}M2$=A8Xj)r zLE-JS`a4J>&$4?3F+`>IR)_t^fRZ`h^r^n8Cl>~T9&()mCqcBXzGD(c=x>fw=$6YX z;2)ISTS?d|k8k23lw7!4UwXHDv8VrMkH0n1?@J&eKQ8vIG9i`pbIcFKX$E=qf`C6? zwww=Ceh=xrPk;XYh0a~Cw4XoM6>XYSSY2+`Mec7<3U+DN?MxAp@gI=`uf@AM;q~gm zn7gBiwPi`ggV)QYo4Mz@P{ZZsT8^xHXR+Y29H6wnhw8S_-vcC@3tB$^3(Nc~kTkj6 znA>Qs9t|y@r0C!fVgnOdk-gxR_M2NYj$p3aBzDty*{4z@i}D(Cz{O7CK9TK~bpvRw zBWN-*#h1sj;~yYMo{8+=Lwn|(jJ=rZb@XSz&=k@+(=y{&L04x zxqn?! zHWk7)0Pz#w#LJVV-T_R?fgiBf?L7XjPiHQ1+eaq$1AHv`$Fp z^PO8QLN1h-Lal9poFZ4xJkoDRzKE~>$d?LtuS8^H(0`^O!yRAB9C$wM+C;=&>GP^O zII8$Z$Esmn3YsM+sh#X7syf~op>}#801o=9lfwSO9y239yEtyP{yyJoH^q>;#Jz0Y9tq$?Z_F<*+I5xd0ue^AS_9d|$=ttVK`*6f(=nDqEY7y^K z0jVw`f`Jk4cqx@z&M8tKSp#MI8Ri+Ea;UsnYbt4$%niz^NXMA1MgP}kUl*gJHL30I zsocVS0B&4rLpR!Far^RMcR0NomB_<@QQzR0I&oH2FL%{gXh&Mxa__k+Up6%?Cm6f! zfq#zpZkgO>#u^+xesnjwP%#y-_N ztW4LrOcNc_r~rOycO~s)zkLjt&4vQH6|H@1EkzMHrRL|jHy99dt5*ojU`y{GI*DC) zFjZNYxUN6xa8N++k=Y)2Z8OisZJIIjF*$4tH=|hX;c&6he==P&f_5{=z@l%sBuv#8 zsnocL2*Ox%0jUjb>GM6^0=~B7kCn3j;R_SML><+1EV{@Mfc3S4vSoMg}b$s;L>S z83uePI96f$36g!rr)#XyjSqFt>G05ZE8BQ7E(WXBXJ8~n(#YFMflh6jk?PBBCF0L# zp4d%4IrxfXS1U7HZQ8nEK{Ug#FUJ5xx@~lH)EbiDMv*=e>~-1dK|u53L6Yf|eWm?} zmMZg4{7(@H(k+9QhzUbY7R&ku9XJQ0w;K4Nu9}QaFJR=*+0LXobEl3Z$y^d-ym{9z z8m{$ZxL^Mw&HYF-bxUn~95PLN)i)vUy0AK%z01h2dVy=_S5^qon@Vgs*!;cqZ3k?$ z`{SLy?0NC-OkcNt$>v-op7zXzdXO==cN*SvH<+P0I5$U{cs*4kC}JSLiw}YvrWFG> zLaf1v-$Y>>e<2N2;Ma(WBhMqz{DYBrl#@_PppDnBAy3%MuDo;*J1sh}JAV3`#Je43 zE|+H^6=`exd{Olp;>?qiowR~Pzk?&7z6GOI6moBZI#5|^6;T)>L}dAL{i@32!Q{ic z4vE`(7&A?}j8#%)Zs)#2+E3~Aw-iNU%aaFwKPyMH2r|vZ-F!=vXBW?_b*v;$UJC`f zBVd7y_#QuH5scmV`Yl<#v|22T`5#@UnQGsi*685ZkaZGb53Uwi;WFNRUoWh!LZy~k zNrV<%KGF3#&ZGB(^j$r{aSISSov2i8H_*ITD6XzKN0O6pQ#Q}SQ7Pmy013Ye6R#HK2ogZ*Bz)Zbe-r0Fr#rM&^f(wzjr~M;VOoUXC z6icIBI8>1M|zrsde5)$Arlp7!KNf0pP8MEtI(twI`3qAs1OV69aByb+#WR-x>I zPHyIl`pk{=V-3ycdgR}p0MPJfF$QhcITb<1!yU;~@ZhmRm27=Q$_r`+muZH=aED3~hh?3UM^D>K83a80?gle@#7YeTyJqfi z51m4j#cs^n!;WhBkD-7yBR+3d$rjb+tA`~Lx7{0<_sefwQRyuv1jMXk{G1SqmVQ3x ziCMWm*UpzNL237CFdTwt%|%pz`ffiD{5G9OFZh9`0z8k*e1f13@r}W2UTineLR`lX zoQ5Bpj#+4PHph957h`_&+Pzy#W@^C>qKHvn=T$#?uS2@}(6mX4FJW~pp-j%?QZxUt zw)|Zvk(M=^TNkX7aSvE+=7KhZSI$$w{{7%qO)qy~+1ZpMeY!UBvSgHLt(BL@#0oxr^^vE5EW#~`$ zO2qZg&r*cLT}#U3y}wb7e$6?L;KZ4xO*QUKqiq^H6Xk8;erSG{HX(?Lkq)ifEr=ck z+%&$&)eXXHj3#xKT0l~@Fh67%%L_rJUC;R})Y|6_(xf|L`i4VSaIL@--gikkEnK)E z9V|zTw5g($H3tnp_ZCZiUy$f2^K#%n1rA@8ONst83yEs`SOm}9DL-tiEGW(ngW z*LDWLW@eCkHW^g&z5`lA>YH?w6%wL*$bAr zL|t){QqNa!Nk%k8uIWqC#lL3|TwZ9<*^8T5OOP~ia%!_{Ir+MvarUDTYGF1{@0Xai znm2RYWQB4oGrK8N`%NUtvwxoFibc|!k3F^z8xSmyX_>hi1&}Bsa)xtXwAkK{1K4^? zQEf$~yVM4CTzXbs{P(mCc>B5A6RHe3gj}riqR7}o{yC?6qVZ}JZ{yN2+X2miaytO- z>|@3l4XD+212{p7kmZ}Sq2Q@;+F1qgsiLZno&PB?fPj2mS4T{3t98$i88sigmktosGQmUuQg$fg2Vb|lOb#K25 z?vEOLNYC~c9JEkeIbaodV|=C_q09`~Q?E){5PhtZP~fl;_G^T4DRUn%OWQ`A%nJRl zEU0~{AHSP?m#T64ZZa`Lo=sE70V++%G8$8SR$~RmyqRs=N*!xGwyd7?JHIFz)TQ5` z;~rAF@GQCg4PJz9c@`ds3_olY`B~3Cvlv}>W=hcz!jwENqr3_o27X)YGNF;U6qvG(OF;#MaY1gXY zDkn8vm6OgATLVj9P*^j)$>5oDOzh%O&&|SQMUjDpGL11uxvf1nFWLY?O5-V!o6RqOFyq~rldwQ%+ zlB6wv}(#`H{G?=(0Z{*R(7k7xS-<2ps8 zk|WndrOZ`G&ZR;~i!k@Tj$Cu!$5@V#Gj~yzBDcBbX1PVk&77P2-kf8zexKi;`)7~G z9-sH#pV#yGdLEN=q|zSaTzehS42*E0K1V6i8AwpT#Rh}@;*~}Fc%I$5`IrEJL<;;6 z2AJ)GM;k#4=Bf{zgVTK~+HUq#HA6oOpw3vdPz8ef&{s=oYHT0G8u`WJ(;w)$c;8Q( zPTi@N(eqqoGdWc_tyJPd5a!;MZFCNTsG=wLujY5k-AXvV3ZCf5g)zyS+W6~z-G^l2 z;tzdlV6XrfjU2E*gUpLV+f2;kwu_v^el$fvFJ&$z8GT|c<;s6Ek%fB+t3oPQku;*| zkJh?~ED*$?TyFtf^&bO7Ko2beIx;`J1>C~2V>A~9%m=t0z-KGPWY&lk3|`}Uuf~5i zf1ZJ%18Kn%OjPU76M8}JlRI2L&2Hu~kqm3(C_&k+Rp%VLwYhTZX60?~LOo$#bCi3M zTkJ&+6@I3jzK1Heh9%!c6NmN6<6qeD*PMyh-_7|iY?y=G72S1WsOC|f^|Q$c;?rJW zM%{o?$mGkvEFLyAmepN6oPB`ci;eb~67UhYIN!#X6|^yCeOeSOyk3~>7|O)J&nboL zY33%URn@3SnRY+c{**03r*Mq=; z2a7l3_8yO3=Wy1CPe2jb&SBHZqC4DZa+TA^u#Ued>aWZcl*Ra_5^h%VyjA{2mX!Su z>XJ0RUj9`cI72{W+;9d7yet*SaoQc@*W@qJf)FC+WWE}f38(Rxb-JLupy!*-Z3p&I z$woVww)E1cd+i+&mtmPPM~UwGwpq`#CvW)5nM@>ImxJ9`9;a1tgdbh`;p$ky+$*45 z&dzRvB5lWA zS9;`gcl*e%H8W=1`CHJ%xgi{j(wP#+>+3pJCA+;#1zn&{@JSj;E~L1b$4}36QbM@q z18bGuD~+dv-=+>!!xDCQ@>Q$V9E`8|J)9e$VDf%{sY?32gw~G(q1yrD%5j?ZP49xf z+|yRm80v4E`cLTDhcB7G7alm(l!3I4T+dUcsHPphzCP#;OJMJs+?Ogibzl}69MLG; zISZ;ELc)LoYUpt!eH(J9lxE;MGHzw4f&Syau6m`?8RdMPqiV>Dk|g80eR??1yxqLQ zu?fH8BEha&T^)tK_mcBDBcsa;x*te2I@_AyO$Vyahf{TBWz9bNO#tn<(~r)S?zAez@n}(cHCRg1#KZR+jfaWt=@NeW1h2fx~;h zL7S_+iMv?uv`5V^!QLQs<@@b-tCVXrEtRW0!}dZ<#w)s`Zq|?UXE)|ww>9#a!2YOH>IToM?Dab}tOk}dd=~M~9i1&O z-?1q<0-TlntWAGE8XjiH#5}ki2)D;n)? z!L)$Eo2;}q^WWzseX)t|_pax8^K){2z?xtxAA3)jH^>bP(fdndE9_Y6XIGiP^tkdO z-N*u|-n)P1O58Z-s=7W1X9{=*W7|aYkZzl&+>|iF&Uq6zrDkwPqnsrH9QRU{E4?j) z(g3GSVK&4(!Ja#2P3i4AF0*jOJ|*fnZ$@)xPqI05to$cMeRhJ&zDr`D#lrfgAQ^1T$g_v`%VLuXK32EAndj6GW>*BV?1`GkTY zXhUR}l5k3klwP8a`k5OZbZpx?SAAMpr}l4cJFI?A0~456K$={?ZZh1*ZexcH3Zlv( zpAVleH=Zrs1``xrna$0=n^^w|W&br`VxcdtIno)yQ?6|-axwVl26s6@u`c`Y{umfUV27hGFz}Q6N&2LagsT=Ma zW>=a>c~9%;7a*xfR2~gZ?i;OwTHeREm6sOLxu7x5f*_YW{h~=Pmv1;Bn?tLYCnnT z&HI&4y||l4h5r~xjY@!KUKwH{f58X!Z=C+RQQm%RmOJ5f^aU%fQgFA5h~H-=JEFgB zWtyf_Nw`}};{o`$ppS*$Vd-AG9%YY!O!D)E4~y(Bv9t?#g!O*bNWUBCtTexn z>sOjgkMX9bvv(|M(on(xCh<>6bzs>mNb)gd7#R@H*04f+5;c5ZHC7!So$7TRlG@R_N%&3-^WQ$G4#pb}Jf*{HZM zXb@eYnD7C`Q`N1cq;!k0_A;q4c3ke|!-6kt;!6}A6e&N5 zRP^WR5ZMl8iT*X#D3}OuEArDrmCcwwqAaT43Jcg97mf{ZA8NK8c0D*XkxDlIO>wtC zXLn$X+jno`U!MMjYW2I<-VRzJK;nGT40VkuhtJ~%pYr|<;g}j7Rp+9geXwN_S!$of zOumRdE%ClnvUpc$@Nuec-4S&dunqJ6t^t-z|)U$3r4f|?(j&r+; zOK3gtRSEgx@KL?aOdd@OwUSBu{g2`7KZa`-KZp=Xk>?(O`l{n0peFFRIlXYvzD^urLyTFap zks3FD*CsA+F$2?K2k}?Fn$m-)w~D=N=VuYp>(snZz^xT%!z8dH@wz4HwY+?S%j3S!u5~R5$CqAGIGQaF#I$@vP%fp2jm6Y49+5_53U79 zPwJ|^mif{No|s%4dIbKW;!nGdJnRYxUOD zY)ZyL>sZX-YQJNZnTd+_qnfN6dsSayM{m?b`q3R0VL8C@gDtpHQjL2>k|j@S*%#GI zFdwo6t0R#uZuo_KtssUqdB)~yQ&80{iAbLNRr_k@;C<6v9t~8oY z^akz$$F2J~Z_u}3c8XoDpGX@SaQ{=U>}t}W^ghpE(V>!=-?d?XK7rDn9eqC$7+sv@ zKCSVZD#BDy;^m#3Mw~ROsQ03rP(hb3f{q3uT8HLrM06{h6N9%V8>C1I_H_AOD2|+R z%9Y<|NJ0!JC5V=<^t{(bgDFLA135S&?AGBKA_^*JvQkQ2aCairzyZQylW0?^>uadJ)+ed-*P>KQLziq3!DQQZ1K8U z9Be?H&camR#mVbK1AQ8NHM#zyxCC!!EBM_`Eg37IJGA|vX*kK@ z=>0W?4Jeg9gim?^(i4Lm8608t3Tv5#q3^PqAvuDjXLCNSwzST*U^ssUJVTm zx6AqZ8YGZ09db2z1&1Ap5#j%6V2}auH5H`*I!ePX2|OO%!rW>8dr%<2RsD~_kPLe8 z@{5}WpZCfa>%6?>5x41Kjcmi%T=16zbXy#(J;S^|*KwU}{Qe>G*ME?6$FQW?p!65@ zk81dfz;sSS3}H&U1gw=V;5&!8HR62=(@p8IkysgP>Un&BPIt!Sx4^J-VJ|BGUO4I7 z843(Xdz|n(W%^moksJR~Q=w!ygwa*JY6{IoshIA#aqZiSxARGe%m#AY&(HkJiP`+S z(+`Yo_=R7sD{m<2h`8u`6!z$J-aCR)jL|lu-+>#$6i5=9MIGQ`d}*E*asL=DKO7p? zq{1>10ai|VCKbF_2`8J+kBpY^CnS9h0w-o2QgvM)5q=lH4b$pO%E_%ex#dsJntKUC zWa>vFfnDBA9y)%s^;z*STBo%6~_o~VbGWKUt zE~-_T$b_80ErsFcVorxo**M0--k!fH6lXs}`sZ?0SMvEv)pILR)1x2E$b0977$ zqvPJ^ei9@u00#2N+Y+#n>rouXzS7coYFqoD=e*m!+rg=s-eS7iWnU5}^vidiOUH`_ znoCV**MGF0Ez}g&|6`lb&&OQ4IYxt-rCsLYh|D*MP|}ZGSyFKms=K_5aJYgfLd|GK zFE*+UXpP=X&2BudKP=Fh6oW7?zV9C%EAH@4yD3}$pl*5!y1Rg{wyfi2=$FEFsCqr< zaA@{?W_S+--=@z5o;{e1+eemhnCS>x^YS_vR@;D`l!N4#pEss z*M5w%7F?+Bt?^e5O@F2D-1jya`I3Jf&BL~9w$72Su(41)TjrcTiRt#c!1-B{n!Dc} zz6O@88p=iFqQjfyiAAj|)HR4T$C;6zXXEZgR{f!gMLME7v zpZ0oaU->1Ic%z8qM9H=SljJo6ZshdUPby7dtpGl(!9Xs`PLgQ`$UGXFp+wSEmJ*eY z{xMkl`?v(}I4;=!b=^LQv)j^M)(n&#N>gr0zfBKqjsXOQ*UdB-!D7<&ThytgSaqfY zJlEp`p{hnz!VY$Q2yX&pmi%K#D8zd4<&&Ghtws4p_9q5)y-lrm|1oUuW$Y(>u()n- z8G)CQO1IajI7ZTqOWRsL*n4N1g74A7IFCKYd9@^khckvLF##mMXkf^xdH!)~np~5J z{h`k1y=t$fDWI8c@2*Dlcg3iWrm>+(Le0@6+pG^_X-5m5Q`(D~`w z%K;+(vFBt>H4O8-$i^dz!0@UJ`bQ#;el*Vsv9oweNXzzkWc8 zo)OL}=y2qDEP_DyslHr* z!uuTRU_)=?afiRRO+X64+OzG-u$J`+?8mfGm{nVozec31A?q`{mwd1ad_-f^)}oAG*;>vS{00R))T<`5n?uEC`J@>VrO_qF1n&b2U zrDFp)WA0}{lX@&>z#+pwc}k*FspPSyh+>lTeeNl9{lpF_Dr8a?bSL1=0eC9YvdpWb zp(skkpi%c2rigj)`JCiS!|)pmjiynbo4{|hMt?-B!DN4#PII^Adhi@JNu>2`JQ3(3 z7WGw6rAJdXh@iNaw%Y&sKZ~9m4VGT()RdtfyAAB-uO;*}KcplQTRih$bWbjq=WimI z?bylbZPzJTgV;b9Wdz}Jo3EXxEG84K?U%V=o|eh$P@|E@L5_^3Y1rb!(|syg+hX0fe{?gn^8{ocV>xrh{n$z>5yP1qJ_(sB04KUiBfQhl<*4#zf1{*b(*> zg6dvhbhqF|&F>zGogHsjpS_#3{4|_wMyQM8-vo=k)i+efwC#krZGF+L3fe_Ki^wxY>VymR`W5B*{!STqrL z82o~jl?{Y=i$L1rfbTCyW zv0LyhZ_vi#sw}^w;(tOSx$@r5ZeOJN3=1Z@Qd~cjE^FKYzsop!ZJlUNOB-H4&zUds zS46Gb zyUR`5T4fQ69xn}g9Lepr!o9z4Tu6+$u(i=%@+PQ|Iu>4sRl}>_2ZC&?4s$p)fCi81@(iJ&`2Ki5oq6P$uzWy{TKBTCRy97feW|GBgHpAWp@^A`q0b~#Q6 zRLHs^52ytY`>2$m@eX&hJ<^i4C&k=L$yWf9pn z%W{ei7M|Tm0wVhYt72l_+9Cpf>v6u0WV{vc@gVk|2)+(LBFfX5jgqdHn#12Q;;rI7 zKN$FZ&l~_LXM}LxO1H=e6LMr+Oou1Cju`KM(S&P#fBszBCFyTVQ{e#}1iC z2SD2(k$)-Sex-p>c1JdMdYHrY)1><@{K%rWF+1Piy|l(WF21+CKed!!XF4&?q~~ol zooy3rE7umJXn@rpEWKTxp>a*4;ix(p;;s3_ykslPOfE|0s^7+QKcNi!#@@JzIFP41 zeE?2ft~d9Y>VUI){dO(c&NxwgM;hQGUo_&CXAMi7j8JC{gtzQOsocV1tLUR24%iMC zfcul{$iVbC{$R5(O(uoyOKCH&UwUgz`Q4%*JF#-mSMl7vC-b1ji^JpP3we&4`{)~O z*Q*oUoyJ)w?gYWemkBid_XG| zX*!rS1fL1-R`_(O5F#|#YVsK8hHGR|OzqP0C`~c7P5C~Z0{x3&_7iEvRqkvq034SC zoGe*SW<2pe>K>Gx8Ft#!F8VL&J#n{X?b5IL@nwyza>A;M%s+;U;+wgU${$S=%>PA6 z@Cer>l=HzhYPOheE2JmvxGpzdW4mIC$9G-qvOYg7lauD;92Xm27^=P$Wt|&*EBLQM zS~(FU&{tq4a>7yN;M9;~{`uwNf=V9fy#i0yr=c7tB;uRw*3ZU4{F9$ETszlVN{=i? zzYWcQ_`v*vG5U!0GL-Tui!TV$=9|X#C@Gj?eA+-!xew&J&k9k%3*siDOQQn) zecyR)v>k*KpP_soRC*e&-^msg-}L)9x+Tk7mQZfq#WJe-z5~7EQ~XZgdnFcr$!wKN zhoxOrOvx#|L_?sQKM9*crcnm4Plgn<5zvc%@jXUZ)1OsM!XKwE3Vgyq#+|C1A9v-g z>4C%7H~uo?^6axjN4yvARHqzW`+X$!P-Cb8b(JPYz7YTUG|$x~v%cGUF?-(~k4Xsw zy(EJUotLUk0{vg+ytQ6M6pkTbd~UL*p=c?VZt30eRQnT^va*?&eA^lBjT&CWxJ(I5 z72P^+;@wfv5y&TRh^imWjDF%4%tP+P{_y)Ci@K8NARHHu^S}L6${_M4BMCG7k0El7 zyAg(2yZKjQu_Gr?V7SkD72S65YMMMTzH$^nlO|loUGKJIyO8$vqD{JA`%r+mZWURe z>D&6)$x&oxPZpC@L-QIV1zMOzY%KOU%l+P_crfqISl-5abYwFZQlF&#|BY~DiL+$M z!xHgEZ5Cf`=}IoEOrw`y&&7Y+tKp-C0Y`q)uFX&?xBs0BG7k6nYum@R(&3QARfSW; zOz;)I(t&nPaPeAKcLXX4h+oLMex!c1f55T7T;%KEozvmqy>#3#^caA2p8j@ACtDUx z8!J+E?_VmrT8a-q*www}us`?LgCPl%)b_;R%YODT9g*pa_jT9IxqxNI#L8hfC20P* z(VO`7mT-ABe*8Ym)VBB!PpnOvc%i|~s524jh=a&^yVizNZQ}b5QjKcf0PzL&*woAFOmI(WJLh1u2s^s{~* zD4;opYN0YFGyZ1eM|hl;z}O0*CTxFz(?NB+Dsv;#FTrTRqc9k|lXYf9MH#6^K5uc2 zG*;qQdXWne)}c#55rTZ$o>rHQ%!kWlhCh`jey%Q<+7o?nzk9)lY4ZU*7CTB*RJN`Sty}3s4{l?5oL|KfMgwVKKP2vJar&#ZmJkS z#&A$^C0B3CCThNfQ#$NuwuxDG ztx3?$piKuMyp4_mBtj)J1DDoS1U*0PQ~h1OYMd{AVH84PrZtm}w&27ab}hR2nCq)B z#z{Tb%^fk#J%L9%!anZtFU~eq5;5&KCI`?3C38ZyM>oJ%j;Xv@fN8|ZqD}nxvqdvv9zxyuf@WaeOq^;gC(lRkjDQQE{(12?I-0OZzut`C{+Du_0{vot zY@W)65#bz^t?$oVo^`lfA)0ypUhB$5_=O#YHejp%F+7hvg`LzN?no}-wJ;_+xR;xC+?#XR34VllVl1?SrV^>Tqnmhzk~G;QcePG8p_LEC3q`Xz7-}- zdB(r!?hLeQtbGOTarg-$I_49;*T_X1V&hvk+J`Su)b5rzG=sD3X?0d+u+QDbEA6k7 zF1Sn>IUXe};Z4y!1q6rGp7?F5vAD>IO2=WbJWUnJ`*!LUR24v2q)>zo8@}{zoRLq_ zqhO}iXs()qlGQkZl@7cG@GfM*4B?ze&$_@HXv=K1h|iEy*;OFU2J}HQ|!Ofmfd%>!po;RE+fUP z!$XyESIi?mqw}7f<`6IaeUGsmmXWCpv7+<&mOdjKeDSTZ)y_T&WIhbb>vb3TgcFr? zyu<1EdC?^POQ%n28AIu-?H)~GagglB?4yd%VPK>BaHi-ZDfbJP$#;icbfc(l`BGeN zlE-Kp<>$0Ig!u&3PQTtfc-wwk2+!sYL?5Zv?cUz~xWLiK@yAh7z~FWcdvw~(D3#F& zyWf3A(J#lkm+MBl{Nq`jOnOWOQrmCmNbQ%lz0~V@@TAnW_i-%4sEkQ!=d|FjqUw`d zlPlBv95>PpD?mbX=Wd)i388X1@Q5!r5J2Y8+!8S4CU-75VE{P7TN>MAlx!rB97{N;5&c%j>LH)g$Hd$HHljGNS zEn&O!Fx@-R)2r=MteM>raC35!wJR+k<80si97#p7IS;nf$fGo~ZAc`?utS=V^25no z`HbHr)T9}?t8C4zWOK8~vc!1zw(D$stJK*n<$A*}z8>}h!U@kCw2*{HWV)UR|Jb5sw@F_wwHqrPCZ2otnQ8h#^HL&Gk@^yqHy zG(o78acAML$!37L={xKfu8J$l8_ouN0M;J8dF1hllfztB2O$lxGqL&HY4$-!<(`D_ z=~CkYyQ;`zxItlwgwX|#ogceY;@i1HXfz-;L(iy5()KX^B$qX9} zH8o5D{=EeNK>(sA&}4T6&XV{$u!S<#=lGb9>99Y6JRX_w$;6yyn>NO?O^PiN#Xo22 z14m~dtyWH#%lqE0b9zub_1cklygDDrZ~N3&Cs$vzctmWBgjm-iB_3GhUaYN;Ju9#= zR@Y?Jf1huFoBLeGq_CXuH8!_|}8*zF8Hr)uewYTQjS3kDVv}?MQaC`f<_n zKZfY6$5f66e+`z1Mrk^a+1j(sj9ldD!a(#@zr+mtPifXhJvjk9!(l;JU|n=@gB%)h z{U1X~sWJ`b8<4$Zg?QgMT3UGzi1s1tn*GJ9(^WVL#RD|)v`UzypfIOskJPOu4ha(%0%%kjjw zMM1S{6h}lHkM~+!J^j`wVkzd6_$Pn;L0^#dH^02L`ldWQeUv(hcr3J7>TRXsDfUdUaKsfDRxs<)|f|_ z#5}jnb7ua>Fl9LY5O7%sNkVHzl6RQ*CX^nwijy-Y2gh|QgQEO6?(oa5guBGQ6qdyh zwzyuuKLipg+R~F{%N;sjaD4SAzOSx(@ce<95zgUKV5F>xMiKt`04P!MT!%j>!EUse zw|96!e7`hZy{Qg;dLQ4Q9av{gGA?DANzpC01gZo9XwXuumf zv>O1!pX9aaC>oS;T~$9PhefUtF&8Y>lB!J zkTGWdiP*&0a&o@YjUQYIWO?$4e69Re8s_;+GUYh1&+~!f)fUNt4zDcH+DFe@e{f@j z>x@@|dOH{bm2pv)QqE=uQ}WGYlYRaI8Oi>r2c^9)ae)?WLkae_0nHzO__$ll5KxTT z1(E~YvU_=v2fByFJSXbIM2?>!IksaZiXrRGJfw?ZHz>`b?WijX1_K-vXxq^RFwoY3)$uwUF8MtgB^Jk(UsEywR8Z6SqTw^RMZ>0evVuRIB5& zzVY&+uyF=R!I4MjZ=IllgRb z8`42usKhp{M1^Z_?zCEM>Ug}>*00Mj_pxaKhxL7nOhc8&lsBZb4Rv$5!%vh*rGU5b z$@1x{{QONeU1|ae_$Iqnfrq4(EnX7fucS1^Z$oR;Zo3zW{lq(~*wn{%&;m&@f?acG z(i(kU1zH`*um~iqVrawsFD~#cv$g_eO3+{`?0N$-#N+X~==I!iPnY7X zBAWn;26teIeL!RfIzr`gJ$=G9^qyUW{}!{5i-UxCj$d^_Ka*{O`|1_pi z`{ke*XBoQjYi_Z>9StGjUm!l1ByY(`2&ukVKodpUF%GpC--??#rw)}A z^m0L*JU>loWO66%5oAl77QlWmea^%g^ET%5?`3%xp%3pJi_}aWLuR@-Xp{7AIGanb zAbm(dE=+PB)9{W$o$aF_nvgHU)2k3$!Xng@lSrkr1j0CGvIk$WEhbK35Hvq#k?3>dxD? z_!E;~e=pwZg@cJ|#cjn)tbouB)64b_SbZ(!b@n#7cFt{heO2*2rJgR=PBr~7q~h+~ zGeVnPL2Th-EJT|rfQSK20EIcV&k8zg={(b&i_)PfIH}nDykj|VVz6*`6eLw`#zX5S zX><;70wJ|;N$Sb$YrE~c1BW`9?0u<6$iV~iqA6wf#CY&ecC%`A(SY@!w}hSdCWe(n zat;D--iUWjVg9zLp)~C_Cn?5g_VMA9vkhN)54|YPIgW8fjv4XUVw;n?H)1Xs(6D8|!2l zJ&lWhjoF=^#UC7=u55vRn%Wn5J?$PlwYBz)0??-p|8}DuAnHS2BE>LE`z}`3?#i*iGm)XCi@iW?5@0y zV;PZAO4ss!FQNEIcb9eMW~Yuk%`MUiTUK50y0Wp7ZNu5o@buLU;T);^42oSmk<1Mt z`+O{W7E?t?{?QU83T0my zG|&cR;upAJCS`7-O)vEm$IsCxD=NEDTi=}Q1mCGcyX(_KhF?V-cD2<62d521R;P4* zEf^6|y8iCni+OsAYJvKysK#UiyI)u*zO-Z>D1sV7Hb5(!LLV=@*Tpr|W!5PV1e{!+ zAN`3XCA>_wXhz&_ZcMZZ*FsLFvbnJ_2++vz8MFdAI*PtT)8o#?75J#o$&QXCq1Ab~LDJk&KDjz7wOF>}7cPSJ9)KRy#BJ`daHr z2K6%0HL?-nK=S$P0NFhPC7usdDoHdgEFCp3s&Op+0~=^&r<4=tBI&Hv6y!;Ol8@eW z#qHhWz{cf0CMD#Q5jpFEpEMxC=WQ)>apmEu!KWkwLFT0A~ZBKzUh;wAqL+Ypzy1#(R`BOTr`bA zqytHI8>u$mbxmf9W2+=$l=nU5pV$3_gd&Pm;|Yc?kCk7 zDpeBZ+P6GgfEHM6c+yU?@23x;rCr?9#g~MIe}cDraaU(flO(TFUJdt+$ey1Ix{%wj z`QeZ@WlUXx2)zwb2Colbjk#YPDH^a~xlS{78J887ZLkdJ=(aauz8?-`*}<`G{3@!*}Y zTz?4ODbfjcjmSG7Efv<4>s#_-VxKAa*3|k=C_|q??bwa0){bN*+ucP04GtOKnm?3* z8^Dvs3$)j_o_mg6rnu0E>kPB}peYhGDQ%g_@9OY4MeOj%cEqD;Y`S|H&%yG9jcU^6|EgtC;!R z9U=Z_16wqXdB{c6OQkW6zC|Ntlh;AZh{SOv)9*$F%o>@rdtc62=wkeiOyCdk6Z3YM zfW6$hwORx9nAw+N0pbJEpf`%#969mx8hiIkZ{B|WFf|x&r}UsSyx2bBzd_llDE{1M zz31RLxU#mjV7@U4C2MbKwA5jCBo7E|I2s;#wMv*+q-68lRqi-5nD54Kt zav&ENyjl+Z=d&a;)KBLsnBDX3?tSG;zrJoVI$Id3)P8-`2<0X3x?*0Mo!fUet9)|h zXi!mtP2(lXBDIzBMqT&J@CsXmDKhz#*dB~s^YJ~baC9@*2(h0EvpF4N+b>eL`$|(9 zHQ7+FlLO-%5Iz<(S}EiS!p;DDb#krG!?({G+){-aV7!S}Zu2_KM#_RBk#LlyU9CFT z)(b{j^_#Y9`%2@wf!UXv9?w1T6%=}5gz{(n11?PKd^nvq-}Its`b2y$>hAFFAdb}i zqmVAU{&eWll2I;PLNSq=2y~R&fHo}N+;GdqP4PVJp}BU=2>Q|uHZxd$7*tRdY2E9u zpa$NZV;nQjiR~O%R#DT;lsfKdg+Quoq!Cg5GBB*#dKjV)2^`h5C=Xr9iN8tnYPW|83gI1S-S>us`R?$ z+-FoeH^pCv2}dfU@*R(^x8$=7Tqg9Kj-BX_Z@Q}S^~Vqp{2IN{#D|eTVezTk^y?&Y z4b~blBprfw?5wV!K5{8Hp9^&vnU$^j3Q9EV$cV3krx@HvODdpm{-k^zT zPYYqf4ll|P_1zqn<%jOX%CDI9dAWpSdEVRq`=gYxDRG(Kp@+f6&qR3Gd#4T=FLK?7XU%ywO>N zv@UNR!m4KtS#|ei9H0DSc>Q&tKLD0$n46%LD!tD7hUhvah8y%t8B=<%df?@{HTke`9*@(<$0 zGTV=G7i;ZCJHPQf9qsXUicfB>_6gnDQg8Zj;Qaj9&4FD(3e*M%4`8Mzo^CL$ZJC22 z8WC=!^fW!@v4V-7S1TL1UvoDod(VftUX8ukdbEfthJPAQ3?7Fwipzd|?6=VtQ$3Fy z7GfJG7h@);h}iXK0$))g+r9t6<{*<6AaYaQq|(HOpGu_Cx8h)CimxV_b&@KX3Y4R3 z6BIu!G&aACR0i)`0AL3mioE4zCnfS(M6u{MO|x`$tKc-PeT%~^rz6YNoAgKH+wZ`T z)HS#i9w`kWeMIH@UfxlYVK{;|sfkL}Nx@DekelRcvkbwmhF)lK$!Z&J@LWS44Dv^-7ujOwGmxC|uG zcrkw@*>%2p#C0jh1(Y_CwtL?UcMP|yxKG1=ID(kGN`wv-W88Asv^K?qd%fu&e2oeB zs5ch=?ycEAFh%pDTJ!6%%R{yBYYuWgfHhb42q+wu*NZc(rBIX5Dfa=uPdekAG@S zp~JCAw+V4&(A)Tmut}MuPJwfRV>RMYr*^XZTRzycT#5pD!oRg4n?OZesxDhRdCiC* z?R9qcnam8Btb2b>e)9%m2DGgzDBc-NQenjU_AZxXn=lMlA^iShO65K*QVGHp(hXMH z9Vh4YtYQ1b#f(5Dq-P_L8BW;2%e6u3=)aE%4GmDDy>X4Sw4?`ie{PG59N9(}CAWj! zhE3NV&TKibuj~>Cz8njR8ZOe&6j&>Yl;9WEe$5Zlb`o}6j?Cm&fl@95nz&tPf{J2R z>%A>&-#L(YA2LYndV%4?&EkMhxqSu><=g)ls<;kyMn_ZxC_%nIdk-7dP(xenN5xTv z4d!*nW~GQuB>Mu2gC~6uG$O1P{Y2#nBP8SHV+V_}Np4^sQ|2Th&`40ud=G{tg1 zuZAE9Ri60%DCfK6M}_jv7YWP1I}vy!7gfz3%|-($tV=tpYc8fZlq=^;pDma8yZyV} zal2&sT9TjHeyJlL*<{nzM9OtWW#u%6ejQ30nBR_bAR2^rwKt2_qT5fBdx}1@uAwTr zmdcFvpCOp#BSYBBR}qWAsAo-VE&{~qP??1P?+hIA4o4{R-G;bEOojPri52SpUK4FG zswO?T&4s(oZgnnpPT(LeaNg#`rHR>mqXLhgnC#E}ZkXlIwIwxEdTz(SK}uJ%m*oq~ z_^$vuI*`;hAGFh%uvs;_5<$q2Y3NF}WBO$J*`|dxEvWWd$ZotzL%yDY#G?DykyDaV zU>gLR+E-95AYr$c-5u}n<8)C^xl7ptPawpIuqD#fF`NL zGP2@YIox6V_acx=!TZ(&g=LJQv--_;_EWXX;t|fO8LWR;JiFp)0Cuu%Jl6ub>g@`g zP+(pP58al3yF1s7rhFzG`v!yiTvvg5;(+@xoo8mj|LMB22s?M5)o^!_PZ%LU!QT#B`>v%+P_OuMkkn{vJ&LLV}f00mb1>`Kpw zbx0BKk^!z{DsCEyK4A?GdDSeG3swXR$m9=Y9D2(_p247`{hBl``n`HTJd>wVGis7X zp>%(Y(1v4=F*3Bv9KN2Nu5h)ih`p4a6OJ?o6_a%-REwDSj@_f8vXtN#AAw zY`_b&hV_b#c0W zQ?HBvA$U1qV(0L2^`KV)5=Ay_h1N_5Z{vHm#eYsq>it?sUJHo_kF+c^XR%Ki8g3@Y zaX~X+5E)@AEZXVqn0JtZw*A?iJt1zE*$Fz&5anhzq6T-(Yet3X2Lh3qqsgH@4{JiH z$yqQB>7XaoN^r@i1NRYoF~B|0Dz12FPbvD2d`+2`8Kkk0-GBvW| zOs{!Lu+I9Mrrm4!Jnp-7U6@@2pGR%s_x_RmNs2+HvM6C{aS;|`cx}o31wyvHXE3Eo z5str17;kZnRS5T00h5-$`YSD*^-bvg8l0>V!7&=#hhwU8jSRjtswAHSRqlwj%+aqq zlO>%r!5Xh`7dt=3nwDB*7MpQ|2|4jTtXzOn)bl{j4ib73bd9hfKu#`oW`-q^Iw+!lT2mo{FEBcnX4gg(BC+^A;1Oa|ycA$8RY;i@B4*vkJXx)Oh;|39v;N)(ba zLLsV!k|Woa3L%!n9I+&cIdWgKm2zJp6h-b3%Y9q!BE}qJ&dq&q&atE4=l2hI%s!vJ z->=v6gmigK@OMmKC6Ix*xcAZlZcUvQ(p!SJ+0JU~vvD0ZlAYqJgW+d;NgEbo^`-JM z`f+=oHvmcHZ||;nbwb6$NXr~C;Vd*N58@dAv{-d+px>& z$1Ahb5&qmVdCBYI35=ZZCSl9#n+FFQv~4efq;3P6++wko@6<58c|MNGnC!@%hm!mguN*uVx&d zH2wLn1dS`8SSCVatDWGV`f1y)7->dRo5Q{#4tviu*b=x3y1S9!v|sLDg1jVB-T-Hj|WUtQi>*;c87a4^iuFDu8Lel+YB zng&Af_n7N@@!0^1?RbiLMAwEa9)#s}4ob;`uUq$&0#-RfdyS+ZSk8J{o|RJOa&uB{ zDhU+Do{-I<;)UyipSNR)Teq8yoLiiccrrNIlO(I*_Uu<|FiadNQ<(eHtn@_ru3ja@ zYsUJy%6UoS`hI^XA3S3q{wDlZK<}?K)|sZ8Q@5a;^d#St%|q?j>ZL9mcZW|_q*874 zE*XWeMW})Pv7vdlbPmp0Fw-I#l2nkP@m0l3wM`Abje;36o{ir0dNl`^gsqf(T|(Z5 zndPCd#*@6}Zp@9+ZF1? zCrUXzwj`x9VICi^+7BGM&Pl(E-b06@kImD9lBN(;oTu}5;^*nr{+n8+ok9jneqvlo zcQaG=IiggsCupG7hT2G!@Jf;-_dNkGn<+a+;sbj5!jDS_9|Bw9CVPWsXZPQjw&`}y zK(4SZvi_p^8GL!-ev8ACME$=>ca``M-3;{$JGODE?9I(<`>TXbG$n;5WrN|J(orb! zuj~yC_zXatqRgd3)ou22w#-xFAhL%j3PcWk%m6OIQM+fS;_K@o8V4g|GOUnE_>@pC%4P*-WYDK0i<)m2!zB} zNk@8J&MTg_l0H`S#BaD=5O)Q&K8w@=_WfgxUBDa17PQ_366{F6Cvd^FwX6p}y44{J zA6GMl0TSvqxx9dj*621J`9fk@)`4g1jw=+E@<(H^CS3Vrsb5GlU5kcznD)!xA1qsM ztI;h)^iyuam_L%~0pO$DyU2EmjS0w&pTWCU~nvBRC8=Q`$W|3tLSJL}{?$l41 z1T6zP6lqXD;g-GXZB;fG8ga5|QkAFbNOU_(Co2Li@FICb=+e0MEt?~gWODOGBVY0< zJ1cwOx(+TO^}=J}{QqyDxzw0IUu_hF5+P)vPsrM)bD1oc@V+3B;*T2ss4V*ZmGWMU zPVvMYH{HL`srKyOmI@y4a6@M_-Xjc(qBkZgKw+Z9$3Zcw~gU%3511|Qb zb$?3|l$4(?*)Mbc7UxDn7Pl|}Y~Z+cc8un49^Kz2&D~KJ_kCt-!8MKJk8XEQe^70{ zy@iJ50IL?vV{2m>@aSH47yf3vic(gFAL|nYIwZJvGDTwLM~5`?g(GgXMS=Nw4}le= z$w5cZb9g*_?={++xylrz)Q119c|G|b&j{P|`e)MCz>V!2S(YtX7v}H*Q=`X0va&7< zCxXATo-wE9Zml15dD=zDe^4@v`h7EJX}YFKH!r57L~P>ifnXG1AYx-H>nx8^ul#Na z*U3jfJ<{}E@j%{Ml2V!*EUF74j|9?0?}hHy^1Daw$9?t=`^WZT7h;$>)3#$EQJAO0 z=a_T}JUAlwQ#A{jfN$&wA2P4c@?4PNJ*mNVXFm80%vL`3a{c1#(ooy{>^FyMPt#a@ zvNx-nZ!<~?hLj;cCqQGM&4?VQ%_1#DY&G0X)>zR>!)@38Qi{L9GCcv1mTH5FD)2yz}dFY-ul?>wD(y;SO(Td||Quoho?)^T5rQ?`cvDv3AdeM%?9+366|c%M09Q}4qO z1!WI`&Xen1dTbbsvOM>m=7UXZt<4n!LGx`O)9O@sN5_;1OyVf!l}17u<2qu{7o>{6 zBs3R#B2c1W?fXx;;`BBBO9^wQX|lCSYE;;gTsKabTP!3h#Sg+pn9LGzpLj^wnbdE+ z--kMm8bpheN2=A8^)wwL_XNYGa-ubshrEJB{Lv!%5qnnp8Fr`VF!9@~W=ba^W^h$n zWrr|9>1U1GI-=M2X6fV=2Rp5DalX~Np;MWy-x-K6w#pJtQ>Cf2@)EnT@3kM7#7h=} z$6gvhKaVyhaB@679|$*WHFs=4WDlFd?atcSLUV%iM-gG0 zn+A?Qu2sG&ktm{=;Mc8<{1JB5WcENGG(Km${HcE1jL0jy^FPl{UQ}6bxd^1iO_}@q zl$n?npBtbyd-Yr8U2gEfC=J`$$X(XpT1T3ae}Cpp?5+-!XEhAz87vmJtG*m@NR?0G zOSX#A=-^EqD^53g=Zi5Z@9aX149(V5jh0AdR>k&!`DvenNx}Kox?KpPGaM+e3y!QA zsG|8T(J3~ahJ(@Uf2@EOFN}6aG^qsUe{8I)yVCqdrRolIe~@(xsY*>L*73gKb%%2c zF3W&Z+(sjOAyGb`dx=i3rN3eNcINu-;#k^_=zbvF)QNhl0~m`;vPnkR>0UwW#H6@* zg@ZZnu20)C_^c1)I(d&d?J4}?Rqnv*(GbaYiQZkJo{}rBp3K!h$)xJ$=DKvlg=!4fO3e*-u3M(su`YwLBY7c!Nu5~vs4+}0 z_mV^O#t)Xdv9}`exi(@OwxQ;^=Imy#e%{p$GXWo)*mp`J_f?&oKgpBR^w8Jp(bK+6 zTcr4vl02ml|4z7o@~7PxG5s#K=@T#DNIq&vYARk&9KMV=5~I1%Zv=3VM#I=`ZF2t> zYr=XZt0iN!_(*gsH|0+bm;|JiqRUEWbZkAC|L`jSVHl=tgtg`Gs)|BF%cIEvPKTzb|CbCSiYg>+(zgp;A=DFj$h-p0@&)pkEH?BP1xQ zY*>w0ZuB_=^lkHIQI5fP?Cq7O@O|Bp{LuxBfjFPPHleMi0h+wDs;P4{+J!}vdJ(Tj zGmK#UNxEf;jflE-r6Xi)eRgE7vH^w19xdk>P(Twqadn~tbs>ZwHk2P5vkR2lf+jYU^l0>oln{Ez z2_y$~&SZE?%ika2Ga=WYDoO?}^ z>$VfCsL49DU#qy#h4&2A7aR%z&vnl5R+S>!HI7dkPuDyK#Y8SS#%;V=;8IXScAPbs z!#6kp0Epn!IL-V1aN*90Ud3_n=!EJ0_hE<4+PMFk#oEzlHme}Q0jT!wW8luBr*)wZ zX8b6H&XR9c#x*UMVqG^aX=)VB4%Vf*YAQIz^HgI+gUHTlfm#2fUCfXTbbN<(uFnQa z<%lEi_@2TNekE1Pw7`y<3yX6+wmN74x*sVFM00ZRk2tAJooURTE^w{suouO2^IUt=FLn$ z(kH(xAxlOdcqMQG0)U>Wh5o(8BPihc179L+C~CpJH&5O~W_3gs0rR#mnrnA1T}{j& zFYq>OLn>dzX^bM70GDvp0wJfxX$Yl1Axeji^O8!uI(8Yrk#%~%zlM_XW%>lABPd7E zROZ*x)hwME!3DS=)(HgVr4bdYI~S~tc5wa-cs+e&K>|n@56=an!3&T*FWLoK+iDz^QnXvLnV+;A zzcXAZ)h(64n|jW9#aE!>Jh*MmQTxj?@wnHgTBi=e|1z2+5TbPtoho3arrRn-CLc*X zwGe|+g)-J3e|lyjp%+~xIC&0vd1SZLdE@Va=SBtnVW_fABCE7j#;Y(txU(x4$R-7P zEA(S!N2RgVdgIsgYdQ4~T2j2|bNat*ofn|z-k#san_ce9{qRTbZahCXF!pP~B!q=e z$g2U($q!Fn9*zz4;NN*i+ANyGiBDPu@VB=%X-mWkjT)5sX(exowT4M`+*SPiv-4}z z?>WIeKnuN*O<|SVlb8~|N>GAtFr-~F0&)to<<2lACGg3U&USmeOS;s5{%d~N*3_nJ9BECpHt4VjxrE? zaFM(5Cd=DS?VI)KUg$dZ5Md`z#kI9uq0B^dM|b2LW9NH2=?K^}%oohF4U6|Zik26u zKVEYMmY@4bxGf65)H#~QX7#iE?2LajqX9b{(Z$QHTIw8-d_OHQ&_KCo?hg{vD{?%v zTZCO?{&^8DgbM0NzqTyFS0ns*Ds~ff=Lh6JRx3ncx6cnQupE^1O8C|#Cm{!yznET- zVy!#5S?k0F1)hXTKZ~(l&uTeMD~{+U4ymWk)~Vl|D)^0DUp80J?+;mZ3S^@OMbB(= zMO+i$YklgA`PiZzaO*K8Z*Z0Pbu+uD+y8pR2Ep-RoV0ky?W)|ixk0!l2D*HBIlM_D6zVL=+?|Phz zhxP4wb4FWMwMScwvFuhSN`R8vOIoqBl%t;5*6|G9kIQ9g%jYMu`h-m^Z+XA{Nmy0{ zpl2uGjF0I^H!n@CI>8WnA@)i(k$C02j#%!!ZP2s1OJ2IjEUNInHcgJalG9ZApee?J zn0SRxLVKlIxX z(Jf+8%I~;3_KUI9^o=qwkI^hb-obpp%%rk2Z9Mbu`Q&PY)H8FN+^@+D{VMhF%3~_GIBTwtB^rPC=nX@Zkfz-7wqH$iu>RB1|ix%C~ z8E>=%tp{_u-^6<|zWOkM`feXWg|%x&E{1&@oeoq$jqqfg$-mxlS#`FyPwIxl&hcY@ z7m{v_FuPa-|JZmSFX;)#?~xXhPBiN7@C^N#8`KvEQL#W}?96ds%u0jdpC@gK6i&1B zam^VtD|BORsc*ZWMP|v9*Vla3^rr_{a0Va6o9rK}ay5fA{nqn)o<6sdy*$ju-vgorqXRlq8^~MS)fzFKEHe`wwav z)Yg`%gIeY-O>{(dvE6m%XVyg;#;_6Zrm)01#R{33KCCrw>hY#usdN00K!lv!(*Z)x zao=g~6Q=R3OCk#&!(n#0;s}m=l%>$U9adr_Ku<+}KwC38X+ZwP@iAH=vjiBVP6CM# z8HzJ>Ya`)4;I=;5@|b$Co=hETJyL#%yyXeH`ygqx`=BR;w|9rP1Lmf&TTqqwM@!`a z_WfSgJG6ib^klnSJC3j2^GCpA+TqLNeriJNGk*7bPtlf@>crxK6oWorx;hQiMKsBp zgCDnTU0ivDVgV-8V0mnhn73qF&#Ayo?X`EzOzg`oi8X5e<0=_hYl5`3!ij$%rhpq4dTED;YTcp{{T=$v&Q#tv2^wm4_^cR8DG$PRvfwp)(YH&*0CCD?U*xmtGzPO2Q}$mtDFn}(qA_= zRX91VBL7~U`YlLF=)$~{kflgQJ>J$nu^R`T6@C&uJ5AIJIFzoFmz_O!esX-V;cBO$ zR^{NA+mYkbViG03R4!cm(T_U1Ob$Hx!0jJfC5#u)mu5NFM{gO%tS9|r!#1=8;;XV7 zH~b&7cIYV%Pev}sE49-C1brwU3N}3rML+tE=eBVB`(HU11j`(E?kjAc2aJLSwq!r9Q3*!!-3o**E(c^7+R#6?;5QaisrrZrz-;0IH=c#w`(k z>AD{@bL@;AThhw(9RwLaEAIjJ!u`(S93)eQ*YSId3gk7_)((UC?(DastOPipRzZ29 zZsDarWGWR?XLpP3ZjB4%;8*goG&q5R#SWDV+$gp&e1dts77^l>YX9y zfoTTE&OBET5ZSF~1f)AwP*M&)_k|c!WJf_B!Zxocw^AA$;~$=>mz*-q>S@r#MO1V> zNG|oys_g%)7J<5Kt$YkcxE^}tsd%2y?8dGp+62T&cs%Ag%*sPMP04cex24OyPw|vK z)=i1sYu$=wDp5+`ds?NjPS~X!Hq%vGvWgMueldmRFNWO69RAZ&7yrl&-R_@6etV;h z*j+NJrpseLHg|e=c-;M=x@ceLSu1(%D1BF5o+p>I)O{{gUppFhq)Pj3dpOU$xng_fNU%Atq-ZQVvSFiD zQc*in9BVj5=}$QhtFZY!?fN*SU++A0h!H&!toc=yksSK=f#Zb*?dE@MI<~_#T1UyZ zHPMx0f0&8iyCjs|#ILGCT2L>lK-wrtX!-9G|%;g(4S&%*9^$ zEWIt27x{SpQo@;U`{10$D_=$~XX9+gBH{GA3THz|*3Sth!{7L=?8Q;r`#}S~;Sa!n zB##E$j(?i7YLxvL=H8}P$5o2z;2PNMFe2}Bqy7~tqi>su{>aS*M5KF8oA zm*9kKZIu`ZBNE~^B%D$dvk;F_ArB;fr^#9&{?ZX_}R^1#}R=c;6Q{^+m z7l2xY%D%>aD2tg>3|}l)()i+r0S1d|4|(?;JFAuZbaz6V-FX6vWETR;X`pE$UO@&O zSKxeOv;&LEkAbc*!Rci{PUc(dA>FTK_D~dDh-V4rVJZwjlVvRev|twB)}a$zW?53e z*C|Edw~4N=bwr1oSucGtKPNS!A6U?R7AT83l$EWl{J@aT#TqG|DNq^p>)P~m^WSd2 ztFFFo8mp5wJ@4=_=QB+`LW0*~M_V#0>lCG*9L0T=G342PSkrI3O(UV4q zdrFh+O63>Xba>8D9da96-TAdiYIJi9*#vXSWeoswl*genkHhqW=DG6uvxw00ZK1T; zy_=sEhZYY6_Yc;qDckrCm-L^KYto#DT!--eP-<;Nsk{`xUP4*cgk{XSgY!!iI}^jQ z%lFF$iQh;(L{r%xg2k!aBi|rhBkMDjK%Rqp6TC#s_U*9eQ%Khgm&*(#1rWSr;r6VK z^cRqcFa0*)+&(sOjHP)}JO7VUS83CFYL*2g=C5?4qeTTJKWS*v6S6$AweZr$%GEdvO?HQeW z(wm%;TSUbGE@??y=VwG+h8B!8vjL%Uts*|BzvjGO)CQyT=M;r z<6$P6u+?bH&(V_sWsU-K*+T)T__0@Uxy$cewP0_JXAx^QZnz!3KL75=sVDhLB!dPM ztR@S%!)njK^dxV8&v$n9AyFUGoj!FP88>J4zELZ^TpPV!Dt!Y8{7;B#QbReZ-H7Mx z2OVQTu6thJDFSk!@~u?_2iNYtPVKfWkg?7OCNxc0A<#vvih5#O(y831<0dmM(&;w} zJGm&;XDudP8eXpQOJ2@!L1p4nF^>K&=~`+&raf=`)5m87TS|gaaz|RYbC5;26FY$E zY(>JtR`eOt)6$p$UYf;zBBKi3)>!q;5AtuA_8L$MuQ3&ddVw@dp%xojH(e#*mSxO_ z=-acZYDxL}y(^jTFe`Q?BuHc_*~m&}X!Bm6bG`Ni9)oD`kF5U6rHMV;=ONzQe%$N( z-(_B%PGnordVjH1h%;GWc6S|RMRbO6Ahns) zUYIvyz`)sL%$r)U0={aD;k_MV zUYSb*Bx_>jWYN7GpwFOhrV1J@an#-_+N)Z0eYX4MZ~D^pcb><6f0=y35OAky_;yQt zwqx&>afgH8VFU%ce}}Re17u!zrpGPdd-IGIp=XrXCX%n*eb8k8VV*F&|FiV#M>vnx zdCuz0*``FO{2yXinif``wcPo7?$zVs!=g>krkY#O484#MXQ!9JPlPkYjRxO}*zaOe zTSRHyohUIpaj47gQu&DW(6Mi|kHsxx%YrgJW690XcC^Ivk^h$OPD5)n?N#G*w=w_u ztQI0J;D4{`7U@oIXT?eWmk_O7^ekMfBvb=74=$%)XP8jZqNeBCw>908ZysXhdg1k% zS7;+>`)nh@#k2*tMB%;<$A-yG8FRc0Q;q6EV%6auZ#86-^=GGJpkm3_-`-QOsKjz3 z`)!dvdDS0Xsn($w#Y++k=H*|JhQ=N{6+G92}pbB?p4+H7bO|WX4wIGF(90O&g z;*C3HnwZj{y(?{8c(BuHi>jroZP%EvL&>X1R|>#{_8Pdk0|at4RV}U*%!@qNIJD<+ zb`>CLoeR7A_cdHX5*4k@i7dCjPm}2Sn+sEX*RXOYTT*EG)S!twBN8~~GqdPJTqk-Q za@Or~<499ui2-`l20MwdVaMs~m#>j-QZwd*E}+{nr%*UBXh?w(ILB1+C|%0vvFVHr ziPQRMT5~(~(gjQUUxCEb^_zAue>w#9`v@9`COr!X8tMrjV%f?$(!pDYtAn)%!4{Y{ zY&7a3C9I`>7R_PxC*N6uUv|xX@%O-$f;mQFdW=+R+dUz5$p$?g!m=^)*qq|n5V^KJ!=cpq6V}RXDXi6i zV-_Jy;@wc0g`cbng2y3mx9v~-jR*CD!rxW|In^<)U3J_oTGLq>8+`7z+0(s>X$6EE z0VZ3IC|NPebaC*eCEr{qe!~_#Fg_*xar}xP)MR8QtsC>MVs%=ODumI_{P_W1+>&xK zAt+QwyCPBZ@3$>?p5@#wu`fQ~8TZK82!chT%e3l_az^8KwSp3fi3HU-k7;kwGIPSb z_Ng)sTPkm?ma32BIcGbmNM|$nWoqW1C?%4uDWA9@qp>9-d%xJZb4RT1DoBE#a*K4A z#tUttLObXsA&!gjG8KJ)e@;&H@^3z?lX17N0+ZgK7o1vw0|Jnl(DV{39`~`=r-fY720~s@^s{UT+Qwpu>lu@t&nW{>NLyDn>pR#uO=-0*TCNKc??kHYy$w( z<>@ylU%!|6*eT`|6B9lM1Px!QoN7YF(DAHCS3pT#EWo zQ72(}`uy)}D5~eMAGt-lVgUXm^ZWwXc;t^?CWqv8nF|V|EcqUG31JVDAi@BIt3@s# zm&8*d4`}yD!WVDcsLGZgk9VD1o~mabvk3apz?h+j*QTVWoE6U{=#Mm{jn!z~x{_M$NC+r)`S$RC%LX2X_~H zl`R%Vmg~Ts4VdUqO^3#n?KNUIINiX1AfMtC`xb@GzrwCjisYn^i;qdW>!R<&e{zIt z-Ob{b_fXl8F&wb5*|>4{XV>5EYcZnNj>|6#IN+Q7--VqB3cfXXG+CVAk8`B$&^4E> zJ>KUHZzxV-EVPjLGYTxgi2sC{*re@CWp6*KjbQz`(DG#akRLlW-ncNahw)+NXc=>K z1)oP=OlnDxejxlpCYQ<@m#qWdaTbV@^()v2%Q@APG}S3N6IY9jV%Wg0LFW-erqwbD z8cNA2pC?@;Pr*sY-HHSJ@TWF&&21}Uqs+ZC#{>_v=Kit8H zeZ3INYujA13+sSTDz)55%zK^wK@*3N=55A79RDI~f8;!-10%RPLnOfbG7(d~%`NCz zx36odDVpM|{`!EW`9BnhhTV@!y|No4e|b_hC$9MbkZd6z(tsbcfV*D3Lc>h6 z@d`hf%VS+_|M&SH#RIVx;UCuwu+mjve#*vgx_RDH3!zfd6v(_W(w3;xeS}SF@q$OE ze%p~-KT~g%yODzelP71rgP%frQ6`Of$Cuuph_^J1cy#LGfBNjP4h5hMF=(hA zKeaZ)Id9FgeZHIF)s>}o5R?KqwE8wtrvjGRk2=z?+c;UW<}U0!SsAwF@wd(-fRB5O z%h^(JMBJO|IHKb;ci+d=(lcRRPu{lP&7+t%dz~UTZ4d2RQJvmCHc)Q7sVsamP?#^O zWq3>Ec@I|oCEJTw5T>{LKkTEdyHtrcIYU;J0SEr@V#c^!}7qRm4kt!LjmMz5^aTv7u6nQ3fy1=L&6oLp4=@znP}; z#zl3blfV5QZd_NtbemoLHfD@HFM-4|?_C2Q8&Uzmf)o2(sYw>)r#ZwdS+5x|`V}C& zClE`CHlk1Hdw>DtF(|4&A+`MG6EAapoS2aIz?Wx?%Vg*oT!D+yD}vY6^y?lIU$R+i zSGmk1d^g~udW9rgSpYJvtmc+gx7WU)|L%|w8-Sg~w(&+9Gr!J3xC4J^O48 z;PxsKN|id2N0T2W%5Ao>-Nsa7Ns@F}#BN5ZtR41y_OAZhj zD|DBoI3&BHVB_N+)y!La0Y(5zP!1wh;9|}i#(7X@`8?0S*RE5QPJN9YL8Hx%&SK%Mq%ocs3WC?l!sT|W`ojNgg8TNS}Bz579#<>t~^9xWfkiB4YM*welRS0kU z=%c;`?Ir-ZEB&|&2@9I)j`XorGq|x_JMo2yUn<1cZ;1t%b&85l`ioX$h}ZLG4YN9F zFb=0#n-DaNu0`N<;9E9W*(-IgRsY9U-xGKc2e1ktL`_Y?qubtNgdA31t^dKyxr@V_ z3ST;1?#I^8@Bcc`ZU%?ed+;L_2D5J3*s)(^Nk2k@RWmtv+)E+B{s+>)`%=|pJ+PP} z%<&Ta97}Qa$`<;cu-~$|?k=Cgq5i zp9fw#l>7}_&ZV^*4drCn5VCs8hsdd`v+;w!3ho_ob&3C-)R36TA)YDz#}>e~S$eUu zn&99oR+`VIiS{APtz*wNIgx+%7OkeO_!r%mn0v9|{&RD#H%5d@x{diuWR8#ygbJQS z4Qt@q#&i$jmxBAWMQk3pm$54&V=F58CvJ1R8v@OA{i525Ry!Jw+TsxB+f`PEmW*%h znVi1jh7o>THP>^Ehofg@Ex~2;(%D4E$h1&Wqg^qw~uu>%-S!deRU!ki&5qx!hOCC&c#u4WXv}#%0Yw0tb(Q#HLd~^+7;kPLwcfIKPS&TdM%L^Yp5O9q7d&!wK#ksF@c~dlFQAUNp7W2b8*y07ENRklCi9*} zyBwBJb+RzHgsn|INbt-IX83t^x5=ZzfK4y^%Yu*qAejFFQrj7Lpa|#3tW`MHZY2*1 z!@HNF2{I~&>Uonc-o@ybwn{?c7pSm3Xe40l)Kl>F{?v?Ea{2_zFHOj9y{_%1fXZjw zC9GEiGowL4r$*W@$ki^y65ZH_E{mgI8_F}^MkP>xy3B{9+ku+vsP%M-xzS42aM7(2 zx(1&*G+NC%u5-SjYTM0Rgi*3HLxyz9JG##g?T)IOMIYh1Am`jR)|YHn9LzbIk&+l% zgz@N}MSYaoajlJlyoC~}N3lN9TC6ogs#9FRb`x=D>n!pN6;MBaukfI8hrob&Gb+eD z?MN=!)+vgmuzUm?pR0*j@;zDmAl_kT+aq*ZYKXtDp~jJ;T(CJGd4}LFE&QnMbqMxo zWY40U8s2W5DDN2attgrNx}t{Dt^AJo@=pUvfJbdBqso4QjNItXw@7X!ZlaM3sOnwr z?gX<(hcnK&d@q>)E!&zsWlQ~DmK*Jp(9w2ThbN>2F{eZ%dk zho7W&o{Kvl(}x*7i8di8;uRM2Vd+?n7k9j z+jVP+Ijp_K*Dw2uy>B=val(&iyDQ8e)_lI?=Y{);hX-PouyIMQmJ>k9{q?6Sk$>sO z$SrnM;fTe!mXQ}IvlSxtJ`kIaj!V}M9TnH?0HvZoa&7O3Ry#jl%Cl9lL)=<`M4{Q5 z912{L#A0Bt+Kyf}WB2uGeV&vUKu+@BtOL8v+w`p7W7Zu|Y56FI0aXqtwSecE8PeyB zuEpQJKvAk_KzLGVj5&k99`U?+e=^tBMR^?JOL)%wz6ET8z1aUvdnX4xIoB@SgoA%L^3LAIbTz|B2Mlws|6_;^U{Q}Mnj0C@ zy+!02vcMy)Sotg_5*=Z=)|Iagxzo1Bzc0t)JfM)=);G%LANpv6_HNDbXGn5?{o*Eg z2;8g3hlVjK7+#ytpAvtfuhHlBBeB75=N_aLfEI4I3#vcQy7?!z(GMbse%0sf>biO1 zg0S2+rrjgj?sD|bldtqMgcayU$XsV_?dOZ=#_fbsl$0&4V9R>MQ9Lf>fZfskwgu#A0&Q~g-ikklOg&49fD0-(}L)l>+ZZ@>bH0-wg)OpQ3 zV-0a?_J+aPZkKIK6?vnFU!F?&l+qXo89Zjf?7y0v;!b*__>+|hhm27zt^KY2?|K3R zHD-~&^R^ErTRt0)y!T3~(*BEAwI6*}uNs_L%X)Au;x{Zg`sks`%h7`PhVUe?C*G%( zEDu|+^`am)C66Pu8z~21Ulek+4?R005O{9k#l7yre{99ryI`I*jqa!>joVYwJlsD0 z`i_c$w&w!Ux*JAPi$@cUJDw>nSV~gL_W;qH=^C69>HXa$*|r00;b4WP6levnnEv(I zK9u}7!a}ilt&uQF))72;;U0D=6q><0UJCc1A(1+bz+zk~70@*Hk8L<)W9Eo=8`Fd) zNz#zYS-YiK2VLF3;DYFGa-7BTxHJf7k@}tx3+^)B^*T-svF$NHb8BkykHL*xDcZa# z&nVTnn5sp*FL26KOtq$2o_Xi54rHX9-IGs z&8w{qT8E!SUcn>X{Dn&W=_S~$u62M`-*RKrm%$8Z!LwTt4I5#Lr3KC3-m0d`sl(07 z>KqDta=flzYjj;u0CUk!w-Nk5tOSj0CL_Fs##3jUVMHa~Pl3lJ^|ox~TC@0`BGra+ zC|&u5Q1X>68CfE=vc)`kbD{i=HU$AUXar7(hz{W(^Zdg7@Z^QsG>wqeSx9x+1`;zB z*kF>mu^qCSrlxHvlUok_zN>^vsqhUiK=#3VBvX@I5y;2vzRwa`>)8eG9obNLwsfhG zb`XTeCAYU}sMPc$1KSkW(vm{5;~zl@pa!fljLRt1(WRkUGKn-L&*#6g)OZ$)ZZkq83`N_v^j_Fu@&U+myH&ROa7vK{N|05;2x2o#gH~^3@YnvH3Gkdu8_>6z)!t8&N$S7HYFn= z6pS0{A^?t`eIAJVbHQ{4J829Iv5JPw;s}yrD?(dbdBgsBDw*%?fLkXrg&F~+-6Crn z(XVm3a(5Vh<0PtZJnKGxoN8U%LP+rgH~#naO0ecr?7k6HkS446f75aPp#iJ;AKPdJ zx?&133)DBGdQqILk%kdpTUp37Fs&w{Y3Xz!!W>UZzIEP^bK+^sAfuOUFFZO3B`aW6 zS_oCqAkz6!w?GBpbeebB`KmtPjvv@Ao&^QNE|(uRJ5!8Ed!>c@9Q$5@kV!y7Ldj)` z9lo9jdB>Cxn*JG>)iA}{GCaV=N(f>pNX62Xy4OKmjG<9r7?L-{bq>%h#8mZ_J_5lw zLfW9>c%iaOc7ohPUo^6^h0t7DR{^F6K0eqf%-CUZZEG6=p>RTkq1B+&xHC!hGIN6K zk~}Rkh@hqSsrw5KmThmA_4^X67kqJNCq7%gxq-+@wX&cm!j%gR)iN|B-_4w9;&&?1 zbB?DqSlWW^ z_1>(h`Wormzwlwxg+0I%L%&2v5NjFjb5&`*hZ4$|4A{QSz;(=5clR@?^FU^o!GFhS z$Kb{R8Dwx}i0^g&-<0Qc3n*`c`83zHPqpzS2uK)vH`&mfL>!(qibR zb%n$Ft7n5h=s)7z;orr)W0=l*2-AeEj5W;~YU-}v&AOb>Gs(!iVV0?RCYZ2Z7z)0; zVv(BlqA~5Dno8JE1L6+h9SKnuavj!l%a8h8@kLd&Qy`-+WrTI_r&e}|$nZ-8oh4-n z1N+X8;xAn_1shN@tBYNLQiefy2dh-jr;Z-%&F%Ix^E~eanck@1GyDn+9ZHTQO&8!Z z-2?uy?TTP$RT$kg2YpIfXw|rn2M=w+059UkeS1i2;?L!NzX$DKoJ8ab@z@~9><9Qc zBtI!3f?%x}S$*us!%Fz|v?RIYrNhVPkH7b^*zO5uU%joet5jB)AOB|opUdO$2kW}l zEXO+63|_^ooJv=gzQ)t z(R5WH2ufU+1`eFTf~CQAG7!e7}M0+R(Tt9>KQ2lbO)hEewpl> zw*o;lrUvA1^9?P%JmhDtveJfrNpiPH26&BYnY-;rEtouzPU}*WGzo?+@osQXf&qv5 zCU7$&A&Ohi-}S!h^_ZhiSkzOMn@bjLd72ciK7LxEzcWb6B6fGA`4PY;sIp%OFfYw8 zx3z_xp#}Q*KBfZV0FAYVg!;yeG0@)fKW>%#75zQa!aq_~&=;^`^aC4C*_?o(YQ3rw*gQ5 zzg{}je~R&t^3d)Q<%=B@01F+q6sO>io8$BgUI9HRfO~wy;tIyGE>u%ptHwL)x@%X! zUOt!OlkZ;>mWA7hU6s`TY>HIY1Q8EL@4Wwti&uA_4FN4cVnMTRCz+FzrOjPd<{Y;j z#gDB-Q_g6O z#f~~pYk!JYkvB0dFe&c}g?T5cyT%sJCe~tInBr89^@S277%)IHWWW-`a>MeFg)cla z9eoG2Xztg9_B`fLS?0AKJosT9RXo45HQfECC6DEGjOv=QQ~&_xpd=BWcHgh9%5I5plcA2!|MYPjJ6#_G zeBYy~!p223$GGN+5R&AiM`<8{6m%#G+BoydPvuTNjOO8wUTzYj3U_Gir`8Ht;IDSy zMabzg^PD&MQERHw__B?Hn*j+%;q>*%C0a?R3Vi&a73kN&lAh5}s{r<_*t^YoISw}@ zy38O%M$n@pAt1}7-bd+IONYBo>U@7!)9$inwhfz0TgOr7=qAP=Q`7|Pqf$%3mWNcRy zi-Aq%;gGN2P?rJ7>Ow`n!(kKkDR9TJ{_K}vc~>C-ualYJI>RsyEGS`p&TIE&6|(9X zaquC%3@Vh=9VI7m?tM#^{I7p(W-_+C1tPro7Gq`^)}b*$JmBJ=S&@s?Ewzb`&}&MW z4t%u;HxBk@f79*)d}f58t3Z-h6k~oBe4ZMYnSFVo3jh)B=$UhqYdH-QKMz z8{M&23k>0}or^~d7g=a!uIZ(SR%hM{lw^jb2&tMx&n8$O3GOzl!GuJ?#7U!9hw%oAJU+>>2c5}%%0yT{h2Up;2? zUFloO)T!x05i3*9Lw{JLvt*TUm?OnFti!8HZC$Fb{$lc|O62-nVyct)sYR^D-wHx_ zhxgDJ{YUM2Aloi^FyCPdJD#=wme3?vM042ppsL^|$fC((HmetJNDbe+Dzy1Sprl(Z z_1i4T>&;ogl3hFj)`w!Jz3T)3T_gn9=h@pQ7@xj^2`!}LKE1Sq;JKcxo4hA=c6_={ zf>yJ}*YnsXH3_u%J(PjU!$y#m7L#BPq zkt}sKSY$v%TTpHECbqg7f!(b~1?_roO6RBGW0D$EfA4sGiMJfrY?MLP)9qJbEKvP$ zH62k9&i*<=Zjs$7A5~1d*A?WBoz<4J9B`N4yM5w^=i!IjUFLc^T|uxV{YZHrt|b6% zPIL4ZKs~`WqPJiJ=yIp<79hNv)Pdz{Qgm%_FLJ(~WNT9umagypH`o9B+<)}}M(X=} z#dmWJnoY+$MHC$)3P@Exnz$%Md$so-v%_@>>q~!+VB4JZM^pNgEwak)qo+;j=3Pz^ z9)E?!bw^b@e$bA?qx#S*_B9;)DS?>2O_aeC)bKT}PjLl`nF5lnh_m^{= z<2d_#K6`&&m*?~GwESc+v6?c?n~*n|^@9j<61%fl4rsAAu~!qjco?KOnf#*t0pMRD;DGjLrrQpeW~^%1KMm3$ z?}4)+3EYQQ!Ix_Je`Ht{3hyNFse^n`M3~v&;icM;yeo%=yW;0Q+qhvYp-Cy9!DrDj z9y7Ej^{B(n&f#h6I(zCA9*B$>%vzvGJ*lacIdDs=T(PLH&tcb!BH0$$kJK=*C5!v4 z?hc21x~>H&qB`DkIAbP|G4ArHW+?vqcZYFUv}4?_GRc3^7B=|iFUJJcepC(RBXT`I zG2@3jM|CKG!V>0q!q<(<6sceq%K-r9Z9_#z+p z)ue{QzvMHvdU-J!N`%ZSujHF1?xb7v8(6chKQOQu8+{c*aWc@RD@p-O{j7L%{|c9p zZC=spwi{I)YzQ?KJdyAj+KZClp9Yfc3K9YaD&9<2ox8IYe#KeNtJ^tuOQPg%G<2H{YZD7q+?egEz&^|t7^f7G&o&WiVe4|WGh zloCzrQ}2Fk5$8J7rZ0@(M$xL@y?I^xWpHbr|5l%3tL^Dr!;1K%=SQ?GgV&1aqw?A?6NcW z?#MH|XMN#Wy%$8v$I|ICSj-3cs0pjQF2~D>87U8eVS(S^Q2Z{A3R?x*LIg%S%C~I8 z|1-3P=5j68`jwhr>!%P(@mHPMmwl(7{P>7N83`hCVMc?f$5w|mqi(@dPbgeGC{72} zNc^~(a>D9mtjq&`WS!5enUgO9dOp(UU--Hl{%_x5L3X{H=j;l)@{WlR3!9Sm4XxRQ-rcGd6S3m1D}kGWDKTV6K?yfoH6nZ<7@MJ=j7NDNNuD>krZg#O)>%(DCg&GJAC ze;ayvN{6e~uOsrdA1=|>?KhnD#^9D9qLvtqrKc;r zK7O?;>JirnC`~j0qQQUrwu@O+giq380jU2*xoQkex_>xD=1R@dy4Es~e#>X_^W%2_ z_YF{Kag`?(;9etF$XIInJoL=Qhc?TJPQY<4gpsfG-y|h3bQ>*3Omih4Dlg}UjQft*} zn$^g#HdM5YOOJt_E+76YP~ebTjapB&DflecHd*~rIwFVVVA%qeI zcSWq&178rMyIa3~rX7nei7U-|%awHIFpN=WmNl=}LWOukPBtwXp1o4tr>eheySBVT z!g*htpn<++9=Ia7ez1CV<8DZwUnsh@S;mwt0$d+p7g7dTRsgaqd&e7?2UUgvsEa*G zvPRrDuKtiv$-gI~k9YXxUas|23=1z@_8Z%s0aG`hHSa6BR)swjDuozV%e{|5pa4a> z8vR}ua4+ssIT?RFh<2e?=_PLwnLHS0WNG>|`xCrmm%zyAmCu_G?***GbNs-~G9cdI z)-cz?ceeKVDT|HPgrwQTt$W4-;ytogi}FV7tjp;H#*gUHbk_a(-7Ea{UqL?EylD$0 zfkJpg5NYzENL72*HGK#6lo%8+?hC}c8vtKi8kJsUfi>$-7JlFO@-D!DDljQ*5w@K) z@&V_bxqoCwujeeVPDQVVZiREztx~DbB)xUh7opBYkzNMMWNSCzBN(hKcs;6^+@!&t zE#XCg>7mFIO|Bq5go$gLvg=$IR9n7m>-Hnjy zSB!lUH)Hp(!oM~jzn=w$&6rF;D!$cY%)Kvpk9AAAW$pOHXGTa?G{lGdrd|BqLctsH zvha#y2*>t!?pd@6**EbVBS7_V;3_&US9fTw9sj}_+D&5NIA6JNZK`T86gwK;a^nIl~G zDfBO2jg2F(pG4r=NtPMKV019XAw=?}F?|QUV7F_=8ydP!jT8#+0}7GkD#kZ#p?Pot zO`En8lv-FgES9fJ^P*1FJFdMy8*^ryO(K^(x0tm{S_km>=p+0=2^S@%^JnvCFG^pZbut#KJ(&KX6R;jp2z2RO zc?N)=4~V#H^vR?08S<5P}W~03HxaS&FEUZ?S1g{^^-$_)fPjIo7&}3y} zlQ4Z=Ti=0Onoq^#xdO=%LKkEG}ui`B|jO&po|Di*QeoE3x!t~YQZl#$@*=Cgb?pggry#hcV!|DbqlIy{Qcg$UQJ%q67M~LxUY!+eJ zF>4BaFO=tv*KZVT?w^R14Hf(8!wQpEW6iXK4&ddx$k~qKv8%Ol^2dMWpA$PTo48MV z{XRg3fJJe|;Oa$hk#$WofC38$@$H@8Y-Fx!+c()1tlnz>ev9b{N=b9iy;PfK=rVDS z{}*i{cK76u+IXp(hqR_f#>+Eg0}c?3_y_q!@L6VU+lqxU$B?p-KeldAGJwp*jaYs2 z{0Y|OK4I$vBfF*=^io#s-q`t#sq7u~%`0o`eFo4@0L%tBA~3PUp!yFj=M-0Z9}j1Zkyh8Wg9LIhQo!*Pz(5sn^};2PIy^rAHVb5-geM1ga;E zp37KpQW@I_QG9alkzmH66m<*Nmz^TT9;852%{j01>ETg4vv*chr^4c<$`MqxmjZ8{ z)9YZ1bD!3_245WW`kYc2ah8pu65F09*fFYf(c?tFY@4(;)|Z>k6I6k3`#aY@;riEK zflo3}?YR9D2p^L=82vib)uhVgM{4bff7WMybVS?(^ct<~H=I0H{eSxw(R)M=dTIss zfD35uqCb&2b>4j8cBb>L7tkor>em3j41qX)F6y~9EA$1CQZ)MPv}T&gr9#fYPKUmx-#~e8HLUq=|?E%a48{#>nx*t@*OlLBTt7Mc{zR;$y?JErksT2 zAJCQp2HR~Dvx?Wp*ciU#Q3fMs$POU!6Y5>czb*46n2FYoPTiv$7k|;vKcSzwChW^I z;>v?nefu>(7B96>pO0PpeY|R$@*_}WKbzCq&`pz&ep>qk(tQ3!9nZ5HBWyHKWV5sy zagM9bGP%HrCQ0(;zfbO> z6jKKAdh8C)8QRR5KsV<3S$I}j()VjpNxc11U5l__7o_VTp2)hdnGVz@%dJS)eod3H zXR$>xC6r~&SDC+dz}{^5tcVczOurXUT8ercRZe_dlW0ezTQk9&n+%JXRB~96q$JtC zO2D%uT0UbZI2I@aHAT2ygN1s2U6L6fZtd->7e(tePN@gxs@Sd(4b%!`{@(j23dMA(Z3%` zoDm50a#(2_&5StLh???_M9i>j!$&=KS0Bi#DNBBZShYO?qAga#V;LAqO%XG{Gs`5p zENTdzE1OfPt=id)cf!s5)#_6q+%BVE;uE)Uy$}(A2tIRyYY2yvX|C}X2_wW2X@T~c z>W3Y{o_gGH6PQ{q72~Sf8K`UaYs-(TqfVzvPIqZC&QfO^ueA;&33scL}Z&ym%nsNd|k`nRPU?rd<)*W?)A^MrcT)JfF zh54y@z2J$x=Rc4TN{=IK*OP#qb31(yN0qa~A3py{iZ zdbRuf?S!>gMelr_S)UB2I=$`R+q-5E4+>MIp4^<>?`T|Wq$elx2P8c{#lGLJXO~sb z3I1H(FdHu|ntnVYWG_WqQ0n+N4cAx9Ao)Fk2cF4=8 zEwo8zbkwDZ+|q}B&HrjZRDThc8Kcy%v-FfdcCQ?FQOMJ?0-lM#MYF@!XdD+-dtb@# ztxLOJO1|v)i*R9?$QYT<3@IsAwL)OqpQ3CS?uDB|}=LA3&x|O4%@WwVeDE9z&n*B9H z!?UpdO3)U>a4Y4=a(VX;=O0HCobdJJDf2rGk{|q56&g~+mq|0xO%{;Y?b9j>(!Z$A zjdyhPaVq=1xIO-6F20sjv1~OHq%%*HFbBmcrUfe}Z?8!Mb&HJiu0v6@LRfJjPJg6y zjjKi>R3IV;xJ$){h(?nEDe-u92JXF=53ijOAF9BO++N|CLJEm2nQqy$M50Cz)f$1*C=a4`v z@h#;9uG(YDdW`l&gZI-EpzwW`3AKOm?*-P4i}wRc{F+ndQef31FSg-UW?R5%#tvP9 z4Jz?RKh&V$hG5ho_J?InMR#G}E?3xj5g9htiW^R#!3M^IE{HD~za(|J8X7%h`}-G7 zR=xY3wK;KJM_tTu)8X!ZhrQr-LmOD;y&l1>Ue&fX5*5I`K*yY|0XupK6{Qi_RM`dx! zvudn^-!uRD+*i;dM6h{f)xj+G=jXw3CA)=+e{ye>=4L6DWw+%@5bYA4!aQsVgX1Hjvm24sD4j6_-IyYn6L$WT~?TNKXCGaY_rif#*Z z2=9=9njY-SztxN-z#&H+(aNYPN;+TKtp)H2lvO8mO#}t`KY;v4+`Y9vyO#EX$!XY3 zawE?dMb9UPUCX@G9E}V|aUMHamGwMv1GKZ!%d@<1iuE-|j5me#vv%Xn8eR6ldAAb5 z5-|36YPgZ{;|nRFO!52m#`?g&e@!@9%`>-28{F88dvI+E7=yx8*}y0O#>jMT+Fbm+ zb)+67f_yU4q)e&v?%Ts{qZrcbZS(z6Iu5xJnrDX4MW4CdU>n6&^L)@pZn4}=C_yS6X3y~tUH)RLSlphp?YW`e&ntA z+DV#Z(1x*ys%IZJoct^Bs@7%M_#Od=9Mg@QdR?eP_Cj?^uhl8UX2)Zpishi8tLT8+ zZG77W2ni(Cve!yb*#L&>C9Xzo>Dk+&)A^^v z4ctvV75{*xk$0H_+)xB*5TeJ))`eEIcnZW-Ln`yLwfGF&kx{%TqMK?TLm9LwDGEb1 z!psm>Enw9_FFycSdS@x|x>ec#2|8|30gZ^Y0=r+WZ~8}4GICl}7Kh1qn-)L{FZ#$Waw|>4Rr(1C%Z|kfy3pLhd zH-hwnYbG^$745=EAOi`szZ>`X_fqDmfSf^x^q7Pa?DZEBo`~ZtId&1n%N3~a3u-Ff z&&NH7WJ=+Nts+5{&bYRZ`4#5hbuQ-1`dp{MVC;#G+@p-#bU;b~Oj*{TWFLTIKh(7c zX*U_V3Y#2r(7L5~?XOx=n8QYep@!Lm{O_E2Bs7pbVPBOEo7&G6x$o>T4PqYUX#NH+ zU>{@__Xn^p&d8w1Y7DO46Brd(;J?xd&mcqCSu~hu0|q%vUN=lplZ+{{o>CO(#2=~T zXjd1w7g{lhu{~dY!(UqWUt!^jAk|>v0_{fe=93vS1C?hOfA~B_bGCF;KSM27LrrC2 z2^k(z00`Emiaa>Vvnd@|MgO+kq2kZ=7U;d;vhXloy ztk@h@xs-impP=Uy((uc;#~~MwJ5QD69>!V~zO-dkWvm50-u%^2Ta1_o5;hv09eEP; z`3=G%`d3CXGrI$jx3CR#U#o~wzt3m)ADqbA;Q%_Z@N2kFVBK)s|BWEq5RdG|ai_pa z)ryL65x?Zev0kcSmAg|7>~U(iGUZ4us~$C2h*DqK`v}NeG?7RK0P&Z*km^+*=xLGa zazYMx%rjLVrrav~{=Vz3D2brJm-a-T;F{iFpjzJ4A-Bom*`t0<*4(7kD$toPi>#}7i zLJNhnaRq3TobhLI5eq}V=OZ3PWB^JprpQ~ca1X}8971X$us#t20U;L;fwl0b*0y1= zqFc?D$|KLUGB>kyY;X1K0Km{N;vW+VAje)mW;)voevmwET~+Ud?^|1_S9|VTu$rs^ z|C^Rz2Y$xP?7)e!-xxPA7I3Ucf$rnI%|}`zmJiEV$rp_YCd|%sT;oFE!9iEy%Fn!( zJMUc{n@Bb}P=-i&P&ru23^pZ`3tTLG3iFH%O5Yp`C4PV1X?bsFBx>oq{@4~z&2*69 zKIKRctwGP5pN4}t*=lR3H@Sk0nbsr~FQw|3<;9+EuX6&$E1jM+%~;rR10}L%Ouq-~ z@~sM!lln1}8vpc!5nm%sa=UxXvh4YM61n|}VTPsWXRA}+KD!PXjUYu&8AyiWlN260 zb$1jz)C=(-OrOvP1fgd~s+}D>2Z8wplrxTO2F@PM-(uk-w%%eyWppTe0;taA@ha>H z01Fm+zW812wVbbZPx}N8@ut;U3<9Yl0F91TFsyO@ozuQ6hKG+|Z@a0dtD2!Cg#ms4 z0IDLk5)Q$|9N?LE(vJ=rGH<-u2p3p^U2@=$x*SO8+(Rd6YATz=obB$u^3=OMe9%ft zih?&D0yGAsmgMP+os=WyexAg9-Y5L8gUdsxpC;l9M2f~UkK29^(#Bv?d?H?lX^h>+ zxcfp~(WQ3x)1M;tFW0}2sFO$3de;HNQLggGK8mvg$DX%nesh~}F z)z=!hxTu1Qt+!!@D|^!^1(T+7KlR^HgAf$_GB52?ggl~VR>*EdW8M0Si1HbIz|$(Y ztt6wEzyDo#ZAac>4K!>(G?}{+mm#esdau=1QA%hW zvyK*t=)!|Wx5khUOMaD)uWd?P3F@4lz}z3Fs~C3ZY*%AjeLIo?ga3t6e4DLVzyPhwY7(?;8-`GE9ZauyceE?F0Ex!45j~ud1;%y3Hw?9^tOP<9zfEjz~~?f;V9z6Ml11- zV_lBUNx8?Tf&?^Bu(RA0&7J2vtO8=O_c=o?|d(I^XwV<-)4?n%Ses8;5mR zuXy=N=r-k^TeW`wCbsIJ-lNF}ZX1I#Lbi+L6FuEStmvX-pGdSY^7dn@$Z3Tx!vtU> ztbpr=j9bUEAP2ZEY=4dsTmKguCUe}L*zI}>eaTBEG;vP!^V}1>GrL(h1Hi92JV%#y zDXo}mqID{@z)o)8SZ^gD464f(4UE6&icf4g8f%{zTGj3uEm{rx)k%qwZWq>-f zqlc=bufvRyVWVKRWta0QW`lWAaK1}0#!K0A{|DENxpivp$UHOxP<#Uz4`2)~H#LoJ z5J(B-4i@#ycQ*e;T*En@Am1+PemR_oG0D)E#m)d;dv@sH+R*O}Z*FZ8m45`+gFbWH z*NuYON@yKisnl-!u_?h$9@e2$k8th(fxKKoHv1m(gVI`HW3?s zlj?H_SitMgCK&--^$*trh0t{~To9160nxjG68teOGWMX_B&f!u6Lk{l11R=yug9Tj zsP+H$G1XkO$TPA_p*c-i0&UGVv4zU+4}%pU>Pdh8+*DjkX`B$>aP%NX+#oIzHEi28 z+Ru4R4pt?&uXXdXOu>?}5PMV5W-^;ow*vn$GhsSaf=hIFb6C_!Df!^i)vMyGEqtYM}|ho{)bR1XY43B;Ek6`s|CNdfZ*%XFq@Yi6j$q>E3s z40PxjqRZ0MF!@E$De=rJkkbZo49INo>-AMvKu?Cx2t0R!SX=0rJnc|xRVEZJBG7>w zZ~M6yz?Wg9ymcMAU>cRFciaoVMt!vVB<8$H)-CR%cX7QD#z+JMl@cMrUZ%E17>1My z&SRg`WZH5Clc%G?peYCx} z?^|WwYcL@UuW9?7pAeocf6eLWalI6zCwo{f!i;UjPoNwp&{31TGQ(}Ey!;4(tO^)o zK+PWkVM7)k!vMA7k|J(2)S}z6E_L5ZKxus?y#|6&Gx~I5+iZ}V)-KQbL~o?qgo_k# zk9*T`<~{+p9I*stcOUig#$D(~#Ee}-BU5NP9WI%{GXM8Ax9Y*Jn5VYxRMY3V9mADX z%_ICid0EEoJ&y=|=6PZRX7#@3;|-PYI{rpv2o>w$K@bEVZ8tKrH-VGObGIhK zWS960RVF1EcqxY%<{oUDcR&4KWIhJf3SnxqkO{4sFS~@Ma}1{n(8}D#$}<$7j^E2iU(Qt&#>)w zshOWU8;OkLy$F2DH&ZfBEJwe~PJ3c;{noP0O}6y3RI%mYx1lYi4($09JpVev%7~p) z;6AO(2rsT1E~^WGQ954o5At$nj7Cw_5$FCHXqYXPiyf^7a$7)&y66*IIYzcjp_w%- z23OrT59)FQD*!GH)H~|MTaT6ZFgE^Y^LIRpmRNs5QE}z@GBj;{;9y zuwk;2MwoZxUTu)Bu20s!FqS>qyz&!Wl^N5dEt~$Zi`Db(pMCXrRnd>*ch@OCZ;?h! zJy^DF{l>XX)MD0)(%rkX`|g`OqgGv#_gCnh_vZU|uy1C=VGpV|y&fxtV}yq+M-9#< zYGMtxt%*8I$C;Y<-*B|ef5drpbYY(atpa2XwhC_ALzT5s-3-KvFoL5E6pMM;9Z<(V z@s=&5=Am}91XAlOT-eEE;pR@V+Pu2M+EC|?4Lfo3e&3qB?be{JwN>418=3v1YLzY{ByzeTjGs%qT8 z|JKZ&M-_+ijCnZ*4^xDVL};)+lNQalH?T8rp9}_i5_C0ZaYS^ai&{snfEjn9UC?0o z!PgM+?|zD_p!lr$_Oc7F>&nDh_y*4&AL^PoQZNFjTB@<0-H%?E>fF+LB~6eIlvF;q zFqEBGmhzNo`wK#gSDOS&uKZY$R;1|(oFf9=F~p2va|BZ`%fiU-8}`2H^}unBq}}na z$MxH{x4@YKN)NaxK!X#>mV%Rb?XR_d^VFD=KP07ecxNM~_@XIxLulK;-$?XdLN0); zvCXu757<6t5W1M$Bg(c82Bm$22904YO~h{#h-6ax3nyq~uWPY=PjjDHeg0JS{9=ZB zc-q(^cp3=AjWe(+?V2wOuBE|#fcFi|GIFy|-_w{X5fEo+uA*kct!$cWQ%4Dps!-9H zJEclCPUWWlq%^5@qGSRH$F)vb z)QU54?INTIl73PAnd>ZE-%P@c?6zmm%mpR=XX`M>jF&kMqd8upG@GA#&-;eZr}?9! z%zUNnHLm9zL^5JHv-ZXtosQ1AzNtPl?jTb+Y1?%8)92SqYQNPDpqTJ7{EkrDH8X1$ zdf$_zqb7NKKM9gtpYI~oG^&=~)ItuR9yg&V@1PuZ&HrV3&E0rAYBFZX^Yimx9McNO<+-#6N9r`CBl;>_#+z=9 zL1}jBB_wAOs*tEj0ac)SaArh~v4`C?dFhMn9Tsr-M!f;hybc!A|LvRPuMzGe#XVV= z@HMqX&Dd$Nsu84-_*6K?jXYG~fUG7j+P`fMGL)55WczSzVRBL2BnMjHDykTmCC8}#+1mSyzw#K1#s{;UxX6##`h9uW*b$ zryDclh`6Ho;hc&3j^N`ML_!r-XFK7!4u$pt%bhFavsfCD)tPaHxKPGvb? zfK^4|Ze_s#{#E`9?fWsjxVj8Vl0VBnM@FAox%m1@?7E7>o(AAU1muro$kvTr4%Mdv zV4P;*NdN5{f%Zyg(z8Y07KAId62E$x>5#=+UvHN!S2^Veg^qd9Z~+Du45D*0J1IOY z@ajfe(4WF!xp=F8@@vI6P@SN08PgFj@JK-`{{_(Ifu_SpShz@a11}7SeJ9Wpr}_}J zm~)@$lOS_@m|9NsPAo4cTUGmZbZbr-N-Z~U#__TUPjjp+8BGq!KL*9)FL&EB(Q7=< zbu6gGZW=UXcMTjui1q|77&E=?B-K-SCD&sYHAQ9=%BzkBPH9ZO2Wf?FIUiW)b4cjT3|gLUDj&d- z#~0|KjB?u?i2uixl6sX~Xo-~}rNJVRc!1>k=^w@Tp>mssMJKjOl3~ zD4}w&eLm;sbI;w*ee(r*slW821+e+{UNP^#x-cQw0jiTF{>_SQA5*LhDbuihx7C3d z&1dfYRg0;rxz+BQ{u8Gf^DsG$w^{`;$+mc?xZE7pPCAlhdFQ!YFOtj@VI?p|qmG*f zo-hVq$K?VcF$t99vb(uLkt$yxh(&J_H;?zF=`i!u;4(~%9cf+z^{+c+hU3&x#ka0| zJDT==$k)nIF_s;LSa|bq++=Dxnv00sUHpB(^p3%!K-?GpslGA<`hJXBmBM-7m_6v6 z)oSIp6=kFP(e^*CduGpZ z79AI%I|8gyfc?Z;s+;7%45gy$v0*x7E9DccL5^mDh%v2S%dW{C{6R9RY&t*5va}$*V1$+13;q z?~vf-B4oeg6dNROSaHp&;y2gD*>3GSX0<}vit-SCjDlQXTa&C72RfS56aj>{(sB&XZK?gC6&WwA`#+#V)W0 zTdb|&dE+-ksJn5F9RJB?YQ(1IzWq*CT|%D!urfhEC^sKUjxs~=rh|2v*-(=4XpH2WZ`i6>nTd@>uW5i3J zM<$D%Fy&0VUCcck{y`$j#zIup`deDOfg3{uhsK`Dg_MTo{6`(%q9WY&C46|Id80BP%jo#KgN;;U0hm*a=}cc$D}f0!@!bkm=a^Dbn= z&-W~pb_X9a{^dMT!x0wtoC&m=5St~uqF+7hv)a2&I)CH=Z|-BNAK5qZa_EOys54W* zmX`_8=dR(;ur2M+eQ+P@+0FD*kekT%HcK5MrowXhYP`=#F-Bi6(-~kil-}d7Bi-p- z2pt;|F#r71l2zLD2KXaLYp(H&wO)7-puyJgO4(BojH1);6qgz>iKuD=Y@l2 zJ1^T`l-Utm0=!L%RFM??F?Q_T#wOzDo4)?iMjnGF6NlX9HC8dnR%a-U5zXTHm-q5R z^H5K4u`cJC(UbTa%+TLOlg^TV(!ACh@95X{1yak@R{F%?>1+XL*mx6W!lLB3lS#~c z%w@$Fmw#^kACJ4j%4h$mE$6C91A`xUSob&Vb0giUx))yOOhLhjB@;8wW`}fFd=(^W zFC}3|11;WB_x#;&?Opc>oBZRiEVVx~t;PSYIp3B#|7q?61bBVD_5rJUNW0`;AAD3J z;9BLOpd56~cRp9ZHGaJb!i2S=#0+M{f$lD(91iyEYNn3Q>udH@6Mcrx6T1nb{9_2s zcGSU?xKVd(Akptf58=d)utB@Z?+d!JdS`IC5wbuRd5yT$gFCL%1KvryikM=F9X(;G zU3pMZz@27UAPv_VnwxJhsXzwunz@RS37si!0XtU?oh)mS&Ig;ER}(+p<}7>>SNGJ zhzMR3a2jjGx#Gqqu+t;w6;CyZti@0ZgZ_%H7ExbzeYDO0_4q?ARB#wNOg{% zM}q~fUB7yWeO82f0(REvg(2eRUf;Gi$O8fbB$FUE)_4P z|6KX&*^ktkT~CI{A^qVLoP1XxwteX?@C8#))wB_hPYh$NL-_}}f5qt)u$HLT(&jbZ{hjShYuMyM- zjEu}SjG<~D$9RosUq_VhdcLmPeNkd$1Yv-_N?o$@nEg5Xa>jur7OU6yYl{Z8&$^7W znzCX*)O6O_7fm`XY?y^Z*eZ@aDn_>!ev})mxgkzR9frWsaiCL33or8eDYnyO={Fro z-O+U^TgWu0=M7hR6raz2GdmLnXz6rt$IKs-rzHBU((H4?8WE2p+1fFe<4&qxPKtkb zQ236rz}b}k<5uG)9LLrVxqu&IX2&uYmI*Vl$%U=f+n)Y}X&rt;FdTR@%M)fTwpr-C zcMUPyGC--)FB#~GXJ5VFUh1EAI~YTzo44)-G(;3j{Hxk%P}LC<)QmIr?iKo|6@A7P zb~y1b9>@6zy3WUjg>HN<%v*IiQ`@FTJ4K38Ajxj|@7}g!edS{AosdR6dQ02n) zCsHrjbl&X{VS)O6hcBxfLg<^ZFSM=b=9eGRgEHz93ke`dq{ZqgJ!@Q z{RkbY2O7(jDvu0Hh&V|b?FB@q__q_JrlBqr!`4v5xYo2uYqqg|r0&Qs`wz4-3o zLZlKqi7F!#ZqiD6ka=%X(tbuYFA*`7n7ZSqAB;V21#k;RDI)A!NN54C#2($lF`#Yu zU{a3Er)8|vT9*OSAjSBH8J=p!J^j4Uet-`V9!9B2o;jmq8~3cLn`Y{bUh26#k*Z?YJBx9&R$SNy1@mWcSM-5ky<6 zl{5_jNns*JU%RxsD{Dy_kE`x2i@r+M3A$fb{^Fw1>;8Dcs_fxixrR?_u1z0OdXX3NN99H)qZX~_Yu(KW!y2ooWx(`p-1deZLvNszLS ztfBUtn3Zexg(-!XteZ7{SXtz)qFOegl}Jg(Vp~ewUT!l*^X;^9?472;s_bkmSmy1I z+J?7Iw=fzn>kKj`>SGV);cRkEGjBUTrZkH#KEUYG{g+^G2n-t-+a7$5qZ|)x{fbB- zVO;i_^aJ3=1m)qXfC<)3+6~-WK#NS;GDXc7fcyxbQ{u9)Q&U7z)yIh6#qjmbK3UBC zRY?1cCbaLqvq=kpk99fMkja7a!IUv(6H56v*ZiUFFZWv;-%1-LV>>jX(L+iz?SyDO zDD3HApby8*C73BNi~Bi~>|pMQsI~er3|y%BCWL7$rEUQ)$nXLF1aw-1)aIgD^hd!* zHQ;5+=~vVYf#^M=4AXHoGr7w7xeQi~5=AR~Gz=CcsDO~6bw9vh8db2u3_g?Hc1Png=e~gOX zWC2lh7mtajG^qe~5MDaR;V+GX%Tr=|GGy8{7aWk5RhuciLy8EK1%Y%H{ z>dZN<0U(5RptMN_Yg*D(=pHK9F*Sz>$H-jG7TJmYM=hD1@YtU$B2kSmu3`QC8B+73 zq2>=uKY~{9@V5o?hLEcxI1f!b_!JlOBt9iXmtTXCPF=O?b++%e7%{fyM zsrK**NM$0jwMqX0zE(0tx_dP(eAYxf?-T3)W)cRc1q_N`9$h;Yc+7RCvKQbjK4fZk zl4{qzG_a302=lppxlQs`2iPU7HV(|Z&edlnt9-WCM*sTrA52M~rlOKj_~F%L0=aO8 zWii95CX!)!gkyZw@EA46YlV}WLF)1BFplalAg6T`R#*5A2!whtkdwTUzfFll8p*0c zrn1|o(#m-U5st7h`b9Hd$4!h=Suk`_@)iEdif+f*+?6MygZQXwHqJBHGm%LBgg!c4 zsHamt3(dROGi5T$Kfye2rBa!l>he0I=xB})9p41)!pk#-y}veAeZaRe)*8JqO9K;$ zKKq{I6}5^~uUH;oOK(J_NVBm(%!y07dd`fu8gr*pUSRMXJ1XmZ>QQpWTa8fYvR=EK zpQm~KS%G+VepfJ#jCDtRT^O6Z{dtzSf;^V^UWKDqzO_A1hr1}hL*C_m#u`8h!!c&O zRoji=NW7%;a;~bU)Dsz_R2VSSX`CVx1VjSKA4Z;_jbyHz-~(QXa$`0w_0vdH%@NGA~VcerH7 z|JN-vV@daqBY8rXB2-}fZq(_8kicUKlZk$+rJzca`0Zz#I`-FUsT;VN#$ZPC7mh|> z;{%PM)(YLAlx{3@^+WET5!Wv$eXuSss~T%FUek&I(5r&GjdpDP?+CM7idBhg~y*U0cwx;cwJJiJq@$om*I=Z#Hvq5KeklURn?^ct#40qh)3h5Kn>{o}a{K$@?8fQ6}gP7rZ)fK5&}e@LM! zG~J~Xt_BPg_M8%V>mr?E5JDsdg1i%#08dw9U1hxdw zB&qU_4F)>iEDZmFD%n^e=(Lp#l;=npYcCk=wC~MhqZ3S&z17}vn; zWFAZZIW(d5o%V31aS*iZfGIea8OBF5;b3CUqGNTbm+732p8A9j}FGfu>> z2M9$B-Dw!}c!P)$%R4r1%mBzhaj28G(OK&MBkA1Zng0LxuY(RGNg}6_qRgq}kkj5O z=d>i|ypWv9d6?~uAY6kLTmM zuKRTzCv&m*k)?2Z$$j&EgEk$VD%mfX$b>o)2FQ!)em3|yxP~}Oe}p_=$b$Sh$0{X)Jz^oqvl(qWrkob z{2jI`s~ti&XoiJ^0A+tmge)0G21A>ww%J(JoL!*Ncjp7yXLL93J#j-`z3-*e)^Vp0LL9Mf~(ybXIA zK+L-fjE-k9u=_!W$was)UFKd+M6Jl*nmzqg#$z|(>)D-`y5!F1LW)A)B|ThW4L9a; z3YGDe?R$bDT(iC4L@MV*Q`Jx0xkh>=Tl>)atoPhL+2LGeHmUn(BWr^(o*UQdfYnMJ zNQ_zgi+sNrZqj7i;Q5w9D~qlh!bj07^4q;!QaO>`M6$wGgncvqJr0X|G}x%6%EVN$ zMUZ+~CX=2r5_{74szK)Cx>~(y#>eq49bQTVebS$j`>EyYq|e^SfNVX01|V*l8a6$h zIJ7%BL0-xbT8Gi|Nk0Sk{3O?Y{LgV0&(^4~AF*%N4zA1yXgmpg$aSJx>K0cqK4pE0 z`utVrWP$QrKwZxv&<(GTBF7W(1-TA1ZP6Tgnv%ny2Zm7AZ1D-8w(ync8)a%a#!AKp zeBh!d@PgJ+7pNZ&%6!Tz_NOb0|VWhl6hNbNDV<%JhM6kS??ItogF0g<+qI^Ump3$Z=kP9XG|) z9&!RS;>1;B7K?sspsCtOqr7_;Bv6B5{BdGE0wT$23_3b#Dn&pY^qe&o0ev??bk)Bi7 z?-#B*n8K}CxtgPGzOdR}Dx#M7bmaZ5tI^d7N9k~euzon~4r z^DvBUY&@r*Gg4oI`#YW$!xsLAt9vLV;hZ-Vd69Pp)Pig~*_WcmhOE4vop1GXuqZ^0 z=79*{qJN0bJ=V1hs8Z0{Vff>wdJ(Z>8xE&O)@9rau*&9*frE#g)H>s83TuAShlF!w zeLzXlxs52@N8znv{S7mhmyV?O+WSSQlzjBpYBF81eA~1kT7jwN3h#E%o#>qRS*n`>DU2mGreD#-Y z-LD(edQLT?8JC!%Lz3Phk4G-Ig}eRK|NSRelnC1fJ<3>hI28?}Bp2v$F1Gu32XJ&> zJEB%UL7GvwJ0Ta{SEf6&U)kin_rgCrb!S!_5|^>GII@C|mj6{?+4wniZFGeOHV^J5 zjz@i@#EQaB%?>^Q?+K$;*_Fm&zW!E|TdL*M9l; z6$7Z_nnO(ig3u<2Jv-ikx?+*Nev)A_)yXGyR^s0CfbK_8_-EC??-m_G1-j2UKbJ8| zi&;#x-yEFi;!6ZiwS0XmRycTP-NEq4nKlU=GUhCPOS&s#UKXN^bZeSDWZ4 zMW`&6Nzsz&#Z>>*wjFzW(-7_McegiFZA&zD4ZALJ;G=7T<}_Fd&4R@FQ}`CD6fDjD zh4-)iAgL>A@vmI`#zUlDXu?Dn-9M+SpErvs+t6f#ORFX==t*ftw~$luhp;HWA2@kd2Eln-kRBE|;+X0b zWhAbw@Mq!Bbp@(O0Cj-(oEKMHy-#0%ZTVU6d3s6y}CMt zmWAys%(sYL^=<5{(}vnDrtH&JTNVF3g2rfR3r^*ySfiJ*Jyk5Y{!`u56(cC=`@EPZ_8StBaG@0TS?Rn{GDlt$ixw7eN&T zf4Do=1Wj@fucn@k{Y2$`#q4oR-0l|fnI&wM&*_jO-cwJv&wCExdjRj;2@IR_0Pi&} z60(q31~XzPq_gX`*;Ksg{@Sw)lh1?8q_5`2f7u>vrP+1-rO~3%8tdn=mYMqLtGM5o z_<13DEi+QRoJy#uy?e(FlqZ|z$7(^|lzFuuJjx957JWT_@cB9Xtqj%WEE(QEp1^c> zE`Ox;ptdmHXBA!NefN#n*RQ;@Qwaqpd_p&pj%_;SBPMq9#hs5%vM|lIzN7&Ts8|gd zAJ{bPxnzXV%HNp|yubbG_oc<8ecVul^!m0@V?SCOQnB@(YVv^6JTC%}a7+!_9Fs!E z#!MS`ozjF2hea3i*9OGE`@?OJ>iIS-pZn#ol-!a#W;3`Rfh~#mPzV?t=yg*wYyr01 zP=2hd>Sc&pzkhFr+s#3^)Rv~b%`ot6$JD1M93FJdpY~RPUI1FTG-`ulABNdPL0uv1f^?Fgno3w-9h2*Lh1rdh0KRT&ARM(8V*d?s>5-8 zmYTr3=UHf-cQNL=wCud3D9rB5#Dl4)(t}86(zb5E2T~r)?wIuH_p{ezs49EK?Eqe1 z3j*vqLm3mE<9nqhns{)LDcA63+B?A&-Kpl1)u4*ov^yH&EUh+cm(y|qjexrylhte; zuFV2kl0GHUYxnkWU@E+QzH)!tp>C0EJl`$|ni4Enmd3xNw^NgV5}3Ec{OF#cs*dLq z3Bw8Dd1KsH4EPMDhbMEp5(|b4*Agx}e0_03t7`81}1FeQ3T4(dH=CG8Cq5T{~T}D2Do?p-O)^27VRw$i| zD;5K9-u@s`ur(AevM6QRZI8*qgOTFEIjTnj=7>PsP~o2P!iM@>~|CIjmx7q>ju}5HVZp#b`p+OXKGN#A<9aq{uU@E{A4Vv{PK2auVH~yhE^K?KTZUueLbT;-gzp z>f?6?&bkX9#NH`={o$(3jGdfdTE&7lkNEMUY>MY$wKu43o{~VvSXY|Dj5A*hmnARh zn>3J|g&(LjKg)NFnvP0t7T!)e{I=;It|?s}N{Hhe9Y2yEfgH@U`|0C8Cnc@r<=`Kdak9^*HCUF3^8 zn?rbI8}G|xt;Fp7!8@lcUt93$A>kva+pn1RG&3TIgNYhM`UQo%o>Zmz&lo&wyLLB0 z$R$9fQFSMPH}|Yu2*lYwwIt%IzlL374@!TV%m*lo4uYdZx@qyOPDYs!em>Om=1DF2 z9iC(??J0n?5A>J=>bM#K>+DhNW-Lu8hounKr~c!yHyRnqXaw()#wKU|Fgw04fk=ztM7qIRnP(?Il99S-KJawpYI9S zy_+C2u^xJuvlZEiO5mN;-!`H!8*58L+!g9V+}ZhjMG4o~*XL&0(E>3HS$n24Re(gl zIVDhAQJKDZkS6uk^4b~7qGZuba2GGA1K2&xMKWaQiQU3p4(Dq}H#fJo1P(N}&^M(- z|7H9;1D)e4WF+ibUiJUFkTzT~vmBFel3ftC2@oKGbjqvdfP8#bKw(e8`paIn_e^tF zFXYwp$n209!gEFb4dGW)=vR5e4)rdlG;Lu9F1jD8_+#`$&!yd!2l0*PJ`?BfGJ>_pm9m^YD%YV7*B!R=Iz@YI6b9c_*e@+@H{ zZ`bv>LKl6B{Qx9L2#}V-4Tc;k^U>Mh*oKH5!)}i+;E8iQ1Us(XT~P}|g3H6b??M_# z{8;f!KqXRzg{?U(8YTy7>Z&w5`+vKB9xdBOHQ01y=y*`?6g*PhbUNn}`93&(6H)FA zoEBnC89GYzAhkZQ6G;TM5>C_8C(NhCAnhN&s z?*TOaCg_xAJ1q+q+*msxYsDleJeYM5DP*b-1+kupDHD-VTZeUbXjHMGl}GPDHV0q& zpoqhx5ieHz67I_XnnGUoe6IQZ*SdDuvzK{`w)X z;a#WpkKYu}duBVt{Lf*(baCc#Pe%8Ngz-nQ0vLaQbQVq^TU7%S!%h+I3-FVN#@m3@ z^b=g>+~|_SX=>dorkSs6F+HK}h7}U;Pce04`mupw@YQY5=6mEc^cEx->tB(_^?`lO zMW}Q=JO_mTo3b8F5PiVPel+*0wq02sxASj+}C%PP8iPHNaU zbB84JTsN{74+FIcJ-i=mgnadVZ@YqVVzDPEH&|$2(^aDPSosE&ZY>A^R{KJ}=+%Ga zso61TC8Xd5fKbG@pnrDlmAIr%7ok#^Qe)ywlcFC{n$~>5rGdE2UrNJ$AM?*lb^lrT za?iNW`H)Rd{KU&a-y%tJdxdCT*z1g#qUHrgJ=&>tR3FZ80Pa^mPqtLEp_stvZ?~$|Zj>ad1 zA?*xaUMmA%*`$wa2}dG^bdMvSx2e;gl{8^S?iAvd`Ry{4EvEnL*w7J3KC$ek4si2q zDEGd$v-#=B&)Q$QA>d~UzB8S;5?PGe*$@53azTQNZDJn!IoN#l14&LHLOKMM_rUhI zC>_)XIurCVp#4ej&tCj5Gv^ai&HUXU5rM@W`HD!gW&X{zOUg~7zSXnpLaEULIN4Dx z_e-vPz#k7e{##s{9v$G{0bYQ;pQry0@_CMf@@AP9`P=9e<#cVVi1i<4C$jkwEQ3|J zSl?px{J}be{%Da+FW>(15J)}_?Z5>*EHfJ&=cRv1;9_-d%vdG$_$t=`w(UKE?KwFr z50+;6aCFyo5<--0PK2!2jW4}-;A;p_dmfRy`En{GuGz#Dloj~ekSExW9*?df{7LD; zLC2cqD%_i-{`snhq<&db3k&}ESTrA}U#G?UDE8W5Rv~@bvW9A>zw!Bqb?Q&EQBaB7 z;B)G70&`o#NUZ{NTl)9K<ik~(p?HUg{pH>)agRG8--3S8a!dE1yWO^6d$t={!wj_-J` zVaX;BQ39Rx;(Nmia?Cy5`mvo(yE%2|m7FHqG6blGz1zQG4{FD8&m+RkWlT*YTfovw zIbv2eisG)uj3#37CX2+KEb0-^(UEGbURwS%s_4A!U11G=2cwTdiq2QGd5}R=#rA`& zQ?>+SMwTUoH=e0{{-t#|7u6maQF?qucqKHTMEhY;oY^4gL-G9)_~Tn0`u_Sl8T0N( zIyJ@aT`vm($saeQCFp!^Lvb=Bv@DyZPe&KdOtRi}+Hdf3x!)*IH=0{r7K~%MFUo5; z3pC41h;eK@-P`7)7g#>4bIJd?1}_xEp7z>J8voz{IVBbQ*rGQ6Wkjh>V7WIrGzj3=~+~tTm~rs0qYowznTE+^@D$c^EFF_5X~$62 z691I%gR^u?YzN`=Uaqq4t;73$VmfAayn@_Q-C_@sSN9NL*skR&M=^QvJVs=(s88 z4L-vsw+>;c$lH{^RZR?wKR6r^5C<__@=X0R=Psmxqb_iOmzC5LEToBMf9<{-%aX-F zvHQL}ZnR$qiz(pW%z(vYWwWYm@UdmtJx9CM4xXzk3lV9*?VOs0H`o_%6nQI}ol^gY zII4l^#W%^n7QTHDto8Lx;-I(_@_LcS)HgzNzY79G%;?OBzhE7j!`|RJat0_WVG%h8 z`c2UbWzVf@n$0QT`4R`k2X;3!y-+;Nf2eiI03Sbg^Tpr>v>hVIbp8gulbPl%c1GZU z(4iOw_ZVl7lXw*x;js`dt-CL9pVzqs(|T|1m2^1WH2;iMD3((iy*rC}%LW0hDAe;n z2>GmiAyg6gOXv(7I`6{BK$>uXG$>hR! z$d64mdxy{rkza}crs6K_P7g%b7rWLRlvo-ri zx3Kc+!bjpE3}fcQ`r*CFK$c3chO0WRU-2T-q~ZOeRJm_yrXwp&zKK9Z4enpZp~4)C zk>mCzK+js#mNPt07z$wA<>=80B`fUp77rmEo8UXN-pe|ZbjJlKX3EOmye;hC5&3I} z%#y>!Bb30wb{mL)Bk%6WRmyg<{L*mE$z#ZpfLdLkvJ?jBC{@@5)2?Km?7AD%eq^|t zO+@(yfLPh_0qlaWd})M9l_lrzFH4HP{T_9f50+)Z>8=XN*Q>MTzeZ_;m*xeSPpC4H z<(b446GMskI2a>WBb(uDL3L$b1qv5{g}NB>VA1zA3Iw1XX*fC~0ZtI}8*J6I$zK{b z&Tx2+l4Sz1$EbAX<6)+=Y7X&1Wn)#H6Ss-Tg^sWbdM_WNTTR(YeTd<8nsDIY+Zz>L zPW;-&nb)G>dhAg`#pS3#Ms5hPzdmY<7w|Sok>LqbztcmC=8RlJuB2E!8h>&t3cN+R zq*QYANldoZuLtXpMCbkDA3%)csdj+zpev;uS46IAI0x$S;>~XIJ%^eFonNGihTN)t zBjDol!g(is57(J!s<6J13OBSu^+JRcnU}NzKD$>!U+C+aO|=+Fae3JDu@r1jL%=s} z*~^t48{Wh|&lnB|kPRRnc3(gZ>@P%^g{50NQh?SD)l4nT0G_gyzj^|z>Eq4VP867;UkaEw47I%R+l+gG{||jGYa6j%T9c_aT?bLxzc#*j(5gtkiSByQ zgg%dpl5VN3wZ|j*^?NVv8$^!W5HBN01PZL7DLKdx)@Lo4{CILTcGIv9fSy?CvWqS@ zHx3%^uMMT!ePgGg03<7$k0m{{KKvM6X#(0}zrej`>QSCBnVvb;_2V*i)#vhDw-X{I zS8o)@*!e{}a1xi>GC6_n%Akck94C)xgH8w|2@amQ!X|AlvPdp}8f}+9Oi@!WV z08r$|ci~N=3xrqhHzky0Z`NgAclftbTzBtk^GW+D7KR*LLz?H;QK6&TYW-VPewrrk zewW~MZsYwO*WZ&JVWX51JFNjF_bu82$?*h1vP|&6Y{T2y`R$9xw1TJMFy;l;r(!Ct zcI+rk)IK5nIyrbgc=rfZTa5d$RF&!~C3hI*@_2H%-`bn&O5K9_DMxoUgv z?JnZF39-Ki(QRT^BPtCq3h~JZdBu2D$G4w5@n;EE_WBjOPblgC$TF=pqT z04K)&qjC^Xf!+a{O4(l6{{N?~{~=+we+-VLs;uCFB}vp5rgTd~`oponLvNS3MV{lt ziod$$TKeECS8OLW&&LQ`(g#;Slt-APLFT1VX}StE^Cpn&Y*T#^d$eUjVOlsY0U`g8eGN`TNARGzt|Kv z6ps6DswOcY?mWx>BXJJa@|PeEBzGv3HQ>zOFTHnFJ^ot!`;1$Ic*hpqwI33&ey55L zHKAY4{TeojIZI286;#(atq=h10=iUIR&A(K%h&6ruvco886FAfXPwa7#P9-%#q)}YZm+sC-S_Myci8H!$Y92r$Qq6y!U=iZMjjteL#*PW-P?@ zj0^dpl-n6M?<$xBJl21K^wCRdc?PXT8cyFkytb>OKkraNJ@G$y&?_k5=NZTm^T!?o zIC5gL<7`=O!-7Pzzi*!U#yDG^$@{0KU&l})$;YA-dHtb+aLUa-UJ-5qVU-&@PEXu` z?94b>26%jq6)&$5M3onOM9nD~Zzw~%JYW>{y zwpDH0PHc;T39hY zs&kgUykL(uRUIETNOiCjE-bRB4SSbhd?ziQS&ABQ09u6G=P#SMQ*RHFc0;s_X3icA z5`ryqWVUEOaQ*qfzch4epSRK^B*dM&Fc|Bp=NVkbz8-XFAG~K8ee%QyKlu0nS&uT% zKBqbq#0{0!1WKCJcKNrF$vbtk3}6)jkAYmqpKU+g5pkLbeXS)YzehI0qRKcw_=lSP z$objRdhwEleOYAI&udgT40<`u%#CtAI-u#{;VvzL-sw&$==k4kEWI@}1gbPyZPjiz z1yCr%RX#Gm47PZ69S*2!i@cjf)TS=j6Q&I%5Bse>f~#aC@NY`8*7-q$x<0o^nI7-I zPaZr+`gAIHK`KF;RUhdo@=$v!U^{)ir@}Y8E*xI%|DfTi^o_0OT)ji9_~4|4K%CsT zjxf8Mtq`>iKi-Z*SWcUV{*Dd}W0d(^oD@qD-YhjNtN5;KzQQZVGh_J%_o!gS=)J6f zq`3azBz7`gzn%hnzpyPuIvp!_DbZ9Ap2L7xR6|=5?e(Ymh_Pm}vK$3P_H<8b><@TV zP)FWndMahRBm4SXzJ{^yzCdlmgM`PWaZ6V@5n0LkBJzJejCZh4(p#wL){4os+CTl* zk4rPA}ge#}Y<1?_j3DM1y$8*V}D^!AD{<=2}kn_eQG^6%+>ER~tyM zMq*BLJXpf@eE>k4i8a_bo9*al6#LpQM|OUv-V}X(r7V5Q(tz;;-PrfnU%lLwgaMW) zQFGEQ|04b9p3;$@fBFTKOy4nbanovm0u<}7&*Bp~zAX8Ftn05WpqAGx*c-HmS(Ywf z;vM*ghiFqyjXDjNVxOk_ZVif?XvPbjH9x2I^L})D#&5wqOx4`EsWR!MmO{%Mkfs5p zpc8xQ+dgkRn|D;u`g>kKv9b1h%*Tv&?_L!3Y|#DJe*C8;{OmKZ@%gMzY`(v{_8zxw z5w7=Y>W7-GN!&lb5BLT657|ykaQ}NQJ2LKT3{T~a?)R!3#|XM?PWG>{#;xej>pIsN zVhcasPOkZzczdN8-ZzrcoPo+E|M+3Xb>lH7a*<)l?sB+F&!V1arv-=jDx^C$p1G5i z)mu@$>%yR=8tjDbxUM>>8Hc5t?N6C!fnSU&1tf)j!W#@6DRu&~T8#Le3JL&z&L?@E zG*1UyPP?>H*Bp(k=tGwaF&3T0*93G33M7Oo3xi^E9-@n4X@Gg+W4Z9wMIHzeI zm)(vBdS#)C(SncYwg{=*#iqy#e@h?z*>a()inHn*u5$ps)m3p4yXe}6;nE{AN)4}% zXeFl(h*mT^c$GEz*ZT7c%^}RYrLu{*1dM#iR*(zj?Ya&00n(mfi|Qj>M0!oBti+Vq z7=5+X$_q;HF^JQdpHp7+?ZY&U18gQLq;()6OmQlPNtnXhp@T2?RUSpQqcRx41A>b| zB1{zlnk@Q)O{jZ_R@iNo_D+CgI_--(T-9kHHQNHz0}4BXdg6bq9)V8h!gb*o40gq% zWZ5{LpL1s%HG_Y2{%(1ac32ult_IBay11Wwl^j6bHKyV-TVuS(tF_*`yf=pzc)1?v+GiX z0Rv3uB?3BU%mwBHO3<*##$@>?PsYTUNM&l1OqS3&WKmo4#Bw~Xyj$ZX=C%P=L%>DH z$ol#U3O3h!g!-5{CsCU@rU&tNxnZ#2IiV06Tb*NA-h5LrL$ODP&SujNQTq`_47>GO ztcVtH*$x}hY^pEVU1*B-tlcKf^bJOoj^nSef3XDzbp&(^=^c>)0q>t5`fpwGd;LQq zN%UEWz`XFF5xk^xGk0F$Cxtgzhm*2ZQ>^#ZeZ$?|muBmldBd4Y>qm}_Ubofmu2?|= zb8K8k%%d^5EGWQtSC@B&`wu&VVF2vb7k_A=Uomm#H?kzJ#qX|`q+S*4(y=)L_)b3t zV%WEabiY@`PyX{DwN=4*pZ3kqYU0_KU6uR`Z5q3q0F3B2_T=8?Ox)VY; zH#{qcK$NNoKZ6rKGy3}?5aZ2GNLY?{cNCjgTzbbq?y{bieG7B*b#Pq ztG};9*ED%Y%xa8}=4m7494S>8x>+gW$*@J@fdxa8V#4Ml+x9jA#syn=-uwvwbo_xJ zR6UVTg>-0&wB+*SHj=d)0juX6=oC;SH;>T~pkB+9;Bj{*Zrr|7*x10u2m7)%)T|)Y zDyM&C+HZj>-erS*2yDc!hXC*jFhLerpJuTyx&sV|qi?nA_s#|4g6I#k#hZviMfkI7 zBNhY!`gA+{Om)^Kwh`AxOwNgJr)i&?{L=Nai6qhjHe4NQKHZ5G3~W(Q7Qm~8!#W`R z5HHiyWdHv-3SmD{GOUtn@ML2AMge!4lPB|*&egIjEnX3&)YFU2O*40zo`#S2A9Z)4 zFM3{2YggYV!y!yP)|Wd2=iUCm>EZvdNA%%yeN93tJJ4jn4i8{O_{n9})3!j@m8S^0 zyZJOHVcz42Nr>`j(ulc#?5@q~AVCPP1gC4YWtbYrwP~OBO9$n;{Ld% zd2alDxx)1-v@kQa1j>1)?9o%E3B&BiT>%hKQ#)|s3To7WwFsI3lVGn2p)zrS?i}ke zMKOdE%g(!W;FFttm41qdsYSp=EDx2*UF?~7u!mW$ob;a{*=-mx90RW$gl5}HB2YM#{NaH1n zf%mwHZ-Wmk9iE%uo8^9Ie0=95o~3!WY&qR{1`2jCDe=7E?qprwaJ&U;U6XmkPiRWJ zbP=0;l+pPbqc(Yd3{dh!vML{` zGp6NEvQ`LEmW=z51ULx*yRNcPk?9?YV#v*Hhb6b^`?IvvB#TdB^2OvAjm1f1ve(kD zi*=cP+wf}aF*e8^F>aKwZ4t}zM&Bh=p*O1*al$~fvKoKQ+SbkhoIBAAra*+*{g4;F zu~K@2a-_z1)9}YX%%ilvSvqc!ox73RjGKQUo`8dScoE1XT*l*psg+@(Ewecbo--|u zy(0qIYosdjAx#1M)sEaT`;0K4i177gv8#xnTNxf;RfJlXh^DEO(OcLuitt6dt3-?I zV~i7L#7%|~^C{K)KaT7v(SHi(!F)^=>cH#kBSGXB>OP-#5-rygOq@N;8#WvSPAxr4 z$yn23VYtVdBDjgB%wXFE;~!LbtP0`$)Zq%`3r_VL@4e&)p_17<-H+3tHw+_6%m3hr zW7n-dCp{oF*qXc*ydAIZ(Vw!TuBdiAB>ZXd&!C->f<{2{yH$$Q$uJ=Tri@wM2cxM| zl4;tq;pn)69yc7@#o^rs6m;Qto@LO(FiR`U+_0aG=v-$RZ%rAbvh2=RDz5@AWgWU+zbzYcNHho;t zb)N0H1U-wq*I3^eZe@j$TYydAMX36@A)ocfO?P8(yh@#;NnRL(W6()ACqzcU;2u-g zfzxmCP|oGXVoTU|qDov4=guOjc~_s3{JJVk@5g}3U^^WMlHJWkWH*W@674k&shlMQ zrI^eTdO>S2^g_MceIgjPSyJvQbXw-6-GQ?E-clATZk&9k2%C*H}Uf~7i+`5 zR8(gU^rFDMCT%YiNbOL+_s;%>N3d=Jzm6~>ik$nZaA~~YzFND9 zfs~$kQ4)k7!Dye_Pg0(p!@b|@7MMWDFr+M%Seaz-5*bPyBM^UE1_P-@;VvM++cm&T zf3hww)4^xtFsde8dJvWnw_2xTWCITtQR&ycbuSBR&c`>nKC*)tz{Qx8 zph{PI?L%+bOn>5Uhn4-?=09SdcHV8rS`?j(Kew(?{%Yuk?K_N*% z8z6GJvyNy$Y$suog4(8j@cpG{@WuXt8yT z(z-9005pQo4FBT@Io-dkf?7oV0vzZ%5_D;BnXM5+W(&C2iZFj+;JIutyQjLknqddo z86w;#yinw!LtqqZzzekU*gVwIzIX@6&3|I7s0Ncj@MeWT35JDrBTJRIr^hN!Ig=*A!}KIXTs z022uCP~E}IjIUrZctNHUWwwYl5vA?8#T^sfC(EBVA;G_2ZbVwQ2&=+pUVf`m5zDxV zd^iLkdu$&MdEePE%ui5vI30OK4gJFgzd)a6=+7bn?pOK(hIcA-JZzeo1Bv9{zKrX~ zOo|Y*lJ~GcXL<&`T(Qza9Nq((c>vuSABG6wC7J4!?M}t>^N7lWpV^nggP_-^k!j>9 zZfB448hPNR?76kywGBs(#4@f3bUS6kro!rddMLceUtE%xaDPeDmHeyhoqUSy78o<$^QM+d$C*a znEyCxWx68ZmeedJxaDqw?V9;n1&~&s5~E?Q?c6iiW^vNi5@jG}AH2=I&exEjJ^1t6 z8U4MhgdBKW$P)*|#*@4k3fs3{IFA4^ORvhbx3fp$FLVlC31I$UP9Q_*?u$_F>A&cU z-4o>*#}jjN^1aSE=`AfCNH~jU&;%nPuJf@aMC)dQg~WH;YZ8K*BLU*%C8P*M|}W8d!_x;@%FaDyU$~d zs|&yd2U0)gYgXI64Z6N325DMf)VqDt$T9Ub*9X69VV1-YO3qDOBlEWj_$^xgpPL_6 zDQoa3K+17#$M7%>vbSvaI>@qZ-$Z;>Xx*D{Gd1cuW?#&cxBWPi|HKUj&{Bg>7QRTb{%FBYE!zTZha#FT1FYpWYkQJHjK>fqwv8RnuEK4lv#nB%YU*naF{1cSF~yq9lry6=|K;Ke8l-jm z)%+!$Ywi_O>y6Q$rFa<8b%(HFHa{ze8A`l~+|ZnUH_dLmY^p(|j%@tF4>_`>RCdESjA~Sh#AK1FeTsMvu-G3@`y|B*jEklw<>DQ5^e8NS+jWdC$FSo$6tpc{tsd_Q1_SrTeFa9&};; z^hOx0tA(3{j#oh5#kqyu$>P_`x4G@EE~q&vW%b$mQ-}YBIdi_B8yhrumzeC<&;|p} z@?{8Y=x_mNYn1#mD4%s`*TnT1nQf1GSu|C$i0M{*OrzmXZCTwT6}2mc=V-QCVzNI*@sEN zmrY)WRX^=6So5IL5KvkaH41;LtvB=C2FgeoWEigN4~ z9G2h~A=NI^cAz7fW0Gm|^JMy+O#eH#jRl0fkuTfxZlBkLMWcS9)Tyvm;CWnyyiD~T zcrVVXP?C}o`mvh70Ec9u#4JnsSfKMsb^(`(E=yY*#QNb&SXz$v`6=rM?!Ar2*q z^kDxUyECoM9`yZ;czx`8k{@em>_CAm2feVmf-*4e7TO;WrSmT^=hViSW{U*gwH`-e zpY@iFNDg9d(mDCLkZ%3iG)O1 zm}_53$!o($WnCj3O1%i@&BMwSfoE95IW#0aU@`Y1NCYWDwW7O!H91*J(M7DROj#{t zq)UYz#@X|!rXZGRxD2K@gC8}B2hjmLoq1c@1U+gP3PgE=c@GY5z0mq_l>Uat4|KlS zmx}4cL;j$S(FJXGeI&ANsU_Q{d5<;U=&N!jI_aPK(;e!;N~W7p1}ulp`wj|@DG0Sb8Ju(INI_q}JVu$7265fyrYpD6j%fXfJ3hVHa2oQOK**BMbtYie>(u||@> z!rR2$;^cB5KVU@|8^^r<{lP;cCyj6? zM28-C{=Z5S>iS(tRSK8qW)HjsR#9>k(mJhFyX{2vxUQ3!k68Okqe(r1Wb=0|54?X2Q7NAzV{<{ARFMin;}V7&#9i?PTAa9I|W5p9=VC0s38uItRp|(y0dzFdg$k(uu}ZbC%;~9MgC(z z1yDzq=-Z^GFU{Tn(XqEWK@D+d&MZ}AB_ssd)Et}k+aGENLpd|1P(7OPk`&4qReMl`Fs)T z2|cz|H46o99P(kdNGobLH_(vS&F&or6<)RAtotW1$&5&H(+F60JHQ=d8pl`PC$iAz z>ZF#;pQ3A;A?~Fn08IehCkTAt(EgRx15u6RugiT5Qua#?$$|82+rZkZ>(CmVr|{^rTNyy6Zt^ql+cq>_qP!R{+KE9a5#(R;Fn^ zB<$&G&fR5|<~RKPD{vdXvZeuyJkUe311}}Sf<8?F^2)Me+QbCeU33!s0un|NDei|^))vNz)V2`gHLM1%QLm)8>}8Ca5M>acm}w9_uMHjvRJ`<8jnBKW;lgSO zqZb8`QK>S_YuO6)-T={b?5n@i z8mt_}^+w4N_noZM0?=CL<-vHvz3oOA#i)CwK`J2^#gWAk`u0z{jf z=uXrLrljYDmao8vw6nIqw!jD5MEMd=_2TIygQmUm^O6USmr@LiuZ4{#)fkeiw{IbE z{pjR5xz|vgiJyLOU#59u_VA+#9+W22ovNeqAhBOPt@%SjVy;PIhAt=s$=CSDC(zKP zAv$nc|JGc~E{U^vS_?wcg=b+;=v-S@J^JcyeW)CxlW}DBQ3Q8kKPQ1!q4Y9$;8Cts zJ;)5zJ9lo7Twq<8#9w0oHD)DKazr$khV<}qdxXxeOQ~4G1ji9<5JDM5Ne@30cEP=m z-Z=0?io0z4&rtl-72P7TfV~`YSc|{t%(HrPrG+mWa1g?PgvzIi6{Z$m$Jgv&QmW_0 zk%nVaJLDx7=MA_n@+uu&g7CZSnG_xAK-TgW7$Iezat9eW4>OJhUI)(m6{2-uk`WuTJ+jgJCorAwDl`CzmVg)KJ4Lv zdfWEKgvhQIc>^XH>33rZP9OSG=Nbs2^a6%^(p_4)k=34iDTsSRT&%7xBL(>7IMD^| z!dnjDzhrcKFs<YXp2&g|c)Y7fHHO#Wn?^8C;^rr_dKNfW5e zyIxC3_l$gslXx8(@&K$HF7#(s;HM8bo_ZcH|7Z}Xq@JIs7lrj>FK9(L7yQ_H>nEi@>{1wF9c4r=#A%PnOTG-EL|HlF4 zpZR*m?6evDE=(pCvdweT_^6<_`93N6$@~*R{)fGop90(eaquAjDLWaWtbf*HS20eU ze@JI+NK!=diqMuPxXnTS^>3maNr)p5r=KZ&DZkIuYPNiJVsJww{<%=OvpMzQMZ1X1 zM`v)Fcxu?xJk2N4no6p01SJIQx0Rv%1ZKYiBqFX&VWe}2=mYea{gC5;RgPtuT@ zbpZ9o6AO|Vysp&Ed%TgC7irQ+wf&YI(Jnf!Bl3#sm-aRnBU>w}7;_1KWAf&EVr`bE z);+=v9ToS3M&I~PwfA>lUnRhsO+R{M70&!tQ@sV}Mw-yIOYPU^r|Kd~_fVjAmPN5D zVwjc=eJ=_E#_UK>cB94tH$pyb^`Kszgtr>f`Uk1`wa+zx@eTHAlBxe)m$7;ZkKv0A z${k$2JB2KsVT(}jt!a39drKOes0lsYg-g1q`w5aEWukO|lNmdfhXNoyp8q4+Zy%UN3^Y{$UF`ME+=~Pex**gOefCfZsuN!M4{^x_tgLOV zyI)L7r~4T-@N7;Iv<=PoB(M{b>91zZ=Lk|XH_L6(HS;e5&|N{qBFP|LvWhyNE?-vr zXrh2`xgrU$jW{7>0Nr&~rk!ff-ohJ?)l4)Sm9%>1Piu|)5UpgiJ-6VmK~>!;T6%@| zgTDU!^Upt;=H=i-K4Wu}w;RRk?baYQB;a+6)4*A6e zH^vQ_$R65^z*#2tmveU$Dl+<}Bd~gk9?_p$VX8!DIl}Y6>&TIxy>G9VA~S zK$vG4noK(iAxStqM0y(1WS)0=8E(w{2S}#$K~CEpJ~@?3;wl~RZEN-|y~7#_+vymt zU0ueKMHYQ|ncDRC#dK&Z2$0#Zge^-MOf~it(`k@6+L-m(_FcljB=VuSW#Bb9b z7kmDXqB9R<`v2p2rBVs0+{Y@SEOa=Ft&)T!#4>D2a+7ilvqibDFNGqOTbSh-ximXG5fPn7NqeTcRAW_Yyw+DA&jgeM7F9=VSP5__bS*8@)zMjtucj#fX<4C zuURB6nVgjO=n9BrJ!YQ;62a_yN@2=**mnSYcPF5O*)nTw@r^S4!dJpH0%9uJjk}f> zdK634gy#E;ZQ)Ka2`jm3wp~T9Tklme1gDc3qz9{r-09d5KWKQ-q*3wsP8cQlNOfwA zjLWDagpYwi4?g>kBiPRTpbz{t=F6-1-M|!2_mS4t^8JX;9BN1Bs8hrukmCwMo^d1C z3@8@90ytz}b!;>ZudH?dvdIS+*d0u6c)$C6uoqXks&TboWfuEvKl&a6X+#79VIi=+ z<<{d%yNc;rp2a;*K3s{$}t zK}R6x%CKs$NQf;~0weaLrZ0nQDmE&HIQ=rc{m7$v!#;iC@6sPy{KZo4GV5td#QBYR zJ?wI(8-6)fcoY?4E0ygi@qM@1vawkfpDAm$ko|Q>DGWN5&H)qqL)m(u10;KU54@Q| zOKJf=7p{4hbela}H^lDi;#SVn zMWwdWcdY$1C&nSRC}Uq4&r8*~KoGH_K}pn-%6v$jmacZtT|4!8;g2d8MsTO)(?J4_ zN6!;C-pT7L^0m_FYg(6>X55_aCceNCj-N=_fu@nP8RbCfsaL8zY}0v1Zl{P2G{!_^ zRBn*p`BahkhJebjg4PQ9BnrB_w7QwMx@!IF^aESZZY#T!o;Ut)mm!XicRJI4p~oRe zGO1RYW*-a%q`adfwd&1nUg&S5OQ^}+@->{7APve#4F)GpgsT;t}aB;2vYPq%mbY&c* z0zg9vr{qtlcTIgO&U&{FgIdK8Ci_X)TlIrRRF(8C4ev`S5jknOYp>vCPwyAt63k27 zx0&|Qw4RPmgmH#*yi+fDVa|)7*UMH=U^=$%J%fe#?LhRIWIA>7c||-hEf|fr@4g3l z0#7-#eJ09#Mvnq&+^;Hke&y9xJqs7%-$4Wy|N)`TUB;rUod=o6X!Y$$i&q zCPO+OusKB*(w-uO{>WOQ8y-7|JOTWk-;6kn+@0cUO;=xRg>R6l&NTao8lLI4RgS36 zEOO~rlyYl#;!u$5O{n{~VxB)f+=8Wa@=hD`Rzatu>QC9;=K90aRU_AIbIn`07mzTt z{xEP9L19%o^IttKJ=l*fZ|SakGFa!!Gmr!zg_98MAbBgJdl?A)HY8x#cSR@Ho0SzVnqq(oJEQ2 z|LS9N)7#T{UYA(lWh;gh?L%b0%!00FJ6xY=JF)r!m=SE~e*X^X(JMmS%D%h|+Bi{^ z@hd_7W}xEVU8y;<@`c`0je?O@z+`me%cH>vbLdH^oBI)FwJwu@UxtAG%PvnR+6z?P zo%TIkH~n7vg4C@oLNHou(N{PH`3oV>7=-XJwg-XzS)03ylI%$2uf%89zNk0y9p`sq zvS>Xk;i}bm3ufewZRsat%Z*hJAeHo4H^$(>t5@Rwz804~>a|EI+>uuKXLVCyC$D68 z{s0D^qxh*K0`yB7b!PPEN=2y*}^Q^6(JhO9A}G*kT8$_NBtP zCpO*cDihi-m!%FBg+A$XrA;Y* zmZj_EEDUjx8WZdOCR_3zckv=pfo((Nw3Rhyk&JfL*zMLMKBY}6`W_PbMV6T}&`1&& z1){*#E$LIS9G4%^O&niHoL}%5Et@)0-D-7?3G0TQWnX288}`%at&N7}>kn`L$I|Sv zJO8G;c-HNCj==F)RNGkA_15F0kh1>5e~0>u=k}_150;u4ANLd11J|~fLEFPuEU;55 z`>)T+M_)d%^@nbMXK=nebY{KDD3Bz{)MMyV7xaG)s$97ysHj)@>!$DSfa3JkHq*~@ zU*cLtHDQZN^-AH%6-_-Wsd7NIRC)p2rOdzj5xLh`z&i5b_$zs`Z|sDNE!UA>polaP z`a+Y@VP~lDq#;8!Jb&zdM#l-4=-&krBCK{1D#!8lO8T>m&0$i}K^#d41)dgeti97P zAq+61N=k_1O)^HC@#`|3sl^H6X0Y!|e#Vham^FELY3VV2#kZdY;|L5rPirN^lty^u zsCCQF%=|}WrR3Z^t2YIdrK{x?Qs)^l{|$2_2J!CQDH0x53N-~+g5&PWdmoHRjTb~S z3w|`?mv0YO9HTm&6INZkV)UuqoYnV&ub}(U(I{Aq!~SH0`*BN$ePm<6!&gH}A1PwQ znhx7PJr%ac?e@rFFV5BVbWS$4P9Ky+r?tMy_lJb`s$pN1(YHuaRiV2#*JBR|U%e@F z++mtp(+4mbUpPHu-3WWmA^EW@7iST8_LQ!ziB#*>mhYt)^S%1}h|mv;5L?)B|2Q** zAJ%hA(Ee%4R$0K7tcR?yZ`u1Yq~<_R8QhURH;<$T6U*M7kGaz5k<|KC+V&UXLkt@y zs#MT!&?n_BAOFJ~Ivm;rjEm4$t^ic_gJV9`i9eSL*IgK|rMf+;8K(BN{yU__H?h4V zIPnR1<$6uwPTdDK%{PIMK^|ou(5_9jwV4H;x)#4(l^wYVpuf;nlW+!A^qifCo%{Ij z&{CXgqWkU+ zXBIG(Ec@BwklPlN^2aW`SByxA#u7&+f#{0k_Ep zAt^CBTfs47bS&!Mp?o#Bn=)xn`LJHS-Wno>FDmsvy7g!cK%|%kCJfEwr9UZ^uM%?7 z{G0i9eH7nr+?bU@)lGRGdRf^cf>vjW&lpL_1#j10I{J?POG#6P<~LWDnUZ+$IpE(U z47rHHHpUZz*;=ZZgz@gPBd+g5)IXvAyt-Us)YasqHF~6!Es)M!Frv5M3fPFB-b46w z2fK|0b7WlS(ck_RuH5q)JOc!`=BpcsofTWlk&5J!y#-tqRV88O1-+<@J2}C$ zq(u{3>I886lr>5NxH+VS4?6edzwC#?!uovu;%O_rwH9fA<$Cuf*gIX zF0(dYJ~NNNhy!6Jcq-iUgScJ?Ur=S~*fWjFtCA?Eo(Gsn=ACH_t_aS~sdYADIy*aq zMjqTNA69EzCYQVckH|7)gl@mHx;_i*dOGjM(BjdM{_vrGXaCj2i@=8g`B5vC8G!(eyPjJEw;V%T7h(Q&BolKm@TFzzl?fVj#laI?v_Hs6&QL|Np zRkc$BdcRaK_RgTP;mkby&gd9y&Cs$G7qi(n&!SazT`6$j-ywqa&%)uBo8kMpQ%Z^= zLUJoNtbsaM3y}xSKb>dEA;`1vRROo_G90fhuN?dV-Z}Yi^us=Z@Lp?_dB;vqnlmu8 z0Vy5!ZEP^ebfAt;3%n+I+tR;Fx+Ci4y%y)r8s}-2<@lwg{x#y*-|V$&Ttt=2=ID~8 ztcFPwY)I;S|4LC!itR+1%4G1}wtmOHkojLMEi0z}`fZ$)*)yAWUI;CI3r0l{8oFb& zeSk(=*2V-WjHB|jWOVCQ zkp96pdQL(%s#Grw-iZ)WL1gzhpE_tl0XsdS)n3Nv?of2PR#A915J44hBNUoW=OpAV zCJ9&Eh|(5yohLn}XbRAvqPiCc$X@Q^31fE#8&do(Em$aYgRK}%`-LstvWKM9^GzMQ zguc#q_{^s#O+!kjaa)U)UH6_(TxXbtBz=FaN}VBk_cc79_ln!hZ4Q91>p|@{uwIB% zy0bydM%U$%fEu8uk-G&HC1*@hEmrEX#?~pt@aIfk4&M5BMe>&Q(9OkstA$HV_j(dQ zl_Uu|XeTvsF_RbJ%MhNv52vb5OuD@qv&2|0OpEo4q^R)`QNz>h(96})#aSY}Wy z-sq7VAHBGU^kLrw>PRdcrQl-rKy+hdLq9{*=9S_p}y9ye=_pB!12Q!u?ReL6=FQ~WB z2YWWY=#>pcT0|v`bCg8c`r~5!$V>A@8aWjm;M9?-{UHAwbk%qo$$vO|iEf1)w_Py1 z-;Tqhb7xlhXM9%qX<^$So&mU2BXFnjb ztJE5;`5V2tR_J-H-<$oPz0RtP8xU2FjV0Soet;!nt{AvJafjHP5eQj~ciz`vO46xZ zM|5u=i?|WUge|Ps$A*-9S0Jgy2YVd%Zn6zW5al*MU}Qbs?EXa99-^2%6jP;i4CM_P z&q8Lyp}d+}YU^Sb_fD;#f30$SRslrD(8W)?&WKdOC#^9kP&kK=`F3W0BXrJ6rM0o= zv`90)Od2}|4Gu8#5p&||jLq9BdOOfyUS2bkt(`4+X{@YxPE5ESuMubTIg_4;cBLyB z40_frjpfn_f?Ia^=i!ysWz#%PF=9W-rhBLZ(!WD5wYX?Lp=)KM%#XZMk4;*xcMNU9 z>SF_MD+OuF567Opl)o}h&UB}~o9=hMtx{L=Se1Vc+2hj$3QxPpyv*R{$oBbGp$T*A zoEcbBOx-)(WJ&(?j^Z6ecTw(pws8$THc;`;(*9jb4-vDkx}D9a-B64y>Pi3lsyRtF zMQNWLN1DS#qc)#0Z5Ze3Z#p@ivb=NgELYG++W|J1u>hpZV&T@~a4MkryaJw-LwndJ z+&Dq(`d@#|-9P*PVH*S)N0x!i0J0Nc(0pa_qO_Zh+p*Kno`-vJN%|6g12$aP5O9zU zyqvEAnABaYR=};P*wJ|?zhkE@?T$tEXzUlrZ5aCMP;mC6%{lqEyH1iJ>8wLbs(mg& z*7(&xYgi-hkG*l9N_zgmkB{$dp5(<;ggiV&8v9T}N`Ol&xGcbuxWFjbHRTHW}2`$szDA-7u5pEeJgp)2NJsYywDA>(x z;I2!8o+cL$Pju6E)G<#XoMB+C)?sb5qcfDvjQ3ie<~k*OG%*0s6UH$x^XI+0P@F zytD*EZIPM|WFFFu#YOkr9;e@*a7uH>s{COU60REHwDUiQ&f{%6mH%ckbfb<3=l?e0 ze)(&MfcFCZL7EJ*HI5Ue+QmF>hP>9Uge`#=$s?0H=6z*``jW7w#1PBxYLIm}t6vh? zm?jkNvQntim?m>W^26L>3$HIKs-tJu{zpjX)-HhXIK}D43ZW%w>Rwl-CwD9oIQPv+ zrpVr&I%oumi%tnX&rEN+CBQ}xPc|TV9i=XBrkGV&cgdP-zAOKUk!7Ew?^1!}t*hN8 z#z+yWpbAS3Yw@>~>)rZcBPb2bTvW9gcc4`e;vQOto9LV?bLTYvJLKiO8CnlLtD+QA z)nEUjIfGliHi9w;AC3jgVjo3P$$dh%`h zGyTU0LBP%GzyLwFf+wcM#&eIU6V5Zr3mB12s<^+U0r5&Z%YPt8mc@8O0~N9=PZW0v z$cy;F0Mi8s>ohYx>&$b7=4l!@IfvjE9Dm>3ZKh#!?%ES6iHZg2b;!EBA)^W(#Og?0p-xf>d@lfA|#<*XZ(%AeXvetgi7GY=N}tF;tE$<{g@oQ ze8FN`z|Bl+JUd-Qa}2q>35~mNtE$g({f%LBiB%^U#gt{pC86LV0zl(5o}I zDKv>>5>aJzf@vq#&8_j@nddLC@+|BWvdMRKCdSaZ+_lbT3A{O`qngmS^LdtFyVdfd z;6195R?@)Ykw4eR6JKkOTYaKv3v-^`{+r(qN253c8({&W`_HS@=Uc)j!xuY*r@ zWPP?>UPXqHZZeEwYlIo69N9V<oXG(j#p7$aMxaUAec*Uz4fnNyx5yeJ0BH73$Zw)47*Bu*w`TK)prZQyRI1%1Bz+ zTQ+z+JX7kB$hq}1<^*A>vl&~G8(8C-?YQd$SB7DMwK)PH6eaKVv>;1b>&M^er|1WW zUF1x0vB4#xMdZWnmI`yQdFVMnzYC{&lcy`uQW9k!dDmALtt+ByU<{-_JqK5Sk~1T1 zxT@aeXjdGRb}B(yNMD$#fQ(YjP>6D+xqqxEcpExz`C3L5I?@eI^-o&qp?j}4Qq3qq zn)e)8qxiTU@qdQ`z?lMHj1yE)X(|>qcWTl?`*0<@VTENq{Eqt*UnS>ug1t^@Iu~3s zz#a|M{VM*2WRgx=Y+%HIw^Ga*!l7*KX>~U|bU141 z2o={A+$(ySZF;7#M?b_?EDoIllEw5+o~kjV_3)2l8$|7S7Z1;-jd#m@X@QtUO|cig zoqn;^22ehavdxD8vSA#UZnJna9VlpI=42z3IG7s8d^{xk|2Hf?l=mocw@n2P9M#vJ_j00(;Ex%F38j zf0guoSMcCIr=w~ON3&E=v;OHiy!oI~vg#E5R}x?5iNP?&`BVqT-rZH<24S9?se$E` z-YVJmao?vr_RMQ>p)1XvYVe#-S@4(SJI&8Xp%}5xjxBw@m5Rg zua_@rV*+Pg4(5CX-?u>4*Os?jBISaWFr?Hdt%OIN+RKKfX25g9LuQCQ^#-$~%8M3c zaPRamo)5X<(dQf`vu^%7B!N?7wYrbblkbu)IJ~xg;(BxKx=z1nuabyDFKj9)0z=7i z7Dh-_0f9fqJGw|+!xWZstd9|OW(Md_I7p>ge2-$xsx>bJ+j(bXDhhC!bTJC3WPz ze+bUaqM_5!Ff{WnWG<|K#I1ePC&{>3TU;8Vg-` zU^waTjP=~KI)`by$W zB~ocZYiuFciYp!*#6U(|b1fU$8a2o(a?ARUDth+fsRMj>zqZCJaF$JnvFtRD`N8sFo%W-XtnPa=o|Exf8hb^VXk(T;R^o_`SpH22J7-Mu@&h& z3cA<|rxI$4D(?Jf9jVre4JAP^9j49sO!sDv7NaK3t=eJL-e#fnQ1LNow5`b4E0en} z-@F%*fI~&M9KbqfGch7inzi%!pH2CVzeilooQ35tynU~+aKYY@pv(+*wx2psY7!Qr zzC}2NOpF-L4T`INd39Vl^c0UFSpDsj!-xH@5C|L)92g>*cL0Pc<^ky|IYGpip3peulAuhxX+LnT_@nohunWarAu&d0@H zlo-8>>V1Wd4E(TeWzSZGQk6LGlaYYHXT*^w)zfh+J$X)kg0<^WURQ@1eaEGv@2{90 z_V5=|w%OE47ZDOex{=q)kFC=-)8iSxam~X}kBsldAGW=`e<5 zzyC&JLg>UTfoYezj#}B#Em?c5Kx(qR!Me1_mlldo{S_w6;hNE zGPX8}l7IMn!$WxmJz2nv*lY$8<=K7pmeCUl z+E0N_;Q#DAN;@Se?JWI-rRQsz!8@3};YFlC`p?$?4&*_CXx${!vW8+p0`UND2+@7XSkT`6GNuWO@%*_89%nDSm zC*+}Z3gFD22ClA1X=ZM{PyBMUR&K}F#pn6OCUajCe6_B5-({xPb<$`zG>f$_>)e`eunBh!;& zz@%3Qdx-!q^>Cp*i4_kvFVFgkF$qnelWf=NRq#qK(VXrb>~-w7vD2iw)6R9uXZM#W zAAO+DJ!2!hg|8oE8SWNj_7VZ6bWvr}b=(x;{T6{&{pr5HhE`Cq2%(@xFqauL{!-}s z`#PVZR&L3@ZwvurZ>>>yA6m=x1%F~SCt+`zb4aMEY0<42!R-4RtkO-}&x(Uf$xb18L zog)yvq%PJ}(Qzrez-VrXJ|bp+m|l5Zgw4stck05gYz z^Y8!n6YdljEK!hEN$Zq>U*j&Q^JMSJCosG>L2tsX1TP6>BFodbor=HH#@PHSvKY#8L+ z5T-r1YBfD;eN;SNxeLv)g;-?y{nGFS@lB!yVul-rcfi)^ zya5SxqbXuyrK42lm5#fIUT76W(E%pexRG&u_lVvH0F#k>?Djz{ulECN0P5Ao)9|4x|}`RxB6Ozp;Pjz!meB znM>0N4bMuPM)D>zN?mSy|KL4s+XD*w3Anw20-c4|D=1X*$r`Sr)X8y%^G@8P? zgO&hVfOE9o!D4siF;+1|7KM(SHcxPNh*Ne#NfzXS-QJY`wKXSP?FL#Glq9n}5EIQ& zwoWwCVE@*0Pyfr09*&)=<|QVEI{X~HAR@A8rnC*cJtm_mwtpKg-T7@SEk`yR^-NG3 z;Vhf&NqH?Nqu_U}tG3NZD{LPIF35SgnOO(bAM+dzO-DHiM;%icG<18quC=_-2?{eG zW~^`enGfYQ6<9Ww{>CfW-u@hGBY!D2xrj2HmjWO}*%KP?NsV%Xp=wbpr~lkEG21JX z{___kRG3aF?5doV8e&SajDgTE%pBlvM|Op>wpood!`H*hMi!SeoY=&J425ggUi$wT zG5FYhf}qXMtsu|?)c@d)JV#$Ch6sIn=sbA?6cqyjxRkjZQCx%HBf=%-B$p)OA#9NgrV4AVLzjYtrP?D8Q=Kv zj&N2jE1pzLACwZ0=ynAg#*C>W(9LtsDWm3Dl@8xRdb(0Uh)S_`ePzxSl_lm%bL|~PzE`s7+=Ba8S?AC){cP#i$<6)KZim#H z{ZtAXl+BaBDgSulRama|$%<*{kp*~;MsLrxh`(O_Aq8>Q-;h*mmEsv5xinB1XsK$J z`sq``GSZvv!>9lpKcL74Tup?9?cOeX6imbMIY9V+4%ks#-7{FCikr5d;#Lt}x8pv$ z#Q>!iC!f~>i7C~zi<@XUNB&Qad_9V4hS30Hf7dy)AARLmVFkvWj(uAL?9MRRIm^_t zNbJjgZPAISC2c{#3oUhx(@|E&8R~bNjZoT~2P{+1+s7v2HAl@D#|x&xMkVk)?27yA zhj`$9#}90Q6>#fN5^mjmJ;;SE;Nx{>C_M(LwocmL?gCBc#CrLjA3k-YQO>F1MtSao zRR#X{kV_NH*!ufgfHYwE8g%Q$bH=A?@-)&000JhWDjs#l6+El-sb&k4us(B;y-WFkglD5#aYS84j-SRarE3U@=Gq#318~X-X=2ie$ zoAf@TLXhLwp4%KRvcqH$mLGFr+6erv>2Mpsale<8B8p`>mab_3H{LPbK_~$`4 zjHx4(n{K>4?`~b_LBG@tZ~)hZ$kPEnbtQ1a zLqQ9rfgcdEcCpVbuBg2I1B};PcEp{P?n!`+#V~HL^HRO1?z`0oPm>Sy7FJYb3(SYHU5$^F;_38E43lEL^XA z?0iYL#)kFHvl&yvEYL5l-oy+6W*l)QdCDl*tDxNt->;$BQxl`X{fCl$s`E$dZF&9Q zG|~xl09`HSn&LHO7QI7o_U=Lz|Jn}0xpc_Y%k{Za#Z%5MYkRMrp|tZ#Ol#Q*bH-PB zG|H&AMv&w5EZoIP>oJu3$^UVI8`%S>gt)BRWaSV7XzTxNGzred78jJmizoK z|JC?I=YK%Q>j>9}Ywr1b@OF01t^AVLAmrDan{kzEq#EtZ?Nj$Om#-%{3b+l@WS#OhVkw{8=&v#`z^eY>p)1*?J)n}cjKQmgYs?_V^OWr}`0SR?vF z+keJT-`Uys)zxy@RzyNG(;wKK(+#sK{O~#W^9aPjdT#YiY+3caNhX)I(rW8e3Q~}1 zvJ715T}*%?9m+yd{V19F;NKw{jzPlt6gd7nwDEUG*d^%n_@CMXHKc)<>jfy&RU|*r)w};W4t>%)ZcL|?Rfv!0?~T}>oKA|V z37A^v9cj#4DpACXbH8=Xa6s4B!!AqwXnkxKpyU7P2Y-nl%Cj3h;~r5|bBoYp<1XO1 zcCkFl?(JJ$5ke^E7N$e2Yb_lbv$k&2)K;_{{iSs=4VY7R3M`Fc_6ZSeZJ+t(4Ba`- zI&-pt{;@m8UQ>8z*_TMkVhbB@THF87UrhUFiI8Tl1S|p&at_-yONx?1g%3 z+xqe<_Xrj7uvY^Fs~p=mR_|he`BBEPrLpW=u^%o5n9Pd0r#_0IIi#tW=xU4=NF83U zDq`N*c{S)|6uRsWgamRYT6ghnH=&ExNugv%fKerN9#FbW^?;W_q!E_1gOfYhjxXaB zKcU{b@}i&vuE3wk5&gd-1F$(d;;9C+EgQaTk!R&t)m9WPM$%4EXgvuVTqAyYsF^kz zRz$SSzQ|6v^T2zWlH+Gj7u=8oye#g&R4hY28PM%MjY|=P_iu3=GWRIQ(lONJm|9ve zJv4$1I*kX*$m#+tJx^_NY%KxtgT`erTica3QOEL0t z5i$#(ohW*?b-8=ImJv8idO{&7FrAxl0&bcLzTacTI+Y>THBT^|Ie#X{L9*N!ZQp#* zmA@V}*QK34SsOJdAZS1kcZFYesNL!mAVG$axj2`cFhWCK>U1tW;S<^}pY(vH;lOX}mQ9Yf6=2@OFmBX$7StL9!P4hQolbgzW4pCVZ@mvL@eC&?VXSL>*u_=vmib=Y7boHK9$i`IAH7#u6| zgk;LgO9AtH)w6M3p46R6`WDmTB_k#pe5`Hr1;Eh%JJbL$&qlMB$fp$9SLxl)x4zVn zx7RsDAVdPCBKqPMUb_|M>8F7nZjtR~Jtofq(1Udzlu?VYl;qFoHItrQof&Y&Te|Kv zA^?aSku$J8-zL8huF-qy3hu0)qOLuEJ${L+Uw<6X=MLurRjQ)ot6-xSvAEb%4?7nA zQI--lPm5I^eVQf$yY40a8G4;Ype_sK9jpuPJ;=i3o=8#39_g5(eXs26z4GU(^zcf> z8yyF%vt<=K%lPEcj28xWjw+=OVe{}kaoX@Fci|LuN!|ckj~|T!cH4v$j%@p}u9~`- z?UDQNZ|ccD;d@#PNQZ&7otFr=Xpi9fc|T)Bg75rsck`fgXHBpMm%jl-zK-0Ujdjz^EfgiSStmlO#X6_(TC=tWs*0=ngPv@fN1W#GRp zz}5$@qsmbK3=tIi#@_3>>mM}HjXnZPOZ-KlpA=eXy+2;J-qaxP5xLTY4?02^1imyR z(fy;srSjsl>aRiLg#6B*?4LtYs+TmE+W)LU6wx9v>xM7-Vn_*l>(0|!vIEWkEvjLc z0JRT>%;j8Qeg(*J#N2%uawXySB=uI~)lpD2u@_HRxp{i3EMMW&%8k+CH4b{&pqT#!-d3Mf4-UI?er7v zGJW%dnd;SJ2WF*1c*Vl^ge%No*YpX;4P77EF-nU8>xF~B=;6$8oL;b%OA{q1Pk&>on|pWl_(6P{;LSj>Ad9w=8WFRp`3)%VzX7{Eqgm2$uVdZZzKO z%3Nh?MZ7X6II zKI}I~oROHB;}|K~RDGT=VF7doqe^wWWbDN$Cmk^!ZXLaEABX9~2`HowxVe`7JCrcE zLjV%jXU3TytvkJtDAYnL@OmQRF5^7o)h-O<8rK84nh_US<^}v~HiW#6yoaga(x%@9 z8eLa#sz7LbHw*K5EZfWEo{l#hAh@Gx>) zsiZ2J3qFL$-DKBMN>7`zbkEa^u_w1Gg0{esDCrk0+g5onbt_r{Gt7%#{?6>5@3OUi z*zG`*UeBYE!uyG?@=M*|0=RV~{MDNkp0kezmc5>#o7vAr^g2_20^;@0cwPC)4$87Yb4yEmHC00b`4Q+W875WxM z$68ejp_)VL^-hHE3@GyL-}Z|)6R6I!IAOMP#qgT=ljz7h{=;6S>ZUPa zPqRi#K(wAPp~U1|YAZ(r&N;t(U%lSwfd-?V11HrNu{jOpTdILuVwP&`zdS1d zEAagth%mSJ&114Hl>QW!%+-Q#cKFoF1{ll&G9t2Z^o~l1ci0^eoXX%1!Mkp?D=Ma+ z$jkwcb^f^ZLD*P&4p_lHPI^!-v~#7aB1WrVI#z6XU>U|od;B}Vd65tUf2SjQBU6{N zE=YeRdr))VHnjx2^0REux<@FBP&|ti!iX9KWTcGpc$Db8`~zx6e*OQ_*uK-yvV90m zm}q8I9W)Q&-I{YpT(esNbUfOGo{%ZCE1YL0y#Sv0RT}Z_b5Y#*XOo4r!ae`#)zpd5 z>1u!VugxpiD+~RfjF@;l`qD7^;2wk`*?FX`kjHB)A4yLbW1^a4w@nT@dML{F!z#+Z z9z`wv6=A{J$pKEWKi=-Re`a0kq&?u=EPhs5Y-#EFxMfDxsFln0c?CoEQ!Z_d5p;CM zLqug3bX;j=OF69pbT!#%+$F)R71;F$rc4!}iVpkFgO79OZOwM}oPQ^PJ6h9EQYBg* zOrBNN@}{tEc6!x6JgTZ=@J=ezQ?ptm+0Iiu=|c>dry0qMoj&or2%xZbESws5uB9tQ zE1$UW?!~-j)T4*u+15s5>HU~Ld*UsVM~TQ)2ck7>m%Q0?5V7Agt;^Q;#4+=r!$7k) zv|5ig8J?J!$+bl|Qo|gbUN+?j3}gqGo+Z36$#gAC29g{auqIDhdlruz5K?ohSP7G76}5{HUipHC+-*;Py^Ew?rL^@hcCS&c9Ea1Gjx_oYnAU+v(YB905Y zOfb3}!G{S`DV44&&(VJ}0=9D|iO>j5jqfEbEjP)n698)g+KLoqcC2pQIsT4Is-);( z6;UJtBgZrO!hanf4&4A_q;1i(zjl*J@ri~Pg3WK~;geP;jt$@DI(A&FVHTRr< zI2R|L@ZVyee*@mWHN9SM>0j-B%Q(e%7ydUP1Ai|6)PARJ7Sn+4)6lmlM5z#1Nqy#e zwQdp7L#vhVr|8kDv8PZMJ$W=m--25)s^iwcq=U6{1TL(20*_>f7rgpr64*O@Q;4u* z2z&Nr(OB2oC)fqJgB0NOnag@ah8@srxmGixi3idKGpvN2;Mgcn0PORIIGV3l<|R0J zBHQqu*g$k_`W8sQTIWdAOhj*0wOVf4_J$uD(f|5a2zBE5#*^EyDwILfnWd)+gO5R6 zanv988$TEM3Ff?)>b~zw;ATt)x=V)`MB0hSlD_xzExYUwcv*v2^_ta8leE>F*!IjM zJTtQ1$JNW*c(r#+1rQ=nW3*IrbF*`B8)v!o_hB^m`cXqzmf2@Txyf^>}Tb1#b?_DE@F%Ojm@hpJKT;skLWVoUrizCAe41D1Cm< zuY3R<(?~?2u2p4uC{}1U^*B!Qp?wRmt&=L-xH@;qNy`USUduMuRg?F?Cc3|BFY@|P z8lp$B0R28};abS%*`9|F) zpBX9jJPWpDzmmy#(qlz8OREr{6H~F_%*5d6=bG&~y#W3KAbO>0>s14~eW%0nrq9lG zV2})+HmD|t4#)46Cw?mPptw3TkOUtF@ar^=3=i1uiqjye^mSTAh+Ju!OpKoN#hoZ2 zuPF&{LMLF%ga8*B=>sLb2 z&!{XSk3!Ba!u+gmqwgyIJCsUF*!E$AM~(QIAp=dI5vX&-`HgYOD-H7}d*Z&n*fcXK z$Qo{4fW5~^I^L#46`=$2PRf2&%Klt`Vcgv#aSz@7Y`DCF&KyA%JY~i*jtpx^%|03) zUdb>}zoeyl9O`Q%vWiTfq(`DP=n9{l;V}sT63}ItBT6@Xx44KHSF|s4G!Ip^zy7GE zk1$OqmbwNFYbw;crF<=J)tvb*xV$m$G;sZqKtepV5QFP5!8b3e#9^jWP-9BkM zX`{6Pb}`z24W#fn9~sGeK0BEJa~s>Nb?|9!sR|RG(iDR7=|L)d8Di}|h z{V)Vl;d7vx!a*gt68FYpy3({+Y@_?xLRHSi1UPL)V+?_))&}YC@<(!blg#Kpe#J>! zCaqkDAmPjk&dF8QJ@mtabT`6fMp24ZS87y>(BAogMD-qev1Y-)L*C?j?3W~6 zSG)(K06W?_v7YLq6|R#j zuzILC*4Zc=ajMsCn}^J6;0VdfY)d6$U4=>87x?~2`osA6V&p`X6hq$O_0tyc03b$f zrH}QLnV}R4JG`Qayk=x4lplZaMaQ#sb~9xB)1_}m2tm5zc7$t8VY-0@lY1uj|IL@_ z2xPltD^T5gE`Q{3RM@O`_@xB8{=$iehvY6$TK^q7yvBCTXZk+{;vRPKNWlZbuKO`E zbxrazYpL(8wUI}}!IuLdID{IVA%}Sy@uMFPtZIHSuBrZ|meW)bfafaQ^PGJriY9!Q zfc4SHBYEB8%^0ar-pWa+>7?d6r+rvsubZn(={Qfj?3^kqh4!Uk>sm>Tj>}Vy@ zWvT^VyH|}!l&o^kN>EaQbHx~DLQR$lOR@b0N=HydJZ1x$0`36?i{%zU;6Q&diL2`O zZ%L^h5nkKk`cZ(S2ca$g9r`Yt4n+Tf1U#dIy+U_y?1`@83lxgHCY1E>PYE|mXVx`p z7JS0)lzxz~)+32cYh9pGMhv>pKmqhw6%gl+wUB{9TK=c<)_=XJbGD7A8U{gJhiIkT z5vwArDzBYKE!#GTd#`@ra_^_(PT#uYW}cn4u`&2UjOZ>FTjh{Arth0^{pS<=bN+az zXj$sCrnNKH%i+nb^ zh+>)lluVI;UqS#aco`rfW*#Ts;%Kpz<1KZQ$5lJD?of(9^aJ@;X@?(QT>5z9ShM$C04_mHq*%yo0ieRCaV{XV~c z{Ndpb=JPr4_c`bFdcLL{8OjFU5%0LtL=sy#(h+FtD``uFrj-bw2WPD#oPaRoTr>>b z75Y)s=4(Y|>9X}4q8coK^2wRA#{MN{K&=aP$zu8bQ?NM95ijrZnAg51POSmMb1614 zm-dEtcJBiPv0;j!YPODI9`u#ZbrrrGsL#MgLB)6joNV8+-3_wCNZOL4Z`9q(J06VL z7Rayizo!(hua44fZrpiVtnweF7Ul~D@dmW{Bh~WXbxecO_pc5XnMOjP zW$c5t1n1-HT7=nGG(!%a5pL#6O`mWY4uVj44S+#mzfSn=jj6Y-x3`mf@=v=ehz)^n zxmO0;RK_T`Ch-c1_vvFvN8HO;*^CTa&fIxw*WSw3_zJwhc*5j?#q^&8jx{BmRL`k` z*QvT^(LG6=P=6=?ekA@`y}P}}qpq(>MShz1X@Ba$x`Mh6XO(gwXJl^9*}wo9q^oQQ zFPSOsF&O{kh~Ag4rYim^N(n^)#{pU6X8w}ZyaR}3V{^P$h}g%~3zjyPN$Sx%(o3;9?bGm0+Yc zGS-&C5r$LNA1eA}AK1;O=W|}OuhQ-TIuJwD_kvX5KmaCN+)ji%_crSj&QH~?#johj za?q~0ty33I0p&t1|C`sRE}?ox50BEy;ZbE|%7Hk+z#Mvrs?YZ5!odGWzQCie1Z|&& zH}JGL7R(!mgZz))m$}|HZU4wwl%5zDa((zC?+fn`x<{ki8AM#HrTQ3QH`DgAP#0C+ z(F|F;Ia`V^o|xWUzJc2w;@Glrh@M#!(d7`w*=GvAKMAEuFrxAjwBJMJD}<GmcVha8Xal&&6Wy?H@x_LMp($qD3v}r`QCGsmmNpITCbBelXytJ>DS*>(x zv&kGVzx?4lTsLmUikogJqH2z|ftaEW2WsPY3{No1T9DPIt)Z5G|AzaxF&1EppFNi| z*9^F2p=${zxr(Kv@&Hvf)@TF!E;u-709d8sM!mT=az>8{td|{wE82ORHyMP1Pa=hX z2A}NMxzpt7y1WP3IIwBJ2@m{c*9j0?ZXm)@RE_lfzYh&JebS>A<`7Q2Zn6di-Oh_r zLRN>nJ|+~BPO>Nz>fU92hZh-=##+I5 zYV|c4$wM^Y2}m!Wdvf!Wj&2Ju z$(Mk{3B4a5wue2GL*am`pjkK?zLS|s(RwXhEL?%9GK!$}R|Tn*%FwyGObAL0I+H@EqbPaVdB9d{w>3!j`Yz zCemO(sgN@se!KO7$|Cjv%7E%)`!L4jmARn6YAtpsYpuD)u9Dr+TYF}iki~r>{hLMz z3(zRcIPOp3EZ7n&<6AZMdc;AaM(vlNHVOXBZUOYK(7O7VOKD48op5>$|4&?!obl=a{sG#op;38pB-4 zg{jQy_hbk6Y$4*%nn;p{Xyc5bYl_YwxaJT_3l@v$z$oii*LL&m$7rwDbF6`n_;pS= zuIv92q>R8f*n{c_A$RK@eC0Sy{o-k}2Bb%WIg4~4z0Z|KBdiwRJwT~VE|_$Xa|#89bQfMw9M zvc3L~&+2%%F#~qI>xs8Gi`Yv?S^!C)a(kkZnc!$w>e85T%U~kg;!sa&zvSS&H?Kcu z?$oFxYff#8tz3Ai4{9l0NQ;5W+Kk*(R+fu9d0m1aYZ{}ixf$)YDbfb}6f7ovI1H|T z|KLZhZtVyjd-}IGl0C;dUY#qq2KGGv~V+UrlF+S!mm(prQ8w1G^SH~XShK$ z->{qs9wpwDpE#%{hus!TQ!q^_RZqZH@en5W>nR z|5e`{B)kRa1+Ny}-|yja!)Xk>%h603CaEy7bo7Cn&Nb+6<8wO}Uorn9mZZu#{jjW& zyWrGPUoRiOFD<8E;cLKvRR#(&*vr}Zm`VD&E)7|A+Kl2nDVm2WiIHWxlO>vzFb2?T ztO9a!;#s)FGGwiO65g1b$~19og`JHboO5{N?9|n2W_7Nk`=@PbuxsF^ZRRc=Sjk2N zNycNnE4`~@*9*a7y3z%D%!30m=RYTTM(V(aU_vDM%=yp@qHG0kS<0tNxmiW1cOuHf zklXgcf@v+MhTYBP4EL$X`64Vi=V{H_`jtDDr2~8{=to^EVR{y!geyPcq2slG_F>|H z?|E+2%fubaHF|pLjmO%3?w5O^T5Ljd_X%VW!2P;B9l_zbHv!PcT4U1_qw_WQ$?!CL z)8=mA017rAwE?IQDZIzD_L_VB0VaV{4B}r)LnQ7&>==6%EVfvJ@tHYM~*2R!{(yD*mGX* z!W*%3R#TLQ?%^>xtgV=Z@MFuDtF-2E&gIG_5Q}DqZ48#YKyccB6{p9CdB48DDuLs3 zkEeh zuBm;+!#s&iDzh#}BgGe%24qgpCQK#6CT#uk$JN^{M#5xbt(X{v|Gyvd>@q;}oYFKV z^J(W%F)xH0*Dg8>@aCoo!4wQ58-3dC3J1fif9F%Yo|dx)L&qszZQaWde}DNgr&K8-y{49kSj5^mf0d``QFP}t-K?sh7G z&csafBsUmoRgc?cELpnkIP4{^0L!^Pc5$K)eQ3UtCz?KG5uDc=L?=0|9I}#tgO{F< zUMJA7Km-`?G_aN9q}kVvKoRV`segR+UaPb}h=X`27%AHl?4SY2nLYq`!wf2B#Ah>Q zptd^_Q4aJ`(ncsH>H4p1RIM)n14?sp{_JRV`R=d6(-&`RgH>cvLIA)Rl5h&?cRH;( zh+R@aA5{zB0S>*h(fUE~ zYmY12@R0L0CI*`wtl+n5f(>XZ@AL4J&rSdEOGD&M22x{;mrH2^3SSk&ZJFSMNDT83 zVMuvNme8p}cWMAj>E4(Gm;|>`dU<~n@33z!Z$_t_$Xof6<@~UgdLeJE{pTMKGRNQ_ z-y!)Pb%JDNgnCQ7?1fsT>D6CgO>sFGrQQ~L6a{6OE*%OC4rWU8eDT;ZB_~!?q8k+(`&?T9%whws(6u=+ z-r;s0Xvb;S7?6Hyb0NSbyHDoN9}@$B9x#{i+!(b2(Fj!`BEj4%4_^P{yJ`K85165F zQa5B3MGUKCoHFEfC)YgTS7IWCUs&Wtb`L0QDopciW|amQc)Y55r7oF zprFu>-u`tmo=KKFnere-Z;4g)*Rp3%ua_7(WtO^llJVXTGL1cg?2wA9ClSJ^>NxmK zV_LF#;=HtOukwhTx^h~6YuA@U&Xtkn992?uJVxx%-TGSf)JO>#+Tud?fZed*+h0hGz{DXE8R2sPyVjMkmBR)70JuRq63o!G<8qLmn7l-?%JXlIyV zx^zlueX62rcLM$aE zvBo|vk0$9gV)=_ndN{qnKnK|@&WzPLaUN(hp-K33tiGaG17CD&Kz4+~#P>f|xhq&3 z06z-o8I-OUGZz|NC8{J^%z?5XyuI@l$h2XCW&Wz#6rFth>#0og)PO7AD7{-T=AM39 zr=L7L|Dno-^yuoa)P}!}6^O6zuf2>u`}kAy^JBE<09mljy(3V-a`#L(7qVa~ z^z+gQ92)e?osQ4vhAP79K-J9pNHIr>8A9_fPdjy`IX=Ze;rPkK2k%!nR@SNYk?5gd zP`DD1m%E5%&mDK?Tvxs$=q#5C!k~!Ec?9i zOKQ~}kQX-fmL!o}LEGzOz?7g!7J6(qwD)1r7?CG)ge@+5L>cNSu+23w4bz)xx|<@#}2#FD@i z?T^fLPM~wKTZ;Uqv)ilETM0(X+#75sOJdos`T%x@V#A?Rn>8^`#PU^-im%D>HXjo@ z$0fNZ$bB)=&~YK{6RMBL#8_J`*rW7xo!=?Hc$2GDgrP2eMqmc6Igytf$UEs)v32Me zCRg6kEAz@Xv;n4(2=w|e>hYj4!5?h?w?0nsGb7~+!?$GmcVL=L>6beEexAKG3^HH) zsjb~Md$pK4*K)gg?OB)e?O(a9&K6P0y3}Di<<>T%jTDk`7$x_G9^y)So;?y{E@H!Y z7z+0)K)%hRG&ht?z#~$r&pAP^$?&5RyL$-^{VR&orK^cMG+6GlR1CnF;dF&my!OSp zn_YJ93MD~BXev(-7kyhow$Uj!jEs7sLKm#cC`21-)y?*r16%-ailXU zChEb?Ht_h`^^IHjTn!q|NG`rO620G={}zkzkKV)w7^k zvCxgtd_J_*!V^eE1(f;jRH`Akh;08_TXg~P(uw&A zm-WX+J_xQ4;784AACD?SdgGx*LE2+*BM-~)7NEg*CPdyDNQJMU=-Ym42Bd14WN-xP zwzmpJA}uhquB+=qU*B&zbKSh2&bfwX39-7K{1FK!Xp!B_u+GjTLA})=LN)qq9Apve zzWRZDwiiP|J#bV_P(hDae}p_1K%(3HzaiBY5{ufMZ zrq3ycCui#F&HL7?w{D&O&-%f~eV<#-I98D>y}I_2ukzkL$*140uRZ{3<&Z7of`)HH z$Rw&dGNDuSkZ;Xbu|icfE3j=ufXe9CPjO!8p$;aIpe8?4?$}t^USL1i+I;DXHhdqB zYd|V{gY&g47$2X$Mug2=2kuwHHg0HdOkh0ZCooQl$^bt;UYv3ET2sRf4xSLH8Jt@8 zwer*O0z=xCacGnfqpU8tTcBG_oht6|-Ma65!!x0Rbb|>7abELumnWkY7CZq^pi5Xn zCGwRpaU|md!F_f7ojE{53w3haSz7eH_U-bcH3t$}7PteCXe4~Q8|4EYNvgH>>!Bo1 zE%wLrd<2mg1FAI>eh&^-D3vN*2Tkv^#Y~r3R`!&SUpd_Nd2*Ut3ok*2HDc+NX7-$S zDjMe;`hm(X#DN&rCVNjCI&&Wz)5r)1HexwLOccpjh%N2iFvi7oR-cbTZsBZ1GB&i+1T z0Gi~jy2d4Rxkg~C3)W*(6ZU;%L2@YSk(Z>dIT(H$4=?V#8U(7(oVP9=_4Mxh+dn3q z5UfI2Vn}S=FaiKu3BTtMTBrwQ#D1lV#9F39H$uWIlmY9TpNl)qv>FdtMcbx=jitvJflE~t<@7-*03TlkSWFg|NyDacdtM>vMAppe zn1%G4{{&|6U!PE?z1un+8Rj6ZtoYp~@~H%eK1Z^_pWa`D@8!~30i7p3mHOK#r0k!e z(A=!XgMdR&=8~Y-$;n(ljAz^8k9o@v37YcLGG|#Pd3a#=Cle2u!p6XkE*80ZH0!_J zHyiJ3$2@{33YUz46AvE>Z^21H@N!Fc6}J8sm|VWco|!XCPUyl2 za4ylBDv`h(1#uhc-{vUzSWy|P`nm-PZwt5Gy+rl^W=G9MBNM*%YR{EbtdWoMB`is% zr9^}<$Fy||j-l5tU|uCnEJC3iH>SiaYmaK`EWf?v>R%SW&`#$khG5>>%rEZ!n6Hz) zGq9}abF?#SQu->SrK^x1Kw*P{NrEPu=WKll3b zj8bOpri@%=7%q0&Qs}E=jgJ8G)&I8v#=GbB73s#G4+~f)=&|3v z1;`uQB^7e(#oX2_uSxe;mF42Eez1~V$q<`d(X{V#O)#t+q@8lc#ZY+1=1T6V{~iBT z7r8;WMJ0++FNGU%L-L7!4zld&>8YD%MN5+gNC_V8_-p<6=&!s3!{vmi;GlPAxR5{q z31^R1Iqg#(cC#VvT%oke#8$J;dTY~DX@xM^!Y$u?$JHUhBevb zz&EKznNrL03kZF^0ypiYS94=>J;8P_i8)Wrg~?0Rpy7z3x$ca6i`lR67dy4U9m0{y z#aww@crit6kIDnOl@WK|cL1TBj+Kdj>&mE?MqK&O{ZRkiI+)=+xoz!|&8lS-XFKhs z1!vxx4)sFJt2zt=#5fSkEBDQqG)0XK`f;Y1PYV4R)E@^V)LCVF2<4r!OhUCSxDY49 zA9WjweJ7xLwahP$uu>fBWA*d}D#31%=tW=FPJR zBHSnaGks%$BZa0;(UJdMJZB0P$9_+ZXsLsq`@MH#p_635_8Z&XGBe(x&EBYwQ3(v$ z>gevSCOETFWK%aTPqW_59`bC%-pgu>1EMX7QV}25s%FRl^7>a<%}>yleLt4Iq>9i* z-C~`xTl3^opQENmmri%X{by{uJZ66q9<%scW4Z-862}jy9jMGj{vk5zY#OJtqcl%G zYBK4sx%aNnNW}Vmru?ljxh1*Adb--YhTEYG@sf2yHA-X?ae}^AyjF)kfc|r7#fV*g zsoqytgqgW{9XG=+@7G#w-Z^tI*gSpcugrNC;2!=HD=-nLbhp#^je&nD-%9l6VNyAX z@}r+E*DM8Oy_0s&0kmd2BQaM4qM=N-Og?CjZjuXKd~xN52i1%AIKENc)?fA>Qhfp@ ztcLwDxk%hJedQ%IzW$vw22tqzIG?M7Ptd_KmTQ?1L-No*%V8^Bc~+$0PaqX=SJPPZ z9nFgd`fu1xqb<-^YjEC73rL_-D{1s>N<5zou0> zF|Z(iCEDZtb~p1d0ibIA<4be%2OhM$?5_0YQczS@Tzw>*iUF--Pc8s1kiNTmwVeM{(?)O8r<-1lzy46)AqLru{~Xm?>L`~*ZWbrUc<3d}FJwVIIU&K{vxfjh1i+(N#6u9(ypCb_GpOWg zo+#_;9NgA}wvN}b*b3)2T6-GmJGh#Xnb{gKS_H>(3YnI(hg-+V@Uu)Z zzGu_xnvrH|X}ZUJ=N!D9)YUs;lMfVgOeQ?!%`PN|27uP_`LFOwu*m_ED`O_pE6=8| zHUP9Z&PsDLyLm9y<4_p3+!&H?GsYu(LB7S|Wl#^u*BFStH=IrnI|Y6GJkV_<=NqP)nF zqQGVR*bPcj`DEImEwEXCm=oc+e}?P>HhQ>~IMbI?3(m6PkN&Jb)80dS;#L?#-&QeX zo2Gc3nmzlSpx1*ka;i4fFh5Nbm9+R+^qV1LXt$%p)$s}UsYW_E1Qv6ImItmqpRU*y z*dQDqZ}EF6I}yFE1*0gn)JL!bY}6cGOf*;tq&`l^j(xzL0dn z*C8U8>q;sAvN@mRI!I%RMWUfdYGX;uDjb>DPh8KA#SkdI<3LJ}@!DXi`}9med^HAYnVui_=;tSMn;+nhr>W*4Z*LrD+f znM#W{^fBsVV2&wU(C&@HV8p)qSLEM&U+Z1FrECE6EuHw>v7|gxYHiWG^g~4n>={>X z(t`jR4KJqdS%p8Iu-7ndnlTh7q@YhC6}C*@dI=ochscyU$>JE>cebEpykb9WIr^LfON&YA7{)j#AqDpjxIBlUVR$G|HXcjV2*PGAZ4 z?lmijX82z1#a_c|C;-rUyEEPTL$abP7armQ^m!Y?I(X1DxCFqb=P&)n6ZbLrX5(PxW9OXn< zmPLs92m@Xn#f6Wod?39-$#yXnpPzPur)IQw5cowQO*MGDZ>59qtJd2}=~ntaCJ+gu z+4nU{h4N*5=Qr`vGN#0`n4c$)U#fRI^f~1885MU{_bDWu*;8cpDr(8zlJ#}UZs*#P z|3s;WZuU}af~P*AJ@-l5m(=GSqi}ZHV`5x} zH^9?k&qhYPLy3588V5SQvls`>Y8lpE2+)?tmPl^30zX0A_EisOK<2z;8?0p{N*>{}&0fatvwxUoZ0$dg z_n}t?RHV<8NOiN;44Kw4Z+mICln{ZkOUZhiB)@`qI83MS_qP?~W$(uhn2#{!kvb2F z6(x@@Zl4kIhy})<0xNeS8~A_QW)`oqp6qM$hBT89Iri>3um3;CpKi4&Dk zeEJ}`R$b5h+@VM_Il zCVFGjrx7!>DJb~N)l%FU|7`N2rc2=T_3H|iiN*~Gft8C?3gaCp!YIa`RWleRV)`rj z0n|Hr3mpM0NaB>d^7!~v9bB2CB1Xw+j>iU8;i1@I@!mJWNxGVQ_4y8>0V*=cmN2WN za6F8H34F9e#^rE=hGmi`B!Si|96iIxq=>nP4yfiH=}b*6T^o9Hx$kYEl5jM0}pSI~w!qR`DbBilOZM^AY@|ghA{pYXBW|A2Fo9p;iw} z)fcT-ALPJi4HmPDyCrXGeqZeq^4O?t8iR;gW1LUW?#g`h2?|?nu2r^}M2MnbH56se zoP#OO#wosvP(?mF$F>{3v0KqaawH#T7k{(r9*T!1o9A~b9S328E!t7HaP-TQ(K~{h z&w*3snaAA($A5gNbwUDH&o#{7>t`WY7-dm_)O;A2`uGfN)Q$zo+*NU1DQjPMO#)y{ zc-qq3)^)3de^j0wfB+HsQ4F_a#l=rL;O#2%x4J?izM=LC-OPqMtYxlvVD7L_e{w%J zD0Rll9{ZDw&!XWGQK^ z=#H|#kX!&vff?4@gM7H#26CM+xwPr#U^Z5h$%?s!QHk{ih27m<-SeP2%?VTT7Sm)* za=EIM_4LOoOJI*?KUIV`N^@_gMQRcnCIr8i|BBdrte4XGNk~8nSC2s4xjg+mPsB`4Vwx1A@nCn)e|NKHs?0`}m7acvnzCtGU#A{Zs z$ei)ghCE`V9^WqV5k-c!&PZISkg>AwMNe5}7h~G8X;*{35bLN@UT|Y=d#al0R|_tr zpt$tfWBW_KE$&Dh4+@?Gd@n280__{cpIpK|=$%$5%&ODWqdcjiw$87){3R~R059l! zKOX2;m#|{BAwJ$Rnln5}8F4cY|3NYT5gNT~V7`Ko1WcXt7|FRz_M_GxTMuj}>f<(4 zPFy(Yfgn8_20&`#si2yWJD4)SdB%qsAnb1J1Ogu*4z?WKQO_11XgkO3QdRc4SbhC( z|C)YtPiLXNhp+f~auHjM`^tGo9hC&*Mr7vlJ7*T&+Z;c}vAzEZrhWRTZN@pulF}lu z_!l5DggNU4Tr=RZ<~rvPbS~{O4t$k^YZkMy&s^OJgU5|>Z?cb+{Yr^mH;n|dXFPtc zz$Sm~LEPmp_%e?z9Wbb18$^84*fat-L9DbJgJG%&3%XaJ8a6?P1x2q*@uU764Ie(-J%Nb@-TSx1>T^{k-o*W zU34R-bz&08HD-2i_#hI-D>S(;Z1>i<3v;V-E9I66orX%v8z*_`nqTa#Q&`c%wv`fVEcDJEfEb!Z+$w#H}`r(?=_`2NaZ3VITeg>I; zw-i(Bu%^NdQR#sF3~n#(wp_4t9=BPYl0PFB`Gr)oozz;n^YMPG z#FyugIqXRcj*vp~>gOfTJ7Xn^&I6VNIn-m?y2s;S^eP!v#tzD#42SbBy{jWk`fx*y z5o-T;-n_@7iSG=1W3D`)z>J{6qZ%I^t11={esj>Pr0SXj;)K`#=UKFo)HhZlFq4b$ zq(~q;%z$xYH3Wa*X;Q~-V^ue(TPQ7G;ql$mrNv+B`j&G#=H_Soc>!3fPUaxRCr~56;O(upiVx=w*qBgD5jjveKLQPrX;3I%?^{ z>AXuo(UC=@V#P3%^O&u!aEFu1PMx??W-w2Orv8ubRF2<0v+zUV^%Dr8&I>0@kB@w} zxpHr`&4{}|#Z0-H=AzU8{|6>W`Rp+3JF2e*5Wov1tlukvga|4pUmvj~NEr2IsM=#* za>X(SlPgy(e*jU3kZ-`(#BeAvxDl9KaiGLCYVWg9wMv6NviOhE3d+E!Y*D>PNYJ!n zT>s&v3G(LGc)nJFcPVfSTz9XQi`wL+*_1SZxuX;{sTXm+XwH|s*fldC&sh%@C zaQeCAkMGiuCKCL68d}w18n>>#MvFslgo<0O>|4b)gC%aH8vZ!<-ZzzF(TtGxRw_PT z1Bk2?b~>3phG^n;GP8dB2(Z&1ynMG)?++g2G`MamZpgeV+qoJ+40TC+;!`MG_6w%H zy~`LMBV?mgdE`1x7FNlE6T69dJ9dJa-A8#2ods=itwnb(CGD1`aw-`qTr_-`-pzfb z|3Jb$^wA@IzrU!=NkVjJauLs}a0xWs<1n8(yVG9!kb4vOXP0i|vbs$hYw|+ZLZz;9 zlpzbg7yG`b{~AH+8B&kB06+3JOc%kPfhIS2j0X&`jg!G*<~g}${?~qTEttTBACc_A z07rJ?IA4w34z&$!mClgc*jwCC+Am20H@vcM!0T^lmHU;kwp6YT=ZRt}S?-H4K@}v!O5`fk<8P%Qeabbx6Hx#KYtC*%!(7J z*kcaq8e=+FGoucqa=;UET#RsTcz-^oIO4Cu<&2_kSxzZ)VK+vih;iwP?XwnCL@*y# z3io#%%@4~+ i@u-g7_f5ImN2gnFf>Z`=cN@#UXQzMG{;-VsH}pSz^>Bv( literal 0 HcmV?d00001 diff --git a/python/sglang/multimodal_gen/test/test_offline_api.py b/python/sglang/multimodal_gen/test/test_offline_api.py new file mode 100644 index 000000000..2e9ea67a3 --- /dev/null +++ b/python/sglang/multimodal_gen/test/test_offline_api.py @@ -0,0 +1,75 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +""" + Testing the performance of generate command of sgl_diffusion' CLI +""" + +import unittest + +import torch + +from sglang.multimodal_gen.runtime.entrypoints.diffusion_generator import DiffGenerator +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +class TestGeneratorAPIBase(unittest.TestCase): + # server args + server_kwargs = {} + + # sampling + output_path: str = "outputs" + + results = [] + + @classmethod + def setUpClass(cls): + cls.results = [] + + def verify_single_generation_result(self, result): + self.assertIsNotNone(result, "Generation failed") + self.assertTrue( + "samples" in result and isinstance(result["samples"], torch.Tensor), + f"Incorrect Generation result", + ) + + def _run_test(self, name, server_kwargs, test_key: str): + generator = DiffGenerator.from_pretrained(**server_kwargs) + result = generator.generate(prompt="A curious raccoon") + self.verify_single_generation_result(result) + + def test_single_gpu(self): + self._run_test( + name=self.server_kwargs["model_path"], + server_kwargs=self.server_kwargs | dict(num_gpus=1), + test_key="test_single_gpu", + ) + + def test_cfg_parallel(self): + self._run_test( + name=self.server_kwargs["model_path"], + server_kwargs=self.server_kwargs + | dict(num_gpus=2, enable_cfg_parallel=True), + test_key="test_cfg_parallel", + ) + + def test_multiple_prompts(self): + generator = DiffGenerator.from_pretrained( + **self.server_kwargs | dict(num_gpus=2, enable_cfg_parallel=True) + ) + prompts = ["A curious raccoon", "A curious cat"] + results = generator.generate(prompt=prompts) + + self.assertEqual(len(results), len(prompts), "Some generation tasks fail") + for result in results: + self.verify_single_generation_result(result) + + +class TestWan2_1_T2V(TestGeneratorAPIBase): + server_kwargs = {"model_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"} + + +if __name__ == "__main__": + del TestGeneratorAPIBase + unittest.main() diff --git a/python/sglang/multimodal_gen/test/test_utils.py b/python/sglang/multimodal_gen/test/test_utils.py new file mode 100644 index 000000000..37f7418d8 --- /dev/null +++ b/python/sglang/multimodal_gen/test/test_utils.py @@ -0,0 +1,260 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo +import os +import shlex +import socket +import subprocess +import sys +import time +import unittest + +from PIL import Image + +from sglang.multimodal_gen.configs.sample.base import DataType +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +def run_command(command): + """Runs a command and returns the execution time and status.""" + print(f"Running command: {' '.join(command)}") + + duration = None + with subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + encoding="utf-8", + ) as process: + for line in process.stdout: + sys.stdout.write(line) + if "Pixel data generated" in line: + words = line.split(" ") + duration = float(words[-2]) + + if process.returncode == 0: + return duration + else: + print(f"Command failed with exit code {process.returncode}") + return None + + +def probe_port(host="127.0.0.1", port=30010, timeout=2.0) -> bool: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.settimeout(timeout) + try: + s.connect((host, port)) + return True + except OSError: + return False + + +def is_mp4(data): + idx = data.find(b"ftyp") + return 0 <= idx <= 32 + + +def is_png(data): + # PNG files start with: 89 50 4E 47 0D 0A 1A 0A + return data.startswith(b"\x89PNG\r\n\x1a\n") + + +def wait_for_port(host="127.0.0.1", port=30010, deadline=300.0, interval=0.5): + end = time.time() + deadline + last_err = None + while time.time() < end: + if probe_port(host, port, timeout=interval): + return True + time.sleep(interval) + raise TimeoutError(f"Port {host}:{port} not ready. Last error: {last_err}") + + +def check_image_size(ut, image, width, height): + # check image size + ut.assertEqual(image.size, (width, height)) + + +class TestCLIBase(unittest.TestCase): + model_path: str = None + extra_args = [] + data_type: DataType = None + # tested on h100 + thresholds = {} + + width: int = 720 + height: int = 720 + output_path: str = "outputs" + + base_command = [ + "sglang", + "generate", + "--text-encoder-cpu-offload", + "--pin-cpu-memory", + "--prompt='A curious raccoon'", + "--save-output", + "--log-level=debug", + f"--width={width}", + f"--height={height}", + f"--output-path={output_path}", + ] + + results = [] + + @classmethod + def setUpClass(cls): + cls.results = [] + + def _run_command(self, name, model_path: str, test_key: str = "", args=[]): + command = ( + self.base_command + + [f"--model-path={model_path}"] + + shlex.split(args or "") + + [f"--output-file-name={name}"] + + self.extra_args + ) + duration = run_command(command) + status = "Success" if duration else "Failed" + + duration_str = f"{duration:.4f}s" if duration else "NA" + self.__class__.results.append( + {"name": name, "key": test_key, "duration": duration_str, "status": status} + ) + + return name, duration, status + + +class TestGenerateBase(TestCLIBase): + model_path: str = None + extra_args = [] + data_type: DataType = None + # tested on h100 + thresholds = {} + + width: int = 720 + height: int = 720 + output_path: str = "outputs" + image_path: str | None = None + prompt: str | None = "A curious raccoon" + + base_command = [ + "sglang", + "generate", + # "--text-encoder-cpu-offload", + # "--pin-cpu-memory", + f"--prompt='{prompt}'", + "--save-output", + "--log-level=debug", + f"--width={width}", + f"--height={height}", + f"--output-path={output_path}", + ] + + results = [] + + @classmethod + def setUpClass(cls): + cls.results = [] + + @classmethod + def tearDownClass(cls): + # Print markdown table + print("\n## Test Results\n") + print("| Test Case | Duration | Status |") + print("|--------------------------------|----------|---------|") + test_keys = ["test_single_gpu", "test_cfg_parallel", "test_usp", "test_mixed"] + test_key_to_order = { + test_key: order for order, test_key in enumerate(test_keys) + } + + ordered_results: list[dict] = [{}] * len(test_keys) + + for result in cls.results: + order = test_key_to_order[result["key"]] + ordered_results[order] = result + + for result in ordered_results: + if not result: + continue + status = ( + result["status"] and result["duration"] <= cls.thresholds[result["key"]] + ) + print(f"| {result['name']:<30} | {result['duration']:<8} | {status:<7} |") + print() + durations = [result["duration"] for result in cls.results] + print(" | ".join([""] + durations + [""])) + + def _run_test(self, name, args, model_path: str, test_key: str): + time_threshold = self.thresholds[test_key] + name, duration, status = self._run_command( + name, args=args, model_path=model_path, test_key=test_key + ) + self.verify(status, name, duration, time_threshold) + + def verify(self, status, name, duration, time_threshold): + print("-" * 80) + print("\n" * 3) + + # test task status + self.assertEqual(status, "Success", f"{name} command failed") + self.assertIsNotNone(duration, f"Could not parse duration for {name}") + self.assertLessEqual( + duration, + time_threshold, + f"{name} failed with {duration:.4f}s > {time_threshold}s", + ) + + # test output file + path = os.path.join( + self.output_path, f"{name}.{self.data_type.get_default_extension()}" + ) + self.assertTrue(os.path.exists(path), f"Output file not exist for {path}") + if self.data_type == DataType.IMAGE: + with Image.open(path) as image: + check_image_size(self, image, self.width, self.height) + logger.info(f"{name} passed in {duration:.4f}s (threshold: {time_threshold}s)") + + def model_name(self): + return self.model_path.split("/")[-1] + + def test_single_gpu(self): + """single gpu""" + self._run_test( + name=f"{self.model_name()}, single gpu", + args=None, + model_path=self.model_path, + test_key="test_single_gpu", + ) + + def test_cfg_parallel(self): + """cfg parallel""" + if self.data_type == DataType.IMAGE: + return + self._run_test( + name=f"{self.model_name()}, cfg parallel", + args="--num-gpus 2 --enable-cfg-parallel", + model_path=self.model_path, + test_key="test_cfg_parallel", + ) + + def test_usp(self): + """usp""" + if self.data_type == DataType.IMAGE: + return + self._run_test( + name=f"{self.model_name()}, usp", + args="--num-gpus 4 --ulysses-degree=2 --ring-degree=2", + model_path=self.model_path, + test_key="test_usp", + ) + + def test_mixed(self): + """mixed""" + if self.data_type == DataType.IMAGE: + return + self._run_test( + name=f"{self.model_name()}, mixed", + args="--num-gpus 4 --ulysses-degree=2 --ring-degree=1 --enable-cfg-parallel", + model_path=self.model_path, + test_key="test_mixed", + ) diff --git a/python/sglang/multimodal_gen/test/utils.py b/python/sglang/multimodal_gen/test/utils.py new file mode 100644 index 000000000..b1d7620c9 --- /dev/null +++ b/python/sglang/multimodal_gen/test/utils.py @@ -0,0 +1,162 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +import json +import os + +import numpy as np +import torch +from pytorch_msssim import ms_ssim, ssim +from torchvision.io import read_video + +from sglang.multimodal_gen.runtime.utils.logging_utils import init_logger + +logger = init_logger(__name__) + + +def compute_video_ssim_torchvision(video1_path, video2_path, use_ms_ssim=True): + """ + Compute SSIM between two videos. + + Args: + video1_path: Path to the first video. + video2_path: Path to the second video. + use_ms_ssim: Whether to use Multi-Scale Structural Similarity(MS-SSIM) instead of SSIM. + """ + print(f"Computing SSIM between {video1_path} and {video2_path}...") + if not os.path.exists(video1_path): + raise FileNotFoundError(f"Video1 not found: {video1_path}") + if not os.path.exists(video2_path): + raise FileNotFoundError(f"Video2 not found: {video2_path}") + + frames1, _, _ = read_video(video1_path, pts_unit="sec", output_format="TCHW") + frames2, _, _ = read_video(video2_path, pts_unit="sec", output_format="TCHW") + + # Ensure same number of frames + min_frames = min(frames1.shape[0], frames2.shape[0]) + frames1 = frames1[:min_frames] + frames2 = frames2[:min_frames] + + frames1 = frames1.float() / 255.0 + frames2 = frames2.float() / 255.0 + + if torch.cuda.is_available(): + frames1 = frames1.cuda() + frames2 = frames2.cuda() + + ssim_values = [] + + # Process each frame individually + for i in range(min_frames): + img1 = frames1[i : i + 1] + img2 = frames2[i : i + 1] + + with torch.no_grad(): + if use_ms_ssim: + value = ms_ssim(img1, img2, data_range=1.0) + else: + value = ssim(img1, img2, data_range=1.0) + + ssim_values.append(value.item()) + + if ssim_values: + mean_ssim = np.mean(ssim_values) + min_ssim = np.min(ssim_values) + max_ssim = np.max(ssim_values) + min_frame_idx = np.argmin(ssim_values) + max_frame_idx = np.argmax(ssim_values) + + print(f"Mean SSIM: {mean_ssim:.4f}") + print(f"Min SSIM: {min_ssim:.4f} (at frame {min_frame_idx})") + print(f"Max SSIM: {max_ssim:.4f} (at frame {max_frame_idx})") + + return mean_ssim, min_ssim, max_ssim + else: + print("No SSIM values calculated") + return 0, 0, 0 + + +def compare_folders(reference_folder, generated_folder, use_ms_ssim=True): + """ + Compare videos with the same filename between reference_folder and generated_folder + + Example usage: + results = compare_folders(reference_folder, generated_folder, + args.use_ms_ssim) + for video_name, ssim_value in results.items(): + if ssim_value is not None: + print( + f"{video_name}: {ssim_value[0]:.4f}, Min SSIM: {ssim_value[1]:.4f}, Max SSIM: {ssim_value[2]:.4f}" + ) + else: + print(f"{video_name}: Error during comparison") + + valid_ssims = [v for v in results.values() if v is not None] + if valid_ssims: + avg_ssim = np.mean([v[0] for v in valid_ssims]) + print(f"\nAverage SSIM across all videos: {avg_ssim:.4f}") + else: + print("\nNo valid SSIM values to average") + """ + + reference_videos = [f for f in os.listdir(reference_folder) if f.endswith(".mp4")] + + results = {} + + for video_name in reference_videos: + ref_path = os.path.join(reference_folder, video_name) + gen_path = os.path.join(generated_folder, video_name) + + if os.path.exists(gen_path): + print(f"\nComparing {video_name}...") + try: + ssim_value = compute_video_ssim_torchvision( + ref_path, gen_path, use_ms_ssim + ) + results[video_name] = ssim_value + except Exception as e: + print(f"Error comparing {video_name}: {e}") + results[video_name] = None + else: + print(f"\nSkipping {video_name} - no matching file in generated folder") + + return results + + +def write_ssim_results( + output_dir, ssim_values, reference_path, generated_path, num_inference_steps, prompt +): + """ + Write SSIM results to a JSON file in the same directory as the generated videos. + """ + try: + logger.info(f"Attempting to write SSIM results to directory: {output_dir}") + + if not os.path.exists(output_dir): + os.makedirs(output_dir, exist_ok=True) + + mean_ssim, min_ssim, max_ssim = ssim_values + + result = { + "mean_ssim": mean_ssim, + "min_ssim": min_ssim, + "max_ssim": max_ssim, + "reference_video": reference_path, + "generated_video": generated_path, + "parameters": { + "num_inference_steps": num_inference_steps, + "prompt": prompt, + }, + } + + test_name = f"steps{num_inference_steps}_{prompt[:100]}" + result_file = os.path.join(output_dir, f"{test_name}_ssim.json") + logger.info(f"Writing JSON results to: {result_file}") + with open(result_file, "w") as f: + json.dump(result, f, indent=2) + + logger.info(f"SSIM results written to {result_file}") + return True + except Exception as e: + logger.error(f"ERROR writing SSIM results: {str(e)}") + return False diff --git a/python/sglang/multimodal_gen/third_party/__init__.py b/python/sglang/multimodal_gen/third_party/__init__.py new file mode 100644 index 000000000..af2eb7d10 --- /dev/null +++ b/python/sglang/multimodal_gen/third_party/__init__.py @@ -0,0 +1 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo diff --git a/python/sglang/multimodal_gen/third_party/pynvml.py b/python/sglang/multimodal_gen/third_party/pynvml.py new file mode 100644 index 000000000..546dc8b8b --- /dev/null +++ b/python/sglang/multimodal_gen/third_party/pynvml.py @@ -0,0 +1,7227 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# copied from https://pypi.org/project/nvidia-ml-py +# version 12.570.86 + +##### +# Copyright (c) 2011-2023, NVIDIA Corporation. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of the NVIDIA Corporation nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +# THE POSSIBILITY OF SUCH DAMAGE. +##### + +import os +import string +import sys +import threading + +## +# Python bindings for the NVML library +## +from ctypes import * +from functools import wraps + +## C Type mappings ## +## Enums +_nvmlEnableState_t = c_uint +NVML_FEATURE_DISABLED = 0 +NVML_FEATURE_ENABLED = 1 + +_nvmlBrandType_t = c_uint +NVML_BRAND_UNKNOWN = 0 +NVML_BRAND_QUADRO = 1 +NVML_BRAND_TESLA = 2 +NVML_BRAND_NVS = 3 +NVML_BRAND_GRID = ( + 4 # Deprecated from API reporting. Keeping definition for backward compatibility. +) +NVML_BRAND_GEFORCE = 5 +NVML_BRAND_TITAN = 6 +NVML_BRAND_NVIDIA_VAPPS = 7 # NVIDIA Virtual Applications +NVML_BRAND_NVIDIA_VPC = 8 # NVIDIA Virtual PC +NVML_BRAND_NVIDIA_VCS = 9 # NVIDIA Virtual Compute Server +NVML_BRAND_NVIDIA_VWS = 10 # NVIDIA RTX Virtual Workstation +NVML_BRAND_NVIDIA_CLOUD_GAMING = 11 # NVIDIA Cloud Gaming +NVML_BRAND_NVIDIA_VGAMING = NVML_BRAND_NVIDIA_CLOUD_GAMING # Deprecated from API reporting. Keeping definition for backward compatibility. +NVML_BRAND_QUADRO_RTX = 12 +NVML_BRAND_NVIDIA_RTX = 13 +NVML_BRAND_NVIDIA = 14 +NVML_BRAND_GEFORCE_RTX = 15 # Unused +NVML_BRAND_TITAN_RTX = 16 # Unused +NVML_BRAND_COUNT = 17 + +_nvmlTemperatureThresholds_t = c_uint +NVML_TEMPERATURE_THRESHOLD_SHUTDOWN = 0 +NVML_TEMPERATURE_THRESHOLD_SLOWDOWN = 1 +NVML_TEMPERATURE_THRESHOLD_MEM_MAX = 2 +NVML_TEMPERATURE_THRESHOLD_GPU_MAX = 3 +NVML_TEMPERATURE_THRESHOLD_ACOUSTIC_MIN = 4 +NVML_TEMPERATURE_THRESHOLD_ACOUSTIC_CURR = 5 +NVML_TEMPERATURE_THRESHOLD_ACOUSTIC_MAX = 6 +NVML_TEMPERATURE_THRESHOLD_GPS_CURR = 7 +NVML_TEMPERATURE_THRESHOLD_COUNT = 8 + +_nvmlTemperatureSensors_t = c_uint +NVML_TEMPERATURE_GPU = 0 +NVML_TEMPERATURE_COUNT = 1 + + +_nvmlComputeMode_t = c_uint +NVML_COMPUTEMODE_DEFAULT = 0 +NVML_COMPUTEMODE_EXCLUSIVE_THREAD = 1 ## Support Removed +NVML_COMPUTEMODE_PROHIBITED = 2 +NVML_COMPUTEMODE_EXCLUSIVE_PROCESS = 3 +NVML_COMPUTEMODE_COUNT = 4 + +_nvmlMemoryLocation_t = c_uint +NVML_MEMORY_LOCATION_L1_CACHE = 0 +NVML_MEMORY_LOCATION_L2_CACHE = 1 +NVML_MEMORY_LOCATION_DEVICE_MEMORY = 2 +NVML_MEMORY_LOCATION_DRAM = 2 +NVML_MEMORY_LOCATION_REGISTER_FILE = 3 +NVML_MEMORY_LOCATION_TEXTURE_MEMORY = 4 +NVML_MEMORY_LOCATION_TEXTURE_SHM = 5 +NVML_MEMORY_LOCATION_CBU = 6 +NVML_MEMORY_LOCATION_SRAM = 7 +NVML_MEMORY_LOCATION_COUNT = 8 + +NVML_NVLINK_MAX_LINKS = 18 + +# For backwards compatibility, maintain the incorrectly-named "LANES" define +NVML_NVLINK_MAX_LANES = NVML_NVLINK_MAX_LINKS + +_nvmlNvLinkErrorCounter_t = c_uint +NVML_NVLINK_ERROR_DL_REPLAY = 0 +NVML_NVLINK_ERROR_DL_RECOVERY = 1 +NVML_NVLINK_ERROR_DL_CRC_FLIT = 2 +NVML_NVLINK_ERROR_DL_CRC_DATA = 3 +NVML_NVLINK_ERROR_DL_ECC_DATA = 4 +NVML_NVLINK_ERROR_COUNT = 5 + +_nvmlNvLinkEccLaneErrorCounter_t = c_uint +NVML_NVLINK_ERROR_DL_ECC_LANE0 = 0 +NVML_NVLINK_ERROR_DL_ECC_LANE1 = 1 +NVML_NVLINK_ERROR_DL_ECC_LANE2 = 2 +NVML_NVLINK_ERROR_DL_ECC_LANE3 = 3 +NVML_NVLINK_ERROR_DL_ECC_COUNT = 5 + +_nvmlNvLinkCapability_t = c_uint +NVML_NVLINK_CAP_P2P_SUPPORTED = 0 +NVML_NVLINK_CAP_SYSMEM_ACCESS = 1 +NVML_NVLINK_CAP_P2P_ATOMICS = 2 +NVML_NVLINK_CAP_SYSMEM_ATOMICS = 3 +NVML_NVLINK_CAP_SLI_BRIDGE = 4 +NVML_NVLINK_CAP_VALID = 5 +NVML_NVLINK_CAP_COUNT = 6 + +_nvmlNvLinkUtilizationCountPktTypes_t = c_uint +NVML_NVLINK_COUNTER_PKTFILTER_NOP = 0x1 +NVML_NVLINK_COUNTER_PKTFILTER_READ = 0x2 +NVML_NVLINK_COUNTER_PKTFILTER_WRITE = 0x4 +NVML_NVLINK_COUNTER_PKTFILTER_RATOM = 0x8 +NVML_NVLINK_COUNTER_PKTFILTER_NRATOM = 0x10 +NVML_NVLINK_COUNTER_PKTFILTER_FLUSH = 0x20 +NVML_NVLINK_COUNTER_PKTFILTER_RESPDATA = 0x40 +NVML_NVLINK_COUNTER_PKTFILTER_RESPNODATA = 0x80 +NVML_NVLINK_COUNTER_PKTFILTER_ALL = 0xFF + +_nvmlNvLinkUtilizationCountUnits_t = c_uint +NVML_NVLINK_COUNTER_UNIT_CYCLES = 0 +NVML_NVLINK_COUNTER_UNIT_PACKETS = 1 +NVML_NVLINK_COUNTER_UNIT_BYTES = 2 +NVML_NVLINK_COUNTER_UNIT_RESERVED = 3 +NVML_NVLINK_COUNTER_UNIT_COUNT = 4 + +_nvmlNvLinkDeviceType_t = c_uint +NVML_NVLINK_DEVICE_TYPE_GPU = 0x00 +NVML_NVLINK_DEVICE_TYPE_IBMNPU = 0x01 +NVML_NVLINK_DEVICE_TYPE_SWITCH = 0x02 +NVML_NVLINK_DEVICE_TYPE_UNKNOWN = 0xFF + +# These are deprecated, instead use _nvmlMemoryErrorType_t +_nvmlEccBitType_t = c_uint +NVML_SINGLE_BIT_ECC = 0 +NVML_DOUBLE_BIT_ECC = 1 +NVML_ECC_ERROR_TYPE_COUNT = 2 + +_nvmlEccCounterType_t = c_uint +NVML_VOLATILE_ECC = 0 +NVML_AGGREGATE_ECC = 1 +NVML_ECC_COUNTER_TYPE_COUNT = 2 + +_nvmlMemoryErrorType_t = c_uint +NVML_MEMORY_ERROR_TYPE_CORRECTED = 0 +NVML_MEMORY_ERROR_TYPE_UNCORRECTED = 1 +NVML_MEMORY_ERROR_TYPE_COUNT = 2 + +_nvmlClockType_t = c_uint +NVML_CLOCK_GRAPHICS = 0 +NVML_CLOCK_SM = 1 +NVML_CLOCK_MEM = 2 +NVML_CLOCK_VIDEO = 3 +NVML_CLOCK_COUNT = 4 + +_nvmlClockId_t = c_uint +NVML_CLOCK_ID_CURRENT = 0 +NVML_CLOCK_ID_APP_CLOCK_TARGET = 1 +NVML_CLOCK_ID_APP_CLOCK_DEFAULT = 2 +NVML_CLOCK_ID_CUSTOMER_BOOST_MAX = 3 +NVML_CLOCK_ID_COUNT = 4 + +_nvmlDriverModel_t = c_uint +NVML_DRIVER_WDDM = 0 +NVML_DRIVER_WDM = 1 +NVML_DRIVER_MCDM = 2 + +NVML_MAX_GPU_PERF_PSTATES = 16 + +_nvmlPstates_t = c_uint +NVML_PSTATE_0 = 0 +NVML_PSTATE_1 = 1 +NVML_PSTATE_2 = 2 +NVML_PSTATE_3 = 3 +NVML_PSTATE_4 = 4 +NVML_PSTATE_5 = 5 +NVML_PSTATE_6 = 6 +NVML_PSTATE_7 = 7 +NVML_PSTATE_8 = 8 +NVML_PSTATE_9 = 9 +NVML_PSTATE_10 = 10 +NVML_PSTATE_11 = 11 +NVML_PSTATE_12 = 12 +NVML_PSTATE_13 = 13 +NVML_PSTATE_14 = 14 +NVML_PSTATE_15 = 15 +NVML_PSTATE_UNKNOWN = 32 + +_nvmlInforomObject_t = c_uint +NVML_INFOROM_OEM = 0 +NVML_INFOROM_ECC = 1 +NVML_INFOROM_POWER = 2 +NVML_INFOROM_DEN = 3 +NVML_INFOROM_COUNT = 4 + +_nvmlReturn_t = c_uint +NVML_SUCCESS = 0 +NVML_ERROR_UNINITIALIZED = 1 +NVML_ERROR_INVALID_ARGUMENT = 2 +NVML_ERROR_NOT_SUPPORTED = 3 +NVML_ERROR_NO_PERMISSION = 4 +NVML_ERROR_ALREADY_INITIALIZED = 5 +NVML_ERROR_NOT_FOUND = 6 +NVML_ERROR_INSUFFICIENT_SIZE = 7 +NVML_ERROR_INSUFFICIENT_POWER = 8 +NVML_ERROR_DRIVER_NOT_LOADED = 9 +NVML_ERROR_TIMEOUT = 10 +NVML_ERROR_IRQ_ISSUE = 11 +NVML_ERROR_LIBRARY_NOT_FOUND = 12 +NVML_ERROR_FUNCTION_NOT_FOUND = 13 +NVML_ERROR_CORRUPTED_INFOROM = 14 +NVML_ERROR_GPU_IS_LOST = 15 +NVML_ERROR_RESET_REQUIRED = 16 +NVML_ERROR_OPERATING_SYSTEM = 17 +NVML_ERROR_LIB_RM_VERSION_MISMATCH = 18 +NVML_ERROR_IN_USE = 19 +NVML_ERROR_MEMORY = 20 +NVML_ERROR_NO_DATA = 21 +NVML_ERROR_VGPU_ECC_NOT_SUPPORTED = 22 +NVML_ERROR_INSUFFICIENT_RESOURCES = 23 +NVML_ERROR_FREQ_NOT_SUPPORTED = 24 +NVML_ERROR_ARGUMENT_VERSION_MISMATCH = 25 +NVML_ERROR_DEPRECATED = 26 +NVML_ERROR_NOT_READY = 27 +NVML_ERROR_GPU_NOT_FOUND = 28 +NVML_ERROR_INVALID_STATE = 29 +NVML_ERROR_UNKNOWN = 999 + +_nvmlFanState_t = c_uint +NVML_FAN_NORMAL = 0 +NVML_FAN_FAILED = 1 + +_nvmlFanControlPolicy_t = c_uint +NVML_FAN_POLICY_TEMPERATURE_CONTINOUS_SW = 0 +NVML_FAN_POLICY_MANUAL = 1 + +_nvmlLedColor_t = c_uint +NVML_LED_COLOR_GREEN = 0 +NVML_LED_COLOR_AMBER = 1 + +_nvmlGpuOperationMode_t = c_uint +NVML_GOM_ALL_ON = 0 +NVML_GOM_COMPUTE = 1 +NVML_GOM_LOW_DP = 2 + +_nvmlPageRetirementCause_t = c_uint +NVML_PAGE_RETIREMENT_CAUSE_MULTIPLE_SINGLE_BIT_ECC_ERRORS = 0 +NVML_PAGE_RETIREMENT_CAUSE_DOUBLE_BIT_ECC_ERROR = 1 +NVML_PAGE_RETIREMENT_CAUSE_COUNT = 2 + +_nvmlRestrictedAPI_t = c_uint +NVML_RESTRICTED_API_SET_APPLICATION_CLOCKS = 0 +NVML_RESTRICTED_API_SET_AUTO_BOOSTED_CLOCKS = 1 +NVML_RESTRICTED_API_COUNT = 2 + +_nvmlBridgeChipType_t = c_uint +NVML_BRIDGE_CHIP_PLX = 0 +NVML_BRIDGE_CHIP_BRO4 = 1 +NVML_MAX_PHYSICAL_BRIDGE = 128 + +_nvmlValueType_t = c_uint +NVML_VALUE_TYPE_DOUBLE = 0 +NVML_VALUE_TYPE_UNSIGNED_INT = 1 +NVML_VALUE_TYPE_UNSIGNED_LONG = 2 +NVML_VALUE_TYPE_UNSIGNED_LONG_LONG = 3 +NVML_VALUE_TYPE_SIGNED_LONG_LONG = 4 +NVML_VALUE_TYPE_SIGNED_INT = 5 +NVML_VALUE_TYPE_UNSIGNED_SHORT = 6 +NVML_VALUE_TYPE_COUNT = 7 + +_nvmlNvlinkVersion_t = c_uint +NVML_NVLINK_VERSION_INVALID = 0 +NVML_NVLINK_VERSION_1_0 = 1 +NVML_NVLINK_VERSION_2_0 = 2 +NVML_NVLINK_VERSION_2_2 = 3 +NVML_NVLINK_VERSION_3_0 = 4 +NVML_NVLINK_VERSION_3_1 = 5 +NVML_NVLINK_VERSION_4_0 = 6 +NVML_NVLINK_VERSION_5_0 = 7 + +_nvmlPerfPolicyType_t = c_uint +NVML_PERF_POLICY_POWER = 0 +NVML_PERF_POLICY_THERMAL = 1 +NVML_PERF_POLICY_SYNC_BOOST = 2 +NVML_PERF_POLICY_BOARD_LIMIT = 3 +NVML_PERF_POLICY_LOW_UTILIZATION = 4 +NVML_PERF_POLICY_RELIABILITY = 5 +NVML_PERF_POLICY_TOTAL_APP_CLOCKS = 10 +NVML_PERF_POLICY_TOTAL_BASE_CLOCKS = 11 +NVML_PERF_POLICY_COUNT = 12 + +_nvmlEncoderQueryType_t = c_uint +NVML_ENCODER_QUERY_H264 = 0 +NVML_ENCODER_QUERY_HEVC = 1 +NVML_ENCODER_QUERY_AV1 = 2 +NVML_ENCODER_QUERY_UNKNOWN = 255 + +_nvmlFBCSessionType_t = c_uint +NVML_FBC_SESSION_TYPE_UNKNOWN = 0 +NVML_FBC_SESSION_TYPE_TOSYS = 1 +NVML_FBC_SESSION_TYPE_CUDA = 2 +NVML_FBC_SESSION_TYPE_VID = 3 +NVML_FBC_SESSION_TYPE_HWENC = 4 + +_nvmlDetachGpuState_t = c_uint +NVML_DETACH_GPU_KEEP = 0 +NVML_DETACH_GPU_REMOVE = 1 + +_nvmlPcieLinkState_t = c_uint +NVML_PCIE_LINK_KEEP = 0 +NVML_PCIE_LINK_SHUT_DOWN = 1 + +_nvmlSamplingType_t = c_uint +NVML_TOTAL_POWER_SAMPLES = 0 +NVML_GPU_UTILIZATION_SAMPLES = 1 +NVML_MEMORY_UTILIZATION_SAMPLES = 2 +NVML_ENC_UTILIZATION_SAMPLES = 3 +NVML_DEC_UTILIZATION_SAMPLES = 4 +NVML_PROCESSOR_CLK_SAMPLES = 5 +NVML_MEMORY_CLK_SAMPLES = 6 +NVML_MODULE_POWER_SAMPLES = 7 +NVML_JPG_UTILIZATION_SAMPLES = 8 +NVML_OFA_UTILIZATION_SAMPLES = 9 +NVML_SAMPLINGTYPE_COUNT = 10 + +_nvmlPcieUtilCounter_t = c_uint +NVML_PCIE_UTIL_TX_BYTES = 0 +NVML_PCIE_UTIL_RX_BYTES = 1 +NVML_PCIE_UTIL_COUNT = 2 + +_nvmlGpuTopologyLevel_t = c_uint +NVML_TOPOLOGY_INTERNAL = 0 +NVML_TOPOLOGY_SINGLE = 10 +NVML_TOPOLOGY_MULTIPLE = 20 +NVML_TOPOLOGY_HOSTBRIDGE = 30 +NVML_TOPOLOGY_NODE = 40 +NVML_TOPOLOGY_CPU = NVML_TOPOLOGY_NODE +NVML_TOPOLOGY_SYSTEM = 50 + +_nvmlGpuP2PCapsIndex_t = c_uint +NVML_P2P_CAPS_INDEX_READ = (0,) +NVML_P2P_CAPS_INDEX_WRITE = 1 +NVML_P2P_CAPS_INDEX_NVLINK = 2 +NVML_P2P_CAPS_INDEX_ATOMICS = 3 +# +# NVML_P2P_CAPS_INDEX_PROP is deprecated. +# Use NVML_P2P_CAPS_INDEX_PCI instead. +# +NVML_P2P_CAPS_INDEX_PROP = 4 +NVML_P2P_CAPS_INDEX_PCI = 4 +NVML_P2P_CAPS_INDEX_UNKNOWN = 5 + +_nvmlGpuP2PStatus_t = c_uint +NVML_P2P_STATUS_OK = 0 +NVML_P2P_STATUS_CHIPSET_NOT_SUPPORED = 1 +NVML_P2P_STATUS_CHIPSET_NOT_SUPPORTED = NVML_P2P_STATUS_CHIPSET_NOT_SUPPORED +NVML_P2P_STATUS_GPU_NOT_SUPPORTED = 2 +NVML_P2P_STATUS_IOH_TOPOLOGY_NOT_SUPPORTED = 3 +NVML_P2P_STATUS_DISABLED_BY_REGKEY = 4 +NVML_P2P_STATUS_NOT_SUPPORTED = 5 +NVML_P2P_STATUS_UNKNOWN = 6 + +_nvmlDeviceArchitecture_t = c_uint +NVML_DEVICE_ARCH_KEPLER = 2 +NVML_DEVICE_ARCH_MAXWELL = 3 +NVML_DEVICE_ARCH_PASCAL = 4 +NVML_DEVICE_ARCH_VOLTA = 5 +NVML_DEVICE_ARCH_TURING = 6 +NVML_DEVICE_ARCH_AMPERE = 7 +NVML_DEVICE_ARCH_ADA = 8 +NVML_DEVICE_ARCH_HOPPER = 9 +NVML_DEVICE_ARCH_BLACKWELL = 10 +NVML_DEVICE_ARCH_T23X = 11 +NVML_DEVICE_ARCH_UNKNOWN = 0xFFFFFFFF + +# PCI bus Types +_nvmlBusType_t = c_uint +NVML_BUS_TYPE_UNKNOWN = 0 +NVML_BUS_TYPE_PCI = 1 +NVML_BUS_TYPE_PCIE = 2 +NVML_BUS_TYPE_FPCI = 3 +NVML_BUS_TYPE_AGP = 4 + +_nvmlPowerSource_t = c_uint +NVML_POWER_SOURCE_AC = 0x00000000 +NVML_POWER_SOURCE_BATTERY = 0x00000001 +NVML_POWER_SOURCE_UNDERSIZED = 0x00000002 + +_nvmlAdaptiveClockInfoStatus_t = c_uint +NVML_ADAPTIVE_CLOCKING_INFO_STATUS_DISABLED = 0x00000000 +NVML_ADAPTIVE_CLOCKING_INFO_STATUS_ENABLED = 0x00000001 + +_nvmlClockLimitId_t = c_uint +NVML_CLOCK_LIMIT_ID_RANGE_START = 0xFFFFFF00 +NVML_CLOCK_LIMIT_ID_TDP = 0xFFFFFF01 +NVML_CLOCK_LIMIT_ID_UNLIMITED = 0xFFFFFF02 + +_nvmlPcieLinkMaxSpeed_t = c_uint +NVML_PCIE_LINK_MAX_SPEED_INVALID = 0x00000000 +NVML_PCIE_LINK_MAX_SPEED_2500MBPS = 0x00000001 +NVML_PCIE_LINK_MAX_SPEED_5000MBPS = 0x00000002 +NVML_PCIE_LINK_MAX_SPEED_8000MBPS = 0x00000003 +NVML_PCIE_LINK_MAX_SPEED_16000MBPS = 0x00000004 +NVML_PCIE_LINK_MAX_SPEED_32000MBPS = 0x00000005 +NVML_PCIE_LINK_MAX_SPEED_64000MBPS = 0x00000006 + +_nvmlPcieAtomicsCapability_t = c_uint +NVML_PCIE_ATOMICS_CAP_FETCHADD32 = 0x01 +NVML_PCIE_ATOMICS_CAP_FETCHADD64 = 0x02 +NVML_PCIE_ATOMICS_CAP_SWAP32 = 0x04 +NVML_PCIE_ATOMICS_CAP_SWAP64 = 0x08 +NVML_PCIE_ATOMICS_CAP_CAS32 = 0x10 +NVML_PCIE_ATOMICS_CAP_CAS64 = 0x20 +NVML_PCIE_ATOMICS_CAP_CAS128 = 0x40 +NVML_PCIE_ATOMICS_OPS_MAX = 7 + +_nvmlAffinityScope_t = c_uint +NVML_AFFINITY_SCOPE_NODE = 0 +NVML_AFFINITY_SCOPE_SOCKET = 1 + +_nvmlDeviceGpuRecoveryAction_t = c_uint +NVML_GPU_RECOVERY_ACTION_NONE = 0 +NVML_GPU_RECOVERY_ACTION_GPU_RESET = 1 +NVML_GPU_RECOVERY_ACTION_NODE_REBOOT = 2 +NVML_GPU_RECOVERY_ACTION_DRAIN_P2P = 3 +NVML_GPU_RECOVERY_ACTION_DRAIN_AND_RESET = 4 + +# C preprocessor defined values +nvmlFlagDefault = 0 +nvmlFlagForce = 1 +NVML_INIT_FLAG_NO_GPUS = 1 +NVML_INIT_FLAG_NO_ATTACH = 2 + +NVML_MAX_GPC_COUNT = 32 + +# buffer size +NVML_DEVICE_INFOROM_VERSION_BUFFER_SIZE = 16 +NVML_DEVICE_UUID_BUFFER_SIZE = 80 +NVML_DEVICE_UUID_V2_BUFFER_SIZE = 96 +NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE = 80 +NVML_SYSTEM_NVML_VERSION_BUFFER_SIZE = 80 +NVML_DEVICE_NAME_BUFFER_SIZE = 64 +NVML_DEVICE_NAME_V2_BUFFER_SIZE = 96 +NVML_DEVICE_SERIAL_BUFFER_SIZE = 30 +NVML_DEVICE_PART_NUMBER_BUFFER_SIZE = 80 +NVML_DEVICE_GPU_PART_NUMBER_BUFFER_SIZE = 80 +NVML_DEVICE_VBIOS_VERSION_BUFFER_SIZE = 32 +NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE = 32 +NVML_DEVICE_PCI_BUS_ID_BUFFER_V2_SIZE = 16 +NVML_GRID_LICENSE_BUFFER_SIZE = 128 +NVML_VGPU_NAME_BUFFER_SIZE = 64 +NVML_GRID_LICENSE_FEATURE_MAX_COUNT = 3 +NVML_VGPU_METADATA_OPAQUE_DATA_SIZE = sizeof(c_uint) + 256 +NVML_VGPU_PGPU_METADATA_OPAQUE_DATA_SIZE = 256 +NVML_DEVICE_GPU_FRU_PART_NUMBER_BUFFER_SIZE = ( + 0x14 # NV2080_GPU_MAX_PRODUCT_PART_NUMBER_LENGTH +) +NVML_PERF_MODES_BUFFER_SIZE = 2048 + +# Format strings +NVML_DEVICE_PCI_BUS_ID_LEGACY_FMT = "%04X:%02X:%02X.0" +NVML_DEVICE_PCI_BUS_ID_FMT = "%08X:%02X:%02X.0" + +NVML_VALUE_NOT_AVAILABLE_ulonglong = c_ulonglong(-1) +NVML_VALUE_NOT_AVAILABLE_uint = c_uint(-1) + +""" + Field Identifiers. + + All Identifiers pertain to a device. Each ID is only used once and is guaranteed never to change. +""" +NVML_FI_DEV_ECC_CURRENT = 1 # Current ECC mode. 1=Active. 0=Inactive +NVML_FI_DEV_ECC_PENDING = 2 # Pending ECC mode. 1=Active. 0=Inactive + +# ECC Count Totals +NVML_FI_DEV_ECC_SBE_VOL_TOTAL = 3 # Total single bit volatile ECC errors +NVML_FI_DEV_ECC_DBE_VOL_TOTAL = 4 # Total double bit volatile ECC errors +NVML_FI_DEV_ECC_SBE_AGG_TOTAL = 5 # Total single bit aggregate (persistent) ECC errors +NVML_FI_DEV_ECC_DBE_AGG_TOTAL = 6 # Total double bit aggregate (persistent) ECC errors +# Individual ECC locations +NVML_FI_DEV_ECC_SBE_VOL_L1 = 7 # L1 cache single bit volatile ECC errors +NVML_FI_DEV_ECC_DBE_VOL_L1 = 8 # L1 cache double bit volatile ECC errors +NVML_FI_DEV_ECC_SBE_VOL_L2 = 9 # L2 cache single bit volatile ECC errors +NVML_FI_DEV_ECC_DBE_VOL_L2 = 10 # L2 cache double bit volatile ECC errors +NVML_FI_DEV_ECC_SBE_VOL_DEV = 11 # Device memory single bit volatile ECC errors +NVML_FI_DEV_ECC_DBE_VOL_DEV = 12 # Device memory double bit volatile ECC errors +NVML_FI_DEV_ECC_SBE_VOL_REG = 13 # Register file single bit volatile ECC errors +NVML_FI_DEV_ECC_DBE_VOL_REG = 14 # Register file double bit volatile ECC errors +NVML_FI_DEV_ECC_SBE_VOL_TEX = 15 # Texture memory single bit volatile ECC errors +NVML_FI_DEV_ECC_DBE_VOL_TEX = 16 # Texture memory double bit volatile ECC errors +NVML_FI_DEV_ECC_DBE_VOL_CBU = 17 # CBU double bit volatile ECC errors +NVML_FI_DEV_ECC_SBE_AGG_L1 = 18 # L1 cache single bit aggregate (persistent) ECC errors +NVML_FI_DEV_ECC_DBE_AGG_L1 = 19 # L1 cache double bit aggregate (persistent) ECC errors +NVML_FI_DEV_ECC_SBE_AGG_L2 = 20 # L2 cache single bit aggregate (persistent) ECC errors +NVML_FI_DEV_ECC_DBE_AGG_L2 = 21 # L2 cache double bit aggregate (persistent) ECC errors +NVML_FI_DEV_ECC_SBE_AGG_DEV = ( + 22 # Device memory single bit aggregate (persistent) ECC errors +) +NVML_FI_DEV_ECC_DBE_AGG_DEV = ( + 23 # Device memory double bit aggregate (persistent) ECC errors +) +NVML_FI_DEV_ECC_SBE_AGG_REG = ( + 24 # Register File single bit aggregate (persistent) ECC errors +) +NVML_FI_DEV_ECC_DBE_AGG_REG = ( + 25 # Register File double bit aggregate (persistent) ECC errors +) +NVML_FI_DEV_ECC_SBE_AGG_TEX = ( + 26 # Texture memory single bit aggregate (persistent) ECC errors +) +NVML_FI_DEV_ECC_DBE_AGG_TEX = ( + 27 # Texture memory double bit aggregate (persistent) ECC errors +) +NVML_FI_DEV_ECC_DBE_AGG_CBU = 28 # CBU double bit aggregate ECC errors + +# Page Retirement +NVML_FI_DEV_RETIRED_SBE = 29 # Number of retired pages because of single bit errors +NVML_FI_DEV_RETIRED_DBE = 30 # Number of retired pages because of double bit errors +NVML_FI_DEV_RETIRED_PENDING = 31 # If any pages are pending retirement. 1=yes. 0=no. + +# NvLink Flit Error Counters +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L0 = ( + 32 # NVLink flow control CRC Error Counter for Lane 0 +) +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L1 = ( + 33 # NVLink flow control CRC Error Counter for Lane 1 +) +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L2 = ( + 34 # NVLink flow control CRC Error Counter for Lane 2 +) +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L3 = ( + 35 # NVLink flow control CRC Error Counter for Lane 3 +) +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L4 = ( + 36 # NVLink flow control CRC Error Counter for Lane 4 +) +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L5 = ( + 37 # NVLink flow control CRC Error Counter for Lane 5 +) +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_TOTAL = ( + 38 # NVLink flow control CRC Error Counter total for all Lanes +) + +# NvLink CRC Data Error Counters +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L0 = ( + 39 # NVLink data CRC Error Counter for Lane 0 +) +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L1 = ( + 40 # NVLink data CRC Error Counter for Lane 1 +) +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L2 = ( + 41 # NVLink data CRC Error Counter for Lane 2 +) +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L3 = ( + 42 # NVLink data CRC Error Counter for Lane 3 +) +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L4 = ( + 43 # NVLink data CRC Error Counter for Lane 4 +) +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L5 = ( + 44 # NVLink data CRC Error Counter for Lane 5 +) +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_TOTAL = ( + 45 # NvLink data CRC Error Counter total for all Lanes +) + +# NvLink Replay Error Counters +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L0 = 46 # NVLink Replay Error Counter for Lane 0 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L1 = 47 # NVLink Replay Error Counter for Lane 1 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L2 = 48 # NVLink Replay Error Counter for Lane 2 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L3 = 49 # NVLink Replay Error Counter for Lane 3 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L4 = 50 # NVLink Replay Error Counter for Lane 4 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L5 = 51 # NVLink Replay Error Counter for Lane 5 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_TOTAL = ( + 52 # NVLink Replay Error Counter total for all Lanes +) + +# NvLink Recovery Error Counters +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L0 = ( + 53 # NVLink Recovery Error Counter for Lane 0 +) +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L1 = ( + 54 # NVLink Recovery Error Counter for Lane 1 +) +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L2 = ( + 55 # NVLink Recovery Error Counter for Lane 2 +) +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L3 = ( + 56 # NVLink Recovery Error Counter for Lane 3 +) +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L4 = ( + 57 # NVLink Recovery Error Counter for Lane 4 +) +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L5 = ( + 58 # NVLink Recovery Error Counter for Lane 5 +) +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_TOTAL = ( + 59 # NVLink Recovery Error Counter total for all Lanes +) + +# NvLink Bandwidth Counters +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L0 = ( + 60 # NVLink Bandwidth Counter for Counter Set 0, Lane 0 +) +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L1 = ( + 61 # NVLink Bandwidth Counter for Counter Set 0, Lane 1 +) +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L2 = ( + 62 # NVLink Bandwidth Counter for Counter Set 0, Lane 2 +) +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L3 = ( + 63 # NVLink Bandwidth Counter for Counter Set 0, Lane 3 +) +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L4 = ( + 64 # NVLink Bandwidth Counter for Counter Set 0, Lane 4 +) +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L5 = ( + 65 # NVLink Bandwidth Counter for Counter Set 0, Lane 5 +) +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_TOTAL = ( + 66 # NVLink Bandwidth Counter Total for Counter Set 0, All Lanes +) + +# NvLink Bandwidth Counters +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L0 = ( + 67 # NVLink Bandwidth Counter for Counter Set 1, Lane 0 +) +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L1 = ( + 68 # NVLink Bandwidth Counter for Counter Set 1, Lane 1 +) +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L2 = ( + 69 # NVLink Bandwidth Counter for Counter Set 1, Lane 2 +) +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L3 = ( + 70 # NVLink Bandwidth Counter for Counter Set 1, Lane 3 +) +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L4 = ( + 71 # NVLink Bandwidth Counter for Counter Set 1, Lane 4 +) +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L5 = ( + 72 # NVLink Bandwidth Counter for Counter Set 1, Lane 5 +) +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_TOTAL = ( + 73 # NVLink Bandwidth Counter Total for Counter Set 1, All Lanes +) + +# Perf Policy Counters +NVML_FI_DEV_PERF_POLICY_POWER = 74 # Perf Policy Counter for Power Policy +NVML_FI_DEV_PERF_POLICY_THERMAL = 75 # Perf Policy Counter for Thermal Policy +NVML_FI_DEV_PERF_POLICY_SYNC_BOOST = 76 # Perf Policy Counter for Sync boost Policy +NVML_FI_DEV_PERF_POLICY_BOARD_LIMIT = 77 # Perf Policy Counter for Board Limit +NVML_FI_DEV_PERF_POLICY_LOW_UTILIZATION = ( + 78 # Perf Policy Counter for Low GPU Utilization Policy +) +NVML_FI_DEV_PERF_POLICY_RELIABILITY = 79 # Perf Policy Counter for Reliability Policy +NVML_FI_DEV_PERF_POLICY_TOTAL_APP_CLOCKS = ( + 80 # Perf Policy Counter for Total App Clock Policy +) +NVML_FI_DEV_PERF_POLICY_TOTAL_BASE_CLOCKS = ( + 81 # Perf Policy Counter for Total Base Clocks Policy +) + +# Memory temperatures +NVML_FI_DEV_MEMORY_TEMP = 82 # Memory temperature for the device + +# Energy Counter +NVML_FI_DEV_TOTAL_ENERGY_CONSUMPTION = ( + 83 # Total energy consumption for the GPU in mJ since the driver was last reloaded +) + +# NVLink Speed +NVML_FI_DEV_NVLINK_SPEED_MBPS_L0 = 84 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L1 = 85 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L2 = 86 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L3 = 87 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L4 = 88 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L5 = 89 +NVML_FI_DEV_NVLINK_SPEED_MBPS_COMMON = 90 + +# NVLink Link Count +NVML_FI_DEV_NVLINK_LINK_COUNT = 91 + +# Page Retirement pending fields +NVML_FI_DEV_RETIRED_PENDING_SBE = 92 +NVML_FI_DEV_RETIRED_PENDING_DBE = 93 + +# PCIe replay and replay rollover counters +NVML_FI_DEV_PCIE_REPLAY_COUNTER = 94 +NVML_FI_DEV_PCIE_REPLAY_ROLLOVER_COUNTER = 95 + +# NvLink Flit Error Counters +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L6 = ( + 96 # NVLink flow control CRC Error Counter for Lane 6 +) +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L7 = ( + 97 # NVLink flow control CRC Error Counter for Lane 7 +) +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L8 = ( + 98 # NVLink flow control CRC Error Counter for Lane 8 +) +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L9 = ( + 99 # NVLink flow control CRC Error Counter for Lane 9 +) +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L10 = ( + 100 # NVLink flow control CRC Error Counter for Lane 10 +) +NVML_FI_DEV_NVLINK_CRC_FLIT_ERROR_COUNT_L11 = ( + 101 # NVLink flow control CRC Error Counter for Lane 11 +) + +# NvLink CRC Data Error Counters +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L6 = ( + 102 # NVLink data CRC Error Counter for Lane 6 +) +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L7 = ( + 103 # NVLink data CRC Error Counter for Lane 7 +) +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L8 = ( + 104 # NVLink data CRC Error Counter for Lane 8 +) +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L9 = ( + 105 # NVLink data CRC Error Counter for Lane 9 +) +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L10 = ( + 106 # NVLink data CRC Error Counter for Lane 10 +) +NVML_FI_DEV_NVLINK_CRC_DATA_ERROR_COUNT_L11 = ( + 107 # NVLink data CRC Error Counter for Lane 11 +) + +# NvLink Replay Error Counters +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L6 = 108 # NVLink Replay Error Counter for Lane 6 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L7 = 109 # NVLink Replay Error Counter for Lane 7 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L8 = 110 # NVLink Replay Error Counter for Lane 8 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L9 = 111 # NVLink Replay Error Counter for Lane 9 +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L10 = ( + 112 # NVLink Replay Error Counter for Lane 10 +) +NVML_FI_DEV_NVLINK_REPLAY_ERROR_COUNT_L11 = ( + 113 # NVLink Replay Error Counter for Lane 11 +) + +# NvLink Recovery Error Counters +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L6 = ( + 114 # NVLink Recovery Error Counter for Lane 6 +) +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L7 = ( + 115 # NVLink Recovery Error Counter for Lane 7 +) +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L8 = ( + 116 # NVLink Recovery Error Counter for Lane 8 +) +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L9 = ( + 117 # NVLink Recovery Error Counter for Lane 9 +) +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L10 = ( + 118 # NVLink Recovery Error Counter for Lane 10 +) +NVML_FI_DEV_NVLINK_RECOVERY_ERROR_COUNT_L11 = ( + 119 # NVLink Recovery Error Counter for Lane 11 +) + +# NvLink Bandwidth Counters +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L6 = ( + 120 # NVLink Bandwidth Counter for Counter Set 0, Lane 6 +) +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L7 = ( + 121 # NVLink Bandwidth Counter for Counter Set 0, Lane 7 +) +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L8 = ( + 122 # NVLink Bandwidth Counter for Counter Set 0, Lane 8 +) +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L9 = ( + 123 # NVLink Bandwidth Counter for Counter Set 0, Lane 9 +) +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L10 = ( + 124 # NVLink Bandwidth Counter for Counter Set 0, Lane 10 +) +NVML_FI_DEV_NVLINK_BANDWIDTH_C0_L11 = ( + 125 # NVLink Bandwidth Counter for Counter Set 0, Lane 11 +) + +# NvLink Bandwidth Counters +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L6 = ( + 126 # NVLink Bandwidth Counter for Counter Set 1, Lane 6 +) +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L7 = ( + 127 # NVLink Bandwidth Counter for Counter Set 1, Lane 7 +) +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L8 = ( + 128 # NVLink Bandwidth Counter for Counter Set 1, Lane 8 +) +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L9 = ( + 129 # NVLink Bandwidth Counter for Counter Set 1, Lane 9 +) +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L10 = ( + 130 # NVLink Bandwidth Counter for Counter Set 1, Lane 10 +) +NVML_FI_DEV_NVLINK_BANDWIDTH_C1_L11 = ( + 131 # NVLink Bandwidth Counter for Counter Set 1, Lane 11 +) + +# NVLink Speed +NVML_FI_DEV_NVLINK_SPEED_MBPS_L6 = 132 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L7 = 133 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L8 = 134 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L9 = 135 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L10 = 136 +NVML_FI_DEV_NVLINK_SPEED_MBPS_L11 = 137 + +# NVLink Throughput Counters +NVML_FI_DEV_NVLINK_THROUGHPUT_DATA_TX = 138 # NVLink TX Data throughput in KiB +NVML_FI_DEV_NVLINK_THROUGHPUT_DATA_RX = 139 # NVLink RX Data throughput in KiB +NVML_FI_DEV_NVLINK_THROUGHPUT_RAW_TX = 140 # NVLink TX Data + protocol overhead in KiB +NVML_FI_DEV_NVLINK_THROUGHPUT_RAW_RX = 141 # NVLink RX Data + protocol overhead in KiB + +# Row Remapper +NVML_FI_DEV_REMAPPED_COR = 142 +NVML_FI_DEV_REMAPPED_UNC = 143 +NVML_FI_DEV_REMAPPED_PENDING = 144 +NVML_FI_DEV_REMAPPED_FAILURE = 145 + +# Remote device NVLink ID +NVML_FI_DEV_NVLINK_REMOTE_NVLINK_ID = 146 + +# Number of NVLinks connected to NVSwitch +NVML_FI_DEV_NVSWITCH_CONNECTED_LINK_COUNT = 147 + +# NvLink ECC Data Error Counters +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L0 = ( + 148 # < NVLink data ECC Error Counter for Link 0 +) +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L1 = ( + 149 # < NVLink data ECC Error Counter for Link 1 +) +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L2 = ( + 150 # < NVLink data ECC Error Counter for Link 2 +) +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L3 = ( + 151 # < NVLink data ECC Error Counter for Link 3 +) +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L4 = ( + 152 # < NVLink data ECC Error Counter for Link 4 +) +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L5 = ( + 153 # < NVLink data ECC Error Counter for Link 5 +) +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L6 = ( + 154 # < NVLink data ECC Error Counter for Link 6 +) +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L7 = ( + 155 # < NVLink data ECC Error Counter for Link 7 +) +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L8 = ( + 156 # < NVLink data ECC Error Counter for Link 8 +) +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L9 = ( + 157 # < NVLink data ECC Error Counter for Link 9 +) +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L10 = ( + 158 # < NVLink data ECC Error Counter for Link 10 +) +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_L11 = ( + 159 # < NVLink data ECC Error Counter for Link 11 +) +NVML_FI_DEV_NVLINK_ECC_DATA_ERROR_COUNT_TOTAL = ( + 160 # < NvLink data ECC Error Counter total for all Links +) + +NVML_FI_DEV_NVLINK_ERROR_DL_REPLAY = 161 +NVML_FI_DEV_NVLINK_ERROR_DL_RECOVERY = 162 +NVML_FI_DEV_NVLINK_ERROR_DL_CRC = 163 +NVML_FI_DEV_NVLINK_GET_SPEED = 164 +NVML_FI_DEV_NVLINK_GET_STATE = 165 +NVML_FI_DEV_NVLINK_GET_VERSION = 166 + +NVML_FI_DEV_NVLINK_GET_POWER_STATE = 167 +NVML_FI_DEV_NVLINK_GET_POWER_THRESHOLD = 168 + +NVML_FI_DEV_PCIE_L0_TO_RECOVERY_COUNTER = 169 + +NVML_FI_DEV_C2C_LINK_COUNT = 170 +NVML_FI_DEV_C2C_LINK_GET_STATUS = 171 +NVML_FI_DEV_C2C_LINK_GET_MAX_BW = 172 + +NVML_FI_DEV_PCIE_COUNT_CORRECTABLE_ERRORS = 173 +NVML_FI_DEV_PCIE_COUNT_NAKS_RECEIVED = 174 +NVML_FI_DEV_PCIE_COUNT_RECEIVER_ERROR = 175 +NVML_FI_DEV_PCIE_COUNT_BAD_TLP = 176 +NVML_FI_DEV_PCIE_COUNT_NAKS_SENT = 177 +NVML_FI_DEV_PCIE_COUNT_BAD_DLLP = 178 +NVML_FI_DEV_PCIE_COUNT_NON_FATAL_ERROR = 179 +NVML_FI_DEV_PCIE_COUNT_FATAL_ERROR = 180 +NVML_FI_DEV_PCIE_COUNT_UNSUPPORTED_REQ = 181 +NVML_FI_DEV_PCIE_COUNT_LCRC_ERROR = 182 +NVML_FI_DEV_PCIE_COUNT_LANE_ERROR = 183 + +NVML_FI_DEV_IS_RESETLESS_MIG_SUPPORTED = 184 + +NVML_FI_DEV_POWER_AVERAGE = 185 +NVML_FI_DEV_POWER_INSTANT = 186 +NVML_FI_DEV_POWER_MIN_LIMIT = 187 +NVML_FI_DEV_POWER_MAX_LIMIT = 188 +NVML_FI_DEV_POWER_DEFAULT_LIMIT = 189 +NVML_FI_DEV_POWER_CURRENT_LIMIT = 190 +NVML_FI_DEV_ENERGY = 191 +NVML_FI_DEV_POWER_REQUESTED_LIMIT = 192 + +NVML_FI_DEV_TEMPERATURE_SHUTDOWN_TLIMIT = 193 +NVML_FI_DEV_TEMPERATURE_SLOWDOWN_TLIMIT = 194 +NVML_FI_DEV_TEMPERATURE_MEM_MAX_TLIMIT = 195 +NVML_FI_DEV_TEMPERATURE_GPU_MAX_TLIMIT = 196 + +NVML_FI_DEV_PCIE_COUNT_TX_BYTES = 197 +NVML_FI_DEV_PCIE_COUNT_RX_BYTES = 198 + +NVML_FI_DEV_IS_MIG_MODE_INDEPENDENT_MIG_QUERY_CAPABLE = 199 + +NVML_FI_DEV_NVLINK_GET_POWER_THRESHOLD_MAX = 200 + +NVML_FI_DEV_NVLINK_COUNT_XMIT_PACKETS = 201 +NVML_FI_DEV_NVLINK_COUNT_XMIT_BYTES = 202 +NVML_FI_DEV_NVLINK_COUNT_RCV_PACKETS = 203 +NVML_FI_DEV_NVLINK_COUNT_RCV_BYTES = 204 +NVML_FI_DEV_NVLINK_COUNT_VL15_DROPPED = 205 # Deprecated, do not use +NVML_FI_DEV_NVLINK_COUNT_MALFORMED_PACKET_ERRORS = 206 +NVML_FI_DEV_NVLINK_COUNT_BUFFER_OVERRUN_ERRORS = 207 +NVML_FI_DEV_NVLINK_COUNT_RCV_ERRORS = 208 +NVML_FI_DEV_NVLINK_COUNT_RCV_REMOTE_ERRORS = 209 +NVML_FI_DEV_NVLINK_COUNT_RCV_GENERAL_ERRORS = 210 +NVML_FI_DEV_NVLINK_COUNT_LOCAL_LINK_INTEGRITY_ERRORS = 211 +NVML_FI_DEV_NVLINK_COUNT_XMIT_DISCARDS = 212 + +NVML_FI_DEV_NVLINK_COUNT_LINK_RECOVERY_SUCCESSFUL_EVENTS = 213 +NVML_FI_DEV_NVLINK_COUNT_LINK_RECOVERY_FAILED_EVENTS = 214 +NVML_FI_DEV_NVLINK_COUNT_LINK_RECOVERY_EVENTS = 215 + +NVML_FI_DEV_NVLINK_COUNT_RAW_BER_LANE0 = 216 # Deprecated, do not use +NVML_FI_DEV_NVLINK_COUNT_RAW_BER_LANE1 = 217 # Deprecated, do not use +NVML_FI_DEV_NVLINK_COUNT_RAW_BER = 218 # Deprecated, do not use +NVML_FI_DEV_NVLINK_COUNT_EFFECTIVE_ERRORS = 219 +NVML_FI_DEV_NVLINK_COUNT_EFFECTIVE_BER = 220 +NVML_FI_DEV_NVLINK_COUNT_SYMBOL_ERRORS = 221 +NVML_FI_DEV_NVLINK_COUNT_SYMBOL_BER = 222 + +NVML_FI_DEV_NVLINK_GET_POWER_THRESHOLD_MIN = 223 +NVML_FI_DEV_NVLINK_GET_POWER_THRESHOLD_UNITS = ( + 224 # Values are in the form NVML_NVLINK_LOW_POWER_THRESHOLD_UNIT_* +) +NVML_FI_DEV_NVLINK_GET_POWER_THRESHOLD_SUPPORTED = 225 + +NVML_FI_DEV_RESET_STATUS = ( + 226 # Deprecated use NVML_FI_DEV_GET_GPU_RECOVERY_ACTION instead +) +NVML_FI_DEV_DRAIN_AND_RESET_STATUS = ( + 227 # Deprecated use NVML_FI_DEV_GET_GPU_RECOVERY_ACTION instead +) +NVML_FI_DEV_PCIE_OUTBOUND_ATOMICS_MASK = 228 +NVML_FI_DEV_PCIE_INBOUND_ATOMICS_MASK = 229 +NVML_FI_DEV_GET_GPU_RECOVERY_ACTION = 230 + +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_0 = 235 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_1 = 236 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_2 = 237 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_3 = 238 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_4 = 239 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_5 = 240 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_6 = 241 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_7 = 242 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_8 = 243 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_9 = 244 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_10 = 245 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_11 = 246 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_12 = 247 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_13 = 248 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_14 = 249 +NVML_FI_DEV_NVLINK_COUNT_FEC_HISTORY_15 = 250 +NVML_FI_PWR_SMOOTHING_ENABLED = 251 # Enablement (0/DISABLED or 1/ENABLED) +NVML_FI_PWR_SMOOTHING_PRIV_LVL = 252 # Current privilege level +NVML_FI_PWR_SMOOTHING_IMM_RAMP_DOWN_ENABLED = ( + 253 # Immediate ramp down enablement (0/DISABLED or 1/ENABLED) +) +NVML_FI_PWR_SMOOTHING_APPLIED_TMP_CEIL = 254 # Applied TMP ceiling value +NVML_FI_PWR_SMOOTHING_APPLIED_TMP_FLOOR = 255 # Applied TMP floor value +NVML_FI_PWR_SMOOTHING_MAX_PERCENT_TMP_FLOOR_SETTING = 256 # Max % TMP Floor value +NVML_FI_PWR_SMOOTHING_MIN_PERCENT_TMP_FLOOR_SETTING = 257 # Min % TMP Floor value +NVML_FI_PWR_SMOOTHING_HW_CIRCUITRY_PERCENT_LIFETIME_REMAINING = ( + 258 # HW Circuitry % lifetime remaining +) +NVML_FI_PWR_SMOOTHING_MAX_NUM_PRESET_PROFILES = 259 # Max number of preset profiles +NVML_FI_PWR_SMOOTHING_PROFILE_PERCENT_TMP_FLOOR = 260 # % TMP floor for a given profile +NVML_FI_PWR_SMOOTHING_PROFILE_RAMP_UP_RATE = ( + 261 # Ramp up rate in mW/s for a given profile +) +NVML_FI_PWR_SMOOTHING_PROFILE_RAMP_DOWN_RATE = ( + 262 # Ramp down rate in mW/s for a given profile +) +NVML_FI_PWR_SMOOTHING_PROFILE_RAMP_DOWN_HYST_VAL = ( + 263 # Ramp down hysteresis value in ms for a given profile +) +NVML_FI_PWR_SMOOTHING_ACTIVE_PRESET_PROFILE = 264 # Active preset profile number +NVML_FI_PWR_SMOOTHING_ADMIN_OVERRIDE_PERCENT_TMP_FLOOR = ( + 265 # % TMP floor for a given profile +) +NVML_FI_PWR_SMOOTHING_ADMIN_OVERRIDE_RAMP_UP_RATE = ( + 266 # Ramp up rate in mW/s for a given profile +) +NVML_FI_PWR_SMOOTHING_ADMIN_OVERRIDE_RAMP_DOWN_RATE = ( + 267 # Ramp down rate in mW/s for a given profile +) +NVML_FI_PWR_SMOOTHING_ADMIN_OVERRIDE_RAMP_DOWN_HYST_VAL = ( + 268 # Ramp down hysteresis value in ms for a given profile +) + +NVML_FI_MAX = 269 # One greater than the largest field ID defined above + +# NVML_FI_DEV_NVLINK_GET_STATE state enums +NVML_NVLINK_STATE_INACTIVE = 0x0 +NVML_NVLINK_STATE_ACTIVE = 0x1 +NVML_NVLINK_STATE_SLEEP = 0x2 + +NVML_NVLINK_LOW_POWER_THRESHOLD_UNIT_100US = ( + 0 # NVML_FI_DEV_NVLINK_GET_POWER_THRESHOLD_UNITS +) +NVML_NVLINK_LOW_POWER_THRESHOLD_UNIT_50US = ( + 1 # NVML_FI_DEV_NVLINK_GET_POWER_THRESHOLD_UNITS +) + +## Enums needed for the method nvmlDeviceGetVirtualizationMode and nvmlDeviceSetVirtualizationMode +NVML_GPU_VIRTUALIZATION_MODE_NONE = 0 # Represents Bare Metal GPU +NVML_GPU_VIRTUALIZATION_MODE_PASSTHROUGH = ( + 1 # Device is associated with GPU-Passthorugh +) +NVML_GPU_VIRTUALIZATION_MODE_VGPU = ( + 2 # Device is associated with vGPU inside virtual machine. +) +NVML_GPU_VIRTUALIZATION_MODE_HOST_VGPU = ( + 3 # Device is associated with VGX hypervisor in vGPU mode +) +NVML_GPU_VIRTUALIZATION_MODE_HOST_VSGA = ( + 4 # Device is associated with VGX hypervisor in vSGA mode +) + +## Lib loading ## +nvmlLib = None +libLoadLock = threading.Lock() +_nvmlLib_refcount = 0 # Incremented on each nvmlInit and decremented on nvmlShutdown + +## vGPU Management +_nvmlVgpuTypeId_t = c_uint +_nvmlVgpuInstance_t = c_uint + +_nvmlVgpuVmIdType_t = c_uint +NVML_VGPU_VM_ID_DOMAIN_ID = 0 +NVML_VGPU_VM_ID_UUID = 1 + +_nvmlGridLicenseFeatureCode_t = c_uint +NVML_GRID_LICENSE_FEATURE_CODE_UNKNOWN = 0 +NVML_GRID_LICENSE_FEATURE_CODE_VGPU = 1 +NVML_GRID_LICENSE_FEATURE_CODE_NVIDIA_RTX = 2 +NVML_GRID_LICENSE_FEATURE_CODE_VWORKSTATION = ( + 2 # deprecated, use NVML_GRID_LICENSE_FEATURE_CODE_NVIDIA_RTX. +) +NVML_GRID_LICENSE_FEATURE_CODE_GAMING = 3 +NVML_GRID_LICENSE_FEATURE_CODE_COMPUTE = 4 + +_nvmlGridLicenseExpiryStatus_t = c_uint8 +NVML_GRID_LICENSE_EXPIRY_NOT_AVAILABLE = (0,) # Expiry information not available +NVML_GRID_LICENSE_EXPIRY_INVALID = (1,) # Invalid expiry or error fetching expiry +NVML_GRID_LICENSE_EXPIRY_VALID = (2,) # Valid expiry +NVML_GRID_LICENSE_EXPIRY_NOT_APPLICABLE = (3,) # Expiry not applicable +NVML_GRID_LICENSE_EXPIRY_PERMANENT = (4,) # Permanent expiry + +_nvmlVgpuCapability_t = c_uint +NVML_VGPU_CAP_NVLINK_P2P = 0 # vGPU P2P over NVLink is supported +NVML_VGPU_CAP_GPUDIRECT = 1 # GPUDirect capability is supported +NVML_VGPU_CAP_MULTI_VGPU_EXCLUSIVE = ( + 2 # vGPU profile cannot be mixed with other vGPU profiles in same VM +) +NVML_VGPU_CAP_EXCLUSIVE_TYPE = ( + 3 # vGPU profile cannot run on a GPU alongside other profiles of different type +) +NVML_VGPU_CAP_EXCLUSIVE_SIZE = ( + 4 # vGPU profile cannot run on a GPU alongside other profiles of different size +) +NVML_VGPU_CAP_COUNT = 5 + +_nvmlVgpuDriverCapability_t = c_uint +NVML_VGPU_DRIVER_CAP_HETEROGENEOUS_MULTI_VGPU = ( + 0 # Supports mixing of different vGPU profiles within one guest VM +) +NVML_VGPU_DRIVER_CAP_WARM_UPDATE = 1 # Supports FSR and warm update of vGPU host driver without terminating the running guest VM +NVML_VGPU_DRIVER_CAP_COUNT = 2 + +_nvmlDeviceVgpuCapability_t = c_uint +NVML_DEVICE_VGPU_CAP_FRACTIONAL_MULTI_VGPU = 0 # Query whether the fractional vGPU profiles on this GPU can be used in multi-vGPU configurations +NVML_DEVICE_VGPU_CAP_HETEROGENEOUS_TIMESLICE_PROFILES = 1 # Query whether the GPU supports concurrent execution of timesliced vGPU profiles of differing types +NVML_DEVICE_VGPU_CAP_HETEROGENEOUS_TIMESLICE_SIZES = 2 # Query whether the GPU supports concurrent execution of timesliced vGPU profiles of differing framebuffer sizes +NVML_DEVICE_VGPU_CAP_READ_DEVICE_BUFFER_BW = 3 # Query the GPU's read_device_buffer expected bandwidth capacity in megabytes per second +NVML_DEVICE_VGPU_CAP_WRITE_DEVICE_BUFFER_BW = 4 # Query the GPU's write_device_buffer expected bandwidth capacity in megabytes per second +NVML_DEVICE_VGPU_CAP_DEVICE_STREAMING = ( + 5 # Query whether the vGPU profiles on the GPU supports migration data streaming +) +NVML_DEVICE_VGPU_CAP_MINI_QUARTER_GPU = ( + 6 # Set/Get support of mini-quarter vGPU profiles +) +NVML_DEVICE_VGPU_CAP_COMPUTE_MEDIA_ENGINE_GPU = ( + 7 # Set/Get support for compute media engine vGPU profiles +) +NVML_DEVICE_VGPU_CAP_WARM_UPDATE = ( + 8 # Query whether the GPU supports FSR and warm update +) +NVML_DEVICE_VGPU_CAP_HOMOGENEOUS_PLACEMENTS = 9 # Query whether the GPU supports reporting of placements of timesliced vGPU profiles with identical framebuffer sizes +NVML_DEVICE_VGPU_CAP_COUNT = 10 + +_nvmlVgpuGuestInfoState_t = c_uint +NVML_VGPU_INSTANCE_GUEST_INFO_STATE_UNINITIALIZED = 0 +NVML_VGPU_INSTANCE_GUEST_INFO_STATE_INITIALIZED = 1 + +_nvmlVgpuVmCompatibility_t = c_uint +NVML_VGPU_VM_COMPATIBILITY_NONE = 0x0 +NVML_VGPU_VM_COMPATIBILITY_COLD = 0x1 +NVML_VGPU_VM_COMPATIBILITY_HIBERNATE = 0x2 +NVML_VGPU_VM_COMPATIBILITY_SLEEP = 0x4 +NVML_VGPU_VM_COMPATIBILITY_LIVE = 0x8 + +_nvmlVgpuPgpuCompatibilityLimitCode_t = c_uint +NVML_VGPU_COMPATIBILITY_LIMIT_NONE = 0x0 +NVML_VGPU_COMPATIBILITY_LIMIT_HOST_DRIVER = 0x1 +NVML_VGPU_COMPATIBILITY_LIMIT_GUEST_DRIVER = 0x2 +NVML_VGPU_COMPATIBILITY_LIMIT_GPU = 0x4 +NVML_VGPU_COMPATIBILITY_LIMIT_OTHER = 0x80000000 + +_nvmlHostVgpuMode_t = c_uint +NVML_HOST_VGPU_MODE_NON_SRIOV = 0 +NVML_HOST_VGPU_MODE_SRIOV = 1 + +_nvmlConfComputeGpusReadyState_t = c_uint +NVML_CC_ACCEPTING_CLIENT_REQUESTS_FALSE = 0 +NVML_CC_ACCEPTING_CLIENT_REQUESTS_TRUE = 1 + +_nvmlConfComputeGpuCaps_t = c_uint +NVML_CC_SYSTEM_GPUS_CC_NOT_CAPABLE = 0 +NVML_CC_SYSTEM_GPUS_CC_CAPABLE = 1 + +_nvmlConfComputeCpuCaps_t = c_uint +NVML_CC_SYSTEM_CPU_CAPS_NONE = 0 +NVML_CC_SYSTEM_CPU_CAPS_AMD_SEV = 1 +NVML_CC_SYSTEM_CPU_CAPS_INTEL_TDX = 2 +NVML_CC_SYSTEM_CPU_CAPS_AMD_SEV_SNP = 3 +NVML_CC_SYSTEM_CPU_CAPS_AMD_SNP_VTOM = 4 + +_nvmlConfComputeDevToolsMode_t = c_uint +NVML_CC_SYSTEM_DEVTOOLS_MODE_OFF = 0 +NVML_CC_SYSTEM_DEVTOOLS_MODE_ON = 1 + +NVML_CC_SYSTEM_MULTIGPU_NONE = 0 +NVML_CC_SYSTEM_MULTIGPU_PROTECTED_PCIE = 1 + +NVML_CC_SYSTEM_ENVIRONMENT_UNAVAILABLE = 0 +NVML_CC_SYSTEM_ENVIRONMENT_SIM = 1 +NVML_CC_SYSTEM_ENVIRONMENT_PROD = 2 + +_nvmlConfComputeCcFeature_t = c_uint +NVML_CC_SYSTEM_FEATURE_DISABLED = 0 +NVML_CC_SYSTEM_FEATURE_ENABLED = 1 + +_nvmlConfComputeCcKeyRotationThreshAttackerAdv_t = c_uint +NVML_CC_KEY_ROTATION_THRESH_ATTACKER_ADVANTAGE_MIN = 50 +NVML_CC_KEY_ROTATION_THRESH_ATTACKER_ADVANTAGE_MAX = 65 + +# GSP firmware +NVML_GSP_FIRMWARE_VERSION_BUF_SIZE = 0x40 + + +class NVMLLibraryMismatchError(Exception): + pass + + +## Error Checking ## +class NVMLError(Exception): + _valClassMapping = dict() + # List of currently known error codes + _errcode_to_string = { + NVML_ERROR_UNINITIALIZED: "Uninitialized", + NVML_ERROR_INVALID_ARGUMENT: "Invalid Argument", + NVML_ERROR_NOT_SUPPORTED: "Not Supported", + NVML_ERROR_NO_PERMISSION: "Insufficient Permissions", + NVML_ERROR_ALREADY_INITIALIZED: "Already Initialized", + NVML_ERROR_NOT_FOUND: "Not Found", + NVML_ERROR_INSUFFICIENT_SIZE: "Insufficient Size", + NVML_ERROR_INSUFFICIENT_POWER: "Insufficient External Power", + NVML_ERROR_DRIVER_NOT_LOADED: "Driver Not Loaded", + NVML_ERROR_TIMEOUT: "Timeout", + NVML_ERROR_IRQ_ISSUE: "Interrupt Request Issue", + NVML_ERROR_LIBRARY_NOT_FOUND: "NVML Shared Library Not Found", + NVML_ERROR_FUNCTION_NOT_FOUND: "Function Not Found", + NVML_ERROR_CORRUPTED_INFOROM: "Corrupted infoROM", + NVML_ERROR_GPU_IS_LOST: "GPU is lost", + NVML_ERROR_RESET_REQUIRED: "GPU requires restart", + NVML_ERROR_OPERATING_SYSTEM: "The operating system has blocked the request.", + NVML_ERROR_LIB_RM_VERSION_MISMATCH: "RM has detected an NVML/RM version mismatch.", + NVML_ERROR_MEMORY: "Insufficient Memory", + NVML_ERROR_UNKNOWN: "Unknown Error", + } + + def __new__(typ, value): + """ + Maps value to a proper subclass of NVMLError. + See _extractNVMLErrorsAsClasses function for more details + """ + if typ == NVMLError: + typ = NVMLError._valClassMapping.get(value, typ) + obj = Exception.__new__(typ) + obj.value = value + return obj + + def __str__(self): + try: + if self.value not in NVMLError._errcode_to_string: + NVMLError._errcode_to_string[self.value] = str( + nvmlErrorString(self.value) + ) + return NVMLError._errcode_to_string[self.value] + except NVMLError: + return "NVML Error with code %d" % self.value + + def __eq__(self, other): + return self.value == other.value + + +def nvmlExceptionClass(nvmlErrorCode): + if nvmlErrorCode not in NVMLError._valClassMapping: + raise ValueError("nvmlErrorCode %s is not valid" % nvmlErrorCode) + return NVMLError._valClassMapping[nvmlErrorCode] + + +def _extractNVMLErrorsAsClasses(): + """ + Generates a hierarchy of classes on top of NVMLError class. + + Each NVML Error gets a new NVMLError subclass. This way try,except blocks can filter appropriate + exceptions more easily. + + NVMLError is a parent class. Each NVML_ERROR_* gets it's own subclass. + e.g. NVML_ERROR_ALREADY_INITIALIZED will be turned into NVMLError_AlreadyInitialized + """ + this_module = sys.modules[__name__] + nvmlErrorsNames = [x for x in dir(this_module) if x.startswith("NVML_ERROR_")] + for err_name in nvmlErrorsNames: + # e.g. Turn NVML_ERROR_ALREADY_INITIALIZED into NVMLError_AlreadyInitialized + class_name = "NVMLError_" + string.capwords( + err_name.replace("NVML_ERROR_", ""), "_" + ).replace("_", "") + err_val = getattr(this_module, err_name) + + def gen_new(val): + def new(typ): + obj = NVMLError.__new__(typ, val) + return obj + + return new + + new_error_class = type(class_name, (NVMLError,), {"__new__": gen_new(err_val)}) + new_error_class.__module__ = __name__ + setattr(this_module, class_name, new_error_class) + NVMLError._valClassMapping[err_val] = new_error_class + + +_extractNVMLErrorsAsClasses() + + +def _nvmlCheckReturn(ret): + if ret != NVML_SUCCESS: + raise NVMLError(ret) + return ret + + +## Function access ## +_nvmlGetFunctionPointer_cache = ( + dict() +) # function pointers are cached to prevent unnecessary libLoadLock locking + + +def _nvmlGetFunctionPointer(name): + global nvmlLib + + if name in _nvmlGetFunctionPointer_cache: + return _nvmlGetFunctionPointer_cache[name] + + libLoadLock.acquire() + try: + # ensure library was loaded + if nvmlLib == None: + raise NVMLError(NVML_ERROR_UNINITIALIZED) + try: + _nvmlGetFunctionPointer_cache[name] = getattr(nvmlLib, name) + return _nvmlGetFunctionPointer_cache[name] + except AttributeError: + raise NVMLError(NVML_ERROR_FUNCTION_NOT_FOUND) + finally: + # lock is always freed + libLoadLock.release() + + +## Alternative object +# Allows the object to be printed +# Allows mismatched types to be assigned +# - like None when the Structure variant requires c_uint +class nvmlFriendlyObject(object): + def __init__(self, dictionary): + for x in dictionary: + setattr(self, x, dictionary[x]) + + def __str__(self): + return self.__dict__.__str__() + + +def nvmlStructToFriendlyObject(struct): + d = {} + for x in struct._fields_: + key = x[0] + value = getattr(struct, key) + # only need to convert from bytes if bytes, no need to check python version. + d[key] = value.decode() if isinstance(value, bytes) else value + obj = nvmlFriendlyObject(d) + return obj + + +# pack the object so it can be passed to the NVML library +def nvmlFriendlyObjectToStruct(obj, model): + for x in model._fields_: + key = x[0] + value = obj.__dict__[key] + # any c_char_p in python3 needs to be bytes, default encoding works fine. + if sys.version_info >= (3,): + setattr(model, key, value.encode()) + else: + setattr(model, key, value) + return model + + +## Unit structures +class struct_c_nvmlUnit_t(Structure): + pass # opaque handle + + +c_nvmlUnit_t = POINTER(struct_c_nvmlUnit_t) + + +class _PrintableStructure(Structure): + """ + Abstract class that produces nicer __str__ output than ctypes.Structure. + e.g. instead of: + >>> print str(obj) + + this class will print + class_name(field_name: formatted_value, field_name: formatted_value) + + _fmt_ dictionary of -> + e.g. class that has _field_ 'hex_value', c_uint could be formatted with + _fmt_ = {"hex_value" : "%08X"} + to produce nicer output. + Default formatting string for all fields can be set with key "" like: + _fmt_ = {"" : "%d MHz"} # e.g all values are numbers in MHz. + If not set it's assumed to be just "%s" + + Exact format of returned str from this class is subject to change in the future. + """ + + _fmt_ = {} + + def __str__(self): + result = [] + for x in self._fields_: + key = x[0] + value = getattr(self, key) + fmt = "%s" + if key in self._fmt_: + fmt = self._fmt_[key] + elif "" in self._fmt_: + fmt = self._fmt_[""] + result.append(("%s: " + fmt) % (key, value)) + return self.__class__.__name__ + "(" + ", ".join(result) + ")" + + def __getattribute__(self, name): + res = super(_PrintableStructure, self).__getattribute__(name) + # need to convert bytes to unicode for python3 don't need to for python2 + # Python 2 strings are of both str and bytes + # Python 3 strings are not of type bytes + # ctypes should convert everything to the correct values otherwise + if isinstance(res, bytes): + if isinstance(res, str): + return res + return res.decode() + return res + + def __setattr__(self, name, value): + if isinstance(value, str): + # encoding a python2 string returns the same value, since python2 strings are bytes already + # bytes passed in python3 will be ignored. + value = value.encode() + super(_PrintableStructure, self).__setattr__(name, value) + + +class c_nvmlUnitInfo_t(_PrintableStructure): + _fields_ = [ + ("name", c_char * 96), + ("id", c_char * 96), + ("serial", c_char * 96), + ("firmwareVersion", c_char * 96), + ] + + +class c_nvmlC2cModeInfo_v1_t(_PrintableStructure): + _fields_ = [("isC2cEnabled", c_uint)] + + +nvmlC2cModeInfo_v1 = 0x1000008 + + +class c_nvmlLedState_t(_PrintableStructure): + _fields_ = [ + ("cause", c_char * 256), + ("color", _nvmlLedColor_t), + ] + + +class c_nvmlPSUInfo_t(_PrintableStructure): + _fields_ = [ + ("state", c_char * 256), + ("current", c_uint), + ("voltage", c_uint), + ("power", c_uint), + ] + + +class c_nvmlUnitFanInfo_t(_PrintableStructure): + _fields_ = [ + ("speed", c_uint), + ("state", _nvmlFanState_t), + ] + + +class c_nvmlUnitFanSpeeds_t(_PrintableStructure): + _fields_ = [("fans", c_nvmlUnitFanInfo_t * 24), ("count", c_uint)] + + +## Device structures +class struct_c_nvmlDevice_t(Structure): + pass # opaque handle + + +c_nvmlDevice_t = POINTER(struct_c_nvmlDevice_t) + + +class nvmlPciInfoExt_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("domain", c_uint), + ("bus", c_uint), + ("device", c_uint), + ("pciDeviceId", c_uint), + ("pciSubSystemId", c_uint), + ("baseClass", c_uint), + ("subClass", c_uint), + ("busId", c_char * NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE), + ] + _fmt_ = { + "version": "0x%04X", + "domain": "0x%04X", + "bus": "0x%02X", + "device": "0x%02X", + "pciDeviceId": "0x%08X", + "pciSubSystemId": "0x%08X", + "baseClass": "0x%01X", + "subClass": "0x%01X", + } + + +nvmlPciInfoExt_v1 = 0x1000040 + + +# Legacy pciInfo used for _v1 and _v2 +class nvmlPciInfo_v2_t(_PrintableStructure): + _fields_ = [ + ("busId", c_char * NVML_DEVICE_PCI_BUS_ID_BUFFER_V2_SIZE), + ("domain", c_uint), + ("bus", c_uint), + ("device", c_uint), + ("pciDeviceId", c_uint), + # Added in 2.285 + ("pciSubSystemId", c_uint), + ("reserved0", c_uint), + ("reserved1", c_uint), + ("reserved2", c_uint), + ("reserved3", c_uint), + ] + _fmt_ = { + "domain": "0x%04X", + "bus": "0x%02X", + "device": "0x%02X", + "pciDeviceId": "0x%08X", + "pciSubSystemId": "0x%08X", + } + + +class nvmlPciInfo_t(_PrintableStructure): + _fields_ = [ + # Moved to the new busId location below + ("busIdLegacy", c_char * NVML_DEVICE_PCI_BUS_ID_BUFFER_V2_SIZE), + ("domain", c_uint), + ("bus", c_uint), + ("device", c_uint), + ("pciDeviceId", c_uint), + # Added in 2.285 + ("pciSubSystemId", c_uint), + # New busId replaced the long deprecated and reserved fields with a + # field of the same size in 9.0 + ("busId", c_char * NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE), + ] + _fmt_ = { + "domain": "0x%08X", + "bus": "0x%02X", + "device": "0x%02X", + "pciDeviceId": "0x%08X", + "pciSubSystemId": "0x%08X", + } + + +class c_nvmlSystemDriverBranchInfo_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("branch", c_char * NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE), + ] + + +SystemDriverBranchInfo_v1 = 0x1000054 + + +class c_nvmlExcludedDeviceInfo_t(_PrintableStructure): + _fields_ = [("pci", nvmlPciInfo_t), ("uuid", c_char * NVML_DEVICE_UUID_BUFFER_SIZE)] + + +class nvmlNvLinkUtilizationControl_t(_PrintableStructure): + _fields_ = [ + ("units", _nvmlNvLinkUtilizationCountUnits_t), + ("pktfilter", _nvmlNvLinkUtilizationCountPktTypes_t), + ] + + +class c_nvmlMemory_t(_PrintableStructure): + _fields_ = [ + ("total", c_ulonglong), + ("free", c_ulonglong), + ("used", c_ulonglong), + ] + _fmt_ = {"": "%d B"} + + +class c_nvmlMemory_v2_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("total", c_ulonglong), + ("reserved", c_ulonglong), + ("free", c_ulonglong), + ("used", c_ulonglong), + ] + _fmt_ = {"": "%d B"} + + +nvmlMemory_v2 = 0x02000028 + + +class c_nvmlBAR1Memory_t(_PrintableStructure): + _fields_ = [ + ("bar1Total", c_ulonglong), + ("bar1Free", c_ulonglong), + ("bar1Used", c_ulonglong), + ] + _fmt_ = {"": "%d B"} + + +class nvmlClkMonFaultInfo_t(Structure): + _fields_ = [("clkApiDomain", c_uint), ("clkDomainFaultMask", c_uint)] + + +MAX_CLK_DOMAINS = 32 + + +class nvmlClkMonStatus_t(Structure): + _fields_ = [ + ("bGlobalStatus", c_uint), + ("clkMonListSize", c_uint), + ("clkMonList", nvmlClkMonFaultInfo_t * MAX_CLK_DOMAINS), + ] + + +# On Windows with the WDDM driver, usedGpuMemory is reported as None +# Code that processes this structure should check for None, I.E. +# +# if (info.usedGpuMemory == None): +# # TODO handle the error +# pass +# else: +# print("Using %d MiB of memory" % (info.usedGpuMemory / 1024 / 1024)) +# endif +# +# See NVML documentation for more information +class c_nvmlProcessInfo_v2_t(_PrintableStructure): + _fields_ = [ + ("pid", c_uint), + ("usedGpuMemory", c_ulonglong), + ("gpuInstanceId", c_uint), + ("computeInstanceId", c_uint), + ] + _fmt_ = {"usedGpuMemory": "%d B"} + + +c_nvmlProcessInfo_v3_t = c_nvmlProcessInfo_v2_t + +c_nvmlProcessInfo_t = c_nvmlProcessInfo_v3_t + +_nvmlProcessMode_t = c_uint +NVML_PROCESS_MODE_COMPUTE = 0 +NVML_PROCESS_MODE_GRAPHICS = 1 +NVML_PROCESS_MODE_MPS = 2 + + +class c_nvmlProcessDetail_v1_t(Structure): + _fields_ = [ + ("pid", c_uint), + ("usedGpuMemory", c_ulonglong), + ("gpuInstanceId", c_uint), + ("computeInstanceId", c_uint), + ("usedGpuCcProtectedMemory", c_ulonglong), + ] + + +class c_nvmlProcessDetailList_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("mode", _nvmlProcessMode_t), + ("numProcArrayEntries", c_uint), + ("procArray", POINTER(c_nvmlProcessDetail_v1_t)), + ] + _fmt_ = {"numProcArrayEntries": "%d B"} + + +c_nvmlProcessDetailList_t = c_nvmlProcessDetailList_v1_t + +nvmlProcessDetailList_v1 = 0x1000018 + + +class c_nvmlBridgeChipInfo_t(_PrintableStructure): + _fields_ = [ + ("type", _nvmlBridgeChipType_t), + ("fwVersion", c_uint), + ] + + +class c_nvmlBridgeChipHierarchy_t(_PrintableStructure): + _fields_ = [ + ("bridgeCount", c_uint), + ("bridgeChipInfo", c_nvmlBridgeChipInfo_t * 128), + ] + + +class c_nvmlEccErrorCounts_t(_PrintableStructure): + _fields_ = [ + ("l1Cache", c_ulonglong), + ("l2Cache", c_ulonglong), + ("deviceMemory", c_ulonglong), + ("registerFile", c_ulonglong), + ] + + +class c_nvmlUtilization_t(_PrintableStructure): + _fields_ = [ + ("gpu", c_uint), + ("memory", c_uint), + ] + _fmt_ = {"": "%d %%"} + + +# Added in 2.285 +class c_nvmlHwbcEntry_t(_PrintableStructure): + _fields_ = [ + ("hwbcId", c_uint), + ("firmwareVersion", c_char * 32), + ] + + +class c_nvmlValue_t(Union): + _fields_ = [ + ("dVal", c_double), + ("uiVal", c_uint), + ("ulVal", c_ulong), + ("ullVal", c_ulonglong), + ("sllVal", c_longlong), + ("siVal", c_int), + ("usVal", c_ushort), + ] + + +class c_nvmlSample_t(_PrintableStructure): + _fields_ = [ + ("timeStamp", c_ulonglong), + ("sampleValue", c_nvmlValue_t), + ] + + +class c_nvmlViolationTime_t(_PrintableStructure): + _fields_ = [ + ("referenceTime", c_ulonglong), + ("violationTime", c_ulonglong), + ] + + +class c_nvmlFieldValue_t(_PrintableStructure): + _fields_ = [ + ("fieldId", c_uint32), + ("scopeId", c_uint32), + ("timestamp", c_int64), + ("latencyUsec", c_int64), + ("valueType", _nvmlValueType_t), + ("nvmlReturn", _nvmlReturn_t), + ("value", c_nvmlValue_t), + ] + + +NVML_NVLINK_TOTAL_SUPPORTED_BW_MODES = 23 + +nvmlNvlinkSupportedBwModes_v1 = 0x100001C + + +class c_nvmlNvlinkSupportedBwModes_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("bwModes", c_uint8 * NVML_NVLINK_TOTAL_SUPPORTED_BW_MODES), + ("totalBwModes", c_uint8), + ] + + def __init__(self): + super(c_nvmlNvlinkSupportedBwModes_v1_t, self).__init__( + version=nvmlNvlinkSupportedBwModes_v1 + ) + + +nvmlNvlinkGetBwMode_v1 = 0x100000C + + +class c_nvmlNvlinkGetBwMode_v1_t(_PrintableStructure): + _fields_ = [("version", c_uint), ("bIsBest", c_uint), ("bwMode", c_uint8)] + + def __init__(self): + super(c_nvmlNvlinkGetBwMode_v1_t, self).__init__(version=nvmlNvlinkGetBwMode_v1) + + +nvmlNvlinkSetBwMode_v1 = 0x100000C + + +class c_nvmlNvlinkSetBwMode_v1_t(_PrintableStructure): + _fields_ = [("version", c_uint), ("bSetBest", c_uint), ("bwMode", c_uint8)] + + def __init__(self): + super(c_nvmlNvlinkSetBwMode_v1_t, self).__init__(version=nvmlNvlinkSetBwMode_v1) + + +class c_nvmlVgpuHeterogeneousMode_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("mode", c_uint), + ] + + +VgpuHeterogeneousMode_v1 = 0x1000008 + + +class c_nvmlVgpuPlacementId_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("placementId", c_uint), + ] + + +VgpuPlacementId_v1 = 0x1000008 + + +class c_nvmlVgpuPlacementList_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("count", c_uint), + ("placementSize", c_uint), + ("placementIds", POINTER(c_uint)), + ] + + +VgpuPlacementList_v1 = 0x1000018 + +NVML_VGPU_PGPU_HETEROGENEOUS_MODE = 0 +NVML_VGPU_PGPU_HOMOGENEOUS_MODE = 1 + + +class c_nvmlVgpuPlacementList_v2_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("placementSize", c_uint), + ("count", c_uint), + ("placementIds", POINTER(c_uint)), + ("mode", c_uint), + ] + + +VgpuPlacementList_v2 = 0x2000020 + + +class c_nvmlVgpuTypeBar1Info_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("bar1Size", c_ulonglong), + ] + + +VgpuTypeBar1Info_v1 = 0x1000010 + + +class c_nvmlVgpuInstanceUtilizationSample_t(_PrintableStructure): + _fields_ = [ + ("vgpuInstance", _nvmlVgpuInstance_t), + ("timeStamp", c_ulonglong), + ("smUtil", c_nvmlValue_t), + ("memUtil", c_nvmlValue_t), + ("encUtil", c_nvmlValue_t), + ("decUtil", c_nvmlValue_t), + ] + + +class c_nvmlVgpuInstanceUtilizationInfo_v1_t(_PrintableStructure): + _fields_ = [ + ("timeStamp", c_ulonglong), + ("vgpuInstance", _nvmlVgpuInstance_t), + ("smUtil", c_nvmlValue_t), + ("memUtil", c_nvmlValue_t), + ("encUtil", c_nvmlValue_t), + ("decUtil", c_nvmlValue_t), + ("jpgUtil", c_nvmlValue_t), + ("ofaUtil", c_nvmlValue_t), + ] + + +class c_nvmlVgpuInstancesUtilizationInfo_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("sampleValType", _nvmlValueType_t), + ("vgpuInstanceCount", c_uint), + ("lastSeenTimeStamp", c_ulonglong), + ("vgpuUtilArray", POINTER(c_nvmlVgpuInstanceUtilizationInfo_v1_t)), + ] + + +VgpuInstancesUtilizationInfo_v1 = 0x01000020 + + +class c_nvmlVgpuProcessUtilizationSample_t(_PrintableStructure): + _fields_ = [ + ("vgpuInstance", _nvmlVgpuInstance_t), + ("pid", c_uint), + ("processName", c_char * NVML_VGPU_NAME_BUFFER_SIZE), + ("timeStamp", c_ulonglong), + ("smUtil", c_uint), + ("memUtil", c_uint), + ("encUtil", c_uint), + ("decUtil", c_uint), + ] + + +class c_nvmlVgpuProcessUtilizationInfo_v1_t(_PrintableStructure): + _fields_ = [ + ("processName", c_char * NVML_VGPU_NAME_BUFFER_SIZE), + ("timeStamp", c_ulonglong), + ("vgpuInstance", _nvmlVgpuInstance_t), + ("pid", c_uint), + ("smUtil", c_uint), + ("memUtil", c_uint), + ("encUtil", c_uint), + ("decUtil", c_uint), + ("jpgUtil", c_uint), + ("ofaUtil", c_uint), + ] + + +class c_nvmlVgpuProcessesUtilizationInfo_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("vgpuProcessCount", c_uint), + ("lastSeenTimeStamp", c_ulonglong), + ("vgpuProcUtilArray", POINTER(c_nvmlVgpuProcessUtilizationInfo_v1_t)), + ] + + +VgpuProcessesUtilizationInfo_v1 = 0x01000018 + + +class nvmlVgpuRuntimeState_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("size", c_ulonglong), + ] + + +VgpuRuntimeState_v1 = 0x1000010 + + +class c_nvmlVgpuLicenseExpiry_t(_PrintableStructure): + _fields_ = [ + ("year", c_uint32), + ("month", c_uint16), + ("day", c_uint16), + ("hour", c_uint16), + ("min", c_uint16), + ("sec", c_uint16), + ("status", c_uint8), + ] + + +NVML_GRID_LICENSE_STATE_UNKNOWN = 0 +NVML_GRID_LICENSE_STATE_UNINITIALIZED = 1 +NVML_GRID_LICENSE_STATE_UNLICENSED_UNRESTRICTED = 2 +NVML_GRID_LICENSE_STATE_UNLICENSED_RESTRICTED = 3 +NVML_GRID_LICENSE_STATE_UNLICENSED = 4 +NVML_GRID_LICENSE_STATE_LICENSED = 5 + + +class c_nvmlVgpuLicenseInfo_t(_PrintableStructure): + _fields_ = [ + ("isLicensed", c_uint8), + ("licenseExpiry", c_nvmlVgpuLicenseExpiry_t), + ("currentState", c_uint), + ] + + +class c_nvmlEncoderSession_t(_PrintableStructure): + _fields_ = [ + ("sessionId", c_uint), + ("pid", c_uint), + ("vgpuInstance", _nvmlVgpuInstance_t), + ("codecType", c_uint), + ("hResolution", c_uint), + ("vResolution", c_uint), + ("averageFps", c_uint), + ("encodeLatency", c_uint), + ] + + +class c_nvmlProcessUtilizationSample_t(_PrintableStructure): + _fields_ = [ + ("pid", c_uint), + ("timeStamp", c_ulonglong), + ("smUtil", c_uint), + ("memUtil", c_uint), + ("encUtil", c_uint), + ("decUtil", c_uint), + ] + + +class c_nvmlProcessUtilizationInfo_v1_t(_PrintableStructure): + _fields_ = [ + ("timeStamp", c_ulonglong), + ("pid", c_uint), + ("smUtil", c_uint), + ("memUtil", c_uint), + ("encUtil", c_uint), + ("decUtil", c_uint), + ("jpgUtil", c_uint), + ("ofaUtil", c_uint), + ] + + +class c_nvmlProcessesUtilizationInfo_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("processSamplesCount", c_uint), + ("lastSeenTimeStamp", c_ulonglong), + ("procUtilArray", POINTER(c_nvmlProcessUtilizationInfo_v1_t)), + ] + + +ProcessesUtilizationInfo_v1 = 0x01000018 + + +class c_nvmlGridLicenseExpiry_t(_PrintableStructure): + _fields_ = [ + ("year", c_uint32), + ("month", c_uint16), + ("day", c_uint16), + ("hour", c_uint16), + ("min", c_uint16), + ("sec", c_uint16), + ("status", c_uint8), + ] + + +class c_nvmlGridLicensableFeature_v4_t(_PrintableStructure): + _fields_ = [ + ("featureCode", _nvmlGridLicenseFeatureCode_t), + ("featureState", c_uint), + ("licenseInfo", c_char * NVML_GRID_LICENSE_BUFFER_SIZE), + ("productName", c_char * NVML_GRID_LICENSE_BUFFER_SIZE), + ("featureEnabled", c_uint), + ("licenseExpiry", c_nvmlGridLicenseExpiry_t), + ] + + +class c_nvmlGridLicensableFeatures_v4_t(_PrintableStructure): + _fields_ = [ + ("isGridLicenseSupported", c_int), + ("licensableFeaturesCount", c_uint), + ( + "gridLicensableFeatures", + c_nvmlGridLicensableFeature_v4_t * NVML_GRID_LICENSE_FEATURE_MAX_COUNT, + ), + ] + + +class c_nvmlGridLicensableFeature_v3_t(_PrintableStructure): + _fields_ = [ + ("featureCode", _nvmlGridLicenseFeatureCode_t), + ("featureState", c_uint), + ("licenseInfo", c_char * NVML_GRID_LICENSE_BUFFER_SIZE), + ("productName", c_char * NVML_GRID_LICENSE_BUFFER_SIZE), + ("featureEnabled", c_uint), + ] + + +class c_nvmlGridLicensableFeatures_v3_t(_PrintableStructure): + _fields_ = [ + ("isGridLicenseSupported", c_int), + ("licensableFeaturesCount", c_uint), + ( + "gridLicensableFeatures", + c_nvmlGridLicensableFeature_v3_t * NVML_GRID_LICENSE_FEATURE_MAX_COUNT, + ), + ] + + +class c_nvmlGridLicensableFeature_v2_t(_PrintableStructure): + _fields_ = [ + ("featureCode", _nvmlGridLicenseFeatureCode_t), + ("featureState", c_uint), + ("licenseInfo", c_char * NVML_GRID_LICENSE_BUFFER_SIZE), + ("productName", c_char * NVML_GRID_LICENSE_BUFFER_SIZE), + ] + + +class c_nvmlGridLicensableFeatures_v2_t(_PrintableStructure): + _fields_ = [ + ("isGridLicenseSupported", c_int), + ("licensableFeaturesCount", c_uint), + ( + "gridLicensableFeatures", + c_nvmlGridLicensableFeature_v2_t * NVML_GRID_LICENSE_FEATURE_MAX_COUNT, + ), + ] + + +class c_nvmlGridLicensableFeature_t(_PrintableStructure): + _fields_ = [ + ("featureCode", _nvmlGridLicenseFeatureCode_t), + ("featureState", c_uint), + ("licenseInfo", c_char * NVML_GRID_LICENSE_BUFFER_SIZE), + ] + + +class c_nvmlGridLicensableFeatures_t(_PrintableStructure): + _fields_ = [ + ("isGridLicenseSupported", c_int), + ("licensableFeaturesCount", c_uint), + ( + "gridLicensableFeatures", + c_nvmlGridLicensableFeature_t * NVML_GRID_LICENSE_FEATURE_MAX_COUNT, + ), + ] + + +class c_nvmlMarginTemperature_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("marginTemperature", c_int), + ] + + +nvmlMarginTemperature_v1 = 0x1000008 + + +## Event structures +class struct_c_nvmlEventSet_t(Structure): + pass # opaque handle + + +c_nvmlEventSet_t = POINTER(struct_c_nvmlEventSet_t) + +nvmlEventTypeSingleBitEccError = 0x0000000000000001 +nvmlEventTypeDoubleBitEccError = 0x0000000000000002 +nvmlEventTypePState = 0x0000000000000004 +nvmlEventTypeXidCriticalError = 0x0000000000000008 +nvmlEventTypeClock = 0x0000000000000010 +nvmlEventTypePowerSourceChange = 0x0000000000000080 +nvmlEventMigConfigChange = 0x0000000000000100 +nvmlEventTypeSingleBitEccErrorStorm = 0x0000000000000200 +nvmlEventTypeDramRetirementEvent = 0x0000000000000400 +nvmlEventTypeDramRetirementFailure = 0x0000000000000800 +nvmlEventTypeNonFatalPoisonError = 0x0000000000001000 +nvmlEventTypeFatalPoisonError = 0x0000000000002000 +nvmlEventTypeGpuUnavailableError = 0x0000000000004000 +nvmlEventTypeGpuRecoveryAction = 0x0000000000008000 +nvmlEventTypeNone = 0x0000000000000000 +nvmlEventTypeAll = ( + nvmlEventTypeNone + | nvmlEventTypeSingleBitEccError + | nvmlEventTypeDoubleBitEccError + | nvmlEventTypePState + | nvmlEventTypeClock + | nvmlEventTypePowerSourceChange + | nvmlEventTypeXidCriticalError + | nvmlEventMigConfigChange + | nvmlEventTypeSingleBitEccErrorStorm + | nvmlEventTypeDramRetirementEvent + | nvmlEventTypeDramRetirementFailure + | nvmlEventTypeNonFatalPoisonError + | nvmlEventTypeFatalPoisonError + | nvmlEventTypeGpuUnavailableError + | nvmlEventTypeGpuRecoveryAction +) + +## Clock Event Reasons defines +nvmlClocksEventReasonGpuIdle = 0x0000000000000001 +nvmlClocksEventReasonApplicationsClocksSetting = 0x0000000000000002 +nvmlClocksEventReasonUserDefinedClocks = nvmlClocksEventReasonApplicationsClocksSetting # deprecated, use nvmlClocksEventReasonApplicationsClocksSetting +nvmlClocksEventReasonSwPowerCap = 0x0000000000000004 +nvmlClocksEventReasonHwSlowdown = 0x0000000000000008 +nvmlClocksEventReasonSyncBoost = 0x0000000000000010 +nvmlClocksEventReasonSwThermalSlowdown = 0x0000000000000020 +nvmlClocksEventReasonHwThermalSlowdown = 0x0000000000000040 +nvmlClocksEventReasonHwPowerBrakeSlowdown = 0x0000000000000080 +nvmlClocksEventReasonDisplayClockSetting = 0x0000000000000100 +nvmlClocksEventReasonNone = 0x0000000000000000 +nvmlClocksEventReasonAll = ( + nvmlClocksEventReasonNone + | nvmlClocksEventReasonGpuIdle + | nvmlClocksEventReasonApplicationsClocksSetting + | nvmlClocksEventReasonSwPowerCap + | nvmlClocksEventReasonHwSlowdown + | nvmlClocksEventReasonSyncBoost + | nvmlClocksEventReasonSwThermalSlowdown + | nvmlClocksEventReasonHwThermalSlowdown + | nvmlClocksEventReasonHwPowerBrakeSlowdown + | nvmlClocksEventReasonDisplayClockSetting +) + +## Following have been deprecated +nvmlClocksThrottleReasonGpuIdle = 0x0000000000000001 +nvmlClocksThrottleReasonApplicationsClocksSetting = 0x0000000000000002 +nvmlClocksThrottleReasonUserDefinedClocks = nvmlClocksThrottleReasonApplicationsClocksSetting # deprecated, use nvmlClocksThrottleReasonApplicationsClocksSetting +nvmlClocksThrottleReasonSwPowerCap = 0x0000000000000004 +nvmlClocksThrottleReasonHwSlowdown = 0x0000000000000008 +nvmlClocksThrottleReasonSyncBoost = 0x0000000000000010 +nvmlClocksThrottleReasonSwThermalSlowdown = 0x0000000000000020 +nvmlClocksThrottleReasonHwThermalSlowdown = 0x0000000000000040 +nvmlClocksThrottleReasonHwPowerBrakeSlowdown = 0x0000000000000080 +nvmlClocksThrottleReasonDisplayClockSetting = 0x0000000000000100 +nvmlClocksThrottleReasonNone = 0x0000000000000000 +nvmlClocksThrottleReasonAll = ( + nvmlClocksThrottleReasonNone + | nvmlClocksThrottleReasonGpuIdle + | nvmlClocksThrottleReasonApplicationsClocksSetting + | nvmlClocksThrottleReasonSwPowerCap + | nvmlClocksThrottleReasonHwSlowdown + | nvmlClocksThrottleReasonSyncBoost + | nvmlClocksThrottleReasonSwThermalSlowdown + | nvmlClocksThrottleReasonHwThermalSlowdown + | nvmlClocksThrottleReasonHwPowerBrakeSlowdown + | nvmlClocksThrottleReasonDisplayClockSetting +) + + +class c_nvmlEventData_t(_PrintableStructure): + _fields_ = [ + ("device", c_nvmlDevice_t), + ("eventType", c_ulonglong), + ("eventData", c_ulonglong), + ("gpuInstanceId", c_uint), + ("computeInstanceId", c_uint), + ] + _fmt_ = {"eventType": "0x%08X"} + + +class c_nvmlAccountingStats_t(_PrintableStructure): + _fields_ = [ + ("gpuUtilization", c_uint), + ("memoryUtilization", c_uint), + ("maxMemoryUsage", c_ulonglong), + ("time", c_ulonglong), + ("startTime", c_ulonglong), + ("isRunning", c_uint), + ("reserved", c_uint * 5), + ] + + +class c_nvmlVgpuVersion_t(Structure): + _fields_ = [("minVersion", c_uint), ("maxVersion", c_uint)] + + +class c_nvmlVgpuMetadata_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("revision", c_uint), + ("guestInfoState", _nvmlVgpuGuestInfoState_t), + ("guestDriverVersion", c_char * NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE), + ("hostDriverVersion", c_char * NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE), + ("reserved", c_uint * 6), + ("vgpuVirtualizationCaps", c_uint), + ("guestVgpuVersion", c_uint), + ("opaqueDataSize", c_uint), + ("opaqueData", c_char * NVML_VGPU_METADATA_OPAQUE_DATA_SIZE), + ] + + +class c_nvmlVgpuPgpuMetadata_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("revision", c_uint), + ("hostDriverVersion", c_char * NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE), + ("pgpuVirtualizationCaps", c_uint), + ("reserved", c_uint * 5), + ("hostSupportedVgpuRange", c_nvmlVgpuVersion_t), + ("opaqueDataSize", c_uint), + ("opaqueData", c_char * NVML_VGPU_PGPU_METADATA_OPAQUE_DATA_SIZE), + ] + + +class c_nvmlVgpuPgpuCompatibility_t(Structure): + _fields_ = [ + ("vgpuVmCompatibility", _nvmlVgpuVmCompatibility_t), + ("compatibilityLimitCode", _nvmlVgpuPgpuCompatibilityLimitCode_t), + ] + + +## vGPU scheduler policy defines +NVML_VGPU_SCHEDULER_POLICY_UNKNOWN = 0 +NVML_VGPU_SCHEDULER_POLICY_BEST_EFFORT = 1 +NVML_VGPU_SCHEDULER_POLICY_EQUAL_SHARE = 2 +NVML_VGPU_SCHEDULER_POLICY_FIXED_SHARE = 3 + +## Supported vGPU scheduler policy count +NVML_SUPPORTED_VGPU_SCHEDULER_POLICY_COUNT = 3 + +NVML_SCHEDULER_SW_MAX_LOG_ENTRIES = 200 + +NVML_VGPU_SCHEDULER_ARR_DEFAULT = 0 +NVML_VGPU_SCHEDULER_ARR_DISABLE = 1 +NVML_VGPU_SCHEDULER_ARR_ENABLE = 2 + + +class c_nvmlVgpuSchedDataWithARR_t(_PrintableStructure): + _fields_ = [ + ("avgFactor", c_uint), + ("timeslice", c_uint), + ] + + +class c_nvmlVgpuSchedData_t(_PrintableStructure): + _fields_ = [ + ("timeslice", c_uint), + ] + + +class c_nvmlVgpuSchedulerParams_t(Union): + _fields_ = [ + ("vgpuSchedDataWithARR", c_nvmlVgpuSchedDataWithARR_t), + ("vgpuSchedData", c_nvmlVgpuSchedData_t), + ] + + +class c_nvmlVgpuSchedulerLogEntry_t(_PrintableStructure): + _fields_ = [ + ("timestamp", c_ulonglong), + ("timeRunTotal", c_ulonglong), + ("timeRun", c_ulonglong), + ("swRunlistId", c_uint), + ("targetTimeSlice", c_ulonglong), + ("cumulativePreemptionTime", c_ulonglong), + ] + + +class c_nvmlVgpuSchedulerLog_t(_PrintableStructure): + _fields_ = [ + ("engineId", c_uint), + ("schedulerPolicy", c_uint), + ("arrMode", c_uint), + ("schedulerParams", c_nvmlVgpuSchedulerParams_t), + ("entriesCount", c_uint), + ( + "logEntries", + c_nvmlVgpuSchedulerLogEntry_t * NVML_SCHEDULER_SW_MAX_LOG_ENTRIES, + ), + ] + + +class c_nvmlVgpuSchedulerGetState_t(_PrintableStructure): + _fields_ = [ + ("schedulerPolicy", c_uint), + ("arrMode", c_uint), + ("schedulerParams", c_nvmlVgpuSchedulerParams_t), + ] + + +class c_nvmlVgpuSchedSetDataWithARR_t(_PrintableStructure): + _fields_ = [ + ("avgFactor", c_uint), + ("frequency", c_uint), + ] + + +class c_nvmlVgpuSchedSetData_t(_PrintableStructure): + _fields_ = [ + ("timeslice", c_uint), + ] + + +class c_nvmlVgpuSchedulerSetParams_t(Union): + _fields_ = [ + ("vgpuSchedDataWithARR", c_nvmlVgpuSchedSetDataWithARR_t), + ("vgpuSchedData", c_nvmlVgpuSchedSetData_t), + ] + + +class c_nvmlVgpuSchedulerSetState_t(_PrintableStructure): + _fields_ = [ + ("schedulerPolicy", c_uint), + ("enableARRMode", c_uint), + ("schedulerParams", c_nvmlVgpuSchedulerSetParams_t), + ] + + +class c_nvmlVgpuSchedulerCapabilities_t(_PrintableStructure): + _fields_ = [ + ("supportedSchedulers", c_uint * NVML_SUPPORTED_VGPU_SCHEDULER_POLICY_COUNT), + ("maxTimeslice", c_uint), + ("minTimeslice", c_uint), + ("isArrModeSupported", c_uint), + ("maxFrequencyForARR", c_uint), + ("minFrequencyForARR", c_uint), + ("maxAvgFactorForARR", c_uint), + ("minAvgFactorForARR", c_uint), + ] + + +class c_nvmlFBCStats_t(Structure): + _fields_ = [ + ("sessionsCount", c_uint), + ("averageFPS", c_uint), + ("averageLatency", c_uint), + ] + + +class c_nvmlFBCSession_t(_PrintableStructure): + _fields_ = [ + ("sessionId", c_uint), + ("pid", c_uint), + ("vgpuInstance", _nvmlVgpuInstance_t), + ("displayOrdinal", c_uint), + ("sessionType", c_uint), + ("sessionFlags", c_uint), + ("hMaxResolution", c_uint), + ("vMaxResolution", c_uint), + ("hResolution", c_uint), + ("vResolution", c_uint), + ("averageFPS", c_uint), + ("averageLatency", c_uint), + ] + + +NVML_DEVICE_MIG_DISABLE = 0x0 +NVML_DEVICE_MIG_ENABLE = 0x1 + +NVML_GPU_INSTANCE_PROFILE_1_SLICE = 0x0 +NVML_GPU_INSTANCE_PROFILE_2_SLICE = 0x1 +NVML_GPU_INSTANCE_PROFILE_3_SLICE = 0x2 +NVML_GPU_INSTANCE_PROFILE_4_SLICE = 0x3 +NVML_GPU_INSTANCE_PROFILE_7_SLICE = 0x4 +NVML_GPU_INSTANCE_PROFILE_8_SLICE = 0x5 +NVML_GPU_INSTANCE_PROFILE_6_SLICE = 0x6 +NVML_GPU_INSTANCE_PROFILE_1_SLICE_REV1 = 0x7 +NVML_GPU_INSTANCE_PROFILE_2_SLICE_REV1 = 0x8 +NVML_GPU_INSTANCE_PROFILE_1_SLICE_REV2 = 0x9 +NVML_GPU_INSTANCE_PROFILE_1_SLICE_GFX = 0xA +NVML_GPU_INSTANCE_PROFILE_2_SLICE_GFX = 0xB +NVML_GPU_INSTANCE_PROFILE_4_SLICE_GFX = 0xC +NVML_GPU_INSTANCE_PROFILE_COUNT = 0xD + + +class c_nvmlGpuInstancePlacement_t(Structure): + _fields_ = [("start", c_uint), ("size", c_uint)] + + +class c_nvmlGpuInstanceProfileInfo_t(Structure): + _fields_ = [ + ("id", c_uint), + ("isP2pSupported", c_uint), + ("sliceCount", c_uint), + ("instanceCount", c_uint), + ("multiprocessorCount", c_uint), + ("copyEngineCount", c_uint), + ("decoderCount", c_uint), + ("encoderCount", c_uint), + ("jpegCount", c_uint), + ("ofaCount", c_uint), + ("memorySizeMB", c_ulonglong), + ] + + +nvmlGpuInstanceProfileInfo_v2 = 0x02000098 + + +class c_nvmlGpuInstanceProfileInfo_v2_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("id", c_uint), + ("isP2pSupported", c_uint), + ("sliceCount", c_uint), + ("instanceCount", c_uint), + ("multiprocessorCount", c_uint), + ("copyEngineCount", c_uint), + ("decoderCount", c_uint), + ("encoderCount", c_uint), + ("jpegCount", c_uint), + ("ofaCount", c_uint), + ("memorySizeMB", c_ulonglong), + ("name", c_char * NVML_DEVICE_NAME_V2_BUFFER_SIZE), + ] + + def __init__(self): + super(c_nvmlGpuInstanceProfileInfo_v2_t, self).__init__( + version=nvmlGpuInstanceProfileInfo_v2 + ) + + +class c_nvmlGpuInstanceInfo_t(Structure): + _fields_ = [ + ("device", c_nvmlDevice_t), + ("id", c_uint), + ("profileId", c_uint), + ("placement", c_nvmlGpuInstancePlacement_t), + ] + + +class struct_c_nvmlGpuInstance_t(Structure): + pass # opaque handle + + +c_nvmlGpuInstance_t = POINTER(struct_c_nvmlGpuInstance_t) + +NVML_COMPUTE_INSTANCE_PROFILE_1_SLICE = 0x0 +NVML_COMPUTE_INSTANCE_PROFILE_2_SLICE = 0x1 +NVML_COMPUTE_INSTANCE_PROFILE_3_SLICE = 0x2 +NVML_COMPUTE_INSTANCE_PROFILE_4_SLICE = 0x3 +NVML_COMPUTE_INSTANCE_PROFILE_7_SLICE = 0x4 +NVML_COMPUTE_INSTANCE_PROFILE_8_SLICE = 0x5 +NVML_COMPUTE_INSTANCE_PROFILE_6_SLICE = 0x6 +NVML_COMPUTE_INSTANCE_PROFILE_1_SLICE_REV1 = 0x7 +NVML_COMPUTE_INSTANCE_PROFILE_COUNT = 0x8 + +NVML_COMPUTE_INSTANCE_ENGINE_PROFILE_SHARED = 0x0 +NVML_COMPUTE_INSTANCE_ENGINE_PROFILE_COUNT = 0x1 + + +class c_nvmlComputeInstancePlacement_t(Structure): + _fields_ = [("start", c_uint), ("size", c_uint)] + + +class c_nvmlComputeInstanceProfileInfo_t(Structure): + _fields_ = [ + ("id", c_uint), + ("sliceCount", c_uint), + ("instanceCount", c_uint), + ("multiprocessorCount", c_uint), + ("sharedCopyEngineCount", c_uint), + ("sharedDecoderCount", c_uint), + ("sharedEncoderCount", c_uint), + ("sharedJpegCount", c_uint), + ("sharedOfaCount", c_uint), + ] + + +nvmlComputeInstanceProfileInfo_v2 = 0x02000088 + + +class c_nvmlComputeInstanceProfileInfo_v2_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("id", c_uint), + ("sliceCount", c_uint), + ("instanceCount", c_uint), + ("multiprocessorCount", c_uint), + ("sharedCopyEngineCount", c_uint), + ("sharedDecoderCount", c_uint), + ("sharedEncoderCount", c_uint), + ("sharedJpegCount", c_uint), + ("sharedOfaCount", c_uint), + ("name", c_char * NVML_DEVICE_NAME_V2_BUFFER_SIZE), + ] + + def __init__(self): + super(c_nvmlComputeInstanceProfileInfo_v2_t, self).__init__( + version=nvmlComputeInstanceProfileInfo_v2 + ) + + +class c_nvmlComputeInstanceInfo_t(Structure): + _fields_ = [ + ("device", c_nvmlDevice_t), + ("gpuInstance", c_nvmlGpuInstance_t), + ("id", c_uint), + ("profileId", c_uint), + ("placement", c_nvmlComputeInstancePlacement_t), + ] + + +NVML_MAX_GPU_UTILIZATIONS = 8 +NVML_GPU_UTILIZATION_DOMAIN_GPU = 0 +NVML_GPU_UTILIZATION_DOMAIN_FB = 1 +NVML_GPU_UTILIZATION_DOMAIN_VID = 2 +NVML_GPU_UTILIZATION_DOMAIN_BUS = 3 + + +class c_nvmlGpuDynamicPstatesUtilization_t(Structure): + _fields_ = [ + ("bIsPresent", c_uint, 1), + ("percentage", c_uint), + ("incThreshold", c_uint), + ("decThreshold", c_uint), + ] + + +class c_nvmlGpuDynamicPstatesInfo_t(Structure): + _fields_ = [ + ("flags", c_uint), + ( + "utilization", + c_nvmlGpuDynamicPstatesUtilization_t * NVML_MAX_GPU_UTILIZATIONS, + ), + ] + + +NVML_MAX_THERMAL_SENSORS_PER_GPU = 3 + +NVML_THERMAL_TARGET_NONE = 0 +NVML_THERMAL_TARGET_GPU = 1 +NVML_THERMAL_TARGET_MEMORY = 2 +NVML_THERMAL_TARGET_POWER_SUPPLY = 4 +NVML_THERMAL_TARGET_BOARD = 8 +NVML_THERMAL_TARGET_VCD_BOARD = 9 +NVML_THERMAL_TARGET_VCD_INLET = 10 +NVML_THERMAL_TARGET_VCD_OUTLET = 11 +NVML_THERMAL_TARGET_ALL = 15 +NVML_THERMAL_TARGET_UNKNOWN = -1 + +NVML_THERMAL_CONTROLLER_NONE = 0 +NVML_THERMAL_CONTROLLER_GPU_INTERNAL = 1 +NVML_THERMAL_CONTROLLER_ADM1032 = 2 +NVML_THERMAL_CONTROLLER_ADT7461 = 3 +NVML_THERMAL_CONTROLLER_MAX6649 = 4 +NVML_THERMAL_CONTROLLER_MAX1617 = 5 +NVML_THERMAL_CONTROLLER_LM99 = 6 +NVML_THERMAL_CONTROLLER_LM89 = 7 +NVML_THERMAL_CONTROLLER_LM64 = 8 +NVML_THERMAL_CONTROLLER_G781 = 9 +NVML_THERMAL_CONTROLLER_ADT7473 = 10 +NVML_THERMAL_CONTROLLER_SBMAX6649 = 11 +NVML_THERMAL_CONTROLLER_VBIOSEVT = 12 +NVML_THERMAL_CONTROLLER_OS = 13 +NVML_THERMAL_CONTROLLER_NVSYSCON_CANOAS = 14 +NVML_THERMAL_CONTROLLER_NVSYSCON_E551 = 15 +NVML_THERMAL_CONTROLLER_MAX6649R = 16 +NVML_THERMAL_CONTROLLER_ADT7473S = 17 +NVML_THERMAL_CONTROLLER_UNKNOWN = -1 + + +class c_nvmlGpuThermalSensor_t(Structure): + _fields_ = [ + ("controller", c_int), + ("defaultMinTemp", c_int), + ("defaultMaxTemp", c_int), + ("currentTemp", c_int), + ("target", c_int), + ] + + +class c_nvmlGpuThermalSettings_t(Structure): + _fields_ = [ + ("count", c_uint), + ("sensor", c_nvmlGpuThermalSensor_t * NVML_MAX_THERMAL_SENSORS_PER_GPU), + ] + + +_nvmlCoolerControl_t = c_uint +NVML_THERMAL_COOLER_SIGNAL_NONE = 0 +NVML_THERMAL_COOLER_SIGNAL_TOGGLE = 1 +NVML_THERMAL_COOLER_SIGNAL_VARIABLE = 2 +NVML_THERMAL_COOLER_SIGNAL_COUNT = 3 + +_nvmlCoolerTarget_t = c_uint +NVML_THERMAL_COOLER_TARGET_NONE = 1 << 0 +NVML_THERMAL_COOLER_TARGET_GPU = 1 << 1 +NVML_THERMAL_COOLER_TARGET_MEMORY = 1 << 2 +NVML_THERMAL_COOLER_TARGET_POWER_SUPPLY = 1 << 3 +NVML_THERMAL_COOLER_TARGET_GPU_RELATED = ( + NVML_THERMAL_COOLER_TARGET_GPU + | NVML_THERMAL_COOLER_TARGET_MEMORY + | NVML_THERMAL_COOLER_TARGET_POWER_SUPPLY +) + + +class c_nvmlCoolerInfo_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("index", c_uint), + ("coolerControlType", _nvmlCoolerControl_t), + ("coolerTarget", _nvmlCoolerTarget_t), + ] + + +nvmlCoolerInfo_v1 = 0x1000010 + + +def nvmlDeviceGetCoolerInfo(handle): + c_coolerInfo = c_nvmlCoolerInfo_t() + c_coolerInfo.version = nvmlCoolerInfo_v1 + c_coolerInfo.index = 0 + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCoolerInfo") + ret = fn(handle, byref(c_coolerInfo)) + _nvmlCheckReturn(ret) + return [c_coolerInfo.coolerControlType, c_coolerInfo.coolerTarget] + + +class struct_c_nvmlComputeInstance_t(Structure): + pass # opaque handle + + +c_nvmlComputeInstance_t = POINTER(struct_c_nvmlComputeInstance_t) + + +class c_nvmlDeviceAttributes(Structure): + _fields_ = [ + ("multiprocessorCount", c_uint), + ("sharedCopyEngineCount", c_uint), + ("sharedDecoderCount", c_uint), + ("sharedEncoderCount", c_uint), + ("sharedJpegCount", c_uint), + ("sharedOfaCount", c_uint), + ("gpuInstanceSliceCount", c_uint), + ("computeInstanceSliceCount", c_uint), + ("memorySizeMB", c_ulonglong), + ] + + +class c_nvmlRowRemapperHistogramValues(Structure): + _fields_ = [ + ("max", c_uint), + ("high", c_uint), + ("partial", c_uint), + ("low", c_uint), + ("none", c_uint), + ] + + +NVML_GPU_CERT_CHAIN_SIZE = 0x1000 +NVML_GPU_ATTESTATION_CERT_CHAIN_SIZE = 0x1400 +NVML_CC_GPU_CEC_NONCE_SIZE = 0x20 +NVML_CC_GPU_ATTESTATION_REPORT_SIZE = 0x2000 +NVML_CC_GPU_CEC_ATTESTATION_REPORT_SIZE = 0x1000 +NVML_CC_CEC_ATTESTATION_REPORT_NOT_PRESENT = 0 +NVML_CC_CEC_ATTESTATION_REPORT_PRESENT = 1 + + +class c_nvmlConfComputeSystemState_t(Structure): + _fields_ = [ + ("environment", c_uint), + ("ccFeature", c_uint), + ("devToolsMode", c_uint), + ] + + +nvmlSystemConfComputeSettings_v1 = 0x1000014 + + +class c_nvmlSystemConfComputeSettings_v1_t(Structure): + _fields_ = [ + ("version", c_uint), + ("environment", c_uint), + ("ccFeature", c_uint), + ("devToolsMode", c_uint), + ("multiGpuMode", c_uint), + ] + + def __init__(self): + super(c_nvmlSystemConfComputeSettings_v1_t, self).__init__( + version=nvmlSystemConfComputeSettings_v1 + ) + + +class c_nvmlConfComputeSystemCaps_t(Structure): + _fields_ = [ + ("cpuCaps", c_uint), + ("gpusCaps", c_uint), + ] + + +class c_nvmlConfComputeMemSizeInfo_t(Structure): + _fields_ = [ + ("protectedMemSizeKib", c_ulonglong), + ("unprotectedMemSizeKib", c_ulonglong), + ] + + +class c_nvmlConfComputeGpuCertificate_t(Structure): + _fields_ = [ + ("certChainSize", c_uint), + ("attestationCertChainSize", c_uint), + ("certChain", c_uint8 * NVML_GPU_CERT_CHAIN_SIZE), + ("attestationCertChain", c_uint8 * NVML_GPU_ATTESTATION_CERT_CHAIN_SIZE), + ] + + +class c_nvmlConfComputeGpuAttestationReport_t(Structure): + _fields_ = [ + ("isCecAttestationReportPresent", c_uint), + ("attestationReportSize", c_uint), + ("cecAttestationReportSize", c_uint), + ("nonce", c_uint8 * NVML_CC_GPU_CEC_NONCE_SIZE), + ("attestationReport", c_uint8 * NVML_CC_GPU_ATTESTATION_REPORT_SIZE), + ("cecAttestationReport", c_uint8 * NVML_CC_GPU_CEC_ATTESTATION_REPORT_SIZE), + ] + + +class c_nvmlConfComputeSetKeyRotationThresholdInfo_t(Structure): + _fields_ = [ + ("version", c_uint), + ("maxAttackerAdvantage", c_ulong), + ] + + +ConfComputeSetKeyRotationThresholdInfo_v1 = 0x1000010 + + +class c_nvmlConfComputeGetKeyRotationThresholdInfo_t(Structure): + _fields_ = [ + ("version", c_uint), + ("attackerAdvantage", c_ulong), + ] + + +ConfComputeGetKeyRotationThresholdInfo_v1 = 0x1000010 + + +## string/bytes conversion for ease of use +def convertStrBytes(func): + """ + In python 3, strings are unicode instead of bytes, and need to be converted for ctypes + Args from caller: (1, 'string', <__main__.c_nvmlDevice_t at 0xFFFFFFFF>) + Args passed to function: (1, b'string', <__main__.c_nvmlDevice_t at 0xFFFFFFFF)> + ---- + Returned from function: b'returned string' + Returned to caller: 'returned string' + """ + + @wraps(func) + def wrapper(*args, **kwargs): + # encoding a str returns bytes in python 2 and 3 + args = [arg.encode() if isinstance(arg, str) else arg for arg in args] + res = func(*args, **kwargs) + # In python 2, str and bytes are the same + # In python 3, str is unicode and should be decoded. + # Ctypes handles most conversions, this only effects c_char and char arrays. + if isinstance(res, bytes): + if isinstance(res, str): + return res + return res.decode() + return res + + if sys.version_info >= (3,): + return wrapper + return func + + +def throwOnVersionMismatch(func): + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except NVMLError_FunctionNotFound: + raise NVMLLibraryMismatchError( + "Unversioned function called and the " + "pyNVML version does not match the NVML lib version. " + "Either use matching pyNVML and NVML lib versions or " + "use a versioned function such as " + func.__name__ + "_v2" + ) + + return wrapper + + +## C function wrappers ## +def nvmlInitWithFlags(flags): + _LoadNvmlLibrary() + + # + # Initialize the library + # + fn = _nvmlGetFunctionPointer("nvmlInitWithFlags") + ret = fn(flags) + _nvmlCheckReturn(ret) + + # Atomically update refcount + global _nvmlLib_refcount + libLoadLock.acquire() + _nvmlLib_refcount += 1 + libLoadLock.release() + return None + + +def nvmlInit(): + nvmlInitWithFlags(0) + return None + + +def _LoadNvmlLibrary(): + """ + Load the library if it isn't loaded already + """ + global nvmlLib + + if nvmlLib == None: + # lock to ensure only one caller loads the library + libLoadLock.acquire() + + try: + # ensure the library still isn't loaded + if nvmlLib == None: + try: + if sys.platform[:3] == "win": + # cdecl calling convention + try: + # Check for nvml.dll in System32 first for DCH drivers + nvmlLib = CDLL( + os.path.join( + os.getenv("WINDIR", "C:/Windows"), + "System32/nvml.dll", + ) + ) + except OSError as ose: + # If nvml.dll is not found in System32, it should be in ProgramFiles + # load nvml.dll from %ProgramFiles%/NVIDIA Corporation/NVSMI/nvml.dll + nvmlLib = CDLL( + os.path.join( + os.getenv("ProgramFiles", "C:/Program Files"), + "NVIDIA Corporation/NVSMI/nvml.dll", + ) + ) + else: + # assume linux + nvmlLib = CDLL("libnvidia-ml.so.1") + except OSError as ose: + _nvmlCheckReturn(NVML_ERROR_LIBRARY_NOT_FOUND) + if nvmlLib == None: + _nvmlCheckReturn(NVML_ERROR_LIBRARY_NOT_FOUND) + finally: + # lock is always freed + libLoadLock.release() + + +def nvmlShutdown(): + # + # Leave the library loaded, but shutdown the interface + # + fn = _nvmlGetFunctionPointer("nvmlShutdown") + ret = fn() + _nvmlCheckReturn(ret) + + # Atomically update refcount + global _nvmlLib_refcount + libLoadLock.acquire() + if 0 < _nvmlLib_refcount: + _nvmlLib_refcount -= 1 + libLoadLock.release() + return None + + +# Added in 2.285 +@convertStrBytes +def nvmlErrorString(result): + fn = _nvmlGetFunctionPointer("nvmlErrorString") + fn.restype = c_char_p # otherwise return is an int + ret = fn(result) + return ret + + +# Added in 2.285 +@convertStrBytes +def nvmlSystemGetNVMLVersion(): + c_version = create_string_buffer(NVML_SYSTEM_NVML_VERSION_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlSystemGetNVMLVersion") + ret = fn(c_version, c_uint(NVML_SYSTEM_NVML_VERSION_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_version.value + + +def nvmlSystemGetCudaDriverVersion(): + c_cuda_version = c_int() + fn = _nvmlGetFunctionPointer("nvmlSystemGetCudaDriverVersion") + ret = fn(byref(c_cuda_version)) + _nvmlCheckReturn(ret) + return c_cuda_version.value + + +def nvmlSystemGetCudaDriverVersion_v2(): + c_cuda_version = c_int() + fn = _nvmlGetFunctionPointer("nvmlSystemGetCudaDriverVersion_v2") + ret = fn(byref(c_cuda_version)) + _nvmlCheckReturn(ret) + return c_cuda_version.value + + +# Added in 2.285 +@convertStrBytes +def nvmlSystemGetProcessName(pid): + c_name = create_string_buffer(1024) + fn = _nvmlGetFunctionPointer("nvmlSystemGetProcessName") + ret = fn(c_uint(pid), c_name, c_uint(1024)) + _nvmlCheckReturn(ret) + return c_name.value + + +@convertStrBytes +def nvmlSystemGetDriverVersion(): + c_version = create_string_buffer(NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlSystemGetDriverVersion") + ret = fn(c_version, c_uint(NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_version.value + + +# Added in 2.285 +def nvmlSystemGetHicVersion(): + c_count = c_uint(0) + hics = None + fn = _nvmlGetFunctionPointer("nvmlSystemGetHicVersion") + + # get the count + ret = fn(byref(c_count), None) + + # this should only fail with insufficient size + if (ret != NVML_SUCCESS) and (ret != NVML_ERROR_INSUFFICIENT_SIZE): + raise NVMLError(ret) + + # If there are no hics + if c_count.value == 0: + return [] + + hic_array = c_nvmlHwbcEntry_t * c_count.value + hics = hic_array() + ret = fn(byref(c_count), hics) + _nvmlCheckReturn(ret) + return hics + + +def nvmlSystemGetDriverBranch(): + c_branchInfo = c_nvmlSystemDriverBranchInfo_v1_t(0) + c_branchInfo.version = SystemDriverBranchInfo_v1 + fn = _nvmlGetFunctionPointer("nvmlSystemGetDriverBranch") + ret = fn(byref(c_branchInfo), c_uint(NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_branchInfo + + +## Unit get functions +def nvmlUnitGetCount(): + c_count = c_uint() + fn = _nvmlGetFunctionPointer("nvmlUnitGetCount") + ret = fn(byref(c_count)) + _nvmlCheckReturn(ret) + return c_count.value + + +def nvmlUnitGetHandleByIndex(index): + c_index = c_uint(index) + unit = c_nvmlUnit_t() + fn = _nvmlGetFunctionPointer("nvmlUnitGetHandleByIndex") + ret = fn(c_index, byref(unit)) + _nvmlCheckReturn(ret) + return unit + + +def nvmlUnitGetUnitInfo(unit): + c_info = c_nvmlUnitInfo_t() + fn = _nvmlGetFunctionPointer("nvmlUnitGetUnitInfo") + ret = fn(unit, byref(c_info)) + _nvmlCheckReturn(ret) + return c_info + + +def nvmlUnitGetLedState(unit): + c_state = c_nvmlLedState_t() + fn = _nvmlGetFunctionPointer("nvmlUnitGetLedState") + ret = fn(unit, byref(c_state)) + _nvmlCheckReturn(ret) + return c_state + + +def nvmlUnitGetPsuInfo(unit): + c_info = c_nvmlPSUInfo_t() + fn = _nvmlGetFunctionPointer("nvmlUnitGetPsuInfo") + ret = fn(unit, byref(c_info)) + _nvmlCheckReturn(ret) + return c_info + + +def nvmlUnitGetTemperature(unit, type): + c_temp = c_uint() + fn = _nvmlGetFunctionPointer("nvmlUnitGetTemperature") + ret = fn(unit, c_uint(type), byref(c_temp)) + _nvmlCheckReturn(ret) + return c_temp.value + + +def nvmlUnitGetFanSpeedInfo(unit): + c_speeds = c_nvmlUnitFanSpeeds_t() + fn = _nvmlGetFunctionPointer("nvmlUnitGetFanSpeedInfo") + ret = fn(unit, byref(c_speeds)) + _nvmlCheckReturn(ret) + return c_speeds + + +# added to API +def nvmlUnitGetDeviceCount(unit): + c_count = c_uint(0) + # query the unit to determine device count + fn = _nvmlGetFunctionPointer("nvmlUnitGetDevices") + ret = fn(unit, byref(c_count), None) + if ret == NVML_ERROR_INSUFFICIENT_SIZE: + ret = NVML_SUCCESS + _nvmlCheckReturn(ret) + return c_count.value + + +def nvmlUnitGetDevices(unit): + c_count = c_uint(nvmlUnitGetDeviceCount(unit)) + device_array = c_nvmlDevice_t * c_count.value + c_devices = device_array() + fn = _nvmlGetFunctionPointer("nvmlUnitGetDevices") + ret = fn(unit, byref(c_count), c_devices) + _nvmlCheckReturn(ret) + return c_devices + + +## Device get functions +def nvmlDeviceGetCount(): + c_count = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCount_v2") + ret = fn(byref(c_count)) + _nvmlCheckReturn(ret) + return c_count.value + + +def nvmlDeviceGetHandleByIndex(index): + c_index = c_uint(index) + device = c_nvmlDevice_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetHandleByIndex_v2") + ret = fn(c_index, byref(device)) + _nvmlCheckReturn(ret) + return device + + +@convertStrBytes +def nvmlDeviceGetHandleBySerial(serial): + c_serial = c_char_p(serial) + device = c_nvmlDevice_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetHandleBySerial") + ret = fn(c_serial, byref(device)) + _nvmlCheckReturn(ret) + return device + + +@convertStrBytes +def nvmlDeviceGetHandleByUUID(uuid): + c_uuid = c_char_p(uuid) + device = c_nvmlDevice_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetHandleByUUID") + ret = fn(c_uuid, byref(device)) + _nvmlCheckReturn(ret) + return device + + +@convertStrBytes +def nvmlDeviceGetHandleByPciBusId(pciBusId): + c_busId = c_char_p(pciBusId) + device = c_nvmlDevice_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetHandleByPciBusId_v2") + ret = fn(c_busId, byref(device)) + _nvmlCheckReturn(ret) + return device + + +@convertStrBytes +def nvmlDeviceGetName(handle): + c_name = create_string_buffer(NVML_DEVICE_NAME_V2_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetName") + ret = fn(handle, c_name, c_uint(NVML_DEVICE_NAME_V2_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_name.value + + +class c_nvmlDevicePerfModes_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("str", c_char * NVML_PERF_MODES_BUFFER_SIZE), + ] + + +nvmlDevicePerfModes_v1 = 0x1000804 + + +@convertStrBytes +def nvmlDeviceGetPerformanceModes(handle): + perfModes = c_nvmlDevicePerfModes_v1_t() + perfModes.version = nvmlDevicePerfModes_v1 + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPerformanceModes") + ret = fn(handle, byref(perfModes)) + _nvmlCheckReturn(ret) + return perfModes.str + + +class c_nvmlDeviceCurrentClockFreqs_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("str", c_char * NVML_PERF_MODES_BUFFER_SIZE), + ] + + +nvmlDeviceCurrentClockFreqs_v1 = 0x1000804 + + +@convertStrBytes +def nvmlDeviceGetCurrentClockFreqs(handle): + currentClockFreqs = c_nvmlDeviceCurrentClockFreqs_v1_t() + currentClockFreqs.version = nvmlDeviceCurrentClockFreqs_v1 + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCurrentClockFreqs") + ret = fn(handle, byref(currentClockFreqs)) + _nvmlCheckReturn(ret) + return currentClockFreqs.str + + +def nvmlDeviceGetBoardId(handle): + c_id = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetBoardId") + ret = fn(handle, byref(c_id)) + _nvmlCheckReturn(ret) + return c_id.value + + +def nvmlDeviceGetMultiGpuBoard(handle): + c_multiGpu = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMultiGpuBoard") + ret = fn(handle, byref(c_multiGpu)) + _nvmlCheckReturn(ret) + return c_multiGpu.value + + +def nvmlDeviceGetBrand(handle): + c_type = _nvmlBrandType_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetBrand") + ret = fn(handle, byref(c_type)) + _nvmlCheckReturn(ret) + return c_type.value + + +def nvmlDeviceGetC2cModeInfoV1(handle): + c_info = c_nvmlC2cModeInfo_v1_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetC2cModeInfoV") + ret = fn(handle, byref(c_info)) + _nvmlCheckReturn(ret) + return c_info + + +def nvmlDeviceGetC2cModeInfoV(handle): + return nvmlDeviceGetC2cModeInfoV1(handle) + + +@convertStrBytes +def nvmlDeviceGetBoardPartNumber(handle): + c_part_number = create_string_buffer(NVML_DEVICE_PART_NUMBER_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetBoardPartNumber") + ret = fn(handle, c_part_number, c_uint(NVML_DEVICE_PART_NUMBER_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_part_number.value + + +@convertStrBytes +def nvmlDeviceGetSerial(handle): + c_serial = create_string_buffer(NVML_DEVICE_SERIAL_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSerial") + ret = fn(handle, c_serial, c_uint(NVML_DEVICE_SERIAL_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_serial.value + + +def nvmlDeviceGetModuleId(handle, moduleId=c_uint()): + isReference = type(moduleId) is not c_uint + moduleIdRef = moduleId if isReference else byref(moduleId) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetModuleId") + ret = fn(handle, moduleIdRef) + if isReference: + return ret + else: + _nvmlCheckReturn(ret) + return moduleId.value + + +def nvmlDeviceGetMemoryAffinity(handle, nodeSetSize, scope): + affinity_array = c_ulonglong * nodeSetSize + c_affinity = affinity_array() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMemoryAffinity") + ret = fn(handle, nodeSetSize, byref(c_affinity), _nvmlAffinityScope_t(scope)) + _nvmlCheckReturn(ret) + return c_affinity + + +def nvmlDeviceGetCpuAffinityWithinScope(handle, cpuSetSize, scope): + affinity_array = c_ulonglong * cpuSetSize + c_affinity = affinity_array() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCpuAffinityWithinScope") + ret = fn(handle, cpuSetSize, byref(c_affinity), _nvmlAffinityScope_t(scope)) + _nvmlCheckReturn(ret) + return c_affinity + + +def nvmlDeviceGetCpuAffinity(handle, cpuSetSize): + affinity_array = c_ulonglong * cpuSetSize + c_affinity = affinity_array() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCpuAffinity") + ret = fn(handle, cpuSetSize, byref(c_affinity)) + _nvmlCheckReturn(ret) + return c_affinity + + +def nvmlDeviceSetCpuAffinity(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetCpuAffinity") + ret = fn(handle) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceClearCpuAffinity(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceClearCpuAffinity") + ret = fn(handle) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceGetNumaNodeId(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNumaNodeId") + node = c_int() + ret = fn(handle, byref(node)) + _nvmlCheckReturn(ret) + return node.value + + +def nvmlDeviceGetMinorNumber(handle): + c_minor_number = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMinorNumber") + ret = fn(handle, byref(c_minor_number)) + _nvmlCheckReturn(ret) + return c_minor_number.value + + +@convertStrBytes +def nvmlDeviceGetUUID(handle): + c_uuid = create_string_buffer(NVML_DEVICE_UUID_V2_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetUUID") + ret = fn(handle, c_uuid, c_uint(NVML_DEVICE_UUID_V2_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_uuid.value + + +@convertStrBytes +def nvmlDeviceGetInforomVersion(handle, infoRomObject): + c_version = create_string_buffer(NVML_DEVICE_INFOROM_VERSION_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetInforomVersion") + ret = fn( + handle, + _nvmlInforomObject_t(infoRomObject), + c_version, + c_uint(NVML_DEVICE_INFOROM_VERSION_BUFFER_SIZE), + ) + _nvmlCheckReturn(ret) + return c_version.value + + +# Added in 4.304 +@convertStrBytes +def nvmlDeviceGetInforomImageVersion(handle): + c_version = create_string_buffer(NVML_DEVICE_INFOROM_VERSION_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetInforomImageVersion") + ret = fn(handle, c_version, c_uint(NVML_DEVICE_INFOROM_VERSION_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_version.value + + +# Added in 4.304 +def nvmlDeviceGetInforomConfigurationChecksum(handle): + c_checksum = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetInforomConfigurationChecksum") + ret = fn(handle, byref(c_checksum)) + _nvmlCheckReturn(ret) + return c_checksum.value + + +# Added in 4.304 +def nvmlDeviceValidateInforom(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceValidateInforom") + ret = fn(handle) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceGetLastBBXFlushTime(handle): + c_timestamp = c_ulonglong() + c_durationUs = c_ulong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetLastBBXFlushTime") + ret = fn(handle, byref(c_timestamp), byref(c_durationUs)) + _nvmlCheckReturn(ret) + return [c_timestamp.value, c_durationUs.value] + + +def nvmlDeviceGetDisplayMode(handle): + c_mode = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDisplayMode") + ret = fn(handle, byref(c_mode)) + _nvmlCheckReturn(ret) + return c_mode.value + + +def nvmlDeviceGetDisplayActive(handle): + c_mode = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDisplayActive") + ret = fn(handle, byref(c_mode)) + _nvmlCheckReturn(ret) + return c_mode.value + + +def nvmlDeviceGetPersistenceMode(handle): + c_state = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPersistenceMode") + ret = fn(handle, byref(c_state)) + _nvmlCheckReturn(ret) + return c_state.value + + +def nvmlDeviceGetPciInfoExt(handle, c_info): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPciInfoExt") + ret = fn(handle, c_info) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceGetPciInfo_v3(handle): + c_info = nvmlPciInfo_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPciInfo_v3") + ret = fn(handle, byref(c_info)) + _nvmlCheckReturn(ret) + return c_info + + +def nvmlDeviceGetPciInfo(handle): + return nvmlDeviceGetPciInfo_v3(handle) + + +def nvmlDeviceGetClockInfo(handle, type): + c_clock = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetClockInfo") + ret = fn(handle, _nvmlClockType_t(type), byref(c_clock)) + _nvmlCheckReturn(ret) + return c_clock.value + + +# Added in 2.285 +def nvmlDeviceGetMaxClockInfo(handle, type): + c_clock = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMaxClockInfo") + ret = fn(handle, _nvmlClockType_t(type), byref(c_clock)) + _nvmlCheckReturn(ret) + return c_clock.value + + +# Added in 4.304 +def nvmlDeviceGetApplicationsClock(handle, type): + c_clock = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetApplicationsClock") + ret = fn(handle, _nvmlClockType_t(type), byref(c_clock)) + _nvmlCheckReturn(ret) + return c_clock.value + + +def nvmlDeviceGetMaxCustomerBoostClock(handle, type): + c_clock = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMaxCustomerBoostClock") + ret = fn(handle, _nvmlClockType_t(type), byref(c_clock)) + _nvmlCheckReturn(ret) + return c_clock.value + + +def nvmlDeviceGetClock(handle, type, id): + c_clock = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetClock") + ret = fn(handle, _nvmlClockType_t(type), _nvmlClockId_t(id), byref(c_clock)) + _nvmlCheckReturn(ret) + return c_clock.value + + +# Added in 5.319 +def nvmlDeviceGetDefaultApplicationsClock(handle, type): + c_clock = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDefaultApplicationsClock") + ret = fn(handle, _nvmlClockType_t(type), byref(c_clock)) + _nvmlCheckReturn(ret) + return c_clock.value + + +# Added in 4.304 +def nvmlDeviceGetSupportedMemoryClocks(handle): + # first call to get the size + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSupportedMemoryClocks") + ret = fn(handle, byref(c_count), None) + + if ret == NVML_SUCCESS: + # special case, no clocks + return [] + elif ret == NVML_ERROR_INSUFFICIENT_SIZE: + # typical case + clocks_array = c_uint * c_count.value + c_clocks = clocks_array() + + # make the call again + ret = fn(handle, byref(c_count), c_clocks) + _nvmlCheckReturn(ret) + + procs = [] + for i in range(c_count.value): + procs.append(c_clocks[i]) + + return procs + else: + # error case + raise NVMLError(ret) + + +# Added in 4.304 +def nvmlDeviceGetSupportedGraphicsClocks(handle, memoryClockMHz): + # first call to get the size + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSupportedGraphicsClocks") + ret = fn(handle, c_uint(memoryClockMHz), byref(c_count), None) + + if ret == NVML_SUCCESS: + # special case, no clocks + return [] + elif ret == NVML_ERROR_INSUFFICIENT_SIZE: + # typical case + clocks_array = c_uint * c_count.value + c_clocks = clocks_array() + + # make the call again + ret = fn(handle, c_uint(memoryClockMHz), byref(c_count), c_clocks) + _nvmlCheckReturn(ret) + + procs = [] + for i in range(c_count.value): + procs.append(c_clocks[i]) + + return procs + else: + # error case + raise NVMLError(ret) + + +def nvmlDeviceGetFanSpeed(handle): + c_speed = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetFanSpeed") + ret = fn(handle, byref(c_speed)) + _nvmlCheckReturn(ret) + return c_speed.value + + +def nvmlDeviceGetFanSpeed_v2(handle, fan): + c_speed = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetFanSpeed_v2") + ret = fn(handle, fan, byref(c_speed)) + _nvmlCheckReturn(ret) + return c_speed.value + + +class c_nvmlFanSpeedInfo_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("fan", c_uint), + ("speed", c_uint), + ] + + +nvmlFanSpeedInfo_v1 = 0x100000C + + +def nvmlDeviceGetFanSpeedRPM(handle): + c_fanSpeed = c_nvmlFanSpeedInfo_t() + c_fanSpeed.fan = 0 + c_fanSpeed.version = nvmlFanSpeedInfo_v1 + fn = _nvmlGetFunctionPointer("nvmlDeviceGetFanSpeedRPM") + ret = fn(handle, byref(c_fanSpeed)) + _nvmlCheckReturn(ret) + return c_fanSpeed.speed + + +def nvmlDeviceGetTargetFanSpeed(handle, fan): + c_speed = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetTargetFanSpeed") + ret = fn(handle, fan, byref(c_speed)) + _nvmlCheckReturn(ret) + return c_speed.value + + +def nvmlDeviceGetNumFans(device): + c_numFans = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNumFans") + ret = fn(device, byref(c_numFans)) + _nvmlCheckReturn(ret) + return c_numFans.value + + +def nvmlDeviceSetDefaultFanSpeed_v2(handle, index): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetDefaultFanSpeed_v2") + ret = fn(handle, index) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlDeviceGetMinMaxFanSpeed(handle, minSpeed=c_uint(), maxSpeed=c_uint()): + isReference = (type(minSpeed) is not c_uint) or (type(maxSpeed) is not c_uint) + minSpeedRef = minSpeed if isReference else byref(minSpeed) + maxSpeedRef = maxSpeed if isReference else byref(maxSpeed) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMinMaxFanSpeed") + ret = fn(handle, minSpeedRef, maxSpeedRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isReference else [minSpeed.value, maxSpeed.value] + + +def nvmlDeviceGetFanControlPolicy_v2(handle, fan, fanControlPolicy=c_uint()): + isReference = type(fanControlPolicy) is not c_uint + fanControlPolicyRef = fanControlPolicy if isReference else byref(fanControlPolicy) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetFanControlPolicy_v2") + ret = fn(handle, fan, fanControlPolicyRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isReference else fanControlPolicy.value + + +def nvmlDeviceSetFanControlPolicy(handle, fan, fanControlPolicy): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetFanControlPolicy") + ret = fn(handle, fan, _nvmlFanControlPolicy_t(fanControlPolicy)) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +class c_nvmlTemperature_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("sensorType", _nvmlTemperatureSensors_t), + ("temperature", c_int), + ] + + +nvmlTemperature_v1 = 0x100000C + + +def nvmlDeviceGetTemperatureV1(handle, sensor): + c_temp = c_nvmlTemperature_v1_t() + c_temp.version = nvmlTemperature_v1 + c_temp.sensorType = _nvmlTemperatureSensors_t(sensor) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetTemperatureV") + ret = fn(handle, byref(c_temp)) + _nvmlCheckReturn(ret) + return c_temp.temperature + + +def nvmlDeviceGetTemperatureV(handle, sensor, version=nvmlTemperature_v1): + if version == nvmlTemperature_v1: + return nvmlDeviceGetTemperatureV1(handle, sensor) + else: + raise NVMLError(NVML_ERROR_ARGUMENT_VERSION_MISMATCH) + + +# DEPRECATED use nvmlDeviceGetTemperatureV instead +def nvmlDeviceGetTemperature(handle, sensor): + c_temp = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetTemperature") + ret = fn(handle, _nvmlTemperatureSensors_t(sensor), byref(c_temp)) + _nvmlCheckReturn(ret) + return c_temp.value + + +def nvmlDeviceGetTemperatureThreshold(handle, threshold): + c_temp = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetTemperatureThreshold") + ret = fn(handle, _nvmlTemperatureThresholds_t(threshold), byref(c_temp)) + _nvmlCheckReturn(ret) + return c_temp.value + + +def nvmlDeviceSetTemperatureThreshold(handle, threshold, temp): + c_temp = c_uint() + c_temp.value = temp + fn = _nvmlGetFunctionPointer("nvmlDeviceSetTemperatureThreshold") + ret = fn(handle, _nvmlTemperatureThresholds_t(threshold), byref(c_temp)) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceGetMarginTemperature(handle): + c_marginTempInfo = c_nvmlMarginTemperature_v1_t() + c_marginTempInfo.version = nvmlMarginTemperature_v1 + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMarginTemperature") + ret = fn(handle, byref(c_marginTempInfo)) + _nvmlCheckReturn(ret) + return c_marginTempInfo.marginTemperature + + +# DEPRECATED use nvmlDeviceGetPerformanceState +def nvmlDeviceGetPowerState(handle): + c_pstate = _nvmlPstates_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPowerState") + ret = fn(handle, byref(c_pstate)) + _nvmlCheckReturn(ret) + return c_pstate.value + + +def nvmlDeviceGetPerformanceState(handle): + c_pstate = _nvmlPstates_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPerformanceState") + ret = fn(handle, byref(c_pstate)) + _nvmlCheckReturn(ret) + return c_pstate.value + + +def nvmlDeviceGetPowerManagementMode(handle): + c_pcapMode = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPowerManagementMode") + ret = fn(handle, byref(c_pcapMode)) + _nvmlCheckReturn(ret) + return c_pcapMode.value + + +def nvmlDeviceGetPowerManagementLimit(handle): + c_limit = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPowerManagementLimit") + ret = fn(handle, byref(c_limit)) + _nvmlCheckReturn(ret) + return c_limit.value + + +# Added in 4.304 +def nvmlDeviceGetPowerManagementLimitConstraints(handle): + c_minLimit = c_uint() + c_maxLimit = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPowerManagementLimitConstraints") + ret = fn(handle, byref(c_minLimit), byref(c_maxLimit)) + _nvmlCheckReturn(ret) + return [c_minLimit.value, c_maxLimit.value] + + +# Added in 4.304 +def nvmlDeviceGetPowerManagementDefaultLimit(handle): + c_limit = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPowerManagementDefaultLimit") + ret = fn(handle, byref(c_limit)) + _nvmlCheckReturn(ret) + return c_limit.value + + +# Added in 331 +def nvmlDeviceGetEnforcedPowerLimit(handle): + c_limit = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetEnforcedPowerLimit") + ret = fn(handle, byref(c_limit)) + _nvmlCheckReturn(ret) + return c_limit.value + + +def nvmlDeviceGetPowerUsage(handle): + c_watts = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPowerUsage") + ret = fn(handle, byref(c_watts)) + _nvmlCheckReturn(ret) + return c_watts.value + + +def nvmlDeviceGetTotalEnergyConsumption(handle): + c_millijoules = c_uint64() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetTotalEnergyConsumption") + ret = fn(handle, byref(c_millijoules)) + _nvmlCheckReturn(ret) + return c_millijoules.value + + +# Added in 4.304 +def nvmlDeviceGetGpuOperationMode(handle): + c_currState = _nvmlGpuOperationMode_t() + c_pendingState = _nvmlGpuOperationMode_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuOperationMode") + ret = fn(handle, byref(c_currState), byref(c_pendingState)) + _nvmlCheckReturn(ret) + return [c_currState.value, c_pendingState.value] + + +# Added in 4.304 +def nvmlDeviceGetCurrentGpuOperationMode(handle): + return nvmlDeviceGetGpuOperationMode(handle)[0] + + +# Added in 4.304 +def nvmlDeviceGetPendingGpuOperationMode(handle): + return nvmlDeviceGetGpuOperationMode(handle)[1] + + +def nvmlDeviceGetMemoryInfo(handle, version=None): + if not version: + c_memory = c_nvmlMemory_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMemoryInfo") + else: + c_memory = c_nvmlMemory_v2_t() + c_memory.version = version + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMemoryInfo_v2") + ret = fn(handle, byref(c_memory)) + _nvmlCheckReturn(ret) + return c_memory + + +def nvmlDeviceGetBAR1MemoryInfo(handle): + c_bar1_memory = c_nvmlBAR1Memory_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetBAR1MemoryInfo") + ret = fn(handle, byref(c_bar1_memory)) + _nvmlCheckReturn(ret) + return c_bar1_memory + + +def nvmlDeviceGetComputeMode(handle): + c_mode = _nvmlComputeMode_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetComputeMode") + ret = fn(handle, byref(c_mode)) + _nvmlCheckReturn(ret) + return c_mode.value + + +def nvmlDeviceGetCudaComputeCapability(handle): + c_major = c_int() + c_minor = c_int() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCudaComputeCapability") + ret = fn(handle, byref(c_major), byref(c_minor)) + _nvmlCheckReturn(ret) + return (c_major.value, c_minor.value) + + +def nvmlDeviceGetEccMode(handle): + c_currState = _nvmlEnableState_t() + c_pendingState = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetEccMode") + ret = fn(handle, byref(c_currState), byref(c_pendingState)) + _nvmlCheckReturn(ret) + return [c_currState.value, c_pendingState.value] + + +# added to API +def nvmlDeviceGetCurrentEccMode(handle): + return nvmlDeviceGetEccMode(handle)[0] + + +# added to API +def nvmlDeviceGetPendingEccMode(handle): + return nvmlDeviceGetEccMode(handle)[1] + + +def nvmlDeviceGetDefaultEccMode(handle): + c_defaultState = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDefaultEccMode") + ret = fn(handle, byref(c_defaultState)) + _nvmlCheckReturn(ret) + return [c_defaultState.value] + + +def nvmlDeviceGetTotalEccErrors(handle, errorType, counterType): + c_count = c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetTotalEccErrors") + ret = fn( + handle, + _nvmlMemoryErrorType_t(errorType), + _nvmlEccCounterType_t(counterType), + byref(c_count), + ) + _nvmlCheckReturn(ret) + return c_count.value + + +# This is deprecated, instead use nvmlDeviceGetMemoryErrorCounter +def nvmlDeviceGetDetailedEccErrors(handle, errorType, counterType): + c_counts = c_nvmlEccErrorCounts_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDetailedEccErrors") + ret = fn( + handle, + _nvmlMemoryErrorType_t(errorType), + _nvmlEccCounterType_t(counterType), + byref(c_counts), + ) + _nvmlCheckReturn(ret) + return c_counts + + +# Added in 4.304 +def nvmlDeviceGetMemoryErrorCounter(handle, errorType, counterType, locationType): + c_count = c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMemoryErrorCounter") + ret = fn( + handle, + _nvmlMemoryErrorType_t(errorType), + _nvmlEccCounterType_t(counterType), + _nvmlMemoryLocation_t(locationType), + byref(c_count), + ) + _nvmlCheckReturn(ret) + return c_count.value + + +def nvmlDeviceGetUtilizationRates(handle): + c_util = c_nvmlUtilization_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetUtilizationRates") + ret = fn(handle, byref(c_util)) + _nvmlCheckReturn(ret) + return c_util + + +def nvmlDeviceGetEncoderUtilization(handle): + c_util = c_uint() + c_samplingPeriod = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetEncoderUtilization") + ret = fn(handle, byref(c_util), byref(c_samplingPeriod)) + _nvmlCheckReturn(ret) + return [c_util.value, c_samplingPeriod.value] + + +def nvmlDeviceGetDecoderUtilization(handle): + c_util = c_uint() + c_samplingPeriod = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDecoderUtilization") + ret = fn(handle, byref(c_util), byref(c_samplingPeriod)) + _nvmlCheckReturn(ret) + return [c_util.value, c_samplingPeriod.value] + + +def nvmlDeviceGetJpgUtilization(handle): + c_util = c_uint() + c_samplingPeriod = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetJpgUtilization") + ret = fn(handle, byref(c_util), byref(c_samplingPeriod)) + _nvmlCheckReturn(ret) + return [c_util.value, c_samplingPeriod.value] + + +def nvmlDeviceGetOfaUtilization(handle): + c_util = c_uint() + c_samplingPeriod = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetOfaUtilization") + ret = fn(handle, byref(c_util), byref(c_samplingPeriod)) + _nvmlCheckReturn(ret) + return [c_util.value, c_samplingPeriod.value] + + +def nvmlDeviceGetPcieReplayCounter(handle): + c_replay = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPcieReplayCounter") + ret = fn(handle, byref(c_replay)) + _nvmlCheckReturn(ret) + return c_replay.value + + +def nvmlDeviceGetDriverModel(handle): + c_currModel = _nvmlDriverModel_t() + c_pendingModel = _nvmlDriverModel_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDriverModel") + ret = fn(handle, byref(c_currModel), byref(c_pendingModel)) + _nvmlCheckReturn(ret) + return [c_currModel.value, c_pendingModel.value] + + +# added to API +def nvmlDeviceGetCurrentDriverModel(handle): + return nvmlDeviceGetDriverModel(handle)[0] + + +# added to API +def nvmlDeviceGetPendingDriverModel(handle): + return nvmlDeviceGetDriverModel(handle)[1] + + +# Added in 2.285 +@convertStrBytes +def nvmlDeviceGetVbiosVersion(handle): + c_version = create_string_buffer(NVML_DEVICE_VBIOS_VERSION_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVbiosVersion") + ret = fn(handle, c_version, c_uint(NVML_DEVICE_VBIOS_VERSION_BUFFER_SIZE)) + _nvmlCheckReturn(ret) + return c_version.value + + +# Added in 2.285 +def nvmlDeviceGetComputeRunningProcesses_v2(handle): + # first call to get the size + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetComputeRunningProcesses_v2") + ret = fn(handle, byref(c_count), None) + if ret == NVML_SUCCESS: + # special case, no running processes + return [] + elif ret == NVML_ERROR_INSUFFICIENT_SIZE: + # typical case + # oversize the array in case more processes are created + c_count.value = c_count.value * 2 + 5 + proc_array = c_nvmlProcessInfo_v2_t * c_count.value + c_procs = proc_array() + # make the call again + ret = fn(handle, byref(c_count), c_procs) + _nvmlCheckReturn(ret) + procs = [] + for i in range(c_count.value): + # use an alternative struct for this object + obj = nvmlStructToFriendlyObject(c_procs[i]) + if obj.usedGpuMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value: + # special case for WDDM on Windows, see comment above + obj.usedGpuMemory = None + procs.append(obj) + return procs + else: + # error case + raise NVMLError(ret) + + +# Added in 2.285 +def nvmlDeviceGetComputeRunningProcesses_v3(handle): + # first call to get the size + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetComputeRunningProcesses_v3") + ret = fn(handle, byref(c_count), None) + + if ret == NVML_SUCCESS: + # special case, no running processes + return [] + elif ret == NVML_ERROR_INSUFFICIENT_SIZE: + # typical case + # oversize the array in case more processes are created + c_count.value = c_count.value * 2 + 5 + proc_array = c_nvmlProcessInfo_v3_t * c_count.value + c_procs = proc_array() + + # make the call again + ret = fn(handle, byref(c_count), c_procs) + _nvmlCheckReturn(ret) + + procs = [] + for i in range(c_count.value): + # use an alternative struct for this object + obj = nvmlStructToFriendlyObject(c_procs[i]) + if obj.usedGpuMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value: + # special case for WDDM on Windows, see comment above + obj.usedGpuMemory = None + procs.append(obj) + + return procs + else: + # error case + raise NVMLError(ret) + + +@throwOnVersionMismatch +def nvmlDeviceGetComputeRunningProcesses(handle): + return nvmlDeviceGetComputeRunningProcesses_v3(handle) + + +def nvmlDeviceGetGraphicsRunningProcesses_v2(handle): + # first call to get the size + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGraphicsRunningProcesses_v2") + ret = fn(handle, byref(c_count), None) + if ret == NVML_SUCCESS: + # special case, no running processes + return [] + elif ret == NVML_ERROR_INSUFFICIENT_SIZE: + # typical case + # oversize the array in case more processes are created + c_count.value = c_count.value * 2 + 5 + proc_array = c_nvmlProcessInfo_v2_t * c_count.value + c_procs = proc_array() + # make the call again + ret = fn(handle, byref(c_count), c_procs) + _nvmlCheckReturn(ret) + procs = [] + for i in range(c_count.value): + # use an alternative struct for this object + obj = nvmlStructToFriendlyObject(c_procs[i]) + if obj.usedGpuMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value: + # special case for WDDM on Windows, see comment above + obj.usedGpuMemory = None + procs.append(obj) + return procs + else: + # error case + raise NVMLError(ret) + + +def nvmlDeviceGetGraphicsRunningProcesses_v3(handle): + # first call to get the size + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGraphicsRunningProcesses_v3") + ret = fn(handle, byref(c_count), None) + + if ret == NVML_SUCCESS: + # special case, no running processes + return [] + elif ret == NVML_ERROR_INSUFFICIENT_SIZE: + # typical case + # oversize the array in case more processes are created + c_count.value = c_count.value * 2 + 5 + proc_array = c_nvmlProcessInfo_v3_t * c_count.value + c_procs = proc_array() + + # make the call again + ret = fn(handle, byref(c_count), c_procs) + _nvmlCheckReturn(ret) + + procs = [] + for i in range(c_count.value): + # use an alternative struct for this object + obj = nvmlStructToFriendlyObject(c_procs[i]) + if obj.usedGpuMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value: + # special case for WDDM on Windows, see comment above + obj.usedGpuMemory = None + procs.append(obj) + + return procs + else: + # error case + raise NVMLError(ret) + + +@throwOnVersionMismatch +def nvmlDeviceGetGraphicsRunningProcesses(handle): + return nvmlDeviceGetGraphicsRunningProcesses_v3(handle) + + +@throwOnVersionMismatch +def nvmlDeviceGetMPSComputeRunningProcesses(handle): + return nvmlDeviceGetMPSComputeRunningProcesses_v3(handle) + + +def nvmlDeviceGetMPSComputeRunningProcesses_v2(handle): + # first call to get the size + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMPSComputeRunningProcesses_v2") + ret = fn(handle, byref(c_count), None) + + if ret == NVML_SUCCESS: + # special case, no running processes + return [] + elif ret == NVML_ERROR_INSUFFICIENT_SIZE: + # typical case + # oversize the array in case more processes are created + c_count.value = c_count.value * 2 + 5 + proc_array = c_nvmlProcessInfo_v2_t * c_count.value + c_procs = proc_array() + + # make the call again + ret = fn(handle, byref(c_count), c_procs) + _nvmlCheckReturn(ret) + + procs = [] + for i in range(c_count.value): + # use an alternative struct for this object + obj = nvmlStructToFriendlyObject(c_procs[i]) + if obj.usedGpuMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value: + # special case for WDDM on Windows, see comment above + obj.usedGpuMemory = None + procs.append(obj) + + return procs + else: + # error case + raise NVMLError(ret) + + +def nvmlDeviceGetMPSComputeRunningProcesses_v3(handle): + # first call to get the size + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMPSComputeRunningProcesses_v3") + ret = fn(handle, byref(c_count), None) + + if ret == NVML_SUCCESS: + # special case, no running processes + return [] + elif ret == NVML_ERROR_INSUFFICIENT_SIZE: + # typical case + # oversize the array in case more processes are created + c_count.value = c_count.value * 2 + 5 + proc_array = c_nvmlProcessInfo_v3_t * c_count.value + c_procs = proc_array() + + # make the call again + ret = fn(handle, byref(c_count), c_procs) + _nvmlCheckReturn(ret) + + procs = [] + for i in range(c_count.value): + # use an alternative struct for this object + obj = nvmlStructToFriendlyObject(c_procs[i]) + if obj.usedGpuMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value: + # special case for WDDM on Windows, see comment above + obj.usedGpuMemory = None + procs.append(obj) + + return procs + else: + # error case + raise NVMLError(ret) + + +def nvmlDeviceGetRunningProcessDetailList(handle, version, mode): + c_processDetailList = c_nvmlProcessDetailList_t() + c_processDetailList.version = version + c_processDetailList.mode = mode + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetRunningProcessDetailList") + + # first call to get the size + ret = fn(handle, byref(c_processDetailList)) + if ret == NVML_SUCCESS: + # special case, no running processes + return [] + elif ret == NVML_ERROR_INSUFFICIENT_SIZE: + c_procs = c_nvmlProcessDetail_v1_t * c_processDetailList.numProcArrayEntries + c_processDetailList.procArray = cast( + (c_procs)(), POINTER(c_nvmlProcessDetail_v1_t) + ) + + # make the call again + ret = fn(handle, byref(c_processDetailList)) + _nvmlCheckReturn(ret) + + procs = [] + for i in range(c_processDetailList.numProcArrayEntries): + # use an alternative struct for this object + obj = c_processDetailList.procArray[i] + if obj.usedGpuMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value: + obj.usedGpuMemory = None + if obj.usedGpuCcProtectedMemory == NVML_VALUE_NOT_AVAILABLE_ulonglong.value: + obj.usedGpuCcProtectedMemory = None + procs.append(obj) + + return procs + else: + # error case + raise NVMLError(ret) + + +def nvmlDeviceGetAutoBoostedClocksEnabled(handle): + c_isEnabled = _nvmlEnableState_t() + c_defaultIsEnabled = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetAutoBoostedClocksEnabled") + ret = fn(handle, byref(c_isEnabled), byref(c_defaultIsEnabled)) + _nvmlCheckReturn(ret) + return [c_isEnabled.value, c_defaultIsEnabled.value] + # Throws NVML_ERROR_NOT_SUPPORTED if hardware doesn't support setting auto boosted clocks + + +## Set functions +def nvmlUnitSetLedState(unit, color): + fn = _nvmlGetFunctionPointer("nvmlUnitSetLedState") + ret = fn(unit, _nvmlLedColor_t(color)) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceSetPersistenceMode(handle, mode): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetPersistenceMode") + ret = fn(handle, _nvmlEnableState_t(mode)) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceSetComputeMode(handle, mode): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetComputeMode") + ret = fn(handle, _nvmlComputeMode_t(mode)) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceSetEccMode(handle, mode): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetEccMode") + ret = fn(handle, _nvmlEnableState_t(mode)) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceClearEccErrorCounts(handle, counterType): + fn = _nvmlGetFunctionPointer("nvmlDeviceClearEccErrorCounts") + ret = fn(handle, _nvmlEccCounterType_t(counterType)) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceSetDriverModel(handle, model): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetDriverModel") + ret = fn(handle, _nvmlDriverModel_t(model)) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceSetAutoBoostedClocksEnabled(handle, enabled): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetAutoBoostedClocksEnabled") + ret = fn(handle, _nvmlEnableState_t(enabled)) + _nvmlCheckReturn(ret) + return None + # Throws NVML_ERROR_NOT_SUPPORTED if hardware doesn't support setting auto boosted clocks + + +def nvmlDeviceSetDefaultAutoBoostedClocksEnabled(handle, enabled, flags): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetDefaultAutoBoostedClocksEnabled") + ret = fn(handle, _nvmlEnableState_t(enabled), c_uint(flags)) + _nvmlCheckReturn(ret) + return None + # Throws NVML_ERROR_NOT_SUPPORTED if hardware doesn't support setting auto boosted clocks + + +def nvmlDeviceSetGpuLockedClocks(handle, minGpuClockMHz, maxGpuClockMHz): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetGpuLockedClocks") + ret = fn(handle, c_uint(minGpuClockMHz), c_uint(maxGpuClockMHz)) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceResetGpuLockedClocks(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceResetGpuLockedClocks") + ret = fn(handle) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceSetMemoryLockedClocks(handle, minMemClockMHz, maxMemClockMHz): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetMemoryLockedClocks") + ret = fn(handle, c_uint(minMemClockMHz), c_uint(maxMemClockMHz)) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceResetMemoryLockedClocks(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceResetMemoryLockedClocks") + ret = fn(handle) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceGetClkMonStatus(handle, c_clkMonInfo=nvmlClkMonStatus_t()): + isReference = type(c_clkMonInfo) is not nvmlClkMonStatus_t + c_clkMonInfoRef = c_clkMonInfo if isReference else byref(c_clkMonInfo) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetClkMonStatus") + ret = fn(handle, c_clkMonInfoRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isReference else c_clkMonInfo + + +# Added in 4.304 +def nvmlDeviceSetApplicationsClocks(handle, maxMemClockMHz, maxGraphicsClockMHz): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetApplicationsClocks") + ret = fn(handle, c_uint(maxMemClockMHz), c_uint(maxGraphicsClockMHz)) + _nvmlCheckReturn(ret) + return None + + +# Added in 4.304 +def nvmlDeviceResetApplicationsClocks(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceResetApplicationsClocks") + ret = fn(handle) + _nvmlCheckReturn(ret) + return None + + +# Added in 4.304 +def nvmlDeviceSetPowerManagementLimit(handle, limit): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetPowerManagementLimit") + ret = fn(handle, c_uint(limit)) + _nvmlCheckReturn(ret) + return None + + +# Added in 4.304 +def nvmlDeviceSetGpuOperationMode(handle, mode): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetGpuOperationMode") + ret = fn(handle, _nvmlGpuOperationMode_t(mode)) + _nvmlCheckReturn(ret) + return None + + +# Added in 2.285 +def nvmlEventSetCreate(): + fn = _nvmlGetFunctionPointer("nvmlEventSetCreate") + eventSet = c_nvmlEventSet_t() + ret = fn(byref(eventSet)) + _nvmlCheckReturn(ret) + return eventSet + + +# Added in 2.285 +def nvmlDeviceRegisterEvents(handle, eventTypes, eventSet): + fn = _nvmlGetFunctionPointer("nvmlDeviceRegisterEvents") + ret = fn(handle, c_ulonglong(eventTypes), eventSet) + _nvmlCheckReturn(ret) + return None + + +# Added in 2.285 +def nvmlDeviceGetSupportedEventTypes(handle): + c_eventTypes = c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSupportedEventTypes") + ret = fn(handle, byref(c_eventTypes)) + _nvmlCheckReturn(ret) + return c_eventTypes.value + + +# raises NVML_ERROR_TIMEOUT exception on timeout +def nvmlEventSetWait_v2(eventSet, timeoutms): + fn = _nvmlGetFunctionPointer("nvmlEventSetWait_v2") + data = c_nvmlEventData_t() + ret = fn(eventSet, byref(data), c_uint(timeoutms)) + _nvmlCheckReturn(ret) + return data + + +def nvmlEventSetWait(eventSet, timeoutms): + return nvmlEventSetWait_v2(eventSet, timeoutms) + + +# Added in 2.285 +def nvmlEventSetFree(eventSet): + fn = _nvmlGetFunctionPointer("nvmlEventSetFree") + ret = fn(eventSet) + _nvmlCheckReturn(ret) + return None + + +# Added in 3.295 +def nvmlDeviceOnSameBoard(handle1, handle2): + fn = _nvmlGetFunctionPointer("nvmlDeviceOnSameBoard") + onSameBoard = c_int() + ret = fn(handle1, handle2, byref(onSameBoard)) + _nvmlCheckReturn(ret) + return onSameBoard.value != 0 + + +# Added in 3.295 +def nvmlDeviceGetCurrPcieLinkGeneration(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCurrPcieLinkGeneration") + gen = c_uint() + ret = fn(handle, byref(gen)) + _nvmlCheckReturn(ret) + return gen.value + + +# Added in 3.295 +def nvmlDeviceGetMaxPcieLinkGeneration(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMaxPcieLinkGeneration") + gen = c_uint() + ret = fn(handle, byref(gen)) + _nvmlCheckReturn(ret) + return gen.value + + +# Added in 3.295 +def nvmlDeviceGetCurrPcieLinkWidth(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCurrPcieLinkWidth") + width = c_uint() + ret = fn(handle, byref(width)) + _nvmlCheckReturn(ret) + return width.value + + +# Added in 3.295 +def nvmlDeviceGetMaxPcieLinkWidth(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMaxPcieLinkWidth") + width = c_uint() + ret = fn(handle, byref(width)) + _nvmlCheckReturn(ret) + return width.value + + +def nvmlDeviceGetGpuMaxPcieLinkGeneration(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuMaxPcieLinkGeneration") + gen = c_uint() + ret = fn(handle, byref(gen)) + _nvmlCheckReturn(ret) + return gen.value + + +# Added in 4.304 +def nvmlDeviceGetSupportedClocksThrottleReasons(handle): + c_reasons = c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSupportedClocksThrottleReasons") + ret = fn(handle, byref(c_reasons)) + _nvmlCheckReturn(ret) + return c_reasons.value + + +def nvmlDeviceGetSupportedClocksEventReasons(handle): + c_reasons = c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSupportedClocksEventReasons") + ret = fn(handle, byref(c_reasons)) + _nvmlCheckReturn(ret) + return c_reasons.value + + +# Added in 4.304 +def nvmlDeviceGetCurrentClocksThrottleReasons(handle): + c_reasons = c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCurrentClocksThrottleReasons") + ret = fn(handle, byref(c_reasons)) + _nvmlCheckReturn(ret) + return c_reasons.value + + +def nvmlDeviceGetCurrentClocksEventReasons(handle): + c_reasons = c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCurrentClocksEventReasons") + ret = fn(handle, byref(c_reasons)) + _nvmlCheckReturn(ret) + return c_reasons.value + + +# Added in 5.319 +def nvmlDeviceGetIndex(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetIndex") + c_index = c_uint() + ret = fn(handle, byref(c_index)) + _nvmlCheckReturn(ret) + return c_index.value + + +# Added in 5.319 +def nvmlDeviceGetAccountingMode(handle): + c_mode = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetAccountingMode") + ret = fn(handle, byref(c_mode)) + _nvmlCheckReturn(ret) + return c_mode.value + + +def nvmlDeviceSetAccountingMode(handle, mode): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetAccountingMode") + ret = fn(handle, _nvmlEnableState_t(mode)) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceClearAccountingPids(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceClearAccountingPids") + ret = fn(handle) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceGetAccountingStats(handle, pid): + stats = c_nvmlAccountingStats_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetAccountingStats") + ret = fn(handle, c_uint(pid), byref(stats)) + _nvmlCheckReturn(ret) + if stats.maxMemoryUsage == NVML_VALUE_NOT_AVAILABLE_ulonglong.value: + # special case for WDDM on Windows, see comment above + stats.maxMemoryUsage = None + return stats + + +def nvmlDeviceGetAccountingPids(handle): + count = c_uint(nvmlDeviceGetAccountingBufferSize(handle)) + pids = (c_uint * count.value)() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetAccountingPids") + ret = fn(handle, byref(count), pids) + _nvmlCheckReturn(ret) + return list(map(int, pids[0 : count.value])) + + +def nvmlDeviceGetAccountingBufferSize(handle): + bufferSize = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetAccountingBufferSize") + ret = fn(handle, byref(bufferSize)) + _nvmlCheckReturn(ret) + return int(bufferSize.value) + + +def nvmlDeviceGetRetiredPages(device, sourceFilter): + c_source = _nvmlPageRetirementCause_t(sourceFilter) + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetRetiredPages") + + # First call will get the size + ret = fn(device, c_source, byref(c_count), None) + + # this should only fail with insufficient size + if (ret != NVML_SUCCESS) and (ret != NVML_ERROR_INSUFFICIENT_SIZE): + raise NVMLError(ret) + + # call again with a buffer + # oversize the array for the rare cases where additional pages + # are retired between NVML calls + c_count.value = c_count.value * 2 + 5 + page_array = c_ulonglong * c_count.value + c_pages = page_array() + ret = fn(device, c_source, byref(c_count), c_pages) + _nvmlCheckReturn(ret) + return list(map(int, c_pages[0 : c_count.value])) + + +def nvmlDeviceGetRetiredPages_v2(device, sourceFilter): + c_source = _nvmlPageRetirementCause_t(sourceFilter) + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetRetiredPages_v2") + + # First call will get the size + ret = fn(device, c_source, byref(c_count), None) + + # this should only fail with insufficient size + if (ret != NVML_SUCCESS) and (ret != NVML_ERROR_INSUFFICIENT_SIZE): + raise NVMLError(ret) + + # call again with a buffer + # oversize the array for the rare cases where additional pages + # are retired between NVML calls + c_count.value = c_count.value * 2 + 5 + page_array = c_ulonglong * c_count.value + c_pages = page_array() + times_array = c_ulonglong * c_count.value + c_times = times_array() + ret = fn(device, c_source, byref(c_count), c_pages, c_times) + _nvmlCheckReturn(ret) + return [ + {"address": int(c_pages[i]), "timestamp": int(c_times[i])} + for i in range(c_count.value) + ] + + +def nvmlDeviceGetRetiredPagesPendingStatus(device): + c_pending = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetRetiredPagesPendingStatus") + ret = fn(device, byref(c_pending)) + _nvmlCheckReturn(ret) + return int(c_pending.value) + + +def nvmlDeviceGetAPIRestriction(device, apiType): + c_permission = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetAPIRestriction") + ret = fn(device, _nvmlRestrictedAPI_t(apiType), byref(c_permission)) + _nvmlCheckReturn(ret) + return int(c_permission.value) + + +def nvmlDeviceSetAPIRestriction(handle, apiType, isRestricted): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetAPIRestriction") + ret = fn(handle, _nvmlRestrictedAPI_t(apiType), _nvmlEnableState_t(isRestricted)) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceGetBridgeChipInfo(handle): + bridgeHierarchy = c_nvmlBridgeChipHierarchy_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetBridgeChipInfo") + ret = fn(handle, byref(bridgeHierarchy)) + _nvmlCheckReturn(ret) + return bridgeHierarchy + + +def nvmlDeviceGetSamples(device, sampling_type, timeStamp): + c_sampling_type = _nvmlSamplingType_t(sampling_type) + c_time_stamp = c_ulonglong(timeStamp) + c_sample_count = c_uint(0) + c_sample_value_type = _nvmlValueType_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSamples") + + ## First Call gets the size + ret = fn( + device, + c_sampling_type, + c_time_stamp, + byref(c_sample_value_type), + byref(c_sample_count), + None, + ) + + # Stop if this fails + if ret != NVML_SUCCESS: + raise NVMLError(ret) + + sampleArray = c_sample_count.value * c_nvmlSample_t + c_samples = sampleArray() + ret = fn( + device, + c_sampling_type, + c_time_stamp, + byref(c_sample_value_type), + byref(c_sample_count), + c_samples, + ) + _nvmlCheckReturn(ret) + return (c_sample_value_type.value, c_samples[0 : c_sample_count.value]) + + +def nvmlDeviceGetViolationStatus(device, perfPolicyType): + c_perfPolicy_type = _nvmlPerfPolicyType_t(perfPolicyType) + c_violTime = c_nvmlViolationTime_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetViolationStatus") + + ## Invoke the method to get violation time + ret = fn(device, c_perfPolicy_type, byref(c_violTime)) + _nvmlCheckReturn(ret) + return c_violTime + + +def nvmlDeviceGetPcieThroughput(device, counter): + c_util = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPcieThroughput") + ret = fn(device, _nvmlPcieUtilCounter_t(counter), byref(c_util)) + _nvmlCheckReturn(ret) + return c_util.value + + +def nvmlSystemGetTopologyGpuSet(cpuNumber): + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlSystemGetTopologyGpuSet") + + # First call will get the size + ret = fn(cpuNumber, byref(c_count), None) + + if ret != NVML_SUCCESS: + raise NVMLError(ret) + # call again with a buffer + device_array = c_nvmlDevice_t * c_count.value + c_devices = device_array() + ret = fn(cpuNumber, byref(c_count), c_devices) + _nvmlCheckReturn(ret) + return list(c_devices[0 : c_count.value]) + + +def nvmlDeviceGetTopologyNearestGpus(device, level): + c_count = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetTopologyNearestGpus") + + # First call will get the size + ret = fn(device, level, byref(c_count), None) + + if ret != NVML_SUCCESS: + raise NVMLError(ret) + + # call again with a buffer + device_array = c_nvmlDevice_t * c_count.value + c_devices = device_array() + ret = fn(device, level, byref(c_count), c_devices) + _nvmlCheckReturn(ret) + return list(c_devices[0 : c_count.value]) + + +def nvmlDeviceGetTopologyCommonAncestor(device1, device2): + c_level = _nvmlGpuTopologyLevel_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetTopologyCommonAncestor") + ret = fn(device1, device2, byref(c_level)) + _nvmlCheckReturn(ret) + return c_level.value + + +def nvmlDeviceGetNvLinkUtilizationCounter(device, link, counter): + c_rxcounter = c_ulonglong() + c_txcounter = c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvLinkUtilizationCounter") + ret = fn(device, link, counter, byref(c_rxcounter), byref(c_txcounter)) + _nvmlCheckReturn(ret) + return (c_rxcounter.value, c_txcounter.value) + + +def nvmlDeviceFreezeNvLinkUtilizationCounter(device, link, counter, freeze): + fn = _nvmlGetFunctionPointer("nvmlDeviceFreezeNvLinkUtilizationCounter") + ret = fn(device, link, counter, freeze) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceResetNvLinkUtilizationCounter(device, link, counter): + fn = _nvmlGetFunctionPointer("nvmlDeviceResetNvLinkUtilizationCounter") + ret = fn(device, link, counter) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceSetNvLinkUtilizationControl(device, link, counter, control, reset): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetNvLinkUtilizationControl") + ret = fn(device, link, counter, byref(control), reset) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceGetNvLinkUtilizationControl(device, link, counter): + c_control = nvmlNvLinkUtilizationControl_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvLinkUtilizationControl") + ret = fn(device, link, counter, byref(c_control)) + _nvmlCheckReturn(ret) + return c_control + + +def nvmlDeviceGetNvLinkCapability(device, link, capability): + c_capResult = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvLinkCapability") + ret = fn(device, link, capability, byref(c_capResult)) + _nvmlCheckReturn(ret) + return c_capResult.value + + +def nvmlDeviceGetNvLinkErrorCounter(device, link, counter): + c_result = c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvLinkErrorCounter") + ret = fn(device, link, counter, byref(c_result)) + _nvmlCheckReturn(ret) + return c_result.value + + +def nvmlDeviceResetNvLinkErrorCounters(device, link): + fn = _nvmlGetFunctionPointer("nvmlDeviceResetNvLinkErrorCounters") + ret = fn(device, link) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceGetNvLinkRemotePciInfo(device, link): + c_pci = nvmlPciInfo_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvLinkRemotePciInfo_v2") + ret = fn(device, link, byref(c_pci)) + _nvmlCheckReturn(ret) + return c_pci + + +def nvmlDeviceGetNvLinkRemoteDeviceType(handle, link): + c_type = _nvmlNvLinkDeviceType_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvLinkRemoteDeviceType") + ret = fn(handle, link, byref(c_type)) + _nvmlCheckReturn(ret) + return c_type.value + + +def nvmlDeviceGetNvLinkState(device, link): + c_isActive = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvLinkState") + ret = fn(device, link, byref(c_isActive)) + _nvmlCheckReturn(ret) + return c_isActive.value + + +def nvmlDeviceGetNvLinkVersion(device, link): + c_version = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvLinkVersion") + ret = fn(device, link, byref(c_version)) + _nvmlCheckReturn(ret) + return c_version.value + + +def nvmlDeviceModifyDrainState(pciInfo, newState): + fn = _nvmlGetFunctionPointer("nvmlDeviceModifyDrainState") + ret = fn(pointer(pciInfo), newState) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceQueryDrainState(pciInfo): + c_newState = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceQueryDrainState") + ret = fn(pointer(pciInfo), byref(c_newState)) + _nvmlCheckReturn(ret) + return c_newState.value + + +def nvmlDeviceRemoveGpu(pciInfo): + fn = _nvmlGetFunctionPointer("nvmlDeviceRemoveGpu") + ret = fn(pointer(pciInfo)) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceDiscoverGpus(pciInfo): + fn = _nvmlGetFunctionPointer("nvmlDeviceDiscoverGpus") + ret = fn(pointer(pciInfo)) + _nvmlCheckReturn(ret) + return None + + +def nvmlDeviceGetFieldValues(handle, fieldIds): + values_arr = c_nvmlFieldValue_t * len(fieldIds) + values = values_arr() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetFieldValues") + + for i, fieldId in enumerate(fieldIds): + try: + (values[i].fieldId, values[i].scopeId) = fieldId + except TypeError: + values[i].fieldId = fieldId + + ret = fn(handle, c_int32(len(fieldIds)), byref(values)) + _nvmlCheckReturn(ret) + return values + + +def nvmlDeviceClearFieldValues(handle, fieldIds): + values_arr = c_nvmlFieldValue_t * len(fieldIds) + values = values_arr() + fn = _nvmlGetFunctionPointer("nvmlDeviceClearFieldValues") + + for i, fieldId in enumerate(fieldIds): + try: + (values[i].fieldId, values[i].scopeId) = fieldId + except TypeError: + values[i].fieldId = fieldId + + ret = fn(handle, c_int32(len(fieldIds)), byref(values)) + _nvmlCheckReturn(ret) + return values + + +def nvmlDeviceGetVirtualizationMode(handle): + c_virtualization_mode = c_ulonglong() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVirtualizationMode") + ret = fn(handle, byref(c_virtualization_mode)) + _nvmlCheckReturn(ret) + return c_virtualization_mode.value + + +def nvmlDeviceSetVirtualizationMode(handle, virtualization_mode): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetVirtualizationMode") + return fn(handle, virtualization_mode) + + +def nvmlDeviceGetVgpuHeterogeneousMode(handle): + c_vgpuHeterogeneousMode = c_nvmlVgpuHeterogeneousMode_v1_t(0) + c_vgpuHeterogeneousMode.version = VgpuHeterogeneousMode_v1 + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuHeterogeneousMode") + ret = fn(handle, byref(c_vgpuHeterogeneousMode)) + _nvmlCheckReturn(ret) + return c_vgpuHeterogeneousMode.mode + + +def nvmlDeviceSetVgpuHeterogeneousMode(handle, heterogeneous_mode): + c_vgpuHeterogeneousMode = c_nvmlVgpuHeterogeneousMode_v1_t(0) + c_vgpuHeterogeneousMode.version = VgpuHeterogeneousMode_v1 + c_vgpuHeterogeneousMode.mode = heterogeneous_mode + fn = _nvmlGetFunctionPointer("nvmlDeviceSetVgpuHeterogeneousMode") + ret = fn(handle, byref(c_vgpuHeterogeneousMode)) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlVgpuInstanceGetPlacementId(vgpuInstance): + c_placement = c_nvmlVgpuPlacementId_v1_t(0) + c_placement.version = VgpuPlacementId_v1 + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetPlacementId") + ret = fn(vgpuInstance, byref(c_placement)) + _nvmlCheckReturn(ret) + return c_placement.placementId + + +def nvmlDeviceGetVgpuTypeSupportedPlacements(handle, vgpuTypeId, mode=0, version=1): + c_max_instances = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetMaxInstances") + ret = fn(handle, vgpuTypeId, byref(c_max_instances)) + _nvmlCheckReturn(ret) + + if version == 2: + c_vgpu_placements = c_nvmlVgpuPlacementList_v2_t() + c_vgpu_placements.version = VgpuPlacementList_v2 + c_vgpu_placements.count = c_max_instances.value + c_vgpu_placements.mode = mode + elif version == 1: + c_vgpu_placements = c_nvmlVgpuPlacementList_v1_t() + c_vgpu_placements.version = VgpuPlacementList_v1 + else: + raise NVMLError(NVML_ERROR_ARGUMENT_VERSION_MISMATCH) + + c_placements = c_uint * c_max_instances.value + c_vgpu_placements.placementIds = c_placements() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuTypeSupportedPlacements") + ret = fn(handle, vgpuTypeId, byref(c_vgpu_placements)) + _nvmlCheckReturn(ret) + return c_vgpu_placements + + +def nvmlDeviceGetVgpuTypeCreatablePlacements(handle, vgpuTypeId, version=1): + c_max_instances = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetMaxInstances") + ret = fn(handle, vgpuTypeId, byref(c_max_instances)) + _nvmlCheckReturn(ret) + + if version == 2: + c_vgpu_placements = c_nvmlVgpuPlacementList_v2_t() + c_vgpu_placements.version = VgpuPlacementList_v2 + c_vgpu_placements.count = c_max_instances.value + elif version == 1: + c_vgpu_placements = c_nvmlVgpuPlacementList_v1_t() + c_vgpu_placements.version = VgpuPlacementList_v1 + + c_placements = c_uint * c_max_instances.value + c_vgpu_placements.placementIds = c_placements() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuTypeCreatablePlacements") + ret = fn(handle, vgpuTypeId, byref(c_vgpu_placements)) + _nvmlCheckReturn(ret) + return c_vgpu_placements + + +def nvmlGetVgpuDriverCapabilities(capability): + c_capResult = c_uint() + fn = _nvmlGetFunctionPointer("nvmlGetVgpuDriverCapabilities") + ret = fn(_nvmlVgpuDriverCapability_t(capability), byref(c_capResult)) + _nvmlCheckReturn(ret) + return c_capResult.value + + +def nvmlDeviceGetVgpuCapabilities(handle, capability): + c_capResult = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuCapabilities") + ret = fn(handle, _nvmlDeviceVgpuCapability_t(capability), byref(c_capResult)) + _nvmlCheckReturn(ret) + return c_capResult.value + + +def nvmlDeviceSetVgpuCapabilities(handle, capability, state): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetVgpuCapabilities") + ret = fn(handle, _nvmlDeviceVgpuCapability_t(capability), state) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlDeviceGetSupportedVgpus(handle): + # first call to get the size + c_vgpu_count = c_uint(0) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSupportedVgpus") + ret = fn(handle, byref(c_vgpu_count), None) + + if ret == NVML_SUCCESS: + # special case, no supported vGPUs + return [] + elif ret == NVML_ERROR_INSUFFICIENT_SIZE: + # typical case + vgpu_type_ids_array = _nvmlVgpuTypeId_t * c_vgpu_count.value + c_vgpu_type_ids = vgpu_type_ids_array() + + # make the call again + ret = fn(handle, byref(c_vgpu_count), c_vgpu_type_ids) + _nvmlCheckReturn(ret) + vgpus = [] + for i in range(c_vgpu_count.value): + vgpus.append(c_vgpu_type_ids[i]) + return vgpus + else: + # error case + raise NVMLError(ret) + + +def nvmlDeviceGetCreatableVgpus(handle): + # first call to get the size + c_vgpu_count = c_uint(0) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCreatableVgpus") + ret = fn(handle, byref(c_vgpu_count), None) + + if ret == NVML_SUCCESS: + # special case, no supported vGPUs + return [] + elif ret == NVML_ERROR_INSUFFICIENT_SIZE: + # typical case + vgpu_type_ids_array = _nvmlVgpuTypeId_t * c_vgpu_count.value + c_vgpu_type_ids = vgpu_type_ids_array() + + # make the call again + ret = fn(handle, byref(c_vgpu_count), c_vgpu_type_ids) + _nvmlCheckReturn(ret) + vgpus = [] + for i in range(c_vgpu_count.value): + vgpus.append(c_vgpu_type_ids[i]) + return vgpus + else: + # error case + raise NVMLError(ret) + + +def nvmlVgpuTypeGetGpuInstanceProfileId(vgpuTypeId): + c_profile_id = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetGpuInstanceProfileId") + ret = fn(vgpuTypeId, byref(c_profile_id)) + _nvmlCheckReturn(ret) + return c_profile_id.value + + +@convertStrBytes +def nvmlVgpuTypeGetClass(vgpuTypeId): + c_class = create_string_buffer(NVML_DEVICE_NAME_BUFFER_SIZE) + c_buffer_size = c_uint(NVML_DEVICE_NAME_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetClass") + ret = fn(vgpuTypeId, c_class, byref(c_buffer_size)) + _nvmlCheckReturn(ret) + return c_class.value + + +@convertStrBytes +def nvmlVgpuTypeGetName(vgpuTypeId): + c_name = create_string_buffer(NVML_DEVICE_NAME_BUFFER_SIZE) + c_buffer_size = c_uint(NVML_DEVICE_NAME_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetName") + ret = fn(vgpuTypeId, c_name, byref(c_buffer_size)) + _nvmlCheckReturn(ret) + return c_name.value + + +def nvmlVgpuTypeGetDeviceID(vgpuTypeId): + c_device_id = c_ulonglong(0) + c_subsystem_id = c_ulonglong(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetDeviceID") + ret = fn(vgpuTypeId, byref(c_device_id), byref(c_subsystem_id)) + _nvmlCheckReturn(ret) + return (c_device_id.value, c_subsystem_id.value) + + +def nvmlVgpuTypeGetFramebufferSize(vgpuTypeId): + c_fb_size = c_ulonglong(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetFramebufferSize") + ret = fn(vgpuTypeId, byref(c_fb_size)) + _nvmlCheckReturn(ret) + return c_fb_size.value + + +def nvmlVgpuTypeGetNumDisplayHeads(vgpuTypeId): + c_num_heads = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetNumDisplayHeads") + ret = fn(vgpuTypeId, byref(c_num_heads)) + _nvmlCheckReturn(ret) + return c_num_heads.value + + +def nvmlVgpuTypeGetResolution(vgpuTypeId): + c_xdim = c_uint(0) + c_ydim = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetResolution") + ret = fn(vgpuTypeId, 0, byref(c_xdim), byref(c_ydim)) + _nvmlCheckReturn(ret) + return (c_xdim.value, c_ydim.value) + + +@convertStrBytes +def nvmlVgpuTypeGetLicense(vgpuTypeId): + c_license = create_string_buffer(NVML_GRID_LICENSE_BUFFER_SIZE) + c_buffer_size = c_uint(NVML_GRID_LICENSE_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetLicense") + ret = fn(vgpuTypeId, c_license, c_buffer_size) + _nvmlCheckReturn(ret) + return c_license.value + + +def nvmlVgpuTypeGetFrameRateLimit(vgpuTypeId): + c_frl_config = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetFrameRateLimit") + ret = fn(vgpuTypeId, byref(c_frl_config)) + _nvmlCheckReturn(ret) + return c_frl_config.value + + +def nvmlVgpuTypeGetGspHeapSize(vgpuTypeId): + c_gsp_heap = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetGspHeapSize") + ret = fn(vgpuTypeId, byref(c_gsp_heap)) + _nvmlCheckReturn(ret) + return c_gsp_heap.value + + +def nvmlVgpuTypeGetFbReservation(vgpuTypeId): + c_fb_reservation = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetFbReservation") + ret = fn(vgpuTypeId, byref(c_fb_reservation)) + _nvmlCheckReturn(ret) + return c_fb_reservation.value + + +def nvmlVgpuInstanceGetRuntimeStateSize(vgpuInstance): + c_runtime_state = nvmlVgpuRuntimeState_v1_t() + c_runtime_state.version = VgpuRuntimeState_v1 + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetRuntimeStateSize") + ret = fn(vgpuInstance, byref(c_runtime_state)) + _nvmlCheckReturn(ret) + return c_runtime_state + + +def nvmlVgpuTypeGetMaxInstances(handle, vgpuTypeId): + c_max_instances = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetMaxInstances") + ret = fn(handle, vgpuTypeId, byref(c_max_instances)) + _nvmlCheckReturn(ret) + return c_max_instances.value + + +def nvmlVgpuTypeGetMaxInstancesPerVm(vgpuTypeId): + c_max_instances_per_vm = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetMaxInstancesPerVm") + ret = fn(vgpuTypeId, byref(c_max_instances_per_vm)) + _nvmlCheckReturn(ret) + return c_max_instances_per_vm.value + + +def nvmlVgpuTypeGetBAR1Info(vgpuTypeId): + c_bar1Info = c_nvmlVgpuTypeBar1Info_v1_t(0) + c_bar1Info.version = VgpuTypeBar1Info_v1 + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetBAR1Info") + ret = fn(vgpuTypeId, byref(c_bar1Info)) + _nvmlCheckReturn(ret) + return c_bar1Info + + +def nvmlDeviceGetActiveVgpus(handle): + # first call to get the size + c_vgpu_count = c_uint(0) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetActiveVgpus") + ret = fn(handle, byref(c_vgpu_count), None) + + if ret == NVML_SUCCESS: + # special case, no active vGPUs + return [] + elif ret == NVML_ERROR_INSUFFICIENT_SIZE: + # typical case + vgpu_instance_array = _nvmlVgpuInstance_t * c_vgpu_count.value + c_vgpu_instances = vgpu_instance_array() + + # make the call again + ret = fn(handle, byref(c_vgpu_count), c_vgpu_instances) + _nvmlCheckReturn(ret) + vgpus = [] + for i in range(c_vgpu_count.value): + vgpus.append(c_vgpu_instances[i]) + return vgpus + else: + # error case + raise NVMLError(ret) + + +@convertStrBytes +def nvmlVgpuInstanceGetVmID(vgpuInstance): + c_vm_id = create_string_buffer(NVML_DEVICE_UUID_BUFFER_SIZE) + c_buffer_size = c_uint(NVML_GRID_LICENSE_BUFFER_SIZE) + c_vm_id_type = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetVmID") + ret = fn(vgpuInstance, byref(c_vm_id), c_buffer_size, byref(c_vm_id_type)) + _nvmlCheckReturn(ret) + return (c_vm_id.value, c_vm_id_type.value) + + +@convertStrBytes +def nvmlVgpuInstanceGetUUID(vgpuInstance): + c_uuid = create_string_buffer(NVML_DEVICE_UUID_BUFFER_SIZE) + c_buffer_size = c_uint(NVML_DEVICE_UUID_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetUUID") + ret = fn(vgpuInstance, byref(c_uuid), c_buffer_size) + _nvmlCheckReturn(ret) + return c_uuid.value + + +@convertStrBytes +def nvmlVgpuInstanceGetMdevUUID(vgpuInstance): + c_uuid = create_string_buffer(NVML_DEVICE_UUID_BUFFER_SIZE) + c_buffer_size = c_uint(NVML_DEVICE_UUID_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetMdevUUID") + ret = fn(vgpuInstance, byref(c_uuid), c_buffer_size) + _nvmlCheckReturn(ret) + return c_uuid.value + + +@convertStrBytes +def nvmlVgpuInstanceGetVmDriverVersion(vgpuInstance): + c_driver_version = create_string_buffer(NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE) + c_buffer_size = c_uint(NVML_SYSTEM_DRIVER_VERSION_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetVmDriverVersion") + ret = fn(vgpuInstance, byref(c_driver_version), c_buffer_size) + _nvmlCheckReturn(ret) + return c_driver_version.value + + +def nvmlVgpuInstanceGetLicenseStatus(vgpuInstance): + c_license_status = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetLicenseStatus") + ret = fn(vgpuInstance, byref(c_license_status)) + _nvmlCheckReturn(ret) + return c_license_status.value + + +def nvmlVgpuInstanceGetLicenseInfo_v2(vgpuInstance): + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetLicenseInfo_v2") + c_license_info = c_nvmlVgpuLicenseInfo_t() + ret = fn(vgpuInstance, byref(c_license_info)) + _nvmlCheckReturn(ret) + return c_license_info + + +def nvmlVgpuInstanceGetLicenseInfo(vgpuInstance): + return nvmlVgpuInstanceGetLicenseInfo_v2(vgpuInstance) + + +def nvmlVgpuInstanceGetFrameRateLimit(vgpuInstance): + c_frl = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetFrameRateLimit") + ret = fn(vgpuInstance, byref(c_frl)) + _nvmlCheckReturn(ret) + return c_frl.value + + +def nvmlVgpuInstanceGetEccMode(vgpuInstance): + c_mode = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetEccMode") + ret = fn(vgpuInstance, byref(c_mode)) + _nvmlCheckReturn(ret) + return c_mode.value + + +def nvmlVgpuInstanceGetType(vgpuInstance): + c_vgpu_type = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetType") + ret = fn(vgpuInstance, byref(c_vgpu_type)) + _nvmlCheckReturn(ret) + return c_vgpu_type.value + + +def nvmlVgpuInstanceGetEncoderCapacity(vgpuInstance): + c_encoder_capacity = c_ulonglong(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetEncoderCapacity") + ret = fn(vgpuInstance, byref(c_encoder_capacity)) + _nvmlCheckReturn(ret) + return c_encoder_capacity.value + + +def nvmlVgpuInstanceSetEncoderCapacity(vgpuInstance, encoder_capacity): + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceSetEncoderCapacity") + return fn(vgpuInstance, encoder_capacity) + + +def nvmlVgpuInstanceGetFbUsage(vgpuInstance): + c_fb_usage = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetFbUsage") + ret = fn(vgpuInstance, byref(c_fb_usage)) + _nvmlCheckReturn(ret) + return c_fb_usage.value + + +def nvmlVgpuTypeGetCapabilities(vgpuTypeId, capability): + c_cap_result = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuTypeGetCapabilities") + ret = fn(vgpuTypeId, _nvmlVgpuCapability_t(capability), byref(c_cap_result)) + _nvmlCheckReturn(ret) + return c_cap_result.value + + +def nvmlVgpuInstanceGetGpuInstanceId(vgpuInstance): + c_id = c_uint(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetGpuInstanceId") + ret = fn(vgpuInstance, byref(c_id)) + _nvmlCheckReturn(ret) + return c_id.value + + +@convertStrBytes +def nvmlVgpuInstanceGetGpuPciId(vgpuInstance): + c_vgpuPciId = create_string_buffer(NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetGpuPciId") + ret = fn( + vgpuInstance, c_vgpuPciId, byref(c_uint(NVML_DEVICE_PCI_BUS_ID_BUFFER_SIZE)) + ) + _nvmlCheckReturn(ret) + return c_vgpuPciId.value + + +def nvmlDeviceGetVgpuUtilization(handle, timeStamp): + # first call to get the size + c_vgpu_count = c_uint(0) + c_time_stamp = c_ulonglong(timeStamp) + c_sample_value_type = _nvmlValueType_t() + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuUtilization") + ret = fn( + handle, c_time_stamp, byref(c_sample_value_type), byref(c_vgpu_count), None + ) + + if ret == NVML_SUCCESS: + # special case, no active vGPUs + return [] + elif ret == NVML_ERROR_INSUFFICIENT_SIZE: + # typical case + sampleArray = c_vgpu_count.value * c_nvmlVgpuInstanceUtilizationSample_t + c_samples = sampleArray() + + # make the call again + ret = fn( + handle, + c_time_stamp, + byref(c_sample_value_type), + byref(c_vgpu_count), + c_samples, + ) + _nvmlCheckReturn(ret) + + return c_samples[0 : c_vgpu_count.value] + else: + # error case + raise NVMLError(ret) + + +def nvmlDeviceGetVgpuInstancesUtilizationInfo(handle, timeStamp): + # first call to get the size + c_time_stamp = c_ulonglong(timeStamp) + c_vgpuUtilInfo = c_nvmlVgpuInstancesUtilizationInfo_v1_t(0) + c_vgpuUtilInfo.version = VgpuInstancesUtilizationInfo_v1 + c_vgpuUtilInfo.sampleValType = _nvmlValueType_t() + c_vgpuUtilInfo.vgpuInstanceCount = c_uint(0) + c_vgpuUtilInfo.lastSeenTimeStamp = c_time_stamp + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuInstancesUtilizationInfo") + ret = fn(handle, byref(c_vgpuUtilInfo)) + + if ret == NVML_SUCCESS: + # special case, no active vGPUs + return [] + elif ret == NVML_ERROR_INSUFFICIENT_SIZE: + # typical case + sampleArray = ( + c_vgpuUtilInfo.vgpuInstanceCount * c_nvmlVgpuInstanceUtilizationInfo_v1_t + ) + c_samples = sampleArray() + c_vgpuUtilInfo.vgpuUtilArray = c_samples + + # make the call again + ret = fn(handle, byref(c_vgpuUtilInfo)) + _nvmlCheckReturn(ret) + + return c_samples[0 : c_vgpuUtilInfo.vgpuInstanceCount] + else: + # error case + raise NVMLError(ret) + + +def nvmlDeviceGetP2PStatus(device1, device2, p2pIndex): + c_p2pstatus = _nvmlGpuP2PStatus_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetP2PStatus") + ret = fn(device1, device2, p2pIndex, byref(c_p2pstatus)) + _nvmlCheckReturn(ret) + return c_p2pstatus.value + + +def nvmlDeviceGetGridLicensableFeatures_v4(handle): + c_get_grid_licensable_features = c_nvmlGridLicensableFeatures_v4_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGridLicensableFeatures_v4") + ret = fn(handle, byref(c_get_grid_licensable_features)) + _nvmlCheckReturn(ret) + + return c_get_grid_licensable_features + + +def nvmlDeviceGetGridLicensableFeatures(handle): + return nvmlDeviceGetGridLicensableFeatures_v4(handle) + + +def nvmlDeviceGetGspFirmwareVersion(handle, version=None): + isUserDefined = version is not None + if not isUserDefined: + version = (c_char * NVML_GSP_FIRMWARE_VERSION_BUF_SIZE)() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGspFirmwareVersion") + ret = fn(handle, version) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isUserDefined else version.value + + +def nvmlDeviceGetGspFirmwareMode(handle, isEnabled=c_uint(), defaultMode=c_uint()): + isReference = type(isEnabled) is not c_uint + isEnabledRef = isEnabled if isReference else byref(isEnabled) + defaultModeRef = defaultMode if isReference else byref(defaultMode) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGspFirmwareMode") + ret = fn(handle, isEnabledRef, defaultModeRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isReference else [isEnabled.value, defaultMode.value] + + +def nvmlDeviceGetEncoderCapacity(handle, encoderQueryType): + c_encoder_capacity = c_ulonglong(0) + c_encoderQuery_type = _nvmlEncoderQueryType_t(encoderQueryType) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetEncoderCapacity") + ret = fn(handle, c_encoderQuery_type, byref(c_encoder_capacity)) + _nvmlCheckReturn(ret) + return c_encoder_capacity.value + + +def nvmlDeviceGetVgpuProcessUtilization(handle, timeStamp): + # first call to get the size + c_vgpu_count = c_uint(0) + c_time_stamp = c_ulonglong(timeStamp) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuProcessUtilization") + ret = fn(handle, c_time_stamp, byref(c_vgpu_count), None) + + if ret == NVML_SUCCESS: + # special case, no active vGPUs + return [] + elif ret == NVML_ERROR_INSUFFICIENT_SIZE: + # typical case + sampleArray = c_vgpu_count.value * c_nvmlVgpuProcessUtilizationSample_t + c_samples = sampleArray() + + # make the call again + ret = fn(handle, c_time_stamp, byref(c_vgpu_count), c_samples) + _nvmlCheckReturn(ret) + + return c_samples[0 : c_vgpu_count.value] + else: + # error case + raise NVMLError(ret) + + +def nvmlDeviceGetVgpuProcessesUtilizationInfo(handle, timeStamp): + # first call to get the size + c_time_stamp = c_ulonglong(timeStamp) + c_vgpuProcUtilInfo = c_nvmlVgpuProcessesUtilizationInfo_v1_t(0) + c_vgpuProcUtilInfo.version = VgpuProcessesUtilizationInfo_v1 + c_vgpuProcUtilInfo.vgpuProcessCount = c_uint(0) + c_vgpuProcUtilInfo.lastSeenTimeStamp = c_time_stamp + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuProcessesUtilizationInfo") + ret = fn(handle, byref(c_vgpuProcUtilInfo)) + + if ret == NVML_SUCCESS: + # special case, no active vGPUs + return [] + elif ret == NVML_ERROR_INSUFFICIENT_SIZE: + # typical case + sampleArray = ( + c_vgpuProcUtilInfo.vgpuProcessCount * c_nvmlVgpuProcessUtilizationInfo_v1_t + ) + c_samples = sampleArray() + c_vgpuProcUtilInfo.vgpuProcUtilArray = c_samples + + # make the call again + ret = fn(handle, byref(c_vgpuProcUtilInfo)) + _nvmlCheckReturn(ret) + + return c_samples[0 : c_vgpuProcUtilInfo.vgpuProcessCount] + else: + # error case + raise NVMLError(ret) + + +def nvmlDeviceGetEncoderStats(handle): + c_encoderCount = c_ulonglong(0) + c_encodeFps = c_ulonglong(0) + c_encoderLatency = c_ulonglong(0) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetEncoderStats") + ret = fn(handle, byref(c_encoderCount), byref(c_encodeFps), byref(c_encoderLatency)) + _nvmlCheckReturn(ret) + return (c_encoderCount.value, c_encodeFps.value, c_encoderLatency.value) + + +def nvmlDeviceGetEncoderSessions(handle): + # first call to get the size + c_session_count = c_uint(0) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetEncoderSessions") + ret = fn(handle, byref(c_session_count), None) + + if ret == NVML_SUCCESS: + if c_session_count.value != 0: + # typical case + session_array = c_nvmlEncoderSession_t * c_session_count.value + c_sessions = session_array() + + # make the call again + ret = fn(handle, byref(c_session_count), c_sessions) + _nvmlCheckReturn(ret) + sessions = [] + for i in range(c_session_count.value): + sessions.append(c_sessions[i]) + return sessions + else: + return [] # no active sessions + else: + # error case + raise NVMLError(ret) + + +def nvmlDeviceGetFBCStats(handle): + c_fbcStats = c_nvmlFBCStats_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetFBCStats") + ret = fn(handle, byref(c_fbcStats)) + _nvmlCheckReturn(ret) + return c_fbcStats + + +def nvmlDeviceGetFBCSessions(handle): + # first call to get the size + c_session_count = c_uint(0) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetFBCSessions") + ret = fn(handle, byref(c_session_count), None) + + if ret == NVML_SUCCESS: + if c_session_count.value != 0: + # typical case + session_array = c_nvmlFBCSession_t * c_session_count.value + c_sessions = session_array() + + # make the call again + ret = fn(handle, byref(c_session_count), c_sessions) + _nvmlCheckReturn(ret) + sessions = [] + for i in range(c_session_count.value): + sessions.append(c_sessions[i]) + return sessions + else: + return [] # no active sessions + else: + # error case + raise NVMLError(ret) + + +def nvmlVgpuInstanceGetEncoderStats(vgpuInstance): + c_encoderCount = c_ulonglong(0) + c_encodeFps = c_ulonglong(0) + c_encoderLatency = c_ulonglong(0) + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetEncoderStats") + ret = fn( + vgpuInstance, byref(c_encoderCount), byref(c_encodeFps), byref(c_encoderLatency) + ) + _nvmlCheckReturn(ret) + return (c_encoderCount.value, c_encodeFps.value, c_encoderLatency.value) + + +def nvmlVgpuInstanceGetEncoderSessions(vgpuInstance): + # first call to get the size + c_session_count = c_uint(0) + + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetEncoderSessions") + ret = fn(vgpuInstance, byref(c_session_count), None) + + if ret == NVML_SUCCESS: + if c_session_count.value != 0: + # typical case + session_array = c_nvmlEncoderSession_t * c_session_count.value + c_sessions = session_array() + + # make the call again + ret = fn(vgpuInstance, byref(c_session_count), c_sessions) + _nvmlCheckReturn(ret) + sessions = [] + for i in range(c_session_count.value): + sessions.append(c_sessions[i]) + return sessions + else: + return [] # no active sessions + else: + # error case + raise NVMLError(ret) + + +def nvmlVgpuInstanceGetFBCStats(vgpuInstance): + c_fbcStats = c_nvmlFBCStats_t() + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetFBCStats") + ret = fn(vgpuInstance, byref(c_fbcStats)) + _nvmlCheckReturn(ret) + return c_fbcStats + + +def nvmlVgpuInstanceGetFBCSessions(vgpuInstance): + # first call to get the size + c_session_count = c_uint(0) + + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetFBCSessions") + ret = fn(vgpuInstance, byref(c_session_count), None) + + if ret == NVML_SUCCESS: + if c_session_count.value != 0: + # typical case + session_array = c_nvmlFBCSession_t * c_session_count.value + c_sessions = session_array() + + # make the call again + ret = fn(vgpuInstance, byref(c_session_count), c_sessions) + _nvmlCheckReturn(ret) + sessions = [] + for i in range(c_session_count.value): + sessions.append(c_sessions[i]) + return sessions + else: + return [] # no active sessions + else: + # error case + raise NVMLError(ret) + + +def nvmlDeviceGetProcessUtilization(handle, timeStamp): + # first call to get the size + c_count = c_uint(0) + c_time_stamp = c_ulonglong(timeStamp) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetProcessUtilization") + ret = fn(handle, None, byref(c_count), c_time_stamp) + + if ret == NVML_ERROR_INSUFFICIENT_SIZE: + # typical case + sampleArray = c_count.value * c_nvmlProcessUtilizationSample_t + c_samples = sampleArray() + + # make the call again + ret = fn(handle, c_samples, byref(c_count), c_time_stamp) + _nvmlCheckReturn(ret) + + return c_samples[0 : c_count.value] + else: + # error case + raise NVMLError(ret) + + +def nvmlDeviceGetProcessesUtilizationInfo(handle, timeStamp): + # first call to get the size + c_time_stamp = c_ulonglong(timeStamp) + c_processesUtilInfo = c_nvmlProcessesUtilizationInfo_v1_t(0) + c_processesUtilInfo.version = ProcessesUtilizationInfo_v1 + c_processesUtilInfo.processSamplesCount = c_uint(0) + c_processesUtilInfo.lastSeenTimeStamp = c_time_stamp + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetProcessesUtilizationInfo") + ret = fn(handle, byref(c_processesUtilInfo)) + + if ret == NVML_ERROR_INSUFFICIENT_SIZE: + # typical case + sampleArray = ( + c_processesUtilInfo.processSamplesCount * c_nvmlProcessUtilizationInfo_v1_t + ) + c_samples = sampleArray() + c_processesUtilInfo.procUtilArray = c_samples + + # make the call again + ret = fn(handle, byref(c_processesUtilInfo)) + _nvmlCheckReturn(ret) + + return c_samples[0 : c_processesUtilInfo.processSamplesCount] + else: + # error case + raise NVMLError(ret) + + +def nvmlVgpuInstanceGetMetadata(vgpuInstance): + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetMetadata") + c_vgpuMetadata = c_nvmlVgpuMetadata_t() + c_bufferSize = c_uint(0) + # Make the first NVML API call to get the c_bufferSize value. + # We have already allocated required buffer above. + ret = fn(vgpuInstance, byref(c_vgpuMetadata), byref(c_bufferSize)) + if ret == NVML_ERROR_INSUFFICIENT_SIZE: + ret = fn(vgpuInstance, byref(c_vgpuMetadata), byref(c_bufferSize)) + _nvmlCheckReturn(ret) + else: + raise NVMLError(ret) + return c_vgpuMetadata + + +def nvmlDeviceGetVgpuMetadata(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuMetadata") + c_vgpuPgpuMetadata = c_nvmlVgpuPgpuMetadata_t() + c_bufferSize = c_uint(0) + # Make the first NVML API call to get the c_bufferSize value. + # We have already allocated required buffer above. + ret = fn(handle, byref(c_vgpuPgpuMetadata), byref(c_bufferSize)) + if ret == NVML_ERROR_INSUFFICIENT_SIZE: + ret = fn(handle, byref(c_vgpuPgpuMetadata), byref(c_bufferSize)) + _nvmlCheckReturn(ret) + else: + raise NVMLError(ret) + return c_vgpuPgpuMetadata + + +def nvmlGetVgpuCompatibility(vgpuMetadata, pgpuMetadata): + fn = _nvmlGetFunctionPointer("nvmlGetVgpuCompatibility") + c_vgpuPgpuCompatibility = c_nvmlVgpuPgpuCompatibility_t() + ret = fn(byref(vgpuMetadata), byref(pgpuMetadata), byref(c_vgpuPgpuCompatibility)) + _nvmlCheckReturn(ret) + return c_vgpuPgpuCompatibility + + +@convertStrBytes +def nvmlDeviceGetPgpuMetadataString(handle): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPgpuMetadataString") + c_pgpuMetadata = create_string_buffer(NVML_VGPU_PGPU_METADATA_OPAQUE_DATA_SIZE) + c_bufferSize = c_uint(0) + # Make the first NVML API call to get the c_bufferSize value. + # We have already allocated required buffer above. + ret = fn(handle, byref(c_pgpuMetadata), byref(c_bufferSize)) + if ret == NVML_ERROR_INSUFFICIENT_SIZE: + ret = fn(handle, byref(c_pgpuMetadata), byref(c_bufferSize)) + _nvmlCheckReturn(ret) + else: + raise NVMLError(ret) + return (c_pgpuMetadata.value, c_bufferSize.value) + + +def nvmlDeviceGetVgpuSchedulerLog(handle): + c_vgpu_sched_log = c_nvmlVgpuSchedulerLog_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuSchedulerLog") + ret = fn(handle, byref(c_vgpu_sched_log)) + _nvmlCheckReturn(ret) + return c_vgpu_sched_log + + +def nvmlDeviceGetVgpuSchedulerState(handle): + c_vgpu_sched_state = c_nvmlVgpuSchedulerGetState_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuSchedulerState") + ret = fn(handle, byref(c_vgpu_sched_state)) + _nvmlCheckReturn(ret) + return c_vgpu_sched_state + + +def nvmlDeviceGetVgpuSchedulerCapabilities(handle): + c_vgpu_sched_caps = c_nvmlVgpuSchedulerCapabilities_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetVgpuSchedulerCapabilities") + ret = fn(handle, byref(c_vgpu_sched_caps)) + _nvmlCheckReturn(ret) + return c_vgpu_sched_caps + + +def nvmlDeviceSetVgpuSchedulerState(handle, sched_state): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetVgpuSchedulerState") + ret = fn(handle, byref(sched_state)) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlSetVgpuVersion(vgpuVersion): + fn = _nvmlGetFunctionPointer("nvmlSetVgpuVersion") + ret = fn(byref(vgpuVersion)) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlGetVgpuVersion(supported=None, current=None): + isUserDefined = (supported is not None) or (current is not None) + if not isUserDefined: + supported = c_nvmlVgpuVersion_t() + current = c_nvmlVgpuVersion_t() + fn = _nvmlGetFunctionPointer("nvmlGetVgpuVersion") + ret = fn(byref(supported), byref(current)) + _nvmlCheckReturn(ret) + return ( + NVML_SUCCESS + if isUserDefined + else [ + (supported.minVersion, supported.maxVersion), + (current.minVersion, current.maxVersion), + ] + ) + + +def nvmlVgpuInstanceGetAccountingMode(vgpuInstance): + c_mode = _nvmlEnableState_t() + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetAccountingMode") + ret = fn(vgpuInstance, byref(c_mode)) + _nvmlCheckReturn(ret) + return c_mode.value + + +def nvmlVgpuInstanceGetAccountingPids(vgpuInstance): + c_pidCount = c_uint() + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetAccountingPids") + ret = fn(vgpuInstance, byref(c_pidCount), None) + if ret == NVML_ERROR_INSUFFICIENT_SIZE: + sampleArray = c_pidCount.value * c_uint + c_pidArray = sampleArray() + ret = fn(vgpuInstance, byref(c_pidCount), byref(c_pidArray)) + _nvmlCheckReturn(ret) + else: + raise NVMLError(ret) + return (c_pidCount, c_pidArray) + + +def nvmlVgpuInstanceGetAccountingStats(vgpuInstance, pid): + c_accountingStats = c_nvmlAccountingStats_t() + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceGetAccountingStats") + ret = fn(vgpuInstance, pid, byref(c_accountingStats)) + _nvmlCheckReturn(ret) + return c_accountingStats + + +def nvmlVgpuInstanceClearAccountingPids(vgpuInstance): + fn = _nvmlGetFunctionPointer("nvmlVgpuInstanceClearAccountingPids") + ret = fn(vgpuInstance) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlGetExcludedDeviceCount(): + c_count = c_uint() + fn = _nvmlGetFunctionPointer("nvmlGetExcludedDeviceCount") + ret = fn(byref(c_count)) + _nvmlCheckReturn(ret) + return c_count.value + + +def nvmlGetExcludedDeviceInfoByIndex(index): + c_index = c_uint(index) + info = c_nvmlExcludedDeviceInfo_t() + fn = _nvmlGetFunctionPointer("nvmlGetExcludedDeviceInfoByIndex") + ret = fn(c_index, byref(info)) + _nvmlCheckReturn(ret) + return info + + +def nvmlDeviceGetHostVgpuMode(handle): + c_host_vgpu_mode = _nvmlHostVgpuMode_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetHostVgpuMode") + ret = fn(handle, byref(c_host_vgpu_mode)) + _nvmlCheckReturn(ret) + return c_host_vgpu_mode.value + + +def nvmlDeviceSetMigMode(device, mode): + c_activationStatus = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceSetMigMode") + ret = fn(device, mode, byref(c_activationStatus)) + _nvmlCheckReturn(ret) + return c_activationStatus.value + + +def nvmlDeviceGetMigMode(device): + c_currentMode = c_uint() + c_pendingMode = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMigMode") + ret = fn(device, byref(c_currentMode), byref(c_pendingMode)) + _nvmlCheckReturn(ret) + return [c_currentMode.value, c_pendingMode.value] + + +def nvmlDeviceGetGpuInstanceProfileInfo(device, profile, version=2): + if version == 2: + c_info = c_nvmlGpuInstanceProfileInfo_v2_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuInstanceProfileInfoV") + elif version == 1: + c_info = c_nvmlGpuInstanceProfileInfo_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuInstanceProfileInfo") + else: + raise NVMLError(NVML_ERROR_FUNCTION_NOT_FOUND) + ret = fn(device, profile, byref(c_info)) + _nvmlCheckReturn(ret) + return c_info + + +# Define function alias for the API exposed by NVML +nvmlDeviceGetGpuInstanceProfileInfoV = nvmlDeviceGetGpuInstanceProfileInfo + + +def nvmlDeviceGetGpuInstanceRemainingCapacity(device, profileId): + c_count = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuInstanceRemainingCapacity") + ret = fn(device, profileId, byref(c_count)) + _nvmlCheckReturn(ret) + return c_count.value + + +def nvmlDeviceGetGpuInstancePossiblePlacements( + device, profileId, placementsRef, countRef +): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuInstancePossiblePlacements_v2") + ret = fn(device, profileId, placementsRef, countRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlDeviceCreateGpuInstance(device, profileId): + c_instance = c_nvmlGpuInstance_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceCreateGpuInstance") + ret = fn(device, profileId, byref(c_instance)) + _nvmlCheckReturn(ret) + return c_instance + + +def nvmlDeviceCreateGpuInstanceWithPlacement(device, profileId, placement): + c_instance = c_nvmlGpuInstance_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceCreateGpuInstanceWithPlacement") + ret = fn(device, profileId, placement, byref(c_instance)) + _nvmlCheckReturn(ret) + return c_instance + + +def nvmlGpuInstanceDestroy(gpuInstance): + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceDestroy") + ret = fn(gpuInstance) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlDeviceGetGpuInstances(device, profileId, gpuInstancesRef, countRef): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuInstances") + ret = fn(device, profileId, gpuInstancesRef, countRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlDeviceGetGpuInstanceById(device, gpuInstanceId): + c_instance = c_nvmlGpuInstance_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuInstanceById") + ret = fn(device, gpuInstanceId, byref(c_instance)) + _nvmlCheckReturn(ret) + return c_instance + + +def nvmlGpuInstanceGetInfo(gpuInstance): + c_info = c_nvmlGpuInstanceInfo_t() + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceGetInfo") + ret = fn(gpuInstance, byref(c_info)) + _nvmlCheckReturn(ret) + return c_info + + +def nvmlGpuInstanceGetComputeInstanceProfileInfo( + device, profile, engProfile, version=2 +): + if version == 2: + c_info = c_nvmlComputeInstanceProfileInfo_v2_t() + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceGetComputeInstanceProfileInfoV") + elif version == 1: + c_info = c_nvmlComputeInstanceProfileInfo_t() + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceGetComputeInstanceProfileInfo") + else: + raise NVMLError(NVML_ERROR_FUNCTION_NOT_FOUND) + ret = fn(device, profile, engProfile, byref(c_info)) + _nvmlCheckReturn(ret) + return c_info + + +# Define function alias for the API exposed by NVML +nvmlGpuInstanceGetComputeInstanceProfileInfoV = ( + nvmlGpuInstanceGetComputeInstanceProfileInfo +) + + +def nvmlGpuInstanceGetComputeInstanceRemainingCapacity(gpuInstance, profileId): + c_count = c_uint() + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceGetComputeInstanceRemainingCapacity") + ret = fn(gpuInstance, profileId, byref(c_count)) + _nvmlCheckReturn(ret) + return c_count.value + + +def nvmlGpuInstanceGetComputeInstancePossiblePlacements( + gpuInstance, profileId, placementsRef, countRef +): + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceGetComputeInstancePossiblePlacements") + ret = fn(gpuInstance, profileId, placementsRef, countRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlGpuInstanceCreateComputeInstance(gpuInstance, profileId): + c_instance = c_nvmlComputeInstance_t() + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceCreateComputeInstance") + ret = fn(gpuInstance, profileId, byref(c_instance)) + _nvmlCheckReturn(ret) + return c_instance + + +def nvmlGpuInstanceCreateComputeInstanceWithPlacement( + gpuInstance, profileId, placement +): + c_instance = c_nvmlComputeInstance_t() + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceCreateComputeInstanceWithPlacement") + ret = fn(gpuInstance, profileId, placement, byref(c_instance)) + _nvmlCheckReturn(ret) + return c_instance + + +def nvmlComputeInstanceDestroy(computeInstance): + fn = _nvmlGetFunctionPointer("nvmlComputeInstanceDestroy") + ret = fn(computeInstance) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlGpuInstanceGetComputeInstances( + gpuInstance, profileId, computeInstancesRef, countRef +): + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceGetComputeInstances") + ret = fn(gpuInstance, profileId, computeInstancesRef, countRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlGpuInstanceGetComputeInstanceById(gpuInstance, computeInstanceId): + c_instance = c_nvmlComputeInstance_t() + fn = _nvmlGetFunctionPointer("nvmlGpuInstanceGetComputeInstanceById") + ret = fn(gpuInstance, computeInstanceId, byref(c_instance)) + _nvmlCheckReturn(ret) + return c_instance + + +def nvmlComputeInstanceGetInfo_v2(computeInstance): + c_info = c_nvmlComputeInstanceInfo_t() + fn = _nvmlGetFunctionPointer("nvmlComputeInstanceGetInfo_v2") + ret = fn(computeInstance, byref(c_info)) + _nvmlCheckReturn(ret) + return c_info + + +def nvmlComputeInstanceGetInfo(computeInstance): + return nvmlComputeInstanceGetInfo_v2(computeInstance) + + +def nvmlDeviceIsMigDeviceHandle(device): + c_isMigDevice = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceIsMigDeviceHandle") + ret = fn(device, byref(c_isMigDevice)) + _nvmlCheckReturn(ret) + return c_isMigDevice + + +def nvmlDeviceGetGpuInstanceId(device): + c_gpuInstanceId = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuInstanceId") + ret = fn(device, byref(c_gpuInstanceId)) + _nvmlCheckReturn(ret) + return c_gpuInstanceId.value + + +def nvmlDeviceGetComputeInstanceId(device): + c_computeInstanceId = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetComputeInstanceId") + ret = fn(device, byref(c_computeInstanceId)) + _nvmlCheckReturn(ret) + return c_computeInstanceId.value + + +def nvmlDeviceGetMaxMigDeviceCount(device): + c_count = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMaxMigDeviceCount") + ret = fn(device, byref(c_count)) + _nvmlCheckReturn(ret) + return c_count.value + + +def nvmlDeviceGetMigDeviceHandleByIndex(device, index): + c_index = c_uint(index) + migDevice = c_nvmlDevice_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMigDeviceHandleByIndex") + ret = fn(device, c_index, byref(migDevice)) + _nvmlCheckReturn(ret) + return migDevice + + +def nvmlDeviceGetDeviceHandleFromMigDeviceHandle(migDevice): + device = c_nvmlDevice_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDeviceHandleFromMigDeviceHandle") + ret = fn(migDevice, byref(device)) + _nvmlCheckReturn(ret) + return device + + +def nvmlDeviceGetAttributes_v2(device): + c_attrs = c_nvmlDeviceAttributes() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetAttributes_v2") + ret = fn(device, byref(c_attrs)) + _nvmlCheckReturn(ret) + return c_attrs + + +def nvmlDeviceGetAttributes(device): + return nvmlDeviceGetAttributes_v2(device) + + +def nvmlDeviceGetRemappedRows(device): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetRemappedRows") + c_corr = c_uint() + c_unc = c_uint() + c_bpending = c_uint() + c_bfailure = c_uint() + ret = fn(device, byref(c_corr), byref(c_unc), byref(c_bpending), byref(c_bfailure)) + _nvmlCheckReturn(ret) + return (c_corr.value, c_unc.value, c_bpending.value, c_bfailure.value) + + +def nvmlDeviceGetRowRemapperHistogram(device): + c_vals = c_nvmlRowRemapperHistogramValues() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetRowRemapperHistogram") + ret = fn(device, byref(c_vals)) + _nvmlCheckReturn(ret) + return c_vals + + +def nvmlDeviceGetArchitecture(device): + arch = _nvmlDeviceArchitecture_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetArchitecture") + ret = fn(device, byref(arch)) + _nvmlCheckReturn(ret) + return arch.value + + +def nvmlDeviceGetBusType(device): + c_busType = _nvmlBusType_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetBusType") + ret = fn(device, byref(c_busType)) + _nvmlCheckReturn(ret) + return c_busType.value + + +def nvmlDeviceGetIrqNum(device): + c_irqNum = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetIrqNum") + ret = fn(device, byref(c_irqNum)) + _nvmlCheckReturn(ret) + return c_irqNum.value + + +def nvmlDeviceGetNumGpuCores(device): + c_numCores = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNumGpuCores") + ret = fn(device, byref(c_numCores)) + _nvmlCheckReturn(ret) + return c_numCores.value + + +def nvmlDeviceGetPowerSource(device): + c_powerSource = _nvmlPowerSource_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPowerSource") + ret = fn(device, byref(c_powerSource)) + _nvmlCheckReturn(ret) + return c_powerSource.value + + +def nvmlDeviceGetMemoryBusWidth(device): + c_memBusWidth = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMemoryBusWidth") + ret = fn(device, byref(c_memBusWidth)) + _nvmlCheckReturn(ret) + return c_memBusWidth.value + + +def nvmlDeviceGetPcieLinkMaxSpeed(device): + c_speed = _nvmlPcieLinkMaxSpeed_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPcieLinkMaxSpeed") + ret = fn(device, byref(c_speed)) + _nvmlCheckReturn(ret) + return c_speed.value + + +def nvmlDeviceGetAdaptiveClockInfoStatus(device): + c_adaptiveClockInfoStatus = _nvmlAdaptiveClockInfoStatus_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetAdaptiveClockInfoStatus") + ret = fn(device, byref(c_adaptiveClockInfoStatus)) + _nvmlCheckReturn(ret) + return c_adaptiveClockInfoStatus.value + + +def nvmlDeviceGetPcieSpeed(device): + c_speed = c_uint() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPcieSpeed") + ret = fn(device, byref(c_speed)) + _nvmlCheckReturn(ret) + return c_speed.value + + +def nvmlDeviceGetDynamicPstatesInfo( + device, c_dynamicpstatesinfo=c_nvmlGpuDynamicPstatesInfo_t() +): + isReference = type(c_dynamicpstatesinfo) is not c_nvmlGpuDynamicPstatesInfo_t + dynamicpstatesinfoRef = ( + c_dynamicpstatesinfo if isReference else byref(c_dynamicpstatesinfo) + ) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDynamicPstatesInfo") + ret = fn(device, dynamicpstatesinfoRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isReference else c_dynamicpstatesinfo + + +def nvmlDeviceSetFanSpeed_v2(handle, index, speed): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetFanSpeed_v2") + ret = fn(handle, index, speed) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlDeviceGetThermalSettings( + device, sensorindex, c_thermalsettings=c_nvmlGpuThermalSettings_t() +): + isReference = type(c_thermalsettings) is not c_nvmlGpuThermalSettings_t + thermalsettingsRef = c_thermalsettings if isReference else byref(c_thermalsettings) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetThermalSettings") + ret = fn(device, sensorindex, thermalsettingsRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isReference else c_thermalsettings.sensor[:] + + +def nvmlDeviceGetMinMaxClockOfPState( + device, clockType, pstate, minClockMHz=c_uint(), maxClockMHz=c_uint() +): + isReference = (type(minClockMHz) is not c_uint) or (type(maxClockMHz) is not c_uint) + minClockMHzRef = minClockMHz if isReference else byref(minClockMHz) + maxClockMHzRef = maxClockMHz if isReference else byref(maxClockMHz) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMinMaxClockOfPState") + ret = fn( + device, + _nvmlClockType_t(clockType), + _nvmlClockType_t(pstate), + minClockMHzRef, + maxClockMHzRef, + ) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isReference else (minClockMHz.value, maxClockMHz.value) + + +class c_nvmlClockOffset_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("type", _nvmlClockType_t), + ("pstate", _nvmlPstates_t), + ("clockOffsetMHz", c_int), + ("minClockOffsetMHz", c_int), + ("maxClockOffsetMHz", c_int), + ] + + +nvmlClockOffset_v1 = 0x1000018 + + +def nvmlDeviceGetClockOffsets(device, info): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetClockOffsets") + ret = fn(device, info) + return NVML_SUCCESS + + +def nvmlDeviceSetClockOffsets(device, info): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetClockOffsets") + ret = fn(device, info) + return NVML_SUCCESS + + +def nvmlDeviceGetSupportedPerformanceStates(device): + pstates = [] + c_count = c_uint(NVML_MAX_GPU_PERF_PSTATES) + c_size = sizeof(c_uint) * c_count.value + + # NOTE: use 'c_uint' to represent the size of the nvmlPstate_t enumeration. + pstates_array = _nvmlPstates_t * c_count.value + c_pstates = pstates_array() + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSupportedPerformanceStates") + ret = fn(device, c_pstates, c_size) + _nvmlCheckReturn(ret) + + for value in c_pstates: + if value != NVML_PSTATE_UNKNOWN: + pstates.append(value) + + return pstates + + +def nvmlDeviceGetGpcClkVfOffset(device): + offset = c_int32() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpcClkVfOffset") + ret = fn(device, byref(offset)) + _nvmlCheckReturn(ret) + return offset.value + + +def nvmlDeviceSetGpcClkVfOffset(device, offset): + c_offset = c_int32(offset) + fn = _nvmlGetFunctionPointer("nvmlDeviceSetGpcClkVfOffset") + ret = fn(device, c_offset) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlDeviceGetGpcClkMinMaxVfOffset(device, minOffset=c_int(), maxOffset=c_int()): + isReference = (type(minOffset) is not c_int) or (type(maxOffset) is not c_int) + minOffsetRef = minOffset if isReference else byref(minOffset) + maxOffsetRef = maxOffset if isReference else byref(maxOffset) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpcClkMinMaxVfOffset") + ret = fn(device, minOffsetRef, maxOffsetRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isReference else (minOffset.value, maxOffset.value) + + +def nvmlDeviceGetMemClkVfOffset(device): + offset = c_int32() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMemClkVfOffset") + ret = fn(device, byref(offset)) + _nvmlCheckReturn(ret) + return offset.value + + +def nvmlDeviceSetMemClkVfOffset(device, offset): + c_offset = c_int32(offset) + fn = _nvmlGetFunctionPointer("nvmlDeviceSetMemClkVfOffset") + ret = fn(device, c_offset) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlDeviceGetMemClkMinMaxVfOffset(device, minOffset=c_int(), maxOffset=c_int()): + isReference = (type(minOffset) is not c_int) or (type(maxOffset) is not c_int) + minOffsetRef = minOffset if isReference else byref(minOffset) + maxOffsetRef = maxOffset if isReference else byref(maxOffset) + + fn = _nvmlGetFunctionPointer("nvmlDeviceGetMemClkMinMaxVfOffset") + ret = fn(device, minOffsetRef, maxOffsetRef) + _nvmlCheckReturn(ret) + return NVML_SUCCESS if isReference else (minOffset.value, maxOffset.value) + + +def nvmlSystemSetConfComputeGpusReadyState(state): + c_state = c_uint(state) + fn = _nvmlGetFunctionPointer("nvmlSystemSetConfComputeGpusReadyState") + ret = fn(c_state) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlSystemGetConfComputeGpusReadyState(): + c_state = c_uint() + fn = _nvmlGetFunctionPointer("nvmlSystemGetConfComputeGpusReadyState") + ret = fn(byref(c_state)) + _nvmlCheckReturn(ret) + return c_state.value + + +def nvmlSystemGetConfComputeCapabilities(): + c_ccSysCaps = c_nvmlConfComputeSystemCaps_t() + fn = _nvmlGetFunctionPointer("nvmlSystemGetConfComputeCapabilities") + ret = fn(byref(c_ccSysCaps)) + _nvmlCheckReturn(ret) + return c_ccSysCaps + + +def nvmlSystemGetConfComputeState(): + c_state = c_nvmlConfComputeSystemState_t() + fn = _nvmlGetFunctionPointer("nvmlSystemGetConfComputeState") + ret = fn(byref(c_state)) + _nvmlCheckReturn(ret) + return c_state + + +def nvmlSystemGetConfComputeSettings(settings): + fn = _nvmlGetFunctionPointer("nvmlSystemGetConfComputeSettings") + return fn(settings) + + +def nvmlDeviceSetConfComputeUnprotectedMemSize(device, c_ccMemSize): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetConfComputeUnprotectedMemSize") + ret = fn(device, c_ccMemSize) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlDeviceGetConfComputeMemSizeInfo(device): + c_ccMemSize = c_nvmlConfComputeMemSizeInfo_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetConfComputeMemSizeInfo") + ret = fn(device, byref(c_ccMemSize)) + _nvmlCheckReturn(ret) + return c_ccMemSize + + +def nvmlDeviceGetConfComputeProtectedMemoryUsage(device): + c_memory = c_nvmlMemory_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetConfComputeProtectedMemoryUsage") + ret = fn(device, byref(c_memory)) + _nvmlCheckReturn(ret) + return c_memory + + +def nvmlDeviceGetConfComputeGpuCertificate(device): + c_cert = c_nvmlConfComputeGpuCertificate_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetConfComputeGpuCertificate") + ret = fn(device, byref(c_cert)) + _nvmlCheckReturn(ret) + return c_cert + + +def nvmlDeviceGetConfComputeGpuAttestationReport(device, c_nonce): + c_attestReport = c_nvmlConfComputeGpuAttestationReport_t() + c_nonce_arr = (c_uint8 * len(c_nonce))(*(c_nonce)) + setattr(c_attestReport, "nonce", c_nonce_arr) + fn = _nvmlGetFunctionPointer("nvmlDeviceGetConfComputeGpuAttestationReport") + ret = fn(device, byref(c_attestReport)) + _nvmlCheckReturn(ret) + return c_attestReport + + +def nvmlSystemSetConfComputeKeyRotationThresholdInfo(max_atk_adv): + c_keyRotationThrInfo = c_nvmlConfComputeSetKeyRotationThresholdInfo_t(0) + c_keyRotationThrInfo.version = ConfComputeSetKeyRotationThresholdInfo_v1 + c_keyRotationThrInfo.maxAttackerAdvantage = max_atk_adv + fn = _nvmlGetFunctionPointer("nvmlSystemSetConfComputeKeyRotationThresholdInfo") + ret = fn(byref(c_keyRotationThrInfo)) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlSystemGetConfComputeKeyRotationThresholdInfo(): + c_keyRotationThrInfo = c_nvmlConfComputeGetKeyRotationThresholdInfo_t(0) + c_keyRotationThrInfo.version = ConfComputeGetKeyRotationThresholdInfo_v1 + fn = _nvmlGetFunctionPointer("nvmlSystemGetConfComputeKeyRotationThresholdInfo") + ret = fn(byref(c_keyRotationThrInfo)) + _nvmlCheckReturn(ret) + return c_keyRotationThrInfo + + +## GPM ## +######### + +## Enums/defines + +#### GPM Metric Identifiers +NVML_GPM_METRIC_GRAPHICS_UTIL = ( + 1 # Percentage of time any compute/graphics app was active on the GPU. 0.0 - 100.0 +) +NVML_GPM_METRIC_SM_UTIL = 2 # Percentage of SMs that were busy. 0.0 - 100.0 +NVML_GPM_METRIC_SM_OCCUPANCY = ( + 3 # Percentage of warps that were active vs theoretical maximum. 0.0 - 100.0 +) +NVML_GPM_METRIC_INTEGER_UTIL = ( + 4 # Percentage of time the GPU's SMs were doing integer operations. 0.0 - 100.0 +) +NVML_GPM_METRIC_ANY_TENSOR_UTIL = ( + 5 # Percentage of time the GPU's SMs were doing ANY tensor operations. 0.0 - 100.0 +) +NVML_GPM_METRIC_DFMA_TENSOR_UTIL = ( + 6 # Percentage of time the GPU's SMs were doing DFMA tensor operations. 0.0 - 100.0 +) +NVML_GPM_METRIC_HMMA_TENSOR_UTIL = ( + 7 # Percentage of time the GPU's SMs were doing HMMA tensor operations. 0.0 - 100.0 +) +NVML_GPM_METRIC_IMMA_TENSOR_UTIL = ( + 9 # Percentage of time the GPU's SMs were doing IMMA tensor operations. 0.0 - 100.0 +) +NVML_GPM_METRIC_DRAM_BW_UTIL = ( + 10 # Percentage of DRAM bw used vs theoretical maximum. 0.0 - 100.0 +) +NVML_GPM_METRIC_FP64_UTIL = ( + 11 # Percentage of time the GPU's SMs were doing non-tensor FP64 math. 0.0 - 100.0 +) +NVML_GPM_METRIC_FP32_UTIL = ( + 12 # Percentage of time the GPU's SMs were doing non-tensor FP32 math. 0.0 - 100.0 +) +NVML_GPM_METRIC_FP16_UTIL = ( + 13 # Percentage of time the GPU's SMs were doing non-tensor FP16 math. 0.0 - 100.0 +) +NVML_GPM_METRIC_PCIE_TX_PER_SEC = 20 # PCIe traffic from this GPU in MiB/sec +NVML_GPM_METRIC_PCIE_RX_PER_SEC = 21 # PCIe traffic to this GPU in MiB/sec +NVML_GPM_METRIC_NVDEC_0_UTIL = 30 # Percent utilization of NVDEC 0. 0.0 - 100.0 +NVML_GPM_METRIC_NVDEC_1_UTIL = 31 # Percent utilization of NVDEC 1. 0.0 - 100.0 +NVML_GPM_METRIC_NVDEC_2_UTIL = 32 # Percent utilization of NVDEC 2. 0.0 - 100.0 +NVML_GPM_METRIC_NVDEC_3_UTIL = 33 # Percent utilization of NVDEC 3. 0.0 - 100.0 +NVML_GPM_METRIC_NVDEC_4_UTIL = 34 # Percent utilization of NVDEC 4. 0.0 - 100.0 +NVML_GPM_METRIC_NVDEC_5_UTIL = 35 # Percent utilization of NVDEC 5. 0.0 - 100.0 +NVML_GPM_METRIC_NVDEC_6_UTIL = 36 # Percent utilization of NVDEC 6. 0.0 - 100.0 +NVML_GPM_METRIC_NVDEC_7_UTIL = 37 # Percent utilization of NVDEC 7. 0.0 - 100.0 +NVML_GPM_METRIC_NVJPG_0_UTIL = 40 # Percent utilization of NVJPG 0. 0.0 - 100.0 +NVML_GPM_METRIC_NVJPG_1_UTIL = 41 # Percent utilization of NVJPG 1. 0.0 - 100.0 +NVML_GPM_METRIC_NVJPG_2_UTIL = 42 # Percent utilization of NVJPG 2. 0.0 - 100.0 +NVML_GPM_METRIC_NVJPG_3_UTIL = 43 # Percent utilization of NVJPG 3. 0.0 - 100.0 +NVML_GPM_METRIC_NVJPG_4_UTIL = 44 # Percent utilization of NVJPG 4. 0.0 - 100.0 +NVML_GPM_METRIC_NVJPG_5_UTIL = 45 # Percent utilization of NVJPG 5. 0.0 - 100.0 +NVML_GPM_METRIC_NVJPG_6_UTIL = 46 # Percent utilization of NVJPG 6. 0.0 - 100.0 +NVML_GPM_METRIC_NVJPG_7_UTIL = 47 # Percent utilization of NVJPG 7. 0.0 - 100.0 +NVML_GPM_METRIC_NVOFA_0_UTIL = 50 # Percent utilization of NVOFA 0. 0.0 - 100.0 +NVML_GPM_METRIC_NVOFA_1_UTIL = 51 # Percent utilization of NVOFA 1. 0.0 - 100.0 +NVML_GPM_METRIC_NVLINK_TOTAL_RX_PER_SEC = ( + 60 # NvLink read bandwidth for all links in MiB/sec +) +NVML_GPM_METRIC_NVLINK_TOTAL_TX_PER_SEC = ( + 61 # NvLink write bandwidth for all links in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L0_RX_PER_SEC = 62 # NvLink read bandwidth for link 0 in MiB/sec +NVML_GPM_METRIC_NVLINK_L0_TX_PER_SEC = ( + 63 # NvLink write bandwidth for link 0 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L1_RX_PER_SEC = 64 # NvLink read bandwidth for link 1 in MiB/sec +NVML_GPM_METRIC_NVLINK_L1_TX_PER_SEC = ( + 65 # NvLink write bandwidth for link 1 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L2_RX_PER_SEC = 66 # NvLink read bandwidth for link 2 in MiB/sec +NVML_GPM_METRIC_NVLINK_L2_TX_PER_SEC = ( + 67 # NvLink write bandwidth for link 2 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L3_RX_PER_SEC = 68 # NvLink read bandwidth for link 3 in MiB/sec +NVML_GPM_METRIC_NVLINK_L3_TX_PER_SEC = ( + 69 # NvLink write bandwidth for link 3 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L4_RX_PER_SEC = 70 # NvLink read bandwidth for link 4 in MiB/sec +NVML_GPM_METRIC_NVLINK_L4_TX_PER_SEC = ( + 71 # NvLink write bandwidth for link 4 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L5_RX_PER_SEC = 72 # NvLink read bandwidth for link 5 in MiB/sec +NVML_GPM_METRIC_NVLINK_L5_TX_PER_SEC = ( + 73 # NvLink write bandwidth for link 5 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L6_RX_PER_SEC = 74 # NvLink read bandwidth for link 6 in MiB/sec +NVML_GPM_METRIC_NVLINK_L6_TX_PER_SEC = ( + 75 # NvLink write bandwidth for link 6 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L7_RX_PER_SEC = 76 # NvLink read bandwidth for link 7 in MiB/sec +NVML_GPM_METRIC_NVLINK_L7_TX_PER_SEC = ( + 77 # NvLink write bandwidth for link 7 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L8_RX_PER_SEC = 78 # NvLink read bandwidth for link 8 in MiB/sec +NVML_GPM_METRIC_NVLINK_L8_TX_PER_SEC = ( + 79 # NvLink write bandwidth for link 8 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L9_RX_PER_SEC = 80 # NvLink read bandwidth for link 9 in MiB/sec +NVML_GPM_METRIC_NVLINK_L9_TX_PER_SEC = ( + 81 # NvLink write bandwidth for link 9 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L10_RX_PER_SEC = ( + 82 # NvLink read bandwidth for link 10 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L10_TX_PER_SEC = ( + 83 # NvLink write bandwidth for link 10 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L11_RX_PER_SEC = ( + 84 # NvLink read bandwidth for link 11 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L11_TX_PER_SEC = ( + 85 # NvLink write bandwidth for link 11 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L12_RX_PER_SEC = ( + 86 # NvLink read bandwidth for link 12 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L12_TX_PER_SEC = ( + 87 # NvLink write bandwidth for link 12 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L13_RX_PER_SEC = ( + 88 # NvLink read bandwidth for link 13 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L13_TX_PER_SEC = ( + 89 # NvLink write bandwidth for link 13 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L14_RX_PER_SEC = ( + 90 # NvLink read bandwidth for link 14 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L14_TX_PER_SEC = ( + 91 # NvLink write bandwidth for link 14 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L15_RX_PER_SEC = ( + 92 # NvLink read bandwidth for link 15 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L15_TX_PER_SEC = ( + 93 # NvLink write bandwidth for link 15 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L16_RX_PER_SEC = ( + 94 # NvLink read bandwidth for link 16 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L16_TX_PER_SEC = ( + 95 # NvLink write bandwidth for link 16 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L17_RX_PER_SEC = ( + 96 # NvLink read bandwidth for link 17 in MiB/sec +) +NVML_GPM_METRIC_NVLINK_L17_TX_PER_SEC = ( + 97 # NvLink write bandwidth for link 17 in MiB/sec +) +NVML_GPM_METRIC_MAX = 98 + +## Structs + + +class c_nvmlUnitInfo_t(_PrintableStructure): + _fields_ = [ + ("name", c_char * 96), + ("id", c_char * 96), + ("serial", c_char * 96), + ("firmwareVersion", c_char * 96), + ] + + +class struct_c_nvmlGpmSample_t(Structure): + pass # opaque handle + + +c_nvmlGpmSample_t = POINTER(struct_c_nvmlGpmSample_t) + + +class c_metricInfo_t(Structure): + _fields_ = [ + ("shortName", c_char_p), + ("longName", c_char_p), + ("unit", c_char_p), + ] + + +class c_nvmlGpmMetric_t(_PrintableStructure): + _fields_ = [ + ("metricId", c_uint), + ("nvmlReturn", _nvmlReturn_t), + ("value", c_double), + ("metricInfo", c_metricInfo_t), + ] + + +class c_nvmlGpmMetricsGet_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("numMetrics", c_uint), + ("sample1", c_nvmlGpmSample_t), + ("sample2", c_nvmlGpmSample_t), + ("metrics", c_nvmlGpmMetric_t * NVML_GPM_METRIC_MAX), + ] + + +NVML_GPM_METRICS_GET_VERSION = 1 + + +class c_nvmlGpmSupport_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("isSupportedDevice", c_uint), + ] + + +NVML_GPM_SUPPORT_VERSION = 1 + +## Functions + + +def nvmlGpmMetricsGet(metricsGet): + fn = _nvmlGetFunctionPointer("nvmlGpmMetricsGet") + ret = fn(byref(metricsGet)) + _nvmlCheckReturn(ret) + return metricsGet + + +def nvmlGpmSampleFree(gpmSample): + fn = _nvmlGetFunctionPointer("nvmlGpmSampleFree") + ret = fn(gpmSample) + _nvmlCheckReturn(ret) + return + + +def nvmlGpmSampleAlloc(): + gpmSample = c_nvmlGpmSample_t() + fn = _nvmlGetFunctionPointer("nvmlGpmSampleAlloc") + ret = fn(byref(gpmSample)) + _nvmlCheckReturn(ret) + return gpmSample + + +def nvmlGpmSampleGet(device, gpmSample): + fn = _nvmlGetFunctionPointer("nvmlGpmSampleGet") + ret = fn(device, gpmSample) + _nvmlCheckReturn(ret) + return gpmSample + + +def nvmlGpmMigSampleGet(device, gpuInstanceId, gpmSample): + fn = _nvmlGetFunctionPointer("nvmlGpmMigSampleGet") + ret = fn(device, gpuInstanceId, gpmSample) + _nvmlCheckReturn(ret) + return gpmSample + + +def nvmlGpmQueryDeviceSupport(device): + gpmSupport = c_nvmlGpmSupport_t() + gpmSupport.version = NVML_GPM_SUPPORT_VERSION + fn = _nvmlGetFunctionPointer("nvmlGpmQueryDeviceSupport") + ret = fn(device, byref(gpmSupport)) + _nvmlCheckReturn(ret) + return gpmSupport + + +def nvmlGpmSetStreamingEnabled(device, state): + c_state = c_uint(state) + fn = _nvmlGetFunctionPointer("nvmlGpmSetStreamingEnabled") + ret = fn(device, c_state) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlGpmQueryIfStreamingEnabled(device): + c_state = c_uint() + fn = _nvmlGetFunctionPointer("nvmlGpmQueryIfStreamingEnabled") + ret = fn(device, byref(c_state)) + _nvmlCheckReturn(ret) + return c_state.value + + +# Low Power Structure and Function + +NVML_NVLINK_POWER_STATE_HIGH_SPEED = 0x0 +NVML_NVLINK_POWER_STATE_LOW = 0x1 + +NVML_NVLINK_LOW_POWER_THRESHOLD_MIN = 0x1 +NVML_NVLINK_LOW_POWER_THRESHOLD_MAX = 0x1FFF +NVML_NVLINK_LOW_POWER_THRESHOLD_RESET = 0xFFFFFFFF +NVML_NVLINK_LOW_POWER_THRESHOLD_DEFAULT = NVML_NVLINK_LOW_POWER_THRESHOLD_RESET + + +class c_nvmlNvLinkPowerThres_t(Structure): + _fields_ = [ + ("lowPwrThreshold", c_uint), + ] + + +def nvmlDeviceSetNvLinkDeviceLowPowerThreshold(device, l1threshold): + c_info = c_nvmlNvLinkPowerThres_t() + c_info.lowPwrThreshold = l1threshold + fn = _nvmlGetFunctionPointer("nvmlDeviceSetNvLinkDeviceLowPowerThreshold") + ret = fn(device, byref(c_info)) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +NVML_GPU_FABRIC_UUID_LEN = 16 + +_nvmlGpuFabricState_t = c_uint +NVML_GPU_FABRIC_STATE_NOT_SUPPORTED = 0 +NVML_GPU_FABRIC_STATE_NOT_STARTED = 1 +NVML_GPU_FABRIC_STATE_IN_PROGRESS = 2 +NVML_GPU_FABRIC_STATE_COMPLETED = 3 + + +class c_nvmlGpuFabricInfo_t(_PrintableStructure): + _fields_ = [ + ("clusterUuid", c_char * NVML_DEVICE_UUID_BUFFER_SIZE), + ("status", _nvmlReturn_t), + ("cliqueId", c_uint32), + ("state", _nvmlGpuFabricState_t), + ] + + +NVML_GPU_FABRIC_HEALTH_MASK_DEGRADED_BW_NOT_SUPPORTED = 0 +NVML_GPU_FABRIC_HEALTH_MASK_DEGRADED_BW_TRUE = 1 +NVML_GPU_FABRIC_HEALTH_MASK_DEGRADED_BW_FALSE = 2 +NVML_GPU_FABRIC_HEALTH_MASK_SHIFT_DEGRADED_BW = 0 +NVML_GPU_FABRIC_HEALTH_MASK_WIDTH_DEGRADED_BW = 0x11 + +NVML_GPU_FABRIC_HEALTH_MASK_ROUTE_RECOVERY_NOT_SUPPORTED = 0 +NVML_GPU_FABRIC_HEALTH_MASK_ROUTE_RECOVERY_TRUE = 1 +NVML_GPU_FABRIC_HEALTH_MASK_ROUTE_RECOVERY_FALSE = 2 +NVML_GPU_FABRIC_HEALTH_MASK_SHIFT_ROUTE_RECOVERY = 2 +NVML_GPU_FABRIC_HEALTH_MASK_WIDTH_ROUTE_RECOVERY = 0x11 + +NVML_GPU_FABRIC_HEALTH_MASK_ROUTE_UNHEALTHY_NOT_SUPPORTED = 0 +NVML_GPU_FABRIC_HEALTH_MASK_ROUTE_UNHEALTHY_TRUE = 1 +NVML_GPU_FABRIC_HEALTH_MASK_ROUTE_UNHEALTHY_FALSE = 2 +NVML_GPU_FABRIC_HEALTH_MASK_SHIFT_ROUTE_UNHEALTHY = 4 +NVML_GPU_FABRIC_HEALTH_MASK_WIDTH_ROUTE_UNHEALTHY = 0x11 + +NVML_GPU_FABRIC_HEALTH_MASK_ACCESS_TIMEOUT_RECOVERY_NOT_SUPPORTED = 0 +NVML_GPU_FABRIC_HEALTH_MASK_ACCESS_TIMEOUT_RECOVERY_TRUE = 1 +NVML_GPU_FABRIC_HEALTH_MASK_ACCESS_TIMEOUT_RECOVERY_FALSE = 2 +NVML_GPU_FABRIC_HEALTH_MASK_SHIFT_ACCESS_TIMEOUT_RECOVERY = 6 +NVML_GPU_FABRIC_HEALTH_MASK_WIDTH_ACCESS_TIMEOUT_RECOVERY = 0x11 + +nvmlGpuFabricInfo_v2 = 0x02000024 + + +class c_nvmlGpuFabricInfoV_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("clusterUuid", c_char * NVML_GPU_FABRIC_UUID_LEN), + ("status", _nvmlReturn_t), + ("cliqueId", c_uint32), + ("state", _nvmlGpuFabricState_t), + ("healthMask", c_uint32), + ] + + def __init__(self): + super(c_nvmlGpuFabricInfoV_t, self).__init__(version=nvmlGpuFabricInfo_v2) + + +def nvmlDeviceGetGpuFabricInfo(device, gpuFabricInfo): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuFabricInfo") + ret = fn(device, gpuFabricInfo) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlDeviceGetGpuFabricInfoV(device, gpuFabricInfo): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetGpuFabricInfoV") + ret = fn(device, gpuFabricInfo) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +###################### +## Enums/defines +#### NVML GPU NVLINK BW MODE +NVML_GPU_NVLINK_BW_MODE_FULL = 0x0 +NVML_GPU_NVLINK_BW_MODE_OFF = 0x1 +NVML_GPU_NVLINK_BW_MODE_MIN = 0x2 +NVML_GPU_NVLINK_BW_MODE_HALF = 0x3 +NVML_GPU_NVLINK_BW_MODE_3QUARTER = 0x4 +NVML_GPU_NVLINK_BW_MODE_COUNT = 0x5 + + +def nvmlSystemSetNvlinkBwMode(mode): + fn = _nvmlGetFunctionPointer("nvmlSystemSetNvlinkBwMode") + ret = fn(mode) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlSystemGetNvlinkBwMode(): + mode = c_uint() + fn = _nvmlGetFunctionPointer("nvmlSystemGetNvlinkBwMode") + ret = fn(byref(mode)) + _nvmlCheckReturn(ret) + return mode.value + + +_nvmlPowerScopeType_t = c_uint +NVML_POWER_SCOPE_GPU = 0 +NVML_POWER_SCOPE_MODULE = 1 +NVML_POWER_SCOPE_MEMORY = 2 + + +class c_nvmlPowerValue_v2_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("powerScope", _nvmlPowerScopeType_t), + ("powerValueMw", c_uint), + ] + _fmt_ = {"": "%d B"} + + +nvmlPowerValue_v2 = 0x0200000C + + +def nvmlDeviceSetPowerManagementLimit_v2( + device, powerScope, powerLimit, version=nvmlPowerValue_v2 +): + c_powerScope = _nvmlPowerScopeType_t(powerScope) + c_powerValue = c_nvmlPowerValue_v2_t() + c_powerValue.version = c_uint(version) + c_powerValue.powerScope = c_powerScope + c_powerValue.powerValueMw = c_uint(powerLimit) + fn = _nvmlGetFunctionPointer("nvmlDeviceSetPowerManagementLimit_v2") + ret = fn(device, byref(c_powerValue)) + return NVML_SUCCESS + + +class c_nvmlEccSramErrorStatus_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("aggregateUncParity", c_ulonglong), + ("aggregateUncSecDed", c_ulonglong), + ("aggregateCor", c_ulonglong), + ("volatileUncParity", c_ulonglong), + ("volatileUncSecDed", c_ulonglong), + ("volatileCor", c_ulonglong), + ("aggregateUncBucketL2", c_ulonglong), + ("aggregateUncBucketSm", c_ulonglong), + ("aggregateUncBucketPcie", c_ulonglong), + ("aggregateUncBucketMcu", c_ulonglong), + ("aggregateUncBucketOther", c_ulonglong), + ("bThresholdExceeded", c_uint), + ] + + def __init__(self): + super(c_nvmlEccSramErrorStatus_v1_t, self).__init__( + version=nvmlEccSramErrorStatus_v1 + ) + + +nvmlEccSramErrorStatus_v1 = 0x1000068 + + +def nvmlDeviceGetSramEccErrorStatus(device, status): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetSramEccErrorStatus") + ret = fn(device, status) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +NVML_DEV_CAP_EGM = 1 << 0 +nvmlDeviceCapabilities_v1 = 0x1000008 + + +class c_nvmlDeviceCapabilities_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("capMask", c_uint), + ] + + def __init__(self): + super(c_nvmlDeviceCapabilities_v1_t, self).__init__( + version=nvmlDeviceCapabilities_v1 + ) + + +def nvmlDeviceGetCapabilities(device, caps): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetCapabilities") + return fn(device, caps) + + +class c_nvmlPlatformInfo_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("ibGuid", c_char * 16), + ("rackGuid", c_char * 16), + ("chassisPhysicalSlotNumber", c_char), + ("computeSlotIndex", c_char), + ("nodeIndex", c_char), + ("peerType", c_char), + ("moduleId", c_char), + ] + + def __init__(self): + super(c_nvmlPlatformInfo_v1_t, self).__init__(version=nvmlPlatformInfo_v1) + + +nvmlPlatformInfo_v1 = 0x100002C + + +def nvmlDeviceGetPlatformInfo(device, platformInfo): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetPlatformInfo") + ret = fn(device, platformInfo) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +class c_nvmlMask255_t(_PrintableStructure): + _fields_ = [ + ("mask", c_uint * 8), + ] + + +NVML_WORKLOAD_POWER_MAX_PROFILES = 255 +NVML_POWER_PROFILE_MAX_P = 0 +NVML_POWER_PROFILE_MAX_Q = 1 +NVML_POWER_PROFILE_COMPUTE = 2 +NVML_POWER_PROFILE_MEMORY_BOUND = 3 +NVML_POWER_PROFILE_NETWORK = 4 +NVML_POWER_PROFILE_BALANCED = 5 +NVML_POWER_PROFILE_LLM_INFERENCE = 6 +NVML_POWER_PROFILE_LLM_TRAINING = 7 +NVML_POWER_PROFILE_RBM = 8 +NVML_POWER_PROFILE_DCPCIE = 9 +NVML_POWER_PROFILE_HMMA_SPARSE = 10 +NVML_POWER_PROFILE_HMMA_DENSE = 11 +NVML_POWER_PROFILE_SYNC_BALANCED = 12 +NVML_POWER_PROFILE_HPC = 13 +NVML_POWER_PROFILE_MIG = 14 +NVML_POWER_PROFILE_MAX = 15 + +nvmlWorkloadPowerProfileInfo_v1 = 0x100002C + + +class c_nvmlWorkloadPowerProfileInfo_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("profileId", c_uint), + ("priority", c_uint), + ("conflictingmask", c_nvmlMask255_t), + ] + + def __init__(self): + super(c_nvmlWorkloadPowerProfileInfo_v1_t, self).__init__( + version=nvmlWorkloadPowerProfileInfo_v1 + ) + + +nvmlWorkloadPowerProfileProfilesInfo_v1 = 0x1002BF8 + + +class c_nvmlWorkloadPowerProfileProfilesInfo_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("perfProfilesMask", c_nvmlMask255_t), + ( + "perfProfile", + c_nvmlWorkloadPowerProfileInfo_v1_t * NVML_WORKLOAD_POWER_MAX_PROFILES, + ), + ] + + def __init__(self): + super(c_nvmlWorkloadPowerProfileProfilesInfo_v1_t, self).__init__( + version=nvmlWorkloadPowerProfileProfilesInfo_v1 + ) + + +nvmlWorkloadPowerProfileCurrentProfiles_v1 = 0x1000064 + + +class c_nvmlWorkloadPowerProfileCurrentProfiles_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("perfProfilesMask", c_nvmlMask255_t), + ("requestedProfilesMask", c_nvmlMask255_t), + ("enforcedProfilesMask", c_nvmlMask255_t), + ] + + def __init__(self): + super(c_nvmlWorkloadPowerProfileCurrentProfiles_v1_t, self).__init__( + version=nvmlWorkloadPowerProfileCurrentProfiles_v1 + ) + + +nvmlWorkloadPowerProfileRequestedProfiles_v1 = 0x1000024 + + +class c_nvmlWorkloadPowerProfileRequestedProfiles_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("requestedProfilesMask", c_nvmlMask255_t), + ] + + def __init__(self): + super(c_nvmlWorkloadPowerProfileRequestedProfiles_v1_t, self).__init__( + version=nvmlWorkloadPowerProfileRequestedProfiles_v1 + ) + + +def nvmlDeviceWorkloadPowerProfileGetProfilesInfo(device, profilesInfo): + fn = _nvmlGetFunctionPointer("nvmlDeviceWorkloadPowerProfileGetProfilesInfo") + ret = fn(device, profilesInfo) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlDeviceWorkloadPowerProfileGetCurrentProfiles(device, currentProfiles): + fn = _nvmlGetFunctionPointer("nvmlDeviceWorkloadPowerProfileGetCurrentProfiles") + ret = fn(device, currentProfiles) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlDeviceWorkloadPowerProfileSetRequestedProfiles(device, requestedProfiles): + fn = _nvmlGetFunctionPointer("nvmlDeviceWorkloadPowerProfileSetRequestedProfiles") + ret = fn(device, requestedProfiles) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlDeviceWorkloadPowerProfileClearRequestedProfiles(device, requestedProfiles): + fn = _nvmlGetFunctionPointer("nvmlDeviceWorkloadPowerProfileClearRequestedProfiles") + ret = fn(device, requestedProfiles) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlDeviceGetNvlinkSupportedBwModes(device, supportedBwModes): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvlinkSupportedBwModes") + ret = fn(device, supportedBwModes) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlDeviceGetNvlinkBwMode(device, getBwMode): + fn = _nvmlGetFunctionPointer("nvmlDeviceGetNvlinkBwMode") + ret = fn(device, getBwMode) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +def nvmlDeviceSetNvlinkBwMode(device, setBwMode): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetNvlinkBwMode") + ret = fn(device, setBwMode) + _nvmlCheckReturn(ret) + return NVML_SUCCESS + + +nvmlDramEncryptionInfo_v1 = 0x01000008 + + +class c_nvmlDramEncryptionInfo_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("encryptionState", _nvmlEnableState_t), + ] + + def __init__(self): + super(c_nvmlDramEncryptionInfo_t, self).__init__( + version=nvmlDramEncryptionInfo_v1 + ) + + +def nvmlDeviceGetDramEncryptionMode(handle): + c_currState = c_nvmlDramEncryptionInfo_t() + c_pendingState = c_nvmlDramEncryptionInfo_t() + fn = _nvmlGetFunctionPointer("nvmlDeviceGetDramEncryptionMode") + ret = fn(handle, byref(c_currState), byref(c_pendingState)) + _nvmlCheckReturn(ret) + return [c_currState.encryptionState, c_pendingState.encryptionState] + + +# added to API +def nvmlDeviceGetCurrentDramEncryptionMode(handle): + return nvmlDeviceGetDramEncryptionMode(handle)[0] + + +# added to API +def nvmlDeviceGetPendingDramEncryptionMode(handle): + return nvmlDeviceGetDramEncryptionMode(handle)[1] + + +def nvmlDeviceSetDramEncryptionMode(handle, mode): + fn = _nvmlGetFunctionPointer("nvmlDeviceSetDramEncryptionMode") + c_dramEncryptionMode = c_nvmlDramEncryptionInfo_t() + c_dramEncryptionMode.encryptionState = mode + ret = fn(handle, byref(c_dramEncryptionMode)) + _nvmlCheckReturn(ret) + return None + + +# Power Smoothing defines +NVML_POWER_SMOOTHING_MAX_NUM_PROFILES = 5 +NVML_POWER_SMOOTHING_ADMIN_OVERRIDE_NOT_SET = 0xFFFFFFFF +NVML_POWER_SMOOTHING_PROFILE_PARAM_PERCENT_TMP_FLOOR = 0 +NVML_POWER_SMOOTHING_PROFILE_PARAM_RAMP_UP_RATE = 1 +NVML_POWER_SMOOTHING_PROFILE_PARAM_RAMP_DOWN_RATE = 2 +NVML_POWER_SMOOTHING_PROFILE_PARAM_RAMP_DOWN_HYSTERESIS = 3 + +nvmlPowerSmoothingState_v1 = 0x1000008 + + +class c_nvmlPowerSmoothingState_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("state", c_uint), + ] + + def __init__(self): + super(c_nvmlPowerSmoothingState_v1_t, self).__init__( + version=nvmlPowerSmoothingState_v1 + ) + + +nvmlPowerSmoothingProfile_v1 = 0x1000018 + + +class c_nvmlPowerSmoothingProfile_v1_t(_PrintableStructure): + _fields_ = [ + ("version", c_uint), + ("profileId", c_uint), + ("paramId", c_uint), + ("value", c_double), + ] + + def __init__(self): + super(c_nvmlPowerSmoothingProfile_v1_t, self).__init__( + version=nvmlPowerSmoothingProfile_v1 + ) + + +def nvmlDevicePowerSmoothingActivatePresetProfile(device, profile): + fn = _nvmlGetFunctionPointer("nvmlDevicePowerSmoothingActivatePresetProfile") + ret = fn(device, profile) + _nvmlCheckReturn(ret) + + +def nvmlDevicePowerSmoothingUpdatePresetProfileParam(device, profile): + fn = _nvmlGetFunctionPointer("nvmlDevicePowerSmoothingUpdatePresetProfileParam") + ret = fn(device, profile) + _nvmlCheckReturn(ret) + + +def nvmlDevicePowerSmoothingSetState(device, state): + fn = _nvmlGetFunctionPointer("nvmlDevicePowerSmoothingSetState") + ret = fn(device, state) + _nvmlCheckReturn(ret) diff --git a/python/sglang/multimodal_gen/utils.py b/python/sglang/multimodal_gen/utils.py new file mode 100644 index 000000000..655af2c1e --- /dev/null +++ b/python/sglang/multimodal_gen/utils.py @@ -0,0 +1,777 @@ +# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo + +# SPDX-License-Identifier: Apache-2.0 +# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/utils.py + +import argparse +import ctypes +import importlib +import importlib.util +import inspect +import math +import os +import signal +import socket +import sys +import threading +import traceback +from collections.abc import Callable +from dataclasses import dataclass, fields, is_dataclass +from functools import lru_cache, partial, wraps +from typing import Any, TypeVar, cast + +import cloudpickle +import imageio +import numpy as np +import torch +import torchvision +import yaml +from einops import rearrange +from remote_pdb import RemotePdb +from torch.distributed.fsdp import MixedPrecisionPolicy + +import sglang.multimodal_gen.envs as envs +from sglang.multimodal_gen.runtime.utils.logging_utils import ( + SortedHelpFormatter, + init_logger, +) + +logger = init_logger(__name__) + +T = TypeVar("T") + +# TODO(will): used to convert server_args.precision to torch.dtype. Find a +# cleaner way to do this. +PRECISION_TO_TYPE = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} + +STR_BACKEND_ENV_VAR: str = "SGL_DIFFUSION_ATTENTION_BACKEND" +STR_ATTN_CONFIG_ENV_VAR: str = "SGL_DIFFUSION_ATTENTION_CONFIG" + + +def find_nccl_library() -> str: + """ + We either use the library file specified by the `VLLM_NCCL_SO_PATH` + environment variable, or we find the library file brought by PyTorch. + After importing `torch`, `libnccl.so.2` or `librccl.so.1` can be + found by `ctypes` automatically. + """ + so_file = envs.SGL_DIFFUSION_NCCL_SO_PATH + + # manually load the nccl library + if so_file: + logger.info( + "Found nccl from environment variable SGL_DIFFUSION_NCCL_SO_PATH=%s", + so_file, + ) + else: + if torch.version.cuda is not None: + so_file = "libnccl.so.2" + elif torch.version.hip is not None: + so_file = "librccl.so.1" + else: + raise ValueError("NCCL only supports CUDA and ROCm backends.") + logger.info("Found nccl from library %s", so_file) + return str(so_file) + + +prev_set_stream = torch.cuda.set_stream + +_current_stream = None + + +def _patched_set_stream(stream: torch.cuda.Stream | None) -> None: + global _current_stream + _current_stream = stream + if stream is not None: + prev_set_stream(stream) + + +torch.cuda.set_stream = _patched_set_stream + + +def current_stream() -> torch.cuda.Stream | None: + """ + replace `torch.cuda.current_stream()` with `sglang.multimodal_gen.utils.current_stream()`. + it turns out that `torch.cuda.current_stream()` is quite expensive, + as it will construct a new stream object at each call. + here we patch `torch.cuda.set_stream` to keep track of the current stream + directly, so that we can avoid calling `torch.cuda.current_stream()`. + + the underlying hypothesis is that we do not call `torch._C._cuda_setStream` + from C/C++ code. + """ + from sglang.multimodal_gen.runtime.platforms import current_platform + + # For non-CUDA platforms, return None + if not current_platform.is_cuda_alike(): + return None + + global _current_stream + if _current_stream is None: + # when this function is called before any stream is set, + # we return the default stream. + # On ROCm using the default 0 stream in combination with RCCL + # is hurting performance. Therefore creating a dedicated stream + # per process + _current_stream = ( + torch.cuda.Stream() + if current_platform.is_rocm() + else torch.cuda.current_stream() + ) + return _current_stream + + +class StoreBoolean(argparse.Action): + + def __init__(self, option_strings, dest, default=False, required=False, help=None): + super().__init__( + option_strings=option_strings, + dest=dest, + nargs="?", + const=True, + default=default, + required=required, + help=help, + ) + + def __call__(self, parser, namespace, values, option_string=None): + if values is None: + setattr(namespace, self.dest, True) + elif isinstance(values, str): + if values.lower() == "true": + setattr(namespace, self.dest, True) + elif values.lower() == "false": + setattr(namespace, self.dest, False) + else: + raise ValueError( + f"Invalid boolean value: {values}. " "Expected 'true' or 'false'." + ) + else: + setattr(namespace, self.dest, bool(values)) + + +class FlexibleArgumentParser(argparse.ArgumentParser): + """ArgumentParser that allows both underscore and dash in names.""" + + def __init__(self, *args, **kwargs) -> None: + # Set the default 'formatter_class' to SortedHelpFormatter + if "formatter_class" not in kwargs: + kwargs["formatter_class"] = SortedHelpFormatter + super().__init__(*args, **kwargs) + + def parse_args( # type: ignore[override] + self, args=None, namespace=None + ) -> argparse.Namespace: + if args is None: + args = sys.argv[1:] + + if any(arg.startswith("--config") for arg in args): + args = self._pull_args_from_config(args) + + # Convert underscores to dashes and vice versa in argument names + processed_args = [] + for arg in args: + if arg.startswith("--"): + if "=" in arg: + key, value = arg.split("=", 1) + key = "--" + key[len("--") :].replace("_", "-") + processed_args.append(f"{key}={value}") + else: + processed_args.append("--" + arg[len("--") :].replace("_", "-")) + elif arg.startswith("-O") and arg != "-O" and len(arg) == 2: + # allow -O flag to be used without space, e.g. -O3 + processed_args.append("-O") + processed_args.append(arg[2:]) + else: + processed_args.append(arg) + + namespace = super().parse_args(processed_args, namespace) + + # Track which arguments were explicitly provided + namespace._provided = set() + + i = 0 + while i < len(args): + arg = args[i] + if arg.startswith("--"): + # Handle --key=value format + if "=" in arg: + key = arg.split("=")[0][2:].replace("-", "_") + namespace._provided.add(key) + i += 1 + # Handle --key value format + else: + key = arg[2:].replace("-", "_") + namespace._provided.add(key) + # Skip the value if there is one + if i + 1 < len(args) and not args[i + 1].startswith("-"): + i += 2 + else: + i += 1 + else: + i += 1 + + return namespace # type: ignore[no-any-return] + + def _pull_args_from_config(self, args: list[str]) -> list[str]: + """Method to pull arguments specified in the config file + into the command-line args variable. + + The arguments in config file will be inserted between + the argument list. + + example: + ```yaml + port: 12323 + tensor-parallel-size: 4 + ``` + ```python + $: vllm {serve,chat,complete} "facebook/opt-12B" \ + --config config.yaml -tp 2 + $: args = [ + "serve,chat,complete", + "facebook/opt-12B", + '--config', 'config.yaml', + '-tp', '2' + ] + $: args = [ + "serve,chat,complete", + "facebook/opt-12B", + '--port', '12323', + '--tp-size', '4', + '-tp', '2' + ] + ``` + + Please note how the config args are inserted after the sub command. + this way the order of priorities is maintained when these are args + parsed by super(). + """ + index = -1 + config_arg = None + for i, arg in enumerate(args): + if arg.startswith("--config"): + if index != -1: + raise ValueError("More than one config file specified!") + index = i + config_arg = arg + + if config_arg is None: + return args + args_before_config = args[:index] + if "=" in config_arg: + file_path = config_arg.split("=", 1)[1] + args_after_config = args[index + 1 :] + else: + if index == len(args) - 1: + raise ValueError( + "No config file specified! " + "Please check your command-line arguments." + ) + file_path = args[index + 1] + args_after_config = args[index + 2 :] + + config_args = self._load_config_file(file_path) + + # 0th index is for {serve,chat,complete} + # followed by model_tag (only for serve) + # followed by config args + # followed by rest of cli args. + # maintaining this order will enforce the precedence + # of cli > config > defaults + if args[0] == "serve": + if index == 1: + raise ValueError( + "No model_tag specified! Please check your command-line" + " arguments." + ) + command = args_before_config[0] + model_tag = args_before_config[1] + other_args_before = args_before_config[2:] + args = ( + [command, model_tag] + + config_args + + other_args_before + + args_after_config + ) + else: + command = args_before_config[0] + other_args_before = args_before_config[1:] + args = [command] + config_args + other_args_before + args_after_config + + return args + + def _load_config_file(self, file_path: str) -> list[str]: + """Loads a yaml file and returns the key value pairs as a + flattened list with argparse like pattern + ```yaml + port: 12323 + tensor-parallel-size: 4 + vae_config: + load_encoder: false + load_decoder: true + ``` + returns: + processed_args: list[str] = [ + '--port': '12323', + '--tp-size': '4', + '--vae-config.load-encoder': 'false', + '--vae-config.load-decoder': 'true' + ] + """ + + extension: str = file_path.split(".")[-1] + if extension not in ("yaml", "yml", "json"): + raise ValueError( + "Config file must be of a yaml/yml/json type.\ + %s supplied", + extension, + ) + + processed_args: list[str] = [] + + config: dict[str, Any] = {} + try: + with open(file_path) as config_file: + config = yaml.safe_load(config_file) + except Exception as ex: + logger.error( + "Unable to read the config file at %s. \ + Make sure path is correct", + file_path, + ) + raise ex + + store_boolean_arguments = [ + action.dest for action in self._actions if isinstance(action, StoreBoolean) + ] + + def process_dict(prefix: str, d: dict[str, Any]): + for key, value in d.items(): + full_key = f"{prefix}.{key}" if prefix else key + + if isinstance(value, bool) and full_key not in store_boolean_arguments: + if value: + processed_args.append("--" + full_key) + else: + processed_args.append("--" + full_key) + processed_args.append("false") + elif isinstance(value, list): + processed_args.append("--" + full_key) + for item in value: + processed_args.append(str(item)) + elif isinstance(value, dict): + process_dict(full_key, value) + else: + processed_args.append("--" + full_key) + processed_args.append(str(value)) + + process_dict("", config) + + return processed_args + + +def warn_for_unimplemented_methods(cls: type[T]) -> type[T]: + """ + A replacement for `abc.ABC`. + When we use `abc.ABC`, subclasses will fail to instantiate + if they do not implement all abstract methods. + Here, we only require `raise NotImplementedError` in the + base class, and log a warning if the method is not implemented + in the subclass. + """ + + original_init = cls.__init__ + + def find_unimplemented_methods(self: object): + unimplemented_methods = [] + for attr_name in dir(self): + # bypass inner method + if attr_name.startswith("_"): + continue + + try: + attr = getattr(self, attr_name) + # get the func of callable method + if callable(attr): + attr_func = attr.__func__ + except AttributeError: + continue + src = inspect.getsource(attr_func) + if "NotImplementedError" in src: + unimplemented_methods.append(attr_name) + if unimplemented_methods: + method_names = ",".join(unimplemented_methods) + msg = f"Methods {method_names} not implemented in {self}" + logger.warning(msg) + + @wraps(original_init) + def wrapped_init(self, *args, **kwargs) -> None: + original_init(self, *args, **kwargs) + find_unimplemented_methods(self) + + type.__setattr__(cls, "__init__", wrapped_init) + return cls + + +def align_to(value: int, alignment: int) -> int: + """align height, width according to alignment + + Args: + value (int): height or width + alignment (int): target alignment factor + + Returns: + int: the aligned value + """ + return int(math.ceil(value / alignment) * alignment) + + +def resolve_obj_by_qualname(qualname: str) -> Any: + """ + Resolve an object by its fully qualified name. + """ + module_name, obj_name = qualname.rsplit(".", 1) + module = importlib.import_module(module_name) + return getattr(module, obj_name) + + +# From vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/utils.py +def import_pynvml(): + """ + Historical comments: + + libnvml.so is the library behind nvidia-smi, and + pynvml is a Python wrapper around it. We use it to get GPU + status without initializing CUDA context in the current process. + Historically, there are two packages that provide pynvml: + - `nvidia-ml-py` (https://pypi.org/project/nvidia-ml-py/): The official + wrapper. It is a dependency of sgl-diffusion, and is installed when users + install sgl-diffusion. It provides a Python module named `pynvml`. + - `pynvml` (https://pypi.org/project/pynvml/): An unofficial wrapper. + Prior to version 12.0, it also provides a Python module `pynvml`, + and therefore conflicts with the official one which is a standalone Python file. + This causes errors when both of them are installed. + Starting from version 12.0, it migrates to a new module + named `pynvml_utils` to avoid the conflict. + It is so confusing that many packages in the community use the + unofficial one by mistake, and we have to handle this case. + For example, `nvcr.io/nvidia/pytorch:24.12-py3` uses the unofficial + one, and it will cause errors, see the issue + https://github.com/vllm-project/vllm/issues/12847 for example. + After all the troubles, we decide to copy the official `pynvml` + module to our codebase, and use it directly. + """ + import sglang.multimodal_gen.third_party.pynvml as pynvml + + return pynvml + + +def update_environment_variables(envs: dict[str, str]): + for k, v in envs.items(): + if k in os.environ and os.environ[k] != v: + logger.warning( + "Overwriting environment variable %s " "from '%s' to '%s'", + k, + os.environ[k], + v, + ) + os.environ[k] = v + + +def run_method( + obj: Any, method: str | bytes | Callable, args: tuple[Any], kwargs: dict[str, Any] +) -> Any: + """ + Run a method of an object with the given arguments and keyword arguments. + If the method is string, it will be converted to a method using getattr. + If the method is serialized bytes and will be deserialized using + cloudpickle. + If the method is a callable, it will be called directly. + """ + if isinstance(method, bytes): + func = partial(cloudpickle.loads(method), obj) + elif isinstance(method, str): + try: + func = getattr(obj, method) + except AttributeError: + raise NotImplementedError( + f"Method {method!r} is not" " implemented." + ) from None + else: + func = partial(method, obj) # type: ignore + return func(*args, **kwargs) + + +def shallow_asdict(obj) -> dict[str, Any]: + if not is_dataclass(obj): + raise TypeError("Expected dataclass instance") + return {f.name: getattr(obj, f.name) for f in fields(obj)} + + +# TODO: validate that this is fine +def kill_itself_when_parent_died() -> None: + # if sys.platform == "linux": + # sigkill this process when parent worker manager dies + PR_SET_PDEATHSIG = 1 + import platform + + if platform.system() == "Linux": + libc = ctypes.CDLL("libc.so.6") + libc.prctl(PR_SET_PDEATHSIG, signal.SIGKILL) + # elif platform.system() == "Darwin": + # libc = ctypes.CDLL("libc.dylib") + # logger.warning("kill_itself_when_parent_died is only supported in linux.") + else: + logger.warning("kill_itself_when_parent_died is only supported in linux.") + + +def get_exception_traceback() -> str: + etype, value, tb = sys.exc_info() + err_str = "".join(traceback.format_exception(etype, value, tb)) + return err_str + + +class TypeBasedDispatcher: + + def __init__(self, mapping: list[tuple[type, Callable]]): + self._mapping = mapping + + def __call__(self, obj: Any): + for ty, fn in self._mapping: + if isinstance(obj, ty): + return fn(obj) + raise ValueError(f"Invalid object: {obj}") + + +# For non-torch.distributed debugging +def remote_breakpoint() -> None: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind(("localhost", 0)) # Let the OS pick an ephemeral port. + port = s.getsockname()[1] + RemotePdb(host="localhost", port=port).set_trace() + + +@dataclass +class MixedPrecisionState: + param_dtype: torch.dtype | None = None + reduce_dtype: torch.dtype | None = None + output_dtype: torch.dtype | None = None + compute_dtype: torch.dtype | None = None + mp_policy: MixedPrecisionPolicy | None = None + + +# Thread-local storage for mixed precision state +_mixed_precision_state = threading.local() + + +def get_mixed_precision_state() -> MixedPrecisionState: + """Get the current mixed precision state.""" + if not hasattr(_mixed_precision_state, "state"): + raise ValueError("Mixed precision state not set") + return cast(MixedPrecisionState, _mixed_precision_state.state) + + +def set_mixed_precision_policy( + param_dtype: torch.dtype, + reduce_dtype: torch.dtype, + output_dtype: torch.dtype | None = None, + mp_policy: MixedPrecisionPolicy | None = None, +): + """Set mixed precision policy globally. + + Args: + param_dtype: Parameter dtype used for training + reduce_dtype: Reduction dtype used for gradients + output_dtype: Optional output dtype + """ + state = MixedPrecisionState( + param_dtype=param_dtype, + reduce_dtype=reduce_dtype, + output_dtype=output_dtype, + mp_policy=mp_policy, + ) + _mixed_precision_state.state = state + + +def get_compute_dtype() -> torch.dtype: + """Get the current compute dtype from mixed precision policy. + + Returns: + torch.dtype: The compute dtype to use, defaults to get_default_dtype() if no policy set + """ + if not hasattr(_mixed_precision_state, "state"): + return torch.get_default_dtype() + else: + state = get_mixed_precision_state() + return state.param_dtype + + +def dict_to_3d_list( + mask_strategy: dict[str, Any] | None = None, + t_max: int | None = None, + l_max: int | None = None, + h_max: int | None = None, +) -> list[list[list[torch.Tensor | None]]]: + """ + Convert a dictionary of mask indices to a 3D list of tensors. + Args: + mask_strategy: keys are "t_l_h", values are torch.Tensor masks. + t_max, l_max, h_max: if provided (all three), force the output shape to (t_max, l_max, h_max). + If all three are None, infer shape from the data. + """ + # Case 1: no data, but fixed shape requested + if mask_strategy is None: + assert ( + t_max is not None and l_max is not None and h_max is not None + ), "If mask_strategy is None, you must provide t_max, l_max, and h_max" + return [ + [[None for _ in range(h_max)] for _ in range(l_max)] for _ in range(t_max) + ] + + # Parse all keys into integer tuples + indices = [tuple(map(int, key.split("_"))) for key in mask_strategy] + + # Decide on dimensions + if t_max is None and l_max is None and h_max is None: + # fully dynamic: infer from data + max_timesteps_idx = max(t for t, _, _ in indices) + 1 + max_layer_idx = max(l for _, l, _ in indices) + 1 # noqa: E741 + max_head_idx = max(h for _, _, h in indices) + 1 + else: + # require all three to be provided + assert t_max is not None and l_max is not None and h_max is not None, ( + "Either supply none of (t_max, l_max, h_max) to infer dimensions, " + "or supply all three to fix the shape." + ) + max_timesteps_idx = t_max + max_layer_idx = l_max + max_head_idx = h_max + + # Preallocate + result = [ + [[None for _ in range(max_head_idx)] for _ in range(max_layer_idx)] + for _ in range(max_timesteps_idx) + ] + + # Fill in, skipping any out-of-bounds entries + for key, value in mask_strategy.items(): + t, l, h = map(int, key.split("_")) # noqa: E741 + if ( + 0 <= t < max_timesteps_idx + and 0 <= l < max_layer_idx + and 0 <= h < max_head_idx + ): + result[t][l][h] = value + # else: silently ignore any key that doesn't fit + + return result + + +def set_random_seed(seed: int) -> None: + from sglang.multimodal_gen.runtime.platforms import current_platform + + current_platform.seed_everything(seed) + + +@lru_cache(maxsize=1) +def is_vsa_available() -> bool: + return importlib.util.find_spec("vsa") is not None + + +@lru_cache(maxsize=1) +def is_vmoba_available() -> bool: + if importlib.util.find_spec("kernel.csrc.attn.vmoba_attn.vmoba") is None: + return False + try: + import flash_attn + + return flash_attn.__version__ >= "2.7.4" + except Exception: + return False + + +# adapted from: https://github.com/Wan-Video/Wan2.2/blob/main/wan/utils/utils.py +def masks_like( + tensor, zero=False, generator=None, p=0.2 +) -> tuple[list[torch.Tensor], list[torch.Tensor]]: + assert isinstance(tensor, list) + out1 = [torch.ones(u.shape, dtype=u.dtype, device=u.device) for u in tensor] + + out2 = [torch.ones(u.shape, dtype=u.dtype, device=u.device) for u in tensor] + + if zero: + if generator is not None: + for u, v in zip(out1, out2, strict=False): + random_num = torch.rand( + 1, generator=generator, device=generator.device + ).item() + if random_num < p: + u[:, 0] = ( + torch.normal( + mean=-3.5, + std=0.5, + size=(1,), + device=u.device, + generator=generator, + ) + .expand_as(u[:, 0]) + .exp() + ) + v[:, 0] = torch.zeros_like(v[:, 0]) + else: + u[:, 0] = u[:, 0] + v[:, 0] = v[:, 0] + + else: + for u, v in zip(out1, out2, strict=False): + u[:, 0] = torch.zeros_like(u[:, 0]) + v[:, 0] = torch.zeros_like(v[:, 0]) + + return out1, out2 + + +# adapted from: https://github.com/Wan-Video/Wan2.2/blob/main/wan/utils/utils.py +def best_output_size(w, h, dw, dh, expected_area): + # float output size + ratio = w / h + ow = (expected_area * ratio) ** 0.5 + oh = expected_area / ow + + # process width first + ow1 = int(ow // dw * dw) + oh1 = int(expected_area / ow1 // dh * dh) + assert ow1 % dw == 0 and oh1 % dh == 0 and ow1 * oh1 <= expected_area + ratio1 = ow1 / oh1 + + # process height first + oh2 = int(oh // dh * dh) + ow2 = int(expected_area / oh2 // dw * dw) + assert oh2 % dh == 0 and ow2 % dw == 0 and ow2 * oh2 <= expected_area + ratio2 = ow2 / oh2 + + # compare ratios + if max(ratio / ratio1, ratio1 / ratio) < max(ratio / ratio2, ratio2 / ratio): + return ow1, oh1 + else: + return ow2, oh2 + + +def save_decoded_latents_as_video( + decoded_latents: list[torch.Tensor], output_path: str, fps: int +): + # Process outputs + videos = rearrange(decoded_latents, "b c t h w -> t b c h w") + frames = [] + for x in videos: + x = torchvision.utils.make_grid(x, nrow=6) + x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) + frames.append((x * 255).numpy().astype(np.uint8)) + + os.makedirs(os.path.dirname(output_path), exist_ok=True) + imageio.mimsave(output_path, frames, fps=fps, format="mp4") -- GitLab