import torch from ..scheduler import HunyuanScheduler def cache_init(num_steps, model_kwargs=None): ''' Initialization for cache. ''' cache_dic = {} cache = {} cache_index = {} cache[-1]={} cache_index[-1]={} cache_index['layer_index']={} cache_dic['attn_map'] = {} cache_dic['attn_map'][-1] = {} cache_dic['attn_map'][-1]['double_stream'] = {} cache_dic['attn_map'][-1]['single_stream'] = {} cache_dic['k-norm'] = {} cache_dic['k-norm'][-1] = {} cache_dic['k-norm'][-1]['double_stream'] = {} cache_dic['k-norm'][-1]['single_stream'] = {} cache_dic['v-norm'] = {} cache_dic['v-norm'][-1] = {} cache_dic['v-norm'][-1]['double_stream'] = {} cache_dic['v-norm'][-1]['single_stream'] = {} cache_dic['cross_attn_map'] = {} cache_dic['cross_attn_map'][-1] = {} cache[-1]['double_stream']={} cache[-1]['single_stream']={} cache_dic['cache_counter'] = 0 for j in range(20): cache[-1]['double_stream'][j] = {} cache_index[-1][j] = {} cache_dic['attn_map'][-1]['double_stream'][j] = {} cache_dic['attn_map'][-1]['double_stream'][j]['total'] = {} cache_dic['attn_map'][-1]['double_stream'][j]['txt_mlp'] = {} cache_dic['attn_map'][-1]['double_stream'][j]['img_mlp'] = {} cache_dic['k-norm'][-1]['double_stream'][j] = {} cache_dic['k-norm'][-1]['double_stream'][j]['txt_mlp'] = {} cache_dic['k-norm'][-1]['double_stream'][j]['img_mlp'] = {} cache_dic['v-norm'][-1]['double_stream'][j] = {} cache_dic['v-norm'][-1]['double_stream'][j]['txt_mlp'] = {} cache_dic['v-norm'][-1]['double_stream'][j]['img_mlp'] = {} for j in range(40): cache[-1]['single_stream'][j] = {} cache_index[-1][j] = {} cache_dic['attn_map'][-1]['single_stream'][j] = {} cache_dic['attn_map'][-1]['single_stream'][j]['total'] = {} cache_dic['k-norm'][-1]['single_stream'][j] = {} cache_dic['k-norm'][-1]['single_stream'][j]['total'] = {} cache_dic['v-norm'][-1]['single_stream'][j] = {} cache_dic['v-norm'][-1]['single_stream'][j]['total'] = {} cache_dic['taylor_cache'] = False cache_dic['duca'] = False cache_dic['test_FLOPs'] = False mode = 'Taylor' if mode == 'original': cache_dic['cache_type'] = 'random' cache_dic['cache_index'] = cache_index cache_dic['cache'] = cache cache_dic['fresh_ratio_schedule'] = 'ToCa' cache_dic['fresh_ratio'] = 0.0 cache_dic['fresh_threshold'] = 1 cache_dic['force_fresh'] = 'global' cache_dic['soft_fresh_weight'] = 0.0 cache_dic['max_order'] = 0 cache_dic['first_enhance'] = 1 elif mode == 'ToCa': cache_dic['cache_type'] = 'random' cache_dic['cache_index'] = cache_index cache_dic['cache'] = cache cache_dic['fresh_ratio_schedule'] = 'ToCa' cache_dic['fresh_ratio'] = 0.10 cache_dic['fresh_threshold'] = 5 cache_dic['force_fresh'] = 'global' cache_dic['soft_fresh_weight'] = 0.0 cache_dic['max_order'] = 0 cache_dic['first_enhance'] = 1 cache_dic['duca'] = False elif mode == 'DuCa': cache_dic['cache_type'] = 'random' cache_dic['cache_index'] = cache_index cache_dic['cache'] = cache cache_dic['fresh_ratio_schedule'] = 'ToCa' cache_dic['fresh_ratio'] = 0.10 cache_dic['fresh_threshold'] = 5 cache_dic['force_fresh'] = 'global' cache_dic['soft_fresh_weight'] = 0.0 cache_dic['max_order'] = 0 cache_dic['first_enhance'] = 1 cache_dic['duca'] = True elif mode == 'Taylor': cache_dic['cache_type'] = 'random' cache_dic['cache_index'] = cache_index cache_dic['cache'] = cache cache_dic['fresh_ratio_schedule'] = 'ToCa' cache_dic['fresh_ratio'] = 0.0 cache_dic['fresh_threshold'] = 5 cache_dic['max_order'] = 1 cache_dic['force_fresh'] = 'global' cache_dic['soft_fresh_weight'] = 0.0 cache_dic['taylor_cache'] = True cache_dic['first_enhance'] = 1 current = {} current['num_steps'] = num_steps current['activated_steps'] = [0] return cache_dic, current def force_scheduler(cache_dic, current): if cache_dic['fresh_ratio'] == 0: # FORA linear_step_weight = 0.0 else: # TokenCache linear_step_weight = 0.0 step_factor = torch.tensor(1 - linear_step_weight + 2 * linear_step_weight * current['step'] / current['num_steps']) threshold = torch.round(cache_dic['fresh_threshold'] / step_factor) # no force constrain for sensitive steps, cause the performance is good enough. # you may have a try. cache_dic['cal_threshold'] = threshold #return threshold def cal_type(cache_dic, current): ''' Determine calculation type for this step ''' if (cache_dic['fresh_ratio'] == 0.0) and (not cache_dic['taylor_cache']): # FORA:Uniform first_step = (current['step'] == 0) else: # ToCa: First enhanced first_step = (current['step'] < cache_dic['first_enhance']) #first_step = (current['step'] <= 3) force_fresh = cache_dic['force_fresh'] if not first_step: fresh_interval = cache_dic['cal_threshold'] else: fresh_interval = cache_dic['fresh_threshold'] if (first_step) or (cache_dic['cache_counter'] == fresh_interval - 1 ): current['type'] = 'full' cache_dic['cache_counter'] = 0 current['activated_steps'].append(current['step']) #current['activated_times'].append(current['t']) force_scheduler(cache_dic, current) elif (cache_dic['taylor_cache']): cache_dic['cache_counter'] += 1 current['type'] = 'taylor_cache' else: cache_dic['cache_counter'] += 1 if (cache_dic['duca']): if (cache_dic['cache_counter'] % 2 == 1): # 0: ToCa-Aggresive-ToCa, 1: Aggresive-ToCa-Aggresive current['type'] = 'ToCa' # 'cache_noise' 'ToCa' 'FORA' else: current['type'] = 'aggressive' else: current['type'] = 'ToCa' #if current['step'] < 25: # current['type'] = 'FORA' #else: # current['type'] = 'aggressive' ###################################################################### #if (current['step'] in [3,2,1,0]): # current['type'] = 'full' class HunyuanSchedulerFeatureCaching(HunyuanScheduler): def __init__(self, args): super().__init__(args) self.cache_dic, self.current = cache_init(self.infer_steps) def step_pre(self, step_index): super().step_pre(step_index) self.current['step'] = step_index cal_type(self.cache_dic, self.current)