Commit aebda3d8 authored by Christina Floristean's avatar Christina Floristean
Browse files

Config fixes for when using ds kernel

parent f1563999
......@@ -28,6 +28,7 @@ def enforce_config_constraints(config):
(
"globals.use_lma",
"globals.use_flash",
"globals.use_deepspeed_evo_attention"
),
]
......@@ -38,9 +39,18 @@ def enforce_config_constraints(config):
raise ValueError(f"Only one of {s1} and {s2} may be set at a time")
fa_is_installed = importlib.util.find_spec("flash_attn") is not None
if(config.globals.use_flash and not fa_is_installed):
if config.globals.use_flash and not fa_is_installed:
raise ValueError("use_flash requires that FlashAttention is installed")
deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None
ds4s_is_installed = deepspeed_is_installed and importlib.util.find_spec(
"deepspeed.ops.deepspeed4science") is not None
if config.globals.use_deepspeed_evo_attention and not ds4s_is_installed:
raise ValueError(
"use_deepspeed_evo_attention requires that DeepSpeed be installed "
"and that the deepspeed.ops.deepspeed4science package exists"
)
if(
config.globals.offload_inference and
not config.model.template.average_templates
......@@ -193,6 +203,7 @@ def model_config(
c.globals.offload_inference = True
c.globals.use_lma = True
c.globals.use_flash = False
c.globals.use_deepspeed_evo_attention = False
c.model.template.offload_inference = True
c.model.template.template_pair_stack.tune_chunk_size = False
c.model.extra_msa.extra_msa_stack.tune_chunk_size = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment