Commit f5287e90 authored by qiyuxinlin's avatar qiyuxinlin
Browse files

fix no balance_serve import error

parent 03a65d6b
......@@ -12,7 +12,10 @@ import torch.nn as nn
import transformers
from transformers import Cache, PretrainedConfig
from typing import List, Optional, Dict, Any, Tuple
from ktransformers.server.balance_serve.settings import sched_ext
try:
from ktransformers.server.balance_serve.settings import sched_ext
except:
print("no balance_serve")
class StaticCache(transformers.StaticCache):
"""
Static Cache class to be used with `torch.compile(model)`.
......@@ -210,7 +213,7 @@ class KDeepSeekV3Cache(nn.Module):
self.v_caches = []
def load(self, inference_context: sched_ext.InferenceContext):
def load(self, inference_context: "sched_ext.InferenceContext"):
for i in range(self.config.num_hidden_layers):
self.k_caches.append(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment