Commit aebda3d8 authored by Christina Floristean's avatar Christina Floristean
Browse files

Config fixes for when using ds kernel

parent f1563999
...@@ -28,6 +28,7 @@ def enforce_config_constraints(config): ...@@ -28,6 +28,7 @@ def enforce_config_constraints(config):
( (
"globals.use_lma", "globals.use_lma",
"globals.use_flash", "globals.use_flash",
"globals.use_deepspeed_evo_attention"
), ),
] ]
...@@ -38,9 +39,18 @@ def enforce_config_constraints(config): ...@@ -38,9 +39,18 @@ def enforce_config_constraints(config):
raise ValueError(f"Only one of {s1} and {s2} may be set at a time") raise ValueError(f"Only one of {s1} and {s2} may be set at a time")
fa_is_installed = importlib.util.find_spec("flash_attn") is not None fa_is_installed = importlib.util.find_spec("flash_attn") is not None
if(config.globals.use_flash and not fa_is_installed): if config.globals.use_flash and not fa_is_installed:
raise ValueError("use_flash requires that FlashAttention is installed") raise ValueError("use_flash requires that FlashAttention is installed")
deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None
ds4s_is_installed = deepspeed_is_installed and importlib.util.find_spec(
"deepspeed.ops.deepspeed4science") is not None
if config.globals.use_deepspeed_evo_attention and not ds4s_is_installed:
raise ValueError(
"use_deepspeed_evo_attention requires that DeepSpeed be installed "
"and that the deepspeed.ops.deepspeed4science package exists"
)
if( if(
config.globals.offload_inference and config.globals.offload_inference and
not config.model.template.average_templates not config.model.template.average_templates
...@@ -193,6 +203,7 @@ def model_config( ...@@ -193,6 +203,7 @@ def model_config(
c.globals.offload_inference = True c.globals.offload_inference = True
c.globals.use_lma = True c.globals.use_lma = True
c.globals.use_flash = False c.globals.use_flash = False
c.globals.use_deepspeed_evo_attention = False
c.model.template.offload_inference = True c.model.template.offload_inference = True
c.model.template.template_pair_stack.tune_chunk_size = False c.model.template.template_pair_stack.tune_chunk_size = False
c.model.extra_msa.extra_msa_stack.tune_chunk_size = False c.model.extra_msa.extra_msa_stack.tune_chunk_size = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment