Commit 9950f6ec authored by Michael Yang's avatar Michael Yang
Browse files

gpt-oss

parent f1c73840
......@@ -202,6 +202,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
conv = &bertModel{}
case "CohereForCausalLM":
conv = &commandrModel{}
case "GptOssForCausalLM":
conv = &gptossModel{}
default:
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
}
......
package convert
import (
"cmp"
"github.com/ollama/ollama/fs/ggml"
)
type gptossModel struct {
ModelParameters
HiddenLayers uint32 `json:"num_hidden_layers"`
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
AttentionHeads uint32 `json:"num_attention_heads"`
KeyValueHeads uint32 `json:"num_key_value_heads"`
HeadDim uint32 `json:"head_dim"`
Experts uint32 `json:"num_experts"`
ExpertsPerToken uint32 `json:"experts_per_token"`
RMSNormEpsilon float32 `json:"rms_norm_eps"`
InitialContextLength uint32 `json:"initial_context_length"`
RopeTheta float32 `json:"rope_theta"`
RopeScalingFactor float32 `json:"rope_scaling_factor"`
SlidingWindow uint32 `json:"sliding_window"`
}
var _ ModelConverter = (*gptossModel)(nil)
func (m *gptossModel) KV(t *Tokenizer) ggml.KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "gptoss"
kv["general.file_type"] = uint32(4)
kv["gptoss.context_length"] = uint32(m.RopeScalingFactor * float32(m.InitialContextLength))
kv["gptoss.block_count"] = m.HiddenLayers
kv["gptoss.embedding_length"] = m.HiddenSize
kv["gptoss.feed_forward_length"] = m.IntermediateSize
kv["gptoss.expert_count"] = m.Experts
kv["gptoss.expert_used_count"] = m.ExpertsPerToken
kv["gptoss.attention.head_count"] = m.AttentionHeads
kv["gptoss.attention.head_count_kv"] = m.KeyValueHeads
kv["gptoss.attention.key_length"] = m.HeadDim
kv["gptoss.attention.value_length"] = m.HeadDim
kv["gptoss.attention.layer_norm_rms_epsilon"] = cmp.Or(m.RMSNormEpsilon, 1e-5)
kv["gptoss.attention.sliding_window"] = m.SlidingWindow
kv["gptoss.rope.freq_base"] = m.RopeTheta
kv["gptoss.rope.scaling.factor"] = m.RopeScalingFactor
kv["gptoss.rope.scaling.original_context_length"] = m.InitialContextLength
kv["tokenizer.ggml.bos_token_id"] = uint32(199998) // <|startoftext|>
kv["tokenizer.ggml.add_bos_token"] = false
kv["tokenizer.ggml.eos_token_id"] = uint32(199999) // <|endoftext|>
kv["tokenizer.ggml.eos_token_ids"] = []int32{
199999, /* <|endoftext|> */
200002, /* <|return|> */
200012, /* <|call|> */
}
kv["tokenizer.ggml.add_eos_token"] = false
return kv
}
func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
for _, t := range ts {
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (m *gptossModel) Replacements() []string {
return []string{
"block", "blk",
"attn.norm", "attn_norm",
"attn.qkv", "attn_qkv",
"attn.sinks", "attn_sinks",
"attn.out", "attn_out",
"mlp.norm", "ffn_norm",
"mlp.gate", "ffn_gate_inp",
"mlp.mlp1_", "ffn_gate_up_exps.",
"mlp.mlp2_", "ffn_down_exps.",
"embedding", "token_embd",
"norm", "output_norm",
"unembedding", "output",
"scale", "weight",
}
}
......@@ -276,6 +276,7 @@ type Tensor interface {
Cos(ctx Context) Tensor
Tanh(ctx Context) Tensor
GELU(ctx Context) Tensor
QuickGELU(ctx Context) Tensor
SILU(ctx Context) Tensor
RELU(ctx Context) Tensor
Sigmoid(ctx Context) Tensor
......@@ -283,7 +284,7 @@ type Tensor interface {
Reshape(ctx Context, shape ...int) Tensor
View(ctx Context, offset int, shape ...int) Tensor
Permute(ctx Context, shape ...int) Tensor
Contiguous(ctx Context) Tensor
Contiguous(ctx Context, shape ...int) Tensor
Set(ctx Context, t2 Tensor, offset int, strides ...int) Tensor
Pad(ctx Context, shape ...int) Tensor
......
......@@ -958,10 +958,35 @@ func (t *Tensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor {
}
}
func (t *Tensor) Contiguous(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_cont(ctx.(*Context).ctx, t.t),
func (t *Tensor) Contiguous(ctx ml.Context, shape ...int) ml.Tensor {
switch len(shape) {
case 0:
return &Tensor{
b: t.b,
t: C.ggml_cont(ctx.(*Context).ctx, t.t),
}
case 1:
return &Tensor{
b: t.b,
t: C.ggml_cont_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0])),
}
case 2:
return &Tensor{
b: t.b,
t: C.ggml_cont_2d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1])),
}
case 3:
return &Tensor{
b: t.b,
t: C.ggml_cont_3d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2])),
}
case 4:
return &Tensor{
b: t.b,
t: C.ggml_cont_4d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2]), C.int64_t(shape[3])),
}
default:
panic("unsupported number of dimensions")
}
}
......@@ -1176,11 +1201,18 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
func (t *Tensor) RoPE(ctx ml.Context, positions ml.Tensor, ropeDim int, ropeBase, ropeScale float32, options ...func(*rope.Options)) ml.Tensor {
// Default options
opts := &rope.Options{OriginalContextLength: 131072, Factors: &Tensor{}}
opts := rope.Options{
Factors: &Tensor{},
OriginalContextLength: 131072,
ExtrapolationFactor: 0.,
AttentionFactor: 1.,
BetaFast: 32.,
BetaSlow: 1.,
}
// Apply any provided options
for _, option := range options {
option(opts)
option(&opts)
}
dequant := t.t
......@@ -1200,10 +1232,10 @@ func (t *Tensor) RoPE(ctx ml.Context, positions ml.Tensor, ropeDim int, ropeBase
C.int(opts.OriginalContextLength),
C.float(ropeBase),
C.float(ropeScale),
C.float(0.0),
C.float(1.0),
C.float(32.0),
C.float(1.0),
C.float(opts.ExtrapolationFactor),
C.float(opts.AttentionFactor),
C.float(opts.BetaFast),
C.float(opts.BetaSlow),
),
}
}
......@@ -1222,6 +1254,13 @@ func (t *Tensor) GELU(ctx ml.Context) ml.Tensor {
}
}
func (t *Tensor) QuickGELU(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_gelu_quick_inplace(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) SILU(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
......
......@@ -15,3 +15,26 @@ func (m *Linear) Forward(ctx ml.Context, t ml.Tensor) ml.Tensor {
return t
}
type LinearBatch struct {
Weight ml.Tensor `gguf:"weight"`
Bias ml.Tensor `gguf:"bias"`
}
func (m *LinearBatch) Forward(ctx ml.Context, t, indices ml.Tensor) ml.Tensor {
t = m.Weight.MulmatID(ctx, t, indices)
if m.Bias != nil {
var bias ml.Tensor
if len(indices.Shape()) > 1 {
// FIXME: Rows does not support 2D indices for a 2D input tensor so reshape indices to 1D.
bias = m.Bias.Rows(ctx, indices.Contiguous(ctx, indices.Dim(0)*indices.Dim(1))).
Duplicate(ctx).
Reshape(ctx, m.Bias.Dim(0), indices.Dim(0), indices.Dim(1))
} else {
bias = m.Bias.Rows(ctx, indices)
}
t = t.Add(ctx, bias)
}
return t
}
......@@ -4,9 +4,15 @@ import "github.com/ollama/ollama/ml"
// Options contains optional parameters for RoPE function
type Options struct {
OriginalContextLength int
Type int
Factors ml.Tensor
OriginalContextLength int
// YaRN options
ExtrapolationFactor,
AttentionFactor,
BetaFast,
BetaSlow float32
}
// WithOriginalContextLength sets a custom context length
......@@ -31,3 +37,15 @@ func WithFactors(factors ml.Tensor) func(*Options) {
}
}
}
func WithExtrapolationFactor(extrapolationFactor float32) func(*Options) {
return func(opts *Options) {
opts.ExtrapolationFactor = extrapolationFactor
}
}
func WithAttentionFactor(attentionFactor float32) func(*Options) {
return func(opts *Options) {
opts.AttentionFactor = attentionFactor
}
}
package gptoss
import (
"cmp"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/fast"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Transformer struct {
model.Base
model.BytePairEncoding
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
TransformerBlocks []TransformerBlock `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
Options
}
// Forward implements model.Model.
func (m *Transformer) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
one := ctx.Input().FromFloatSlice([]float32{1}, 1)
for i, block := range m.TransformerBlocks {
m.Cache.SetLayer(i)
if c, ok := m.Cache.(*kvcache.WrapperCache); ok {
// Even layers are sliding window attention.
c.SetLayerType(i % 2)
}
var outputs ml.Tensor
if len(batch.Outputs) > 0 && i == len(m.TransformerBlocks)-1 {
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
}
hiddenStates = block.Forward(ctx, hiddenStates, positions, outputs, one, m.Cache, &m.Options)
}
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
return m.Output.Forward(ctx, hiddenStates), nil
}
func (m *Transformer) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return fast.RoPE(ctx, key, shift, m.headDim(), m.ropeBase, 1./m.ropeScale, m.RoPEOptions()...), nil
}
type Options struct {
hiddenSize,
numHeads,
numKVHeads,
keyLength,
valueLength,
numExperts,
numExpertsUsed,
originalContextLength int
eps,
ropeBase,
ropeScale float32
}
func (o Options) RoPEOptions() []func(*rope.Options) {
return []func(*rope.Options){
rope.WithTypeNeoX(),
rope.WithOriginalContextLength(o.originalContextLength),
rope.WithExtrapolationFactor(1.),
// NOTE: ggml sets this implicitly so there's no need to set it here
// rope.WithAttentionFactor(0.1*float32(math.Log(float64(o.ropeScale))) + 1.0),
}
}
func (o Options) headDim() int {
return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads)
}
type TransformerBlock struct {
Attention *AttentionBlock
MLP *MLPBlock
}
func (d *TransformerBlock) Forward(ctx ml.Context, hiddenStates, positions, outputs, one ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
hiddenStates = d.Attention.Forward(ctx, hiddenStates, positions, cache, opts)
if outputs != nil {
hiddenStates = hiddenStates.Rows(ctx, outputs)
}
hiddenStates = d.MLP.Forward(ctx, hiddenStates, one, opts)
return hiddenStates
}
type AttentionBlock struct {
Norm *nn.RMSNorm `gguf:"attn_norm"`
QKV *nn.Linear `gguf:"attn_qkv"`
Output *nn.Linear `gguf:"attn_out"`
Sinks ml.Tensor `gguf:"attn_sinks"`
}
func (attn *AttentionBlock) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
batchSize := hiddenStates.Dim(1)
residual := hiddenStates
hiddenStates = attn.Norm.Forward(ctx, hiddenStates, opts.eps)
qkv := attn.QKV.Forward(ctx, hiddenStates)
// query = qkv[..., : num_attention_heads * head_dim].reshape(batch_size, num_attention_heads, head_dim)
query := qkv.View(ctx,
0,
opts.headDim(), qkv.Stride(0)*opts.headDim(),
opts.numHeads, qkv.Stride(1),
batchSize,
)
query = fast.RoPE(ctx, query, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
// key = qkv[..., num_attention_heads * head_dim:(num_attention_heads + num_key_value_heads) * head_dim].reshape(batch_size, num_key_value_heads, head_dim)
key := qkv.View(ctx,
qkv.Stride(0)*opts.headDim()*opts.numHeads,
opts.headDim(), qkv.Stride(0)*opts.headDim(),
opts.numKVHeads, qkv.Stride(1),
batchSize,
)
key = fast.RoPE(ctx, key, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
// value = qkv[..., (num_attention_heads + num_key_value_heads) * head_dim:].reshape(batch_size, num_key_value_heads, head_dim)
value := qkv.View(ctx,
qkv.Stride(0)*opts.headDim()*(opts.numHeads+opts.numKVHeads),
opts.headDim(), qkv.Stride(0)*opts.headDim(),
opts.numKVHeads, qkv.Stride(1),
batchSize,
)
cache.Put(ctx, key, value)
key, value, mask := cache.Get(ctx)
query = query.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
scores := key.MulmatFullPrec(ctx, query)
scores = scores.Scale(ctx, 1./math.Sqrt(float64(opts.headDim())))
scores = scores.Add(ctx, mask)
scores = scores.Concat(ctx, attn.Sinks.Reshape(ctx, 1, 1, opts.numHeads, 1).Repeat(ctx, 1, batchSize), 0)
scores = scores.Softmax(ctx)
scores = scores.Pad(ctx, -1, 0, 0, 0)
attention := value.Mulmat(ctx, scores)
attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
return attn.Output.Forward(ctx, attention).Add(ctx, residual)
}
type MLPBlock struct {
Norm *nn.RMSNorm `gguf:"ffn_norm"`
Router *nn.Linear `gguf:"ffn_gate_inp"`
GateUp *nn.LinearBatch `gguf:"ffn_gate_up_exps"`
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
}
func (mlp *MLPBlock) Forward(ctx ml.Context, hiddenStates, one ml.Tensor, opts *Options) ml.Tensor {
hiddenDim, sequenceLength, batchSize := hiddenStates.Dim(0), hiddenStates.Dim(1), hiddenStates.Dim(2)
residual := hiddenStates
hiddenStates = mlp.Norm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, sequenceLength*batchSize)
routingWeights := mlp.Router.Forward(ctx, hiddenStates)
selectedExperts := routingWeights.TopK(ctx, opts.numExpertsUsed)
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExperts, sequenceLength*batchSize).Rows(ctx, selectedExperts)
routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, sequenceLength*batchSize).Softmax(ctx)
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, sequenceLength*batchSize)
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
hiddenStates = mlp.GateUp.Forward(ctx, hiddenStates, selectedExperts)
hiddenStates = hiddenStates.Reshape(ctx, 2, hiddenStates.Dim(0)/2, hiddenStates.Dim(1), hiddenStates.Dim(2))
dimStride := []int{hiddenStates.Dim(0) / 2, hiddenStates.Stride(1), hiddenStates.Dim(1), hiddenStates.Stride(2), hiddenStates.Dim(2), hiddenStates.Stride(3), hiddenStates.Dim(3)}
hiddenStates = hiddenStates.View(ctx, 0, dimStride...).
Contiguous(ctx).
QuickGELU(ctx).
Mul(ctx, hiddenStates.View(ctx, hiddenStates.Stride(0), dimStride...).Add(ctx, one))
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)*hiddenStates.Dim(1), hiddenStates.Dim(2), hiddenStates.Dim(3))
experts := mlp.Down.Forward(ctx, hiddenStates, selectedExperts)
experts = experts.Mul(ctx, routingWeights)
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
for i := 1; i < opts.numExpertsUsed; i++ {
nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
}
return nextStates.Add(ctx, residual)
}
func New(c fs.Config) (model.Model, error) {
m := Transformer{
TransformerBlocks: make([]TransformerBlock, c.Uint("block_count")),
BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
keyLength: int(c.Uint("attention.key_length")),
valueLength: int(c.Uint("attention.value_length")),
numExperts: int(c.Uint("expert_count")),
numExpertsUsed: int(c.Uint("expert_used_count")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.scaling.factor", 1.),
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
},
}
m.Cache = kvcache.NewWrapperCache(
kvcache.NewSWACache(int32(c.Uint("attention.sliding_window")), m.Shift),
kvcache.NewCausalCache(m.Shift),
)
m.Cache.SetConfig(ml.CacheConfig{})
return &m, nil
}
func init() {
model.Register("gptoss", New)
}
......@@ -4,6 +4,7 @@ import (
_ "github.com/ollama/ollama/model/models/gemma2"
_ "github.com/ollama/ollama/model/models/gemma3"
_ "github.com/ollama/ollama/model/models/gemma3n"
_ "github.com/ollama/ollama/model/models/gptoss"
_ "github.com/ollama/ollama/model/models/llama"
_ "github.com/ollama/ollama/model/models/llama4"
_ "github.com/ollama/ollama/model/models/mistral3"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment