Unverified Commit ab96b9ae authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat(server): support new falcon config (#712)

parent 2efd46ef
...@@ -200,13 +200,10 @@ def get_model( ...@@ -200,13 +200,10 @@ def get_model(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if model_type in ["RefinedWeb", "RefinedWebModel"]: if model_type in ["RefinedWeb", "RefinedWebModel", "falcon"]:
if sharded: if sharded:
if FLASH_ATTENTION: if FLASH_ATTENTION:
if config_dict.get("alibi", False) or ( if config_dict.get("alibi", False):
model_type == "RefinedWebModel"
and config_dict.get("multi_query", True)
):
raise NotImplementedError("sharded is not supported for this model") raise NotImplementedError("sharded is not supported for this model")
return FlashRWSharded( return FlashRWSharded(
model_id, model_id,
...@@ -215,9 +212,7 @@ def get_model( ...@@ -215,9 +212,7 @@ def get_model(
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
raise NotImplementedError( raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon"))
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded RefinedWeb")
)
else: else:
if FLASH_ATTENTION and not config_dict.get("alibi", False): if FLASH_ATTENTION and not config_dict.get("alibi", False):
return FlashRWSharded( return FlashRWSharded(
......
...@@ -49,8 +49,8 @@ class RWConfig(PretrainedConfig): ...@@ -49,8 +49,8 @@ class RWConfig(PretrainedConfig):
model_type="RefinedWeb", model_type="RefinedWeb",
vocab_size=250880, vocab_size=250880,
hidden_size=64, hidden_size=64,
n_layer=2, num_hidden_layers=None,
n_head=8, num_attention_heads=None,
layer_norm_epsilon=1e-5, layer_norm_epsilon=1e-5,
initializer_range=0.02, initializer_range=0.02,
use_cache=True, use_cache=True,
...@@ -58,9 +58,10 @@ class RWConfig(PretrainedConfig): ...@@ -58,9 +58,10 @@ class RWConfig(PretrainedConfig):
eos_token_id=2, eos_token_id=2,
hidden_dropout=0.0, hidden_dropout=0.0,
attention_dropout=0.0, attention_dropout=0.0,
n_head_kv=None, num_kv_heads=None,
multi_query=False, multi_query=False,
alibi=False, alibi=False,
new_decoder_architecture=None,
bias=False, bias=False,
parallel_attn=False, parallel_attn=False,
**kwargs, **kwargs,
...@@ -78,8 +79,16 @@ class RWConfig(PretrainedConfig): ...@@ -78,8 +79,16 @@ class RWConfig(PretrainedConfig):
# Backward compatibility with n_embed kwarg # Backward compatibility with n_embed kwarg
n_embed = kwargs.pop("n_embed", None) n_embed = kwargs.pop("n_embed", None)
self.hidden_size = hidden_size if n_embed is None else n_embed self.hidden_size = hidden_size if n_embed is None else n_embed
self.n_layer = n_layer self.n_layer = (
self.n_head = n_head num_hidden_layers
if num_hidden_layers is not None
else kwargs.pop("n_layer", 2)
)
self.n_head = (
num_attention_heads
if num_attention_heads is not None
else kwargs.pop("n_head", 8)
)
self.layer_norm_epsilon = layer_norm_epsilon self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.use_cache = use_cache self.use_cache = use_cache
...@@ -91,10 +100,21 @@ class RWConfig(PretrainedConfig): ...@@ -91,10 +100,21 @@ class RWConfig(PretrainedConfig):
self.bos_token_id = bos_token_id self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
if n_head_kv is not None: if num_kv_heads is not None:
self.n_head_kv = n_head_kv self.n_head_kv = num_kv_heads
else: else:
self.n_head_kv = 1 if multi_query else n_head old_n_head_kv = kwargs.pop("n_head_kv", None)
if old_n_head_kv is not None:
self.n_head_kv = old_n_head_kv
else:
self.n_head_kv = 1 if multi_query else self.n_head
if new_decoder_architecture is not None:
self.new_decoder_architecture = new_decoder_architecture
elif model_type == "RefinedWeb":
self.new_decoder_architecture = True
else:
self.new_decoder_architecture = False
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
...@@ -530,26 +550,23 @@ class FlashRWModel(FlashRWPreTrainedModel): ...@@ -530,26 +550,23 @@ class FlashRWModel(FlashRWPreTrainedModel):
self.word_embeddings = TensorParallelEmbedding( self.word_embeddings = TensorParallelEmbedding(
prefix="transformer.word_embeddings", weights=weights prefix="transformer.word_embeddings", weights=weights
) )
if config.model_type == "RefinedWebModel":
if config.new_decoder_architecture:
self.h = nn.ModuleList( self.h = nn.ModuleList(
[ [
FlashRWLayer(layer_id, config, weights) FlashRWLargeLayer(layer_id, config, weights)
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )
self.cache_size = self.h[0].self_attention.num_heads_kv self.cache_size = self.h[0].self_attention.num_groups
elif config.model_type == "RefinedWeb": else:
self.h = nn.ModuleList( self.h = nn.ModuleList(
[ [
FlashRWLargeLayer(layer_id, config, weights) FlashRWLayer(layer_id, config, weights)
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )
self.cache_size = self.h[0].self_attention.num_groups self.cache_size = self.h[0].self_attention.num_heads_kv
else:
raise NotImplementedError(
f"model_type {config.model_type} is not supported."
)
self.ln_f = FastLayerNorm.load( self.ln_f = FastLayerNorm.load(
prefix="transformer.ln_f", prefix="transformer.ln_f",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment