Commit aa47c1c5 authored by xuxzh1's avatar xuxzh1 🎱
Browse files

update

parent 0cb78a2f
package android.llama.cpp
import android.util.Log
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.asCoroutineDispatcher
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.flowOn
import kotlinx.coroutines.withContext
import java.util.concurrent.Executors
import kotlin.concurrent.thread
class LLamaAndroid {
private val tag: String? = this::class.simpleName
private val threadLocalState: ThreadLocal<State> = ThreadLocal.withInitial { State.Idle }
private val runLoop: CoroutineDispatcher = Executors.newSingleThreadExecutor {
thread(start = false, name = "Llm-RunLoop") {
Log.d(tag, "Dedicated thread for native code: ${Thread.currentThread().name}")
// No-op if called more than once.
System.loadLibrary("llama-android")
// Set llama log handler to Android
log_to_android()
backend_init(false)
Log.d(tag, system_info())
it.run()
}.apply {
uncaughtExceptionHandler = Thread.UncaughtExceptionHandler { _, exception: Throwable ->
Log.e(tag, "Unhandled exception", exception)
}
}
}.asCoroutineDispatcher()
private val nlen: Int = 64
private external fun log_to_android()
private external fun load_model(filename: String): Long
private external fun free_model(model: Long)
private external fun new_context(model: Long): Long
private external fun free_context(context: Long)
private external fun backend_init(numa: Boolean)
private external fun backend_free()
private external fun free_batch(batch: Long)
private external fun new_batch(nTokens: Int, embd: Int, nSeqMax: Int): Long
private external fun bench_model(
context: Long,
model: Long,
batch: Long,
pp: Int,
tg: Int,
pl: Int,
nr: Int
): String
private external fun system_info(): String
private external fun completion_init(
context: Long,
batch: Long,
text: String,
nLen: Int
): Int
private external fun completion_loop(
context: Long,
batch: Long,
nLen: Int,
ncur: IntVar
): String?
private external fun kv_cache_clear(context: Long)
suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String {
return withContext(runLoop) {
when (val state = threadLocalState.get()) {
is State.Loaded -> {
Log.d(tag, "bench(): $state")
bench_model(state.context, state.model, state.batch, pp, tg, pl, nr)
}
else -> throw IllegalStateException("No model loaded")
}
}
}
suspend fun load(pathToModel: String) {
withContext(runLoop) {
when (threadLocalState.get()) {
is State.Idle -> {
val model = load_model(pathToModel)
if (model == 0L) throw IllegalStateException("load_model() failed")
val context = new_context(model)
if (context == 0L) throw IllegalStateException("new_context() failed")
val batch = new_batch(512, 0, 1)
if (batch == 0L) throw IllegalStateException("new_batch() failed")
Log.i(tag, "Loaded model $pathToModel")
threadLocalState.set(State.Loaded(model, context, batch))
}
else -> throw IllegalStateException("Model already loaded")
}
}
}
fun send(message: String): Flow<String> = flow {
when (val state = threadLocalState.get()) {
is State.Loaded -> {
val ncur = IntVar(completion_init(state.context, state.batch, message, nlen))
while (ncur.value <= nlen) {
val str = completion_loop(state.context, state.batch, nlen, ncur)
if (str == null) {
break
}
emit(str)
}
kv_cache_clear(state.context)
}
else -> {}
}
}.flowOn(runLoop)
/**
* Unloads the model and frees resources.
*
* This is a no-op if there's no model loaded.
*/
suspend fun unload() {
withContext(runLoop) {
when (val state = threadLocalState.get()) {
is State.Loaded -> {
free_context(state.context)
free_model(state.model)
free_batch(state.batch)
threadLocalState.set(State.Idle)
}
else -> {}
}
}
}
companion object {
private class IntVar(value: Int) {
@Volatile
var value: Int = value
private set
fun inc() {
synchronized(this) {
value += 1
}
}
}
private sealed interface State {
data object Idle: State
data class Loaded(val model: Long, val context: Long, val batch: Long): State
}
// Enforce only one instance of Llm.
private val _instance: LLamaAndroid = LLamaAndroid()
fun instance(): LLamaAndroid = _instance
}
}
package android.llama.cpp
import org.junit.Test
import org.junit.Assert.*
/**
* Example local unit test, which will execute on the development machine (host).
*
* See [testing documentation](http://d.android.com/tools/testing).
*/
class ExampleUnitTest {
@Test
fun addition_isCorrect() {
assertEquals(4, 2 + 2)
}
}
pluginManagement {
repositories {
google()
mavenCentral()
gradlePluginPortal()
}
}
dependencyResolutionManagement {
repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS)
repositories {
google()
mavenCentral()
}
}
rootProject.name = "LlamaAndroid"
include(":app")
include(":llama")
# llama.cpp/examples/llama.swiftui
Local inference of llama.cpp on an iPhone. This is a sample app that can be used as a starting
point for more advanced projects.
For usage instructions and performance stats, check the following discussion: https://github.com/ggerganov/llama.cpp/discussions/4508
![image](https://github.com/ggerganov/llama.cpp/assets/1991296/2b40284f-8421-47a2-b634-74eece09a299)
Video demonstration:
https://github.com/bachittle/llama.cpp/assets/39804642/e290827a-4edb-4093-9642-2a5e399ec545
import Foundation
import llama
enum LlamaError: Error {
case couldNotInitializeContext
}
func llama_batch_clear(_ batch: inout llama_batch) {
batch.n_tokens = 0
}
func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama_pos, _ seq_ids: [llama_seq_id], _ logits: Bool) {
batch.token [Int(batch.n_tokens)] = id
batch.pos [Int(batch.n_tokens)] = pos
batch.n_seq_id[Int(batch.n_tokens)] = Int32(seq_ids.count)
for i in 0..<seq_ids.count {
batch.seq_id[Int(batch.n_tokens)]![Int(i)] = seq_ids[i]
}
batch.logits [Int(batch.n_tokens)] = logits ? 1 : 0
batch.n_tokens += 1
}
actor LlamaContext {
private var model: OpaquePointer
private var context: OpaquePointer
private var batch: llama_batch
private var tokens_list: [llama_token]
var is_done: Bool = false
/// This variable is used to store temporarily invalid cchars
private var temporary_invalid_cchars: [CChar]
var n_len: Int32 = 1024
var n_cur: Int32 = 0
var n_decode: Int32 = 0
init(model: OpaquePointer, context: OpaquePointer) {
self.model = model
self.context = context
self.tokens_list = []
self.batch = llama_batch_init(512, 0, 1)
self.temporary_invalid_cchars = []
}
deinit {
llama_batch_free(batch)
llama_free(context)
llama_free_model(model)
llama_backend_free()
}
static func create_context(path: String) throws -> LlamaContext {
llama_backend_init()
var model_params = llama_model_default_params()
#if targetEnvironment(simulator)
model_params.n_gpu_layers = 0
print("Running on simulator, force use n_gpu_layers = 0")
#endif
let model = llama_load_model_from_file(path, model_params)
guard let model else {
print("Could not load model at \(path)")
throw LlamaError.couldNotInitializeContext
}
let n_threads = max(1, min(8, ProcessInfo.processInfo.processorCount - 2))
print("Using \(n_threads) threads")
var ctx_params = llama_context_default_params()
ctx_params.seed = 1234
ctx_params.n_ctx = 2048
ctx_params.n_threads = UInt32(n_threads)
ctx_params.n_threads_batch = UInt32(n_threads)
let context = llama_new_context_with_model(model, ctx_params)
guard let context else {
print("Could not load context!")
throw LlamaError.couldNotInitializeContext
}
return LlamaContext(model: model, context: context)
}
func model_info() -> String {
let result = UnsafeMutablePointer<Int8>.allocate(capacity: 256)
result.initialize(repeating: Int8(0), count: 256)
defer {
result.deallocate()
}
// TODO: this is probably very stupid way to get the string from C
let nChars = llama_model_desc(model, result, 256)
let bufferPointer = UnsafeBufferPointer(start: result, count: Int(nChars))
var SwiftString = ""
for char in bufferPointer {
SwiftString.append(Character(UnicodeScalar(UInt8(char))))
}
return SwiftString
}
func get_n_tokens() -> Int32 {
return batch.n_tokens;
}
func completion_init(text: String) {
print("attempting to complete \"\(text)\"")
tokens_list = tokenize(text: text, add_bos: true)
temporary_invalid_cchars = []
let n_ctx = llama_n_ctx(context)
let n_kv_req = tokens_list.count + (Int(n_len) - tokens_list.count)
print("\n n_len = \(n_len), n_ctx = \(n_ctx), n_kv_req = \(n_kv_req)")
if n_kv_req > n_ctx {
print("error: n_kv_req > n_ctx, the required KV cache size is not big enough")
}
for id in tokens_list {
print(String(cString: token_to_piece(token: id) + [0]))
}
llama_batch_clear(&batch)
for i1 in 0..<tokens_list.count {
let i = Int(i1)
llama_batch_add(&batch, tokens_list[i], Int32(i), [0], false)
}
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
if llama_decode(context, batch) != 0 {
print("llama_decode() failed")
}
n_cur = batch.n_tokens
}
func completion_loop() -> String {
var new_token_id: llama_token = 0
let n_vocab = llama_n_vocab(model)
let logits = llama_get_logits_ith(context, batch.n_tokens - 1)
var candidates = Array<llama_token_data>()
candidates.reserveCapacity(Int(n_vocab))
for token_id in 0..<n_vocab {
candidates.append(llama_token_data(id: token_id, logit: logits![Int(token_id)], p: 0.0))
}
candidates.withUnsafeMutableBufferPointer() { buffer in
var candidates_p = llama_token_data_array(data: buffer.baseAddress, size: buffer.count, sorted: false)
new_token_id = llama_sample_token_greedy(context, &candidates_p)
}
if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
print("\n")
is_done = true
let new_token_str = String(cString: temporary_invalid_cchars + [0])
temporary_invalid_cchars.removeAll()
return new_token_str
}
let new_token_cchars = token_to_piece(token: new_token_id)
temporary_invalid_cchars.append(contentsOf: new_token_cchars)
let new_token_str: String
if let string = String(validatingUTF8: temporary_invalid_cchars + [0]) {
temporary_invalid_cchars.removeAll()
new_token_str = string
} else if (0 ..< temporary_invalid_cchars.count).contains(where: {$0 != 0 && String(validatingUTF8: Array(temporary_invalid_cchars.suffix($0)) + [0]) != nil}) {
// in this case, at least the suffix of the temporary_invalid_cchars can be interpreted as UTF8 string
let string = String(cString: temporary_invalid_cchars + [0])
temporary_invalid_cchars.removeAll()
new_token_str = string
} else {
new_token_str = ""
}
print(new_token_str)
// tokens_list.append(new_token_id)
llama_batch_clear(&batch)
llama_batch_add(&batch, new_token_id, n_cur, [0], true)
n_decode += 1
n_cur += 1
if llama_decode(context, batch) != 0 {
print("failed to evaluate llama!")
}
return new_token_str
}
func bench(pp: Int, tg: Int, pl: Int, nr: Int = 1) -> String {
var pp_avg: Double = 0
var tg_avg: Double = 0
var pp_std: Double = 0
var tg_std: Double = 0
for _ in 0..<nr {
// bench prompt processing
llama_batch_clear(&batch)
let n_tokens = pp
for i in 0..<n_tokens {
llama_batch_add(&batch, 0, Int32(i), [0], false)
}
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
llama_kv_cache_clear(context)
let t_pp_start = ggml_time_us()
if llama_decode(context, batch) != 0 {
print("llama_decode() failed during prompt")
}
llama_synchronize(context)
let t_pp_end = ggml_time_us()
// bench text generation
llama_kv_cache_clear(context)
let t_tg_start = ggml_time_us()
for i in 0..<tg {
llama_batch_clear(&batch)
for j in 0..<pl {
llama_batch_add(&batch, 0, Int32(i), [Int32(j)], true)
}
if llama_decode(context, batch) != 0 {
print("llama_decode() failed during text generation")
}
llama_synchronize(context)
}
let t_tg_end = ggml_time_us()
llama_kv_cache_clear(context)
let t_pp = Double(t_pp_end - t_pp_start) / 1000000.0
let t_tg = Double(t_tg_end - t_tg_start) / 1000000.0
let speed_pp = Double(pp) / t_pp
let speed_tg = Double(pl*tg) / t_tg
pp_avg += speed_pp
tg_avg += speed_tg
pp_std += speed_pp * speed_pp
tg_std += speed_tg * speed_tg
print("pp \(speed_pp) t/s, tg \(speed_tg) t/s")
}
pp_avg /= Double(nr)
tg_avg /= Double(nr)
if nr > 1 {
pp_std = sqrt(pp_std / Double(nr - 1) - pp_avg * pp_avg * Double(nr) / Double(nr - 1))
tg_std = sqrt(tg_std / Double(nr - 1) - tg_avg * tg_avg * Double(nr) / Double(nr - 1))
} else {
pp_std = 0
tg_std = 0
}
let model_desc = model_info();
let model_size = String(format: "%.2f GiB", Double(llama_model_size(model)) / 1024.0 / 1024.0 / 1024.0);
let model_n_params = String(format: "%.2f B", Double(llama_model_n_params(model)) / 1e9);
let backend = "Metal";
let pp_avg_str = String(format: "%.2f", pp_avg);
let tg_avg_str = String(format: "%.2f", tg_avg);
let pp_std_str = String(format: "%.2f", pp_std);
let tg_std_str = String(format: "%.2f", tg_std);
var result = ""
result += String("| model | size | params | backend | test | t/s |\n")
result += String("| --- | --- | --- | --- | --- | --- |\n")
result += String("| \(model_desc) | \(model_size) | \(model_n_params) | \(backend) | pp \(pp) | \(pp_avg_str) ± \(pp_std_str) |\n")
result += String("| \(model_desc) | \(model_size) | \(model_n_params) | \(backend) | tg \(tg) | \(tg_avg_str) ± \(tg_std_str) |\n")
return result;
}
func clear() {
tokens_list.removeAll()
temporary_invalid_cchars.removeAll()
llama_kv_cache_clear(context)
}
private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
let utf8Count = text.utf8.count
let n_tokens = utf8Count + (add_bos ? 1 : 0) + 1
let tokens = UnsafeMutablePointer<llama_token>.allocate(capacity: n_tokens)
let tokenCount = llama_tokenize(model, text, Int32(utf8Count), tokens, Int32(n_tokens), add_bos, false)
var swiftTokens: [llama_token] = []
for i in 0..<tokenCount {
swiftTokens.append(tokens[Int(i)])
}
tokens.deallocate()
return swiftTokens
}
/// - note: The result does not contain null-terminator
private func token_to_piece(token: llama_token) -> [CChar] {
let result = UnsafeMutablePointer<Int8>.allocate(capacity: 8)
result.initialize(repeating: Int8(0), count: 8)
defer {
result.deallocate()
}
let nTokens = llama_token_to_piece(model, token, result, 8, 0, false)
if nTokens < 0 {
let newResult = UnsafeMutablePointer<Int8>.allocate(capacity: Int(-nTokens))
newResult.initialize(repeating: Int8(0), count: Int(-nTokens))
defer {
newResult.deallocate()
}
let nNewTokens = llama_token_to_piece(model, token, newResult, -nTokens, 0, false)
let bufferPointer = UnsafeBufferPointer(start: newResult, count: Int(nNewTokens))
return Array(bufferPointer)
} else {
let bufferPointer = UnsafeBufferPointer(start: result, count: Int(nTokens))
return Array(bufferPointer)
}
}
}
// !$*UTF8*$!
{
archiveVersion = 1;
classes = {
};
objectVersion = 56;
objects = {
/* Begin PBXBuildFile section */
549479CB2AC9E16000E0F78B /* Metal.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 549479CA2AC9E16000E0F78B /* Metal.framework */; };
79E1D9CD2B4CD16E005F8E46 /* InputButton.swift in Sources */ = {isa = PBXBuildFile; fileRef = 79E1D9CC2B4CD16E005F8E46 /* InputButton.swift */; };
7FA3D2B32B2EA2F600543F92 /* DownloadButton.swift in Sources */ = {isa = PBXBuildFile; fileRef = 7FA3D2B22B2EA2F600543F92 /* DownloadButton.swift */; };
8A1C83772AC328BD0096AF73 /* llama_swiftuiApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8A1C83762AC328BD0096AF73 /* llama_swiftuiApp.swift */; };
8A1C83792AC328BD0096AF73 /* ContentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8A1C83782AC328BD0096AF73 /* ContentView.swift */; };
8A1C837B2AC328BE0096AF73 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 8A1C837A2AC328BE0096AF73 /* Assets.xcassets */; };
8A39BE0A2AC7601100BFEB40 /* Accelerate.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 8A39BE092AC7601000BFEB40 /* Accelerate.framework */; };
8A3F84242AC4C891005E2EE8 /* models in Resources */ = {isa = PBXBuildFile; fileRef = 8A3F84232AC4C891005E2EE8 /* models */; };
8A907F332AC7138A006146EA /* LibLlama.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8A907F322AC7134E006146EA /* LibLlama.swift */; };
8A9F7C4D2AC332EE008AE1EA /* LlamaState.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8A9F7C4C2AC332EE008AE1EA /* LlamaState.swift */; };
DF810E132B4A5BA200301144 /* llama in Frameworks */ = {isa = PBXBuildFile; productRef = DF810E122B4A5BA200301144 /* llama */; };
F1FE20E22B465ECA00B45541 /* LoadCustomButton.swift in Sources */ = {isa = PBXBuildFile; fileRef = F1FE20E12B465EC900B45541 /* LoadCustomButton.swift */; };
/* End PBXBuildFile section */
/* Begin PBXFileReference section */
549479CA2AC9E16000E0F78B /* Metal.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = Metal.framework; path = System/Library/Frameworks/Metal.framework; sourceTree = SDKROOT; };
79E1D9CC2B4CD16E005F8E46 /* InputButton.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = InputButton.swift; sourceTree = "<group>"; };
7FA3D2B22B2EA2F600543F92 /* DownloadButton.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = DownloadButton.swift; sourceTree = "<group>"; };
8A1C83732AC328BD0096AF73 /* llama.swiftui.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = llama.swiftui.app; sourceTree = BUILT_PRODUCTS_DIR; };
8A1C83762AC328BD0096AF73 /* llama_swiftuiApp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = llama_swiftuiApp.swift; sourceTree = "<group>"; };
8A1C83782AC328BD0096AF73 /* ContentView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ContentView.swift; sourceTree = "<group>"; };
8A1C837A2AC328BE0096AF73 /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = "<group>"; };
8A39BE092AC7601000BFEB40 /* Accelerate.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = Accelerate.framework; path = System/Library/Frameworks/Accelerate.framework; sourceTree = SDKROOT; };
8A3F84232AC4C891005E2EE8 /* models */ = {isa = PBXFileReference; lastKnownFileType = folder; name = models; path = llama.swiftui/Resources/models; sourceTree = "<group>"; };
8A907F322AC7134E006146EA /* LibLlama.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LibLlama.swift; sourceTree = "<group>"; };
8A9F7C4C2AC332EE008AE1EA /* LlamaState.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LlamaState.swift; sourceTree = "<group>"; };
DF2D2FE72B4A59BE00FCB72D /* llama.cpp */ = {isa = PBXFileReference; lastKnownFileType = wrapper; name = llama.cpp; path = ../..; sourceTree = "<group>"; };
F1FE20E12B465EC900B45541 /* LoadCustomButton.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LoadCustomButton.swift; sourceTree = "<group>"; };
/* End PBXFileReference section */
/* Begin PBXFrameworksBuildPhase section */
8A1C83702AC328BD0096AF73 /* Frameworks */ = {
isa = PBXFrameworksBuildPhase;
buildActionMask = 2147483647;
files = (
DF810E132B4A5BA200301144 /* llama in Frameworks */,
549479CB2AC9E16000E0F78B /* Metal.framework in Frameworks */,
8A39BE0A2AC7601100BFEB40 /* Accelerate.framework in Frameworks */,
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXFrameworksBuildPhase section */
/* Begin PBXGroup section */
8A1C836A2AC328BD0096AF73 = {
isa = PBXGroup;
children = (
DF2D2FE72B4A59BE00FCB72D /* llama.cpp */,
8A907F312AC7134E006146EA /* llama.cpp.swift */,
8A3F84232AC4C891005E2EE8 /* models */,
8A1C83752AC328BD0096AF73 /* llama.swiftui */,
8A1C83742AC328BD0096AF73 /* Products */,
8A39BE082AC7601000BFEB40 /* Frameworks */,
);
sourceTree = "<group>";
};
8A1C83742AC328BD0096AF73 /* Products */ = {
isa = PBXGroup;
children = (
8A1C83732AC328BD0096AF73 /* llama.swiftui.app */,
);
name = Products;
sourceTree = "<group>";
};
8A1C83752AC328BD0096AF73 /* llama.swiftui */ = {
isa = PBXGroup;
children = (
8A3F84102AC4BD85005E2EE8 /* Resources */,
8A9F7C4B2AC332DC008AE1EA /* Models */,
8A9F7C4A2AC332BF008AE1EA /* UI */,
8A1C83762AC328BD0096AF73 /* llama_swiftuiApp.swift */,
8A1C837A2AC328BE0096AF73 /* Assets.xcassets */,
);
path = llama.swiftui;
sourceTree = "<group>";
};
8A39BE082AC7601000BFEB40 /* Frameworks */ = {
isa = PBXGroup;
children = (
549479CA2AC9E16000E0F78B /* Metal.framework */,
8A39BE092AC7601000BFEB40 /* Accelerate.framework */,
);
name = Frameworks;
sourceTree = "<group>";
};
8A3F84102AC4BD85005E2EE8 /* Resources */ = {
isa = PBXGroup;
children = (
8A3F84112AC4BD8C005E2EE8 /* models */,
);
path = Resources;
sourceTree = "<group>";
};
8A3F84112AC4BD8C005E2EE8 /* models */ = {
isa = PBXGroup;
children = (
);
path = models;
sourceTree = "<group>";
};
8A907F312AC7134E006146EA /* llama.cpp.swift */ = {
isa = PBXGroup;
children = (
8A907F322AC7134E006146EA /* LibLlama.swift */,
);
path = llama.cpp.swift;
sourceTree = "<group>";
};
8A9F7C4A2AC332BF008AE1EA /* UI */ = {
isa = PBXGroup;
children = (
7FA3D2B22B2EA2F600543F92 /* DownloadButton.swift */,
8A1C83782AC328BD0096AF73 /* ContentView.swift */,
F1FE20E12B465EC900B45541 /* LoadCustomButton.swift */,
79E1D9CC2B4CD16E005F8E46 /* InputButton.swift */,
);
path = UI;
sourceTree = "<group>";
};
8A9F7C4B2AC332DC008AE1EA /* Models */ = {
isa = PBXGroup;
children = (
8A9F7C4C2AC332EE008AE1EA /* LlamaState.swift */,
);
path = Models;
sourceTree = "<group>";
};
/* End PBXGroup section */
/* Begin PBXNativeTarget section */
8A1C83722AC328BD0096AF73 /* llama.swiftui */ = {
isa = PBXNativeTarget;
buildConfigurationList = 8A1C83812AC328BE0096AF73 /* Build configuration list for PBXNativeTarget "llama.swiftui" */;
buildPhases = (
8A1C836F2AC328BD0096AF73 /* Sources */,
8A1C83702AC328BD0096AF73 /* Frameworks */,
8A1C83712AC328BD0096AF73 /* Resources */,
);
buildRules = (
);
dependencies = (
);
name = llama.swiftui;
packageProductDependencies = (
DF810E122B4A5BA200301144 /* llama */,
);
productName = llama.swiftui;
productReference = 8A1C83732AC328BD0096AF73 /* llama.swiftui.app */;
productType = "com.apple.product-type.application";
};
/* End PBXNativeTarget section */
/* Begin PBXProject section */
8A1C836B2AC328BD0096AF73 /* Project object */ = {
isa = PBXProject;
attributes = {
BuildIndependentTargetsInParallel = 1;
LastSwiftUpdateCheck = 1500;
LastUpgradeCheck = 1500;
TargetAttributes = {
8A1C83722AC328BD0096AF73 = {
CreatedOnToolsVersion = 15.0;
LastSwiftMigration = 1500;
};
};
};
buildConfigurationList = 8A1C836E2AC328BD0096AF73 /* Build configuration list for PBXProject "llama.swiftui" */;
compatibilityVersion = "Xcode 14.0";
developmentRegion = en;
hasScannedForEncodings = 0;
knownRegions = (
en,
Base,
);
mainGroup = 8A1C836A2AC328BD0096AF73;
packageReferences = (
);
productRefGroup = 8A1C83742AC328BD0096AF73 /* Products */;
projectDirPath = "";
projectRoot = "";
targets = (
8A1C83722AC328BD0096AF73 /* llama.swiftui */,
);
};
/* End PBXProject section */
/* Begin PBXResourcesBuildPhase section */
8A1C83712AC328BD0096AF73 /* Resources */ = {
isa = PBXResourcesBuildPhase;
buildActionMask = 2147483647;
files = (
8A3F84242AC4C891005E2EE8 /* models in Resources */,
8A1C837B2AC328BE0096AF73 /* Assets.xcassets in Resources */,
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXResourcesBuildPhase section */
/* Begin PBXSourcesBuildPhase section */
8A1C836F2AC328BD0096AF73 /* Sources */ = {
isa = PBXSourcesBuildPhase;
buildActionMask = 2147483647;
files = (
F1FE20E22B465ECA00B45541 /* LoadCustomButton.swift in Sources */,
8A907F332AC7138A006146EA /* LibLlama.swift in Sources */,
8A9F7C4D2AC332EE008AE1EA /* LlamaState.swift in Sources */,
8A1C83792AC328BD0096AF73 /* ContentView.swift in Sources */,
8A1C83772AC328BD0096AF73 /* llama_swiftuiApp.swift in Sources */,
7FA3D2B32B2EA2F600543F92 /* DownloadButton.swift in Sources */,
79E1D9CD2B4CD16E005F8E46 /* InputButton.swift in Sources */,
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXSourcesBuildPhase section */
/* Begin XCBuildConfiguration section */
8A1C837F2AC328BE0096AF73 /* Debug */ = {
isa = XCBuildConfiguration;
buildSettings = {
ALWAYS_SEARCH_USER_PATHS = NO;
ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES;
CLANG_ANALYZER_NONNULL = YES;
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
CLANG_CXX_LANGUAGE_STANDARD = "gnu++20";
CLANG_ENABLE_MODULES = YES;
CLANG_ENABLE_OBJC_ARC = YES;
CLANG_ENABLE_OBJC_WEAK = YES;
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
CLANG_WARN_BOOL_CONVERSION = YES;
CLANG_WARN_COMMA = YES;
CLANG_WARN_CONSTANT_CONVERSION = YES;
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
CLANG_WARN_EMPTY_BODY = YES;
CLANG_WARN_ENUM_CONVERSION = YES;
CLANG_WARN_INFINITE_RECURSION = YES;
CLANG_WARN_INT_CONVERSION = YES;
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
CLANG_WARN_STRICT_PROTOTYPES = YES;
CLANG_WARN_SUSPICIOUS_MOVE = YES;
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
CLANG_WARN_UNREACHABLE_CODE = YES;
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
COPY_PHASE_STRIP = NO;
DEBUG_INFORMATION_FORMAT = dwarf;
ENABLE_STRICT_OBJC_MSGSEND = YES;
ENABLE_TESTABILITY = YES;
ENABLE_USER_SCRIPT_SANDBOXING = YES;
GCC_C_LANGUAGE_STANDARD = gnu17;
GCC_DYNAMIC_NO_PIC = NO;
GCC_NO_COMMON_BLOCKS = YES;
GCC_OPTIMIZATION_LEVEL = 0;
GCC_PREPROCESSOR_DEFINITIONS = (
"DEBUG=1",
"$(inherited)",
);
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
GCC_WARN_UNDECLARED_SELECTOR = YES;
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
GCC_WARN_UNUSED_FUNCTION = YES;
GCC_WARN_UNUSED_VARIABLE = YES;
IPHONEOS_DEPLOYMENT_TARGET = 17.0;
LOCALIZATION_PREFERS_STRING_CATALOGS = YES;
MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE;
MTL_FAST_MATH = YES;
ONLY_ACTIVE_ARCH = YES;
SDKROOT = iphoneos;
SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)";
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
};
name = Debug;
};
8A1C83802AC328BE0096AF73 /* Release */ = {
isa = XCBuildConfiguration;
buildSettings = {
ALWAYS_SEARCH_USER_PATHS = NO;
ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES;
CLANG_ANALYZER_NONNULL = YES;
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
CLANG_CXX_LANGUAGE_STANDARD = "gnu++20";
CLANG_ENABLE_MODULES = YES;
CLANG_ENABLE_OBJC_ARC = YES;
CLANG_ENABLE_OBJC_WEAK = YES;
CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;
CLANG_WARN_BOOL_CONVERSION = YES;
CLANG_WARN_COMMA = YES;
CLANG_WARN_CONSTANT_CONVERSION = YES;
CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
CLANG_WARN_EMPTY_BODY = YES;
CLANG_WARN_ENUM_CONVERSION = YES;
CLANG_WARN_INFINITE_RECURSION = YES;
CLANG_WARN_INT_CONVERSION = YES;
CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;
CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;
CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;
CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;
CLANG_WARN_STRICT_PROTOTYPES = YES;
CLANG_WARN_SUSPICIOUS_MOVE = YES;
CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;
CLANG_WARN_UNREACHABLE_CODE = YES;
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
COPY_PHASE_STRIP = NO;
DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym";
ENABLE_NS_ASSERTIONS = NO;
ENABLE_STRICT_OBJC_MSGSEND = YES;
ENABLE_USER_SCRIPT_SANDBOXING = YES;
GCC_C_LANGUAGE_STANDARD = gnu17;
GCC_NO_COMMON_BLOCKS = YES;
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
GCC_WARN_UNDECLARED_SELECTOR = YES;
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
GCC_WARN_UNUSED_FUNCTION = YES;
GCC_WARN_UNUSED_VARIABLE = YES;
IPHONEOS_DEPLOYMENT_TARGET = 17.0;
LOCALIZATION_PREFERS_STRING_CATALOGS = YES;
MTL_ENABLE_DEBUG_INFO = NO;
MTL_FAST_MATH = YES;
SDKROOT = iphoneos;
SWIFT_COMPILATION_MODE = wholemodule;
VALIDATE_PRODUCT = YES;
};
name = Release;
};
8A1C83822AC328BE0096AF73 /* Debug */ = {
isa = XCBuildConfiguration;
buildSettings = {
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
CLANG_ENABLE_MODULES = YES;
CODE_SIGN_STYLE = Automatic;
CURRENT_PROJECT_VERSION = 1;
DEVELOPMENT_TEAM = K5UQJPP73A;
ENABLE_PREVIEWS = YES;
GENERATE_INFOPLIST_FILE = YES;
INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES;
INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES;
INFOPLIST_KEY_UILaunchScreen_Generation = YES;
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
IPHONEOS_DEPLOYMENT_TARGET = 16.0;
LD_RUNPATH_SEARCH_PATHS = (
"$(inherited)",
"@executable_path/Frameworks",
);
MARKETING_VERSION = 1.0;
PRODUCT_BUNDLE_IDENTIFIER = "com.bachittle.llama-swift";
PRODUCT_NAME = "$(TARGET_NAME)";
SUPPORTED_PLATFORMS = "iphoneos iphonesimulator xros xrsimulator";
SUPPORTS_XR_DESIGNED_FOR_IPHONE_IPAD = NO;
SWIFT_EMIT_LOC_STRINGS = YES;
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
SWIFT_VERSION = 5.0;
TARGETED_DEVICE_FAMILY = "1,2,7";
};
name = Debug;
};
8A1C83832AC328BE0096AF73 /* Release */ = {
isa = XCBuildConfiguration;
buildSettings = {
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
CLANG_ENABLE_MODULES = YES;
CODE_SIGN_STYLE = Automatic;
CURRENT_PROJECT_VERSION = 1;
DEVELOPMENT_TEAM = K5UQJPP73A;
ENABLE_PREVIEWS = YES;
GENERATE_INFOPLIST_FILE = YES;
INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES;
INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES;
INFOPLIST_KEY_UILaunchScreen_Generation = YES;
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight";
IPHONEOS_DEPLOYMENT_TARGET = 16.0;
LD_RUNPATH_SEARCH_PATHS = (
"$(inherited)",
"@executable_path/Frameworks",
);
MARKETING_VERSION = 1.0;
PRODUCT_BUNDLE_IDENTIFIER = "com.bachittle.llama-swift";
PRODUCT_NAME = "$(TARGET_NAME)";
SUPPORTED_PLATFORMS = "iphoneos iphonesimulator xros xrsimulator";
SUPPORTS_XR_DESIGNED_FOR_IPHONE_IPAD = NO;
SWIFT_EMIT_LOC_STRINGS = YES;
SWIFT_VERSION = 5.0;
TARGETED_DEVICE_FAMILY = "1,2,7";
};
name = Release;
};
/* End XCBuildConfiguration section */
/* Begin XCConfigurationList section */
8A1C836E2AC328BD0096AF73 /* Build configuration list for PBXProject "llama.swiftui" */ = {
isa = XCConfigurationList;
buildConfigurations = (
8A1C837F2AC328BE0096AF73 /* Debug */,
8A1C83802AC328BE0096AF73 /* Release */,
);
defaultConfigurationIsVisible = 0;
defaultConfigurationName = Release;
};
8A1C83812AC328BE0096AF73 /* Build configuration list for PBXNativeTarget "llama.swiftui" */ = {
isa = XCConfigurationList;
buildConfigurations = (
8A1C83822AC328BE0096AF73 /* Debug */,
8A1C83832AC328BE0096AF73 /* Release */,
);
defaultConfigurationIsVisible = 0;
defaultConfigurationName = Release;
};
/* End XCConfigurationList section */
/* Begin XCSwiftPackageProductDependency section */
DF810E122B4A5BA200301144 /* llama */ = {
isa = XCSwiftPackageProductDependency;
productName = llama;
};
/* End XCSwiftPackageProductDependency section */
};
rootObject = 8A1C836B2AC328BD0096AF73 /* Project object */;
}
<?xml version="1.0" encoding="UTF-8"?>
<Workspace
version = "1.0">
<FileRef
location = "self:">
</FileRef>
</Workspace>
{
"images" : [
{
"idiom" : "universal",
"platform" : "ios",
"size" : "1024x1024"
}
],
"info" : {
"author" : "xcode",
"version" : 1
}
}
import Foundation
struct Model: Identifiable {
var id = UUID()
var name: String
var url: String
var filename: String
var status: String?
}
@MainActor
class LlamaState: ObservableObject {
@Published var messageLog = ""
@Published var cacheCleared = false
@Published var downloadedModels: [Model] = []
@Published var undownloadedModels: [Model] = []
let NS_PER_S = 1_000_000_000.0
private var llamaContext: LlamaContext?
private var defaultModelUrl: URL? {
Bundle.main.url(forResource: "ggml-model", withExtension: "gguf", subdirectory: "models")
// Bundle.main.url(forResource: "llama-2-7b-chat", withExtension: "Q2_K.gguf", subdirectory: "models")
}
init() {
loadModelsFromDisk()
loadDefaultModels()
}
private func loadModelsFromDisk() {
do {
let documentsURL = getDocumentsDirectory()
let modelURLs = try FileManager.default.contentsOfDirectory(at: documentsURL, includingPropertiesForKeys: nil, options: [.skipsHiddenFiles, .skipsSubdirectoryDescendants])
for modelURL in modelURLs {
let modelName = modelURL.deletingPathExtension().lastPathComponent
downloadedModels.append(Model(name: modelName, url: "", filename: modelURL.lastPathComponent, status: "downloaded"))
}
} catch {
print("Error loading models from disk: \(error)")
}
}
private func loadDefaultModels() {
do {
try loadModel(modelUrl: defaultModelUrl)
} catch {
messageLog += "Error!\n"
}
for model in defaultModels {
let fileURL = getDocumentsDirectory().appendingPathComponent(model.filename)
if FileManager.default.fileExists(atPath: fileURL.path) {
} else {
var undownloadedModel = model
undownloadedModel.status = "download"
undownloadedModels.append(undownloadedModel)
}
}
}
func getDocumentsDirectory() -> URL {
let paths = FileManager.default.urls(for: .documentDirectory, in: .userDomainMask)
return paths[0]
}
private let defaultModels: [Model] = [
Model(name: "TinyLlama-1.1B (Q4_0, 0.6 GiB)",url: "https://huggingface.co/TheBloke/TinyLlama-1.1B-1T-OpenOrca-GGUF/resolve/main/tinyllama-1.1b-1t-openorca.Q4_0.gguf?download=true",filename: "tinyllama-1.1b-1t-openorca.Q4_0.gguf", status: "download"),
Model(
name: "TinyLlama-1.1B Chat (Q8_0, 1.1 GiB)",
url: "https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF/resolve/main/tinyllama-1.1b-chat-v1.0.Q8_0.gguf?download=true",
filename: "tinyllama-1.1b-chat-v1.0.Q8_0.gguf", status: "download"
),
Model(
name: "TinyLlama-1.1B (F16, 2.2 GiB)",
url: "https://huggingface.co/ggml-org/models/resolve/main/tinyllama-1.1b/ggml-model-f16.gguf?download=true",
filename: "tinyllama-1.1b-f16.gguf", status: "download"
),
Model(
name: "Phi-2.7B (Q4_0, 1.6 GiB)",
url: "https://huggingface.co/ggml-org/models/resolve/main/phi-2/ggml-model-q4_0.gguf?download=true",
filename: "phi-2-q4_0.gguf", status: "download"
),
Model(
name: "Phi-2.7B (Q8_0, 2.8 GiB)",
url: "https://huggingface.co/ggml-org/models/resolve/main/phi-2/ggml-model-q8_0.gguf?download=true",
filename: "phi-2-q8_0.gguf", status: "download"
),
Model(
name: "Mistral-7B-v0.1 (Q4_0, 3.8 GiB)",
url: "https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF/resolve/main/mistral-7b-v0.1.Q4_0.gguf?download=true",
filename: "mistral-7b-v0.1.Q4_0.gguf", status: "download"
),
Model(
name: "OpenHermes-2.5-Mistral-7B (Q3_K_M, 3.52 GiB)",
url: "https://huggingface.co/TheBloke/OpenHermes-2.5-Mistral-7B-GGUF/resolve/main/openhermes-2.5-mistral-7b.Q3_K_M.gguf?download=true",
filename: "openhermes-2.5-mistral-7b.Q3_K_M.gguf", status: "download"
)
]
func loadModel(modelUrl: URL?) throws {
if let modelUrl {
messageLog += "Loading model...\n"
llamaContext = try LlamaContext.create_context(path: modelUrl.path())
messageLog += "Loaded model \(modelUrl.lastPathComponent)\n"
// Assuming that the model is successfully loaded, update the downloaded models
updateDownloadedModels(modelName: modelUrl.lastPathComponent, status: "downloaded")
} else {
messageLog += "Load a model from the list below\n"
}
}
private func updateDownloadedModels(modelName: String, status: String) {
undownloadedModels.removeAll { $0.name == modelName }
}
func complete(text: String) async {
guard let llamaContext else {
return
}
let t_start = DispatchTime.now().uptimeNanoseconds
await llamaContext.completion_init(text: text)
let t_heat_end = DispatchTime.now().uptimeNanoseconds
let t_heat = Double(t_heat_end - t_start) / NS_PER_S
messageLog += "\(text)"
Task.detached {
while await !llamaContext.is_done {
let result = await llamaContext.completion_loop()
await MainActor.run {
self.messageLog += "\(result)"
}
}
let t_end = DispatchTime.now().uptimeNanoseconds
let t_generation = Double(t_end - t_heat_end) / self.NS_PER_S
let tokens_per_second = Double(await llamaContext.n_len) / t_generation
await llamaContext.clear()
await MainActor.run {
self.messageLog += """
\n
Done
Heat up took \(t_heat)s
Generated \(tokens_per_second) t/s\n
"""
}
}
}
func bench() async {
guard let llamaContext else {
return
}
messageLog += "\n"
messageLog += "Running benchmark...\n"
messageLog += "Model info: "
messageLog += await llamaContext.model_info() + "\n"
let t_start = DispatchTime.now().uptimeNanoseconds
let _ = await llamaContext.bench(pp: 8, tg: 4, pl: 1) // heat up
let t_end = DispatchTime.now().uptimeNanoseconds
let t_heat = Double(t_end - t_start) / NS_PER_S
messageLog += "Heat up time: \(t_heat) seconds, please wait...\n"
// if more than 5 seconds, then we're probably running on a slow device
if t_heat > 5.0 {
messageLog += "Heat up time is too long, aborting benchmark\n"
return
}
let result = await llamaContext.bench(pp: 512, tg: 128, pl: 1, nr: 3)
messageLog += "\(result)"
messageLog += "\n"
}
func clear() async {
guard let llamaContext else {
return
}
await llamaContext.clear()
messageLog = ""
}
}
import SwiftUI
struct ContentView: View {
@StateObject var llamaState = LlamaState()
@State private var multiLineText = ""
@State private var showingHelp = false // To track if Help Sheet should be shown
var body: some View {
NavigationView {
VStack {
ScrollView(.vertical, showsIndicators: true) {
Text(llamaState.messageLog)
.font(.system(size: 12))
.frame(maxWidth: .infinity, alignment: .leading)
.padding()
.onTapGesture {
UIApplication.shared.sendAction(#selector(UIResponder.resignFirstResponder), to: nil, from: nil, for: nil)
}
}
TextEditor(text: $multiLineText)
.frame(height: 80)
.padding()
.border(Color.gray, width: 0.5)
HStack {
Button("Send") {
sendText()
}
Button("Bench") {
bench()
}
Button("Clear") {
clear()
}
Button("Copy") {
UIPasteboard.general.string = llamaState.messageLog
}
}
.buttonStyle(.bordered)
.padding()
NavigationLink(destination: DrawerView(llamaState: llamaState)) {
Text("View Models")
}
.padding()
}
.padding()
.navigationBarTitle("Model Settings", displayMode: .inline)
}
}
func sendText() {
Task {
await llamaState.complete(text: multiLineText)
multiLineText = ""
}
}
func bench() {
Task {
await llamaState.bench()
}
}
func clear() {
Task {
await llamaState.clear()
}
}
struct DrawerView: View {
@ObservedObject var llamaState: LlamaState
@State private var showingHelp = false
func delete(at offsets: IndexSet) {
offsets.forEach { offset in
let model = llamaState.downloadedModels[offset]
let fileURL = getDocumentsDirectory().appendingPathComponent(model.filename)
do {
try FileManager.default.removeItem(at: fileURL)
} catch {
print("Error deleting file: \(error)")
}
}
// Remove models from downloadedModels array
llamaState.downloadedModels.remove(atOffsets: offsets)
}
func getDocumentsDirectory() -> URL {
let paths = FileManager.default.urls(for: .documentDirectory, in: .userDomainMask)
return paths[0]
}
var body: some View {
List {
Section(header: Text("Download Models From Hugging Face")) {
HStack {
InputButton(llamaState: llamaState)
}
}
Section(header: Text("Downloaded Models")) {
ForEach(llamaState.downloadedModels) { model in
DownloadButton(llamaState: llamaState, modelName: model.name, modelUrl: model.url, filename: model.filename)
}
.onDelete(perform: delete)
}
Section(header: Text("Default Models")) {
ForEach(llamaState.undownloadedModels) { model in
DownloadButton(llamaState: llamaState, modelName: model.name, modelUrl: model.url, filename: model.filename)
}
}
}
.listStyle(GroupedListStyle())
.navigationBarTitle("Model Settings", displayMode: .inline).toolbar {
ToolbarItem(placement: .navigationBarTrailing) {
Button("Help") {
showingHelp = true
}
}
}.sheet(isPresented: $showingHelp) { // Sheet for help modal
VStack(alignment: .leading) {
VStack(alignment: .leading) {
Text("1. Make sure the model is in GGUF Format")
.padding()
Text("2. Copy the download link of the quantized model")
.padding()
}
Spacer()
}
}
}
}
}
struct ContentView_Previews: PreviewProvider {
static var previews: some View {
ContentView()
}
}
import SwiftUI
struct DownloadButton: View {
@ObservedObject private var llamaState: LlamaState
private var modelName: String
private var modelUrl: String
private var filename: String
@State private var status: String
@State private var downloadTask: URLSessionDownloadTask?
@State private var progress = 0.0
@State private var observation: NSKeyValueObservation?
private static func getFileURL(filename: String) -> URL {
FileManager.default.urls(for: .documentDirectory, in: .userDomainMask)[0].appendingPathComponent(filename)
}
private func checkFileExistenceAndUpdateStatus() {
}
init(llamaState: LlamaState, modelName: String, modelUrl: String, filename: String) {
self.llamaState = llamaState
self.modelName = modelName
self.modelUrl = modelUrl
self.filename = filename
let fileURL = DownloadButton.getFileURL(filename: filename)
status = FileManager.default.fileExists(atPath: fileURL.path) ? "downloaded" : "download"
}
private func download() {
status = "downloading"
print("Downloading model \(modelName) from \(modelUrl)")
guard let url = URL(string: modelUrl) else { return }
let fileURL = DownloadButton.getFileURL(filename: filename)
downloadTask = URLSession.shared.downloadTask(with: url) { temporaryURL, response, error in
if let error = error {
print("Error: \(error.localizedDescription)")
return
}
guard let response = response as? HTTPURLResponse, (200...299).contains(response.statusCode) else {
print("Server error!")
return
}
do {
if let temporaryURL = temporaryURL {
try FileManager.default.copyItem(at: temporaryURL, to: fileURL)
print("Writing to \(filename) completed")
llamaState.cacheCleared = false
let model = Model(name: modelName, url: modelUrl, filename: filename, status: "downloaded")
llamaState.downloadedModels.append(model)
status = "downloaded"
}
} catch let err {
print("Error: \(err.localizedDescription)")
}
}
observation = downloadTask?.progress.observe(\.fractionCompleted) { progress, _ in
self.progress = progress.fractionCompleted
}
downloadTask?.resume()
}
var body: some View {
VStack {
if status == "download" {
Button(action: download) {
Text("Download " + modelName)
}
} else if status == "downloading" {
Button(action: {
downloadTask?.cancel()
status = "download"
}) {
Text("\(modelName) (Downloading \(Int(progress * 100))%)")
}
} else if status == "downloaded" {
Button(action: {
let fileURL = DownloadButton.getFileURL(filename: filename)
if !FileManager.default.fileExists(atPath: fileURL.path) {
download()
return
}
do {
try llamaState.loadModel(modelUrl: fileURL)
} catch let err {
print("Error: \(err.localizedDescription)")
}
}) {
Text("Load \(modelName)")
}
} else {
Text("Unknown status")
}
}
.onDisappear() {
downloadTask?.cancel()
}
.onChange(of: llamaState.cacheCleared) { newValue in
if newValue {
downloadTask?.cancel()
let fileURL = DownloadButton.getFileURL(filename: filename)
status = FileManager.default.fileExists(atPath: fileURL.path) ? "downloaded" : "download"
}
}
}
}
// #Preview {
// DownloadButton(
// llamaState: LlamaState(),
// modelName: "TheBloke / TinyLlama-1.1B-1T-OpenOrca-GGUF (Q4_0)",
// modelUrl: "https://huggingface.co/TheBloke/TinyLlama-1.1B-1T-OpenOrca-GGUF/resolve/main/tinyllama-1.1b-1t-openorca.Q4_0.gguf?download=true",
// filename: "tinyllama-1.1b-1t-openorca.Q4_0.gguf"
// )
// }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment