Commit f2a3c894 authored by helloyongyang's avatar helloyongyang
Browse files

fix sparge bugs

parent 53c0d05c
......@@ -58,6 +58,11 @@ class AttnWeightTemplate(metaclass=ABCMeta):
def to_cuda(self, non_blocking=False):
self.weight = self.weight.cuda(non_blocking=non_blocking)
def state_dict(self, destination=None):
if destination is None:
destination = {}
return destination
@ATTN_WEIGHT_REGISTER("flash_attn2")
class FlashAttn2Weight(AttnWeightTemplate):
......
......@@ -20,6 +20,7 @@ def get_default_config():
"strength_model": 1.0,
"mm_config": {},
"use_prompt_enhancer": False,
"sparge": False,
}
return default_config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment